feat(lua): 扩展 ngx.req API 并集成所有 ngx API 到沙箱

扩展 API: set_uri, set_uri_args, get_headers, set_header, clear_header, get_body_data, read_body
在 coroutine.SetupSandbox() 中统一注册 ngx.req/resp/var/ctx/log/socket API

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-11 12:17:35 +08:00
parent 86e5b0e6f1
commit 6b9df86217
4 changed files with 1017 additions and 13 deletions

View File

@ -83,10 +83,7 @@ func newNgxReqAPI(ctx *fasthttp.RequestCtx) *ngxReqAPI {
// RegisterNgxReqAPI 在 Lua 状态机中注册 ngx.req API
// 这是主入口函数,由 LuaEngine 在初始化时调用
func RegisterNgxReqAPI(L *glua.LState, api *ngxReqAPI) {
// 创建 ngx 表
ngx := L.NewTable()
func RegisterNgxReqAPI(L *glua.LState, api *ngxReqAPI, ngxTable *glua.LTable) {
// 创建 ngx.req 子表
ngxReq := L.NewTable()
@ -98,21 +95,41 @@ func RegisterNgxReqAPI(L *glua.LState, api *ngxReqAPI) {
// 特点:直接返回请求的 URI 路径(不含 query string
ngxReq.RawSetString("get_uri", L.NewFunction(api.luaGetURI))
// 直接映射层 APIset_uri
// 特点:直接修改请求的 URI 路径,支持可选的内部跳转标记
ngxReq.RawSetString("set_uri", L.NewFunction(api.luaSetURI))
// 兼容层 APIget_uri_args
// 特点:需要解析 query string 为 nginx 兼容的表结构
// 增加了解析开销,但保持 API 兼容性
ngxReq.RawSetString("get_uri_args", L.NewFunction(api.luaGetURIArgs))
// 兼容层 APIset_uri_args
// 特点:支持 table 或 string 参数设置查询参数
ngxReq.RawSetString("set_uri_args", L.NewFunction(api.luaSetURIArgs))
// 兼容层 APIget_headers
// 特点:需要遍历所有请求头,模拟 nginx 的头表结构
ngxReq.RawSetString("get_headers", L.NewFunction(api.luaGetHeaders))
// 直接映射层 APIset_header
// 特点:直接操作 fasthttp 请求头,支持设置和清除
ngxReq.RawSetString("set_header", L.NewFunction(api.luaSetHeader))
// 直接映射层 APIclear_header
// 特点:直接删除 fasthttp 请求头
ngxReq.RawSetString("clear_header", L.NewFunction(api.luaClearHeader))
// 兼容层 APIget_body_data
// 特点:获取请求体内容
ngxReq.RawSetString("get_body_data", L.NewFunction(api.luaGetBodyData))
// 伪非阻塞层 APIread_body
// 特点:使用 yield/resume 模式支持异步读取
// 这是实验性 API展示非阻塞调用模式
ngxReq.RawSetString("read_body", L.NewFunction(api.luaReadBodyAsync))
// 特点确保请求体已被读取fasthttp 已预读)
ngxReq.RawSetString("read_body", L.NewFunction(api.luaReadBody))
// 将 ngx.req 添加到 ngx
ngx.RawSetString("req", ngxReq)
// 注册 ngx 全局变量
L.SetGlobal("ngx", ngx)
// 将 ngx.req 添加到 ngx 表
ngxTable.RawSetString("req", ngxReq)
}
// ==================== 直接映射层 API ====================
@ -138,6 +155,43 @@ func (api *ngxReqAPI) luaGetMethod(L *glua.LState) int {
return 1
}
// luaSetURI 实现 ngx.req.set_uri(uri, jump?) - 直接映射层
// Lua 调用: ngx.req.set_uri("/new/path") 或 ngx.req.set_uri("/new/path", true)
// 参数:
// - uri: 新的 URI 路径
// - jump: 是否触发内部跳转(可选,默认为 false
func (api *ngxReqAPI) luaSetURI(L *glua.LState) int {
start := time.Now()
// 获取 uri 参数
uri := L.CheckString(1)
// 获取可选的 jump 参数
jump := false
if L.GetTop() >= 2 {
jump = L.ToBool(2)
}
// 设置新的 URI
api.ctx.Request.URI().SetPath(uri)
// 如果 jump 为 true记录内部跳转标记供后续处理使用
if jump {
// 在请求上下文中存储跳转标记
api.ctx.SetUserValue("_ngx_req_internal_jump", true)
}
// 记录指标
elapsed := uint64(time.Since(start).Nanoseconds())
api.metrics.DirectCallCount++
api.metrics.DirectTotalNs += elapsed
if elapsed > api.metrics.DirectMaxNs {
api.metrics.DirectMaxNs = elapsed
}
return 0
}
// luaGetURI 实现 ngx.req.get_uri() - 直接映射层
// Lua 调用: local uri = ngx.req.get_uri()
// 返回: string (如 "/path/to/resource")
@ -225,7 +279,241 @@ func (api *ngxReqAPI) parseURIArgs() map[string][]string {
return args
}
// ==================== 伪非阻塞层 API实验性 ====================
// luaSetURIArgs 实现 ngx.req.set_uri_args(args) - 兼容层
// Lua 调用: ngx.req.set_uri_args({ key = "value" }) 或 ngx.req.set_uri_args("key=value&foo=bar")
// 参数:
// - args: table 或 string 类型的查询参数
func (api *ngxReqAPI) luaSetURIArgs(L *glua.LState) int {
start := time.Now()
// 获取参数类型
argType := L.Get(1)
switch argType.Type() {
case glua.LTString:
// 如果是字符串,直接解析并设置
queryStr := string(argType.(glua.LString))
api.ctx.Request.URI().SetQueryString(queryStr)
case glua.LTTable:
// 如果是 table构建查询字符串
table := argType.(*glua.LTable)
args := make(map[string][]string)
table.ForEach(func(key, value glua.LValue) {
keyStr := glua.LVAsString(key)
switch value.Type() {
case glua.LTString:
args[keyStr] = []string{string(value.(glua.LString))}
case glua.LTNumber:
args[keyStr] = []string{glua.LVAsString(value)}
case glua.LTTable:
// 数组形式的多值
arr := value.(*glua.LTable)
values := []string{}
arr.ForEach(func(_, v glua.LValue) {
values = append(values, glua.LVAsString(v))
})
args[keyStr] = values
}
})
// 构建查询字符串
if len(args) > 0 {
query := fasthttp.Args{}
for key, values := range args {
for _, v := range values {
query.Add(key, v)
}
}
api.ctx.Request.URI().SetQueryString(query.String())
}
default:
L.RaiseError("set_uri_args expects table or string, got %s", argType.Type().String())
return 0
}
// 记录指标
elapsed := uint64(time.Since(start).Nanoseconds())
api.metrics.CompatibleCallCount++
api.metrics.CompatibleTotalNs += elapsed
if elapsed > api.metrics.CompatibleMaxNs {
api.metrics.CompatibleMaxNs = elapsed
}
return 0
}
// ==================== 请求头 API ====================
// luaGetHeaders 实现 ngx.req.get_headers(max_headers?) - 兼容层
// Lua 调用: local headers = ngx.req.get_headers() 或 ngx.req.get_headers(50)
// 返回: table (如 { ["host"] = "example.com", ["cookie"] = { "a=1", "b=2" } })
// 注意:兼容层需要遍历所有请求头,模拟 nginx 的头表结构
func (api *ngxReqAPI) luaGetHeaders(L *glua.LState) int {
start := time.Now()
// 获取可选的 max_headers 参数
maxHeaders := 100 // 默认最大头数
if L.GetTop() >= 1 {
maxHeaders = L.ToInt(1)
if maxHeaders <= 0 {
maxHeaders = 100
}
}
// 构建 Lua 表
result := L.NewTable()
headers := &api.ctx.Request.Header
count := 0
// 使用 VisitAll 遍历所有请求头
headers.VisitAll(func(key, value []byte) {
if count >= maxHeaders {
return
}
keyStr := string(key)
valueStr := string(value)
// 检查是否已存在同名头(多值头)
existing := result.RawGetString(keyStr)
if existing == glua.LNil {
// 第一次遇到这个头
result.RawSetString(keyStr, glua.LString(valueStr))
} else if existingStr, ok := existing.(glua.LString); ok {
// 第二次遇到,需要转换为数组
arr := L.NewTable()
arr.Append(existingStr)
arr.Append(glua.LString(valueStr))
result.RawSetString(keyStr, arr)
} else if existingArr, ok := existing.(*glua.LTable); ok {
// 已经是数组,追加
existingArr.Append(glua.LString(valueStr))
}
count++
})
// 记录指标
elapsed := uint64(time.Since(start).Nanoseconds())
api.metrics.CompatibleCallCount++
api.metrics.CompatibleTotalNs += elapsed
if elapsed > api.metrics.CompatibleMaxNs {
api.metrics.CompatibleMaxNs = elapsed
}
L.Push(result)
return 1
}
// luaSetHeader 实现 ngx.req.set_header(key, value) - 直接映射层
// Lua 调用: ngx.req.set_header("X-Custom", "value") 或 ngx.req.set_header("X-Custom", nil) 清除
// 参数:
// - key: 头名称
// - value: 头值,如果为 nil 则清除该头
func (api *ngxReqAPI) luaSetHeader(L *glua.LState) int {
start := time.Now()
// 获取参数
key := L.CheckString(1)
value := L.Get(2)
if value == glua.LNil {
// 值为 nil删除头
api.ctx.Request.Header.Del(key)
} else {
// 设置头值
valueStr := glua.LVAsString(value)
api.ctx.Request.Header.Set(key, valueStr)
}
// 记录指标
elapsed := uint64(time.Since(start).Nanoseconds())
api.metrics.DirectCallCount++
api.metrics.DirectTotalNs += elapsed
if elapsed > api.metrics.DirectMaxNs {
api.metrics.DirectMaxNs = elapsed
}
return 0
}
// luaClearHeader 实现 ngx.req.clear_header(key) - 直接映射层
// Lua 调用: ngx.req.clear_header("X-Custom")
// 参数:
// - key: 要清除的头名称
func (api *ngxReqAPI) luaClearHeader(L *glua.LState) int {
start := time.Now()
// 获取参数
key := L.CheckString(1)
// 删除头
api.ctx.Request.Header.Del(key)
// 记录指标
elapsed := uint64(time.Since(start).Nanoseconds())
api.metrics.DirectCallCount++
api.metrics.DirectTotalNs += elapsed
if elapsed > api.metrics.DirectMaxNs {
api.metrics.DirectMaxNs = elapsed
}
return 0
}
// ==================== 请求体 API ====================
// luaGetBodyData 实现 ngx.req.get_body_data() - 兼容层
// Lua 调用: local body = ngx.req.get_body_data()
// 返回: string 或 nil如果没有请求体
func (api *ngxReqAPI) luaGetBodyData(L *glua.LState) int {
start := time.Now()
// 获取请求体
body := api.ctx.Request.Body()
if len(body) == 0 {
L.Push(glua.LNil)
} else {
L.Push(glua.LString(body))
}
// 记录指标
elapsed := uint64(time.Since(start).Nanoseconds())
api.metrics.CompatibleCallCount++
api.metrics.CompatibleTotalNs += elapsed
if elapsed > api.metrics.CompatibleMaxNs {
api.metrics.CompatibleMaxNs = elapsed
}
return 1
}
// luaReadBody 实现 ngx.req.read_body() - 伪非阻塞层
// Lua 调用: ngx.req.read_body() -- 完成后返回
// 注意fasthttp 已经预读取了请求体,这里主要是确保请求体已被读取
func (api *ngxReqAPI) luaReadBody(L *glua.LState) int {
start := time.Now()
// fasthttp 默认会预读取请求体到内存中
// 这里我们只需要确保请求体已被读取(对于 POST/PUT 等请求)
// 如果请求体未读取,触发读取
if api.ctx.Request.Header.ContentLength() > 0 {
// 访问 Body() 会确保请求体已被读取
_ = api.ctx.Request.Body()
}
// 记录指标(使用伪非阻塞层指标)
elapsed := uint64(time.Since(start).Nanoseconds())
api.metrics.PseudoBlockingCallCount++
api.metrics.PseudoBlockingTotalNs += elapsed
if elapsed > api.metrics.PseudoBlockingMaxNs {
api.metrics.PseudoBlockingMaxNs = elapsed
}
return 0
}
// luaReadBodyAsync 实现 ngx.req.read_body() - 伪非阻塞层
// Lua 调用: ngx.req.read_body() -- 会 yield完成后 resume

View File

@ -0,0 +1,555 @@
// Package lua 提供 ngx.req API 测试
package lua
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
glua "github.com/yuin/gopher-lua"
)
// 创建测试用的 fasthttp.RequestCtx
func createTestRequestCtx(method, uri string, headers map[string]string, body []byte) *fasthttp.RequestCtx {
ctx := &fasthttp.RequestCtx{}
// 设置请求
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI(uri)
// 设置请求头
for key, value := range headers {
ctx.Request.Header.Set(key, value)
}
// 设置请求体
if len(body) > 0 {
ctx.Request.SetBody(body)
}
return ctx
}
// TestNgxReqGetMethod 测试 ngx.req.get_method()
func TestNgxReqGetMethod(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
// 创建测试请求
reqCtx := createTestRequestCtx("POST", "/test", nil, nil)
// 创建协程
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
defer coro.Close()
// 设置沙箱(这会自动注册 ngx API
err = coro.SetupSandbox()
require.NoError(t, err)
// 测试获取请求方法
err = coro.Execute(`
local method = ngx.req.get_method()
if method ~= "POST" then
error("expected POST, got " .. tostring(method))
end
`)
assert.NoError(t, err)
}
// TestNgxReqGetURI 测试 ngx.req.get_uri()
func TestNgxReqGetURI(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
reqCtx := createTestRequestCtx("GET", "/path/to/resource?key=value", nil, nil)
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
defer coro.Close()
err = coro.SetupSandbox()
require.NoError(t, err)
err = coro.Execute(`
local uri = ngx.req.get_uri()
if uri ~= "/path/to/resource" then
error("expected /path/to/resource, got " .. tostring(uri))
end
`)
assert.NoError(t, err)
}
// TestNgxReqSetURI 测试 ngx.req.set_uri()
func TestNgxReqSetURI(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
// 测试设置 URI不带 jump
reqCtx := createTestRequestCtx("GET", "/original", nil, nil)
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
err = coro.SetupSandbox()
require.NoError(t, err)
err = coro.Execute(`
ngx.req.set_uri("/new/path")
`)
assert.NoError(t, err)
coro.Close()
// 验证 URI 已修改
assert.Equal(t, "/new/path", string(reqCtx.URI().Path()))
// 测试设置 URI带 jump
reqCtx2 := createTestRequestCtx("GET", "/original", nil, nil)
coro2, err := engine.NewCoroutine(reqCtx2)
require.NoError(t, err)
err = coro2.SetupSandbox()
require.NoError(t, err)
err = coro2.Execute(`
ngx.req.set_uri("/redirect/path", true)
`)
assert.NoError(t, err)
coro2.Close()
assert.Equal(t, "/redirect/path", string(reqCtx2.URI().Path()))
// 验证 jump 标记已设置
jumpFlag := reqCtx2.UserValue("_ngx_req_internal_jump")
assert.Equal(t, true, jumpFlag)
}
// TestNgxReqGetURIArgs 测试 ngx.req.get_uri_args()
func TestNgxReqGetURIArgs(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
reqCtx := createTestRequestCtx("GET", "/test?foo=bar&baz=qux&arr=1&arr=2", nil, nil)
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
defer coro.Close()
err = coro.SetupSandbox()
require.NoError(t, err)
err = coro.Execute(`
local args = ngx.req.get_uri_args()
if args.foo ~= "bar" then
error("expected foo=bar, got " .. tostring(args.foo))
end
if args.baz ~= "qux" then
error("expected baz=qux, got " .. tostring(args.baz))
end
-- 多值参数应该返回数组
if type(args.arr) ~= "table" then
error("expected arr to be table, got " .. type(args.arr))
end
if args.arr[1] ~= "1" or args.arr[2] ~= "2" then
error("expected arr = {1, 2}")
end
`)
assert.NoError(t, err)
}
// TestNgxReqSetURIArgs 测试 ngx.req.set_uri_args()
func TestNgxReqSetURIArgs(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
reqCtx := createTestRequestCtx("GET", "/test", nil, nil)
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
defer coro.Close()
err = coro.SetupSandbox()
require.NoError(t, err)
// 测试使用 table 设置参数
err = coro.Execute(`
ngx.req.set_uri_args({ key = "value", num = 123 })
`)
assert.NoError(t, err)
queryStr := string(reqCtx.URI().QueryString())
assert.Contains(t, queryStr, "key=value")
assert.Contains(t, queryStr, "num=123")
// 测试使用字符串设置参数
reqCtx2 := createTestRequestCtx("GET", "/test", nil, nil)
coro2, err := engine.NewCoroutine(reqCtx2)
require.NoError(t, err)
defer coro2.Close()
err = coro2.SetupSandbox()
require.NoError(t, err)
err = coro2.Execute(`
ngx.req.set_uri_args("foo=bar&baz=qux")
`)
assert.NoError(t, err)
assert.Equal(t, "foo=bar&baz=qux", string(reqCtx2.URI().QueryString()))
}
// TestNgxReqGetHeaders 测试 ngx.req.get_headers()
func TestNgxReqGetHeaders(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
reqCtx := createTestRequestCtx("GET", "/test", map[string]string{
"Host": "example.com",
"X-Custom": "custom-value",
"Content-Type": "application/json",
}, nil)
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
defer coro.Close()
err = coro.SetupSandbox()
require.NoError(t, err)
err = coro.Execute(`
local headers = ngx.req.get_headers()
if headers.Host ~= "example.com" then
error("expected Host=example.com, got " .. tostring(headers.Host))
end
if headers["X-Custom"] ~= "custom-value" then
error("expected X-Custom=custom-value")
end
if headers["Content-Type"] ~= "application/json" then
error("expected Content-Type=application/json")
end
`)
assert.NoError(t, err)
}
// TestNgxReqSetHeader 测试 ngx.req.set_header()
func TestNgxReqSetHeader(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
// 测试设置请求头
reqCtx := createTestRequestCtx("GET", "/test", nil, nil)
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
err = coro.SetupSandbox()
require.NoError(t, err)
err = coro.Execute(`
ngx.req.set_header("X-Custom-Header", "custom-value")
`)
assert.NoError(t, err)
coro.Close()
assert.Equal(t, "custom-value", string(reqCtx.Request.Header.Peek("X-Custom-Header")))
// 测试使用 nil 清除请求头
reqCtx2 := createTestRequestCtx("GET", "/test", map[string]string{
"X-Custom-Header": "custom-value",
}, nil)
coro2, err := engine.NewCoroutine(reqCtx2)
require.NoError(t, err)
err = coro2.SetupSandbox()
require.NoError(t, err)
err = coro2.Execute(`
ngx.req.set_header("X-Custom-Header", nil)
`)
assert.NoError(t, err)
coro2.Close()
assert.Equal(t, "", string(reqCtx2.Request.Header.Peek("X-Custom-Header")))
}
// TestNgxReqClearHeader 测试 ngx.req.clear_header()
func TestNgxReqClearHeader(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
reqCtx := createTestRequestCtx("GET", "/test", map[string]string{
"X-To-Clear": "value",
}, nil)
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
defer coro.Close()
err = coro.SetupSandbox()
require.NoError(t, err)
// 先验证头存在
assert.Equal(t, "value", string(reqCtx.Request.Header.Peek("X-To-Clear")))
// 清除头
err = coro.Execute(`
ngx.req.clear_header("X-To-Clear")
`)
assert.NoError(t, err)
assert.Equal(t, "", string(reqCtx.Request.Header.Peek("X-To-Clear")))
}
// TestNgxReqGetBodyData 测试 ngx.req.get_body_data()
func TestNgxReqGetBodyData(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
reqCtx := createTestRequestCtx("POST", "/test", nil, []byte("test body data"))
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
defer coro.Close()
err = coro.SetupSandbox()
require.NoError(t, err)
err = coro.Execute(`
local body = ngx.req.get_body_data()
if body ~= "test body data" then
error("expected 'test body data', got " .. tostring(body))
end
`)
assert.NoError(t, err)
}
// TestNgxReqGetBodyDataEmpty 测试 ngx.req.get_body_data() 空请求体
func TestNgxReqGetBodyDataEmpty(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
reqCtx := createTestRequestCtx("GET", "/test", nil, nil)
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
defer coro.Close()
err = coro.SetupSandbox()
require.NoError(t, err)
err = coro.Execute(`
local body = ngx.req.get_body_data()
if body ~= nil then
error("expected nil for empty body, got " .. tostring(body))
end
`)
assert.NoError(t, err)
}
// TestNgxReqReadBody 测试 ngx.req.read_body()
func TestNgxReqReadBody(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
reqCtx := createTestRequestCtx("POST", "/test", map[string]string{
"Content-Length": "14",
}, []byte("test body data"))
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
defer coro.Close()
err = coro.SetupSandbox()
require.NoError(t, err)
// read_body 应该成功执行
err = coro.Execute(`
ngx.req.read_body()
`)
assert.NoError(t, err)
// 验证请求体仍可访问
body := reqCtx.Request.Body()
assert.Equal(t, "test body data", string(body))
}
// TestNgxReqAPIIntegration 测试 ngx.req API 集成场景
func TestNgxReqAPIIntegration(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
reqCtx := createTestRequestCtx("POST", "/api/users?limit=10&offset=20", map[string]string{
"Content-Type": "application/json",
"X-API-Key": "secret123",
}, []byte(`{"name":"test"}`))
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
defer coro.Close()
err = coro.SetupSandbox()
require.NoError(t, err)
// 复杂场景:获取各种请求信息并修改
err = coro.Execute(`
-- 获取请求信息
local method = ngx.req.get_method()
local uri = ngx.req.get_uri()
local args = ngx.req.get_uri_args()
local headers = ngx.req.get_headers()
-- 验证获取的信息
if method ~= "POST" then
error("method should be POST")
end
if uri ~= "/api/users" then
error("uri should be /api/users, got " .. tostring(uri))
end
if args.limit ~= "10" or args.offset ~= "20" then
error("args incorrect")
end
-- 注意fasthttp 会标准化 header 名称所以需要使用实际的 key
if headers["Content-Type"] ~= "application/json" and headers["content-type"] ~= "application/json" then
error("Content-Type header incorrect: " .. tostring(headers["Content-Type"]))
end
-- 修改请求
ngx.req.set_header("X-Request-ID", "req-12345")
ngx.req.set_uri("/api/v2/users")
`)
assert.NoError(t, err)
// 在 Go 层验证修改
assert.Equal(t, "/api/v2/users", string(reqCtx.URI().Path()))
assert.Equal(t, "req-12345", string(reqCtx.Request.Header.Peek("X-Request-ID")))
}
// TestNgxReqMetrics 测试 ngx.req API 性能指标
func TestNgxReqMetrics(t *testing.T) {
reqCtx := createTestRequestCtx("GET", "/test?a=1&b=2", nil, nil)
api := newNgxReqAPI(reqCtx)
L := glua.NewState()
defer L.Close()
// 创建 ngx 表
ngx := L.NewTable()
// 注册 API
RegisterNgxReqAPI(L, api, ngx)
// 将 ngx 设置到全局
L.SetGlobal("ngx", ngx)
// 调用各种 API
L.DoString(`
ngx.req.get_method()
ngx.req.get_uri()
ngx.req.get_uri_args()
`)
// 验证指标
metrics := api.GetMetrics()
assert.Greater(t, metrics.DirectCallCount, uint64(0), "应该有直接层调用")
assert.Greater(t, metrics.CompatibleCallCount, uint64(0), "应该有兼容层调用")
// 验证平均延迟
directAvg := api.GetDirectLayerAvgNs()
compatibleAvg := api.GetCompatibleLayerAvgNs()
assert.GreaterOrEqual(t, directAvg, float64(0))
assert.GreaterOrEqual(t, compatibleAvg, float64(0))
// 验证性能比率
ratio := api.GetPerformanceRatio()
assert.GreaterOrEqual(t, ratio, float64(0))
// 重置指标
api.ResetMetrics()
metrics = api.GetMetrics()
assert.Equal(t, uint64(0), metrics.DirectCallCount)
assert.Equal(t, uint64(0), metrics.CompatibleCallCount)
}
// TestNgxReqGetHeadersWithMaxHeaders 测试 ngx.req.get_headers(max_headers)
func TestNgxReqGetHeadersWithMaxHeaders(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
// 创建带有多个头的请求
reqCtx := &fasthttp.RequestCtx{}
reqCtx.Request.Header.SetMethod("GET")
reqCtx.Request.SetRequestURI("/test")
reqCtx.Request.Header.Set("Header1", "value1")
reqCtx.Request.Header.Set("Header2", "value2")
reqCtx.Request.Header.Set("Header3", "value3")
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
defer coro.Close()
err = coro.SetupSandbox()
require.NoError(t, err)
// 测试限制头数
err = coro.Execute(`
local headers = ngx.req.get_headers(2)
local count = 0
for k, v in pairs(headers) do
count = count + 1
end
-- 应该最多返回 2 个头
if count > 2 then
error("expected at most 2 headers, got " .. count)
end
`)
assert.NoError(t, err)
}
// TestNgxReqSetURIArgsWithArray 测试 ngx.req.set_uri_args() 使用数组值
func TestNgxReqSetURIArgsWithArray(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
reqCtx := createTestRequestCtx("GET", "/test", nil, nil)
coro, err := engine.NewCoroutine(reqCtx)
require.NoError(t, err)
defer coro.Close()
err = coro.SetupSandbox()
require.NoError(t, err)
// 测试使用包含数组的 table
err = coro.Execute(`
ngx.req.set_uri_args({ tags = { "a", "b", "c" }, page = 1 })
`)
assert.NoError(t, err)
queryStr := string(reqCtx.URI().QueryString())
assert.Contains(t, queryStr, "tags=a")
assert.Contains(t, queryStr, "tags=b")
assert.Contains(t, queryStr, "tags=c")
assert.Contains(t, queryStr, "page=1")
}

View File

@ -117,6 +117,9 @@ func (c *LuaCoroutine) SetupSandbox() error {
// Layer 1 & 2: 设置安全的协程库(移除危险函数)
c.setupSecureCoroutineLib()
// Layer 3: 设置 ngx API
c.setupNgxAPI()
return nil
}
@ -163,6 +166,48 @@ func (c *LuaCoroutine) setupSecureCoroutineLib() {
// 因为协程继承的是引擎全局环境,而我们在协程级别设置了独立的 coroutine 表
}
// setupNgxAPI 创建 ngx API
// 注册 ngx.req、ngx.resp、ngx.var、ngx.ctx、ngx.log 和 ngx.socket API
func (c *LuaCoroutine) setupNgxAPI() {
// 检查是否已有 ngx 表(可能已由其他 API 注册)
existingNgx := c.Co.GetGlobal("ngx")
var ngx *glua.LTable
if existingTbl, ok := existingNgx.(*glua.LTable); ok {
ngx = existingTbl
} else {
// 创建 ngx 表
ngx = c.Co.NewTable()
}
// 注册 ngx.req API
if c.RequestCtx != nil {
reqAPI := newNgxReqAPI(c.RequestCtx)
RegisterNgxReqAPI(c.Co, reqAPI, ngx)
// 注册 ngx.resp API
respAPI := newNgxRespAPI(c.RequestCtx)
RegisterNgxRespAPI(c.Co, respAPI)
// 注册 ngx.log API (logger 为 nil 时禁用日志输出)
// ngx.say/print/flush 直接写入 RequestCtx
logAPI := newNgxLogAPI(c.RequestCtx, nil, nil)
RegisterNgxLogAPI(c.Co, logAPI)
}
// 注册 ngx.var API
varAPI := newNgxVarAPI(c.RequestCtx)
RegisterNgxVarAPI(c.Co, varAPI, ngx)
// 注册 ngx.ctx API
RegisterNgxCtxAPI(c.Co, ngx)
// 注册 ngx.socket API
RegisterTCPSocketAPI(c.Co, c.Engine)
// 将 ngx 表设置到协程环境
c.Co.SetGlobal("ngx", ngx)
}
// Execute 在协程中执行 Lua 脚本(支持 Yield/Resume
func (c *LuaCoroutine) Execute(script string) error {
proto, err := c.Engine.codeCache.GetOrCompileInline(script)

View File

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
// TestLuaContext 测试 LuaContext 基础功能
@ -456,3 +457,118 @@ func TestConfig(t *testing.T) {
assert.Equal(t, 100, engine.maxCoroutines)
}
// TestNgxAPIRegistrationInSandbox 测试所有 ngx API 在沙箱中的注册
func TestNgxAPIRegistrationInSandbox(t *testing.T) {
engine, err := NewEngine(DefaultConfig())
require.NoError(t, err)
defer engine.Close()
// 创建 mock RequestCtxngx.req/resp/log API 需要 RequestCtx
mockCtx := &fasthttp.RequestCtx{}
coro, err := engine.NewCoroutine(mockCtx)
require.NoError(t, err)
defer coro.Close()
err = coro.SetupSandbox()
require.NoError(t, err)
// 验证 ngx 表存在
err = coro.Execute(`
assert(ngx ~= nil, "ngx table should exist")
assert(type(ngx) == "table", "ngx should be a table")
`)
assert.NoError(t, err)
// 验证 ngx.req API 存在
coro2, err := engine.NewCoroutine(mockCtx)
require.NoError(t, err)
defer coro2.Close()
err = coro2.SetupSandbox()
require.NoError(t, err)
err = coro2.Execute(`
assert(ngx.req ~= nil, "ngx.req should exist")
assert(type(ngx.req.get_method) == "function", "ngx.req.get_method should be a function")
assert(type(ngx.req.get_uri) == "function", "ngx.req.get_uri should be a function")
assert(type(ngx.req.set_uri) == "function", "ngx.req.set_uri should be a function")
assert(type(ngx.req.get_uri_args) == "function", "ngx.req.get_uri_args should be a function")
assert(type(ngx.req.get_headers) == "function", "ngx.req.get_headers should be a function")
assert(type(ngx.req.set_header) == "function", "ngx.req.set_header should be a function")
assert(type(ngx.req.clear_header) == "function", "ngx.req.clear_header should be a function")
assert(type(ngx.req.get_body_data) == "function", "ngx.req.get_body_data should be a function")
`)
assert.NoError(t, err)
// 验证 ngx.resp API 存在
coro3, err := engine.NewCoroutine(mockCtx)
require.NoError(t, err)
defer coro3.Close()
err = coro3.SetupSandbox()
require.NoError(t, err)
err = coro3.Execute(`
assert(ngx.resp ~= nil, "ngx.resp should exist")
assert(type(ngx.resp.get_status) == "function", "ngx.resp.get_status should be a function")
assert(type(ngx.resp.set_status) == "function", "ngx.resp.set_status should be a function")
assert(type(ngx.resp.get_headers) == "function", "ngx.resp.get_headers should be a function")
assert(type(ngx.resp.set_header) == "function", "ngx.resp.set_header should be a function")
assert(type(ngx.resp.clear_header) == "function", "ngx.resp.clear_header should be a function")
`)
assert.NoError(t, err)
// 验证 ngx.var API 存在
coro4, err := engine.NewCoroutine(mockCtx)
require.NoError(t, err)
defer coro4.Close()
err = coro4.SetupSandbox()
require.NoError(t, err)
err = coro4.Execute(`
assert(ngx.var ~= nil, "ngx.var should exist")
`)
assert.NoError(t, err)
// 验证 ngx.ctx API 存在
coro5, err := engine.NewCoroutine(mockCtx)
require.NoError(t, err)
defer coro5.Close()
err = coro5.SetupSandbox()
require.NoError(t, err)
err = coro5.Execute(`
assert(ngx.ctx ~= nil, "ngx.ctx should exist")
assert(type(ngx.ctx) == "table", "ngx.ctx should be a table")
`)
assert.NoError(t, err)
// 验证 ngx.log API 存在(日志级别常量和函数)
coro6, err := engine.NewCoroutine(mockCtx)
require.NoError(t, err)
defer coro6.Close()
err = coro6.SetupSandbox()
require.NoError(t, err)
err = coro6.Execute(`
assert(ngx.log ~= nil, "ngx.log should exist")
assert(type(ngx.log) == "function", "ngx.log should be a function")
assert(ngx.ERR ~= nil, "ngx.ERR should exist")
assert(ngx.WARN ~= nil, "ngx.WARN should exist")
assert(ngx.INFO ~= nil, "ngx.INFO should exist")
assert(ngx.DEBUG ~= nil, "ngx.DEBUG should exist")
assert(type(ngx.say) == "function", "ngx.say should be a function")
assert(type(ngx.print) == "function", "ngx.print should be a function")
assert(type(ngx.flush) == "function", "ngx.flush should be a function")
assert(type(ngx.exit) == "function", "ngx.exit should be a function")
assert(type(ngx.redirect) == "function", "ngx.redirect should be a function")
`)
assert.NoError(t, err)
// 验证 ngx.socket API 存在
coro7, err := engine.NewCoroutine(mockCtx)
require.NoError(t, err)
defer coro7.Close()
err = coro7.SetupSandbox()
require.NoError(t, err)
err = coro7.Execute(`
assert(ngx.socket ~= nil, "ngx.socket should exist")
assert(type(ngx.socket.tcp) == "function", "ngx.socket.tcp should be a function")
`)
assert.NoError(t, err)
}