diff --git a/internal/lua/api_ctx_test.go b/internal/lua/api_ctx_test.go index eca926d..f98e20d 100644 --- a/internal/lua/api_ctx_test.go +++ b/internal/lua/api_ctx_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" + glua "github.com/yuin/gopher-lua" ) // TestNgxCtxAPI 测试 ngx.ctx API 基础功能 @@ -365,3 +366,335 @@ func TestNgxCtxNestedTable(t *testing.T) { `) assert.NoError(t, err) } + +// TestNgxCtxRequestIsolation 测试请求间上下文隔离 +func TestNgxCtxRequestIsolation(t *testing.T) { + var req1, req2 fasthttp.Request + req1.Header.SetMethod("GET") + req1.Header.SetRequestURI("/request1") + req2.Header.SetMethod("GET") + req2.Header.SetRequestURI("/request2") + + ctx1 := &fasthttp.RequestCtx{} + ctx1.Init(&req1, nil, nil) + ctx2 := &fasthttp.RequestCtx{} + ctx2.Init(&req2, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + // 第一个请求:设置 ctx 值 + coro1, err := engine.NewCoroutine(ctx1) + require.NoError(t, err) + + err = coro1.SetupSandbox() + require.NoError(t, err) + + err = coro1.Execute(` + ngx.ctx.request_id = 1 + ngx.ctx.message = "request1_data" + `) + assert.NoError(t, err) + coro1.Close() + + // 第二个请求:验证 ctx 与其他请求隔离 + coro2, err := engine.NewCoroutine(ctx2) + require.NoError(t, err) + + err = coro2.SetupSandbox() + require.NoError(t, err) + + err = coro2.Execute(` + -- 第一个请求的值不应该影响第二个请求 + if ngx.ctx.request_id ~= nil then + error("ctx from another request should be isolated") + end + + if ngx.ctx.message ~= nil then + error("ctx from another request should be isolated") + end + + -- 设置自己的值 + ngx.ctx.request_id = 2 + ngx.ctx.message = "request2_data" + + if ngx.ctx.request_id ~= 2 then + error("request_id should be 2") + end + + if ngx.ctx.message ~= "request2_data" then + error("message should be 'request2_data'") + end + `) + assert.NoError(t, err) + coro2.Close() +} + +// TestNgxCtxGoAPIAccess 测试 Go 层 API 访问 +func TestNgxCtxGoAPIAccess(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/test") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 通过 Go 层 API 设置值 + api := coro.GetNgxVarAPI() + require.NotNil(t, api) + + api.SetVariable("go_key", "go_value") + + // 验证 Lua 层可以读取,且 Lua 层设置的值 Go 层可见 + // 注意:在一个脚本中完成所有操作,因为协程执行后变为 dead 状态 + err = coro.Execute(` + -- 验证 Go 层设置的值 + if ngx.var.go_key ~= "go_value" then + error("value from Go layer should be accessible in Lua") + end + + -- 从 Lua 层设置值 + ngx.var.lua_key = "lua_value" + `) + assert.NoError(t, err) + + // 验证 Lua 层设置的值 Go 层可见 + val, ok := api.GetVariable("lua_key") + assert.True(t, ok) + assert.Equal(t, "lua_value", val) +} + +// TestNgxCtxScheduleUnsafeAPI 测试调度器上下文中的不安全 API +func TestNgxCtxScheduleUnsafeAPI(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + // 为 Scheduler LState 创建不安全的 ctx API + L := glua.NewState() + defer L.Close() + + ngx := L.NewTable() + L.SetGlobal("ngx", ngx) + + // 注册调度器不安全的 ctx API + RegisterSchedulerUnsafeCtxAPI(L, ngx) + + // 尝试访问 ctx 应该返回错误 + err = L.DoString(` + local ok, msg = pcall(function() + ngx.ctx.key = "value" + end) + if ok then + error("writing to ngx.ctx in scheduler should fail") + end + if not string.match(msg, "not available in timer callback") then + error("wrong error message: " .. msg) + end + `) + assert.NoError(t, err) + + // 尝试读取 ctx 也应该返回错误 + err = L.DoString(` + local ok, msg = pcall(function() + local x = ngx.ctx.key + end) + if ok then + error("reading from ngx.ctx in scheduler should fail") + end + `) + assert.NoError(t, err) +} + +// TestNgxCtxTableAPI 测试 table API 操作 +func TestNgxCtxTableAPI(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/test") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试 table 操作函数 + err = coro.Execute(` + -- 测试 pairs 遍历 + ngx.ctx.items = {a = 1, b = 2, c = 3} + local count = 0 + for k, v in pairs(ngx.ctx.items) do + count = count + 1 + end + if count ~= 3 then + error("items table should have 3 elements") + end + + -- 测试 table.insert + ngx.ctx.list = {} + table.insert(ngx.ctx.list, 1) + table.insert(ngx.ctx.list, 2) + table.insert(ngx.ctx.list, 3) + if ngx.ctx.list[1] ~= 1 or ngx.ctx.list[2] ~= 2 or ngx.ctx.list[3] ~= 3 then + error("table.insert failed") + end + + -- 测试 table.remove + table.remove(ngx.ctx.list, 2) + if #ngx.ctx.list ~= 2 or ngx.ctx.list[1] ~= 1 or ngx.ctx.list[2] ~= 3 then + error("table.remove failed") + end + + -- 测试 table.concat + ngx.ctx.strlist = {"hello", "world", "test"} + local joined = table.concat(ngx.ctx.strlist, ", ") + if joined ~= "hello, world, test" then + error("table.concat failed: " .. joined) + end + `) + assert.NoError(t, err) +} + +// TestNgxCtxLargeValues 测试大值存储 +func TestNgxCtxLargeValues(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/test") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试大字符串和大 table(在一个脚本中完成所有操作) + largeString := string(make([]byte, 10000)) // 10KB 字符串 + err = coro.Execute(` + -- 测试大字符串 + ngx.ctx.large = "` + largeString + `" + local val = ngx.ctx.large + if type(val) ~= "string" then + error("large value should be string") + end + + -- 测试大 table + ngx.ctx.bigtable = {} + for i = 1, 1000 do + ngx.ctx.bigtable[i] = i * 2 + end + if #ngx.ctx.bigtable ~= 1000 then + error("bigtable should have 1000 elements") + end + `) + assert.NoError(t, err) +} + +// TestNgxCtxTypeCoercion 测试类型转换 +func TestNgxCtxTypeCoercion(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/test") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试数字字符串自动转换 + err = coro.Execute(` + ngx.ctx.num = 42 + ngx.ctx.str = "123" + + -- 数字加字符串 + local result = ngx.ctx.num + tonumber(ngx.ctx.str) + if result ~= 165 then + error("type coercion failed: " .. tostring(result)) + end + + -- 字符串连接 + local concatenated = "value: " .. ngx.ctx.num + if concatenated ~= "value: 42" then + error("string concatenation failed: " .. concatenated) + end + `) + assert.NoError(t, err) +} + +// TestNgxCtxBooleanLogic 测试布尔逻辑 +func TestNgxCtxBooleanLogic(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/test") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + err = coro.Execute(` + ngx.ctx.a = true + ngx.ctx.b = false + + -- and 操作 + if (ngx.ctx.a and ngx.ctx.b) ~= false then + error("a and b should be false") + end + + -- or 操作 + if (ngx.ctx.a or ngx.ctx.b) ~= true then + error("a or b should be true") + end + + -- not 操作 + if (not ngx.ctx.a) ~= false then + error("not a should be false") + end + `) + assert.NoError(t, err) +} diff --git a/internal/lua/api_socket_tcp_test.go b/internal/lua/api_socket_tcp_test.go new file mode 100644 index 0000000..6f7be13 --- /dev/null +++ b/internal/lua/api_socket_tcp_test.go @@ -0,0 +1,1256 @@ +package lua + +import ( + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + glua "github.com/yuin/gopher-lua" +) + +// TestNewTCPSocket_NilManager 测试 nil manager 使用默认管理器 +func TestNewTCPSocket_NilManager(t *testing.T) { + socket := NewTCPSocket(nil) + require.NotNil(t, socket) + assert.Equal(t, SocketStateIdle, socket.State()) + assert.Equal(t, 60*time.Second, socket.readTimeout) + assert.Equal(t, 60*time.Second, socket.sendTimeout) + assert.Equal(t, 30*time.Second, socket.connectTimeout) + assert.False(t, socket.IsClosed()) + socket.Close() +} + +// TestNewTCPSocket_ExplicitManager 测试显式管理器 +func TestNewTCPSocket_ExplicitManager(t *testing.T) { + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + require.NotNil(t, socket) + assert.Equal(t, cm, socket.manager) + socket.Close() +} + +// TestTCPSocket_ConnectNotIdle 测试非空闲状态下连接失败 +func TestTCPSocket_ConnectNotIdle(t *testing.T) { + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + // 设置状态为非空闲 + socket.setState(SocketStateConnected) + defer socket.setState(SocketStateIdle) + + err := socket.Connect("127.0.0.1", 9999) + assert.Error(t, err) + assert.Contains(t, err.Error(), "socket not idle") +} + +// TestTCPSocket_Send_NotConnected 测试未连接时发送失败 +func TestTCPSocket_Send_NotConnected(t *testing.T) { + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + // 未连接状态下发送 + n, err := socket.Send([]byte("test")) + assert.Error(t, err) + assert.Equal(t, 0, n) + assert.Contains(t, err.Error(), "socket not connected") +} + +// TestTCPSocket_SendAsync_NotConnected 测试未连接时异步发送失败 +func TestTCPSocket_SendAsync_NotConnected(t *testing.T) { + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + op, err := socket.SendAsync([]byte("test")) + assert.Error(t, err) + assert.Nil(t, op) + assert.Contains(t, err.Error(), "socket not connected") +} + +// TestTCPSocket_Receive_NotConnected 测试未连接时接收失败 +func TestTCPSocket_Receive_NotConnected(t *testing.T) { + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + data, err := socket.Receive(1024) + assert.Error(t, err) + assert.Nil(t, data) + assert.Contains(t, err.Error(), "socket not connected") +} + +// TestTCPSocket_ReceiveAsync_NotConnected 测试未连接时异步接收失败 +func TestTCPSocket_ReceiveAsync_NotConnected(t *testing.T) { + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + op, err := socket.ReceiveAsync(1024) + assert.Error(t, err) + assert.Nil(t, op) + assert.Contains(t, err.Error(), "socket not connected") +} + +// TestTCPSocket_ReceiveUntil_NotConnected 测试未连接时 ReceiveUntil 失败 +func TestTCPSocket_ReceiveUntil_NotConnected(t *testing.T) { + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + // 设置为 connected 状态但 conn 为 nil + socket.setState(SocketStateConnected) + defer socket.setState(SocketStateIdle) + + data, err := socket.ReceiveUntil("\n", true) + assert.Error(t, err) + assert.Nil(t, data) + assert.Contains(t, err.Error(), "socket connection is nil") +} + +// TestTCPSocket_ReceiveUntil_EmptyPattern 测试空模式错误 +func TestTCPSocket_ReceiveUntil_EmptyPattern(t *testing.T) { + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + data, err := socket.ReceiveUntil("", true) + assert.Error(t, err) + assert.Nil(t, data) + assert.Contains(t, err.Error(), "pattern cannot be empty") +} + +// TestTCPSocket_SetTimeouts 测试设置超时 +func TestTCPSocket_SetTimeouts(t *testing.T) { + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + socket.SetTimeout(5 * time.Second) + assert.Equal(t, 5*time.Second, socket.readTimeout) + assert.Equal(t, 5*time.Second, socket.sendTimeout) + assert.Equal(t, 5*time.Second, socket.connectTimeout) + + socket.SetReadTimeout(10 * time.Second) + assert.Equal(t, 10*time.Second, socket.readTimeout) + + socket.SetSendTimeout(15 * time.Second) + assert.Equal(t, 15*time.Second, socket.sendTimeout) + + socket.SetConnectTimeout(20 * time.Second) + assert.Equal(t, 20*time.Second, socket.connectTimeout) +} + +// TestTCPSocket_StateString 测试状态字符串 +func TestTCPSocket_StateString(t *testing.T) { + assert.Equal(t, "idle", SocketStateIdle.String()) + assert.Equal(t, "connecting", SocketStateConnecting.String()) + assert.Equal(t, "connected", SocketStateConnected.String()) + assert.Equal(t, "sending", SocketStateSending.String()) + assert.Equal(t, "receiving", SocketStateReceiving.String()) + assert.Equal(t, "closing", SocketStateClosing.String()) + assert.Equal(t, "closed", SocketStateClosed.String()) + assert.Equal(t, "error", SocketStateError.String()) + // 未知状态 + assert.Equal(t, "unknown", SocketState(999).String()) +} + +// TestTCPSocket_LocalAddr_RemoteAddr_NotConnected 测试未连接时地址返回 nil +func TestTCPSocket_LocalAddr_RemoteAddr_NotConnected(t *testing.T) { + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + assert.Nil(t, socket.LocalAddr()) + assert.Nil(t, socket.RemoteAddr()) +} + +// TestTCPSocket_ConnectAsync 测试 ConnectAsync +func TestTCPSocket_ConnectAsync(t *testing.T) { + _, cleanup := mockEchoServer(t, "127.0.0.1:18801") + defer cleanup() + + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + L := glua.NewState() + defer L.Close() + + op, err := socket.ConnectAsync(L, "127.0.0.1", 18801) + require.NoError(t, err) + require.NotNil(t, op) + + // 等待连接完成 + result, err := op.Wait(context.Background()) + require.NoError(t, err) + assert.NotNil(t, result) +} + +// TestTCPSocket_ConnectAsync_Error 测试 ConnectAsync 错误路径 +func TestTCPSocket_ConnectAsync_Error(t *testing.T) { + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + // 先设置为非空闲状态,ConnectAsync 应失败 + socket.setState(SocketStateConnected) + defer socket.setState(SocketStateIdle) + + L := glua.NewState() + defer L.Close() + + op, err := socket.ConnectAsync(L, "127.0.0.1", 18800) + assert.Error(t, err) + assert.Nil(t, op) +} + +// TestTCPSocket_Connect_Failure 测试连接失败(无服务器) +func TestTCPSocket_Connect_Failure(t *testing.T) { + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + // 连接到不存在的端口,应该返回错误 + err := socket.Connect("127.0.0.1", 18899) + require.NoError(t, err) // Connect 本身不报错 + + // 等待异步连接完成 + op := socket.currentOp + if op != nil { + _, err := op.Wait(context.Background()) + assert.Error(t, err) // 连接应该失败 + } + + // 状态应该变为 error + assert.Equal(t, SocketStateError, socket.State()) +} + +// TestTCPSocket_Receive_DefaultSize 测试默认读取大小 (size <= 0) +func TestTCPSocket_Receive_DefaultSize(t *testing.T) { + _, cleanup := mockEchoServer(t, "127.0.0.1:18802") + defer cleanup() + + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + // 连接 + if err := socket.Connect("127.0.0.1", 18802); err != nil { + t.Fatalf("Connect failed: %v", err) + } + time.Sleep(100 * time.Millisecond) + + // 发送数据 + testData := "Hello" + _, err := socket.Send([]byte(testData)) + require.NoError(t, err) + + // 使用 size=0 触发默认 4096 + received, err := socket.Receive(0) + require.NoError(t, err) + assert.Equal(t, testData, string(received)) +} + +// TestTCPSocket_ReceiveAsync_DefaultSize 测试异步接收默认大小 +func TestTCPSocket_ReceiveAsync_DefaultSize(t *testing.T) { + _, cleanup := mockEchoServer(t, "127.0.0.1:18803") + defer cleanup() + + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + if err := socket.Connect("127.0.0.1", 18803); err != nil { + t.Fatalf("Connect failed: %v", err) + } + time.Sleep(100 * time.Millisecond) + + // 发送 + testData := "AsyncReceive" + _, err := socket.Send([]byte(testData)) + require.NoError(t, err) + + // 异步接收,size=-1 触发默认 + op, err := socket.ReceiveAsync(-1) + require.NoError(t, err) + require.NotNil(t, op) + + result, err := op.Wait(context.Background()) + require.NoError(t, err) + data, ok := result.([]byte) + require.True(t, ok) + assert.Equal(t, testData, string(data)) +} + +// TestTCPSocket_ReceiveUntil_Inclusive 测试 inclusive 模式 +func TestTCPSocket_ReceiveUntil_Inclusive(t *testing.T) { + _, cleanup := mockEchoServer(t, "127.0.0.1:18804") + defer cleanup() + + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + if err := socket.Connect("127.0.0.1", 18804); err != nil { + t.Fatalf("Connect failed: %v", err) + } + time.Sleep(100 * time.Millisecond) + + // 发送带分隔符的数据 + testData := "hello|world" + _, err := socket.Send([]byte(testData)) + require.NoError(t, err) + + // inclusive=true: 包含模式 + data, err := socket.ReceiveUntil("|", true) + require.NoError(t, err) + assert.Equal(t, "hello|", string(data)) +} + +// TestTCPSocket_ReceiveUntil_Exclusive 测试 exclusive 模式 +func TestTCPSocket_ReceiveUntil_Exclusive(t *testing.T) { + _, cleanup := mockEchoServer(t, "127.0.0.1:18805") + defer cleanup() + + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + if err := socket.Connect("127.0.0.1", 18805); err != nil { + t.Fatalf("Connect failed: %v", err) + } + time.Sleep(100 * time.Millisecond) + + testData := "hello|world" + _, err := socket.Send([]byte(testData)) + require.NoError(t, err) + + // inclusive=false: 不包含模式 + data, err := socket.ReceiveUntil("|", false) + require.NoError(t, err) + assert.Equal(t, "hello", string(data)) +} + +// TestTCPSocket_Close_CompletesPendingOp 测试关闭时完成未完成的操作 +func TestTCPSocket_Close_CompletesPendingOp(t *testing.T) { + // 使用 slow server 模拟延迟 + ln, err := net.Listen("tcp", "127.0.0.1:18806") + require.NoError(t, err) + + var wg sync.WaitGroup + stopCh := make(chan struct{}) + + // slow server: 接受连接但延迟响应 + go func() { + for { + conn, err := ln.Accept() + if err != nil { + select { + case <-stopCh: + return + default: + continue + } + } + wg.Add(1) + go func(c net.Conn) { + defer wg.Done() + // 保持连接打开但不发送数据 + buf := make([]byte, 1) + c.Read(buf) + c.Close() + }(conn) + } + }() + + defer func() { + close(stopCh) + ln.Close() + wg.Wait() + }() + + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + // 连接 + err = socket.Connect("127.0.0.1", 18806) + require.NoError(t, err) + time.Sleep(100 * time.Millisecond) + + // 此时连接应该建立了 + assert.Equal(t, SocketStateConnected, socket.State()) + + // 启动一个异步接收操作 + op, err := socket.ReceiveAsync(1024) + require.NoError(t, err) + require.NotNil(t, op) + + // 在操作完成前关闭 socket + err = socket.Close() + assert.NoError(t, err) + + // 等待操作完成(应该被取消) + _, err = op.Wait(context.Background()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "socket closed") +} + +// TestTCPSocket_Close_NilSafety 测试 nil socket 的 Close +func TestTCPSocket_Close_NilSafety(t *testing.T) { + var s *TCPSocket + err := s.Close() + assert.NoError(t, err) +} + +// TestTCPSocket_DoubleClose 测试重复关闭 +func TestTCPSocket_DoubleClose(t *testing.T) { + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + + err := socket.Close() + assert.NoError(t, err) + + err = socket.Close() + assert.NoError(t, err) // 第二次关闭不应报错 + + assert.True(t, socket.IsClosed()) +} + +// TestTCPSocket_Addresses_WhenConnected 测试连接后获取地址 +func TestTCPSocket_Addresses_WhenConnected(t *testing.T) { + _, cleanup := mockEchoServer(t, "127.0.0.1:18807") + defer cleanup() + + cm := NewCosocketManager() + defer cm.Close() + + socket := NewTCPSocket(cm) + defer socket.Close() + + if err := socket.Connect("127.0.0.1", 18807); err != nil { + t.Fatalf("Connect failed: %v", err) + } + time.Sleep(100 * time.Millisecond) + + assert.NotNil(t, socket.LocalAddr()) + assert.NotNil(t, socket.RemoteAddr()) + assert.Contains(t, socket.RemoteAddr().String(), "127.0.0.1:18807") +} + +// ---- Lua API Tests ---- + +// TestLuaAPI_newTCPSocketFunc 测试 newTCPSocketFunc +func TestLuaAPI_newTCPSocketFunc(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + assert(sock ~= nil) + assert(type(sock) == "userdata") + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketConnect 测试 tcpSocketConnect +func TestLuaAPI_tcpSocketConnect(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + // 测试 connect 返回值结构(不等待实际连接完成,因为没有 yield 处理) + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local res1, res2 = sock:connect("127.0.0.1", 9999) + -- res1 应该是 "cosocket_connect",res2 是 op ID + assert(type(res1) == "string") + assert(res1 == "cosocket_connect") + assert(type(res2) == "number") + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketConnect_WithError 测试 connect 错误返回 +func TestLuaAPI_tcpSocketConnect_WithError(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + // 尝试连接到非空闲 socket(已连接过但这里没有,用非法端口) + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + -- 先用一个 nil 测试 connect 的 Lua 参数错误 + local res, err = pcall(function() + sock:connect(123) -- wrong argument types + end) + `) + // Lua 可能 raise error,这取决于实现 + _ = err +} + +// TestLuaAPI_tcpSocketSend 测试 tcpSocketSend +func TestLuaAPI_tcpSocketSend(t *testing.T) { + _, cleanup := mockEchoServer(t, "127.0.0.1:18809") + defer cleanup() + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + // connect 和 send 都会返回 yield 值(cosocket_xxx, op_id) + // 在没有实际 yield 处理的情况下,只测试不报错 + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local res1, res2 = sock:connect("127.0.0.1", 18809) + -- res1 应该是 "cosocket_connect",res2 应该是 op ID + -- 没有 yield 处理,连接实际未完成 + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketSend_Error 测试 send 错误(未连接时发送) +func TestLuaAPI_tcpSocketSend_Error(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local res, err = sock:send("hello") + -- 未连接时应该返回 nil + error + assert(res == nil) + assert(err ~= nil) + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketReceive 测试 tcpSocketReceive +func TestLuaAPI_tcpSocketReceive(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + // 测试未连接时的 receive 返回错误 + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local res, err = sock:receive(1024) + -- 未连接时应该返回 nil + error + assert(res == nil) + assert(err ~= nil) + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketReceive_Error 测试 receive 错误(未连接时接收) +func TestLuaAPI_tcpSocketReceive_Error(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local res, err = sock:receive(1024) + assert(res == nil) + assert(err ~= nil) + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketReceive_Pattern 测试 receive 带模式 +func TestLuaAPI_tcpSocketReceive_Pattern(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + // 测试未连接时 receive("*a") 返回错误 + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local res, err = sock:receive("*a") + assert(res == nil) + assert(err ~= nil) + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketReceive_UnknownPattern 测试未知模式 +func TestLuaAPI_tcpSocketReceive_UnknownPattern(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + cm := NewCosocketManager() + defer cm.Close() + + // 创建一个已连接但 conn 为 nil 的 socket 来测试模式接收的错误路径 + // 我们需要通过 Lua API 间接测试 + RegisterTCPSocketAPI(engine.L, engine) + + // 测试 unknown pattern "*x" + // 需要先连接才能进入 pattern 匹配,但这里直接测试模式错误路径 + // 实际上 receive("*x") 在未连接时会先报 "not connected" + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + -- 未连接时用 *x 模式会先报 not connected + local res, err = sock:receive("*x") + assert(res == nil) + assert(err ~= nil) + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketReceiveWithTable 测试 receive 带 table 参数 +func TestLuaAPI_tcpSocketReceiveWithTable(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + // 未连接时 receive 带 table 参数返回错误 + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local res, err = sock:receive({timeout = 5000}) + assert(res == nil) + assert(err ~= nil) + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketReceiveUntil 测试 receiveuntil +func TestLuaAPI_tcpSocketReceiveUntil(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + // 未连接时 receiveuntil 会先报 not connected + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local res, err = sock:receiveuntil("|") + assert(res == nil) + assert(err ~= nil) + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketReceiveUntil_Inclusive 测试 receiveuntil 带 inclusive 选项 +func TestLuaAPI_tcpSocketReceiveUntil_Inclusive(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + // 未连接时 receiveuntil 会先报 not connected + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local res, err = sock:receiveuntil("|", {inclusive = true}) + assert(res == nil) + assert(err ~= nil) + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketClose 测试 tcpSocketClose +func TestLuaAPI_tcpSocketClose(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local ok = sock:close() + assert(ok == true) + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketSetTimeout 测试 settimeout +func TestLuaAPI_tcpSocketSetTimeout(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local ok = sock:settimeout(5000) + assert(ok == true) + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketSetTimeouts 测试 settimeouts +func TestLuaAPI_tcpSocketSetTimeouts(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local ok = sock:settimeouts(1000, 2000, 3000) + assert(ok == true) + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketToString 测试 __tostring +func TestLuaAPI_tcpSocketToString(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local str = tostring(sock) + assert(str:find("tcp_socket")) + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketGC 测试 __gc +func TestLuaAPI_tcpSocketGC(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + // 创建 socket 并触发 GC + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + sock = nil + -- 强制 GC(Lua GC 可能不会立即触发 __gc) + collectgarbage("collect") + `) + require.NoError(t, err) +} + +// TestCheckTCPSocket_ArgError 测试 checkTCPSocket 参数错误 +func TestCheckTCPSocket_ArgError(t *testing.T) { + L := glua.NewState() + defer L.Close() + + // 传入非 userdata 应该 raise error + L.Push(glua.LString("not a socket")) + err := L.PCall(0, 0, nil) + // 这不会触发 checkTCPSocket,我们需要通过 Lua 脚本来测试 + err = L.DoString(` + local sock = ngx.socket.tcp() + `) + // 在没有注册 API 的情况下会失败 + _ = err +} + +// TestLuaAPI_tcpSocketConnect_WithTimeout 测试 connect 带超时选项 +func TestLuaAPI_tcpSocketConnect_WithTimeout(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + // 测试 connect 带超时选项(不等待实际连接完成) + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local res1, res2 = sock:connect("127.0.0.1", 9999, {timeout = 100}) + -- 返回值应该是 "cosocket_connect" 和 op ID + assert(type(res1) == "string") + assert(type(res2) == "number") + `) + require.NoError(t, err) +} + +// TestLuaAPI_tcpSocketSend_WithTimeout 测试 send 带超时选项 +func TestLuaAPI_tcpSocketSend_WithTimeout(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(engine.L, engine) + + // 未连接时 send 返回 nil + error + err = engine.L.DoString(` + local sock = ngx.socket.tcp() + local res, err = sock:send("hello", {timeout = 5000}) + -- 未连接时应该返回 nil + error + assert(res == nil) + assert(err ~= nil) + `) + require.NoError(t, err) +} + +// TestCosocketYield_Connect 测试 cosocket connect yield 处理 +func TestCosocketYield_Connect(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(nil) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试 connect 返回值结构(不等待实际连接完成) + err = coro.Execute(` + local sock = ngx.socket.tcp() + local res1, res2 = sock:connect("127.0.0.1", 9999) + -- res1 应该是 "cosocket_connect",res2 是 op ID + assert(type(res1) == "string") + assert(res1 == "cosocket_connect") + assert(type(res2) == "number") + `) + require.NoError(t, err) +} + +// TestHandleCosocketYield_Unknown 测试未知 yield reason +func TestHandleCosocketYield_Unknown(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(nil) + require.NoError(t, err) + defer coro.Close() + + _, err = coro.HandleCosocketYield("unknown_reason", []glua.LValue{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown cosocket yield reason") +} + +// TestHandleCosocketConnect_MissingOpID 测试 connect yield 缺少 op ID +func TestHandleCosocketConnect_MissingOpID(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(nil) + require.NoError(t, err) + defer coro.Close() + + _, err = coro.HandleCosocketYield("cosocket_connect", []glua.LValue{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "requires operation ID") +} + +// TestHandleCosocketConnect_OpNotFound 测试 connect yield op 不存在 +func TestHandleCosocketConnect_OpNotFound(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(nil) + require.NoError(t, err) + defer coro.Close() + + _, err = coro.HandleCosocketYield("cosocket_connect", []glua.LValue{glua.LNumber(999999)}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +// TestHandleCosocketSend_MissingOpID 测试 send yield 缺少 op ID +func TestHandleCosocketSend_MissingOpID(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(nil) + require.NoError(t, err) + defer coro.Close() + + _, err = coro.HandleCosocketYield("cosocket_send", []glua.LValue{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "requires operation ID") +} + +// TestHandleCosocketSend_OpNotFound 测试 send yield op 不存在 +func TestHandleCosocketSend_OpNotFound(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(nil) + require.NoError(t, err) + defer coro.Close() + + _, err = coro.HandleCosocketYield("cosocket_send", []glua.LValue{glua.LNumber(999999)}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +// TestHandleCosocketReceive_MissingOpID 测试 receive yield 缺少 op ID +func TestHandleCosocketReceive_MissingOpID(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(nil) + require.NoError(t, err) + defer coro.Close() + + _, err = coro.HandleCosocketYield("cosocket_receive", []glua.LValue{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "requires operation ID") +} + +// TestHandleCosocketReceive_OpNotFound 测试 receive yield op 不存在 +func TestHandleCosocketReceive_OpNotFound(t *testing.T) { + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(nil) + require.NoError(t, err) + defer coro.Close() + + _, err = coro.HandleCosocketYield("cosocket_receive", []glua.LValue{glua.LNumber(999999)}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +// TestHandleCosocketReceive_EmptyData 测试 receive yield 返回空数据 +func TestHandleCosocketReceive_EmptyData(t *testing.T) { + socket := NewTCPSocket(DefaultCosocketManager) + defer socket.Close() + + // 创建一个空的 receive 操作 + op := DefaultCosocketManager.StartOperation(socket, OpReceive, 5*time.Second) + op.Complete([]byte{}, nil) + + coro := &LuaCoroutine{ + Engine: nil, + ExecutionContext: context.Background(), + } + + results, err := coro.HandleCosocketYield("cosocket_receive", []glua.LValue{glua.LNumber(op.ID)}) + require.NoError(t, err) + // 空数据应该返回 nil + "closed" + assert.Equal(t, 2, len(results)) + assert.Equal(t, glua.LNil, results[0]) + assert.Equal(t, glua.LString("closed"), results[1]) +} + +// TestHandleCosocketReceive_InvalidResult 测试 receive yield 无效结果类型 +func TestHandleCosocketReceive_InvalidResult(t *testing.T) { + socket := NewTCPSocket(DefaultCosocketManager) + defer socket.Close() + + op := DefaultCosocketManager.StartOperation(socket, OpReceive, 5*time.Second) + op.Complete("not_bytes", nil) // 非 []byte 类型 + + coro := &LuaCoroutine{ + Engine: nil, + ExecutionContext: context.Background(), + } + + results, err := coro.HandleCosocketYield("cosocket_receive", []glua.LValue{glua.LNumber(op.ID)}) + require.NoError(t, err) + assert.Equal(t, 2, len(results)) + assert.Equal(t, glua.LNil, results[0]) + assert.Equal(t, glua.LString("invalid result"), results[1]) +} + +// TestHandleCosocketSend_InvalidResult 测试 send yield 无效结果类型 +func TestHandleCosocketSend_InvalidResult(t *testing.T) { + socket := NewTCPSocket(DefaultCosocketManager) + defer socket.Close() + + op := DefaultCosocketManager.StartOperation(socket, OpSend, 5*time.Second) + op.Complete("not_int", nil) // 非 int 类型 + + coro := &LuaCoroutine{ + Engine: nil, + ExecutionContext: context.Background(), + } + + results, err := coro.HandleCosocketYield("cosocket_send", []glua.LValue{glua.LNumber(op.ID)}) + require.NoError(t, err) + assert.Equal(t, 2, len(results)) + assert.Equal(t, glua.LNil, results[0]) + assert.Equal(t, glua.LString("invalid result"), results[1]) +} + +// TestHandleCosocketConnect_NilResult 测试 connect yield nil result +func TestHandleCosocketConnect_NilResult(t *testing.T) { + socket := NewTCPSocket(DefaultCosocketManager) + defer socket.Close() + + op := DefaultCosocketManager.StartOperation(socket, OpConnect, 5*time.Second) + op.Complete(nil, nil) + + coro := &LuaCoroutine{ + Engine: nil, + ExecutionContext: context.Background(), + } + + results, err := coro.HandleCosocketYield("cosocket_connect", []glua.LValue{glua.LNumber(op.ID)}) + require.NoError(t, err) + assert.Equal(t, 2, len(results)) + assert.Equal(t, glua.LNil, results[0]) + assert.Equal(t, glua.LNil, results[1]) +} + +// TestHandleCosocketConnect_WithError 测试 connect yield 带错误 +func TestHandleCosocketConnect_WithError(t *testing.T) { + socket := NewTCPSocket(DefaultCosocketManager) + defer socket.Close() + + op := DefaultCosocketManager.StartOperation(socket, OpConnect, 5*time.Second) + op.Complete(nil, fmt.Errorf("connection refused")) + + coro := &LuaCoroutine{ + Engine: nil, + ExecutionContext: context.Background(), + } + + results, err := coro.HandleCosocketYield("cosocket_connect", []glua.LValue{glua.LNumber(op.ID)}) + require.NoError(t, err) + assert.Equal(t, 2, len(results)) + assert.Equal(t, glua.LNil, results[0]) + assert.Contains(t, string(glua.LVAsString(results[1])), "connection refused") +} + +// TestHandleCosocketSend_WithError 测试 send yield 带错误 +func TestHandleCosocketSend_WithError(t *testing.T) { + socket := NewTCPSocket(DefaultCosocketManager) + defer socket.Close() + + op := DefaultCosocketManager.StartOperation(socket, OpSend, 5*time.Second) + op.Complete(0, fmt.Errorf("write error")) + + coro := &LuaCoroutine{ + Engine: nil, + ExecutionContext: context.Background(), + } + + results, err := coro.HandleCosocketYield("cosocket_send", []glua.LValue{glua.LNumber(op.ID)}) + require.NoError(t, err) + assert.Equal(t, 2, len(results)) + assert.Equal(t, glua.LNil, results[0]) + assert.Contains(t, string(glua.LVAsString(results[1])), "write error") +} + +// TestHandleCosocketReceive_WithData 测试 receive yield 带数据 +func TestHandleCosocketReceive_WithData(t *testing.T) { + socket := NewTCPSocket(DefaultCosocketManager) + defer socket.Close() + + expected := []byte("hello world") + op := DefaultCosocketManager.StartOperation(socket, OpReceive, 5*time.Second) + op.Complete(expected, nil) + + coro := &LuaCoroutine{ + Engine: nil, + ExecutionContext: context.Background(), + } + + results, err := coro.HandleCosocketYield("cosocket_receive", []glua.LValue{glua.LNumber(op.ID)}) + require.NoError(t, err) + assert.Equal(t, 1, len(results)) + assert.Equal(t, glua.LString("hello world"), results[0]) +} + +// TestHandleCosocketSend_Success 测试 send yield 成功 +func TestHandleCosocketSend_Success(t *testing.T) { + socket := NewTCPSocket(DefaultCosocketManager) + defer socket.Close() + + op := DefaultCosocketManager.StartOperation(socket, OpSend, 5*time.Second) + op.Complete(5, nil) + + coro := &LuaCoroutine{ + Engine: nil, + ExecutionContext: context.Background(), + } + + results, err := coro.HandleCosocketYield("cosocket_send", []glua.LValue{glua.LNumber(op.ID)}) + require.NoError(t, err) + assert.Equal(t, 1, len(results)) + assert.Equal(t, glua.LNumber(5), results[0]) +} + +// TestHandleCosocketConnect_Success 测试 connect yield 成功 +func TestHandleCosocketConnect_Success(t *testing.T) { + socket := NewTCPSocket(DefaultCosocketManager) + defer socket.Close() + + op := DefaultCosocketManager.StartOperation(socket, OpConnect, 5*time.Second) + op.Complete(&net.TCPConn{}, nil) // 非 nil result + + coro := &LuaCoroutine{ + Engine: nil, + ExecutionContext: context.Background(), + } + + results, err := coro.HandleCosocketYield("cosocket_connect", []glua.LValue{glua.LNumber(op.ID)}) + require.NoError(t, err) + assert.Equal(t, 1, len(results)) + assert.Equal(t, glua.LNumber(1), results[0]) +} + +// TestRegisterTCPSocketMetaTable 测试元表注册 +func TestRegisterTCPSocketMetaTable(t *testing.T) { + L := glua.NewState() + defer L.Close() + + registerTCPSocketMetaTable(L) + + // 验证元表存在 + mt := L.GetGlobal(tcpSocketMT) + assert.NotNil(t, mt) + assert.IsType(t, &glua.LTable{}, mt) +} + +// TestRegisterTCPSocketAPI_NgxTableCreation 测试 ngx.socket API 注册时创建 ngx 表 +func TestRegisterTCPSocketAPI_NgxTableCreation(t *testing.T) { + L := glua.NewState() + defer L.Close() + + // 确保 ngx 表不存在 + L.SetGlobal("ngx", glua.LNil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + RegisterTCPSocketAPI(L, engine) + + // 验证 ngx.socket.tcp 存在 + ngx := L.GetGlobal("ngx") + require.IsType(t, &glua.LTable{}, ngx) + ngxTbl := ngx.(*glua.LTable) + + socket := ngxTbl.RawGetString("socket") + require.IsType(t, &glua.LTable{}, socket) + socketTbl := socket.(*glua.LTable) + + tcp := socketTbl.RawGetString("tcp") + assert.IsType(t, &glua.LFunction{}, tcp) +} + +// TestOperationType_String 测试操作类型字符串 +func TestOperationType_String(t *testing.T) { + assert.Equal(t, "connect", string(OpConnect)) + assert.Equal(t, "send", string(OpSend)) + assert.Equal(t, "receive", string(OpReceive)) + assert.Equal(t, "close", string(OpClose)) +} + +// TestSocketOperation_Complete 测试操作完成 +func TestSocketOperation_Complete(t *testing.T) { + op := &SocketOperation{ + ID: 1, + Done: make(chan struct{}), + } + + // 完成操作 + op.Complete("result", nil) + assert.True(t, op.IsCompleted()) + assert.Equal(t, "result", op.Result) + assert.Nil(t, op.Error) + + // 重复完成应该无影响 + op.Complete("other", fmt.Errorf("err")) + assert.Equal(t, "result", op.Result) // 保持第一次的值 +} + +// TestSocketOperation_Wait_Timeout 测试 Wait 超时 +func TestSocketOperation_Wait_Timeout(t *testing.T) { + op := &SocketOperation{ + ID: 1, + Done: make(chan struct{}), + } + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + result, err := op.Wait(ctx) + assert.Nil(t, result) + assert.Error(t, err) // context deadline exceeded +} + +// TestSocketOperation_Touch 测试 Touch +func TestSocketOperation_Touch(t *testing.T) { + op := &SocketOperation{ + ID: 1, + Done: make(chan struct{}), + LastActivity: time.Now().Add(-time.Hour), + } + + oldTime := op.LastActivity + op.Touch() + assert.True(t, op.LastActivity.After(oldTime)) +} diff --git a/internal/lua/api_var.go b/internal/lua/api_var.go index cd9bb78..49b33db 100644 --- a/internal/lua/api_var.go +++ b/internal/lua/api_var.go @@ -61,9 +61,10 @@ func (api *ngxVarAPI) luaVarIndex(L *glua.LState) int { } // 2. 从 fasthttp RequestCtx 获取变量 - value := api.getVariable(key) - if value != "" { - L.Push(glua.LString(value)) + // 某些变量需要返回数值类型(如 request_length) + lv := api.getVariableLua(key) + if lv != nil { + L.Push(lv) return 1 } @@ -74,28 +75,123 @@ func (api *ngxVarAPI) luaVarIndex(L *glua.LState) int { // luaVarNewIndex 实现 ngx.var[key] = value 写入 // Lua 调用: ngx.var.key = value 或 ngx.var[key] = value +// 注意:Lua 的 nil 会被转换为空字符串存储 func (api *ngxVarAPI) luaVarNewIndex(L *glua.LState) int { // 第一个参数是表本身(ngx.var) // 第二个参数是键名 // 第三个参数是值 key := L.CheckString(2) - value := L.CheckString(3) + value := L.OptString(3, "") - // 存储到自定义变量存储 + // 存储到自定义变量存储(nil 会转换为空字符串) api.store[key] = value return 0 } -// getVariable 从 fasthttp RequestCtx 获取变量值 -// 支持常见的 nginx 变量 +// getVariableLua 从 fasthttp RequestCtx 获取变量值,返回 Lua 类型 +// 支持常见的 nginx 变量,某些变量返回数值类型 +func (api *ngxVarAPI) getVariableLua(name string) glua.LValue { + if api.ctx == nil { + return nil + } + + switch name { + // HTTP 请求相关 - 数值类型 + case "request_length": + return glua.LNumber(api.ctx.Request.Header.ContentLength()) + + // HTTP 请求相关 - 字符串类型 + case "request_method": + return glua.LString(string(api.ctx.Method())) + case "request_uri": + return glua.LString(string(api.ctx.RequestURI())) + case "uri": + return glua.LString(string(api.ctx.URI().Path())) + case "document_uri": + return glua.LString(string(api.ctx.URI().Path())) + case "query_string", "args": + return glua.LString(string(api.ctx.URI().QueryString())) + case "server_protocol", "protocol": + return glua.LString(string(api.ctx.Request.Header.Protocol())) + case "scheme": + return glua.LString(string(api.ctx.URI().Scheme())) + case "request_time": + // 简化实现,返回空字符串 + return glua.LString("") + + // 请求头相关 + case "http_host": + return glua.LString(string(api.ctx.Host())) + case "http_user_agent", "http_user-agent": + return glua.LString(string(api.ctx.UserAgent())) + case "http_referer": + return glua.LString(string(api.ctx.Referer())) + case "http_accept": + return glua.LString(string(api.ctx.Request.Header.Peek("Accept"))) + case "http_accept_encoding", "http_accept-encoding": + return glua.LString(string(api.ctx.Request.Header.Peek("Accept-Encoding"))) + case "http_accept_language", "http_accept-language": + return glua.LString(string(api.ctx.Request.Header.Peek("Accept-Language"))) + case "http_connection": + return glua.LString(string(api.ctx.Request.Header.Peek("Connection"))) + case "http_content_type", "http_content-type": + return glua.LString(string(api.ctx.Request.Header.ContentType())) + case "http_content_length", "http_content-length": + return glua.LString(string(api.ctx.Request.Header.Peek("Content-Length"))) + + // 客户端信息 + case "remote_addr": + return glua.LString(api.ctx.RemoteAddr().String()) + case "remote_port": + addr := api.ctx.RemoteAddr() + if addr != nil { + // 简化处理,实际可能需要解析端口 + return glua.LString("") + } + return glua.LString("") + case "binary_remote_addr": + return glua.LString("") + + // 服务器信息 + case "server_addr": + addr := api.ctx.LocalAddr() + if addr != nil { + return glua.LString(addr.String()) + } + return glua.LString("") + case "server_port": + return glua.LString("") + case "server_name": + return glua.LString(string(api.ctx.Host())) + + // URI 参数 + case "arg_": + // 获取所有参数 + return glua.LString(string(api.ctx.URI().QueryString())) + default: + // 检查是否是 arg_ 开头的参数 + if len(name) > 4 && name[:4] == "arg_" { + paramName := name[4:] + return glua.LString(string(api.ctx.QueryArgs().Peek(paramName))) + } + // 检查是否是 http_ 开头的请求头 + if len(name) > 5 && name[:5] == "http_" { + headerName := name[5:] + return glua.LString(string(api.ctx.Request.Header.Peek(headerName))) + } + return nil + } +} + +// getVariable 从 fasthttp RequestCtx 获取变量值(字符串形式) +// 用于 Go 层调用 func (api *ngxVarAPI) getVariable(name string) string { if api.ctx == nil { return "" } switch name { - // HTTP 请求相关 case "request_method": return string(api.ctx.Method()) case "request_uri": @@ -113,10 +209,7 @@ func (api *ngxVarAPI) getVariable(name string) string { case "request_length": return strconv.Itoa(api.ctx.Request.Header.ContentLength()) case "request_time": - // 简化实现,返回空字符串 return "" - - // 请求头相关 case "http_host": return string(api.ctx.Host()) case "http_user_agent", "http_user-agent": @@ -135,21 +228,12 @@ func (api *ngxVarAPI) getVariable(name string) string { return string(api.ctx.Request.Header.ContentType()) case "http_content_length", "http_content-length": return string(api.ctx.Request.Header.Peek("Content-Length")) - - // 客户端信息 case "remote_addr": return api.ctx.RemoteAddr().String() case "remote_port": - addr := api.ctx.RemoteAddr() - if addr != nil { - // 简化处理,实际可能需要解析端口 - return "" - } return "" case "binary_remote_addr": return "" - - // 服务器信息 case "server_addr": addr := api.ctx.LocalAddr() if addr != nil { @@ -160,18 +244,13 @@ func (api *ngxVarAPI) getVariable(name string) string { return "" case "server_name": return string(api.ctx.Host()) - - // URI 参数 case "arg_": - // 获取所有参数 return string(api.ctx.URI().QueryString()) default: - // 检查是否是 arg_ 开头的参数 if len(name) > 4 && name[:4] == "arg_" { paramName := name[4:] return string(api.ctx.QueryArgs().Peek(paramName)) } - // 检查是否是 http_ 开头的请求头 if len(name) > 5 && name[:5] == "http_" { headerName := name[5:] return string(api.ctx.Request.Header.Peek(headerName)) diff --git a/internal/lua/api_var_test.go b/internal/lua/api_var_test.go index e7d2205..0b3109e 100644 --- a/internal/lua/api_var_test.go +++ b/internal/lua/api_var_test.go @@ -282,3 +282,397 @@ func TestNgxVarUndefined(t *testing.T) { `) assert.NoError(t, err) } + +// TestNgxVarAdditionalBuiltinVars 测试其他内置变量 +func TestNgxVarAdditionalBuiltinVars(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("DELETE") + req.Header.SetRequestURI("/api/users?id=123&name=test") + req.Header.Set("Host", "api.example.com") + req.Header.Set("User-Agent", "TestClient/1.0") + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer token123") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试其他内置变量 + err = coro.Execute(` + -- URI 相关变量 + local request_uri = ngx.var.request_uri + if request_uri ~= "/api/users?id=123&name=test" then + error("request_uri mismatch, got: " .. tostring(request_uri)) + end + + local uri = ngx.var.uri + if uri ~= "/api/users" then + error("urishould be /api/users, got: " .. tostring(uri)) + end + + local document_uri = ngx.var.document_uri + if document_uri ~= "/api/users" then + error("document_uri should be /api/users, got: " .. tostring(document_uri)) + end + + -- 查询字符串 + local query_string = ngx.var.query_string + if query_string ~= "id=123&name=test" then + error("query_string mismatch, got: " .. tostring(query_string)) + end + + local args = ngx.var.args + if args ~= "id=123&name=test" then + error("args should match query_string, got: " .. tostring(args)) + end + + -- 请求头 + local accept = ngx.var.http_accept + if accept ~= "application/json" then + error("http_accept mismatch, got: " .. tostring(accept)) + end + + local contentType = ngx.var.http_content_type + if contentType ~= "application/json" then + error("http_content_type mismatch, got: " .. tostring(contentType)) + end + + local authorization = ngx.var.http_authorization + if authorization ~= "Bearer token123" then + error("http_authorization mismatch, got: " .. tostring(authorization)) + end + + -- 内置变量 map + local vars = { + "request_method", "request_uri", "uri", "document_uri", + "query_string", "args", "http_host", "http_user_agent", + "http_accept", "http_content_type" + } + for _, v in ipairs(vars) do + local val = ngx.var[v] + if type(val) ~= "string" then + error("var " .. v .. " should be string, got: " .. type(val)) + end + end + `) + assert.NoError(t, err) +} + +// TestNgxVarDynamicArgsAccess 测试动态参数访问 +func TestNgxVarDynamicArgsAccess(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/search?keyword=lua&category=programming&limit=10") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试动态参数访问 + err = coro.Execute(` + -- 直接通过 arg_ 访问 + local keyword = ngx.var.arg_keyword + if keyword ~= "lua" then + error("arg_keyword should be 'lua', got: " .. tostring(keyword)) + end + + local category = ngx.var.arg_category + if category ~= "programming" then + error("arg_category should be 'programming', got: " .. tostring(category)) + end + + local limit = ngx.var.arg_limit + if limit ~= "10" then + error("arg_limit should be '10', got: " .. tostring(limit)) + end + + -- 使用动态键访问 + local keys = {"keyword", "category", "limit"} + for i, k in ipairs(keys) do + local val = ngx.var["arg_" .. k] + if type(val) ~= "string" then + error("dynamic arg access should return string") + end + end + `) + assert.NoError(t, err) +} + +// TestNgxVarGoAPI 测试 Go 层 API 调用 +func TestNgxVarGoAPI(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/test") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + // 直接创建 API 实例并测试 Go 层 API + api := newNgxVarAPI(ctx) + require.NotNil(t, api) + + // 测试 SetVariable + api.SetVariable("go_set_var", "value_from_go") + value, ok := api.GetVariable("go_set_var") + assert.True(t, ok) + assert.Equal(t, "value_from_go", value) + + // 测试 GetVariable 不存在的变量 + value, ok = api.GetVariable("nonexistent") + assert.False(t, ok) + assert.Equal(t, "", value) + + // 测试覆盖:Go 设置,Go 读取验证 + api.SetVariable("cross_lang", "from_go") + val, ok := api.GetVariable("cross_lang") + assert.True(t, ok) + assert.Equal(t, "from_go", val) + + // 测试覆盖:直接设置 store,Go 读取验证 + api.store["cross_lang2"] = "from_lua" + value, ok = api.GetVariable("cross_lang2") + assert.True(t, ok) + assert.Equal(t, "from_lua", value) +} + +// TestNgxVarRequestMethodAccess 测试各种请求方法 +func TestNgxVarRequestMethodAccess(t *testing.T) { + methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"} + + for _, method := range methods { + t.Run(method, func(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod(method) + req.Header.SetRequestURI("/test") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + err = coro.Execute(` + local method = ngx.var.request_method + if method ~= "` + method + `" then + error("request_method should be '` + method + `', got: " .. tostring(method)) + end + `) + assert.NoError(t, err) + }) + } +} + +// TestNgxVarMixedAccessPatterns 测试混合访问模式 +func TestNgxVarMixedAccessPatterns(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/test") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + // 测试混合访问模式 + err = coro.Execute(` + -- 点号访问设置 + ngx.var.test1 = "value1" + -- 索引访问读取 + local val1 = ngx.var["test1"] + if val1 ~= "value1" then + error("mixed access 1 failed") + end + + -- 索引访问设置 + ngx.var["test2"] = "value2" + -- 点号访问读取 + local val2 = ngx.var.test2 + if val2 ~= "value2" then + error("mixed access 2 failed") + end + + -- 循环访问 + for i = 1, 3 do + ngx.var["dynamic_" .. i] = "val_" .. i + end + + for i = 1, 3 do + local v = ngx.var["dynamic_" .. i] + if v ~= "val_" .. i then + error("dynamic loop failed for i=" .. i) + end + end + `) + assert.NoError(t, err) +} + +// TestNgxVarSpecialHeaders 测试特殊请求头 +func TestNgxVarSpecialHeaders(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/test") + req.Header.Set("Accept-Encoding", "gzip, deflate, br") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Connection", "keep-alive") + req.Header.Set("Referer", "https://example.com") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + err = coro.Execute(` + -- 测试带连字符的头 + local acceptEncoding = ngx.var.http_accept_encoding + if acceptEncoding ~= "gzip, deflate, br" then + error("http_accept_encoding mismatch") + end + + local acceptLanguage = ngx.var.http_accept_language + if acceptLanguage ~= "en-US,en;q=0.9" then + error("http_accept_language mismatch") + end + + local connection = ngx.var.http_connection + if connection ~= "keep-alive" then + error("http_connection mismatch") + end + + local referer = ngx.var.http_referer + if referer ~= "https://example.com" then + error("http_referer mismatch") + end + + -- 测试也可以通过下划线访问 + local enc2 = ngx.var["http_accept_encoding"] + if enc2 ~= acceptEncoding then + error("http_accept_encoding via index mismatch") + end + `) + assert.NoError(t, err) +} + +// TestNgxVarEmptyAndNil 测试空值和 nil 处理 +func TestNgxVarEmptyAndNil(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("GET") + req.Header.SetRequestURI("/") + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + err = coro.Execute(` + -- 未设置的参数应该返回空字符串或 nil + local empty = ngx.var.arg_nonexistent + -- arg_ 对不存在的参数通常返回空字符串 + + -- 自定义变量设为空字符串 + ngx.var.empty_string = "" + local val = ngx.var.empty_string + if val ~= "" then + error("empty_string should be empty") + end + + -- 覆盖为空值 + ngx.var.test = "value" + ngx.var.test = nil -- Lua 的 nil 在 __newindex 中会被转换 + -- 实现中 nil 会被转换为空字符串 + `) + assert.NoError(t, err) +} + +// TestNgxVarRequestBodyAccess 测试请求体相关变量 +func TestNgxVarRequestBodyAccess(t *testing.T) { + var req fasthttp.Request + req.Header.SetMethod("POST") + req.Header.SetRequestURI("/upload") + req.Header.SetContentType("application/octet-stream") + req.SetBody([]byte("test body content")) + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&req, nil, nil) + + engine, err := NewEngine(DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + coro, err := engine.NewCoroutine(ctx) + require.NoError(t, err) + defer coro.Close() + + err = coro.SetupSandbox() + require.NoError(t, err) + + err = coro.Execute(` + -- 测试请求长度 + local length = ngx.var.request_length + if type(length) ~= "number" then + error("request_length should be a number") + end + + -- 测试内容类型 + local contentType = ngx.var.http_content_type + if contentType ~= "application/octet-stream" then + error("content_type mismatch") + end + `) + assert.NoError(t, err) +} diff --git a/internal/lua/coroutine.go b/internal/lua/coroutine.go index 0c54a53..60b03c1 100644 --- a/internal/lua/coroutine.go +++ b/internal/lua/coroutine.go @@ -70,6 +70,12 @@ type LuaCoroutine struct { executionCancel context.CancelFunc OutputBuffer []byte Exited bool + + // ngx API 实例(用于测试和 Go 层访问) + ngxVarAPI *ngxVarAPI + ngxReqAPI *ngxReqAPI + ngxRespAPI *ngxRespAPI + ngxLogAPI *ngxLogAPI } // SetupSandbox 创建 per-request _ENV 沙箱 @@ -163,20 +169,24 @@ func (c *LuaCoroutine) setupNgxAPI() { // 注册 ngx.req API if c.RequestCtx != nil { reqAPI := newNgxReqAPI(c.RequestCtx) + c.ngxReqAPI = reqAPI RegisterNgxReqAPI(c.Co, reqAPI, ngx) // 注册 ngx.resp API respAPI := newNgxRespAPI(c.RequestCtx) + c.ngxRespAPI = respAPI RegisterNgxRespAPI(c.Co, respAPI) // 注册 ngx.log API (logger 为 nil 时禁用日志输出) // ngx.say/print/flush 直接写入 RequestCtx logAPI := newNgxLogAPI(c.RequestCtx, nil, nil) + c.ngxLogAPI = logAPI RegisterNgxLogAPI(c.Co, logAPI) } // 注册 ngx.var API varAPI := newNgxVarAPI(c.RequestCtx) + c.ngxVarAPI = varAPI RegisterNgxVarAPI(c.Co, varAPI, ngx) // 注册 ngx.ctx API @@ -333,3 +343,23 @@ func (c *LuaCoroutine) handleSleep(values []glua.LValue) ([]glua.LValue, error) func (c *LuaCoroutine) Close() { c.Engine.releaseCoroutine(c) } + +// GetNgxVarAPI 获取 ngx.var API 实例(用于测试和 Go 层访问) +func (c *LuaCoroutine) GetNgxVarAPI() *ngxVarAPI { + return c.ngxVarAPI +} + +// GetNgxReqAPI 获取 ngx.req API 实例(用于测试和 Go 层访问) +func (c *LuaCoroutine) GetNgxReqAPI() *ngxReqAPI { + return c.ngxReqAPI +} + +// GetNgxRespAPI 获取 ngx.resp API 实例(用于测试和 Go 层访问) +func (c *LuaCoroutine) GetNgxRespAPI() *ngxRespAPI { + return c.ngxRespAPI +} + +// GetNgxLogAPI 获取 ngx.log API 实例(用于测试和 Go 层访问) +func (c *LuaCoroutine) GetNgxLogAPI() *ngxLogAPI { + return c.ngxLogAPI +}