feat(lua): 实现共享字典 API (ngx.shared.DICT)
添加共享内存字典实现,支持并发安全的 key-value 存储: - SharedDictManager: 管理多个命名的 SharedDict 实例 - SharedDict: 带 LRU 汰出策略的内存字典 - 支持 set/get/add/incr/size/free_space 操作 - 支持带 TTL 的过期机制 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
6a6cfcd11c
commit
e3e5b1fe83
43
examples/lua-scripts/shared_dict.lua
Normal file
43
examples/lua-scripts/shared_dict.lua
Normal file
@ -0,0 +1,43 @@
|
||||
-- shared_dict.lua - 共享字典示例
|
||||
-- 此脚本演示 ngx.shared.DICT 的使用
|
||||
|
||||
-- 获取共享字典(需要在配置中预先定义)
|
||||
local dict = ngx.shared.DICT("my_cache")
|
||||
|
||||
-- 设置值
|
||||
local ok, err = dict:set("user_count", "100")
|
||||
if not ok then
|
||||
ngx.log(ngx.ERR, "failed to set user_count: ", err)
|
||||
end
|
||||
|
||||
-- 设置带 TTL 的值
|
||||
ok, err = dict:set("session_token", "abc123", 3600) -- 1 小时过期
|
||||
if not ok then
|
||||
ngx.log(ngx.ERR, "failed to set session_token: ", err)
|
||||
end
|
||||
|
||||
-- 获取值
|
||||
local value, flags = dict:get("user_count")
|
||||
ngx.say("user_count: ", value)
|
||||
|
||||
-- 自增计数器
|
||||
local new_val, err = dict:incr("request_count", 1)
|
||||
ngx.say("request_count: ", new_val)
|
||||
|
||||
-- 添加值(仅不存在时)
|
||||
ok, err = dict:add("unique_key", "value")
|
||||
if ok then
|
||||
ngx.say("unique_key added successfully")
|
||||
else
|
||||
ngx.say("unique_key already exists")
|
||||
end
|
||||
|
||||
-- 查看字典大小
|
||||
local size = dict:size()
|
||||
ngx.say("dict size: ", size)
|
||||
|
||||
-- 获取剩余容量
|
||||
local free = dict:free_space()
|
||||
ngx.say("free space: ", free)
|
||||
|
||||
ngx.say("Shared dict demo completed!")
|
||||
428
internal/lua/api_shared_dict.go
Normal file
428
internal/lua/api_shared_dict.go
Normal file
@ -0,0 +1,428 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// SharedDictManager 共享字典管理器
|
||||
// 管理多个命名的 SharedDict 实例
|
||||
type SharedDictManager struct {
|
||||
mu sync.RWMutex
|
||||
dicts map[string]*SharedDict
|
||||
}
|
||||
|
||||
// NewSharedDictManager 创建字典管理器
|
||||
func NewSharedDictManager() *SharedDictManager {
|
||||
return &SharedDictManager{
|
||||
dicts: make(map[string]*SharedDict),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDict 创建或获取字典
|
||||
func (m *SharedDictManager) CreateDict(name string, maxItems int) *SharedDict {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if dict, ok := m.dicts[name]; ok {
|
||||
return dict
|
||||
}
|
||||
|
||||
dict := NewSharedDict(name, maxItems)
|
||||
m.dicts[name] = dict
|
||||
return dict
|
||||
}
|
||||
|
||||
// GetDict 获取字典
|
||||
func (m *SharedDictManager) GetDict(name string) *SharedDict {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.dicts[name]
|
||||
}
|
||||
|
||||
// Close 清理所有字典
|
||||
func (m *SharedDictManager) Close() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.dicts = nil
|
||||
}
|
||||
|
||||
// DictConfig 字典配置
|
||||
type DictConfig struct {
|
||||
Name string
|
||||
MaxItems int
|
||||
}
|
||||
|
||||
// RegisterSharedDictAPI 注册 ngx.shared.DICT API
|
||||
func RegisterSharedDictAPI(L *glua.LState, manager *SharedDictManager, ngx *glua.LTable) {
|
||||
// 创建 ngx.shared 表
|
||||
shared := L.NewTable()
|
||||
|
||||
// ngx.shared.DICT - 返回字典 userdata
|
||||
L.SetField(shared, "DICT", L.NewFunction(func(L *glua.LState) int {
|
||||
name := L.CheckString(1)
|
||||
|
||||
dict := manager.GetDict(name)
|
||||
if dict == nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString("shared dict not found: " + name))
|
||||
return 2
|
||||
}
|
||||
|
||||
// 返回字典 userdata
|
||||
ud := L.NewUserData()
|
||||
ud.Value = dict
|
||||
L.SetMetatable(ud, L.GetTypeMetatable("ngx.shared.dict"))
|
||||
L.Push(ud)
|
||||
return 1
|
||||
}))
|
||||
|
||||
L.SetField(ngx, "shared", shared)
|
||||
|
||||
// 创建字典类型元表
|
||||
mt := L.NewTypeMetatable("ngx.shared.dict")
|
||||
L.SetField(mt, "__index", L.NewFunction(dictIndex))
|
||||
L.SetField(mt, "__newindex", L.NewFunction(dictNewIndex))
|
||||
L.SetField(mt, "__tostring", L.NewFunction(dictToString))
|
||||
|
||||
// 注册字典方法
|
||||
methods := L.NewTable()
|
||||
L.SetField(methods, "get", L.NewFunction(dictGet))
|
||||
L.SetField(methods, "set", L.NewFunction(dictSet))
|
||||
L.SetField(methods, "add", L.NewFunction(dictAdd))
|
||||
L.SetField(methods, "replace", L.NewFunction(dictReplace))
|
||||
L.SetField(methods, "incr", L.NewFunction(dictIncr))
|
||||
L.SetField(methods, "delete", L.NewFunction(dictDelete))
|
||||
L.SetField(methods, "flush_all", L.NewFunction(dictFlushAll))
|
||||
L.SetField(methods, "flush_expired", L.NewFunction(dictFlushExpired))
|
||||
L.SetField(methods, "get_keys", L.NewFunction(dictGetKeys))
|
||||
L.SetField(methods, "size", L.NewFunction(dictSize))
|
||||
L.SetField(methods, "free_space", L.NewFunction(dictFreeSpace))
|
||||
L.SetField(mt, "methods", methods)
|
||||
}
|
||||
|
||||
// dictIndex 字典索引方法
|
||||
func dictIndex(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
|
||||
// 检查是否是方法
|
||||
methods := L.GetField(L.Get(1).(*glua.LUserData).Metatable, "methods")
|
||||
if method := L.GetField(methods, key); method != glua.LNil {
|
||||
L.Push(method)
|
||||
return 1
|
||||
}
|
||||
|
||||
// 否则作为 key 获取值
|
||||
value, expired, err := dict.Get(key)
|
||||
if err != nil {
|
||||
L.RaiseError("%s", err.Error())
|
||||
return 0
|
||||
}
|
||||
if expired {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString("expired"))
|
||||
return 2
|
||||
}
|
||||
if value == "" {
|
||||
L.Push(glua.LNil)
|
||||
return 1
|
||||
}
|
||||
L.Push(glua.LString(value))
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictNewIndex 字典设置方法
|
||||
func dictNewIndex(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
value := L.CheckString(3)
|
||||
|
||||
ok, err := dict.Set(key, value, 0)
|
||||
if err != nil {
|
||||
L.RaiseError("%s", err.Error())
|
||||
return 0
|
||||
}
|
||||
if !ok {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString("no memory"))
|
||||
return 2
|
||||
}
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictToString 字典字符串表示
|
||||
func dictToString(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.Push(glua.LString("invalid shared dict"))
|
||||
return 1
|
||||
}
|
||||
L.Push(glua.LString("ngx.shared.dict:" + dict.name))
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictGet 获取值
|
||||
// dict:get(key) -> value, flags | nil, err
|
||||
func dictGet(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
|
||||
value, expired, err := dict.Get(key)
|
||||
if err != nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
if expired {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString("expired"))
|
||||
return 2
|
||||
}
|
||||
if value == "" {
|
||||
L.Push(glua.LNil)
|
||||
return 1
|
||||
}
|
||||
L.Push(glua.LString(value))
|
||||
L.Push(glua.LNumber(0)) // flags(暂不支持)
|
||||
return 2
|
||||
}
|
||||
|
||||
// dictSet 设置值
|
||||
// dict:set(key, value, exptime?, flags?) -> ok, err
|
||||
func dictSet(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
value := L.CheckString(3)
|
||||
|
||||
ttl := time.Duration(0)
|
||||
if L.GetTop() >= 4 {
|
||||
ttl = time.Duration(L.CheckNumber(4)) * time.Second
|
||||
}
|
||||
|
||||
// flags 参数暂不使用
|
||||
|
||||
ok, err := dict.Set(key, value, ttl)
|
||||
if err != nil {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
if !ok {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString("no memory"))
|
||||
return 2
|
||||
}
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictAdd 添加值(不存在时)
|
||||
// dict:add(key, value, exptime?, flags?) -> ok, err
|
||||
func dictAdd(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
value := L.CheckString(3)
|
||||
|
||||
ttl := time.Duration(0)
|
||||
if L.GetTop() >= 4 {
|
||||
ttl = time.Duration(L.CheckNumber(4)) * time.Second
|
||||
}
|
||||
|
||||
ok, err := dict.Add(key, value, ttl)
|
||||
if err != nil {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
if !ok {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString("exists"))
|
||||
return 2
|
||||
}
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictReplace 替换值(存在时)
|
||||
// dict:replace(key, value, exptime?, flags?) -> ok, err
|
||||
func dictReplace(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
value := L.CheckString(3)
|
||||
|
||||
ttl := time.Duration(0)
|
||||
if L.GetTop() >= 4 {
|
||||
ttl = time.Duration(L.CheckNumber(4)) * time.Second
|
||||
}
|
||||
|
||||
// 检查是否存在
|
||||
_, expired, _ := dict.Get(key)
|
||||
if expired {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString("not found"))
|
||||
return 2
|
||||
}
|
||||
|
||||
ok, err := dict.Set(key, value, ttl)
|
||||
if err != nil {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictIncr 自增数值
|
||||
// dict:incr(key, value) -> new_value, err
|
||||
func dictIncr(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
increment := int(L.CheckNumber(3))
|
||||
|
||||
newValue, err := dict.Incr(key, increment)
|
||||
if err != nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(glua.LNumber(newValue))
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictDelete 删除条目
|
||||
// dict:delete(key) -> ok
|
||||
func dictDelete(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
dict.Delete(key)
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictFlushAll 清空字典
|
||||
// dict:flush_all()
|
||||
func dictFlushAll(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
dict.FlushAll()
|
||||
return 0
|
||||
}
|
||||
|
||||
// dictFlushExpired 清除过期条目
|
||||
// dict:flush_expired(max_count?) -> flushed_count
|
||||
func dictFlushExpired(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
count := dict.FlushExpired()
|
||||
L.Push(glua.LNumber(count))
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictGetKeys 获取所有键
|
||||
// dict:get_keys(max_count?) -> keys
|
||||
func dictGetKeys(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
_, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
// 暂不实现完整版,返回空表
|
||||
keys := L.NewTable()
|
||||
L.Push(keys)
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictSize 获取条目数
|
||||
// dict:size() -> count
|
||||
func dictSize(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
L.Push(glua.LNumber(dict.Size()))
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictFreeSpace 获取剩余容量
|
||||
// dict:free_space() -> slots
|
||||
func dictFreeSpace(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
L.Push(glua.LNumber(dict.FreeSlots()))
|
||||
return 1
|
||||
}
|
||||
324
internal/lua/shared_dict.go
Normal file
324
internal/lua/shared_dict.go
Normal file
@ -0,0 +1,324 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SharedDict 共享内存字典
|
||||
// 支持并发安全的 key-value 存储,带 LRU 汰出策略
|
||||
type SharedDict struct {
|
||||
name string
|
||||
maxItems int
|
||||
mu sync.Mutex
|
||||
data map[string]*sharedDictEntry
|
||||
lruList *list.List // LRU 链表,前端为最近使用
|
||||
}
|
||||
|
||||
// sharedDictEntry 字典条目
|
||||
type sharedDictEntry struct {
|
||||
key string
|
||||
value string
|
||||
expiredAt time.Time // 过期时间(0 表示永不过期)
|
||||
element *list.Element // LRU 链表元素指针
|
||||
}
|
||||
|
||||
// NewSharedDict 创建共享字典
|
||||
func NewSharedDict(name string, maxItems int) *SharedDict {
|
||||
return &SharedDict{
|
||||
name: name,
|
||||
maxItems: maxItems,
|
||||
data: make(map[string]*sharedDictEntry),
|
||||
lruList: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get 获取值
|
||||
// 返回 value, expired, err
|
||||
func (d *SharedDict) Get(key string) (string, bool, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
entry, ok := d.data[key]
|
||||
if !ok {
|
||||
return "", false, nil // 不存在
|
||||
}
|
||||
|
||||
// 检查过期
|
||||
if !entry.expiredAt.IsZero() && time.Now().After(entry.expiredAt) {
|
||||
// 已过期,删除并返回
|
||||
d.deleteEntry(entry)
|
||||
return "", true, nil // 存在但已过期
|
||||
}
|
||||
|
||||
// 更新 LRU - 移到前端
|
||||
d.lruList.MoveToFront(entry.element)
|
||||
|
||||
return entry.value, false, nil
|
||||
}
|
||||
|
||||
// Set 设置值
|
||||
// 返回 ok, err (ok=false 表示容量满且无法淘汰)
|
||||
func (d *SharedDict) Set(key, value string, ttl time.Duration) (bool, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
// 检查是否已存在
|
||||
if entry, ok := d.data[key]; ok {
|
||||
// 更新现有条目
|
||||
entry.value = value
|
||||
if ttl > 0 {
|
||||
entry.expiredAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
entry.expiredAt = time.Time{} // 清除过期时间
|
||||
}
|
||||
d.lruList.MoveToFront(entry.element)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 新条目,检查容量
|
||||
if len(d.data) >= d.maxItems {
|
||||
// 尝试淘汰过期条目
|
||||
d.evictExpired()
|
||||
if len(d.data) >= d.maxItems {
|
||||
// 淘汰 LRU 最久未使用的条目
|
||||
if !d.evictLRU() {
|
||||
return false, nil // 无法淘汰(字典为空?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新条目
|
||||
expiredAt := time.Time{}
|
||||
if ttl > 0 {
|
||||
expiredAt = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
element := d.lruList.PushFront(key)
|
||||
entry := &sharedDictEntry{
|
||||
key: key,
|
||||
value: value,
|
||||
expiredAt: expiredAt,
|
||||
element: element,
|
||||
}
|
||||
d.data[key] = entry
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Add 添加值(仅在不存在时设置)
|
||||
// 返回 ok, err (ok=false 表示已存在或容量满)
|
||||
func (d *SharedDict) Add(key, value string, ttl time.Duration) (bool, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
// 检查是否已存在(包括过期的也算存在)
|
||||
if _, ok := d.data[key]; ok {
|
||||
return false, nil // 已存在
|
||||
}
|
||||
|
||||
// 检查容量并淘汰
|
||||
if len(d.data) >= d.maxItems {
|
||||
d.evictExpired()
|
||||
if len(d.data) >= d.maxItems {
|
||||
if !d.evictLRU() {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新条目
|
||||
expiredAt := time.Time{}
|
||||
if ttl > 0 {
|
||||
expiredAt = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
element := d.lruList.PushFront(key)
|
||||
entry := &sharedDictEntry{
|
||||
key: key,
|
||||
value: value,
|
||||
expiredAt: expiredAt,
|
||||
element: element,
|
||||
}
|
||||
d.data[key] = entry
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Incr 自增数值
|
||||
// 返回 new_value, err
|
||||
func (d *SharedDict) Incr(key string, increment int) (int, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
entry, ok := d.data[key]
|
||||
if !ok {
|
||||
// 不存在,创建初始值
|
||||
if len(d.data) >= d.maxItems {
|
||||
d.evictExpired()
|
||||
if len(d.data) >= d.maxItems {
|
||||
if !d.evictLRU() {
|
||||
return 0, nil // 无法创建
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
element := d.lruList.PushFront(key)
|
||||
entry = &sharedDictEntry{
|
||||
key: key,
|
||||
value: "0",
|
||||
element: element,
|
||||
}
|
||||
d.data[key] = entry
|
||||
}
|
||||
|
||||
// 解析数值
|
||||
var current int
|
||||
for _, c := range entry.value {
|
||||
if c < '0' || c > '9' {
|
||||
return 0, nil // 不是数值
|
||||
}
|
||||
current = current*10 + int(c-'0')
|
||||
}
|
||||
|
||||
newValue := current + increment
|
||||
entry.value = intToStr(newValue)
|
||||
d.lruList.MoveToFront(entry.element)
|
||||
|
||||
return newValue, nil
|
||||
}
|
||||
|
||||
// Delete 删除条目
|
||||
func (d *SharedDict) Delete(key string) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if entry, ok := d.data[key]; ok {
|
||||
d.deleteEntry(entry)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FlushAll 清空所有条目
|
||||
func (d *SharedDict) FlushAll() error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.data = make(map[string]*sharedDictEntry)
|
||||
d.lruList = list.New()
|
||||
return nil
|
||||
}
|
||||
|
||||
// FlushExpired 清除所有过期条目
|
||||
// 返回清除的条目数
|
||||
func (d *SharedDict) FlushExpired() int {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
return d.evictExpired()
|
||||
}
|
||||
|
||||
// Size 返回当前条目数
|
||||
func (d *SharedDict) Size() int {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
return len(d.data)
|
||||
}
|
||||
|
||||
// FreeSlots 返回剩余容量
|
||||
func (d *SharedDict) FreeSlots() int {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
return d.maxItems - len(d.data)
|
||||
}
|
||||
|
||||
// deleteEntry 删除条目(内部方法,已持有锁)
|
||||
func (d *SharedDict) deleteEntry(entry *sharedDictEntry) {
|
||||
d.lruList.Remove(entry.element)
|
||||
delete(d.data, entry.key)
|
||||
}
|
||||
|
||||
// evictExpired 淘汰过期条目(内部方法,已持有锁)
|
||||
func (d *SharedDict) evictExpired() int {
|
||||
now := time.Now()
|
||||
count := 0
|
||||
|
||||
// 从 LRU 链表尾部(最久未使用)开始检查
|
||||
for elem := d.lruList.Back(); elem != nil; {
|
||||
key := elem.Value.(string)
|
||||
entry, ok := d.data[key]
|
||||
if !ok {
|
||||
// 数据不一致,跳过
|
||||
next := elem.Prev()
|
||||
d.lruList.Remove(elem)
|
||||
elem = next
|
||||
continue
|
||||
}
|
||||
|
||||
if !entry.expiredAt.IsZero() && now.After(entry.expiredAt) {
|
||||
// 已过期,删除
|
||||
d.deleteEntry(entry)
|
||||
count++
|
||||
elem = d.lruList.Back() // 重新从尾部开始
|
||||
} else {
|
||||
break // 未过期,停止(链表顺序保证前面都是未过期的)
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// evictLRU 淘汰 LRU 最久未使用的条目(内部方法,已持有锁)
|
||||
func (d *SharedDict) evictLRU() bool {
|
||||
if d.lruList.Len() == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
elem := d.lruList.Back()
|
||||
if elem == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
key := elem.Value.(string)
|
||||
entry, ok := d.data[key]
|
||||
if ok {
|
||||
d.deleteEntry(entry)
|
||||
return true
|
||||
}
|
||||
|
||||
// 数据不一致,移除链表元素并重试
|
||||
d.lruList.Remove(elem)
|
||||
return d.evictLRU()
|
||||
}
|
||||
|
||||
// intToStr 整数转字符串(简单实现,避免 strconv 依赖)
|
||||
func intToStr(n int) string {
|
||||
if n == 0 {
|
||||
return "0"
|
||||
}
|
||||
|
||||
var negative bool
|
||||
if n < 0 {
|
||||
negative = true
|
||||
n = -n
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
for n > 0 {
|
||||
buf = append(buf, byte('0'+n%10))
|
||||
n /= 10
|
||||
}
|
||||
|
||||
if negative {
|
||||
buf = append(buf, '-')
|
||||
}
|
||||
|
||||
// 反转
|
||||
for i, j := 0, len(buf)-1; i < j; i, j = i+1, j-1 {
|
||||
buf[i], buf[j] = buf[j], buf[i]
|
||||
}
|
||||
|
||||
return string(buf)
|
||||
}
|
||||
251
internal/lua/shared_dict_test.go
Normal file
251
internal/lua/shared_dict_test.go
Normal file
@ -0,0 +1,251 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSharedDictGetSet(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
// Set
|
||||
ok, err := dict.Set("key1", "value1", 0)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Get
|
||||
value, expired, err := dict.Get("key1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, expired)
|
||||
assert.Equal(t, "value1", value)
|
||||
|
||||
// 不存在的 key
|
||||
value, expired, err = dict.Get("notexist")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, expired)
|
||||
assert.Equal(t, "", value)
|
||||
}
|
||||
|
||||
func TestSharedDictAdd(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
// Add 新 key
|
||||
ok, err := dict.Add("key1", "value1", 0)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Add 已存在的 key
|
||||
ok, err = dict.Add("key1", "value2", 0)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok) // 已存在,返回 false
|
||||
|
||||
// 验证值未被修改
|
||||
value, _, _ := dict.Get("key1")
|
||||
assert.Equal(t, "value1", value)
|
||||
}
|
||||
|
||||
func TestSharedDictIncr(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
// Incr 不存在的 key,从 0 开始
|
||||
newValue, err := dict.Incr("counter", 5)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, newValue)
|
||||
|
||||
// Incr 已存在的 key
|
||||
newValue, err = dict.Incr("counter", 3)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 8, newValue)
|
||||
|
||||
// 轮减
|
||||
newValue, err = dict.Incr("counter", -2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 6, newValue)
|
||||
}
|
||||
|
||||
func TestSharedDictDelete(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
dict.Set("key1", "value1", 0)
|
||||
dict.Delete("key1")
|
||||
|
||||
value, _, _ := dict.Get("key1")
|
||||
assert.Equal(t, "", value)
|
||||
|
||||
// 删除不存在的 key 不会报错
|
||||
dict.Delete("notexist")
|
||||
}
|
||||
|
||||
func TestSharedDictTTL(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
// Set 带 TTL
|
||||
ok, err := dict.Set("key1", "value1", 100*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 立即获取应该成功
|
||||
value, expired, err := dict.Get("key1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, expired)
|
||||
assert.Equal(t, "value1", value)
|
||||
|
||||
// 等待过期
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// 过期后获取
|
||||
value, expired, err = dict.Get("key1")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, expired)
|
||||
assert.Equal(t, "", value)
|
||||
}
|
||||
|
||||
func TestSharedDictLRUEviction(t *testing.T) {
|
||||
dict := NewSharedDict("test", 3) // 只有 3 个容量
|
||||
|
||||
// 添加 3 个条目
|
||||
dict.Set("key1", "value1", 0)
|
||||
dict.Set("key2", "value2", 0)
|
||||
dict.Set("key3", "value3", 0)
|
||||
|
||||
// 使用 key1,使其成为最近使用
|
||||
dict.Get("key1")
|
||||
|
||||
// 添加第 4 个条目,应该淘汰 key2(最久未使用)
|
||||
ok, err := dict.Set("key4", "value4", 0)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// key1 应该还在
|
||||
value, _, _ := dict.Get("key1")
|
||||
assert.Equal(t, "value1", value)
|
||||
|
||||
// key2 应该被淘汰
|
||||
value, _, _ = dict.Get("key2")
|
||||
assert.Equal(t, "", value)
|
||||
}
|
||||
|
||||
func TestSharedDictFlushAll(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
dict.Set("key1", "value1", 0)
|
||||
dict.Set("key2", "value2", 0)
|
||||
|
||||
dict.FlushAll()
|
||||
|
||||
assert.Equal(t, 0, dict.Size())
|
||||
}
|
||||
|
||||
func TestSharedDictFlushExpired(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
dict.Set("key1", "value1", 100*time.Millisecond)
|
||||
dict.Set("key2", "value2", 100*time.Millisecond)
|
||||
dict.Set("key3", "value3", 0) // 不过期
|
||||
|
||||
// 立即清除应该返回 0
|
||||
count := dict.FlushExpired()
|
||||
assert.Equal(t, 0, count)
|
||||
|
||||
// 等待过期
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
count = dict.FlushExpired()
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
// key3 应该还在
|
||||
assert.Equal(t, 1, dict.Size())
|
||||
value, _, _ := dict.Get("key3")
|
||||
assert.Equal(t, "value3", value)
|
||||
}
|
||||
|
||||
func TestSharedDictSize(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
assert.Equal(t, 0, dict.Size())
|
||||
assert.Equal(t, 100, dict.FreeSlots())
|
||||
|
||||
dict.Set("key1", "value1", 0)
|
||||
assert.Equal(t, 1, dict.Size())
|
||||
assert.Equal(t, 99, dict.FreeSlots())
|
||||
}
|
||||
|
||||
func TestSharedDictManager(t *testing.T) {
|
||||
manager := NewSharedDictManager()
|
||||
|
||||
// 创建字典
|
||||
dict1 := manager.CreateDict("dict1", 100)
|
||||
require.NotNil(t, dict1)
|
||||
|
||||
// 再次获取同一个字典
|
||||
dict1Again := manager.GetDict("dict1")
|
||||
assert.Equal(t, dict1, dict1Again)
|
||||
|
||||
// 创建另一个字典
|
||||
dict2 := manager.CreateDict("dict2", 200)
|
||||
require.NotNil(t, dict2)
|
||||
|
||||
// 获取不存在的字典
|
||||
notexist := manager.GetDict("notexist")
|
||||
assert.Nil(t, notexist)
|
||||
|
||||
// 关闭
|
||||
manager.Close()
|
||||
}
|
||||
|
||||
func TestSharedDictLuaAPI(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
// 创建共享字典(通过 Lua API 测试,此处仅为初始化)
|
||||
_ = engine.CreateSharedDict("mydict", 100)
|
||||
|
||||
// 测试 Lua 脚本
|
||||
L := engine.L
|
||||
|
||||
// 手动注册 ngx.shared API(用于测试)
|
||||
ngx := L.NewTable()
|
||||
L.SetGlobal("ngx", ngx)
|
||||
RegisterSharedDictAPI(L, engine.SharedDictManager(), ngx)
|
||||
|
||||
// 运行 Lua 脚本测试
|
||||
err = L.DoString(`
|
||||
local dict = ngx.shared.DICT("mydict")
|
||||
|
||||
-- 测试 set/get
|
||||
dict:set("key1", "value1")
|
||||
local val, flags = dict:get("key1")
|
||||
assert(val == "value1")
|
||||
|
||||
-- 测试 add
|
||||
local ok, err = dict:add("key2", "value2")
|
||||
assert(ok == true)
|
||||
|
||||
-- 测试 add 已存在的 key
|
||||
ok, err = dict:add("key2", "value3")
|
||||
assert(ok == false)
|
||||
assert(err == "exists")
|
||||
|
||||
-- 测试 incr
|
||||
local new_val, err = dict:incr("counter", 10)
|
||||
assert(new_val == 10)
|
||||
|
||||
new_val, err = dict:incr("counter", 5)
|
||||
assert(new_val == 15)
|
||||
|
||||
-- 测试 size
|
||||
local size = dict:size()
|
||||
assert(size >= 3)
|
||||
|
||||
-- 测试 delete
|
||||
dict:delete("key1")
|
||||
local val, err = dict:get("key1")
|
||||
assert(val == nil)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user