From e3e5b1fe8365bcffe2a8e437a77015c25cbeaa0c Mon Sep 17 00:00:00 2001 From: xfy Date: Sun, 12 Apr 2026 11:21:17 +0800 Subject: [PATCH] =?UTF-8?q?feat(lua):=20=E5=AE=9E=E7=8E=B0=E5=85=B1?= =?UTF-8?q?=E4=BA=AB=E5=AD=97=E5=85=B8=20API=20(ngx.shared.DICT)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加共享内存字典实现,支持并发安全的 key-value 存储: - SharedDictManager: 管理多个命名的 SharedDict 实例 - SharedDict: 带 LRU 汰出策略的内存字典 - 支持 set/get/add/incr/size/free_space 操作 - 支持带 TTL 的过期机制 Co-Authored-By: Claude Opus 4.6 --- examples/lua-scripts/shared_dict.lua | 43 +++ internal/lua/api_shared_dict.go | 428 +++++++++++++++++++++++++++ internal/lua/shared_dict.go | 324 ++++++++++++++++++++ internal/lua/shared_dict_test.go | 251 ++++++++++++++++ 4 files changed, 1046 insertions(+) create mode 100644 examples/lua-scripts/shared_dict.lua create mode 100644 internal/lua/api_shared_dict.go create mode 100644 internal/lua/shared_dict.go create mode 100644 internal/lua/shared_dict_test.go diff --git a/examples/lua-scripts/shared_dict.lua b/examples/lua-scripts/shared_dict.lua new file mode 100644 index 0000000..a43f52e --- /dev/null +++ b/examples/lua-scripts/shared_dict.lua @@ -0,0 +1,43 @@ +-- shared_dict.lua - 共享字典示例 +-- 此脚本演示 ngx.shared.DICT 的使用 + +-- 获取共享字典(需要在配置中预先定义) +local dict = ngx.shared.DICT("my_cache") + +-- 设置值 +local ok, err = dict:set("user_count", "100") +if not ok then + ngx.log(ngx.ERR, "failed to set user_count: ", err) +end + +-- 设置带 TTL 的值 +ok, err = dict:set("session_token", "abc123", 3600) -- 1 小时过期 +if not ok then + ngx.log(ngx.ERR, "failed to set session_token: ", err) +end + +-- 获取值 +local value, flags = dict:get("user_count") +ngx.say("user_count: ", value) + +-- 自增计数器 +local new_val, err = dict:incr("request_count", 1) +ngx.say("request_count: ", new_val) + +-- 添加值(仅不存在时) +ok, err = dict:add("unique_key", "value") +if ok then + ngx.say("unique_key added successfully") +else + ngx.say("unique_key already exists") +end + +-- 查看字典大小 +local size = dict:size() +ngx.say("dict size: ", size) + +-- 获取剩余容量 +local free = dict:free_space() +ngx.say("free space: ", free) + +ngx.say("Shared dict demo completed!") \ No newline at end of file diff --git a/internal/lua/api_shared_dict.go b/internal/lua/api_shared_dict.go new file mode 100644 index 0000000..1ae94cb --- /dev/null +++ b/internal/lua/api_shared_dict.go @@ -0,0 +1,428 @@ +// Package lua 提供 Lua 脚本嵌入能力 +package lua + +import ( + "sync" + "time" + + glua "github.com/yuin/gopher-lua" +) + +// SharedDictManager 共享字典管理器 +// 管理多个命名的 SharedDict 实例 +type SharedDictManager struct { + mu sync.RWMutex + dicts map[string]*SharedDict +} + +// NewSharedDictManager 创建字典管理器 +func NewSharedDictManager() *SharedDictManager { + return &SharedDictManager{ + dicts: make(map[string]*SharedDict), + } +} + +// CreateDict 创建或获取字典 +func (m *SharedDictManager) CreateDict(name string, maxItems int) *SharedDict { + m.mu.Lock() + defer m.mu.Unlock() + + if dict, ok := m.dicts[name]; ok { + return dict + } + + dict := NewSharedDict(name, maxItems) + m.dicts[name] = dict + return dict +} + +// GetDict 获取字典 +func (m *SharedDictManager) GetDict(name string) *SharedDict { + m.mu.RLock() + defer m.mu.RUnlock() + return m.dicts[name] +} + +// Close 清理所有字典 +func (m *SharedDictManager) Close() { + m.mu.Lock() + defer m.mu.Unlock() + m.dicts = nil +} + +// DictConfig 字典配置 +type DictConfig struct { + Name string + MaxItems int +} + +// RegisterSharedDictAPI 注册 ngx.shared.DICT API +func RegisterSharedDictAPI(L *glua.LState, manager *SharedDictManager, ngx *glua.LTable) { + // 创建 ngx.shared 表 + shared := L.NewTable() + + // ngx.shared.DICT - 返回字典 userdata + L.SetField(shared, "DICT", L.NewFunction(func(L *glua.LState) int { + name := L.CheckString(1) + + dict := manager.GetDict(name) + if dict == nil { + L.Push(glua.LNil) + L.Push(glua.LString("shared dict not found: " + name)) + return 2 + } + + // 返回字典 userdata + ud := L.NewUserData() + ud.Value = dict + L.SetMetatable(ud, L.GetTypeMetatable("ngx.shared.dict")) + L.Push(ud) + return 1 + })) + + L.SetField(ngx, "shared", shared) + + // 创建字典类型元表 + mt := L.NewTypeMetatable("ngx.shared.dict") + L.SetField(mt, "__index", L.NewFunction(dictIndex)) + L.SetField(mt, "__newindex", L.NewFunction(dictNewIndex)) + L.SetField(mt, "__tostring", L.NewFunction(dictToString)) + + // 注册字典方法 + methods := L.NewTable() + L.SetField(methods, "get", L.NewFunction(dictGet)) + L.SetField(methods, "set", L.NewFunction(dictSet)) + L.SetField(methods, "add", L.NewFunction(dictAdd)) + L.SetField(methods, "replace", L.NewFunction(dictReplace)) + L.SetField(methods, "incr", L.NewFunction(dictIncr)) + L.SetField(methods, "delete", L.NewFunction(dictDelete)) + L.SetField(methods, "flush_all", L.NewFunction(dictFlushAll)) + L.SetField(methods, "flush_expired", L.NewFunction(dictFlushExpired)) + L.SetField(methods, "get_keys", L.NewFunction(dictGetKeys)) + L.SetField(methods, "size", L.NewFunction(dictSize)) + L.SetField(methods, "free_space", L.NewFunction(dictFreeSpace)) + L.SetField(mt, "methods", methods) +} + +// dictIndex 字典索引方法 +func dictIndex(L *glua.LState) int { + ud := L.CheckUserData(1) + dict, ok := ud.Value.(*SharedDict) + if !ok { + L.RaiseError("invalid shared dict") + return 0 + } + + key := L.CheckString(2) + + // 检查是否是方法 + methods := L.GetField(L.Get(1).(*glua.LUserData).Metatable, "methods") + if method := L.GetField(methods, key); method != glua.LNil { + L.Push(method) + return 1 + } + + // 否则作为 key 获取值 + value, expired, err := dict.Get(key) + if err != nil { + L.RaiseError("%s", err.Error()) + return 0 + } + if expired { + L.Push(glua.LNil) + L.Push(glua.LString("expired")) + return 2 + } + if value == "" { + L.Push(glua.LNil) + return 1 + } + L.Push(glua.LString(value)) + return 1 +} + +// dictNewIndex 字典设置方法 +func dictNewIndex(L *glua.LState) int { + ud := L.CheckUserData(1) + dict, ok := ud.Value.(*SharedDict) + if !ok { + L.RaiseError("invalid shared dict") + return 0 + } + + key := L.CheckString(2) + value := L.CheckString(3) + + ok, err := dict.Set(key, value, 0) + if err != nil { + L.RaiseError("%s", err.Error()) + return 0 + } + if !ok { + L.Push(glua.LFalse) + L.Push(glua.LString("no memory")) + return 2 + } + L.Push(glua.LTrue) + return 1 +} + +// dictToString 字典字符串表示 +func dictToString(L *glua.LState) int { + ud := L.CheckUserData(1) + dict, ok := ud.Value.(*SharedDict) + if !ok { + L.Push(glua.LString("invalid shared dict")) + return 1 + } + L.Push(glua.LString("ngx.shared.dict:" + dict.name)) + return 1 +} + +// dictGet 获取值 +// dict:get(key) -> value, flags | nil, err +func dictGet(L *glua.LState) int { + ud := L.CheckUserData(1) + dict, ok := ud.Value.(*SharedDict) + if !ok { + L.RaiseError("invalid shared dict") + return 0 + } + + key := L.CheckString(2) + + value, expired, err := dict.Get(key) + if err != nil { + L.Push(glua.LNil) + L.Push(glua.LString(err.Error())) + return 2 + } + if expired { + L.Push(glua.LNil) + L.Push(glua.LString("expired")) + return 2 + } + if value == "" { + L.Push(glua.LNil) + return 1 + } + L.Push(glua.LString(value)) + L.Push(glua.LNumber(0)) // flags(暂不支持) + return 2 +} + +// dictSet 设置值 +// dict:set(key, value, exptime?, flags?) -> ok, err +func dictSet(L *glua.LState) int { + ud := L.CheckUserData(1) + dict, ok := ud.Value.(*SharedDict) + if !ok { + L.RaiseError("invalid shared dict") + return 0 + } + + key := L.CheckString(2) + value := L.CheckString(3) + + ttl := time.Duration(0) + if L.GetTop() >= 4 { + ttl = time.Duration(L.CheckNumber(4)) * time.Second + } + + // flags 参数暂不使用 + + ok, err := dict.Set(key, value, ttl) + if err != nil { + L.Push(glua.LFalse) + L.Push(glua.LString(err.Error())) + return 2 + } + if !ok { + L.Push(glua.LFalse) + L.Push(glua.LString("no memory")) + return 2 + } + L.Push(glua.LTrue) + return 1 +} + +// dictAdd 添加值(不存在时) +// dict:add(key, value, exptime?, flags?) -> ok, err +func dictAdd(L *glua.LState) int { + ud := L.CheckUserData(1) + dict, ok := ud.Value.(*SharedDict) + if !ok { + L.RaiseError("invalid shared dict") + return 0 + } + + key := L.CheckString(2) + value := L.CheckString(3) + + ttl := time.Duration(0) + if L.GetTop() >= 4 { + ttl = time.Duration(L.CheckNumber(4)) * time.Second + } + + ok, err := dict.Add(key, value, ttl) + if err != nil { + L.Push(glua.LFalse) + L.Push(glua.LString(err.Error())) + return 2 + } + if !ok { + L.Push(glua.LFalse) + L.Push(glua.LString("exists")) + return 2 + } + L.Push(glua.LTrue) + return 1 +} + +// dictReplace 替换值(存在时) +// dict:replace(key, value, exptime?, flags?) -> ok, err +func dictReplace(L *glua.LState) int { + ud := L.CheckUserData(1) + dict, ok := ud.Value.(*SharedDict) + if !ok { + L.RaiseError("invalid shared dict") + return 0 + } + + key := L.CheckString(2) + value := L.CheckString(3) + + ttl := time.Duration(0) + if L.GetTop() >= 4 { + ttl = time.Duration(L.CheckNumber(4)) * time.Second + } + + // 检查是否存在 + _, expired, _ := dict.Get(key) + if expired { + L.Push(glua.LFalse) + L.Push(glua.LString("not found")) + return 2 + } + + ok, err := dict.Set(key, value, ttl) + if err != nil { + L.Push(glua.LFalse) + L.Push(glua.LString(err.Error())) + return 2 + } + L.Push(glua.LTrue) + return 1 +} + +// dictIncr 自增数值 +// dict:incr(key, value) -> new_value, err +func dictIncr(L *glua.LState) int { + ud := L.CheckUserData(1) + dict, ok := ud.Value.(*SharedDict) + if !ok { + L.RaiseError("invalid shared dict") + return 0 + } + + key := L.CheckString(2) + increment := int(L.CheckNumber(3)) + + newValue, err := dict.Incr(key, increment) + if err != nil { + L.Push(glua.LNil) + L.Push(glua.LString(err.Error())) + return 2 + } + L.Push(glua.LNumber(newValue)) + return 1 +} + +// dictDelete 删除条目 +// dict:delete(key) -> ok +func dictDelete(L *glua.LState) int { + ud := L.CheckUserData(1) + dict, ok := ud.Value.(*SharedDict) + if !ok { + L.RaiseError("invalid shared dict") + return 0 + } + + key := L.CheckString(2) + dict.Delete(key) + L.Push(glua.LTrue) + return 1 +} + +// dictFlushAll 清空字典 +// dict:flush_all() +func dictFlushAll(L *glua.LState) int { + ud := L.CheckUserData(1) + dict, ok := ud.Value.(*SharedDict) + if !ok { + L.RaiseError("invalid shared dict") + return 0 + } + + dict.FlushAll() + return 0 +} + +// dictFlushExpired 清除过期条目 +// dict:flush_expired(max_count?) -> flushed_count +func dictFlushExpired(L *glua.LState) int { + ud := L.CheckUserData(1) + dict, ok := ud.Value.(*SharedDict) + if !ok { + L.RaiseError("invalid shared dict") + return 0 + } + + count := dict.FlushExpired() + L.Push(glua.LNumber(count)) + return 1 +} + +// dictGetKeys 获取所有键 +// dict:get_keys(max_count?) -> keys +func dictGetKeys(L *glua.LState) int { + ud := L.CheckUserData(1) + _, ok := ud.Value.(*SharedDict) + if !ok { + L.RaiseError("invalid shared dict") + return 0 + } + + // 暂不实现完整版,返回空表 + keys := L.NewTable() + L.Push(keys) + return 1 +} + +// dictSize 获取条目数 +// dict:size() -> count +func dictSize(L *glua.LState) int { + ud := L.CheckUserData(1) + dict, ok := ud.Value.(*SharedDict) + if !ok { + L.RaiseError("invalid shared dict") + return 0 + } + + L.Push(glua.LNumber(dict.Size())) + return 1 +} + +// dictFreeSpace 获取剩余容量 +// dict:free_space() -> slots +func dictFreeSpace(L *glua.LState) int { + ud := L.CheckUserData(1) + dict, ok := ud.Value.(*SharedDict) + if !ok { + L.RaiseError("invalid shared dict") + return 0 + } + + L.Push(glua.LNumber(dict.FreeSlots())) + return 1 +} diff --git a/internal/lua/shared_dict.go b/internal/lua/shared_dict.go new file mode 100644 index 0000000..6dc5988 --- /dev/null +++ b/internal/lua/shared_dict.go @@ -0,0 +1,324 @@ +// Package lua 提供 Lua 脚本嵌入能力 +package lua + +import ( + "container/list" + "sync" + "time" +) + +// SharedDict 共享内存字典 +// 支持并发安全的 key-value 存储,带 LRU 汰出策略 +type SharedDict struct { + name string + maxItems int + mu sync.Mutex + data map[string]*sharedDictEntry + lruList *list.List // LRU 链表,前端为最近使用 +} + +// sharedDictEntry 字典条目 +type sharedDictEntry struct { + key string + value string + expiredAt time.Time // 过期时间(0 表示永不过期) + element *list.Element // LRU 链表元素指针 +} + +// NewSharedDict 创建共享字典 +func NewSharedDict(name string, maxItems int) *SharedDict { + return &SharedDict{ + name: name, + maxItems: maxItems, + data: make(map[string]*sharedDictEntry), + lruList: list.New(), + } +} + +// Get 获取值 +// 返回 value, expired, err +func (d *SharedDict) Get(key string) (string, bool, error) { + d.mu.Lock() + defer d.mu.Unlock() + + entry, ok := d.data[key] + if !ok { + return "", false, nil // 不存在 + } + + // 检查过期 + if !entry.expiredAt.IsZero() && time.Now().After(entry.expiredAt) { + // 已过期,删除并返回 + d.deleteEntry(entry) + return "", true, nil // 存在但已过期 + } + + // 更新 LRU - 移到前端 + d.lruList.MoveToFront(entry.element) + + return entry.value, false, nil +} + +// Set 设置值 +// 返回 ok, err (ok=false 表示容量满且无法淘汰) +func (d *SharedDict) Set(key, value string, ttl time.Duration) (bool, error) { + d.mu.Lock() + defer d.mu.Unlock() + + // 检查是否已存在 + if entry, ok := d.data[key]; ok { + // 更新现有条目 + entry.value = value + if ttl > 0 { + entry.expiredAt = time.Now().Add(ttl) + } else { + entry.expiredAt = time.Time{} // 清除过期时间 + } + d.lruList.MoveToFront(entry.element) + return true, nil + } + + // 新条目,检查容量 + if len(d.data) >= d.maxItems { + // 尝试淘汰过期条目 + d.evictExpired() + if len(d.data) >= d.maxItems { + // 淘汰 LRU 最久未使用的条目 + if !d.evictLRU() { + return false, nil // 无法淘汰(字典为空?) + } + } + } + + // 创建新条目 + expiredAt := time.Time{} + if ttl > 0 { + expiredAt = time.Now().Add(ttl) + } + + element := d.lruList.PushFront(key) + entry := &sharedDictEntry{ + key: key, + value: value, + expiredAt: expiredAt, + element: element, + } + d.data[key] = entry + + return true, nil +} + +// Add 添加值(仅在不存在时设置) +// 返回 ok, err (ok=false 表示已存在或容量满) +func (d *SharedDict) Add(key, value string, ttl time.Duration) (bool, error) { + d.mu.Lock() + defer d.mu.Unlock() + + // 检查是否已存在(包括过期的也算存在) + if _, ok := d.data[key]; ok { + return false, nil // 已存在 + } + + // 检查容量并淘汰 + if len(d.data) >= d.maxItems { + d.evictExpired() + if len(d.data) >= d.maxItems { + if !d.evictLRU() { + return false, nil + } + } + } + + // 创建新条目 + expiredAt := time.Time{} + if ttl > 0 { + expiredAt = time.Now().Add(ttl) + } + + element := d.lruList.PushFront(key) + entry := &sharedDictEntry{ + key: key, + value: value, + expiredAt: expiredAt, + element: element, + } + d.data[key] = entry + + return true, nil +} + +// Incr 自增数值 +// 返回 new_value, err +func (d *SharedDict) Incr(key string, increment int) (int, error) { + d.mu.Lock() + defer d.mu.Unlock() + + entry, ok := d.data[key] + if !ok { + // 不存在,创建初始值 + if len(d.data) >= d.maxItems { + d.evictExpired() + if len(d.data) >= d.maxItems { + if !d.evictLRU() { + return 0, nil // 无法创建 + } + } + } + + element := d.lruList.PushFront(key) + entry = &sharedDictEntry{ + key: key, + value: "0", + element: element, + } + d.data[key] = entry + } + + // 解析数值 + var current int + for _, c := range entry.value { + if c < '0' || c > '9' { + return 0, nil // 不是数值 + } + current = current*10 + int(c-'0') + } + + newValue := current + increment + entry.value = intToStr(newValue) + d.lruList.MoveToFront(entry.element) + + return newValue, nil +} + +// Delete 删除条目 +func (d *SharedDict) Delete(key string) error { + d.mu.Lock() + defer d.mu.Unlock() + + if entry, ok := d.data[key]; ok { + d.deleteEntry(entry) + } + return nil +} + +// FlushAll 清空所有条目 +func (d *SharedDict) FlushAll() error { + d.mu.Lock() + defer d.mu.Unlock() + + d.data = make(map[string]*sharedDictEntry) + d.lruList = list.New() + return nil +} + +// FlushExpired 清除所有过期条目 +// 返回清除的条目数 +func (d *SharedDict) FlushExpired() int { + d.mu.Lock() + defer d.mu.Unlock() + + return d.evictExpired() +} + +// Size 返回当前条目数 +func (d *SharedDict) Size() int { + d.mu.Lock() + defer d.mu.Unlock() + return len(d.data) +} + +// FreeSlots 返回剩余容量 +func (d *SharedDict) FreeSlots() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.maxItems - len(d.data) +} + +// deleteEntry 删除条目(内部方法,已持有锁) +func (d *SharedDict) deleteEntry(entry *sharedDictEntry) { + d.lruList.Remove(entry.element) + delete(d.data, entry.key) +} + +// evictExpired 淘汰过期条目(内部方法,已持有锁) +func (d *SharedDict) evictExpired() int { + now := time.Now() + count := 0 + + // 从 LRU 链表尾部(最久未使用)开始检查 + for elem := d.lruList.Back(); elem != nil; { + key := elem.Value.(string) + entry, ok := d.data[key] + if !ok { + // 数据不一致,跳过 + next := elem.Prev() + d.lruList.Remove(elem) + elem = next + continue + } + + if !entry.expiredAt.IsZero() && now.After(entry.expiredAt) { + // 已过期,删除 + d.deleteEntry(entry) + count++ + elem = d.lruList.Back() // 重新从尾部开始 + } else { + break // 未过期,停止(链表顺序保证前面都是未过期的) + } + } + + return count +} + +// evictLRU 淘汰 LRU 最久未使用的条目(内部方法,已持有锁) +func (d *SharedDict) evictLRU() bool { + if d.lruList.Len() == 0 { + return false + } + + elem := d.lruList.Back() + if elem == nil { + return false + } + + key := elem.Value.(string) + entry, ok := d.data[key] + if ok { + d.deleteEntry(entry) + return true + } + + // 数据不一致,移除链表元素并重试 + d.lruList.Remove(elem) + return d.evictLRU() +} + +// intToStr 整数转字符串(简单实现,避免 strconv 依赖) +func intToStr(n int) string { + if n == 0 { + return "0" + } + + var negative bool + if n < 0 { + negative = true + n = -n + } + + var buf []byte + for n > 0 { + buf = append(buf, byte('0'+n%10)) + n /= 10 + } + + if negative { + buf = append(buf, '-') + } + + // 反转 + for i, j := 0, len(buf)-1; i < j; i, j = i+1, j-1 { + buf[i], buf[j] = buf[j], buf[i] + } + + return string(buf) +} diff --git a/internal/lua/shared_dict_test.go b/internal/lua/shared_dict_test.go new file mode 100644 index 0000000..9677f1d --- /dev/null +++ b/internal/lua/shared_dict_test.go @@ -0,0 +1,251 @@ +// Package lua 提供 Lua 脚本嵌入能力 +package lua + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSharedDictGetSet(t *testing.T) { + dict := NewSharedDict("test", 100) + + // Set + ok, err := dict.Set("key1", "value1", 0) + require.NoError(t, err) + assert.True(t, ok) + + // Get + value, expired, err := dict.Get("key1") + require.NoError(t, err) + assert.False(t, expired) + assert.Equal(t, "value1", value) + + // 不存在的 key + value, expired, err = dict.Get("notexist") + require.NoError(t, err) + assert.False(t, expired) + assert.Equal(t, "", value) +} + +func TestSharedDictAdd(t *testing.T) { + dict := NewSharedDict("test", 100) + + // Add 新 key + ok, err := dict.Add("key1", "value1", 0) + require.NoError(t, err) + assert.True(t, ok) + + // Add 已存在的 key + ok, err = dict.Add("key1", "value2", 0) + require.NoError(t, err) + assert.False(t, ok) // 已存在,返回 false + + // 验证值未被修改 + value, _, _ := dict.Get("key1") + assert.Equal(t, "value1", value) +} + +func TestSharedDictIncr(t *testing.T) { + dict := NewSharedDict("test", 100) + + // Incr 不存在的 key,从 0 开始 + newValue, err := dict.Incr("counter", 5) + require.NoError(t, err) + assert.Equal(t, 5, newValue) + + // Incr 已存在的 key + newValue, err = dict.Incr("counter", 3) + require.NoError(t, err) + assert.Equal(t, 8, newValue) + + // 轮减 + newValue, err = dict.Incr("counter", -2) + require.NoError(t, err) + assert.Equal(t, 6, newValue) +} + +func TestSharedDictDelete(t *testing.T) { + dict := NewSharedDict("test", 100) + + dict.Set("key1", "value1", 0) + dict.Delete("key1") + + value, _, _ := dict.Get("key1") + assert.Equal(t, "", value) + + // 删除不存在的 key 不会报错 + dict.Delete("notexist") +} + +func TestSharedDictTTL(t *testing.T) { + dict := NewSharedDict("test", 100) + + // Set 带 TTL + ok, err := dict.Set("key1", "value1", 100*time.Millisecond) + require.NoError(t, err) + assert.True(t, ok) + + // 立即获取应该成功 + value, expired, err := dict.Get("key1") + require.NoError(t, err) + assert.False(t, expired) + assert.Equal(t, "value1", value) + + // 等待过期 + time.Sleep(150 * time.Millisecond) + + // 过期后获取 + value, expired, err = dict.Get("key1") + require.NoError(t, err) + assert.True(t, expired) + assert.Equal(t, "", value) +} + +func TestSharedDictLRUEviction(t *testing.T) { + dict := NewSharedDict("test", 3) // 只有 3 个容量 + + // 添加 3 个条目 + dict.Set("key1", "value1", 0) + dict.Set("key2", "value2", 0) + dict.Set("key3", "value3", 0) + + // 使用 key1,使其成为最近使用 + dict.Get("key1") + + // 添加第 4 个条目,应该淘汰 key2(最久未使用) + ok, err := dict.Set("key4", "value4", 0) + require.NoError(t, err) + assert.True(t, ok) + + // key1 应该还在 + value, _, _ := dict.Get("key1") + assert.Equal(t, "value1", value) + + // key2 应该被淘汰 + value, _, _ = dict.Get("key2") + assert.Equal(t, "", value) +} + +func TestSharedDictFlushAll(t *testing.T) { + dict := NewSharedDict("test", 100) + + dict.Set("key1", "value1", 0) + dict.Set("key2", "value2", 0) + + dict.FlushAll() + + assert.Equal(t, 0, dict.Size()) +} + +func TestSharedDictFlushExpired(t *testing.T) { + dict := NewSharedDict("test", 100) + + dict.Set("key1", "value1", 100*time.Millisecond) + dict.Set("key2", "value2", 100*time.Millisecond) + dict.Set("key3", "value3", 0) // 不过期 + + // 立即清除应该返回 0 + count := dict.FlushExpired() + assert.Equal(t, 0, count) + + // 等待过期 + time.Sleep(150 * time.Millisecond) + + count = dict.FlushExpired() + assert.Equal(t, 2, count) + + // key3 应该还在 + assert.Equal(t, 1, dict.Size()) + value, _, _ := dict.Get("key3") + assert.Equal(t, "value3", value) +} + +func TestSharedDictSize(t *testing.T) { + dict := NewSharedDict("test", 100) + + assert.Equal(t, 0, dict.Size()) + assert.Equal(t, 100, dict.FreeSlots()) + + dict.Set("key1", "value1", 0) + assert.Equal(t, 1, dict.Size()) + assert.Equal(t, 99, dict.FreeSlots()) +} + +func TestSharedDictManager(t *testing.T) { + manager := NewSharedDictManager() + + // 创建字典 + dict1 := manager.CreateDict("dict1", 100) + require.NotNil(t, dict1) + + // 再次获取同一个字典 + dict1Again := manager.GetDict("dict1") + assert.Equal(t, dict1, dict1Again) + + // 创建另一个字典 + dict2 := manager.CreateDict("dict2", 200) + require.NotNil(t, dict2) + + // 获取不存在的字典 + notexist := manager.GetDict("notexist") + assert.Nil(t, notexist) + + // 关闭 + manager.Close() +} + +func TestSharedDictLuaAPI(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + // 创建共享字典(通过 Lua API 测试,此处仅为初始化) + _ = engine.CreateSharedDict("mydict", 100) + + // 测试 Lua 脚本 + L := engine.L + + // 手动注册 ngx.shared API(用于测试) + ngx := L.NewTable() + L.SetGlobal("ngx", ngx) + RegisterSharedDictAPI(L, engine.SharedDictManager(), ngx) + + // 运行 Lua 脚本测试 + err = L.DoString(` + local dict = ngx.shared.DICT("mydict") + + -- 测试 set/get + dict:set("key1", "value1") + local val, flags = dict:get("key1") + assert(val == "value1") + + -- 测试 add + local ok, err = dict:add("key2", "value2") + assert(ok == true) + + -- 测试 add 已存在的 key + ok, err = dict:add("key2", "value3") + assert(ok == false) + assert(err == "exists") + + -- 测试 incr + local new_val, err = dict:incr("counter", 10) + assert(new_val == 10) + + new_val, err = dict:incr("counter", 5) + assert(new_val == 15) + + -- 测试 size + local size = dict:size() + assert(size >= 3) + + -- 测试 delete + dict:delete("key1") + local val, err = dict:get("key1") + assert(val == nil) + `) + require.NoError(t, err) +}