feat(lua): 实现 ngx.var API
通过元表实现动态变量读写,支持 nginx 内置变量和自定义变量 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d021b0e9fd
commit
51061d68ff
197
internal/lua/api_var.go
Normal file
197
internal/lua/api_var.go
Normal file
@ -0,0 +1,197 @@
|
||||
// Package lua 提供 ngx.var API 实现
|
||||
package lua
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// ngxVarAPI ngx.var API 实现
|
||||
type ngxVarAPI struct {
|
||||
// 请求上下文
|
||||
ctx *fasthttp.RequestCtx
|
||||
|
||||
// 变量存储(用于自定义变量)
|
||||
store map[string]string
|
||||
}
|
||||
|
||||
// newNgxVarAPI 创建 ngx.var API 实例
|
||||
func newNgxVarAPI(ctx *fasthttp.RequestCtx) *ngxVarAPI {
|
||||
return &ngxVarAPI{
|
||||
ctx: ctx,
|
||||
store: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterNgxVarAPI 在 Lua 状态机中注册 ngx.var API
|
||||
// 使用元表实现动态读写:ngx.var.key 和 ngx.var[key]
|
||||
func RegisterNgxVarAPI(L *glua.LState, api *ngxVarAPI, ngxTable *glua.LTable) {
|
||||
// 创建 ngx.var 表(使用元表实现动态访问)
|
||||
ngxVar := L.NewTable()
|
||||
|
||||
// 创建元表
|
||||
mt := L.NewTable()
|
||||
|
||||
// __index 元方法:读取变量
|
||||
mt.RawSetString("__index", L.NewFunction(api.luaVarIndex))
|
||||
|
||||
// __newindex 元方法:设置变量
|
||||
mt.RawSetString("__newindex", L.NewFunction(api.luaVarNewIndex))
|
||||
|
||||
// 设置元表
|
||||
L.SetMetatable(ngxVar, mt)
|
||||
|
||||
// 将 ngx.var 添加到 ngx 表
|
||||
ngxTable.RawSetString("var", ngxVar)
|
||||
}
|
||||
|
||||
// luaVarIndex 实现 ngx.var[key] 读取
|
||||
// Lua 调用: local value = ngx.var.key 或 ngx.var[key]
|
||||
func (api *ngxVarAPI) luaVarIndex(L *glua.LState) int {
|
||||
// 第一个参数是表本身(ngx.var)
|
||||
// 第二个参数是键名
|
||||
key := L.CheckString(2)
|
||||
|
||||
// 1. 先查自定义变量存储
|
||||
if value, ok := api.store[key]; ok {
|
||||
L.Push(glua.LString(value))
|
||||
return 1
|
||||
}
|
||||
|
||||
// 2. 从 fasthttp RequestCtx 获取变量
|
||||
value := api.getVariable(key)
|
||||
if value != "" {
|
||||
L.Push(glua.LString(value))
|
||||
return 1
|
||||
}
|
||||
|
||||
// 3. 未找到变量,返回 nil
|
||||
L.Push(glua.LNil)
|
||||
return 1
|
||||
}
|
||||
|
||||
// luaVarNewIndex 实现 ngx.var[key] = value 写入
|
||||
// Lua 调用: ngx.var.key = value 或 ngx.var[key] = value
|
||||
func (api *ngxVarAPI) luaVarNewIndex(L *glua.LState) int {
|
||||
// 第一个参数是表本身(ngx.var)
|
||||
// 第二个参数是键名
|
||||
// 第三个参数是值
|
||||
key := L.CheckString(2)
|
||||
value := L.CheckString(3)
|
||||
|
||||
// 存储到自定义变量存储
|
||||
api.store[key] = value
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// getVariable 从 fasthttp RequestCtx 获取变量值
|
||||
// 支持常见的 nginx 变量
|
||||
func (api *ngxVarAPI) getVariable(name string) string {
|
||||
if api.ctx == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch name {
|
||||
// HTTP 请求相关
|
||||
case "request_method":
|
||||
return string(api.ctx.Method())
|
||||
case "request_uri":
|
||||
return string(api.ctx.RequestURI())
|
||||
case "uri":
|
||||
return string(api.ctx.URI().Path())
|
||||
case "document_uri":
|
||||
return string(api.ctx.URI().Path())
|
||||
case "query_string", "args":
|
||||
return string(api.ctx.URI().QueryString())
|
||||
case "server_protocol", "protocol":
|
||||
return string(api.ctx.Request.Header.Protocol())
|
||||
case "scheme":
|
||||
return string(api.ctx.URI().Scheme())
|
||||
case "request_length":
|
||||
return strconv.Itoa(api.ctx.Request.Header.ContentLength())
|
||||
case "request_time":
|
||||
// 简化实现,返回空字符串
|
||||
return ""
|
||||
|
||||
// 请求头相关
|
||||
case "http_host":
|
||||
return string(api.ctx.Host())
|
||||
case "http_user_agent", "http_user-agent":
|
||||
return string(api.ctx.UserAgent())
|
||||
case "http_referer":
|
||||
return string(api.ctx.Referer())
|
||||
case "http_accept":
|
||||
return string(api.ctx.Request.Header.Peek("Accept"))
|
||||
case "http_accept_encoding", "http_accept-encoding":
|
||||
return string(api.ctx.Request.Header.Peek("Accept-Encoding"))
|
||||
case "http_accept_language", "http_accept-language":
|
||||
return string(api.ctx.Request.Header.Peek("Accept-Language"))
|
||||
case "http_connection":
|
||||
return string(api.ctx.Request.Header.Peek("Connection"))
|
||||
case "http_content_type", "http_content-type":
|
||||
return string(api.ctx.Request.Header.ContentType())
|
||||
case "http_content_length", "http_content-length":
|
||||
return string(api.ctx.Request.Header.Peek("Content-Length"))
|
||||
|
||||
// 客户端信息
|
||||
case "remote_addr":
|
||||
return api.ctx.RemoteAddr().String()
|
||||
case "remote_port":
|
||||
addr := api.ctx.RemoteAddr()
|
||||
if addr != nil {
|
||||
// 简化处理,实际可能需要解析端口
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
case "binary_remote_addr":
|
||||
return ""
|
||||
|
||||
// 服务器信息
|
||||
case "server_addr":
|
||||
addr := api.ctx.LocalAddr()
|
||||
if addr != nil {
|
||||
return addr.String()
|
||||
}
|
||||
return ""
|
||||
case "server_port":
|
||||
return ""
|
||||
case "server_name":
|
||||
return string(api.ctx.Host())
|
||||
|
||||
// URI 参数
|
||||
case "arg_":
|
||||
// 获取所有参数
|
||||
return string(api.ctx.URI().QueryString())
|
||||
default:
|
||||
// 检查是否是 arg_ 开头的参数
|
||||
if len(name) > 4 && name[:4] == "arg_" {
|
||||
paramName := name[4:]
|
||||
return string(api.ctx.QueryArgs().Peek(paramName))
|
||||
}
|
||||
// 检查是否是 http_ 开头的请求头
|
||||
if len(name) > 5 && name[:5] == "http_" {
|
||||
headerName := name[5:]
|
||||
return string(api.ctx.Request.Header.Peek(headerName))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// SetVariable 设置自定义变量(Go 层调用)
|
||||
func (api *ngxVarAPI) SetVariable(name, value string) {
|
||||
api.store[name] = value
|
||||
}
|
||||
|
||||
// GetVariable 获取变量值(Go 层调用)
|
||||
func (api *ngxVarAPI) GetVariable(name string) (string, bool) {
|
||||
// 先查自定义变量
|
||||
if value, ok := api.store[name]; ok {
|
||||
return value, true
|
||||
}
|
||||
// 再查 fasthttp 变量
|
||||
value := api.getVariable(name)
|
||||
return value, value != ""
|
||||
}
|
||||
284
internal/lua/api_var_test.go
Normal file
284
internal/lua/api_var_test.go
Normal file
@ -0,0 +1,284 @@
|
||||
// Package lua 提供 ngx.var API 测试
|
||||
package lua
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// TestNgxVarAPI 测试 ngx.var API 基础功能
|
||||
func TestNgxVarAPI(t *testing.T) {
|
||||
// 创建 fasthttp 请求上下文
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/test/path?foo=bar&baz=qux")
|
||||
req.Header.Set("Host", "example.com")
|
||||
req.Header.Set("User-Agent", "TestAgent")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
// 创建协程
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
// 设置沙箱(这会注册 ngx API)
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试 ngx.var 存在
|
||||
err = coro.Execute(`
|
||||
if type(ngx) ~= "table" then
|
||||
error("ngx is not a table")
|
||||
end
|
||||
if type(ngx.var) ~= "table" then
|
||||
error("ngx.var is not a table")
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxVarReadBuiltin 测试读取内置变量
|
||||
func TestNgxVarReadBuiltin(t *testing.T) {
|
||||
// 创建 fasthttp 请求上下文
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("POST")
|
||||
req.Header.SetRequestURI("/api/test?name=value")
|
||||
req.Header.Set("Host", "test.example.com")
|
||||
req.Header.Set("User-Agent", "Mozilla/5.0")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试读取内置变量
|
||||
err = coro.Execute(`
|
||||
local method = ngx.var.request_method
|
||||
if method ~= "POST" then
|
||||
error("request_method should be POST, got: " .. tostring(method))
|
||||
end
|
||||
|
||||
local uri = ngx.var.uri
|
||||
if uri ~= "/api/test" then
|
||||
error("uri should be /api/test, got: " .. tostring(uri))
|
||||
end
|
||||
|
||||
local host = ngx.var.http_host
|
||||
if host ~= "test.example.com" then
|
||||
error("http_host should be test.example.com, got: " .. tostring(host))
|
||||
end
|
||||
|
||||
local userAgent = ngx.var.http_user_agent
|
||||
if userAgent ~= "Mozilla/5.0" then
|
||||
error("http_user_agent should be Mozilla/5.0, got: " .. tostring(userAgent))
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxVarReadQueryArgs 测试读取查询参数
|
||||
func TestNgxVarReadQueryArgs(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/search?q=lua&page=1")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试读取查询参数
|
||||
err = coro.Execute(`
|
||||
local q = ngx.var.arg_q
|
||||
if q ~= "lua" then
|
||||
error("arg_q should be 'lua', got: " .. tostring(q))
|
||||
end
|
||||
|
||||
local page = ngx.var.arg_page
|
||||
if page ~= "1" then
|
||||
error("arg_page should be '1', got: " .. tostring(page))
|
||||
end
|
||||
|
||||
local queryString = ngx.var.query_string
|
||||
if type(queryString) ~= "string" then
|
||||
error("query_string should be a string")
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxVarWriteCustom 测试设置自定义变量
|
||||
func TestNgxVarWriteCustom(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/test")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试设置和读取自定义变量
|
||||
err = coro.Execute(`
|
||||
-- 设置自定义变量
|
||||
ngx.var.my_custom_var = "hello world"
|
||||
ngx.var.another_var = "12345"
|
||||
|
||||
-- 读取自定义变量
|
||||
local val1 = ngx.var.my_custom_var
|
||||
if val1 ~= "hello world" then
|
||||
error("my_custom_var should be 'hello world', got: " .. tostring(val1))
|
||||
end
|
||||
|
||||
local val2 = ngx.var.another_var
|
||||
if val2 ~= "12345" then
|
||||
error("another_var should be '12345', got: " .. tostring(val2))
|
||||
end
|
||||
|
||||
-- 覆盖已存在的变量
|
||||
ngx.var.my_custom_var = "updated"
|
||||
local val3 = ngx.var.my_custom_var
|
||||
if val3 ~= "updated" then
|
||||
error("my_custom_var should be 'updated', got: " .. tostring(val3))
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxVarIndexAccess 测试索引访问方式 ngx.var[key]
|
||||
func TestNgxVarIndexAccess(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/test")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试索引访问
|
||||
err = coro.Execute(`
|
||||
-- 使用索引方式设置变量
|
||||
ngx.var["dynamic_key"] = "dynamic_value"
|
||||
|
||||
-- 使用索引方式读取变量
|
||||
local val = ngx.var["dynamic_key"]
|
||||
if val ~= "dynamic_value" then
|
||||
error("dynamic_key should be 'dynamic_value', got: " .. tostring(val))
|
||||
end
|
||||
|
||||
-- 混合访问方式
|
||||
ngx.var.mixed = "mixed_value"
|
||||
local mixed = ngx.var["mixed"]
|
||||
if mixed ~= "mixed_value" then
|
||||
error("mixed should be 'mixed_value', got: " .. tostring(mixed))
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxVarNilRequestCtx 测试无请求上下文的情况
|
||||
func TestNgxVarNilRequestCtx(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
// 创建没有请求上下文的协程
|
||||
coro, err := engine.NewCoroutine(nil)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试在无请求上下文时访问 ngx.var
|
||||
err = coro.Execute(`
|
||||
-- 应该返回空字符串或 nil
|
||||
local method = ngx.var.request_method
|
||||
-- 可以接受空字符串或 nil
|
||||
|
||||
-- 但自定义变量仍然可以设置
|
||||
ngx.var.test = "value"
|
||||
local val = ngx.var.test
|
||||
if val ~= "value" then
|
||||
error("custom var should be settable even without request ctx")
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestNgxVarUndefined 测试未定义变量
|
||||
func TestNgxVarUndefined(t *testing.T) {
|
||||
var req fasthttp.Request
|
||||
req.Header.SetMethod("GET")
|
||||
req.Header.SetRequestURI("/test")
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&req, nil, nil)
|
||||
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(ctx)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试读取未定义变量
|
||||
err = coro.Execute(`
|
||||
local undefined = ngx.var.undefined_var_name
|
||||
if undefined ~= nil then
|
||||
error("undefined var should be nil, got: " .. tostring(undefined))
|
||||
end
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user