fix(vim.iter): enable optimizations for arrays (lists with holes) (#28781)

The optimizations that vim.iter uses for array-like tables don't require
that the source table has no holes. The only thing that needs to change
is the determination if a table is "list-like": rather than requiring
consecutive, integer keys, we can simply test for (positive) integer
keys only, and remove any holes in the original array when we make a
copy for the iterator.
This commit is contained in:
Gregory Anders
2024-05-17 14:17:25 -05:00
committed by GitHub
parent aec4938a21
commit 4c0d18c197
3 changed files with 82 additions and 79 deletions

View File

@ -3830,6 +3830,7 @@ chained to create iterator "pipelines": the output of each pipeline stage is
input to the next stage. The first stage depends on the type passed to
`vim.iter()`:
• List tables (arrays, |lua-list|) yield only the value of each element.
• Holes (nil values) are allowed.
• Use |Iter:enumerate()| to also pass the index to the next stage.
• Or initialize with ipairs(): `vim.iter(ipairs(…))`.
• Non-list tables (|lua-dict|) yield both the key and value of each element.
@ -4287,9 +4288,9 @@ Iter:totable() *Iter:totable()*
Collect the iterator into a table.
The resulting table depends on the initial source in the iterator
pipeline. List-like tables and function iterators will be collected into a
list-like table. If multiple values are returned from the final stage in
the iterator pipeline, each value will be included in a table.
pipeline. Array-like tables and function iterators will be collected into
an array-like table. If multiple values are returned from the final stage
in the iterator pipeline, each value will be included in a table.
Examples: >lua
vim.iter(string.gmatch('100 20 50', '%d+')):map(tonumber):totable()
@ -4302,7 +4303,7 @@ Iter:totable() *Iter:totable()*
-- { { 'a', 1 }, { 'c', 3 } }
<
The generated table is a list-like table with consecutive, numeric
The generated table is an array-like table with consecutive, numeric
indices. To create a map-like table with arbitrary keys, use
|Iter:fold()|.

View File

@ -7,6 +7,7 @@
--- `vim.iter()`:
---
--- - List tables (arrays, |lua-list|) yield only the value of each element.
--- - Holes (nil values) are allowed.
--- - Use |Iter:enumerate()| to also pass the index to the next stage.
--- - Or initialize with ipairs(): `vim.iter(ipairs(…))`.
--- - Non-list tables (|lua-dict|) yield both the key and value of each element.
@ -80,13 +81,13 @@ end
--- Special case implementations for iterators on list tables.
---@nodoc
---@class ListIter : Iter
---@class ArrayIter : Iter
---@field _table table Underlying table data
---@field _head number Index to the front of a table iterator
---@field _tail number Index to the end of a table iterator (exclusive)
local ListIter = {}
ListIter.__index = setmetatable(ListIter, Iter)
ListIter.__call = function(self)
local ArrayIter = {}
ArrayIter.__index = setmetatable(ArrayIter, Iter)
ArrayIter.__call = function(self)
return self:next()
end
@ -110,36 +111,34 @@ end
local function sanitize(t)
if type(t) == 'table' and getmetatable(t) == packedmt then
-- Remove length tag
-- Remove length tag and metatable
t.n = nil
setmetatable(t, nil)
end
return t
end
--- Flattens a single list-like table. Errors if it attempts to flatten a
--- Flattens a single array-like table. Errors if it attempts to flatten a
--- dict-like table
---@param v table table which should be flattened
---@param t table table which should be flattened
---@param max_depth number depth to which the table should be flattened
---@param depth number current iteration depth
---@param result table output table that contains flattened result
---@return table|nil flattened table if it can be flattened, otherwise nil
local function flatten(v, max_depth, depth, result)
if depth < max_depth and type(v) == 'table' then
local i = 0
for _ in pairs(v) do
i = i + 1
if v[i] == nil then
local function flatten(t, max_depth, depth, result)
if depth < max_depth and type(t) == 'table' then
for k, v in pairs(t) do
if type(k) ~= 'number' or k <= 0 or math.floor(k) ~= k then
-- short-circuit: this is not a list like table
return nil
end
if flatten(v[i], max_depth, depth + 1, result) == nil then
if flatten(v, max_depth, depth + 1, result) == nil then
return nil
end
end
else
result[#result + 1] = v
elseif t ~= nil then
result[#result + 1] = t
end
return result
@ -198,7 +197,7 @@ function Iter:filter(f)
end
---@private
function ListIter:filter(f)
function ArrayIter:filter(f)
local inc = self._head < self._tail and 1 or -1
local n = self._head
for i = self._head, self._tail - inc, inc do
@ -233,11 +232,11 @@ end
---@return Iter
---@diagnostic disable-next-line:unused-local
function Iter:flatten(depth) -- luacheck: no unused args
error('flatten() requires a list-like table')
error('flatten() requires an array-like table')
end
---@private
function ListIter:flatten(depth)
function ArrayIter:flatten(depth)
depth = depth or 1
local inc = self._head < self._tail and 1 or -1
local target = {}
@ -247,7 +246,7 @@ function ListIter:flatten(depth)
-- exit early if we try to flatten a dict-like table
if flattened == nil then
error('flatten() requires a list-like table')
error('flatten() requires an array-like table')
end
for _, v in pairs(flattened) do
@ -327,7 +326,7 @@ function Iter:map(f)
end
---@private
function ListIter:map(f)
function ArrayIter:map(f)
local inc = self._head < self._tail and 1 or -1
local n = self._head
for i = self._head, self._tail - inc, inc do
@ -360,7 +359,7 @@ function Iter:each(f)
end
---@private
function ListIter:each(f)
function ArrayIter:each(f)
local inc = self._head < self._tail and 1 or -1
for i = self._head, self._tail - inc, inc do
f(unpack(self._table[i]))
@ -371,7 +370,7 @@ end
--- Collect the iterator into a table.
---
--- The resulting table depends on the initial source in the iterator pipeline.
--- List-like tables and function iterators will be collected into a list-like
--- Array-like tables and function iterators will be collected into an array-like
--- table. If multiple values are returned from the final stage in the iterator
--- pipeline, each value will be included in a table.
---
@ -388,7 +387,7 @@ end
--- -- { { 'a', 1 }, { 'c', 3 } }
--- ```
---
--- The generated table is a list-like table with consecutive, numeric indices.
--- The generated table is an array-like table with consecutive, numeric indices.
--- To create a map-like table with arbitrary keys, use |Iter:fold()|.
---
---
@ -408,12 +407,12 @@ function Iter:totable()
end
---@private
function ListIter:totable()
if self.next ~= ListIter.next or self._head >= self._tail then
function ArrayIter:totable()
if self.next ~= ArrayIter.next or self._head >= self._tail then
return Iter.totable(self)
end
local needs_sanitize = getmetatable(self._table[1]) == packedmt
local needs_sanitize = getmetatable(self._table[self._head]) == packedmt
-- Reindex and sanitize.
local len = self._tail - self._head
@ -493,7 +492,7 @@ function Iter:fold(init, f)
end
---@private
function ListIter:fold(init, f)
function ArrayIter:fold(init, f)
local acc = init
local inc = self._head < self._tail and 1 or -1
for i = self._head, self._tail - inc, inc do
@ -525,7 +524,7 @@ function Iter:next()
end
---@private
function ListIter:next()
function ArrayIter:next()
if self._head ~= self._tail then
local v = self._table[self._head]
local inc = self._head < self._tail and 1 or -1
@ -548,11 +547,11 @@ end
---
---@return Iter
function Iter:rev()
error('rev() requires a list-like table')
error('rev() requires an array-like table')
end
---@private
function ListIter:rev()
function ArrayIter:rev()
local inc = self._head < self._tail and 1 or -1
self._head, self._tail = self._tail - inc, self._head - inc
return self
@ -576,11 +575,11 @@ end
---
---@return any
function Iter:peek()
error('peek() requires a list-like table')
error('peek() requires an array-like table')
end
---@private
function ListIter:peek()
function ArrayIter:peek()
if self._head ~= self._tail then
return self._table[self._head]
end
@ -657,11 +656,11 @@ end
---@return any
---@diagnostic disable-next-line: unused-local
function Iter:rfind(f) -- luacheck: no unused args
error('rfind() requires a list-like table')
error('rfind() requires an array-like table')
end
---@private
function ListIter:rfind(f)
function ArrayIter:rfind(f)
if type(f) ~= 'function' then
local val = f
f = function(v)
@ -709,10 +708,10 @@ function Iter:take(n)
end
---@private
function ListIter:take(n)
local inc = self._head < self._tail and 1 or -1
function ArrayIter:take(n)
local inc = self._head < self._tail and n or -n
local cmp = self._head < self._tail and math.min or math.max
self._tail = cmp(self._tail, self._head + n * inc)
self._tail = cmp(self._tail, self._head + inc)
return self
end
@ -730,11 +729,11 @@ end
---
---@return any
function Iter:pop()
error('pop() requires a list-like table')
error('pop() requires an array-like table')
end
--- @nodoc
function ListIter:pop()
function ArrayIter:pop()
if self._head ~= self._tail then
local inc = self._head < self._tail and 1 or -1
self._tail = self._tail - inc
@ -760,11 +759,11 @@ end
---
---@return any
function Iter:rpeek()
error('rpeek() requires a list-like table')
error('rpeek() requires an array-like table')
end
---@nodoc
function ListIter:rpeek()
function ArrayIter:rpeek()
if self._head ~= self._tail then
local inc = self._head < self._tail and 1 or -1
return self._table[self._tail - inc]
@ -793,7 +792,7 @@ function Iter:skip(n)
end
---@private
function ListIter:skip(n)
function ArrayIter:skip(n)
local inc = self._head < self._tail and n or -n
self._head = self._head + inc
if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then
@ -818,11 +817,11 @@ end
---@return Iter
---@diagnostic disable-next-line: unused-local
function Iter:rskip(n) -- luacheck: no unused args
error('rskip() requires a list-like table')
error('rskip() requires an array-like table')
end
---@private
function ListIter:rskip(n)
function ArrayIter:rskip(n)
local inc = self._head < self._tail and n or -n
self._tail = self._tail - inc
if (inc > 0 and self._head > self._tail) or (inc < 0 and self._head < self._tail) then
@ -870,11 +869,11 @@ end
---@return Iter
---@diagnostic disable-next-line: unused-local
function Iter:slice(first, last) -- luacheck: no unused args
error('slice() requires a list-like table')
error('slice() requires an array-like table')
end
---@private
function ListIter:slice(first, last)
function ArrayIter:slice(first, last)
return self:skip(math.max(0, first - 1)):rskip(math.max(0, self._tail - last - 1))
end
@ -955,7 +954,7 @@ function Iter:last()
end
---@private
function ListIter:last()
function ArrayIter:last()
local inc = self._head < self._tail and 1 or -1
local v = self._table[self._tail - inc]
self._head = self._tail
@ -1000,7 +999,7 @@ function Iter:enumerate()
end
---@private
function ListIter:enumerate()
function ArrayIter:enumerate()
local inc = self._head < self._tail and 1 or -1
for i = self._head, self._tail - inc, inc do
local v = self._table[i]
@ -1030,17 +1029,14 @@ function Iter.new(src, ...)
local t = {}
-- O(n): scan the source table to decide if it is a list (consecutive integer indices 1…n).
local count = 0
for _ in pairs(src) do
count = count + 1
local v = src[count]
if v == nil then
-- O(n): scan the source table to decide if it is an array (only positive integer indices).
for k, v in pairs(src) do
if type(k) ~= 'number' or k <= 0 or math.floor(k) ~= k then
return Iter.new(pairs(src))
end
t[count] = v
t[#t + 1] = v
end
return ListIter.new(t)
return ArrayIter.new(t)
end
if type(src) == 'function' then
@ -1068,17 +1064,18 @@ function Iter.new(src, ...)
return it
end
--- Create a new ListIter
--- Create a new ArrayIter
---
---@param t table List-like table. Caller guarantees that this table is a valid list.
---@param t table Array-like table. Caller guarantees that this table is a valid array. Can have
--- holes (nil values).
---@return Iter
---@private
function ListIter.new(t)
function ArrayIter.new(t)
local it = {}
it._table = t
it._head = 1
it._tail = #t + 1
setmetatable(it, ListIter)
setmetatable(it, ArrayIter)
return it
end

View File

@ -117,6 +117,9 @@ describe('vim.iter', function()
eq({ { 1, 1 }, { 2, 4 }, { 3, 9 } }, it:totable())
end
-- Holes in array-like tables are removed
eq({ 1, 2, 3 }, vim.iter({ 1, nil, 2, nil, 3 }):totable())
do
local it = vim.iter(string.gmatch('1,4,lol,17,blah,2,9,3', '%d+')):map(tonumber)
eq({ 1, 4, 17, 2, 9, 3 }, it:totable())
@ -142,7 +145,7 @@ describe('vim.iter', function()
eq({ 3, 2, 1 }, vim.iter({ 1, 2, 3 }):rev():totable())
local it = vim.iter(string.gmatch('abc', '%w'))
matches('rev%(%) requires a list%-like table', pcall_err(it.rev, it))
matches('rev%(%) requires an array%-like table', pcall_err(it.rev, it))
end)
it('skip()', function()
@ -181,7 +184,7 @@ describe('vim.iter', function()
end
local it = vim.iter(vim.gsplit('a|b|c|d', '|'))
matches('rskip%(%) requires a list%-like table', pcall_err(it.rskip, it, 0))
matches('rskip%(%) requires an array%-like table', pcall_err(it.rskip, it, 0))
end)
it('slice()', function()
@ -195,7 +198,7 @@ describe('vim.iter', function()
eq({ 8, 9, 10 }, vim.iter(q):slice(8, 11):totable())
local it = vim.iter(vim.gsplit('a|b|c|d', '|'))
matches('slice%(%) requires a list%-like table', pcall_err(it.slice, it, 1, 3))
matches('slice%(%) requires an array%-like table', pcall_err(it.slice, it, 1, 3))
end)
it('nth()', function()
@ -234,7 +237,7 @@ describe('vim.iter', function()
end
local it = vim.iter(vim.gsplit('a|b|c|d', '|'))
matches('rskip%(%) requires a list%-like table', pcall_err(it.nth, it, -1))
matches('rskip%(%) requires an array%-like table', pcall_err(it.nth, it, -1))
end)
it('take()', function()
@ -356,7 +359,7 @@ describe('vim.iter', function()
do
local it = vim.iter(vim.gsplit('hi', ''))
matches('peek%(%) requires a list%-like table', pcall_err(it.peek, it))
matches('peek%(%) requires an array%-like table', pcall_err(it.peek, it))
end
end)
@ -417,7 +420,7 @@ describe('vim.iter', function()
do
local it = vim.iter(vim.gsplit('AbCdE', ''))
matches('rfind%(%) requires a list%-like table', pcall_err(it.rfind, it, 'E'))
matches('rfind%(%) requires an array%-like table', pcall_err(it.rfind, it, 'E'))
end
end)
@ -434,7 +437,7 @@ describe('vim.iter', function()
do
local it = vim.iter(vim.gsplit('hi', ''))
matches('pop%(%) requires a list%-like table', pcall_err(it.pop, it))
matches('pop%(%) requires an array%-like table', pcall_err(it.pop, it))
end
end)
@ -448,7 +451,7 @@ describe('vim.iter', function()
do
local it = vim.iter(vim.gsplit('hi', ''))
matches('rpeek%(%) requires a list%-like table', pcall_err(it.rpeek, it))
matches('rpeek%(%) requires an array%-like table', pcall_err(it.rpeek, it))
end
end)
@ -482,18 +485,20 @@ describe('vim.iter', function()
local m = { a = 1, b = { 2, 3 }, d = { 4 } }
local it = vim.iter(m)
local flat_err = 'flatten%(%) requires a list%-like table'
local flat_err = 'flatten%(%) requires an array%-like table'
matches(flat_err, pcall_err(it.flatten, it))
-- cases from the documentation
local simple_example = { 1, { 2 }, { { 3 } } }
eq({ 1, 2, { 3 } }, vim.iter(simple_example):flatten():totable())
local not_list_like = vim.iter({ [2] = 2 })
matches(flat_err, pcall_err(not_list_like.flatten, not_list_like))
local not_list_like = { [2] = 2 }
eq({ 2 }, vim.iter(not_list_like):flatten():totable())
local also_not_list_like = vim.iter({ nil, 2 })
matches(flat_err, pcall_err(not_list_like.flatten, also_not_list_like))
local also_not_list_like = { nil, 2 }
eq({ 2 }, vim.iter(also_not_list_like):flatten():totable())
eq({ 1, 2, 3 }, vim.iter({ nil, { 1, nil, 2 }, 3 }):flatten():totable())
local nested_non_lists = vim.iter({ 1, { { a = 2 } }, { { nil } }, { 3 } })
eq({ 1, { a = 2 }, { nil }, 3 }, nested_non_lists:flatten():totable())