feat(lua): 增强变量 API 支持数值类型返回和测试覆盖
- 新增 getVariableLua 方法返回 Lua 类型而非字符串 - request_length 等变量返回数值类型而非字符串 - luaVarNewIndex 支持 nil 值转换为空字符串 - 添加 api_var 全面的单元测试覆盖 - 添加 api_ctx 上下文 API 测试 - 添加 api_socket_tcp TCP socket 测试 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
bec8932561
commit
f123018f2d
@ -7,6 +7,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// TestNgxCtxAPI 测试 ngx.ctx API 基础功能
|
||||
@ -365,3 +366,335 @@ func TestNgxCtxNestedTable(t *testing.T) {
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxCtxRequestIsolation 测试请求间上下文隔离
|
||||
func TestNgxCtxRequestIsolation(t *testing.T) {
|
||||
var req1, req2 fasthttp.Request
|
||||
req1.Header.SetMethod("GET")
|
||||
req1.Header.SetRequestURI("/request1")
|
||||
req2.Header.SetMethod("GET")
|
||||
req2.Header.SetRequestURI("/request2")
|
||||
|
||||
ctx1 := &fasthttp.RequestCtx{}
|
||||
ctx1.Init(&req1, nil, nil)
|
||||
ctx2 := &fasthttp.RequestCtx{}
|
||||
ctx2.Init(&req2, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
// 第一个请求:设置 ctx 值
|
||||
coro1, err := engine.NewCoroutine(ctx1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = coro1.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = coro1.Execute(`
|
||||
ngx.ctx.request_id = 1
|
||||
ngx.ctx.message = "request1_data"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
coro1.Close()
|
||||
|
||||
// 第二个请求:验证 ctx 与其他请求隔离
|
||||
coro2, err := engine.NewCoroutine(ctx2)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = coro2.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = coro2.Execute(`
|
||||
-- 第一个请求的值不应该影响第二个请求
|
||||
if ngx.ctx.request_id ~= nil then
|
||||
error("ctx from another request should be isolated")
|
||||
end
|
||||
|
||||
if ngx.ctx.message ~= nil then
|
||||
error("ctx from another request should be isolated")
|
||||
end
|
||||
|
||||
-- 设置自己的值
|
||||
ngx.ctx.request_id = 2
|
||||
ngx.ctx.message = "request2_data"
|
||||
|
||||
if ngx.ctx.request_id ~= 2 then
|
||||
error("request_id should be 2")
|
||||
end
|
||||
|
||||
if ngx.ctx.message ~= "request2_data" then
|
||||
error("message should be 'request2_data'")
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
coro2.Close()
|
||||
}
|
||||
|
||||
// TestNgxCtxGoAPIAccess 测试 Go 层 API 访问
|
||||
func TestNgxCtxGoAPIAccess(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/test")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 通过 Go 层 API 设置值
|
||||
api := coro.GetNgxVarAPI()
|
||||
require.NotNil(t, api)
|
||||
|
||||
api.SetVariable("go_key", "go_value")
|
||||
|
||||
// 验证 Lua 层可以读取,且 Lua 层设置的值 Go 层可见
|
||||
// 注意:在一个脚本中完成所有操作,因为协程执行后变为 dead 状态
|
||||
err = coro.Execute(`
|
||||
-- 验证 Go 层设置的值
|
||||
if ngx.var.go_key ~= "go_value" then
|
||||
error("value from Go layer should be accessible in Lua")
|
||||
end
|
||||
|
||||
-- 从 Lua 层设置值
|
||||
ngx.var.lua_key = "lua_value"
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证 Lua 层设置的值 Go 层可见
|
||||
val, ok := api.GetVariable("lua_key")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "lua_value", val)
|
||||
}
|
||||
|
||||
// TestNgxCtxScheduleUnsafeAPI 测试调度器上下文中的不安全 API
|
||||
func TestNgxCtxScheduleUnsafeAPI(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
// 为 Scheduler LState 创建不安全的 ctx API
|
||||
L := glua.NewState()
|
||||
defer L.Close()
|
||||
|
||||
ngx := L.NewTable()
|
||||
L.SetGlobal("ngx", ngx)
|
||||
|
||||
// 注册调度器不安全的 ctx API
|
||||
RegisterSchedulerUnsafeCtxAPI(L, ngx)
|
||||
|
||||
// 尝试访问 ctx 应该返回错误
|
||||
err = L.DoString(`
|
||||
local ok, msg = pcall(function()
|
||||
ngx.ctx.key = "value"
|
||||
end)
|
||||
if ok then
|
||||
error("writing to ngx.ctx in scheduler should fail")
|
||||
end
|
||||
if not string.match(msg, "not available in timer callback") then
|
||||
error("wrong error message: " .. msg)
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 尝试读取 ctx 也应该返回错误
|
||||
err = L.DoString(`
|
||||
local ok, msg = pcall(function()
|
||||
local x = ngx.ctx.key
|
||||
end)
|
||||
if ok then
|
||||
error("reading from ngx.ctx in scheduler should fail")
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxCtxTableAPI 测试 table API 操作
|
||||
func TestNgxCtxTableAPI(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/test")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试 table 操作函数
|
||||
err = coro.Execute(`
|
||||
-- 测试 pairs 遍历
|
||||
ngx.ctx.items = {a = 1, b = 2, c = 3}
|
||||
local count = 0
|
||||
for k, v in pairs(ngx.ctx.items) do
|
||||
count = count + 1
|
||||
end
|
||||
if count ~= 3 then
|
||||
error("items table should have 3 elements")
|
||||
end
|
||||
|
||||
-- 测试 table.insert
|
||||
ngx.ctx.list = {}
|
||||
table.insert(ngx.ctx.list, 1)
|
||||
table.insert(ngx.ctx.list, 2)
|
||||
table.insert(ngx.ctx.list, 3)
|
||||
if ngx.ctx.list[1] ~= 1 or ngx.ctx.list[2] ~= 2 or ngx.ctx.list[3] ~= 3 then
|
||||
error("table.insert failed")
|
||||
end
|
||||
|
||||
-- 测试 table.remove
|
||||
table.remove(ngx.ctx.list, 2)
|
||||
if #ngx.ctx.list ~= 2 or ngx.ctx.list[1] ~= 1 or ngx.ctx.list[2] ~= 3 then
|
||||
error("table.remove failed")
|
||||
end
|
||||
|
||||
-- 测试 table.concat
|
||||
ngx.ctx.strlist = {"hello", "world", "test"}
|
||||
local joined = table.concat(ngx.ctx.strlist, ", ")
|
||||
if joined ~= "hello, world, test" then
|
||||
error("table.concat failed: " .. joined)
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxCtxLargeValues 测试大值存储
|
||||
func TestNgxCtxLargeValues(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/test")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试大字符串和大 table(在一个脚本中完成所有操作)
|
||||
largeString := string(make([]byte, 10000)) // 10KB 字符串
|
||||
err = coro.Execute(`
|
||||
-- 测试大字符串
|
||||
ngx.ctx.large = "` + largeString + `"
|
||||
local val = ngx.ctx.large
|
||||
if type(val) ~= "string" then
|
||||
error("large value should be string")
|
||||
end
|
||||
|
||||
-- 测试大 table
|
||||
ngx.ctx.bigtable = {}
|
||||
for i = 1, 1000 do
|
||||
ngx.ctx.bigtable[i] = i * 2
|
||||
end
|
||||
if #ngx.ctx.bigtable ~= 1000 then
|
||||
error("bigtable should have 1000 elements")
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxCtxTypeCoercion 测试类型转换
|
||||
func TestNgxCtxTypeCoercion(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/test")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试数字字符串自动转换
|
||||
err = coro.Execute(`
|
||||
ngx.ctx.num = 42
|
||||
ngx.ctx.str = "123"
|
||||
|
||||
-- 数字加字符串
|
||||
local result = ngx.ctx.num + tonumber(ngx.ctx.str)
|
||||
if result ~= 165 then
|
||||
error("type coercion failed: " .. tostring(result))
|
||||
end
|
||||
|
||||
-- 字符串连接
|
||||
local concatenated = "value: " .. ngx.ctx.num
|
||||
if concatenated ~= "value: 42" then
|
||||
error("string concatenation failed: " .. concatenated)
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxCtxBooleanLogic 测试布尔逻辑
|
||||
func TestNgxCtxBooleanLogic(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/test")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = coro.Execute(`
|
||||
ngx.ctx.a = true
|
||||
ngx.ctx.b = false
|
||||
|
||||
-- and 操作
|
||||
if (ngx.ctx.a and ngx.ctx.b) ~= false then
|
||||
error("a and b should be false")
|
||||
end
|
||||
|
||||
-- or 操作
|
||||
if (ngx.ctx.a or ngx.ctx.b) ~= true then
|
||||
error("a or b should be true")
|
||||
end
|
||||
|
||||
-- not 操作
|
||||
if (not ngx.ctx.a) ~= false then
|
||||
error("not a should be false")
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
1256
internal/lua/api_socket_tcp_test.go
Normal file
1256
internal/lua/api_socket_tcp_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -61,9 +61,10 @@ func (api *ngxVarAPI) luaVarIndex(L *glua.LState) int {
|
||||
}
|
||||
|
||||
// 2. 从 fasthttp RequestCtx 获取变量
|
||||
value := api.getVariable(key)
|
||||
if value != "" {
|
||||
L.Push(glua.LString(value))
|
||||
// 某些变量需要返回数值类型(如 request_length)
|
||||
lv := api.getVariableLua(key)
|
||||
if lv != nil {
|
||||
L.Push(lv)
|
||||
return 1
|
||||
}
|
||||
|
||||
@ -74,28 +75,123 @@ func (api *ngxVarAPI) luaVarIndex(L *glua.LState) int {
|
||||
|
||||
// luaVarNewIndex 实现 ngx.var[key] = value 写入
|
||||
// Lua 调用: ngx.var.key = value 或 ngx.var[key] = value
|
||||
// 注意:Lua 的 nil 会被转换为空字符串存储
|
||||
func (api *ngxVarAPI) luaVarNewIndex(L *glua.LState) int {
|
||||
// 第一个参数是表本身(ngx.var)
|
||||
// 第二个参数是键名
|
||||
// 第三个参数是值
|
||||
key := L.CheckString(2)
|
||||
value := L.CheckString(3)
|
||||
value := L.OptString(3, "")
|
||||
|
||||
// 存储到自定义变量存储
|
||||
// 存储到自定义变量存储(nil 会转换为空字符串)
|
||||
api.store[key] = value
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// getVariable 从 fasthttp RequestCtx 获取变量值
|
||||
// 支持常见的 nginx 变量
|
||||
// getVariableLua 从 fasthttp RequestCtx 获取变量值,返回 Lua 类型
|
||||
// 支持常见的 nginx 变量,某些变量返回数值类型
|
||||
func (api *ngxVarAPI) getVariableLua(name string) glua.LValue {
|
||||
if api.ctx == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch name {
|
||||
// HTTP 请求相关 - 数值类型
|
||||
case "request_length":
|
||||
return glua.LNumber(api.ctx.Request.Header.ContentLength())
|
||||
|
||||
// HTTP 请求相关 - 字符串类型
|
||||
case "request_method":
|
||||
return glua.LString(string(api.ctx.Method()))
|
||||
case "request_uri":
|
||||
return glua.LString(string(api.ctx.RequestURI()))
|
||||
case "uri":
|
||||
return glua.LString(string(api.ctx.URI().Path()))
|
||||
case "document_uri":
|
||||
return glua.LString(string(api.ctx.URI().Path()))
|
||||
case "query_string", "args":
|
||||
return glua.LString(string(api.ctx.URI().QueryString()))
|
||||
case "server_protocol", "protocol":
|
||||
return glua.LString(string(api.ctx.Request.Header.Protocol()))
|
||||
case "scheme":
|
||||
return glua.LString(string(api.ctx.URI().Scheme()))
|
||||
case "request_time":
|
||||
// 简化实现,返回空字符串
|
||||
return glua.LString("")
|
||||
|
||||
// 请求头相关
|
||||
case "http_host":
|
||||
return glua.LString(string(api.ctx.Host()))
|
||||
case "http_user_agent", "http_user-agent":
|
||||
return glua.LString(string(api.ctx.UserAgent()))
|
||||
case "http_referer":
|
||||
return glua.LString(string(api.ctx.Referer()))
|
||||
case "http_accept":
|
||||
return glua.LString(string(api.ctx.Request.Header.Peek("Accept")))
|
||||
case "http_accept_encoding", "http_accept-encoding":
|
||||
return glua.LString(string(api.ctx.Request.Header.Peek("Accept-Encoding")))
|
||||
case "http_accept_language", "http_accept-language":
|
||||
return glua.LString(string(api.ctx.Request.Header.Peek("Accept-Language")))
|
||||
case "http_connection":
|
||||
return glua.LString(string(api.ctx.Request.Header.Peek("Connection")))
|
||||
case "http_content_type", "http_content-type":
|
||||
return glua.LString(string(api.ctx.Request.Header.ContentType()))
|
||||
case "http_content_length", "http_content-length":
|
||||
return glua.LString(string(api.ctx.Request.Header.Peek("Content-Length")))
|
||||
|
||||
// 客户端信息
|
||||
case "remote_addr":
|
||||
return glua.LString(api.ctx.RemoteAddr().String())
|
||||
case "remote_port":
|
||||
addr := api.ctx.RemoteAddr()
|
||||
if addr != nil {
|
||||
// 简化处理,实际可能需要解析端口
|
||||
return glua.LString("")
|
||||
}
|
||||
return glua.LString("")
|
||||
case "binary_remote_addr":
|
||||
return glua.LString("")
|
||||
|
||||
// 服务器信息
|
||||
case "server_addr":
|
||||
addr := api.ctx.LocalAddr()
|
||||
if addr != nil {
|
||||
return glua.LString(addr.String())
|
||||
}
|
||||
return glua.LString("")
|
||||
case "server_port":
|
||||
return glua.LString("")
|
||||
case "server_name":
|
||||
return glua.LString(string(api.ctx.Host()))
|
||||
|
||||
// URI 参数
|
||||
case "arg_":
|
||||
// 获取所有参数
|
||||
return glua.LString(string(api.ctx.URI().QueryString()))
|
||||
default:
|
||||
// 检查是否是 arg_ 开头的参数
|
||||
if len(name) > 4 && name[:4] == "arg_" {
|
||||
paramName := name[4:]
|
||||
return glua.LString(string(api.ctx.QueryArgs().Peek(paramName)))
|
||||
}
|
||||
// 检查是否是 http_ 开头的请求头
|
||||
if len(name) > 5 && name[:5] == "http_" {
|
||||
headerName := name[5:]
|
||||
return glua.LString(string(api.ctx.Request.Header.Peek(headerName)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// getVariable 从 fasthttp RequestCtx 获取变量值(字符串形式)
|
||||
// 用于 Go 层调用
|
||||
func (api *ngxVarAPI) getVariable(name string) string {
|
||||
if api.ctx == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch name {
|
||||
// HTTP 请求相关
|
||||
case "request_method":
|
||||
return string(api.ctx.Method())
|
||||
case "request_uri":
|
||||
@ -113,10 +209,7 @@ func (api *ngxVarAPI) getVariable(name string) string {
|
||||
case "request_length":
|
||||
return strconv.Itoa(api.ctx.Request.Header.ContentLength())
|
||||
case "request_time":
|
||||
// 简化实现,返回空字符串
|
||||
return ""
|
||||
|
||||
// 请求头相关
|
||||
case "http_host":
|
||||
return string(api.ctx.Host())
|
||||
case "http_user_agent", "http_user-agent":
|
||||
@ -135,21 +228,12 @@ func (api *ngxVarAPI) getVariable(name string) string {
|
||||
return string(api.ctx.Request.Header.ContentType())
|
||||
case "http_content_length", "http_content-length":
|
||||
return string(api.ctx.Request.Header.Peek("Content-Length"))
|
||||
|
||||
// 客户端信息
|
||||
case "remote_addr":
|
||||
return api.ctx.RemoteAddr().String()
|
||||
case "remote_port":
|
||||
addr := api.ctx.RemoteAddr()
|
||||
if addr != nil {
|
||||
// 简化处理,实际可能需要解析端口
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
case "binary_remote_addr":
|
||||
return ""
|
||||
|
||||
// 服务器信息
|
||||
case "server_addr":
|
||||
addr := api.ctx.LocalAddr()
|
||||
if addr != nil {
|
||||
@ -160,18 +244,13 @@ func (api *ngxVarAPI) getVariable(name string) string {
|
||||
return ""
|
||||
case "server_name":
|
||||
return string(api.ctx.Host())
|
||||
|
||||
// URI 参数
|
||||
case "arg_":
|
||||
// 获取所有参数
|
||||
return string(api.ctx.URI().QueryString())
|
||||
default:
|
||||
// 检查是否是 arg_ 开头的参数
|
||||
if len(name) > 4 && name[:4] == "arg_" {
|
||||
paramName := name[4:]
|
||||
return string(api.ctx.QueryArgs().Peek(paramName))
|
||||
}
|
||||
// 检查是否是 http_ 开头的请求头
|
||||
if len(name) > 5 && name[:5] == "http_" {
|
||||
headerName := name[5:]
|
||||
return string(api.ctx.Request.Header.Peek(headerName))
|
||||
|
||||
@ -282,3 +282,397 @@ func TestNgxVarUndefined(t *testing.T) {
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxVarAdditionalBuiltinVars 测试其他内置变量
|
||||
func TestNgxVarAdditionalBuiltinVars(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("DELETE")
|
||||
req.Header.SetRequestURI("/api/users?id=123&name=test")
|
||||
req.Header.Set("Host", "api.example.com")
|
||||
req.Header.Set("User-Agent", "TestClient/1.0")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer token123")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试其他内置变量
|
||||
err = coro.Execute(`
|
||||
-- URI 相关变量
|
||||
local request_uri = ngx.var.request_uri
|
||||
if request_uri ~= "/api/users?id=123&name=test" then
|
||||
error("request_uri mismatch, got: " .. tostring(request_uri))
|
||||
end
|
||||
|
||||
local uri = ngx.var.uri
|
||||
if uri ~= "/api/users" then
|
||||
error("urishould be /api/users, got: " .. tostring(uri))
|
||||
end
|
||||
|
||||
local document_uri = ngx.var.document_uri
|
||||
if document_uri ~= "/api/users" then
|
||||
error("document_uri should be /api/users, got: " .. tostring(document_uri))
|
||||
end
|
||||
|
||||
-- 查询字符串
|
||||
local query_string = ngx.var.query_string
|
||||
if query_string ~= "id=123&name=test" then
|
||||
error("query_string mismatch, got: " .. tostring(query_string))
|
||||
end
|
||||
|
||||
local args = ngx.var.args
|
||||
if args ~= "id=123&name=test" then
|
||||
error("args should match query_string, got: " .. tostring(args))
|
||||
end
|
||||
|
||||
-- 请求头
|
||||
local accept = ngx.var.http_accept
|
||||
if accept ~= "application/json" then
|
||||
error("http_accept mismatch, got: " .. tostring(accept))
|
||||
end
|
||||
|
||||
local contentType = ngx.var.http_content_type
|
||||
if contentType ~= "application/json" then
|
||||
error("http_content_type mismatch, got: " .. tostring(contentType))
|
||||
end
|
||||
|
||||
local authorization = ngx.var.http_authorization
|
||||
if authorization ~= "Bearer token123" then
|
||||
error("http_authorization mismatch, got: " .. tostring(authorization))
|
||||
end
|
||||
|
||||
-- 内置变量 map
|
||||
local vars = {
|
||||
"request_method", "request_uri", "uri", "document_uri",
|
||||
"query_string", "args", "http_host", "http_user_agent",
|
||||
"http_accept", "http_content_type"
|
||||
}
|
||||
for _, v in ipairs(vars) do
|
||||
local val = ngx.var[v]
|
||||
if type(val) ~= "string" then
|
||||
error("var " .. v .. " should be string, got: " .. type(val))
|
||||
end
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxVarDynamicArgsAccess 测试动态参数访问
|
||||
func TestNgxVarDynamicArgsAccess(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/search?keyword=lua&category=programming&limit=10")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试动态参数访问
|
||||
err = coro.Execute(`
|
||||
-- 直接通过 arg_ 访问
|
||||
local keyword = ngx.var.arg_keyword
|
||||
if keyword ~= "lua" then
|
||||
error("arg_keyword should be 'lua', got: " .. tostring(keyword))
|
||||
end
|
||||
|
||||
local category = ngx.var.arg_category
|
||||
if category ~= "programming" then
|
||||
error("arg_category should be 'programming', got: " .. tostring(category))
|
||||
end
|
||||
|
||||
local limit = ngx.var.arg_limit
|
||||
if limit ~= "10" then
|
||||
error("arg_limit should be '10', got: " .. tostring(limit))
|
||||
end
|
||||
|
||||
-- 使用动态键访问
|
||||
local keys = {"keyword", "category", "limit"}
|
||||
for i, k in ipairs(keys) do
|
||||
local val = ngx.var["arg_" .. k]
|
||||
if type(val) ~= "string" then
|
||||
error("dynamic arg access should return string")
|
||||
end
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxVarGoAPI 测试 Go 层 API 调用
|
||||
func TestNgxVarGoAPI(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/test")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
// 直接创建 API 实例并测试 Go 层 API
|
||||
api := newNgxVarAPI(ctx)
|
||||
require.NotNil(t, api)
|
||||
|
||||
// 测试 SetVariable
|
||||
api.SetVariable("go_set_var", "value_from_go")
|
||||
value, ok := api.GetVariable("go_set_var")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value_from_go", value)
|
||||
|
||||
// 测试 GetVariable 不存在的变量
|
||||
value, ok = api.GetVariable("nonexistent")
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, "", value)
|
||||
|
||||
// 测试覆盖:Go 设置,Go 读取验证
|
||||
api.SetVariable("cross_lang", "from_go")
|
||||
val, ok := api.GetVariable("cross_lang")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "from_go", val)
|
||||
|
||||
// 测试覆盖:直接设置 store,Go 读取验证
|
||||
api.store["cross_lang2"] = "from_lua"
|
||||
value, ok = api.GetVariable("cross_lang2")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "from_lua", value)
|
||||
}
|
||||
|
||||
// TestNgxVarRequestMethodAccess 测试各种请求方法
|
||||
func TestNgxVarRequestMethodAccess(t *testing.T) {
|
||||
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
|
||||
|
||||
for _, method := range methods {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod(method)
|
||||
req.Header.SetRequestURI("/test")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = coro.Execute(`
|
||||
local method = ngx.var.request_method
|
||||
if method ~= "` + method + `" then
|
||||
error("request_method should be '` + method + `', got: " .. tostring(method))
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNgxVarMixedAccessPatterns 测试混合访问模式
|
||||
func TestNgxVarMixedAccessPatterns(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/test")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试混合访问模式
|
||||
err = coro.Execute(`
|
||||
-- 点号访问设置
|
||||
ngx.var.test1 = "value1"
|
||||
-- 索引访问读取
|
||||
local val1 = ngx.var["test1"]
|
||||
if val1 ~= "value1" then
|
||||
error("mixed access 1 failed")
|
||||
end
|
||||
|
||||
-- 索引访问设置
|
||||
ngx.var["test2"] = "value2"
|
||||
-- 点号访问读取
|
||||
local val2 = ngx.var.test2
|
||||
if val2 ~= "value2" then
|
||||
error("mixed access 2 failed")
|
||||
end
|
||||
|
||||
-- 循环访问
|
||||
for i = 1, 3 do
|
||||
ngx.var["dynamic_" .. i] = "val_" .. i
|
||||
end
|
||||
|
||||
for i = 1, 3 do
|
||||
local v = ngx.var["dynamic_" .. i]
|
||||
if v ~= "val_" .. i then
|
||||
error("dynamic loop failed for i=" .. i)
|
||||
end
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxVarSpecialHeaders 测试特殊请求头
|
||||
func TestNgxVarSpecialHeaders(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/test")
|
||||
req.Header.Set("Accept-Encoding", "gzip, deflate, br")
|
||||
req.Header.Set("Accept-Language", "en-US,en;q=0.9")
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
req.Header.Set("Referer", "https://example.com")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = coro.Execute(`
|
||||
-- 测试带连字符的头
|
||||
local acceptEncoding = ngx.var.http_accept_encoding
|
||||
if acceptEncoding ~= "gzip, deflate, br" then
|
||||
error("http_accept_encoding mismatch")
|
||||
end
|
||||
|
||||
local acceptLanguage = ngx.var.http_accept_language
|
||||
if acceptLanguage ~= "en-US,en;q=0.9" then
|
||||
error("http_accept_language mismatch")
|
||||
end
|
||||
|
||||
local connection = ngx.var.http_connection
|
||||
if connection ~= "keep-alive" then
|
||||
error("http_connection mismatch")
|
||||
end
|
||||
|
||||
local referer = ngx.var.http_referer
|
||||
if referer ~= "https://example.com" then
|
||||
error("http_referer mismatch")
|
||||
end
|
||||
|
||||
-- 测试也可以通过下划线访问
|
||||
local enc2 = ngx.var["http_accept_encoding"]
|
||||
if enc2 ~= acceptEncoding then
|
||||
error("http_accept_encoding via index mismatch")
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxVarEmptyAndNil 测试空值和 nil 处理
|
||||
func TestNgxVarEmptyAndNil(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = coro.Execute(`
|
||||
-- 未设置的参数应该返回空字符串或 nil
|
||||
local empty = ngx.var.arg_nonexistent
|
||||
-- arg_ 对不存在的参数通常返回空字符串
|
||||
|
||||
-- 自定义变量设为空字符串
|
||||
ngx.var.empty_string = ""
|
||||
local val = ngx.var.empty_string
|
||||
if val ~= "" then
|
||||
error("empty_string should be empty")
|
||||
end
|
||||
|
||||
-- 覆盖为空值
|
||||
ngx.var.test = "value"
|
||||
ngx.var.test = nil -- Lua 的 nil 在 __newindex 中会被转换
|
||||
-- 实现中 nil 会被转换为空字符串
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxVarRequestBodyAccess 测试请求体相关变量
|
||||
func TestNgxVarRequestBodyAccess(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("POST")
|
||||
req.Header.SetRequestURI("/upload")
|
||||
req.Header.SetContentType("application/octet-stream")
|
||||
req.SetBody([]byte("test body content"))
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = coro.Execute(`
|
||||
-- 测试请求长度
|
||||
local length = ngx.var.request_length
|
||||
if type(length) ~= "number" then
|
||||
error("request_length should be a number")
|
||||
end
|
||||
|
||||
-- 测试内容类型
|
||||
local contentType = ngx.var.http_content_type
|
||||
if contentType ~= "application/octet-stream" then
|
||||
error("content_type mismatch")
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
@ -70,6 +70,12 @@ type LuaCoroutine struct {
|
||||
executionCancel context.CancelFunc
|
||||
OutputBuffer []byte
|
||||
Exited bool
|
||||
|
||||
// ngx API 实例(用于测试和 Go 层访问)
|
||||
ngxVarAPI *ngxVarAPI
|
||||
ngxReqAPI *ngxReqAPI
|
||||
ngxRespAPI *ngxRespAPI
|
||||
ngxLogAPI *ngxLogAPI
|
||||
}
|
||||
|
||||
// SetupSandbox 创建 per-request _ENV 沙箱
|
||||
@ -163,20 +169,24 @@ func (c *LuaCoroutine) setupNgxAPI() {
|
||||
// 注册 ngx.req API
|
||||
if c.RequestCtx != nil {
|
||||
reqAPI := newNgxReqAPI(c.RequestCtx)
|
||||
c.ngxReqAPI = reqAPI
|
||||
RegisterNgxReqAPI(c.Co, reqAPI, ngx)
|
||||
|
||||
// 注册 ngx.resp API
|
||||
respAPI := newNgxRespAPI(c.RequestCtx)
|
||||
c.ngxRespAPI = respAPI
|
||||
RegisterNgxRespAPI(c.Co, respAPI)
|
||||
|
||||
// 注册 ngx.log API (logger 为 nil 时禁用日志输出)
|
||||
// ngx.say/print/flush 直接写入 RequestCtx
|
||||
logAPI := newNgxLogAPI(c.RequestCtx, nil, nil)
|
||||
c.ngxLogAPI = logAPI
|
||||
RegisterNgxLogAPI(c.Co, logAPI)
|
||||
}
|
||||
|
||||
// 注册 ngx.var API
|
||||
varAPI := newNgxVarAPI(c.RequestCtx)
|
||||
c.ngxVarAPI = varAPI
|
||||
RegisterNgxVarAPI(c.Co, varAPI, ngx)
|
||||
|
||||
// 注册 ngx.ctx API
|
||||
@ -333,3 +343,23 @@ func (c *LuaCoroutine) handleSleep(values []glua.LValue) ([]glua.LValue, error)
|
||||
func (c *LuaCoroutine) Close() {
|
||||
c.Engine.releaseCoroutine(c)
|
||||
}
|
||||
|
||||
// GetNgxVarAPI 获取 ngx.var API 实例(用于测试和 Go 层访问)
|
||||
func (c *LuaCoroutine) GetNgxVarAPI() *ngxVarAPI {
|
||||
return c.ngxVarAPI
|
||||
}
|
||||
|
||||
// GetNgxReqAPI 获取 ngx.req API 实例(用于测试和 Go 层访问)
|
||||
func (c *LuaCoroutine) GetNgxReqAPI() *ngxReqAPI {
|
||||
return c.ngxReqAPI
|
||||
}
|
||||
|
||||
// GetNgxRespAPI 获取 ngx.resp API 实例(用于测试和 Go 层访问)
|
||||
func (c *LuaCoroutine) GetNgxRespAPI() *ngxRespAPI {
|
||||
return c.ngxRespAPI
|
||||
}
|
||||
|
||||
// GetNgxLogAPI 获取 ngx.log API 实例(用于测试和 Go 层访问)
|
||||
func (c *LuaCoroutine) GetNgxLogAPI() *ngxLogAPI {
|
||||
return c.ngxLogAPI
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user