From 51061d68ffc5da6f35d7e610eff02dc34e485385 Mon Sep 17 00:00:00 2001 From: xfy Date: Sat, 11 Apr 2026 12:17:35 +0800 Subject: [PATCH] =?UTF-8?q?feat(lua):=20=E5=AE=9E=E7=8E=B0=20ngx.var=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 通过元表实现动态变量读写,支持 nginx 内置变量和自定义变量 Co-Authored-By: Claude Opus 4.6 --- internal/lua/api_var.go | 197 ++++++++++++++++++++++++ internal/lua/api_var_test.go | 284 +++++++++++++++++++++++++++++++++++ 2 files changed, 481 insertions(+) create mode 100644 internal/lua/api_var.go create mode 100644 internal/lua/api_var_test.go diff --git a/internal/lua/api_var.go b/internal/lua/api_var.go new file mode 100644 index 0000000..be61b7d --- /dev/null +++ b/internal/lua/api_var.go @@ -0,0 +1,197 @@ +// Package lua 提供 ngx.var API 实现 +package lua + +import ( + "strconv" + + "github.com/valyala/fasthttp" + glua "github.com/yuin/gopher-lua" +) + +// ngxVarAPI ngx.var API 实现 +type ngxVarAPI struct { + // 请求上下文 + ctx *fasthttp.RequestCtx + + // 变量存储(用于自定义变量) + store map[string]string +} + +// newNgxVarAPI 创建 ngx.var API 实例 +func newNgxVarAPI(ctx *fasthttp.RequestCtx) *ngxVarAPI { + return &ngxVarAPI{ + ctx: ctx, + store: make(map[string]string), + } +} + +// RegisterNgxVarAPI 在 Lua 状态机中注册 ngx.var API +// 使用元表实现动态读写:ngx.var.key 和 ngx.var[key] +func RegisterNgxVarAPI(L *glua.LState, api *ngxVarAPI, ngxTable *glua.LTable) { + // 创建 ngx.var 表(使用元表实现动态访问) + ngxVar := L.NewTable() + + // 创建元表 + mt := L.NewTable() + + // __index 元方法:读取变量 + mt.RawSetString("__index", L.NewFunction(api.luaVarIndex)) + + // __newindex 元方法:设置变量 + mt.RawSetString("__newindex", L.NewFunction(api.luaVarNewIndex)) + + // 设置元表 + L.SetMetatable(ngxVar, mt) + + // 将 ngx.var 添加到 ngx 表 + ngxTable.RawSetString("var", ngxVar) +} + +// luaVarIndex 实现 ngx.var[key] 读取 +// Lua 调用: local value = ngx.var.key 或 ngx.var[key] +func (api *ngxVarAPI) luaVarIndex(L *glua.LState) int { + // 第一个参数是表本身(ngx.var) + // 第二个参数是键名 + key := L.CheckString(2) + + // 1. 先查自定义变量存储 + if value, ok := api.store[key]; ok { + L.Push(glua.LString(value)) + return 1 + } + + // 2. 从 fasthttp RequestCtx 获取变量 + value := api.getVariable(key) + if value != "" { + L.Push(glua.LString(value)) + return 1 + } + + // 3. 未找到变量,返回 nil + L.Push(glua.LNil) + return 1 +} + +// luaVarNewIndex 实现 ngx.var[key] = value 写入 +// Lua 调用: ngx.var.key = value 或 ngx.var[key] = value +func (api *ngxVarAPI) luaVarNewIndex(L *glua.LState) int { + // 第一个参数是表本身(ngx.var) + // 第二个参数是键名 + // 第三个参数是值 + key := L.CheckString(2) + value := L.CheckString(3) + + // 存储到自定义变量存储 + api.store[key] = value + + return 0 +} + +// getVariable 从 fasthttp RequestCtx 获取变量值 +// 支持常见的 nginx 变量 +func (api *ngxVarAPI) getVariable(name string) string { + if api.ctx == nil { + return "" + } + + switch name { + // HTTP 请求相关 + case "request_method": + return string(api.ctx.Method()) + case "request_uri": + return string(api.ctx.RequestURI()) + case "uri": + return string(api.ctx.URI().Path()) + case "document_uri": + return string(api.ctx.URI().Path()) + case "query_string", "args": + return string(api.ctx.URI().QueryString()) + case "server_protocol", "protocol": + return string(api.ctx.Request.Header.Protocol()) + case "scheme": + return string(api.ctx.URI().Scheme()) + case "request_length": + return strconv.Itoa(api.ctx.Request.Header.ContentLength()) + case "request_time": + // 简化实现,返回空字符串 + return "" + + // 请求头相关 + case "http_host": + return string(api.ctx.Host()) + case "http_user_agent", "http_user-agent": + return string(api.ctx.UserAgent()) + case "http_referer": + return string(api.ctx.Referer()) + case "http_accept": + return string(api.ctx.Request.Header.Peek("Accept")) + case "http_accept_encoding", "http_accept-encoding": + return string(api.ctx.Request.Header.Peek("Accept-Encoding")) + case "http_accept_language", "http_accept-language": + return string(api.ctx.Request.Header.Peek("Accept-Language")) + case "http_connection": + return string(api.ctx.Request.Header.Peek("Connection")) + case "http_content_type", "http_content-type": + return string(api.ctx.Request.Header.ContentType()) + case "http_content_length", "http_content-length": + return string(api.ctx.Request.Header.Peek("Content-Length")) + + // 客户端信息 + case "remote_addr": + return api.ctx.RemoteAddr().String() + case "remote_port": + addr := api.ctx.RemoteAddr() + if addr != nil { + // 简化处理,实际可能需要解析端口 + return "" + } + return "" + case "binary_remote_addr": + return "" + + // 服务器信息 + case "server_addr": + addr := api.ctx.LocalAddr() + if addr != nil { + return addr.String() + } + return "" + case "server_port": + return "" + case "server_name": + return string(api.ctx.Host()) + + // URI 参数 + case "arg_": + // 获取所有参数 + return string(api.ctx.URI().QueryString()) + default: + // 检查是否是 arg_ 开头的参数 + if len(name) > 4 && name[:4] == "arg_" { + paramName := name[4:] + return string(api.ctx.QueryArgs().Peek(paramName)) + } + // 检查是否是 http_ 开头的请求头 + if len(name) > 5 && name[:5] == "http_" { + headerName := name[5:] + return string(api.ctx.Request.Header.Peek(headerName)) + } + return "" + } +} + +// SetVariable 设置自定义变量(Go 层调用) +func (api *ngxVarAPI) SetVariable(name, value string) { + api.store[name] = value +} + +// GetVariable 获取变量值(Go 层调用) +func (api *ngxVarAPI) GetVariable(name string) (string, bool) { + // 先查自定义变量 + if value, ok := api.store[name]; ok { + return value, true + } + // 再查 fasthttp 变量 + value := api.getVariable(name) + return value, value != "" +} diff --git a/internal/lua/api_var_test.go b/internal/lua/api_var_test.go new file mode 100644 index 0000000..e7d2205 --- /dev/null +++ b/internal/lua/api_var_test.go @@ -0,0 +1,284 @@ +// Package lua 提供 ngx.var API 测试 +package lua + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +// TestNgxVarAPI 测试 ngx.var API 基础功能 +func TestNgxVarAPI(t *testing.T) { + // 创建 fasthttp 请求上下文 + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/test/path?foo=bar&baz=qux") + req.Header.Set("Host", "example.com") + req.Header.Set("User-Agent", "TestAgent") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + // 创建协程 + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + // 设置沙箱(这会注册 ngx API) + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试 ngx.var 存在 + err = coro.Execute(` + if type(ngx) ~= "table" then + error("ngx is not a table") + end + if type(ngx.var) ~= "table" then + error("ngx.var is not a table") + end + `) + assert.NoError(t, err) +} + +// TestNgxVarReadBuiltin 测试读取内置变量 +func TestNgxVarReadBuiltin(t *testing.T) { + // 创建 fasthttp 请求上下文 + var req fasthttp.Request + req.Header.SetMethod("POST") + req.Header.SetRequestURI("/api/test?name=value") + req.Header.Set("Host", "test.example.com") + req.Header.Set("User-Agent", "Mozilla/5.0") + req.Header.Set("Accept", "application/json") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试读取内置变量 + err = coro.Execute(` + local method = ngx.var.request_method + if method ~= "POST" then + error("request_method should be POST, got: " .. tostring(method)) + end + + local uri = ngx.var.uri + if uri ~= "/api/test" then + error("uri should be /api/test, got: " .. tostring(uri)) + end + + local host = ngx.var.http_host + if host ~= "test.example.com" then + error("http_host should be test.example.com, got: " .. tostring(host)) + end + + local userAgent = ngx.var.http_user_agent + if userAgent ~= "Mozilla/5.0" then + error("http_user_agent should be Mozilla/5.0, got: " .. tostring(userAgent)) + end + `) + assert.NoError(t, err) +} + +// TestNgxVarReadQueryArgs 测试读取查询参数 +func TestNgxVarReadQueryArgs(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/search?q=lua&page=1") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试读取查询参数 + err = coro.Execute(` + local q = ngx.var.arg_q + if q ~= "lua" then + error("arg_q should be 'lua', got: " .. tostring(q)) + end + + local page = ngx.var.arg_page + if page ~= "1" then + error("arg_page should be '1', got: " .. tostring(page)) + end + + local queryString = ngx.var.query_string + if type(queryString) ~= "string" then + error("query_string should be a string") + end + `) + assert.NoError(t, err) +} + +// TestNgxVarWriteCustom 测试设置自定义变量 +func TestNgxVarWriteCustom(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/test") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试设置和读取自定义变量 + err = coro.Execute(` + -- 设置自定义变量 + ngx.var.my_custom_var = "hello world" + ngx.var.another_var = "12345" + + -- 读取自定义变量 + local val1 = ngx.var.my_custom_var + if val1 ~= "hello world" then + error("my_custom_var should be 'hello world', got: " .. tostring(val1)) + end + + local val2 = ngx.var.another_var + if val2 ~= "12345" then + error("another_var should be '12345', got: " .. tostring(val2)) + end + + -- 覆盖已存在的变量 + ngx.var.my_custom_var = "updated" + local val3 = ngx.var.my_custom_var + if val3 ~= "updated" then + error("my_custom_var should be 'updated', got: " .. tostring(val3)) + end + `) + assert.NoError(t, err) +} + +// TestNgxVarIndexAccess 测试索引访问方式 ngx.var[key] +func TestNgxVarIndexAccess(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/test") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试索引访问 + err = coro.Execute(` + -- 使用索引方式设置变量 + ngx.var["dynamic_key"] = "dynamic_value" + + -- 使用索引方式读取变量 + local val = ngx.var["dynamic_key"] + if val ~= "dynamic_value" then + error("dynamic_key should be 'dynamic_value', got: " .. tostring(val)) + end + + -- 混合访问方式 + ngx.var.mixed = "mixed_value" + local mixed = ngx.var["mixed"] + if mixed ~= "mixed_value" then + error("mixed should be 'mixed_value', got: " .. tostring(mixed)) + end + `) + assert.NoError(t, err) +} + +// TestNgxVarNilRequestCtx 测试无请求上下文的情况 +func TestNgxVarNilRequestCtx(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + // 创建没有请求上下文的协程 + coro, err := engine.NewCoroutine(nil) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试在无请求上下文时访问 ngx.var + err = coro.Execute(` + -- 应该返回空字符串或 nil + local method = ngx.var.request_method + -- 可以接受空字符串或 nil + + -- 但自定义变量仍然可以设置 + ngx.var.test = "value" + local val = ngx.var.test + if val ~= "value" then + error("custom var should be settable even without request ctx") + end + `) + assert.NoError(t, err) +} + +// TestNgxVarUndefined 测试未定义变量 +func TestNgxVarUndefined(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/test") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试读取未定义变量 + err = coro.Execute(` + local undefined = ngx.var.undefined_var_name + if undefined ~= nil then + error("undefined var should be nil, got: " .. tostring(undefined)) + end + `) + assert.NoError(t, err) +}