Merge branch 'master' of https://github.com/DefectingCat/lolly
This commit is contained in:
commit
5dbede27f8
43
examples/lua-scripts/shared_dict.lua
Normal file
43
examples/lua-scripts/shared_dict.lua
Normal file
@ -0,0 +1,43 @@
|
||||
-- shared_dict.lua - 共享字典示例
|
||||
-- 此脚本演示 ngx.shared.DICT 的使用
|
||||
|
||||
-- 获取共享字典(需要在配置中预先定义)
|
||||
local dict = ngx.shared.DICT("my_cache")
|
||||
|
||||
-- 设置值
|
||||
local ok, err = dict:set("user_count", "100")
|
||||
if not ok then
|
||||
ngx.log(ngx.ERR, "failed to set user_count: ", err)
|
||||
end
|
||||
|
||||
-- 设置带 TTL 的值
|
||||
ok, err = dict:set("session_token", "abc123", 3600) -- 1 小时过期
|
||||
if not ok then
|
||||
ngx.log(ngx.ERR, "failed to set session_token: ", err)
|
||||
end
|
||||
|
||||
-- 获取值
|
||||
local value, flags = dict:get("user_count")
|
||||
ngx.say("user_count: ", value)
|
||||
|
||||
-- 自增计数器
|
||||
local new_val, err = dict:incr("request_count", 1)
|
||||
ngx.say("request_count: ", new_val)
|
||||
|
||||
-- 添加值(仅不存在时)
|
||||
ok, err = dict:add("unique_key", "value")
|
||||
if ok then
|
||||
ngx.say("unique_key added successfully")
|
||||
else
|
||||
ngx.say("unique_key already exists")
|
||||
end
|
||||
|
||||
-- 查看字典大小
|
||||
local size = dict:size()
|
||||
ngx.say("dict size: ", size)
|
||||
|
||||
-- 获取剩余容量
|
||||
local free = dict:free_space()
|
||||
ngx.say("free space: ", free)
|
||||
|
||||
ngx.say("Shared dict demo completed!")
|
||||
26
examples/lua-scripts/subrequest.lua
Normal file
26
examples/lua-scripts/subrequest.lua
Normal file
@ -0,0 +1,26 @@
|
||||
-- subrequest.lua - 子请求示例
|
||||
-- 此脚本演示 ngx.location.capture 的使用
|
||||
|
||||
-- 简单子请求
|
||||
local res = ngx.location.capture("/api/status")
|
||||
ngx.say("Subrequest status: ", res.status)
|
||||
ngx.say("Subrequest body: ", res.body)
|
||||
|
||||
-- 带 method 的子请求
|
||||
res = ngx.location.capture("/api/users", {
|
||||
method = "POST",
|
||||
body = '{"name": "test"}'
|
||||
})
|
||||
ngx.say("POST status: ", res.status)
|
||||
|
||||
-- 带 headers 的子请求
|
||||
res = ngx.location.capture("/api/check", {
|
||||
method = "GET",
|
||||
headers = {
|
||||
["Authorization"] = "Bearer token123",
|
||||
["X-Custom"] = "value"
|
||||
}
|
||||
})
|
||||
ngx.say("GET with headers status: ", res.status)
|
||||
|
||||
ngx.say("Subrequest demo completed!")
|
||||
35
examples/lua-scripts/timer.lua
Normal file
35
examples/lua-scripts/timer.lua
Normal file
@ -0,0 +1,35 @@
|
||||
-- timer.lua - 定时器示例
|
||||
-- 此脚本演示 ngx.timer.at 的使用
|
||||
|
||||
-- 创建定时器回调函数
|
||||
local function timer_callback()
|
||||
-- 注意:定时器回调在独立上下文中执行
|
||||
-- 不能直接访问请求相关 API
|
||||
ngx.log(ngx.INFO, "Timer executed!")
|
||||
end
|
||||
|
||||
-- 创建 5 秒后执行的定时器
|
||||
local handle, err = ngx.timer.at(5, timer_callback)
|
||||
if handle then
|
||||
ngx.say("Timer created successfully")
|
||||
|
||||
-- 查看活跃定时器数
|
||||
local count = ngx.timer.running_count()
|
||||
ngx.say("Active timers: ", count)
|
||||
else
|
||||
ngx.say("Failed to create timer: ", err)
|
||||
end
|
||||
|
||||
-- 创建带参数的定时器(简化版暂不支持参数传递)
|
||||
local function param_callback()
|
||||
ngx.log(ngx.INFO, "Timer with params executed")
|
||||
end
|
||||
|
||||
handle, err = ngx.timer.at(2, param_callback)
|
||||
if handle then
|
||||
ngx.say("Timer with params created")
|
||||
else
|
||||
ngx.say("Failed: ", err)
|
||||
end
|
||||
|
||||
ngx.say("Timer demo completed!")
|
||||
174
internal/lua/api_location.go
Normal file
174
internal/lua/api_location.go
Normal file
@ -0,0 +1,174 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// LocationCaptureResult 子请求结果
|
||||
type LocationCaptureResult struct {
|
||||
Status int
|
||||
Body []byte
|
||||
Headers map[string]string
|
||||
}
|
||||
|
||||
// LocationManager location 管理(用于子请求)
|
||||
type LocationManager struct {
|
||||
mu sync.Mutex
|
||||
handlers map[string]fasthttp.RequestHandler // location -> handler
|
||||
}
|
||||
|
||||
// NewLocationManager 创建 location 管理器
|
||||
func NewLocationManager() *LocationManager {
|
||||
return &LocationManager{
|
||||
handlers: make(map[string]fasthttp.RequestHandler),
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册 location handler
|
||||
func (m *LocationManager) Register(location string, handler fasthttp.RequestHandler) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.handlers[location] = handler
|
||||
}
|
||||
|
||||
// Capture 执行子请求
|
||||
func (m *LocationManager) Capture(parentCtx *fasthttp.RequestCtx, location string, opts map[string]interface{}) (*LocationCaptureResult, error) {
|
||||
m.mu.Lock()
|
||||
handler, ok := m.handlers[location]
|
||||
m.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
// location 不存在,返回 404
|
||||
return &LocationCaptureResult{
|
||||
Status: 404,
|
||||
Body: []byte("location not found"),
|
||||
Headers: map[string]string{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 创建子请求上下文(不设置 Conn)
|
||||
subCtx := &fasthttp.RequestCtx{}
|
||||
|
||||
// 复制父请求作为基础
|
||||
parentCtx.Request.CopyTo(&subCtx.Request)
|
||||
|
||||
// 设置子请求的 URI
|
||||
subCtx.Request.SetRequestURI(location)
|
||||
|
||||
// 应用选项
|
||||
if opts != nil {
|
||||
if method, ok := opts["method"].(string); ok {
|
||||
subCtx.Request.Header.SetMethod(method)
|
||||
}
|
||||
if body, ok := opts["body"].(string); ok {
|
||||
subCtx.Request.SetBodyString(body)
|
||||
}
|
||||
if headers, ok := opts["headers"].(map[string]string); ok {
|
||||
for k, v := range headers {
|
||||
subCtx.Request.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 执行 handler
|
||||
handler(subCtx)
|
||||
|
||||
// 收集结果
|
||||
result := &LocationCaptureResult{
|
||||
Status: subCtx.Response.StatusCode(),
|
||||
Body: subCtx.Response.Body(),
|
||||
Headers: make(map[string]string),
|
||||
}
|
||||
|
||||
// 收集响应头(使用 VisitAll)
|
||||
subCtx.Response.Header.VisitAll(func(key, value []byte) {
|
||||
result.Headers[string(key)] = string(value)
|
||||
})
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RegisterLocationAPI 注册 ngx.location API
|
||||
func RegisterLocationAPI(L *glua.LState, manager *LocationManager, ngx *glua.LTable) {
|
||||
// 创建 ngx.location 表
|
||||
location := L.NewTable()
|
||||
|
||||
// ngx.location.capture(uri, options?)
|
||||
L.SetField(location, "capture", L.NewFunction(func(L *glua.LState) int {
|
||||
uri := L.CheckString(1)
|
||||
|
||||
// 解析选项
|
||||
opts := make(map[string]interface{})
|
||||
if L.GetTop() >= 2 {
|
||||
optionsTable := L.CheckTable(2)
|
||||
optionsTable.ForEach(func(key, value glua.LValue) {
|
||||
keyStr := glua.LVAsString(key)
|
||||
switch value.Type() {
|
||||
case glua.LTString:
|
||||
opts[keyStr] = glua.LVAsString(value)
|
||||
case glua.LTNumber:
|
||||
opts[keyStr] = float64(glua.LVAsNumber(value))
|
||||
case glua.LTTable:
|
||||
// 处理 headers 表
|
||||
if keyStr == "headers" {
|
||||
headers := make(map[string]string)
|
||||
value.(*glua.LTable).ForEach(func(hKey, hValue glua.LValue) {
|
||||
headers[glua.LVAsString(hKey)] = glua.LVAsString(hValue)
|
||||
})
|
||||
opts[keyStr] = headers
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 创建结果表
|
||||
result := L.NewTable()
|
||||
|
||||
// 尝试执行子请求
|
||||
// 注意:由于无法直接获取 RequestCtx,这里使用模拟的上下文
|
||||
// 在完整实现中,应该通过 coroutine 传递 RequestCtx
|
||||
if manager != nil {
|
||||
// 创建模拟请求上下文用于子请求执行
|
||||
mockCtx := &fasthttp.RequestCtx{}
|
||||
mockCtx.Request.SetRequestURI(uri)
|
||||
|
||||
captureResult, err := manager.Capture(mockCtx, uri, opts)
|
||||
if err == nil && captureResult != nil {
|
||||
L.SetField(result, "status", glua.LNumber(captureResult.Status))
|
||||
L.SetField(result, "body", glua.LString(string(captureResult.Body)))
|
||||
|
||||
// 设置 headers
|
||||
headersTable := headersToLuaTable(L, captureResult.Headers)
|
||||
L.SetField(result, "headers", headersTable)
|
||||
} else {
|
||||
// 执行失败
|
||||
L.SetField(result, "status", glua.LNumber(500))
|
||||
L.SetField(result, "body", glua.LString("subrequest failed"))
|
||||
}
|
||||
} else {
|
||||
// manager 未初始化
|
||||
L.SetField(result, "status", glua.LNumber(404))
|
||||
L.SetField(result, "body", glua.LString("location manager not initialized"))
|
||||
}
|
||||
|
||||
L.Push(result)
|
||||
return 1
|
||||
}))
|
||||
|
||||
L.SetField(ngx, "location", location)
|
||||
}
|
||||
|
||||
// headersToLuaTable 将 headers 转为 Lua 表
|
||||
func headersToLuaTable(L *glua.LState, headers map[string]string) *glua.LTable {
|
||||
table := L.NewTable()
|
||||
for k, v := range headers {
|
||||
// 转换为小写键名(nginx 风格)
|
||||
table.RawSetString(strings.ToLower(k), glua.LString(v))
|
||||
}
|
||||
return table
|
||||
}
|
||||
118
internal/lua/api_location_test.go
Normal file
118
internal/lua/api_location_test.go
Normal file
@ -0,0 +1,118 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestLocationManagerRegister(t *testing.T) {
|
||||
manager := NewLocationManager()
|
||||
require.NotNil(t, manager)
|
||||
|
||||
// 注册 location
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.WriteString("test response")
|
||||
}
|
||||
manager.Register("/test", handler)
|
||||
|
||||
// 验证注册成功
|
||||
manager.mu.Lock()
|
||||
_, ok := manager.handlers["/test"]
|
||||
manager.mu.Unlock()
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestLocationManagerCapture(t *testing.T) {
|
||||
manager := NewLocationManager()
|
||||
|
||||
// 注册 location
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(200)
|
||||
ctx.SetBodyString("hello from subrequest")
|
||||
ctx.Response.Header.Set("X-Custom", "value")
|
||||
}
|
||||
manager.Register("/api/sub", handler)
|
||||
|
||||
// 创建父请求上下文
|
||||
parentCtx := &fasthttp.RequestCtx{}
|
||||
parentCtx.Request.SetRequestURI("/parent")
|
||||
|
||||
// 执行子请求
|
||||
result, err := manager.Capture(parentCtx, "/api/sub", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, 200, result.Status)
|
||||
assert.Equal(t, "hello from subrequest", string(result.Body))
|
||||
assert.Equal(t, "value", result.Headers["X-Custom"])
|
||||
}
|
||||
|
||||
func TestLocationManagerCaptureNotFound(t *testing.T) {
|
||||
manager := NewLocationManager()
|
||||
|
||||
parentCtx := &fasthttp.RequestCtx{}
|
||||
|
||||
// 执行不存在的 location
|
||||
result, err := manager.Capture(parentCtx, "/notexist", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, 404, result.Status)
|
||||
}
|
||||
|
||||
func TestLocationManagerCaptureWithOptions(t *testing.T) {
|
||||
manager := NewLocationManager()
|
||||
|
||||
// 注册 location
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(200)
|
||||
ctx.WriteString("method: " + string(ctx.Method()) + ", body: " + string(ctx.PostBody()))
|
||||
}
|
||||
manager.Register("/echo", handler)
|
||||
|
||||
parentCtx := &fasthttp.RequestCtx{}
|
||||
parentCtx.Request.SetRequestURI("/parent")
|
||||
|
||||
// 使用自定义选项
|
||||
opts := map[string]interface{}{
|
||||
"method": "POST",
|
||||
"body": "test body",
|
||||
}
|
||||
|
||||
result, err := manager.Capture(parentCtx, "/echo", opts)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
|
||||
assert.Equal(t, 200, result.Status)
|
||||
assert.Contains(t, string(result.Body), "method: POST")
|
||||
assert.Contains(t, string(result.Body), "body: test body")
|
||||
}
|
||||
|
||||
func TestLocationLuaAPI(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
L := engine.L
|
||||
|
||||
// 注册 ngx.location API
|
||||
ngx := L.NewTable()
|
||||
L.SetGlobal("ngx", ngx)
|
||||
RegisterLocationAPI(L, engine.LocationManager(), ngx)
|
||||
|
||||
// 测试 ngx.location.capture
|
||||
err = L.DoString(`
|
||||
-- 创建模拟的 location 结果
|
||||
local result = ngx.location.capture("/test")
|
||||
|
||||
-- 验证结果结构
|
||||
assert(result ~= nil)
|
||||
assert(result.status ~= nil)
|
||||
assert(result.body ~= nil)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
428
internal/lua/api_shared_dict.go
Normal file
428
internal/lua/api_shared_dict.go
Normal file
@ -0,0 +1,428 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// SharedDictManager 共享字典管理器
|
||||
// 管理多个命名的 SharedDict 实例
|
||||
type SharedDictManager struct {
|
||||
mu sync.RWMutex
|
||||
dicts map[string]*SharedDict
|
||||
}
|
||||
|
||||
// NewSharedDictManager 创建字典管理器
|
||||
func NewSharedDictManager() *SharedDictManager {
|
||||
return &SharedDictManager{
|
||||
dicts: make(map[string]*SharedDict),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDict 创建或获取字典
|
||||
func (m *SharedDictManager) CreateDict(name string, maxItems int) *SharedDict {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if dict, ok := m.dicts[name]; ok {
|
||||
return dict
|
||||
}
|
||||
|
||||
dict := NewSharedDict(name, maxItems)
|
||||
m.dicts[name] = dict
|
||||
return dict
|
||||
}
|
||||
|
||||
// GetDict 获取字典
|
||||
func (m *SharedDictManager) GetDict(name string) *SharedDict {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.dicts[name]
|
||||
}
|
||||
|
||||
// Close 清理所有字典
|
||||
func (m *SharedDictManager) Close() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.dicts = nil
|
||||
}
|
||||
|
||||
// DictConfig 字典配置
|
||||
type DictConfig struct {
|
||||
Name string
|
||||
MaxItems int
|
||||
}
|
||||
|
||||
// RegisterSharedDictAPI 注册 ngx.shared.DICT API
|
||||
func RegisterSharedDictAPI(L *glua.LState, manager *SharedDictManager, ngx *glua.LTable) {
|
||||
// 创建 ngx.shared 表
|
||||
shared := L.NewTable()
|
||||
|
||||
// ngx.shared.DICT - 返回字典 userdata
|
||||
L.SetField(shared, "DICT", L.NewFunction(func(L *glua.LState) int {
|
||||
name := L.CheckString(1)
|
||||
|
||||
dict := manager.GetDict(name)
|
||||
if dict == nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString("shared dict not found: " + name))
|
||||
return 2
|
||||
}
|
||||
|
||||
// 返回字典 userdata
|
||||
ud := L.NewUserData()
|
||||
ud.Value = dict
|
||||
L.SetMetatable(ud, L.GetTypeMetatable("ngx.shared.dict"))
|
||||
L.Push(ud)
|
||||
return 1
|
||||
}))
|
||||
|
||||
L.SetField(ngx, "shared", shared)
|
||||
|
||||
// 创建字典类型元表
|
||||
mt := L.NewTypeMetatable("ngx.shared.dict")
|
||||
L.SetField(mt, "__index", L.NewFunction(dictIndex))
|
||||
L.SetField(mt, "__newindex", L.NewFunction(dictNewIndex))
|
||||
L.SetField(mt, "__tostring", L.NewFunction(dictToString))
|
||||
|
||||
// 注册字典方法
|
||||
methods := L.NewTable()
|
||||
L.SetField(methods, "get", L.NewFunction(dictGet))
|
||||
L.SetField(methods, "set", L.NewFunction(dictSet))
|
||||
L.SetField(methods, "add", L.NewFunction(dictAdd))
|
||||
L.SetField(methods, "replace", L.NewFunction(dictReplace))
|
||||
L.SetField(methods, "incr", L.NewFunction(dictIncr))
|
||||
L.SetField(methods, "delete", L.NewFunction(dictDelete))
|
||||
L.SetField(methods, "flush_all", L.NewFunction(dictFlushAll))
|
||||
L.SetField(methods, "flush_expired", L.NewFunction(dictFlushExpired))
|
||||
L.SetField(methods, "get_keys", L.NewFunction(dictGetKeys))
|
||||
L.SetField(methods, "size", L.NewFunction(dictSize))
|
||||
L.SetField(methods, "free_space", L.NewFunction(dictFreeSpace))
|
||||
L.SetField(mt, "methods", methods)
|
||||
}
|
||||
|
||||
// dictIndex 字典索引方法
|
||||
func dictIndex(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
|
||||
// 检查是否是方法
|
||||
methods := L.GetField(L.Get(1).(*glua.LUserData).Metatable, "methods")
|
||||
if method := L.GetField(methods, key); method != glua.LNil {
|
||||
L.Push(method)
|
||||
return 1
|
||||
}
|
||||
|
||||
// 否则作为 key 获取值
|
||||
value, expired, err := dict.Get(key)
|
||||
if err != nil {
|
||||
L.RaiseError("%s", err.Error())
|
||||
return 0
|
||||
}
|
||||
if expired {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString("expired"))
|
||||
return 2
|
||||
}
|
||||
if value == "" {
|
||||
L.Push(glua.LNil)
|
||||
return 1
|
||||
}
|
||||
L.Push(glua.LString(value))
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictNewIndex 字典设置方法
|
||||
func dictNewIndex(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
value := L.CheckString(3)
|
||||
|
||||
ok, err := dict.Set(key, value, 0)
|
||||
if err != nil {
|
||||
L.RaiseError("%s", err.Error())
|
||||
return 0
|
||||
}
|
||||
if !ok {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString("no memory"))
|
||||
return 2
|
||||
}
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictToString 字典字符串表示
|
||||
func dictToString(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.Push(glua.LString("invalid shared dict"))
|
||||
return 1
|
||||
}
|
||||
L.Push(glua.LString("ngx.shared.dict:" + dict.name))
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictGet 获取值
|
||||
// dict:get(key) -> value, flags | nil, err
|
||||
func dictGet(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
|
||||
value, expired, err := dict.Get(key)
|
||||
if err != nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
if expired {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString("expired"))
|
||||
return 2
|
||||
}
|
||||
if value == "" {
|
||||
L.Push(glua.LNil)
|
||||
return 1
|
||||
}
|
||||
L.Push(glua.LString(value))
|
||||
L.Push(glua.LNumber(0)) // flags(暂不支持)
|
||||
return 2
|
||||
}
|
||||
|
||||
// dictSet 设置值
|
||||
// dict:set(key, value, exptime?, flags?) -> ok, err
|
||||
func dictSet(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
value := L.CheckString(3)
|
||||
|
||||
ttl := time.Duration(0)
|
||||
if L.GetTop() >= 4 {
|
||||
ttl = time.Duration(L.CheckNumber(4)) * time.Second
|
||||
}
|
||||
|
||||
// flags 参数暂不使用
|
||||
|
||||
ok, err := dict.Set(key, value, ttl)
|
||||
if err != nil {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
if !ok {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString("no memory"))
|
||||
return 2
|
||||
}
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictAdd 添加值(不存在时)
|
||||
// dict:add(key, value, exptime?, flags?) -> ok, err
|
||||
func dictAdd(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
value := L.CheckString(3)
|
||||
|
||||
ttl := time.Duration(0)
|
||||
if L.GetTop() >= 4 {
|
||||
ttl = time.Duration(L.CheckNumber(4)) * time.Second
|
||||
}
|
||||
|
||||
ok, err := dict.Add(key, value, ttl)
|
||||
if err != nil {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
if !ok {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString("exists"))
|
||||
return 2
|
||||
}
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictReplace 替换值(存在时)
|
||||
// dict:replace(key, value, exptime?, flags?) -> ok, err
|
||||
func dictReplace(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
value := L.CheckString(3)
|
||||
|
||||
ttl := time.Duration(0)
|
||||
if L.GetTop() >= 4 {
|
||||
ttl = time.Duration(L.CheckNumber(4)) * time.Second
|
||||
}
|
||||
|
||||
// 检查是否存在
|
||||
_, expired, _ := dict.Get(key)
|
||||
if expired {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString("not found"))
|
||||
return 2
|
||||
}
|
||||
|
||||
ok, err := dict.Set(key, value, ttl)
|
||||
if err != nil {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictIncr 自增数值
|
||||
// dict:incr(key, value) -> new_value, err
|
||||
func dictIncr(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
increment := int(L.CheckNumber(3))
|
||||
|
||||
newValue, err := dict.Incr(key, increment)
|
||||
if err != nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(glua.LNumber(newValue))
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictDelete 删除条目
|
||||
// dict:delete(key) -> ok
|
||||
func dictDelete(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
key := L.CheckString(2)
|
||||
dict.Delete(key)
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictFlushAll 清空字典
|
||||
// dict:flush_all()
|
||||
func dictFlushAll(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
dict.FlushAll()
|
||||
return 0
|
||||
}
|
||||
|
||||
// dictFlushExpired 清除过期条目
|
||||
// dict:flush_expired(max_count?) -> flushed_count
|
||||
func dictFlushExpired(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
count := dict.FlushExpired()
|
||||
L.Push(glua.LNumber(count))
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictGetKeys 获取所有键
|
||||
// dict:get_keys(max_count?) -> keys
|
||||
func dictGetKeys(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
_, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
// 暂不实现完整版,返回空表
|
||||
keys := L.NewTable()
|
||||
L.Push(keys)
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictSize 获取条目数
|
||||
// dict:size() -> count
|
||||
func dictSize(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
L.Push(glua.LNumber(dict.Size()))
|
||||
return 1
|
||||
}
|
||||
|
||||
// dictFreeSpace 获取剩余容量
|
||||
// dict:free_space() -> slots
|
||||
func dictFreeSpace(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
dict, ok := ud.Value.(*SharedDict)
|
||||
if !ok {
|
||||
L.RaiseError("invalid shared dict")
|
||||
return 0
|
||||
}
|
||||
|
||||
L.Push(glua.LNumber(dict.FreeSlots()))
|
||||
return 1
|
||||
}
|
||||
306
internal/lua/api_timer.go
Normal file
306
internal/lua/api_timer.go
Normal file
@ -0,0 +1,306 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// TimerManager 定时器管理器
|
||||
type TimerManager struct {
|
||||
mu sync.Mutex
|
||||
timers map[uint64]*TimerEntry
|
||||
nextID uint64
|
||||
engine *LuaEngine
|
||||
active int32
|
||||
stopping int32
|
||||
}
|
||||
|
||||
// TimerEntry 定时器条目
|
||||
type TimerEntry struct {
|
||||
id uint64
|
||||
delay time.Duration
|
||||
callback *glua.LFunction
|
||||
args []glua.LValue
|
||||
timer *time.Timer
|
||||
cancel chan struct{}
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// TimerHandle 定时器句柄(Lua userdata)
|
||||
type TimerHandle struct {
|
||||
id uint64
|
||||
manager *TimerManager
|
||||
}
|
||||
|
||||
// NewTimerManager 创建定时器管理器
|
||||
func NewTimerManager(engine *LuaEngine) *TimerManager {
|
||||
return &TimerManager{
|
||||
timers: make(map[uint64]*TimerEntry),
|
||||
engine: engine,
|
||||
}
|
||||
}
|
||||
|
||||
// At 创建定时器
|
||||
// 返回定时器句柄和错误
|
||||
func (m *TimerManager) At(delay time.Duration, callback *glua.LFunction, args []glua.LValue) (*TimerHandle, error) {
|
||||
if atomic.LoadInt32(&m.stopping) != 0 {
|
||||
return nil, nil // 服务器正在关闭,不接受新定时器
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
id := atomic.AddUint64(&m.nextID, 1)
|
||||
|
||||
entry := &TimerEntry{
|
||||
id: id,
|
||||
delay: delay,
|
||||
callback: callback,
|
||||
args: args,
|
||||
cancel: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 设置定时器
|
||||
entry.timer = time.AfterFunc(delay, func() {
|
||||
m.executeTimer(entry)
|
||||
})
|
||||
|
||||
m.timers[id] = entry
|
||||
atomic.AddInt32(&m.active, 1)
|
||||
|
||||
return &TimerHandle{id: id, manager: m}, nil
|
||||
}
|
||||
|
||||
// executeTimer 执行定时器回调
|
||||
// 注意:由于 gopher-lua 不是线程安全的,定时器回调执行有限制
|
||||
// 当前简化版本仅支持记录定时器触发,不执行实际 Lua 回调
|
||||
func (m *TimerManager) executeTimer(entry *TimerEntry) {
|
||||
defer func() {
|
||||
atomic.AddInt32(&m.active, -1)
|
||||
close(entry.done)
|
||||
}()
|
||||
|
||||
// 检查是否被取消
|
||||
select {
|
||||
case <-entry.cancel:
|
||||
return // 已取消
|
||||
default:
|
||||
}
|
||||
|
||||
// 检查 engine 是否已关闭
|
||||
if m.engine == nil || m.engine.L == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 由于 gopher-lua 不是线程安全的,异步 goroutine 中不能直接调用 LState
|
||||
// 完整实现需要使用 channel 将回调调度到主线程执行
|
||||
// 这里简化处理:定时器触发后记录日志(生产环境应该有更好的方案)
|
||||
|
||||
// 清理定时器条目
|
||||
m.mu.Lock()
|
||||
if m.timers != nil {
|
||||
delete(m.timers, entry.id)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// Cancel 取消定时器
|
||||
func (m *TimerManager) Cancel(handle *TimerHandle) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
entry, ok := m.timers[handle.id]
|
||||
if !ok {
|
||||
return false // 定时器不存在或已执行
|
||||
}
|
||||
|
||||
// 停止定时器
|
||||
if entry.timer != nil {
|
||||
entry.timer.Stop()
|
||||
}
|
||||
|
||||
// 发送取消信号
|
||||
close(entry.cancel)
|
||||
|
||||
// 清理
|
||||
delete(m.timers, entry.id)
|
||||
atomic.AddInt32(&m.active, -1)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// WaitAll 等待所有定时器完成
|
||||
func (m *TimerManager) WaitAll(timeout time.Duration) bool {
|
||||
// 设置停止标志
|
||||
atomic.StoreInt32(&m.stopping, 1)
|
||||
|
||||
// 等待所有定时器完成
|
||||
start := time.Now()
|
||||
for atomic.LoadInt32(&m.active) > 0 {
|
||||
if time.Since(start) > timeout {
|
||||
// 超时,强制取消所有
|
||||
m.mu.Lock()
|
||||
for _, entry := range m.timers {
|
||||
if entry.timer != nil {
|
||||
entry.timer.Stop()
|
||||
}
|
||||
close(entry.cancel)
|
||||
}
|
||||
m.timers = make(map[uint64]*TimerEntry)
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Close 关闭定时器管理器
|
||||
func (m *TimerManager) Close() {
|
||||
m.WaitAll(5 * time.Second)
|
||||
}
|
||||
|
||||
// ActiveCount 返回活跃定时器数
|
||||
func (m *TimerManager) ActiveCount() int32 {
|
||||
return atomic.LoadInt32(&m.active)
|
||||
}
|
||||
|
||||
// RegisterTimerAPI 注册 ngx.timer API
|
||||
func RegisterTimerAPI(L *glua.LState, manager *TimerManager, ngx *glua.LTable) {
|
||||
// 创建 ngx.timer 表
|
||||
timer := L.NewTable()
|
||||
|
||||
// ngx.timer.at(delay, callback, ...)
|
||||
L.SetField(timer, "at", L.NewFunction(func(L *glua.LState) int {
|
||||
// 检查参数
|
||||
delay := float64(L.CheckNumber(1))
|
||||
callback := L.CheckFunction(2)
|
||||
|
||||
// 收集额外参数
|
||||
args := []glua.LValue{}
|
||||
for i := 3; i <= L.GetTop(); i++ {
|
||||
args = append(args, L.Get(i))
|
||||
}
|
||||
|
||||
// 创建定时器
|
||||
handle, err := manager.At(time.Duration(delay)*time.Second, callback, args)
|
||||
if err != nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
if handle == nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString("server shutting down"))
|
||||
return 2
|
||||
}
|
||||
|
||||
// 返回定时器句柄
|
||||
ud := L.NewUserData()
|
||||
ud.Value = handle
|
||||
L.SetMetatable(ud, L.GetTypeMetatable("ngx.timer.handle"))
|
||||
L.Push(ud)
|
||||
return 1
|
||||
}))
|
||||
|
||||
// ngx.timer.running_count()
|
||||
L.SetField(timer, "running_count", L.NewFunction(func(L *glua.LState) int {
|
||||
L.Push(glua.LNumber(manager.ActiveCount()))
|
||||
return 1
|
||||
}))
|
||||
|
||||
L.SetField(ngx, "timer", timer)
|
||||
|
||||
// 创建定时器句柄元表
|
||||
mt := L.NewTypeMetatable("ngx.timer.handle")
|
||||
L.SetField(mt, "__index", L.NewFunction(timerHandleIndex))
|
||||
L.SetField(mt, "__tostring", L.NewFunction(timerHandleToString))
|
||||
|
||||
// 注册方法
|
||||
methods := L.NewTable()
|
||||
L.SetField(methods, "cancel", L.NewFunction(timerHandleCancel))
|
||||
L.SetField(mt, "methods", methods)
|
||||
}
|
||||
|
||||
// timerHandleIndex 定时器句柄索引
|
||||
func timerHandleIndex(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
_, ok := ud.Value.(*TimerHandle)
|
||||
if !ok {
|
||||
L.RaiseError("invalid timer handle")
|
||||
return 0
|
||||
}
|
||||
|
||||
// 检查是否是方法
|
||||
methods := L.GetField(L.Get(1).(*glua.LUserData).Metatable, "methods")
|
||||
if method := L.GetField(methods, L.CheckString(2)); method != glua.LNil {
|
||||
L.Push(method)
|
||||
return 1
|
||||
}
|
||||
|
||||
L.Push(glua.LNil)
|
||||
return 1
|
||||
}
|
||||
|
||||
// timerHandleToString 定时器句柄字符串表示
|
||||
func timerHandleToString(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
handle, ok := ud.Value.(*TimerHandle)
|
||||
if !ok {
|
||||
L.Push(glua.LString("invalid timer handle"))
|
||||
return 1
|
||||
}
|
||||
L.Push(glua.LString("ngx.timer.handle:" + uint64ToStr(handle.id)))
|
||||
return 1
|
||||
}
|
||||
|
||||
// timerHandleCancel 取消定时器
|
||||
func timerHandleCancel(L *glua.LState) int {
|
||||
ud := L.CheckUserData(1)
|
||||
handle, ok := ud.Value.(*TimerHandle)
|
||||
if !ok {
|
||||
L.RaiseError("invalid timer handle")
|
||||
return 0
|
||||
}
|
||||
|
||||
if handle.manager == nil {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString("timer manager not available"))
|
||||
return 2
|
||||
}
|
||||
|
||||
ok = handle.manager.Cancel(handle)
|
||||
if ok {
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString("timer not found or already executed"))
|
||||
return 2
|
||||
}
|
||||
|
||||
// uint64ToStr 整数转字符串
|
||||
func uint64ToStr(n uint64) string {
|
||||
if n == 0 {
|
||||
return "0"
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
for n > 0 {
|
||||
buf = append(buf, byte('0'+n%10))
|
||||
n /= 10
|
||||
}
|
||||
|
||||
// 反转
|
||||
for i, j := 0, len(buf)-1; i < j; i, j = i+1, j-1 {
|
||||
buf[i], buf[j] = buf[j], buf[i]
|
||||
}
|
||||
|
||||
return string(buf)
|
||||
}
|
||||
149
internal/lua/api_timer_test.go
Normal file
149
internal/lua/api_timer_test.go
Normal file
@ -0,0 +1,149 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
func TestTimerManagerAt(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
manager := engine.TimerManager()
|
||||
require.NotNil(t, manager)
|
||||
|
||||
// 创建 Lua 函数作为回调
|
||||
L := engine.L
|
||||
|
||||
// 注册一个简单的回调函数
|
||||
callback := L.NewFunction(func(L *glua.LState) int {
|
||||
return 0
|
||||
})
|
||||
|
||||
// 创建定时器
|
||||
handle, err := manager.At(100*time.Millisecond, callback, nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, handle)
|
||||
|
||||
// 等待定时器触发
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// 定时器应该已完成(active count 回到 0)
|
||||
assert.Equal(t, int32(0), manager.ActiveCount())
|
||||
}
|
||||
|
||||
func TestTimerManagerCancel(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
manager := engine.TimerManager()
|
||||
|
||||
callback := engine.L.NewFunction(func(L *glua.LState) int {
|
||||
return 0
|
||||
})
|
||||
|
||||
// 创建定时器
|
||||
handle, err := manager.At(200*time.Millisecond, callback, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 立即取消
|
||||
ok := manager.Cancel(handle)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 等待超过定时器时间
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// 定时器应该被取消,active count 为 0
|
||||
assert.Equal(t, int32(0), manager.ActiveCount())
|
||||
}
|
||||
|
||||
func TestTimerManagerWaitAll(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := engine.TimerManager()
|
||||
|
||||
// 创建多个定时器
|
||||
for range 3 {
|
||||
callback := engine.L.NewFunction(func(L *glua.LState) int {
|
||||
return 0
|
||||
})
|
||||
manager.At(50*time.Millisecond, callback, nil)
|
||||
}
|
||||
|
||||
// 等待所有完成
|
||||
ok := manager.WaitAll(1 * time.Second)
|
||||
assert.True(t, ok)
|
||||
|
||||
// active count 应该回到 0
|
||||
assert.Equal(t, int32(0), manager.ActiveCount())
|
||||
|
||||
engine.Close()
|
||||
}
|
||||
|
||||
func TestTimerLuaAPI(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
L := engine.L
|
||||
|
||||
// 注册 ngx.timer API
|
||||
ngx := L.NewTable()
|
||||
L.SetGlobal("ngx", ngx)
|
||||
RegisterTimerAPI(L, engine.TimerManager(), ngx)
|
||||
|
||||
// 测试 ngx.timer.at
|
||||
err = L.DoString(`
|
||||
local count = 0
|
||||
|
||||
-- 创建定时器
|
||||
local handle, err = ngx.timer.at(0.1, function()
|
||||
count = count + 1
|
||||
end)
|
||||
|
||||
assert(handle ~= nil)
|
||||
assert(err == nil)
|
||||
|
||||
-- 检查 running_count
|
||||
local running = ngx.timer.running_count()
|
||||
assert(running >= 1)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestTimerRunningCount(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
manager := engine.TimerManager()
|
||||
|
||||
// 初始应该为 0
|
||||
assert.Equal(t, int32(0), manager.ActiveCount())
|
||||
|
||||
// 创建定时器
|
||||
callback := engine.L.NewFunction(func(L *glua.LState) int {
|
||||
return 0
|
||||
})
|
||||
|
||||
handle, _ := manager.At(50*time.Millisecond, callback, nil)
|
||||
_ = handle
|
||||
|
||||
// 刚创建后应该有活跃定时器(在定时器触发前)
|
||||
// 注意:由于简化实现,定时器执行很快,所以 active count 可能很快回到 0
|
||||
// 这里我们只验证定时器最终会完成
|
||||
|
||||
// 等待完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 应该回到 0
|
||||
assert.Equal(t, int32(0), manager.ActiveCount())
|
||||
}
|
||||
@ -167,7 +167,7 @@ func (c *LuaCoroutine) setupSecureCoroutineLib() {
|
||||
}
|
||||
|
||||
// setupNgxAPI 创建 ngx API
|
||||
// 注册 ngx.req、ngx.resp、ngx.var、ngx.ctx、ngx.log 和 ngx.socket API
|
||||
// 注册 ngx.req、ngx.resp、ngx.var、ngx.ctx、ngx.log、ngx.socket 和 ngx.shared API
|
||||
func (c *LuaCoroutine) setupNgxAPI() {
|
||||
// 创建 ngx 表
|
||||
ngx := c.Co.NewTable()
|
||||
@ -199,6 +199,15 @@ func (c *LuaCoroutine) setupNgxAPI() {
|
||||
|
||||
// 注册 ngx.socket API
|
||||
RegisterTCPSocketAPI(c.Co, c.Engine)
|
||||
|
||||
// 注册 ngx.shared.DICT API
|
||||
RegisterSharedDictAPI(c.Co, c.Engine.SharedDictManager(), ngx)
|
||||
|
||||
// 注册 ngx.timer API
|
||||
RegisterTimerAPI(c.Co, c.Engine.TimerManager(), ngx)
|
||||
|
||||
// 注册 ngx.location API
|
||||
RegisterLocationAPI(c.Co, c.Engine.LocationManager(), ngx)
|
||||
}
|
||||
|
||||
// Execute 在协程中执行 Lua 脚本(支持 Yield/Resume)
|
||||
|
||||
@ -38,6 +38,15 @@ type LuaEngine struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// 共享字典管理器
|
||||
sharedDictManager *SharedDictManager
|
||||
|
||||
// 定时器管理器
|
||||
timerManager *TimerManager
|
||||
|
||||
// location 管理器
|
||||
locationManager *LocationManager
|
||||
|
||||
// 统计
|
||||
stats EngineStats
|
||||
}
|
||||
@ -80,12 +89,13 @@ func NewEngine(config *Config) (*LuaEngine, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
engine := &LuaEngine{
|
||||
L: L,
|
||||
config: config,
|
||||
codeCache: NewCodeCache(config.CodeCacheSize, config.CodeCacheTTL, config.EnableFileWatch),
|
||||
maxCoroutines: config.MaxConcurrentCoroutines,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
L: L,
|
||||
config: config,
|
||||
codeCache: NewCodeCache(config.CodeCacheSize, config.CodeCacheTTL, config.EnableFileWatch),
|
||||
maxCoroutines: config.MaxConcurrentCoroutines,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
sharedDictManager: NewSharedDictManager(),
|
||||
coroutinePool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
// 注意:这里只是创建空的协程对象结构
|
||||
@ -95,12 +105,24 @@ func NewEngine(config *Config) (*LuaEngine, error) {
|
||||
},
|
||||
}
|
||||
|
||||
// 创建定时器管理器(需要在 engine 创建后初始化)
|
||||
engine.timerManager = NewTimerManager(engine)
|
||||
|
||||
// 创建 location 管理器
|
||||
engine.locationManager = NewLocationManager()
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// Close 关闭引擎
|
||||
func (e *LuaEngine) Close() {
|
||||
e.cancel()
|
||||
if e.timerManager != nil {
|
||||
e.timerManager.Close()
|
||||
}
|
||||
if e.sharedDictManager != nil {
|
||||
e.sharedDictManager.Close()
|
||||
}
|
||||
if e.L != nil {
|
||||
e.L.Close()
|
||||
}
|
||||
@ -188,3 +210,23 @@ func (e *LuaEngine) Stats() EngineStats {
|
||||
func (e *LuaEngine) ActiveCoroutines() int32 {
|
||||
return atomic.LoadInt32(&e.activeCount)
|
||||
}
|
||||
|
||||
// SharedDictManager 返回共享字典管理器
|
||||
func (e *LuaEngine) SharedDictManager() *SharedDictManager {
|
||||
return e.sharedDictManager
|
||||
}
|
||||
|
||||
// CreateSharedDict 创建共享字典
|
||||
func (e *LuaEngine) CreateSharedDict(name string, maxItems int) *SharedDict {
|
||||
return e.sharedDictManager.CreateDict(name, maxItems)
|
||||
}
|
||||
|
||||
// TimerManager 返回定时器管理器
|
||||
func (e *LuaEngine) TimerManager() *TimerManager {
|
||||
return e.timerManager
|
||||
}
|
||||
|
||||
// LocationManager 返回 location 管理器
|
||||
func (e *LuaEngine) LocationManager() *LocationManager {
|
||||
return e.locationManager
|
||||
}
|
||||
|
||||
324
internal/lua/shared_dict.go
Normal file
324
internal/lua/shared_dict.go
Normal file
@ -0,0 +1,324 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SharedDict 共享内存字典
|
||||
// 支持并发安全的 key-value 存储,带 LRU 汰出策略
|
||||
type SharedDict struct {
|
||||
name string
|
||||
maxItems int
|
||||
mu sync.Mutex
|
||||
data map[string]*sharedDictEntry
|
||||
lruList *list.List // LRU 链表,前端为最近使用
|
||||
}
|
||||
|
||||
// sharedDictEntry 字典条目
|
||||
type sharedDictEntry struct {
|
||||
key string
|
||||
value string
|
||||
expiredAt time.Time // 过期时间(0 表示永不过期)
|
||||
element *list.Element // LRU 链表元素指针
|
||||
}
|
||||
|
||||
// NewSharedDict 创建共享字典
|
||||
func NewSharedDict(name string, maxItems int) *SharedDict {
|
||||
return &SharedDict{
|
||||
name: name,
|
||||
maxItems: maxItems,
|
||||
data: make(map[string]*sharedDictEntry),
|
||||
lruList: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// Get 获取值
|
||||
// 返回 value, expired, err
|
||||
func (d *SharedDict) Get(key string) (string, bool, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
entry, ok := d.data[key]
|
||||
if !ok {
|
||||
return "", false, nil // 不存在
|
||||
}
|
||||
|
||||
// 检查过期
|
||||
if !entry.expiredAt.IsZero() && time.Now().After(entry.expiredAt) {
|
||||
// 已过期,删除并返回
|
||||
d.deleteEntry(entry)
|
||||
return "", true, nil // 存在但已过期
|
||||
}
|
||||
|
||||
// 更新 LRU - 移到前端
|
||||
d.lruList.MoveToFront(entry.element)
|
||||
|
||||
return entry.value, false, nil
|
||||
}
|
||||
|
||||
// Set 设置值
|
||||
// 返回 ok, err (ok=false 表示容量满且无法淘汰)
|
||||
func (d *SharedDict) Set(key, value string, ttl time.Duration) (bool, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
// 检查是否已存在
|
||||
if entry, ok := d.data[key]; ok {
|
||||
// 更新现有条目
|
||||
entry.value = value
|
||||
if ttl > 0 {
|
||||
entry.expiredAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
entry.expiredAt = time.Time{} // 清除过期时间
|
||||
}
|
||||
d.lruList.MoveToFront(entry.element)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 新条目,检查容量
|
||||
if len(d.data) >= d.maxItems {
|
||||
// 尝试淘汰过期条目
|
||||
d.evictExpired()
|
||||
if len(d.data) >= d.maxItems {
|
||||
// 淘汰 LRU 最久未使用的条目
|
||||
if !d.evictLRU() {
|
||||
return false, nil // 无法淘汰(字典为空?)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新条目
|
||||
expiredAt := time.Time{}
|
||||
if ttl > 0 {
|
||||
expiredAt = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
element := d.lruList.PushFront(key)
|
||||
entry := &sharedDictEntry{
|
||||
key: key,
|
||||
value: value,
|
||||
expiredAt: expiredAt,
|
||||
element: element,
|
||||
}
|
||||
d.data[key] = entry
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Add 添加值(仅在不存在时设置)
|
||||
// 返回 ok, err (ok=false 表示已存在或容量满)
|
||||
func (d *SharedDict) Add(key, value string, ttl time.Duration) (bool, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
// 检查是否已存在(包括过期的也算存在)
|
||||
if _, ok := d.data[key]; ok {
|
||||
return false, nil // 已存在
|
||||
}
|
||||
|
||||
// 检查容量并淘汰
|
||||
if len(d.data) >= d.maxItems {
|
||||
d.evictExpired()
|
||||
if len(d.data) >= d.maxItems {
|
||||
if !d.evictLRU() {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新条目
|
||||
expiredAt := time.Time{}
|
||||
if ttl > 0 {
|
||||
expiredAt = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
element := d.lruList.PushFront(key)
|
||||
entry := &sharedDictEntry{
|
||||
key: key,
|
||||
value: value,
|
||||
expiredAt: expiredAt,
|
||||
element: element,
|
||||
}
|
||||
d.data[key] = entry
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Incr 自增数值
|
||||
// 返回 new_value, err
|
||||
func (d *SharedDict) Incr(key string, increment int) (int, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
entry, ok := d.data[key]
|
||||
if !ok {
|
||||
// 不存在,创建初始值
|
||||
if len(d.data) >= d.maxItems {
|
||||
d.evictExpired()
|
||||
if len(d.data) >= d.maxItems {
|
||||
if !d.evictLRU() {
|
||||
return 0, nil // 无法创建
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
element := d.lruList.PushFront(key)
|
||||
entry = &sharedDictEntry{
|
||||
key: key,
|
||||
value: "0",
|
||||
element: element,
|
||||
}
|
||||
d.data[key] = entry
|
||||
}
|
||||
|
||||
// 解析数值
|
||||
var current int
|
||||
for _, c := range entry.value {
|
||||
if c < '0' || c > '9' {
|
||||
return 0, nil // 不是数值
|
||||
}
|
||||
current = current*10 + int(c-'0')
|
||||
}
|
||||
|
||||
newValue := current + increment
|
||||
entry.value = intToStr(newValue)
|
||||
d.lruList.MoveToFront(entry.element)
|
||||
|
||||
return newValue, nil
|
||||
}
|
||||
|
||||
// Delete 删除条目
|
||||
func (d *SharedDict) Delete(key string) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if entry, ok := d.data[key]; ok {
|
||||
d.deleteEntry(entry)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FlushAll 清空所有条目
|
||||
func (d *SharedDict) FlushAll() error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.data = make(map[string]*sharedDictEntry)
|
||||
d.lruList = list.New()
|
||||
return nil
|
||||
}
|
||||
|
||||
// FlushExpired 清除所有过期条目
|
||||
// 返回清除的条目数
|
||||
func (d *SharedDict) FlushExpired() int {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
return d.evictExpired()
|
||||
}
|
||||
|
||||
// Size 返回当前条目数
|
||||
func (d *SharedDict) Size() int {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
return len(d.data)
|
||||
}
|
||||
|
||||
// FreeSlots 返回剩余容量
|
||||
func (d *SharedDict) FreeSlots() int {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
return d.maxItems - len(d.data)
|
||||
}
|
||||
|
||||
// deleteEntry 删除条目(内部方法,已持有锁)
|
||||
func (d *SharedDict) deleteEntry(entry *sharedDictEntry) {
|
||||
d.lruList.Remove(entry.element)
|
||||
delete(d.data, entry.key)
|
||||
}
|
||||
|
||||
// evictExpired 淘汰过期条目(内部方法,已持有锁)
|
||||
func (d *SharedDict) evictExpired() int {
|
||||
now := time.Now()
|
||||
count := 0
|
||||
|
||||
// 从 LRU 链表尾部(最久未使用)开始检查
|
||||
for elem := d.lruList.Back(); elem != nil; {
|
||||
key := elem.Value.(string)
|
||||
entry, ok := d.data[key]
|
||||
if !ok {
|
||||
// 数据不一致,跳过
|
||||
next := elem.Prev()
|
||||
d.lruList.Remove(elem)
|
||||
elem = next
|
||||
continue
|
||||
}
|
||||
|
||||
if !entry.expiredAt.IsZero() && now.After(entry.expiredAt) {
|
||||
// 已过期,删除
|
||||
d.deleteEntry(entry)
|
||||
count++
|
||||
elem = d.lruList.Back() // 重新从尾部开始
|
||||
} else {
|
||||
break // 未过期,停止(链表顺序保证前面都是未过期的)
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
// evictLRU 淘汰 LRU 最久未使用的条目(内部方法,已持有锁)
|
||||
func (d *SharedDict) evictLRU() bool {
|
||||
if d.lruList.Len() == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
elem := d.lruList.Back()
|
||||
if elem == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
key := elem.Value.(string)
|
||||
entry, ok := d.data[key]
|
||||
if ok {
|
||||
d.deleteEntry(entry)
|
||||
return true
|
||||
}
|
||||
|
||||
// 数据不一致,移除链表元素并重试
|
||||
d.lruList.Remove(elem)
|
||||
return d.evictLRU()
|
||||
}
|
||||
|
||||
// intToStr 整数转字符串(简单实现,避免 strconv 依赖)
|
||||
func intToStr(n int) string {
|
||||
if n == 0 {
|
||||
return "0"
|
||||
}
|
||||
|
||||
var negative bool
|
||||
if n < 0 {
|
||||
negative = true
|
||||
n = -n
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
for n > 0 {
|
||||
buf = append(buf, byte('0'+n%10))
|
||||
n /= 10
|
||||
}
|
||||
|
||||
if negative {
|
||||
buf = append(buf, '-')
|
||||
}
|
||||
|
||||
// 反转
|
||||
for i, j := 0, len(buf)-1; i < j; i, j = i+1, j-1 {
|
||||
buf[i], buf[j] = buf[j], buf[i]
|
||||
}
|
||||
|
||||
return string(buf)
|
||||
}
|
||||
251
internal/lua/shared_dict_test.go
Normal file
251
internal/lua/shared_dict_test.go
Normal file
@ -0,0 +1,251 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSharedDictGetSet(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
// Set
|
||||
ok, err := dict.Set("key1", "value1", 0)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Get
|
||||
value, expired, err := dict.Get("key1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, expired)
|
||||
assert.Equal(t, "value1", value)
|
||||
|
||||
// 不存在的 key
|
||||
value, expired, err = dict.Get("notexist")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, expired)
|
||||
assert.Equal(t, "", value)
|
||||
}
|
||||
|
||||
func TestSharedDictAdd(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
// Add 新 key
|
||||
ok, err := dict.Add("key1", "value1", 0)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Add 已存在的 key
|
||||
ok, err = dict.Add("key1", "value2", 0)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok) // 已存在,返回 false
|
||||
|
||||
// 验证值未被修改
|
||||
value, _, _ := dict.Get("key1")
|
||||
assert.Equal(t, "value1", value)
|
||||
}
|
||||
|
||||
func TestSharedDictIncr(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
// Incr 不存在的 key,从 0 开始
|
||||
newValue, err := dict.Incr("counter", 5)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, newValue)
|
||||
|
||||
// Incr 已存在的 key
|
||||
newValue, err = dict.Incr("counter", 3)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 8, newValue)
|
||||
|
||||
// 轮减
|
||||
newValue, err = dict.Incr("counter", -2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 6, newValue)
|
||||
}
|
||||
|
||||
func TestSharedDictDelete(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
dict.Set("key1", "value1", 0)
|
||||
dict.Delete("key1")
|
||||
|
||||
value, _, _ := dict.Get("key1")
|
||||
assert.Equal(t, "", value)
|
||||
|
||||
// 删除不存在的 key 不会报错
|
||||
dict.Delete("notexist")
|
||||
}
|
||||
|
||||
func TestSharedDictTTL(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
// Set 带 TTL
|
||||
ok, err := dict.Set("key1", "value1", 100*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// 立即获取应该成功
|
||||
value, expired, err := dict.Get("key1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, expired)
|
||||
assert.Equal(t, "value1", value)
|
||||
|
||||
// 等待过期
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// 过期后获取
|
||||
value, expired, err = dict.Get("key1")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, expired)
|
||||
assert.Equal(t, "", value)
|
||||
}
|
||||
|
||||
func TestSharedDictLRUEviction(t *testing.T) {
|
||||
dict := NewSharedDict("test", 3) // 只有 3 个容量
|
||||
|
||||
// 添加 3 个条目
|
||||
dict.Set("key1", "value1", 0)
|
||||
dict.Set("key2", "value2", 0)
|
||||
dict.Set("key3", "value3", 0)
|
||||
|
||||
// 使用 key1,使其成为最近使用
|
||||
dict.Get("key1")
|
||||
|
||||
// 添加第 4 个条目,应该淘汰 key2(最久未使用)
|
||||
ok, err := dict.Set("key4", "value4", 0)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// key1 应该还在
|
||||
value, _, _ := dict.Get("key1")
|
||||
assert.Equal(t, "value1", value)
|
||||
|
||||
// key2 应该被淘汰
|
||||
value, _, _ = dict.Get("key2")
|
||||
assert.Equal(t, "", value)
|
||||
}
|
||||
|
||||
func TestSharedDictFlushAll(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
dict.Set("key1", "value1", 0)
|
||||
dict.Set("key2", "value2", 0)
|
||||
|
||||
dict.FlushAll()
|
||||
|
||||
assert.Equal(t, 0, dict.Size())
|
||||
}
|
||||
|
||||
func TestSharedDictFlushExpired(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
dict.Set("key1", "value1", 100*time.Millisecond)
|
||||
dict.Set("key2", "value2", 100*time.Millisecond)
|
||||
dict.Set("key3", "value3", 0) // 不过期
|
||||
|
||||
// 立即清除应该返回 0
|
||||
count := dict.FlushExpired()
|
||||
assert.Equal(t, 0, count)
|
||||
|
||||
// 等待过期
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
count = dict.FlushExpired()
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
// key3 应该还在
|
||||
assert.Equal(t, 1, dict.Size())
|
||||
value, _, _ := dict.Get("key3")
|
||||
assert.Equal(t, "value3", value)
|
||||
}
|
||||
|
||||
func TestSharedDictSize(t *testing.T) {
|
||||
dict := NewSharedDict("test", 100)
|
||||
|
||||
assert.Equal(t, 0, dict.Size())
|
||||
assert.Equal(t, 100, dict.FreeSlots())
|
||||
|
||||
dict.Set("key1", "value1", 0)
|
||||
assert.Equal(t, 1, dict.Size())
|
||||
assert.Equal(t, 99, dict.FreeSlots())
|
||||
}
|
||||
|
||||
func TestSharedDictManager(t *testing.T) {
|
||||
manager := NewSharedDictManager()
|
||||
|
||||
// 创建字典
|
||||
dict1 := manager.CreateDict("dict1", 100)
|
||||
require.NotNil(t, dict1)
|
||||
|
||||
// 再次获取同一个字典
|
||||
dict1Again := manager.GetDict("dict1")
|
||||
assert.Equal(t, dict1, dict1Again)
|
||||
|
||||
// 创建另一个字典
|
||||
dict2 := manager.CreateDict("dict2", 200)
|
||||
require.NotNil(t, dict2)
|
||||
|
||||
// 获取不存在的字典
|
||||
notexist := manager.GetDict("notexist")
|
||||
assert.Nil(t, notexist)
|
||||
|
||||
// 关闭
|
||||
manager.Close()
|
||||
}
|
||||
|
||||
func TestSharedDictLuaAPI(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
// 创建共享字典(通过 Lua API 测试,此处仅为初始化)
|
||||
_ = engine.CreateSharedDict("mydict", 100)
|
||||
|
||||
// 测试 Lua 脚本
|
||||
L := engine.L
|
||||
|
||||
// 手动注册 ngx.shared API(用于测试)
|
||||
ngx := L.NewTable()
|
||||
L.SetGlobal("ngx", ngx)
|
||||
RegisterSharedDictAPI(L, engine.SharedDictManager(), ngx)
|
||||
|
||||
// 运行 Lua 脚本测试
|
||||
err = L.DoString(`
|
||||
local dict = ngx.shared.DICT("mydict")
|
||||
|
||||
-- 测试 set/get
|
||||
dict:set("key1", "value1")
|
||||
local val, flags = dict:get("key1")
|
||||
assert(val == "value1")
|
||||
|
||||
-- 测试 add
|
||||
local ok, err = dict:add("key2", "value2")
|
||||
assert(ok == true)
|
||||
|
||||
-- 测试 add 已存在的 key
|
||||
ok, err = dict:add("key2", "value3")
|
||||
assert(ok == false)
|
||||
assert(err == "exists")
|
||||
|
||||
-- 测试 incr
|
||||
local new_val, err = dict:incr("counter", 10)
|
||||
assert(new_val == 10)
|
||||
|
||||
new_val, err = dict:incr("counter", 5)
|
||||
assert(new_val == 15)
|
||||
|
||||
-- 测试 size
|
||||
local size = dict:size()
|
||||
assert(size >= 3)
|
||||
|
||||
-- 测试 delete
|
||||
dict:delete("key1")
|
||||
local val, err = dict:get("key1")
|
||||
assert(val == nil)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user