refactor(treesitter): simplify injection retrieval #33104

Simplify the logic for retrieving the injection ranges for the language
tree. The trees are now also sorted by starting position, regardless of
whether they are part of a combined injection or not. This would be
helpful if ranges are ever to be stored in an interval tree or other
kind of sorted tree structure.
This commit is contained in:
Riley Bruins
2025-03-28 04:38:47 -07:00
committed by GitHub
parent 18fa61049a
commit 75cbd9a8ae
2 changed files with 44 additions and 70 deletions

View File

@ -868,35 +868,42 @@ end
---@alias vim.treesitter.languagetree.Injection table<string,table<integer,vim.treesitter.languagetree.InjectionElem>>
---@param t table<integer,vim.treesitter.languagetree.Injection>
---@param tree_index integer
---@param t vim.treesitter.languagetree.Injection
---@param pattern integer
---@param lang string
---@param combined boolean
---@param ranges Range6[]
local function add_injection(t, tree_index, pattern, lang, combined, ranges)
---@param result table<string,Range6[][]>
local function add_injection(t, pattern, lang, combined, ranges, result)
if #ranges == 0 then
-- Make sure not to add an empty range set as this is interpreted to mean the whole buffer.
return
end
-- Each tree index should be isolated from the other nodes.
if not t[tree_index] then
t[tree_index] = {}
if not result[lang] then
result[lang] = {}
end
if not t[tree_index][lang] then
t[tree_index][lang] = {}
if not combined then
table.insert(result[lang], ranges)
return
end
-- Key this by pattern. If combined is set to true all captures of this pattern
if not t[lang] then
t[lang] = {}
end
-- Key this by pattern. For combined injections, all captures of this pattern
-- will be parsed by treesitter as the same "source".
-- If combined is false, each "region" will be parsed as a single source.
if not t[tree_index][lang][pattern] then
t[tree_index][lang][pattern] = { combined = combined, regions = {} }
if not t[lang][pattern] then
local regions = {}
t[lang][pattern] = regions
table.insert(result[lang], regions)
end
table.insert(t[tree_index][lang][pattern].regions, ranges)
for _, range in ipairs(ranges) do
table.insert(t[lang][pattern], range)
end
end
-- TODO(clason): replace by refactored `ts.has_parser` API (without side effects)
@ -964,19 +971,6 @@ function LanguageTree:_get_injection(match, metadata)
return lang, combined, ranges
end
--- Can't use vim.tbl_flatten since a range is just a table.
---@param regions Range6[][]
---@return Range6[]
local function combine_regions(regions)
local result = {} ---@type Range6[]
for _, region in ipairs(regions) do
for _, range in ipairs(region) do
result[#result + 1] = range
end
end
return result
end
--- Gets language injection regions by language.
---
--- This is where most of the injection processing occurs.
@ -993,13 +987,16 @@ function LanguageTree:_get_injections(range, thread_state)
return {}
end
---@type table<integer,vim.treesitter.languagetree.Injection>
local injections = {}
local start = vim.uv.hrtime()
---@type table<string,Range6[][]>
local result = {}
local full_scan = range == true or self._injection_query.has_combined_injections
for index, tree in pairs(self._trees) do
for _, tree in pairs(self._trees) do
---@type vim.treesitter.languagetree.Injection
local injections = {}
local root_node = tree:root()
local start_line, end_line ---@type integer, integer
if full_scan then
@ -1013,7 +1010,7 @@ function LanguageTree:_get_injections(range, thread_state)
do
local lang, combined, ranges = self:_get_injection(match, metadata)
if lang then
add_injection(injections, index, pattern, lang, combined, ranges)
add_injection(injections, pattern, lang, combined, ranges, result)
else
self:_log('match from injection query failed for pattern', pattern)
end
@ -1025,29 +1022,6 @@ function LanguageTree:_get_injections(range, thread_state)
end
end
---@type table<string,Range6[][]>
local result = {}
-- Generate a map by lang of node lists.
-- Each list is a set of ranges that should be parsed together.
for _, lang_map in pairs(injections) do
for lang, patterns in pairs(lang_map) do
if not result[lang] then
result[lang] = {}
end
for _, entry in pairs(patterns) do
if entry.combined then
table.insert(result[lang], combine_regions(entry.regions))
else
for _, ranges in pairs(entry.regions) do
table.insert(result[lang], ranges)
end
end
end
end
end
if full_scan then
self._processed_injection_range = entire_document_range
else

View File

@ -575,22 +575,22 @@ int x = INT_MAX;
eq(5, exec_lua('return #parser:children().c:trees()'))
eq({
{ 0, 0, 7, 0 }, -- root tree
{ 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 14, 3, 17 }, -- VALUE 123
{ 4, 15, 4, 18 }, -- VALUE1 123
{ 5, 15, 5, 18 }, -- VALUE2 123
{ 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
}, get_ranges())
n.feed('ggo<esc>')
eq(5, exec_lua('return #parser:children().c:trees()'))
eq({
{ 0, 0, 8, 0 }, -- root tree
{ 2, 26, 2, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 29, 3, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
{ 4, 14, 4, 17 }, -- VALUE 123
{ 5, 15, 5, 18 }, -- VALUE1 123
{ 6, 15, 6, 18 }, -- VALUE2 123
{ 2, 26, 2, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 29, 3, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
}, get_ranges())
end)
end)
@ -613,11 +613,11 @@ int x = INT_MAX;
eq(2, exec_lua('return #parser:children().c:trees()'))
eq({
{ 0, 0, 7, 0 }, -- root tree
{ 1, 26, 2, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 14, 5, 18 }, -- VALUE 123
-- VALUE1 123
-- VALUE2 123
{ 1, 26, 2, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
}, get_ranges())
n.feed('ggo<esc>')
@ -625,11 +625,11 @@ int x = INT_MAX;
eq(2, exec_lua('return #parser:children().c:trees()'))
eq({
{ 0, 0, 8, 0 }, -- root tree
{ 4, 14, 6, 18 }, -- VALUE 123
-- VALUE1 123
-- VALUE2 123
{ 2, 26, 3, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
-- VALUE 123
{ 4, 14, 6, 18 }, -- VALUE1 123
-- VALUE2 123
}, get_ranges())
n.feed('7ggI//<esc>')
@ -638,10 +638,10 @@ int x = INT_MAX;
eq(2, exec_lua('return #parser:children().c:trees()'))
eq({
{ 0, 0, 8, 0 }, -- root tree
{ 4, 14, 5, 18 }, -- VALUE 123
-- VALUE1 123
{ 2, 26, 3, 66 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
-- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
-- VALUE 123
{ 4, 14, 5, 18 }, -- VALUE1 123
}, get_ranges())
end)
@ -794,22 +794,22 @@ int x = INT_MAX;
eq(5, exec_lua('return #parser:children().c:trees()'))
eq({
{ 0, 0, 7, 0 }, -- root tree
{ 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 14, 3, 17 }, -- VALUE 123
{ 4, 15, 4, 18 }, -- VALUE1 123
{ 5, 15, 5, 18 }, -- VALUE2 123
{ 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
}, get_ranges())
n.feed('ggo<esc>')
eq(5, exec_lua('return #parser:children().c:trees()'))
eq({
{ 0, 0, 8, 0 }, -- root tree
{ 2, 26, 2, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 29, 3, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
{ 4, 14, 4, 17 }, -- VALUE 123
{ 5, 15, 5, 18 }, -- VALUE1 123
{ 6, 15, 6, 18 }, -- VALUE2 123
{ 2, 26, 2, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 29, 3, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
}, get_ranges())
end)
end)
@ -831,11 +831,11 @@ int x = INT_MAX;
eq('table', exec_lua('return type(parser:children().c)'))
eq({
{ 0, 0, 7, 0 }, -- root tree
{ 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
{ 3, 16, 3, 16 }, -- VALUE 123
{ 4, 17, 4, 17 }, -- VALUE1 123
{ 5, 17, 5, 17 }, -- VALUE2 123
{ 1, 26, 1, 63 }, -- READ_STRING(x, y) (char *)read_string((x), (size_t)(y))
{ 2, 29, 2, 66 }, -- READ_STRING_OK(x, y) (char *)read_string((x), (size_t)(y))
}, get_ranges())
end)
it('should list all directives', function()