From 6b9df86217e9fb0d9bcd7e700a75a467c31ebd53 Mon Sep 17 00:00:00 2001 From: xfy Date: Sat, 11 Apr 2026 12:17:35 +0800 Subject: [PATCH] =?UTF-8?q?feat(lua):=20=E6=89=A9=E5=B1=95=20ngx.req=20API?= =?UTF-8?q?=20=E5=B9=B6=E9=9B=86=E6=88=90=E6=89=80=E6=9C=89=20ngx=20API=20?= =?UTF-8?q?=E5=88=B0=E6=B2=99=E7=AE=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 扩展 API: set_uri, set_uri_args, get_headers, set_header, clear_header, get_body_data, read_body 在 coroutine.SetupSandbox() 中统一注册 ngx.req/resp/var/ctx/log/socket API Co-Authored-By: Claude Opus 4.6 --- internal/lua/api_req.go | 314 +++++++++++++++++++- internal/lua/api_req_test.go | 555 +++++++++++++++++++++++++++++++++++ internal/lua/coroutine.go | 45 +++ internal/lua/lua_test.go | 116 ++++++++ 4 files changed, 1017 insertions(+), 13 deletions(-) create mode 100644 internal/lua/api_req_test.go diff --git a/internal/lua/api_req.go b/internal/lua/api_req.go index b1eea79..fa11fe0 100644 --- a/internal/lua/api_req.go +++ b/internal/lua/api_req.go @@ -83,10 +83,7 @@ func newNgxReqAPI(ctx *fasthttp.RequestCtx) *ngxReqAPI { // RegisterNgxReqAPI 在 Lua 状态机中注册 ngx.req API // 这是主入口函数,由 LuaEngine 在初始化时调用 -func RegisterNgxReqAPI(L *glua.LState, api *ngxReqAPI) { - // 创建 ngx 表 - ngx := L.NewTable() - +func RegisterNgxReqAPI(L *glua.LState, api *ngxReqAPI, ngxTable *glua.LTable) { // 创建 ngx.req 子表 ngxReq := L.NewTable() @@ -98,21 +95,41 @@ func RegisterNgxReqAPI(L *glua.LState, api *ngxReqAPI) { // 特点:直接返回请求的 URI 路径(不含 query string) ngxReq.RawSetString("get_uri", L.NewFunction(api.luaGetURI)) + // 直接映射层 API:set_uri + // 特点:直接修改请求的 URI 路径,支持可选的内部跳转标记 + ngxReq.RawSetString("set_uri", L.NewFunction(api.luaSetURI)) + // 兼容层 API:get_uri_args // 特点:需要解析 query string 为 nginx 兼容的表结构 // 增加了解析开销,但保持 API 兼容性 ngxReq.RawSetString("get_uri_args", L.NewFunction(api.luaGetURIArgs)) + // 兼容层 API:set_uri_args + // 特点:支持 table 或 string 参数设置查询参数 + ngxReq.RawSetString("set_uri_args", L.NewFunction(api.luaSetURIArgs)) + + // 兼容层 API:get_headers + // 特点:需要遍历所有请求头,模拟 nginx 的头表结构 + ngxReq.RawSetString("get_headers", L.NewFunction(api.luaGetHeaders)) + + // 直接映射层 API:set_header + // 特点:直接操作 fasthttp 请求头,支持设置和清除 + ngxReq.RawSetString("set_header", L.NewFunction(api.luaSetHeader)) + + // 直接映射层 API:clear_header + // 特点:直接删除 fasthttp 请求头 + ngxReq.RawSetString("clear_header", L.NewFunction(api.luaClearHeader)) + + // 兼容层 API:get_body_data + // 特点:获取请求体内容 + ngxReq.RawSetString("get_body_data", L.NewFunction(api.luaGetBodyData)) + // 伪非阻塞层 API:read_body - // 特点:使用 yield/resume 模式支持异步读取 - // 这是实验性 API,展示非阻塞调用模式 - ngxReq.RawSetString("read_body", L.NewFunction(api.luaReadBodyAsync)) + // 特点:确保请求体已被读取(fasthttp 已预读) + ngxReq.RawSetString("read_body", L.NewFunction(api.luaReadBody)) - // 将 ngx.req 添加到 ngx - ngx.RawSetString("req", ngxReq) - - // 注册 ngx 全局变量 - L.SetGlobal("ngx", ngx) + // 将 ngx.req 添加到 ngx 表 + ngxTable.RawSetString("req", ngxReq) } // ==================== 直接映射层 API ==================== @@ -138,6 +155,43 @@ func (api *ngxReqAPI) luaGetMethod(L *glua.LState) int { return 1 } +// luaSetURI 实现 ngx.req.set_uri(uri, jump?) - 直接映射层 +// Lua 调用: ngx.req.set_uri("/new/path") 或 ngx.req.set_uri("/new/path", true) +// 参数: +// - uri: 新的 URI 路径 +// - jump: 是否触发内部跳转(可选,默认为 false) +func (api *ngxReqAPI) luaSetURI(L *glua.LState) int { + start := time.Now() + + // 获取 uri 参数 + uri := L.CheckString(1) + + // 获取可选的 jump 参数 + jump := false + if L.GetTop() >= 2 { + jump = L.ToBool(2) + } + + // 设置新的 URI + api.ctx.Request.URI().SetPath(uri) + + // 如果 jump 为 true,记录内部跳转标记(供后续处理使用) + if jump { + // 在请求上下文中存储跳转标记 + api.ctx.SetUserValue("_ngx_req_internal_jump", true) + } + + // 记录指标 + elapsed := uint64(time.Since(start).Nanoseconds()) + api.metrics.DirectCallCount++ + api.metrics.DirectTotalNs += elapsed + if elapsed > api.metrics.DirectMaxNs { + api.metrics.DirectMaxNs = elapsed + } + + return 0 +} + // luaGetURI 实现 ngx.req.get_uri() - 直接映射层 // Lua 调用: local uri = ngx.req.get_uri() // 返回: string (如 "/path/to/resource") @@ -225,7 +279,241 @@ func (api *ngxReqAPI) parseURIArgs() map[string][]string { return args } -// ==================== 伪非阻塞层 API(实验性) ==================== +// luaSetURIArgs 实现 ngx.req.set_uri_args(args) - 兼容层 +// Lua 调用: ngx.req.set_uri_args({ key = "value" }) 或 ngx.req.set_uri_args("key=value&foo=bar") +// 参数: +// - args: table 或 string 类型的查询参数 +func (api *ngxReqAPI) luaSetURIArgs(L *glua.LState) int { + start := time.Now() + + // 获取参数类型 + argType := L.Get(1) + + switch argType.Type() { + case glua.LTString: + // 如果是字符串,直接解析并设置 + queryStr := string(argType.(glua.LString)) + api.ctx.Request.URI().SetQueryString(queryStr) + + case glua.LTTable: + // 如果是 table,构建查询字符串 + table := argType.(*glua.LTable) + args := make(map[string][]string) + + table.ForEach(func(key, value glua.LValue) { + keyStr := glua.LVAsString(key) + switch value.Type() { + case glua.LTString: + args[keyStr] = []string{string(value.(glua.LString))} + case glua.LTNumber: + args[keyStr] = []string{glua.LVAsString(value)} + case glua.LTTable: + // 数组形式的多值 + arr := value.(*glua.LTable) + values := []string{} + arr.ForEach(func(_, v glua.LValue) { + values = append(values, glua.LVAsString(v)) + }) + args[keyStr] = values + } + }) + + // 构建查询字符串 + if len(args) > 0 { + query := fasthttp.Args{} + for key, values := range args { + for _, v := range values { + query.Add(key, v) + } + } + api.ctx.Request.URI().SetQueryString(query.String()) + } + + default: + L.RaiseError("set_uri_args expects table or string, got %s", argType.Type().String()) + return 0 + } + + // 记录指标 + elapsed := uint64(time.Since(start).Nanoseconds()) + api.metrics.CompatibleCallCount++ + api.metrics.CompatibleTotalNs += elapsed + if elapsed > api.metrics.CompatibleMaxNs { + api.metrics.CompatibleMaxNs = elapsed + } + + return 0 +} + +// ==================== 请求头 API ==================== + +// luaGetHeaders 实现 ngx.req.get_headers(max_headers?) - 兼容层 +// Lua 调用: local headers = ngx.req.get_headers() 或 ngx.req.get_headers(50) +// 返回: table (如 { ["host"] = "example.com", ["cookie"] = { "a=1", "b=2" } }) +// 注意:兼容层需要遍历所有请求头,模拟 nginx 的头表结构 +func (api *ngxReqAPI) luaGetHeaders(L *glua.LState) int { + start := time.Now() + + // 获取可选的 max_headers 参数 + maxHeaders := 100 // 默认最大头数 + if L.GetTop() >= 1 { + maxHeaders = L.ToInt(1) + if maxHeaders <= 0 { + maxHeaders = 100 + } + } + + // 构建 Lua 表 + result := L.NewTable() + headers := &api.ctx.Request.Header + + count := 0 + // 使用 VisitAll 遍历所有请求头 + headers.VisitAll(func(key, value []byte) { + if count >= maxHeaders { + return + } + keyStr := string(key) + valueStr := string(value) + + // 检查是否已存在同名头(多值头) + existing := result.RawGetString(keyStr) + if existing == glua.LNil { + // 第一次遇到这个头 + result.RawSetString(keyStr, glua.LString(valueStr)) + } else if existingStr, ok := existing.(glua.LString); ok { + // 第二次遇到,需要转换为数组 + arr := L.NewTable() + arr.Append(existingStr) + arr.Append(glua.LString(valueStr)) + result.RawSetString(keyStr, arr) + } else if existingArr, ok := existing.(*glua.LTable); ok { + // 已经是数组,追加 + existingArr.Append(glua.LString(valueStr)) + } + count++ + }) + + // 记录指标 + elapsed := uint64(time.Since(start).Nanoseconds()) + api.metrics.CompatibleCallCount++ + api.metrics.CompatibleTotalNs += elapsed + if elapsed > api.metrics.CompatibleMaxNs { + api.metrics.CompatibleMaxNs = elapsed + } + + L.Push(result) + return 1 +} + +// luaSetHeader 实现 ngx.req.set_header(key, value) - 直接映射层 +// Lua 调用: ngx.req.set_header("X-Custom", "value") 或 ngx.req.set_header("X-Custom", nil) 清除 +// 参数: +// - key: 头名称 +// - value: 头值,如果为 nil 则清除该头 +func (api *ngxReqAPI) luaSetHeader(L *glua.LState) int { + start := time.Now() + + // 获取参数 + key := L.CheckString(1) + value := L.Get(2) + + if value == glua.LNil { + // 值为 nil,删除头 + api.ctx.Request.Header.Del(key) + } else { + // 设置头值 + valueStr := glua.LVAsString(value) + api.ctx.Request.Header.Set(key, valueStr) + } + + // 记录指标 + elapsed := uint64(time.Since(start).Nanoseconds()) + api.metrics.DirectCallCount++ + api.metrics.DirectTotalNs += elapsed + if elapsed > api.metrics.DirectMaxNs { + api.metrics.DirectMaxNs = elapsed + } + + return 0 +} + +// luaClearHeader 实现 ngx.req.clear_header(key) - 直接映射层 +// Lua 调用: ngx.req.clear_header("X-Custom") +// 参数: +// - key: 要清除的头名称 +func (api *ngxReqAPI) luaClearHeader(L *glua.LState) int { + start := time.Now() + + // 获取参数 + key := L.CheckString(1) + + // 删除头 + api.ctx.Request.Header.Del(key) + + // 记录指标 + elapsed := uint64(time.Since(start).Nanoseconds()) + api.metrics.DirectCallCount++ + api.metrics.DirectTotalNs += elapsed + if elapsed > api.metrics.DirectMaxNs { + api.metrics.DirectMaxNs = elapsed + } + + return 0 +} + +// ==================== 请求体 API ==================== + +// luaGetBodyData 实现 ngx.req.get_body_data() - 兼容层 +// Lua 调用: local body = ngx.req.get_body_data() +// 返回: string 或 nil(如果没有请求体) +func (api *ngxReqAPI) luaGetBodyData(L *glua.LState) int { + start := time.Now() + + // 获取请求体 + body := api.ctx.Request.Body() + + if len(body) == 0 { + L.Push(glua.LNil) + } else { + L.Push(glua.LString(body)) + } + + // 记录指标 + elapsed := uint64(time.Since(start).Nanoseconds()) + api.metrics.CompatibleCallCount++ + api.metrics.CompatibleTotalNs += elapsed + if elapsed > api.metrics.CompatibleMaxNs { + api.metrics.CompatibleMaxNs = elapsed + } + + return 1 +} + +// luaReadBody 实现 ngx.req.read_body() - 伪非阻塞层 +// Lua 调用: ngx.req.read_body() -- 完成后返回 +// 注意:fasthttp 已经预读取了请求体,这里主要是确保请求体已被读取 +func (api *ngxReqAPI) luaReadBody(L *glua.LState) int { + start := time.Now() + + // fasthttp 默认会预读取请求体到内存中 + // 这里我们只需要确保请求体已被读取(对于 POST/PUT 等请求) + // 如果请求体未读取,触发读取 + if api.ctx.Request.Header.ContentLength() > 0 { + // 访问 Body() 会确保请求体已被读取 + _ = api.ctx.Request.Body() + } + + // 记录指标(使用伪非阻塞层指标) + elapsed := uint64(time.Since(start).Nanoseconds()) + api.metrics.PseudoBlockingCallCount++ + api.metrics.PseudoBlockingTotalNs += elapsed + if elapsed > api.metrics.PseudoBlockingMaxNs { + api.metrics.PseudoBlockingMaxNs = elapsed + } + + return 0 +} // luaReadBodyAsync 实现 ngx.req.read_body() - 伪非阻塞层 // Lua 调用: ngx.req.read_body() -- 会 yield,完成后 resume diff --git a/internal/lua/api_req_test.go b/internal/lua/api_req_test.go new file mode 100644 index 0000000..4055072 --- /dev/null +++ b/internal/lua/api_req_test.go @@ -0,0 +1,555 @@ +// Package lua 提供 ngx.req API 测试 +package lua + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" + glua "github.com/yuin/gopher-lua" +) + +// 创建测试用的 fasthttp.RequestCtx +func createTestRequestCtx(method, uri string, headers map[string]string, body []byte) *fasthttp.RequestCtx { + ctx := &fasthttp.RequestCtx{} + + // 设置请求 + ctx.Request.Header.SetMethod(method) + ctx.Request.SetRequestURI(uri) + + // 设置请求头 + for key, value := range headers { + ctx.Request.Header.Set(key, value) + } + + // 设置请求体 + if len(body) > 0 { + ctx.Request.SetBody(body) + } + + return ctx +} + +// TestNgxReqGetMethod 测试 ngx.req.get_method() +func TestNgxReqGetMethod(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + // 创建测试请求 + reqCtx := createTestRequestCtx("POST", "/test", nil, nil) + + // 创建协程 + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + defer coro.Close() + + // 设置沙箱(这会自动注册 ngx API) + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试获取请求方法 + err = coro.Execute(` + local method = ngx.req.get_method() + if method ~= "POST" then + error("expected POST, got " .. tostring(method)) + end + `) + assert.NoError(t, err) +} + +// TestNgxReqGetURI 测试 ngx.req.get_uri() +func TestNgxReqGetURI(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + reqCtx := createTestRequestCtx("GET", "/path/to/resource?key=value", nil, nil) + + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + err = coro.Execute(` + local uri = ngx.req.get_uri() + if uri ~= "/path/to/resource" then + error("expected /path/to/resource, got " .. tostring(uri)) + end + `) + assert.NoError(t, err) +} + +// TestNgxReqSetURI 测试 ngx.req.set_uri() +func TestNgxReqSetURI(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + // 测试设置 URI(不带 jump) + reqCtx := createTestRequestCtx("GET", "/original", nil, nil) + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + + err = coro.SetupSandbox() + require.NoError(t, err) + + err = coro.Execute(` + ngx.req.set_uri("/new/path") + `) + assert.NoError(t, err) + coro.Close() + + // 验证 URI 已修改 + assert.Equal(t, "/new/path", string(reqCtx.URI().Path())) + + // 测试设置 URI(带 jump) + reqCtx2 := createTestRequestCtx("GET", "/original", nil, nil) + coro2, err := engine.NewCoroutine(reqCtx2) + require.NoError(t, err) + + err = coro2.SetupSandbox() + require.NoError(t, err) + + err = coro2.Execute(` + ngx.req.set_uri("/redirect/path", true) + `) + assert.NoError(t, err) + coro2.Close() + + assert.Equal(t, "/redirect/path", string(reqCtx2.URI().Path())) + // 验证 jump 标记已设置 + jumpFlag := reqCtx2.UserValue("_ngx_req_internal_jump") + assert.Equal(t, true, jumpFlag) +} + +// TestNgxReqGetURIArgs 测试 ngx.req.get_uri_args() +func TestNgxReqGetURIArgs(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + reqCtx := createTestRequestCtx("GET", "/test?foo=bar&baz=qux&arr=1&arr=2", nil, nil) + + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + err = coro.Execute(` + local args = ngx.req.get_uri_args() + + if args.foo ~= "bar" then + error("expected foo=bar, got " .. tostring(args.foo)) + end + + if args.baz ~= "qux" then + error("expected baz=qux, got " .. tostring(args.baz)) + end + + -- 多值参数应该返回数组 + if type(args.arr) ~= "table" then + error("expected arr to be table, got " .. type(args.arr)) + end + + if args.arr[1] ~= "1" or args.arr[2] ~= "2" then + error("expected arr = {1, 2}") + end + `) + assert.NoError(t, err) +} + +// TestNgxReqSetURIArgs 测试 ngx.req.set_uri_args() +func TestNgxReqSetURIArgs(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + reqCtx := createTestRequestCtx("GET", "/test", nil, nil) + + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试使用 table 设置参数 + err = coro.Execute(` + ngx.req.set_uri_args({ key = "value", num = 123 }) + `) + assert.NoError(t, err) + + queryStr := string(reqCtx.URI().QueryString()) + assert.Contains(t, queryStr, "key=value") + assert.Contains(t, queryStr, "num=123") + + // 测试使用字符串设置参数 + reqCtx2 := createTestRequestCtx("GET", "/test", nil, nil) + coro2, err := engine.NewCoroutine(reqCtx2) + require.NoError(t, err) + defer coro2.Close() + + err = coro2.SetupSandbox() + require.NoError(t, err) + + err = coro2.Execute(` + ngx.req.set_uri_args("foo=bar&baz=qux") + `) + assert.NoError(t, err) + + assert.Equal(t, "foo=bar&baz=qux", string(reqCtx2.URI().QueryString())) +} + +// TestNgxReqGetHeaders 测试 ngx.req.get_headers() +func TestNgxReqGetHeaders(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + reqCtx := createTestRequestCtx("GET", "/test", map[string]string{ + "Host": "example.com", + "X-Custom": "custom-value", + "Content-Type": "application/json", + }, nil) + + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + err = coro.Execute(` + local headers = ngx.req.get_headers() + + if headers.Host ~= "example.com" then + error("expected Host=example.com, got " .. tostring(headers.Host)) + end + + if headers["X-Custom"] ~= "custom-value" then + error("expected X-Custom=custom-value") + end + + if headers["Content-Type"] ~= "application/json" then + error("expected Content-Type=application/json") + end + `) + assert.NoError(t, err) +} + +// TestNgxReqSetHeader 测试 ngx.req.set_header() +func TestNgxReqSetHeader(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + // 测试设置请求头 + reqCtx := createTestRequestCtx("GET", "/test", nil, nil) + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + + err = coro.SetupSandbox() + require.NoError(t, err) + + err = coro.Execute(` + ngx.req.set_header("X-Custom-Header", "custom-value") + `) + assert.NoError(t, err) + coro.Close() + + assert.Equal(t, "custom-value", string(reqCtx.Request.Header.Peek("X-Custom-Header"))) + + // 测试使用 nil 清除请求头 + reqCtx2 := createTestRequestCtx("GET", "/test", map[string]string{ + "X-Custom-Header": "custom-value", + }, nil) + coro2, err := engine.NewCoroutine(reqCtx2) + require.NoError(t, err) + + err = coro2.SetupSandbox() + require.NoError(t, err) + + err = coro2.Execute(` + ngx.req.set_header("X-Custom-Header", nil) + `) + assert.NoError(t, err) + coro2.Close() + + assert.Equal(t, "", string(reqCtx2.Request.Header.Peek("X-Custom-Header"))) +} + +// TestNgxReqClearHeader 测试 ngx.req.clear_header() +func TestNgxReqClearHeader(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + reqCtx := createTestRequestCtx("GET", "/test", map[string]string{ + "X-To-Clear": "value", + }, nil) + + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 先验证头存在 + assert.Equal(t, "value", string(reqCtx.Request.Header.Peek("X-To-Clear"))) + + // 清除头 + err = coro.Execute(` + ngx.req.clear_header("X-To-Clear") + `) + assert.NoError(t, err) + + assert.Equal(t, "", string(reqCtx.Request.Header.Peek("X-To-Clear"))) +} + +// TestNgxReqGetBodyData 测试 ngx.req.get_body_data() +func TestNgxReqGetBodyData(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + reqCtx := createTestRequestCtx("POST", "/test", nil, []byte("test body data")) + + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + err = coro.Execute(` + local body = ngx.req.get_body_data() + if body ~= "test body data" then + error("expected 'test body data', got " .. tostring(body)) + end + `) + assert.NoError(t, err) +} + +// TestNgxReqGetBodyDataEmpty 测试 ngx.req.get_body_data() 空请求体 +func TestNgxReqGetBodyDataEmpty(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + reqCtx := createTestRequestCtx("GET", "/test", nil, nil) + + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + err = coro.Execute(` + local body = ngx.req.get_body_data() + if body ~= nil then + error("expected nil for empty body, got " .. tostring(body)) + end + `) + assert.NoError(t, err) +} + +// TestNgxReqReadBody 测试 ngx.req.read_body() +func TestNgxReqReadBody(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + reqCtx := createTestRequestCtx("POST", "/test", map[string]string{ + "Content-Length": "14", + }, []byte("test body data")) + + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // read_body 应该成功执行 + err = coro.Execute(` + ngx.req.read_body() + `) + assert.NoError(t, err) + + // 验证请求体仍可访问 + body := reqCtx.Request.Body() + assert.Equal(t, "test body data", string(body)) +} + +// TestNgxReqAPIIntegration 测试 ngx.req API 集成场景 +func TestNgxReqAPIIntegration(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + reqCtx := createTestRequestCtx("POST", "/api/users?limit=10&offset=20", map[string]string{ + "Content-Type": "application/json", + "X-API-Key": "secret123", + }, []byte(`{"name":"test"}`)) + + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 复杂场景:获取各种请求信息并修改 + err = coro.Execute(` + -- 获取请求信息 + local method = ngx.req.get_method() + local uri = ngx.req.get_uri() + local args = ngx.req.get_uri_args() + local headers = ngx.req.get_headers() + + -- 验证获取的信息 + if method ~= "POST" then + error("method should be POST") + end + + if uri ~= "/api/users" then + error("uri should be /api/users, got " .. tostring(uri)) + end + + if args.limit ~= "10" or args.offset ~= "20" then + error("args incorrect") + end + + -- 注意:fasthttp 会标准化 header 名称,所以需要使用实际的 key + if headers["Content-Type"] ~= "application/json" and headers["content-type"] ~= "application/json" then + error("Content-Type header incorrect: " .. tostring(headers["Content-Type"])) + end + + -- 修改请求 + ngx.req.set_header("X-Request-ID", "req-12345") + ngx.req.set_uri("/api/v2/users") + `) + assert.NoError(t, err) + + // 在 Go 层验证修改 + assert.Equal(t, "/api/v2/users", string(reqCtx.URI().Path())) + assert.Equal(t, "req-12345", string(reqCtx.Request.Header.Peek("X-Request-ID"))) +} + +// TestNgxReqMetrics 测试 ngx.req API 性能指标 +func TestNgxReqMetrics(t *testing.T) { + reqCtx := createTestRequestCtx("GET", "/test?a=1&b=2", nil, nil) + api := newNgxReqAPI(reqCtx) + + L := glua.NewState() + defer L.Close() + + // 创建 ngx 表 + ngx := L.NewTable() + + // 注册 API + RegisterNgxReqAPI(L, api, ngx) + + // 将 ngx 设置到全局 + L.SetGlobal("ngx", ngx) + + // 调用各种 API + L.DoString(` + ngx.req.get_method() + ngx.req.get_uri() + ngx.req.get_uri_args() + `) + + // 验证指标 + metrics := api.GetMetrics() + assert.Greater(t, metrics.DirectCallCount, uint64(0), "应该有直接层调用") + assert.Greater(t, metrics.CompatibleCallCount, uint64(0), "应该有兼容层调用") + + // 验证平均延迟 + directAvg := api.GetDirectLayerAvgNs() + compatibleAvg := api.GetCompatibleLayerAvgNs() + assert.GreaterOrEqual(t, directAvg, float64(0)) + assert.GreaterOrEqual(t, compatibleAvg, float64(0)) + + // 验证性能比率 + ratio := api.GetPerformanceRatio() + assert.GreaterOrEqual(t, ratio, float64(0)) + + // 重置指标 + api.ResetMetrics() + metrics = api.GetMetrics() + assert.Equal(t, uint64(0), metrics.DirectCallCount) + assert.Equal(t, uint64(0), metrics.CompatibleCallCount) +} + +// TestNgxReqGetHeadersWithMaxHeaders 测试 ngx.req.get_headers(max_headers) +func TestNgxReqGetHeadersWithMaxHeaders(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + // 创建带有多个头的请求 + reqCtx := &fasthttp.RequestCtx{} + reqCtx.Request.Header.SetMethod("GET") + reqCtx.Request.SetRequestURI("/test") + reqCtx.Request.Header.Set("Header1", "value1") + reqCtx.Request.Header.Set("Header2", "value2") + reqCtx.Request.Header.Set("Header3", "value3") + + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试限制头数 + err = coro.Execute(` + local headers = ngx.req.get_headers(2) + local count = 0 + for k, v in pairs(headers) do + count = count + 1 + end + -- 应该最多返回 2 个头 + if count > 2 then + error("expected at most 2 headers, got " .. count) + end + `) + assert.NoError(t, err) +} + +// TestNgxReqSetURIArgsWithArray 测试 ngx.req.set_uri_args() 使用数组值 +func TestNgxReqSetURIArgsWithArray(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + reqCtx := createTestRequestCtx("GET", "/test", nil, nil) + + coro, err := engine.NewCoroutine(reqCtx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试使用包含数组的 table + err = coro.Execute(` + ngx.req.set_uri_args({ tags = { "a", "b", "c" }, page = 1 }) + `) + assert.NoError(t, err) + + queryStr := string(reqCtx.URI().QueryString()) + assert.Contains(t, queryStr, "tags=a") + assert.Contains(t, queryStr, "tags=b") + assert.Contains(t, queryStr, "tags=c") + assert.Contains(t, queryStr, "page=1") +} diff --git a/internal/lua/coroutine.go b/internal/lua/coroutine.go index 50d3cb7..5fcb5b5 100644 --- a/internal/lua/coroutine.go +++ b/internal/lua/coroutine.go @@ -117,6 +117,9 @@ func (c *LuaCoroutine) SetupSandbox() error { // Layer 1 & 2: 设置安全的协程库(移除危险函数) c.setupSecureCoroutineLib() + // Layer 3: 设置 ngx API + c.setupNgxAPI() + return nil } @@ -163,6 +166,48 @@ func (c *LuaCoroutine) setupSecureCoroutineLib() { // 因为协程继承的是引擎全局环境,而我们在协程级别设置了独立的 coroutine 表 } +// setupNgxAPI 创建 ngx API +// 注册 ngx.req、ngx.resp、ngx.var、ngx.ctx、ngx.log 和 ngx.socket API +func (c *LuaCoroutine) setupNgxAPI() { + // 检查是否已有 ngx 表(可能已由其他 API 注册) + existingNgx := c.Co.GetGlobal("ngx") + var ngx *glua.LTable + if existingTbl, ok := existingNgx.(*glua.LTable); ok { + ngx = existingTbl + } else { + // 创建 ngx 表 + ngx = c.Co.NewTable() + } + + // 注册 ngx.req API + if c.RequestCtx != nil { + reqAPI := newNgxReqAPI(c.RequestCtx) + RegisterNgxReqAPI(c.Co, reqAPI, ngx) + + // 注册 ngx.resp API + respAPI := newNgxRespAPI(c.RequestCtx) + RegisterNgxRespAPI(c.Co, respAPI) + + // 注册 ngx.log API (logger 为 nil 时禁用日志输出) + // ngx.say/print/flush 直接写入 RequestCtx + logAPI := newNgxLogAPI(c.RequestCtx, nil, nil) + RegisterNgxLogAPI(c.Co, logAPI) + } + + // 注册 ngx.var API + varAPI := newNgxVarAPI(c.RequestCtx) + RegisterNgxVarAPI(c.Co, varAPI, ngx) + + // 注册 ngx.ctx API + RegisterNgxCtxAPI(c.Co, ngx) + + // 注册 ngx.socket API + RegisterTCPSocketAPI(c.Co, c.Engine) + + // 将 ngx 表设置到协程环境 + c.Co.SetGlobal("ngx", ngx) +} + // Execute 在协程中执行 Lua 脚本(支持 Yield/Resume) func (c *LuaCoroutine) Execute(script string) error { proto, err := c.Engine.codeCache.GetOrCompileInline(script) diff --git a/internal/lua/lua_test.go b/internal/lua/lua_test.go index 8dc24bf..6c39b2b 100644 --- a/internal/lua/lua_test.go +++ b/internal/lua/lua_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" ) // TestLuaContext 测试 LuaContext 基础功能 @@ -456,3 +457,118 @@ func TestConfig(t *testing.T) { assert.Equal(t, 100, engine.maxCoroutines) } + +// TestNgxAPIRegistrationInSandbox 测试所有 ngx API 在沙箱中的注册 +func TestNgxAPIRegistrationInSandbox(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + // 创建 mock RequestCtx(ngx.req/resp/log API 需要 RequestCtx) + mockCtx := &fasthttp.RequestCtx{} + + coro, err := engine.NewCoroutine(mockCtx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 验证 ngx 表存在 + err = coro.Execute(` + assert(ngx ~= nil, "ngx table should exist") + assert(type(ngx) == "table", "ngx should be a table") + `) + assert.NoError(t, err) + + // 验证 ngx.req API 存在 + coro2, err := engine.NewCoroutine(mockCtx) + require.NoError(t, err) + defer coro2.Close() + err = coro2.SetupSandbox() + require.NoError(t, err) + err = coro2.Execute(` + assert(ngx.req ~= nil, "ngx.req should exist") + assert(type(ngx.req.get_method) == "function", "ngx.req.get_method should be a function") + assert(type(ngx.req.get_uri) == "function", "ngx.req.get_uri should be a function") + assert(type(ngx.req.set_uri) == "function", "ngx.req.set_uri should be a function") + assert(type(ngx.req.get_uri_args) == "function", "ngx.req.get_uri_args should be a function") + assert(type(ngx.req.get_headers) == "function", "ngx.req.get_headers should be a function") + assert(type(ngx.req.set_header) == "function", "ngx.req.set_header should be a function") + assert(type(ngx.req.clear_header) == "function", "ngx.req.clear_header should be a function") + assert(type(ngx.req.get_body_data) == "function", "ngx.req.get_body_data should be a function") + `) + assert.NoError(t, err) + + // 验证 ngx.resp API 存在 + coro3, err := engine.NewCoroutine(mockCtx) + require.NoError(t, err) + defer coro3.Close() + err = coro3.SetupSandbox() + require.NoError(t, err) + err = coro3.Execute(` + assert(ngx.resp ~= nil, "ngx.resp should exist") + assert(type(ngx.resp.get_status) == "function", "ngx.resp.get_status should be a function") + assert(type(ngx.resp.set_status) == "function", "ngx.resp.set_status should be a function") + assert(type(ngx.resp.get_headers) == "function", "ngx.resp.get_headers should be a function") + assert(type(ngx.resp.set_header) == "function", "ngx.resp.set_header should be a function") + assert(type(ngx.resp.clear_header) == "function", "ngx.resp.clear_header should be a function") + `) + assert.NoError(t, err) + + // 验证 ngx.var API 存在 + coro4, err := engine.NewCoroutine(mockCtx) + require.NoError(t, err) + defer coro4.Close() + err = coro4.SetupSandbox() + require.NoError(t, err) + err = coro4.Execute(` + assert(ngx.var ~= nil, "ngx.var should exist") + `) + assert.NoError(t, err) + + // 验证 ngx.ctx API 存在 + coro5, err := engine.NewCoroutine(mockCtx) + require.NoError(t, err) + defer coro5.Close() + err = coro5.SetupSandbox() + require.NoError(t, err) + err = coro5.Execute(` + assert(ngx.ctx ~= nil, "ngx.ctx should exist") + assert(type(ngx.ctx) == "table", "ngx.ctx should be a table") + `) + assert.NoError(t, err) + + // 验证 ngx.log API 存在(日志级别常量和函数) + coro6, err := engine.NewCoroutine(mockCtx) + require.NoError(t, err) + defer coro6.Close() + err = coro6.SetupSandbox() + require.NoError(t, err) + err = coro6.Execute(` + assert(ngx.log ~= nil, "ngx.log should exist") + assert(type(ngx.log) == "function", "ngx.log should be a function") + assert(ngx.ERR ~= nil, "ngx.ERR should exist") + assert(ngx.WARN ~= nil, "ngx.WARN should exist") + assert(ngx.INFO ~= nil, "ngx.INFO should exist") + assert(ngx.DEBUG ~= nil, "ngx.DEBUG should exist") + assert(type(ngx.say) == "function", "ngx.say should be a function") + assert(type(ngx.print) == "function", "ngx.print should be a function") + assert(type(ngx.flush) == "function", "ngx.flush should be a function") + assert(type(ngx.exit) == "function", "ngx.exit should be a function") + assert(type(ngx.redirect) == "function", "ngx.redirect should be a function") + `) + assert.NoError(t, err) + + // 验证 ngx.socket API 存在 + coro7, err := engine.NewCoroutine(mockCtx) + require.NoError(t, err) + defer coro7.Close() + err = coro7.SetupSandbox() + require.NoError(t, err) + err = coro7.Execute(` + assert(ngx.socket ~= nil, "ngx.socket should exist") + assert(type(ngx.socket.tcp) == "function", "ngx.socket.tcp should be a function") + `) + assert.NoError(t, err) +}