This commit is contained in:
xfy 2026-04-12 11:50:35 +08:00
commit 5dbede27f8
12 changed files with 1912 additions and 7 deletions

View File

@ -0,0 +1,43 @@
-- shared_dict.lua - 共享字典示例
-- 此脚本演示 ngx.shared.DICT 的使用
-- 获取共享字典(需要在配置中预先定义)
local dict = ngx.shared.DICT("my_cache")
-- 设置值
local ok, err = dict:set("user_count", "100")
if not ok then
ngx.log(ngx.ERR, "failed to set user_count: ", err)
end
-- 设置带 TTL 的值
ok, err = dict:set("session_token", "abc123", 3600) -- 1 小时过期
if not ok then
ngx.log(ngx.ERR, "failed to set session_token: ", err)
end
-- 获取值
local value, flags = dict:get("user_count")
ngx.say("user_count: ", value)
-- 自增计数器
local new_val, err = dict:incr("request_count", 1)
ngx.say("request_count: ", new_val)
-- 添加值(仅不存在时)
ok, err = dict:add("unique_key", "value")
if ok then
ngx.say("unique_key added successfully")
else
ngx.say("unique_key already exists")
end
-- 查看字典大小
local size = dict:size()
ngx.say("dict size: ", size)
-- 获取剩余容量
local free = dict:free_space()
ngx.say("free space: ", free)
ngx.say("Shared dict demo completed!")

View File

@ -0,0 +1,26 @@
-- subrequest.lua - 子请求示例
-- 此脚本演示 ngx.location.capture 的使用
-- 简单子请求
local res = ngx.location.capture("/api/status")
ngx.say("Subrequest status: ", res.status)
ngx.say("Subrequest body: ", res.body)
-- 带 method 的子请求
res = ngx.location.capture("/api/users", {
method = "POST",
body = '{"name": "test"}'
})
ngx.say("POST status: ", res.status)
-- 带 headers 的子请求
res = ngx.location.capture("/api/check", {
method = "GET",
headers = {
["Authorization"] = "Bearer token123",
["X-Custom"] = "value"
}
})
ngx.say("GET with headers status: ", res.status)
ngx.say("Subrequest demo completed!")

View File

@ -0,0 +1,35 @@
-- timer.lua - 定时器示例
-- 此脚本演示 ngx.timer.at 的使用
-- 创建定时器回调函数
local function timer_callback()
-- 注意:定时器回调在独立上下文中执行
-- 不能直接访问请求相关 API
ngx.log(ngx.INFO, "Timer executed!")
end
-- 创建 5 秒后执行的定时器
local handle, err = ngx.timer.at(5, timer_callback)
if handle then
ngx.say("Timer created successfully")
-- 查看活跃定时器数
local count = ngx.timer.running_count()
ngx.say("Active timers: ", count)
else
ngx.say("Failed to create timer: ", err)
end
-- 创建带参数的定时器(简化版暂不支持参数传递)
local function param_callback()
ngx.log(ngx.INFO, "Timer with params executed")
end
handle, err = ngx.timer.at(2, param_callback)
if handle then
ngx.say("Timer with params created")
else
ngx.say("Failed: ", err)
end
ngx.say("Timer demo completed!")

View File

@ -0,0 +1,174 @@
// Package lua 提供 Lua 脚本嵌入能力
package lua
import (
"strings"
"sync"
"github.com/valyala/fasthttp"
glua "github.com/yuin/gopher-lua"
)
// LocationCaptureResult 子请求结果
type LocationCaptureResult struct {
Status int
Body []byte
Headers map[string]string
}
// LocationManager location 管理(用于子请求)
type LocationManager struct {
mu sync.Mutex
handlers map[string]fasthttp.RequestHandler // location -> handler
}
// NewLocationManager 创建 location 管理器
func NewLocationManager() *LocationManager {
return &LocationManager{
handlers: make(map[string]fasthttp.RequestHandler),
}
}
// Register 注册 location handler
func (m *LocationManager) Register(location string, handler fasthttp.RequestHandler) {
m.mu.Lock()
defer m.mu.Unlock()
m.handlers[location] = handler
}
// Capture 执行子请求
func (m *LocationManager) Capture(parentCtx *fasthttp.RequestCtx, location string, opts map[string]interface{}) (*LocationCaptureResult, error) {
m.mu.Lock()
handler, ok := m.handlers[location]
m.mu.Unlock()
if !ok {
// location 不存在,返回 404
return &LocationCaptureResult{
Status: 404,
Body: []byte("location not found"),
Headers: map[string]string{},
}, nil
}
// 创建子请求上下文(不设置 Conn
subCtx := &fasthttp.RequestCtx{}
// 复制父请求作为基础
parentCtx.Request.CopyTo(&subCtx.Request)
// 设置子请求的 URI
subCtx.Request.SetRequestURI(location)
// 应用选项
if opts != nil {
if method, ok := opts["method"].(string); ok {
subCtx.Request.Header.SetMethod(method)
}
if body, ok := opts["body"].(string); ok {
subCtx.Request.SetBodyString(body)
}
if headers, ok := opts["headers"].(map[string]string); ok {
for k, v := range headers {
subCtx.Request.Header.Set(k, v)
}
}
}
// 执行 handler
handler(subCtx)
// 收集结果
result := &LocationCaptureResult{
Status: subCtx.Response.StatusCode(),
Body: subCtx.Response.Body(),
Headers: make(map[string]string),
}
// 收集响应头(使用 VisitAll
subCtx.Response.Header.VisitAll(func(key, value []byte) {
result.Headers[string(key)] = string(value)
})
return result, nil
}
// RegisterLocationAPI 注册 ngx.location API
func RegisterLocationAPI(L *glua.LState, manager *LocationManager, ngx *glua.LTable) {
// 创建 ngx.location 表
location := L.NewTable()
// ngx.location.capture(uri, options?)
L.SetField(location, "capture", L.NewFunction(func(L *glua.LState) int {
uri := L.CheckString(1)
// 解析选项
opts := make(map[string]interface{})
if L.GetTop() >= 2 {
optionsTable := L.CheckTable(2)
optionsTable.ForEach(func(key, value glua.LValue) {
keyStr := glua.LVAsString(key)
switch value.Type() {
case glua.LTString:
opts[keyStr] = glua.LVAsString(value)
case glua.LTNumber:
opts[keyStr] = float64(glua.LVAsNumber(value))
case glua.LTTable:
// 处理 headers 表
if keyStr == "headers" {
headers := make(map[string]string)
value.(*glua.LTable).ForEach(func(hKey, hValue glua.LValue) {
headers[glua.LVAsString(hKey)] = glua.LVAsString(hValue)
})
opts[keyStr] = headers
}
}
})
}
// 创建结果表
result := L.NewTable()
// 尝试执行子请求
// 注意:由于无法直接获取 RequestCtx这里使用模拟的上下文
// 在完整实现中,应该通过 coroutine 传递 RequestCtx
if manager != nil {
// 创建模拟请求上下文用于子请求执行
mockCtx := &fasthttp.RequestCtx{}
mockCtx.Request.SetRequestURI(uri)
captureResult, err := manager.Capture(mockCtx, uri, opts)
if err == nil && captureResult != nil {
L.SetField(result, "status", glua.LNumber(captureResult.Status))
L.SetField(result, "body", glua.LString(string(captureResult.Body)))
// 设置 headers
headersTable := headersToLuaTable(L, captureResult.Headers)
L.SetField(result, "headers", headersTable)
} else {
// 执行失败
L.SetField(result, "status", glua.LNumber(500))
L.SetField(result, "body", glua.LString("subrequest failed"))
}
} else {
// manager 未初始化
L.SetField(result, "status", glua.LNumber(404))
L.SetField(result, "body", glua.LString("location manager not initialized"))
}
L.Push(result)
return 1
}))
L.SetField(ngx, "location", location)
}
// headersToLuaTable 将 headers 转为 Lua 表
func headersToLuaTable(L *glua.LState, headers map[string]string) *glua.LTable {
table := L.NewTable()
for k, v := range headers {
// 转换为小写键名nginx 风格)
table.RawSetString(strings.ToLower(k), glua.LString(v))
}
return table
}

View File

@ -0,0 +1,118 @@
// Package lua 提供 Lua 脚本嵌入能力
package lua
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func TestLocationManagerRegister(t *testing.T) {
manager := NewLocationManager()
require.NotNil(t, manager)
// 注册 location
handler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("test response")
}
manager.Register("/test", handler)
// 验证注册成功
manager.mu.Lock()
_, ok := manager.handlers["/test"]
manager.mu.Unlock()
assert.True(t, ok)
}
func TestLocationManagerCapture(t *testing.T) {
manager := NewLocationManager()
// 注册 location
handler := func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetBodyString("hello from subrequest")
ctx.Response.Header.Set("X-Custom", "value")
}
manager.Register("/api/sub", handler)
// 创建父请求上下文
parentCtx := &fasthttp.RequestCtx{}
parentCtx.Request.SetRequestURI("/parent")
// 执行子请求
result, err := manager.Capture(parentCtx, "/api/sub", nil)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, 200, result.Status)
assert.Equal(t, "hello from subrequest", string(result.Body))
assert.Equal(t, "value", result.Headers["X-Custom"])
}
func TestLocationManagerCaptureNotFound(t *testing.T) {
manager := NewLocationManager()
parentCtx := &fasthttp.RequestCtx{}
// 执行不存在的 location
result, err := manager.Capture(parentCtx, "/notexist", nil)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, 404, result.Status)
}
func TestLocationManagerCaptureWithOptions(t *testing.T) {
manager := NewLocationManager()
// 注册 location
handler := func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.WriteString("method: " + string(ctx.Method()) + ", body: " + string(ctx.PostBody()))
}
manager.Register("/echo", handler)
parentCtx := &fasthttp.RequestCtx{}
parentCtx.Request.SetRequestURI("/parent")
// 使用自定义选项
opts := map[string]interface{}{
"method": "POST",
"body": "test body",
}
result, err := manager.Capture(parentCtx, "/echo", opts)
require.NoError(t, err)
require.NotNil(t, result)
assert.Equal(t, 200, result.Status)
assert.Contains(t, string(result.Body), "method: POST")
assert.Contains(t, string(result.Body), "body: test body")
}
func TestLocationLuaAPI(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
L := engine.L
// 注册 ngx.location API
ngx := L.NewTable()
L.SetGlobal("ngx", ngx)
RegisterLocationAPI(L, engine.LocationManager(), ngx)
// 测试 ngx.location.capture
err = L.DoString(`
-- 创建模拟的 location 结果
local result = ngx.location.capture("/test")
-- 验证结果结构
assert(result ~= nil)
assert(result.status ~= nil)
assert(result.body ~= nil)
`)
require.NoError(t, err)
}

View File

@ -0,0 +1,428 @@
// Package lua 提供 Lua 脚本嵌入能力
package lua
import (
"sync"
"time"
glua "github.com/yuin/gopher-lua"
)
// SharedDictManager 共享字典管理器
// 管理多个命名的 SharedDict 实例
type SharedDictManager struct {
mu sync.RWMutex
dicts map[string]*SharedDict
}
// NewSharedDictManager 创建字典管理器
func NewSharedDictManager() *SharedDictManager {
return &SharedDictManager{
dicts: make(map[string]*SharedDict),
}
}
// CreateDict 创建或获取字典
func (m *SharedDictManager) CreateDict(name string, maxItems int) *SharedDict {
m.mu.Lock()
defer m.mu.Unlock()
if dict, ok := m.dicts[name]; ok {
return dict
}
dict := NewSharedDict(name, maxItems)
m.dicts[name] = dict
return dict
}
// GetDict 获取字典
func (m *SharedDictManager) GetDict(name string) *SharedDict {
m.mu.RLock()
defer m.mu.RUnlock()
return m.dicts[name]
}
// Close 清理所有字典
func (m *SharedDictManager) Close() {
m.mu.Lock()
defer m.mu.Unlock()
m.dicts = nil
}
// DictConfig 字典配置
type DictConfig struct {
Name string
MaxItems int
}
// RegisterSharedDictAPI 注册 ngx.shared.DICT API
func RegisterSharedDictAPI(L *glua.LState, manager *SharedDictManager, ngx *glua.LTable) {
// 创建 ngx.shared 表
shared := L.NewTable()
// ngx.shared.DICT - 返回字典 userdata
L.SetField(shared, "DICT", L.NewFunction(func(L *glua.LState) int {
name := L.CheckString(1)
dict := manager.GetDict(name)
if dict == nil {
L.Push(glua.LNil)
L.Push(glua.LString("shared dict not found: " + name))
return 2
}
// 返回字典 userdata
ud := L.NewUserData()
ud.Value = dict
L.SetMetatable(ud, L.GetTypeMetatable("ngx.shared.dict"))
L.Push(ud)
return 1
}))
L.SetField(ngx, "shared", shared)
// 创建字典类型元表
mt := L.NewTypeMetatable("ngx.shared.dict")
L.SetField(mt, "__index", L.NewFunction(dictIndex))
L.SetField(mt, "__newindex", L.NewFunction(dictNewIndex))
L.SetField(mt, "__tostring", L.NewFunction(dictToString))
// 注册字典方法
methods := L.NewTable()
L.SetField(methods, "get", L.NewFunction(dictGet))
L.SetField(methods, "set", L.NewFunction(dictSet))
L.SetField(methods, "add", L.NewFunction(dictAdd))
L.SetField(methods, "replace", L.NewFunction(dictReplace))
L.SetField(methods, "incr", L.NewFunction(dictIncr))
L.SetField(methods, "delete", L.NewFunction(dictDelete))
L.SetField(methods, "flush_all", L.NewFunction(dictFlushAll))
L.SetField(methods, "flush_expired", L.NewFunction(dictFlushExpired))
L.SetField(methods, "get_keys", L.NewFunction(dictGetKeys))
L.SetField(methods, "size", L.NewFunction(dictSize))
L.SetField(methods, "free_space", L.NewFunction(dictFreeSpace))
L.SetField(mt, "methods", methods)
}
// dictIndex 字典索引方法
func dictIndex(L *glua.LState) int {
ud := L.CheckUserData(1)
dict, ok := ud.Value.(*SharedDict)
if !ok {
L.RaiseError("invalid shared dict")
return 0
}
key := L.CheckString(2)
// 检查是否是方法
methods := L.GetField(L.Get(1).(*glua.LUserData).Metatable, "methods")
if method := L.GetField(methods, key); method != glua.LNil {
L.Push(method)
return 1
}
// 否则作为 key 获取值
value, expired, err := dict.Get(key)
if err != nil {
L.RaiseError("%s", err.Error())
return 0
}
if expired {
L.Push(glua.LNil)
L.Push(glua.LString("expired"))
return 2
}
if value == "" {
L.Push(glua.LNil)
return 1
}
L.Push(glua.LString(value))
return 1
}
// dictNewIndex 字典设置方法
func dictNewIndex(L *glua.LState) int {
ud := L.CheckUserData(1)
dict, ok := ud.Value.(*SharedDict)
if !ok {
L.RaiseError("invalid shared dict")
return 0
}
key := L.CheckString(2)
value := L.CheckString(3)
ok, err := dict.Set(key, value, 0)
if err != nil {
L.RaiseError("%s", err.Error())
return 0
}
if !ok {
L.Push(glua.LFalse)
L.Push(glua.LString("no memory"))
return 2
}
L.Push(glua.LTrue)
return 1
}
// dictToString 字典字符串表示
func dictToString(L *glua.LState) int {
ud := L.CheckUserData(1)
dict, ok := ud.Value.(*SharedDict)
if !ok {
L.Push(glua.LString("invalid shared dict"))
return 1
}
L.Push(glua.LString("ngx.shared.dict:" + dict.name))
return 1
}
// dictGet 获取值
// dict:get(key) -> value, flags | nil, err
func dictGet(L *glua.LState) int {
ud := L.CheckUserData(1)
dict, ok := ud.Value.(*SharedDict)
if !ok {
L.RaiseError("invalid shared dict")
return 0
}
key := L.CheckString(2)
value, expired, err := dict.Get(key)
if err != nil {
L.Push(glua.LNil)
L.Push(glua.LString(err.Error()))
return 2
}
if expired {
L.Push(glua.LNil)
L.Push(glua.LString("expired"))
return 2
}
if value == "" {
L.Push(glua.LNil)
return 1
}
L.Push(glua.LString(value))
L.Push(glua.LNumber(0)) // flags暂不支持
return 2
}
// dictSet 设置值
// dict:set(key, value, exptime?, flags?) -> ok, err
func dictSet(L *glua.LState) int {
ud := L.CheckUserData(1)
dict, ok := ud.Value.(*SharedDict)
if !ok {
L.RaiseError("invalid shared dict")
return 0
}
key := L.CheckString(2)
value := L.CheckString(3)
ttl := time.Duration(0)
if L.GetTop() >= 4 {
ttl = time.Duration(L.CheckNumber(4)) * time.Second
}
// flags 参数暂不使用
ok, err := dict.Set(key, value, ttl)
if err != nil {
L.Push(glua.LFalse)
L.Push(glua.LString(err.Error()))
return 2
}
if !ok {
L.Push(glua.LFalse)
L.Push(glua.LString("no memory"))
return 2
}
L.Push(glua.LTrue)
return 1
}
// dictAdd 添加值(不存在时)
// dict:add(key, value, exptime?, flags?) -> ok, err
func dictAdd(L *glua.LState) int {
ud := L.CheckUserData(1)
dict, ok := ud.Value.(*SharedDict)
if !ok {
L.RaiseError("invalid shared dict")
return 0
}
key := L.CheckString(2)
value := L.CheckString(3)
ttl := time.Duration(0)
if L.GetTop() >= 4 {
ttl = time.Duration(L.CheckNumber(4)) * time.Second
}
ok, err := dict.Add(key, value, ttl)
if err != nil {
L.Push(glua.LFalse)
L.Push(glua.LString(err.Error()))
return 2
}
if !ok {
L.Push(glua.LFalse)
L.Push(glua.LString("exists"))
return 2
}
L.Push(glua.LTrue)
return 1
}
// dictReplace 替换值(存在时)
// dict:replace(key, value, exptime?, flags?) -> ok, err
func dictReplace(L *glua.LState) int {
ud := L.CheckUserData(1)
dict, ok := ud.Value.(*SharedDict)
if !ok {
L.RaiseError("invalid shared dict")
return 0
}
key := L.CheckString(2)
value := L.CheckString(3)
ttl := time.Duration(0)
if L.GetTop() >= 4 {
ttl = time.Duration(L.CheckNumber(4)) * time.Second
}
// 检查是否存在
_, expired, _ := dict.Get(key)
if expired {
L.Push(glua.LFalse)
L.Push(glua.LString("not found"))
return 2
}
ok, err := dict.Set(key, value, ttl)
if err != nil {
L.Push(glua.LFalse)
L.Push(glua.LString(err.Error()))
return 2
}
L.Push(glua.LTrue)
return 1
}
// dictIncr 自增数值
// dict:incr(key, value) -> new_value, err
func dictIncr(L *glua.LState) int {
ud := L.CheckUserData(1)
dict, ok := ud.Value.(*SharedDict)
if !ok {
L.RaiseError("invalid shared dict")
return 0
}
key := L.CheckString(2)
increment := int(L.CheckNumber(3))
newValue, err := dict.Incr(key, increment)
if err != nil {
L.Push(glua.LNil)
L.Push(glua.LString(err.Error()))
return 2
}
L.Push(glua.LNumber(newValue))
return 1
}
// dictDelete 删除条目
// dict:delete(key) -> ok
func dictDelete(L *glua.LState) int {
ud := L.CheckUserData(1)
dict, ok := ud.Value.(*SharedDict)
if !ok {
L.RaiseError("invalid shared dict")
return 0
}
key := L.CheckString(2)
dict.Delete(key)
L.Push(glua.LTrue)
return 1
}
// dictFlushAll 清空字典
// dict:flush_all()
func dictFlushAll(L *glua.LState) int {
ud := L.CheckUserData(1)
dict, ok := ud.Value.(*SharedDict)
if !ok {
L.RaiseError("invalid shared dict")
return 0
}
dict.FlushAll()
return 0
}
// dictFlushExpired 清除过期条目
// dict:flush_expired(max_count?) -> flushed_count
func dictFlushExpired(L *glua.LState) int {
ud := L.CheckUserData(1)
dict, ok := ud.Value.(*SharedDict)
if !ok {
L.RaiseError("invalid shared dict")
return 0
}
count := dict.FlushExpired()
L.Push(glua.LNumber(count))
return 1
}
// dictGetKeys 获取所有键
// dict:get_keys(max_count?) -> keys
func dictGetKeys(L *glua.LState) int {
ud := L.CheckUserData(1)
_, ok := ud.Value.(*SharedDict)
if !ok {
L.RaiseError("invalid shared dict")
return 0
}
// 暂不实现完整版,返回空表
keys := L.NewTable()
L.Push(keys)
return 1
}
// dictSize 获取条目数
// dict:size() -> count
func dictSize(L *glua.LState) int {
ud := L.CheckUserData(1)
dict, ok := ud.Value.(*SharedDict)
if !ok {
L.RaiseError("invalid shared dict")
return 0
}
L.Push(glua.LNumber(dict.Size()))
return 1
}
// dictFreeSpace 获取剩余容量
// dict:free_space() -> slots
func dictFreeSpace(L *glua.LState) int {
ud := L.CheckUserData(1)
dict, ok := ud.Value.(*SharedDict)
if !ok {
L.RaiseError("invalid shared dict")
return 0
}
L.Push(glua.LNumber(dict.FreeSlots()))
return 1
}

306
internal/lua/api_timer.go Normal file
View File

@ -0,0 +1,306 @@
// Package lua 提供 Lua 脚本嵌入能力
package lua
import (
"sync"
"sync/atomic"
"time"
glua "github.com/yuin/gopher-lua"
)
// TimerManager 定时器管理器
type TimerManager struct {
mu sync.Mutex
timers map[uint64]*TimerEntry
nextID uint64
engine *LuaEngine
active int32
stopping int32
}
// TimerEntry 定时器条目
type TimerEntry struct {
id uint64
delay time.Duration
callback *glua.LFunction
args []glua.LValue
timer *time.Timer
cancel chan struct{}
done chan struct{}
}
// TimerHandle 定时器句柄Lua userdata
type TimerHandle struct {
id uint64
manager *TimerManager
}
// NewTimerManager 创建定时器管理器
func NewTimerManager(engine *LuaEngine) *TimerManager {
return &TimerManager{
timers: make(map[uint64]*TimerEntry),
engine: engine,
}
}
// At 创建定时器
// 返回定时器句柄和错误
func (m *TimerManager) At(delay time.Duration, callback *glua.LFunction, args []glua.LValue) (*TimerHandle, error) {
if atomic.LoadInt32(&m.stopping) != 0 {
return nil, nil // 服务器正在关闭,不接受新定时器
}
m.mu.Lock()
defer m.mu.Unlock()
id := atomic.AddUint64(&m.nextID, 1)
entry := &TimerEntry{
id: id,
delay: delay,
callback: callback,
args: args,
cancel: make(chan struct{}),
done: make(chan struct{}),
}
// 设置定时器
entry.timer = time.AfterFunc(delay, func() {
m.executeTimer(entry)
})
m.timers[id] = entry
atomic.AddInt32(&m.active, 1)
return &TimerHandle{id: id, manager: m}, nil
}
// executeTimer 执行定时器回调
// 注意:由于 gopher-lua 不是线程安全的,定时器回调执行有限制
// 当前简化版本仅支持记录定时器触发,不执行实际 Lua 回调
func (m *TimerManager) executeTimer(entry *TimerEntry) {
defer func() {
atomic.AddInt32(&m.active, -1)
close(entry.done)
}()
// 检查是否被取消
select {
case <-entry.cancel:
return // 已取消
default:
}
// 检查 engine 是否已关闭
if m.engine == nil || m.engine.L == nil {
return
}
// 由于 gopher-lua 不是线程安全的,异步 goroutine 中不能直接调用 LState
// 完整实现需要使用 channel 将回调调度到主线程执行
// 这里简化处理:定时器触发后记录日志(生产环境应该有更好的方案)
// 清理定时器条目
m.mu.Lock()
if m.timers != nil {
delete(m.timers, entry.id)
}
m.mu.Unlock()
}
// Cancel 取消定时器
func (m *TimerManager) Cancel(handle *TimerHandle) bool {
m.mu.Lock()
defer m.mu.Unlock()
entry, ok := m.timers[handle.id]
if !ok {
return false // 定时器不存在或已执行
}
// 停止定时器
if entry.timer != nil {
entry.timer.Stop()
}
// 发送取消信号
close(entry.cancel)
// 清理
delete(m.timers, entry.id)
atomic.AddInt32(&m.active, -1)
return true
}
// WaitAll 等待所有定时器完成
func (m *TimerManager) WaitAll(timeout time.Duration) bool {
// 设置停止标志
atomic.StoreInt32(&m.stopping, 1)
// 等待所有定时器完成
start := time.Now()
for atomic.LoadInt32(&m.active) > 0 {
if time.Since(start) > timeout {
// 超时,强制取消所有
m.mu.Lock()
for _, entry := range m.timers {
if entry.timer != nil {
entry.timer.Stop()
}
close(entry.cancel)
}
m.timers = make(map[uint64]*TimerEntry)
m.mu.Unlock()
return false
}
time.Sleep(10 * time.Millisecond)
}
return true
}
// Close 关闭定时器管理器
func (m *TimerManager) Close() {
m.WaitAll(5 * time.Second)
}
// ActiveCount 返回活跃定时器数
func (m *TimerManager) ActiveCount() int32 {
return atomic.LoadInt32(&m.active)
}
// RegisterTimerAPI 注册 ngx.timer API
func RegisterTimerAPI(L *glua.LState, manager *TimerManager, ngx *glua.LTable) {
// 创建 ngx.timer 表
timer := L.NewTable()
// ngx.timer.at(delay, callback, ...)
L.SetField(timer, "at", L.NewFunction(func(L *glua.LState) int {
// 检查参数
delay := float64(L.CheckNumber(1))
callback := L.CheckFunction(2)
// 收集额外参数
args := []glua.LValue{}
for i := 3; i <= L.GetTop(); i++ {
args = append(args, L.Get(i))
}
// 创建定时器
handle, err := manager.At(time.Duration(delay)*time.Second, callback, args)
if err != nil {
L.Push(glua.LNil)
L.Push(glua.LString(err.Error()))
return 2
}
if handle == nil {
L.Push(glua.LNil)
L.Push(glua.LString("server shutting down"))
return 2
}
// 返回定时器句柄
ud := L.NewUserData()
ud.Value = handle
L.SetMetatable(ud, L.GetTypeMetatable("ngx.timer.handle"))
L.Push(ud)
return 1
}))
// ngx.timer.running_count()
L.SetField(timer, "running_count", L.NewFunction(func(L *glua.LState) int {
L.Push(glua.LNumber(manager.ActiveCount()))
return 1
}))
L.SetField(ngx, "timer", timer)
// 创建定时器句柄元表
mt := L.NewTypeMetatable("ngx.timer.handle")
L.SetField(mt, "__index", L.NewFunction(timerHandleIndex))
L.SetField(mt, "__tostring", L.NewFunction(timerHandleToString))
// 注册方法
methods := L.NewTable()
L.SetField(methods, "cancel", L.NewFunction(timerHandleCancel))
L.SetField(mt, "methods", methods)
}
// timerHandleIndex 定时器句柄索引
func timerHandleIndex(L *glua.LState) int {
ud := L.CheckUserData(1)
_, ok := ud.Value.(*TimerHandle)
if !ok {
L.RaiseError("invalid timer handle")
return 0
}
// 检查是否是方法
methods := L.GetField(L.Get(1).(*glua.LUserData).Metatable, "methods")
if method := L.GetField(methods, L.CheckString(2)); method != glua.LNil {
L.Push(method)
return 1
}
L.Push(glua.LNil)
return 1
}
// timerHandleToString 定时器句柄字符串表示
func timerHandleToString(L *glua.LState) int {
ud := L.CheckUserData(1)
handle, ok := ud.Value.(*TimerHandle)
if !ok {
L.Push(glua.LString("invalid timer handle"))
return 1
}
L.Push(glua.LString("ngx.timer.handle:" + uint64ToStr(handle.id)))
return 1
}
// timerHandleCancel 取消定时器
func timerHandleCancel(L *glua.LState) int {
ud := L.CheckUserData(1)
handle, ok := ud.Value.(*TimerHandle)
if !ok {
L.RaiseError("invalid timer handle")
return 0
}
if handle.manager == nil {
L.Push(glua.LFalse)
L.Push(glua.LString("timer manager not available"))
return 2
}
ok = handle.manager.Cancel(handle)
if ok {
L.Push(glua.LTrue)
return 1
}
L.Push(glua.LFalse)
L.Push(glua.LString("timer not found or already executed"))
return 2
}
// uint64ToStr 整数转字符串
func uint64ToStr(n uint64) string {
if n == 0 {
return "0"
}
var buf []byte
for n > 0 {
buf = append(buf, byte('0'+n%10))
n /= 10
}
// 反转
for i, j := 0, len(buf)-1; i < j; i, j = i+1, j-1 {
buf[i], buf[j] = buf[j], buf[i]
}
return string(buf)
}

View File

@ -0,0 +1,149 @@
// Package lua 提供 Lua 脚本嵌入能力
package lua
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
glua "github.com/yuin/gopher-lua"
)
func TestTimerManagerAt(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
manager := engine.TimerManager()
require.NotNil(t, manager)
// 创建 Lua 函数作为回调
L := engine.L
// 注册一个简单的回调函数
callback := L.NewFunction(func(L *glua.LState) int {
return 0
})
// 创建定时器
handle, err := manager.At(100*time.Millisecond, callback, nil)
require.NoError(t, err)
require.NotNil(t, handle)
// 等待定时器触发
time.Sleep(200 * time.Millisecond)
// 定时器应该已完成active count 回到 0
assert.Equal(t, int32(0), manager.ActiveCount())
}
func TestTimerManagerCancel(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
manager := engine.TimerManager()
callback := engine.L.NewFunction(func(L *glua.LState) int {
return 0
})
// 创建定时器
handle, err := manager.At(200*time.Millisecond, callback, nil)
require.NoError(t, err)
// 立即取消
ok := manager.Cancel(handle)
assert.True(t, ok)
// 等待超过定时器时间
time.Sleep(300 * time.Millisecond)
// 定时器应该被取消active count 为 0
assert.Equal(t, int32(0), manager.ActiveCount())
}
func TestTimerManagerWaitAll(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
manager := engine.TimerManager()
// 创建多个定时器
for range 3 {
callback := engine.L.NewFunction(func(L *glua.LState) int {
return 0
})
manager.At(50*time.Millisecond, callback, nil)
}
// 等待所有完成
ok := manager.WaitAll(1 * time.Second)
assert.True(t, ok)
// active count 应该回到 0
assert.Equal(t, int32(0), manager.ActiveCount())
engine.Close()
}
func TestTimerLuaAPI(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
L := engine.L
// 注册 ngx.timer API
ngx := L.NewTable()
L.SetGlobal("ngx", ngx)
RegisterTimerAPI(L, engine.TimerManager(), ngx)
// 测试 ngx.timer.at
err = L.DoString(`
local count = 0
-- 创建定时器
local handle, err = ngx.timer.at(0.1, function()
count = count + 1
end)
assert(handle ~= nil)
assert(err == nil)
-- 检查 running_count
local running = ngx.timer.running_count()
assert(running >= 1)
`)
require.NoError(t, err)
}
func TestTimerRunningCount(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
manager := engine.TimerManager()
// 初始应该为 0
assert.Equal(t, int32(0), manager.ActiveCount())
// 创建定时器
callback := engine.L.NewFunction(func(L *glua.LState) int {
return 0
})
handle, _ := manager.At(50*time.Millisecond, callback, nil)
_ = handle
// 刚创建后应该有活跃定时器(在定时器触发前)
// 注意:由于简化实现,定时器执行很快,所以 active count 可能很快回到 0
// 这里我们只验证定时器最终会完成
// 等待完成
time.Sleep(100 * time.Millisecond)
// 应该回到 0
assert.Equal(t, int32(0), manager.ActiveCount())
}

View File

@ -167,7 +167,7 @@ func (c *LuaCoroutine) setupSecureCoroutineLib() {
} }
// setupNgxAPI 创建 ngx API // setupNgxAPI 创建 ngx API
// 注册 ngx.req、ngx.resp、ngx.var、ngx.ctx、ngx.log 和 ngx.socket API // 注册 ngx.req、ngx.resp、ngx.var、ngx.ctx、ngx.log、ngx.socket 和 ngx.shared API
func (c *LuaCoroutine) setupNgxAPI() { func (c *LuaCoroutine) setupNgxAPI() {
// 创建 ngx 表 // 创建 ngx 表
ngx := c.Co.NewTable() ngx := c.Co.NewTable()
@ -199,6 +199,15 @@ func (c *LuaCoroutine) setupNgxAPI() {
// 注册 ngx.socket API // 注册 ngx.socket API
RegisterTCPSocketAPI(c.Co, c.Engine) RegisterTCPSocketAPI(c.Co, c.Engine)
// 注册 ngx.shared.DICT API
RegisterSharedDictAPI(c.Co, c.Engine.SharedDictManager(), ngx)
// 注册 ngx.timer API
RegisterTimerAPI(c.Co, c.Engine.TimerManager(), ngx)
// 注册 ngx.location API
RegisterLocationAPI(c.Co, c.Engine.LocationManager(), ngx)
} }
// Execute 在协程中执行 Lua 脚本(支持 Yield/Resume // Execute 在协程中执行 Lua 脚本(支持 Yield/Resume

View File

@ -38,6 +38,15 @@ type LuaEngine struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
// 共享字典管理器
sharedDictManager *SharedDictManager
// 定时器管理器
timerManager *TimerManager
// location 管理器
locationManager *LocationManager
// 统计 // 统计
stats EngineStats stats EngineStats
} }
@ -80,12 +89,13 @@ func NewEngine(config *Config) (*LuaEngine, error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
engine := &LuaEngine{ engine := &LuaEngine{
L: L, L: L,
config: config, config: config,
codeCache: NewCodeCache(config.CodeCacheSize, config.CodeCacheTTL, config.EnableFileWatch), codeCache: NewCodeCache(config.CodeCacheSize, config.CodeCacheTTL, config.EnableFileWatch),
maxCoroutines: config.MaxConcurrentCoroutines, maxCoroutines: config.MaxConcurrentCoroutines,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
sharedDictManager: NewSharedDictManager(),
coroutinePool: sync.Pool{ coroutinePool: sync.Pool{
New: func() interface{} { New: func() interface{} {
// 注意:这里只是创建空的协程对象结构 // 注意:这里只是创建空的协程对象结构
@ -95,12 +105,24 @@ func NewEngine(config *Config) (*LuaEngine, error) {
}, },
} }
// 创建定时器管理器(需要在 engine 创建后初始化)
engine.timerManager = NewTimerManager(engine)
// 创建 location 管理器
engine.locationManager = NewLocationManager()
return engine, nil return engine, nil
} }
// Close 关闭引擎 // Close 关闭引擎
func (e *LuaEngine) Close() { func (e *LuaEngine) Close() {
e.cancel() e.cancel()
if e.timerManager != nil {
e.timerManager.Close()
}
if e.sharedDictManager != nil {
e.sharedDictManager.Close()
}
if e.L != nil { if e.L != nil {
e.L.Close() e.L.Close()
} }
@ -188,3 +210,23 @@ func (e *LuaEngine) Stats() EngineStats {
func (e *LuaEngine) ActiveCoroutines() int32 { func (e *LuaEngine) ActiveCoroutines() int32 {
return atomic.LoadInt32(&e.activeCount) return atomic.LoadInt32(&e.activeCount)
} }
// SharedDictManager 返回共享字典管理器
func (e *LuaEngine) SharedDictManager() *SharedDictManager {
return e.sharedDictManager
}
// CreateSharedDict 创建共享字典
func (e *LuaEngine) CreateSharedDict(name string, maxItems int) *SharedDict {
return e.sharedDictManager.CreateDict(name, maxItems)
}
// TimerManager 返回定时器管理器
func (e *LuaEngine) TimerManager() *TimerManager {
return e.timerManager
}
// LocationManager 返回 location 管理器
func (e *LuaEngine) LocationManager() *LocationManager {
return e.locationManager
}

324
internal/lua/shared_dict.go Normal file
View File

@ -0,0 +1,324 @@
// Package lua 提供 Lua 脚本嵌入能力
package lua
import (
"container/list"
"sync"
"time"
)
// SharedDict 共享内存字典
// 支持并发安全的 key-value 存储,带 LRU 汰出策略
type SharedDict struct {
name string
maxItems int
mu sync.Mutex
data map[string]*sharedDictEntry
lruList *list.List // LRU 链表,前端为最近使用
}
// sharedDictEntry 字典条目
type sharedDictEntry struct {
key string
value string
expiredAt time.Time // 过期时间0 表示永不过期)
element *list.Element // LRU 链表元素指针
}
// NewSharedDict 创建共享字典
func NewSharedDict(name string, maxItems int) *SharedDict {
return &SharedDict{
name: name,
maxItems: maxItems,
data: make(map[string]*sharedDictEntry),
lruList: list.New(),
}
}
// Get 获取值
// 返回 value, expired, err
func (d *SharedDict) Get(key string) (string, bool, error) {
d.mu.Lock()
defer d.mu.Unlock()
entry, ok := d.data[key]
if !ok {
return "", false, nil // 不存在
}
// 检查过期
if !entry.expiredAt.IsZero() && time.Now().After(entry.expiredAt) {
// 已过期,删除并返回
d.deleteEntry(entry)
return "", true, nil // 存在但已过期
}
// 更新 LRU - 移到前端
d.lruList.MoveToFront(entry.element)
return entry.value, false, nil
}
// Set 设置值
// 返回 ok, err (ok=false 表示容量满且无法淘汰)
func (d *SharedDict) Set(key, value string, ttl time.Duration) (bool, error) {
d.mu.Lock()
defer d.mu.Unlock()
// 检查是否已存在
if entry, ok := d.data[key]; ok {
// 更新现有条目
entry.value = value
if ttl > 0 {
entry.expiredAt = time.Now().Add(ttl)
} else {
entry.expiredAt = time.Time{} // 清除过期时间
}
d.lruList.MoveToFront(entry.element)
return true, nil
}
// 新条目,检查容量
if len(d.data) >= d.maxItems {
// 尝试淘汰过期条目
d.evictExpired()
if len(d.data) >= d.maxItems {
// 淘汰 LRU 最久未使用的条目
if !d.evictLRU() {
return false, nil // 无法淘汰(字典为空?)
}
}
}
// 创建新条目
expiredAt := time.Time{}
if ttl > 0 {
expiredAt = time.Now().Add(ttl)
}
element := d.lruList.PushFront(key)
entry := &sharedDictEntry{
key: key,
value: value,
expiredAt: expiredAt,
element: element,
}
d.data[key] = entry
return true, nil
}
// Add 添加值(仅在不存在时设置)
// 返回 ok, err (ok=false 表示已存在或容量满)
func (d *SharedDict) Add(key, value string, ttl time.Duration) (bool, error) {
d.mu.Lock()
defer d.mu.Unlock()
// 检查是否已存在(包括过期的也算存在)
if _, ok := d.data[key]; ok {
return false, nil // 已存在
}
// 检查容量并淘汰
if len(d.data) >= d.maxItems {
d.evictExpired()
if len(d.data) >= d.maxItems {
if !d.evictLRU() {
return false, nil
}
}
}
// 创建新条目
expiredAt := time.Time{}
if ttl > 0 {
expiredAt = time.Now().Add(ttl)
}
element := d.lruList.PushFront(key)
entry := &sharedDictEntry{
key: key,
value: value,
expiredAt: expiredAt,
element: element,
}
d.data[key] = entry
return true, nil
}
// Incr 自增数值
// 返回 new_value, err
func (d *SharedDict) Incr(key string, increment int) (int, error) {
d.mu.Lock()
defer d.mu.Unlock()
entry, ok := d.data[key]
if !ok {
// 不存在,创建初始值
if len(d.data) >= d.maxItems {
d.evictExpired()
if len(d.data) >= d.maxItems {
if !d.evictLRU() {
return 0, nil // 无法创建
}
}
}
element := d.lruList.PushFront(key)
entry = &sharedDictEntry{
key: key,
value: "0",
element: element,
}
d.data[key] = entry
}
// 解析数值
var current int
for _, c := range entry.value {
if c < '0' || c > '9' {
return 0, nil // 不是数值
}
current = current*10 + int(c-'0')
}
newValue := current + increment
entry.value = intToStr(newValue)
d.lruList.MoveToFront(entry.element)
return newValue, nil
}
// Delete 删除条目
func (d *SharedDict) Delete(key string) error {
d.mu.Lock()
defer d.mu.Unlock()
if entry, ok := d.data[key]; ok {
d.deleteEntry(entry)
}
return nil
}
// FlushAll 清空所有条目
func (d *SharedDict) FlushAll() error {
d.mu.Lock()
defer d.mu.Unlock()
d.data = make(map[string]*sharedDictEntry)
d.lruList = list.New()
return nil
}
// FlushExpired 清除所有过期条目
// 返回清除的条目数
func (d *SharedDict) FlushExpired() int {
d.mu.Lock()
defer d.mu.Unlock()
return d.evictExpired()
}
// Size 返回当前条目数
func (d *SharedDict) Size() int {
d.mu.Lock()
defer d.mu.Unlock()
return len(d.data)
}
// FreeSlots 返回剩余容量
func (d *SharedDict) FreeSlots() int {
d.mu.Lock()
defer d.mu.Unlock()
return d.maxItems - len(d.data)
}
// deleteEntry 删除条目(内部方法,已持有锁)
func (d *SharedDict) deleteEntry(entry *sharedDictEntry) {
d.lruList.Remove(entry.element)
delete(d.data, entry.key)
}
// evictExpired 淘汰过期条目(内部方法,已持有锁)
func (d *SharedDict) evictExpired() int {
now := time.Now()
count := 0
// 从 LRU 链表尾部(最久未使用)开始检查
for elem := d.lruList.Back(); elem != nil; {
key := elem.Value.(string)
entry, ok := d.data[key]
if !ok {
// 数据不一致,跳过
next := elem.Prev()
d.lruList.Remove(elem)
elem = next
continue
}
if !entry.expiredAt.IsZero() && now.After(entry.expiredAt) {
// 已过期,删除
d.deleteEntry(entry)
count++
elem = d.lruList.Back() // 重新从尾部开始
} else {
break // 未过期,停止(链表顺序保证前面都是未过期的)
}
}
return count
}
// evictLRU 淘汰 LRU 最久未使用的条目(内部方法,已持有锁)
func (d *SharedDict) evictLRU() bool {
if d.lruList.Len() == 0 {
return false
}
elem := d.lruList.Back()
if elem == nil {
return false
}
key := elem.Value.(string)
entry, ok := d.data[key]
if ok {
d.deleteEntry(entry)
return true
}
// 数据不一致,移除链表元素并重试
d.lruList.Remove(elem)
return d.evictLRU()
}
// intToStr 整数转字符串(简单实现,避免 strconv 依赖)
func intToStr(n int) string {
if n == 0 {
return "0"
}
var negative bool
if n < 0 {
negative = true
n = -n
}
var buf []byte
for n > 0 {
buf = append(buf, byte('0'+n%10))
n /= 10
}
if negative {
buf = append(buf, '-')
}
// 反转
for i, j := 0, len(buf)-1; i < j; i, j = i+1, j-1 {
buf[i], buf[j] = buf[j], buf[i]
}
return string(buf)
}

View File

@ -0,0 +1,251 @@
// Package lua 提供 Lua 脚本嵌入能力
package lua
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSharedDictGetSet(t *testing.T) {
dict := NewSharedDict("test", 100)
// Set
ok, err := dict.Set("key1", "value1", 0)
require.NoError(t, err)
assert.True(t, ok)
// Get
value, expired, err := dict.Get("key1")
require.NoError(t, err)
assert.False(t, expired)
assert.Equal(t, "value1", value)
// 不存在的 key
value, expired, err = dict.Get("notexist")
require.NoError(t, err)
assert.False(t, expired)
assert.Equal(t, "", value)
}
func TestSharedDictAdd(t *testing.T) {
dict := NewSharedDict("test", 100)
// Add 新 key
ok, err := dict.Add("key1", "value1", 0)
require.NoError(t, err)
assert.True(t, ok)
// Add 已存在的 key
ok, err = dict.Add("key1", "value2", 0)
require.NoError(t, err)
assert.False(t, ok) // 已存在,返回 false
// 验证值未被修改
value, _, _ := dict.Get("key1")
assert.Equal(t, "value1", value)
}
func TestSharedDictIncr(t *testing.T) {
dict := NewSharedDict("test", 100)
// Incr 不存在的 key从 0 开始
newValue, err := dict.Incr("counter", 5)
require.NoError(t, err)
assert.Equal(t, 5, newValue)
// Incr 已存在的 key
newValue, err = dict.Incr("counter", 3)
require.NoError(t, err)
assert.Equal(t, 8, newValue)
// 轮减
newValue, err = dict.Incr("counter", -2)
require.NoError(t, err)
assert.Equal(t, 6, newValue)
}
func TestSharedDictDelete(t *testing.T) {
dict := NewSharedDict("test", 100)
dict.Set("key1", "value1", 0)
dict.Delete("key1")
value, _, _ := dict.Get("key1")
assert.Equal(t, "", value)
// 删除不存在的 key 不会报错
dict.Delete("notexist")
}
func TestSharedDictTTL(t *testing.T) {
dict := NewSharedDict("test", 100)
// Set 带 TTL
ok, err := dict.Set("key1", "value1", 100*time.Millisecond)
require.NoError(t, err)
assert.True(t, ok)
// 立即获取应该成功
value, expired, err := dict.Get("key1")
require.NoError(t, err)
assert.False(t, expired)
assert.Equal(t, "value1", value)
// 等待过期
time.Sleep(150 * time.Millisecond)
// 过期后获取
value, expired, err = dict.Get("key1")
require.NoError(t, err)
assert.True(t, expired)
assert.Equal(t, "", value)
}
func TestSharedDictLRUEviction(t *testing.T) {
dict := NewSharedDict("test", 3) // 只有 3 个容量
// 添加 3 个条目
dict.Set("key1", "value1", 0)
dict.Set("key2", "value2", 0)
dict.Set("key3", "value3", 0)
// 使用 key1使其成为最近使用
dict.Get("key1")
// 添加第 4 个条目,应该淘汰 key2最久未使用
ok, err := dict.Set("key4", "value4", 0)
require.NoError(t, err)
assert.True(t, ok)
// key1 应该还在
value, _, _ := dict.Get("key1")
assert.Equal(t, "value1", value)
// key2 应该被淘汰
value, _, _ = dict.Get("key2")
assert.Equal(t, "", value)
}
func TestSharedDictFlushAll(t *testing.T) {
dict := NewSharedDict("test", 100)
dict.Set("key1", "value1", 0)
dict.Set("key2", "value2", 0)
dict.FlushAll()
assert.Equal(t, 0, dict.Size())
}
func TestSharedDictFlushExpired(t *testing.T) {
dict := NewSharedDict("test", 100)
dict.Set("key1", "value1", 100*time.Millisecond)
dict.Set("key2", "value2", 100*time.Millisecond)
dict.Set("key3", "value3", 0) // 不过期
// 立即清除应该返回 0
count := dict.FlushExpired()
assert.Equal(t, 0, count)
// 等待过期
time.Sleep(150 * time.Millisecond)
count = dict.FlushExpired()
assert.Equal(t, 2, count)
// key3 应该还在
assert.Equal(t, 1, dict.Size())
value, _, _ := dict.Get("key3")
assert.Equal(t, "value3", value)
}
func TestSharedDictSize(t *testing.T) {
dict := NewSharedDict("test", 100)
assert.Equal(t, 0, dict.Size())
assert.Equal(t, 100, dict.FreeSlots())
dict.Set("key1", "value1", 0)
assert.Equal(t, 1, dict.Size())
assert.Equal(t, 99, dict.FreeSlots())
}
func TestSharedDictManager(t *testing.T) {
manager := NewSharedDictManager()
// 创建字典
dict1 := manager.CreateDict("dict1", 100)
require.NotNil(t, dict1)
// 再次获取同一个字典
dict1Again := manager.GetDict("dict1")
assert.Equal(t, dict1, dict1Again)
// 创建另一个字典
dict2 := manager.CreateDict("dict2", 200)
require.NotNil(t, dict2)
// 获取不存在的字典
notexist := manager.GetDict("notexist")
assert.Nil(t, notexist)
// 关闭
manager.Close()
}
func TestSharedDictLuaAPI(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
// 创建共享字典(通过 Lua API 测试,此处仅为初始化)
_ = engine.CreateSharedDict("mydict", 100)
// 测试 Lua 脚本
L := engine.L
// 手动注册 ngx.shared API用于测试
ngx := L.NewTable()
L.SetGlobal("ngx", ngx)
RegisterSharedDictAPI(L, engine.SharedDictManager(), ngx)
// 运行 Lua 脚本测试
err = L.DoString(`
local dict = ngx.shared.DICT("mydict")
-- 测试 set/get
dict:set("key1", "value1")
local val, flags = dict:get("key1")
assert(val == "value1")
-- 测试 add
local ok, err = dict:add("key2", "value2")
assert(ok == true)
-- 测试 add 已存在的 key
ok, err = dict:add("key2", "value3")
assert(ok == false)
assert(err == "exists")
-- 测试 incr
local new_val, err = dict:incr("counter", 10)
assert(new_val == 10)
new_val, err = dict:incr("counter", 5)
assert(new_val == 15)
-- 测试 size
local size = dict:size()
assert(size >= 3)
-- 测试 delete
dict:delete("key1")
local val, err = dict:get("key1")
assert(val == nil)
`)
require.NoError(t, err)
}