lolly/internal/lua/socket_test.go
xfy 8bac2fdcfa feat(lua): 实现 Cosocket API 和响应拦截器
- ngx.req API 双层边界验证原型
- TCP Cosocket API (connect/send/receive/close)
- Cosocket 状态管理器和连接池
- ResponseInterceptor 响应拦截器
- 完整单元测试覆盖

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-10 18:30:28 +08:00

703 lines
15 KiB
Go

package lua
import (
"context"
"fmt"
"net"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
)
// mockEchoServer 模拟 echo 服务器
func mockEchoServer(t *testing.T, addr string) (net.Listener, func()) {
ln, err := net.Listen("tcp", addr)
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
var wg sync.WaitGroup
stop := make(chan struct{})
go func() {
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-stop:
return
default:
continue
}
}
wg.Add(1)
go func(c net.Conn) {
defer wg.Done()
buf := make([]byte, 4096)
for {
n, err := c.Read(buf)
if err != nil {
return
}
if n > 0 {
if _, err := c.Write(buf[:n]); err != nil {
return
}
}
}
}(conn)
}
}()
cleanup := func() {
close(stop)
ln.Close()
wg.Wait()
}
return ln, cleanup
}
// TestCosocketManager_Basic 测试基本功能
func TestCosocketManager_Basic(t *testing.T) {
cm := NewCosocketManager()
defer cm.Close()
// 测试初始状态
stats := cm.Stats()
if stats.TotalOperations != 0 {
t.Errorf("Expected 0 operations, got %d", stats.TotalOperations)
}
// 测试操作创建
socket := NewTCPSocket(cm)
defer socket.Close()
op := cm.StartOperation(socket, OpConnect, 5*time.Second)
if op == nil {
t.Fatal("Expected non-nil operation")
}
if op.ID == 0 {
t.Error("Expected non-zero operation ID")
}
if op.Type != OpConnect {
t.Errorf("Expected OpConnect, got %v", op.Type)
}
// 测试统计
stats = cm.Stats()
if stats.TotalOperations != 1 {
t.Errorf("Expected 1 operation, got %d", stats.TotalOperations)
}
if stats.ActiveOperations != 1 {
t.Errorf("Expected 1 active operation, got %d", stats.ActiveOperations)
}
// 测试操作完成
cm.CompleteOperation(op.ID, "done", nil)
stats = cm.Stats()
if stats.ActiveOperations != 0 {
t.Errorf("Expected 0 active operations after complete, got %d", stats.ActiveOperations)
}
}
// TestCosocketManager_Timeout 测试超时机制
func TestCosocketManager_Timeout(t *testing.T) {
// 创建一个使用短清理间隔的管理器用于测试
cm := NewCosocketManager()
cm.SetDefaultTimeout(100 * time.Millisecond)
cm.cleanupInterval = 50 * time.Millisecond
cm.timeoutChecker.Reset(50 * time.Millisecond)
defer cm.Close()
socket := NewTCPSocket(cm)
defer socket.Close()
// 创建一个不完成的操作
op := cm.StartOperation(socket, OpConnect, 100*time.Millisecond)
// 等待超时清理
time.Sleep(300 * time.Millisecond)
// 检查操作是否超时完成
if !op.IsCompleted() {
t.Error("Expected operation to be completed due to timeout")
}
stats := cm.Stats()
if stats.TimeoutOperations != 1 {
t.Errorf("Expected 1 timeout operation, got %d", stats.TimeoutOperations)
}
}
// TestTCPSocket_Connect 测试 TCP 连接
func TestTCPSocket_Connect(t *testing.T) {
_, cleanup := mockEchoServer(t, "127.0.0.1:19999")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
socket := NewTCPSocket(cm)
defer socket.Close()
// 测试连接
err := socket.Connect("127.0.0.1", 19999)
if err != nil {
t.Fatalf("Connect failed: %v", err)
}
// 等待连接完成
op := socket.currentOp
if op != nil {
result, err := op.Wait(context.Background())
if err != nil {
t.Fatalf("Connect wait failed: %v", err)
}
if result == nil {
t.Fatal("Expected non-nil connection")
}
}
if socket.State() != SocketStateConnected {
t.Errorf("Expected state connected, got %v", socket.State())
}
}
// TestTCPSocket_SendReceive 测试发送接收
func TestTCPSocket_SendReceive(t *testing.T) {
_, cleanup := mockEchoServer(t, "127.0.0.1:19998")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
socket := NewTCPSocket(cm)
defer socket.Close()
// 连接
if err := socket.Connect("127.0.0.1", 19998); err != nil {
t.Fatalf("Connect failed: %v", err)
}
// 等待连接完成
time.Sleep(100 * time.Millisecond)
// 发送数据
testData := "Hello, Cosocket!"
n, err := socket.Send([]byte(testData))
if err != nil {
t.Fatalf("Send failed: %v", err)
}
if n != len(testData) {
t.Errorf("Expected %d bytes sent, got %d", len(testData), n)
}
// 接收数据
received, err := socket.Receive(1024)
if err != nil {
t.Fatalf("Receive failed: %v", err)
}
if string(received) != testData {
t.Errorf("Expected '%s', got '%s'", testData, string(received))
}
}
// TestTCPSocket_AsyncOperations 测试异步操作
func TestTCPSocket_AsyncOperations(t *testing.T) {
_, cleanup := mockEchoServer(t, "127.0.0.1:19997")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
socket := NewTCPSocket(cm)
defer socket.Close()
// 测试异步连接
err := socket.Connect("127.0.0.1", 19997)
if err != nil {
t.Fatalf("Connect failed: %v", err)
}
// 等待连接完成
time.Sleep(100 * time.Millisecond)
// 测试异步发送
testData := "Async test"
sendOp, err := socket.SendAsync([]byte(testData))
if err != nil {
t.Fatalf("SendAsync failed: %v", err)
}
result, err := sendOp.Wait(context.Background())
if err != nil {
t.Fatalf("Send wait failed: %v", err)
}
if n, ok := result.(int); !ok || n != len(testData) {
t.Errorf("Expected %d bytes, got %v", len(testData), result)
}
// 测试异步接收
recvOp, err := socket.ReceiveAsync(1024)
if err != nil {
t.Fatalf("ReceiveAsync failed: %v", err)
}
result, err = recvOp.Wait(context.Background())
if err != nil {
t.Fatalf("Receive wait failed: %v", err)
}
if data, ok := result.([]byte); !ok || string(data) != testData {
t.Errorf("Expected '%s', got %v", testData, result)
}
}
// TestTCPSocket_ReceiveUntil 测试接收直到特定模式
func TestTCPSocket_ReceiveUntil(t *testing.T) {
_, cleanup := mockEchoServer(t, "127.0.0.1:19996")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
socket := NewTCPSocket(cm)
defer socket.Close()
// 连接
if err := socket.Connect("127.0.0.1", 19996); err != nil {
t.Fatalf("Connect failed: %v", err)
}
// 等待连接完成
time.Sleep(100 * time.Millisecond)
// 发送带换行的数据
testData := "Line1\nLine2\nLine3\n"
_, err := socket.Send([]byte(testData))
if err != nil {
t.Fatalf("Send failed: %v", err)
}
// 接收直到换行
data, err := socket.ReceiveUntil("\n", true)
if err != nil {
t.Fatalf("ReceiveUntil failed: %v", err)
}
if string(data) != "Line1\n" {
t.Errorf("Expected 'Line1\\n', got '%s'", string(data))
}
}
// TestTCPSocket_Close 测试关闭
func TestTCPSocket_Close(t *testing.T) {
cm := NewCosocketManager()
defer cm.Close()
socket := NewTCPSocket(cm)
if socket.IsClosed() {
t.Error("Socket should not be closed initially")
}
err := socket.Close()
if err != nil {
t.Errorf("Close failed: %v", err)
}
if !socket.IsClosed() {
t.Error("Socket should be closed")
}
// 重复关闭应该返回 nil
err = socket.Close()
if err != nil {
t.Errorf("Second close should not error: %v", err)
}
}
// TestTCPSocket_StateTransitions 测试状态转换
func TestTCPSocket_StateTransitions(t *testing.T) {
_, cleanup := mockEchoServer(t, "127.0.0.1:19995")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
socket := NewTCPSocket(cm)
defer socket.Close()
// 初始状态
if socket.State() != SocketStateIdle {
t.Errorf("Expected idle state, got %v", socket.State())
}
// 连接中
socket.Connect("127.0.0.1", 19995)
if socket.State() != SocketStateConnecting {
t.Errorf("Expected connecting state, got %v", socket.State())
}
// 等待连接完成
time.Sleep(100 * time.Millisecond)
if socket.State() != SocketStateConnected {
t.Errorf("Expected connected state, got %v", socket.State())
}
}
// TestCosocketManager_Concurrent 测试并发操作
func TestCosocketManager_Concurrent(t *testing.T) {
if testing.Short() {
t.Skip("Skipping concurrent test in short mode")
}
_, cleanup := mockEchoServer(t, "127.0.0.1:19994")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
const numSockets = 1000
const numGoroutines = 100
var wg sync.WaitGroup
errors := make(chan error, numSockets)
var completed int32
// 并发创建 socket 和连接
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(start int) {
defer wg.Done()
for j := 0; j < numSockets/numGoroutines; j++ {
socket := NewTCPSocket(cm)
if err := socket.Connect("127.0.0.1", 19994); err != nil {
errors <- fmt.Errorf("connect failed: %v", err)
socket.Close()
continue
}
// 等待连接
time.Sleep(50 * time.Millisecond)
// 发送数据
data := fmt.Sprintf("Test%d", start+j)
if _, err := socket.Send([]byte(data)); err != nil {
errors <- fmt.Errorf("send failed: %v", err)
socket.Close()
continue
}
socket.Close()
atomic.AddInt32(&completed, 1)
}
}(i * (numSockets / numGoroutines))
}
wg.Wait()
// 检查错误
close(errors)
errCount := 0
for err := range errors {
t.Logf("Error: %v", err)
errCount++
}
t.Logf("Completed: %d/%d, Errors: %d", completed, numSockets, errCount)
// 检查统计
stats := cm.Stats()
t.Logf("Stats: %+v", stats)
}
// TestCosocketManager_MemoryLeak 测试内存泄漏
func TestCosocketManager_MemoryLeak(t *testing.T) {
if testing.Short() {
t.Skip("Skipping memory leak test in short mode")
}
_, cleanup := mockEchoServer(t, "127.0.0.1:19993")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
// 记录初始 goroutine 数
initialGoroutines := runtime.NumGoroutine()
// 创建和关闭大量 socket
for i := 0; i < 10000; i++ {
socket := NewTCPSocket(cm)
// 使用同步连接避免竞态
socket.Connect("127.0.0.1", 19993)
time.Sleep(time.Millisecond) // 给连接时间完成
socket.Close()
}
// 强制 GC
runtime.GC()
time.Sleep(100 * time.Millisecond)
// 检查 goroutine 数量
finalGoroutines := runtime.NumGoroutine()
t.Logf("Initial goroutines: %d, Final goroutines: %d", initialGoroutines, finalGoroutines)
// 允许一定的波动
if finalGoroutines > initialGoroutines+100 {
t.Errorf("Possible goroutine leak: started with %d, ended with %d", initialGoroutines, finalGoroutines)
}
// 检查统计
stats := cm.Stats()
if stats.ActiveSockets > 100 {
t.Errorf("Active sockets leak: %d", stats.ActiveSockets)
}
if stats.ActiveOperations > 100 {
t.Errorf("Active operations leak: %d", stats.ActiveOperations)
}
}
// TestCosocketManager_LongRunning 测试长时间运行
func TestCosocketManager_LongRunning(t *testing.T) {
if testing.Short() {
t.Skip("Skipping long running test in short mode")
}
_, cleanup := mockEchoServer(t, "127.0.0.1:19992")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
duration := 10 * time.Second // 缩短到 10 秒进行测试
interval := 100 * time.Millisecond
var totalOps int32
start := time.Now()
for time.Since(start) < duration {
socket := NewTCPSocket(cm)
if err := socket.Connect("127.0.0.1", 19992); err != nil {
t.Logf("Connect error: %v", err)
socket.Close()
continue
}
// 等待连接
time.Sleep(50 * time.Millisecond)
// 发送接收
if _, err := socket.Send([]byte("test")); err == nil {
socket.Receive(1024)
}
socket.Close()
atomic.AddInt32(&totalOps, 1)
time.Sleep(interval)
}
elapsed := time.Since(start)
t.Logf("Completed %d operations in %v", totalOps, elapsed)
// 检查最终统计
stats := cm.Stats()
t.Logf("Final stats: %+v", stats)
if stats.ActiveSockets > 0 {
t.Errorf("Expected 0 active sockets, got %d", stats.ActiveSockets)
}
if stats.ActiveOperations > 0 {
t.Errorf("Expected 0 active operations, got %d", stats.ActiveOperations)
}
}
// BenchmarkCosocket_Connect 基准测试:连接
func BenchmarkCosocket_Connect(b *testing.B) {
_, cleanup := mockEchoServer(nil, "127.0.0.1:19991")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
socket := NewTCPSocket(cm)
socket.Connect("127.0.0.1", 19991)
time.Sleep(10 * time.Millisecond)
socket.Close()
}
})
}
// BenchmarkCosocket_SendReceive 基准测试:发送接收
func BenchmarkCosocket_SendReceive(b *testing.B) {
_, cleanup := mockEchoServer(nil, "127.0.0.1:19990")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
// 预先连接
socket := NewTCPSocket(cm)
socket.Connect("127.0.0.1", 19990)
time.Sleep(100 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
socket.Send([]byte("benchmark"))
socket.Receive(1024)
}
b.StopTimer()
socket.Close()
}
// TestLuaAPI_TCPSocket 测试 Lua API
func TestLuaAPI_TCPSocket(t *testing.T) {
if testing.Short() {
t.Skip("Skipping Lua API test in short mode")
}
// 创建引擎
engine, err := NewEngine(nil)
if err != nil {
t.Fatalf("Failed to create engine: %v", err)
}
defer engine.Close()
// 注册 TCP socket API
RegisterTCPSocketAPI(engine.L, engine)
// 测试创建 socket
script := `
local sock = ngx.socket.tcp()
if not sock then
return nil, "failed to create socket"
end
return "ok"
`
coro, err := engine.NewCoroutine(nil)
if err != nil {
t.Fatalf("Failed to create coroutine: %v", err)
}
defer coro.Close()
if err := coro.SetupSandbox(); err != nil {
t.Fatalf("Failed to setup sandbox: %v", err)
}
err = coro.Execute(script)
if err != nil {
t.Errorf("Execute failed: %v", err)
}
}
// TestCosocketManager_Stress 压力测试
func TestCosocketManager_Stress(t *testing.T) {
if testing.Short() {
t.Skip("Skipping stress test in short mode")
}
// 创建多个 echo 服务器
ports := []int{19980, 19981, 19982, 19983}
cleanups := make([]func(), len(ports))
for i, port := range ports {
_, cleanups[i] = mockEchoServer(t, fmt.Sprintf("127.0.0.1:%d", port))
}
defer func() {
for _, c := range cleanups {
c()
}
}()
cm := NewCosocketManager()
defer cm.Close()
const totalConnections = 10000
const concurrency = 1000
var wg sync.WaitGroup
var successCount int32
var errorCount int32
var latencySum int64
start := time.Now()
// 使用信号量限制并发
sem := make(chan struct{}, concurrency)
for i := 0; i < totalConnections; i++ {
wg.Add(1)
sem <- struct{}{} // 获取信号量
go func(idx int) {
defer wg.Done()
defer func() { <-sem }() // 释放信号量
port := ports[idx%len(ports)]
socket := NewTCPSocket(cm)
opStart := time.Now()
err := socket.Connect("127.0.0.1", port)
if err != nil {
atomic.AddInt32(&errorCount, 1)
socket.Close()
return
}
// 等待连接完成
time.Sleep(10 * time.Millisecond)
// 简单数据交换
if _, err := socket.Send([]byte("hello")); err == nil {
socket.Receive(1024)
}
socket.Close()
latency := time.Since(opStart).Milliseconds()
atomic.AddInt64(&latencySum, latency)
atomic.AddInt32(&successCount, 1)
}(i)
}
wg.Wait()
elapsed := time.Since(start)
t.Logf("Stress test completed:")
t.Logf(" Total: %d, Success: %d, Errors: %d", totalConnections, successCount, errorCount)
t.Logf(" Duration: %v", elapsed)
t.Logf(" RPS: %.2f", float64(totalConnections)/elapsed.Seconds())
if successCount > 0 {
t.Logf(" Avg Latency: %dms", latencySum/int64(successCount))
}
// 内存检查
var m runtime.MemStats
runtime.GC()
runtime.ReadMemStats(&m)
t.Logf(" Memory: %.2f MB", float64(m.HeapAlloc)/(1024*1024))
// 验证没有资源泄漏
stats := cm.Stats()
t.Logf(" Active sockets: %d, Active operations: %d", stats.ActiveSockets, stats.ActiveOperations)
if errorCount > totalConnections/10 { // 允许 10% 错误率
t.Errorf("Too many errors: %d", errorCount)
}
if stats.ActiveSockets > 100 {
t.Errorf("Socket leak detected: %d active sockets", stats.ActiveSockets)
}
}