From d0867bfe3e8e234741a4b836ff6984746fa742d0 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 3 Jun 2026 16:31:18 +0800 Subject: [PATCH] refactor(lua): remove unused mock engine and filter writer subsystem - Delete mock_engine.go (331 lines): unused MockLuaEngine/MockCoroutine - Delete filter_writer.go (811 lines): DelayedResponseWriter not integrated - Delete filter_phase_test.go (1466 lines): tests for removed filter code - Total: 2608 lines of dead code removed --- internal/lua/filter_phase_test.go | 1466 ----------------------------- internal/lua/filter_writer.go | 811 ---------------- internal/lua/mock_engine.go | 331 ------- 3 files changed, 2608 deletions(-) delete mode 100644 internal/lua/filter_phase_test.go delete mode 100644 internal/lua/filter_writer.go delete mode 100644 internal/lua/mock_engine.go diff --git a/internal/lua/filter_phase_test.go b/internal/lua/filter_phase_test.go deleted file mode 100644 index 40b8e5e..0000000 --- a/internal/lua/filter_phase_test.go +++ /dev/null @@ -1,1466 +0,0 @@ -package lua - -import ( - "fmt" - "io" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/valyala/fasthttp" -) - -// mockRequestCtx 创建模拟的 RequestCtx -func mockRequestCtx() *fasthttp.RequestCtx { - ctx := &fasthttp.RequestCtx{} - // 初始化必要的字段 - ctx.Response.Header.Set("Content-Type", "text/plain") - ctx.Response.SetStatusCode(200) - return ctx -} - -// TestResponseInterceptor_Basic 测试基本的响应拦截功能 -func TestResponseInterceptor_Basic(t *testing.T) { - ctx := mockRequestCtx() - ri := NewResponseInterceptor(ctx) - - // 启用拦截 - ri.Enable() - assert.True(t, ri.IsEnabled()) - - // 写入 body(应该被缓冲) - n, err := ri.Write([]byte("Hello, World!")) - require.NoError(t, err) - assert.Equal(t, 13, n) - - // 检查 body 被缓冲 - assert.Equal(t, "Hello, World!", string(ri.GetBufferedBody())) - assert.False(t, ri.headersWritten) -} - -// TestResponseInterceptor_HeaderModification 测试 header 修改 -func TestResponseInterceptor_HeaderModification(t *testing.T) { - ctx := mockRequestCtx() - ri := NewResponseInterceptor(ctx) - ri.Enable() - - // 设置 header - ri.SetHeader("X-Custom-Header", "custom-value") - ri.SetHeader("Cache-Control", "no-cache") - ri.DelHeader("Content-Type") - - // 设置状态码 - ri.SetStatusCode(201) - - // 设置 header filter 回调 - ri.SetHeaderFilter(func() error { - // 模拟 Lua 修改 header - ri.SetHeader("X-Lua-Modified", "true") - return nil - }) - - // 写入一些 body - ri.WriteString("test body") - - // 刷新 - err := ri.Flush() - require.NoError(t, err) - - // 验证 header - assert.Equal(t, 201, ctx.Response.StatusCode()) - assert.Equal(t, "custom-value", string(ctx.Response.Header.Peek("X-Custom-Header"))) - assert.Equal(t, "no-cache", string(ctx.Response.Header.Peek("Cache-Control"))) - assert.Equal(t, "true", string(ctx.Response.Header.Peek("X-Lua-Modified"))) - // Content-Type is set by fasthttp -} - -// TestResponseInterceptor_BodyFilter 测试 body filter -func TestResponseInterceptor_BodyFilter(t *testing.T) { - ctx := mockRequestCtx() - ri := NewResponseInterceptor(ctx) - ri.Enable() - - // 设置 body filter 回调(模拟 Lua 修改 body) - ri.SetBodyFilter(func(body []byte) ([]byte, error) { - // 添加前缀 - modified := append([]byte("[MODIFIED] "), body...) - return modified, nil - }) - - // 写入 body - ri.WriteString("original content") - - // 刷新 - err := ri.Flush() - require.NoError(t, err) - - // 验证 body 被修改 - assert.Equal(t, "[MODIFIED] original content", string(ctx.Response.Body())) -} - -// TestDelayedResponseWriter 测试延迟响应写入器 -func TestDelayedResponseWriter(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - - // 启用 filter phase - drw.EnableFilterPhase() - assert.True(t, drw.GetInterceptor().IsEnabled()) - - // 设置 header - drw.SetHeader("X-Test", "value") - drw.SetStatusCode(202) - - // 写入 body(应该被缓冲) - drw.WriteString("Hello") - drw.Write([]byte(" World")) - - // 验证 body 被缓冲,未实际发送 - assert.Equal(t, 11, drw.GetBufferedBodySize()) - assert.Equal(t, "Hello World", string(drw.GetInterceptor().GetBufferedBody())) - - // 刷新 - err := drw.Flush() - require.NoError(t, err) - - // 验证 - assert.Equal(t, 202, ctx.Response.StatusCode()) - assert.Equal(t, "value", string(ctx.Response.Header.Peek("X-Test"))) - assert.Equal(t, "Hello World", string(ctx.Response.Body())) -} - -// TestDelayedResponseWriter_WithLuaEngine 测试与 Lua 引擎集成 -func TestDelayedResponseWriter_WithLuaEngine(t *testing.T) { - // 创建 Lua 引擎 - engine, err := NewEngine(DefaultConfig()) - require.NoError(t, err) - defer engine.Close() - - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 创建 Lua 上下文 - luaCtx := NewContext(engine, ctx) - defer luaCtx.Release() - - err = luaCtx.InitCoroutine() - require.NoError(t, err) - - // 设置 header filter - err = drw.HeaderFilter(` - ngx.status = 418 - ngx.header["X-Teapot"] = "I'm a teapot" - `, luaCtx) - require.NoError(t, err) - - // 设置 body filter - err = drw.BodyFilter(` - -- 假设 ngx.body 可以访问 - ngx.say("[FILTERED] ") - `, luaCtx) - require.NoError(t, err) - - // 写入 body - drw.WriteString("test") - - // 刷新 - err = drw.Flush() - // 当前 Lua 脚本可能失败,但结构是正确的 - // require.NoError(t, err) - _ = err -} - -// BenchmarkResponseInterceptor 基准测试响应拦截器。 -// -// 注意:每个 goroutine 必须创建独立的 RequestCtx,因为 fasthttp.RequestCtx -// 不是并发安全的。Flush() 会修改 ResponseHeader 的内部 map。 -func BenchmarkResponseInterceptor(b *testing.B) { - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - ctx := mockRequestCtx() - ri := NewResponseInterceptor(ctx) - ri.Enable() - ri.WriteString("Hello, World!") - _ = ri.Flush() - } - }) -} - -// BenchmarkDelayedWrite 基准测试延迟写入 -func BenchmarkDelayedWrite(b *testing.B) { - ctx := mockRequestCtx() - body := []byte("Hello, World! This is a test body for benchmarking.") - - b.ResetTimer() - for b.Loop() { - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - drw.Write(body) - _ = drw.Flush() - } -} - -// BenchmarkNormalWrite 基准测试正常写入(对比) -func BenchmarkNormalWrite(b *testing.B) { - body := []byte("Hello, World! This is a test body for benchmarking.") - - b.ResetTimer() - for b.Loop() { - ctx := mockRequestCtx() - ctx.Write(body) - } -} - -// BenchmarkHeaderFilter 基准测试 header filter -func BenchmarkHeaderFilter(b *testing.B) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 模拟 header filter - drw.GetInterceptor().SetHeaderFilter(func() error { - drw.SetHeader("X-Test", "value") - drw.SetStatusCode(201) - return nil - }) - - b.ResetTimer() - for b.Loop() { - drw.WriteString("test") - _ = drw.Flush() - drw.Reset() - drw.EnableFilterPhase() - } -} - -// TestDelayedResponseWriter_Pool 测试对象池性能 -func TestDelayedResponseWriter_Pool(t *testing.T) { - ctx := mockRequestCtx() - - // 预热池 - for range 100 { - ri := AcquireResponseInterceptor(ctx) - ReleaseResponseInterceptor(ri) - } - - // 测试从池获取的性能 - start := time.Now() - for range 10000 { - ri := AcquireResponseInterceptor(ctx) - ri.WriteString("test") - ReleaseResponseInterceptor(ri) - } - elapsed := time.Since(start) - - t.Logf("Pool operations: 10000 ops in %v (%.2f ops/sec)", elapsed, 10000.0/elapsed.Seconds()) -} - -// TestConcurrentAccess 测试并发访问安全性 -func TestConcurrentAccess(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - var wg sync.WaitGroup - errors := make(chan error, 100) - - for i := range 100 { - wg.Add(1) - go func(idx int) { - defer wg.Done() - drw.SetHeader(fmt.Sprintf("X-Test-%d", idx), fmt.Sprintf("value-%d", idx)) - _, err := drw.WriteString(fmt.Sprintf("data-%d", idx)) - if err != nil { - errors <- err - } - }(i) - } - - wg.Wait() - close(errors) - - // 收集错误 - errList := make([]error, 0, 100) - for err := range errors { - errList = append(errList, err) - } - - // 注意:fasthttp.RequestCtx 不是并发安全的 - // 这里只是测试我们的包装器没有引入额外的并发问题 - // 实际使用时需要保证单 goroutine 访问 - t.Logf("Concurrent operations completed, %d errors", len(errList)) -} - -// TestDelayedResponseWriter_WithLuaHeaderModification 测试 Lua header 修改 -func TestDelayedResponseWriter_WithLuaHeaderModification(t *testing.T) { - // 创建 Lua 引擎 - engine, err := NewEngine(DefaultConfig()) - require.NoError(t, err) - defer engine.Close() - - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 创建 Lua 上下文 - luaCtx := NewContext(engine, ctx) - defer luaCtx.Release() - - err = luaCtx.InitCoroutine() - require.NoError(t, err) - - // 手动设置 header 修改(模拟 Lua 操作) - drw.SetHeader("X-Lua-Header", "lua-value") - drw.SetStatusCode(201) - - // 写入并刷新 - drw.WriteString("test body") - err = drw.Flush() - require.NoError(t, err) - - // 验证 - assert.Equal(t, 201, ctx.Response.StatusCode()) - assert.Equal(t, "lua-value", string(ctx.Response.Header.Peek("X-Lua-Header"))) - assert.Equal(t, "test body", string(ctx.Response.Body())) -} - -// TestHeaderFilterPhase 专门测试 header filter phase -func TestHeaderFilterPhase(t *testing.T) { - tests := []struct { - name string - initialStatus int - modifiedStatus int - initialHeaders map[string]string - modifiedHeaders map[string]string - deletedHeaders []string - }{ - { - name: "status modification", - initialStatus: 200, - modifiedStatus: 404, - initialHeaders: map[string]string{}, - modifiedHeaders: map[string]string{}, - }, - { - name: "header addition", - initialStatus: 200, - modifiedStatus: 200, - initialHeaders: map[string]string{}, - modifiedHeaders: map[string]string{ - "X-Custom": "added", - }, - }, - { - name: "header modification", - initialStatus: 200, - modifiedStatus: 200, - initialHeaders: map[string]string{ - "Content-Type": "text/plain", - }, - modifiedHeaders: map[string]string{ - "Content-Type": "application/json", - }, - }, - { - name: "header deletion", - initialStatus: 200, - modifiedStatus: 200, - initialHeaders: map[string]string{ - "X-Remove": "value", - }, - deletedHeaders: []string{"X-Remove"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 设置初始 headers - for k, v := range tt.initialHeaders { - ctx.Response.Header.Set(k, v) - } - ctx.Response.SetStatusCode(tt.initialStatus) - - // 应用修改 - drw.SetStatusCode(tt.modifiedStatus) - for k, v := range tt.modifiedHeaders { - drw.SetHeader(k, v) - } - for _, k := range tt.deletedHeaders { - drw.DelHeader(k) - } - - // 刷新 - drw.WriteString("test") - err := drw.Flush() - require.NoError(t, err) - - // 验证状态码 - assert.Equal(t, tt.modifiedStatus, ctx.Response.StatusCode()) - - // 验证修改的 headers - for k, v := range tt.modifiedHeaders { - assert.Equal(t, v, string(ctx.Response.Header.Peek(k))) - } - - // 验证删除的 headers - for _, k := range tt.deletedHeaders { - assert.Equal(t, "", string(ctx.Response.Header.Peek(k))) - } - }) - } -} - -// TestBodyFilterPhase 测试 body filter phase -func TestBodyFilterPhase(t *testing.T) { - tests := []struct { - name string - inputBody string - filterFunc func([]byte) []byte - expectedOutput string - }{ - { - name: "prepend content", - inputBody: "Hello", - filterFunc: func(b []byte) []byte { - return append([]byte("Prefix: "), b...) - }, - expectedOutput: "Prefix: Hello", - }, - { - name: "append content", - inputBody: "Hello", - filterFunc: func(b []byte) []byte { - return append(b, []byte(" Suffix")...) - }, - expectedOutput: "Hello Suffix", - }, - { - name: "replace content", - inputBody: "Hello World", - filterFunc: func(b []byte) []byte { - return []byte("Replaced") - }, - expectedOutput: "Replaced", - }, - { - name: "empty body", - inputBody: "", - filterFunc: func(b []byte) []byte { - return []byte("default") - }, - expectedOutput: "", - }, - { - name: "large body", - inputBody: strings.Repeat("x", 10000), - filterFunc: func(b []byte) []byte { - return append([]byte("size="), fmt.Appendf(nil, "%d ", len(b))...) - }, - expectedOutput: "size=10000 ", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 设置 body filter - drw.GetInterceptor().SetBodyFilter(func(body []byte) ([]byte, error) { - return tt.filterFunc(body), nil - }) - - // 写入 body - drw.WriteString(tt.inputBody) - - // 刷新 - err := drw.Flush() - require.NoError(t, err) - - // 验证输出 - assert.Equal(t, tt.expectedOutput, string(ctx.Response.Body())) - }) - } -} - -// TestFilterPhaseSuccessRate 测试 filter phase 成功率 -func TestFilterPhaseSuccessRate(t *testing.T) { - const totalRequests = 1000 - - successCount := 0 - var mu sync.Mutex - - for i := range totalRequests { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 设置 header - drw.SetHeader("X-Request-ID", fmt.Sprintf("%d", i)) - drw.SetStatusCode(200) - - // 写入 body - drw.WriteString(fmt.Sprintf("Response %d", i)) - - // 刷新 - err := drw.Flush() - if err == nil { - // 验证结果 - if ctx.Response.StatusCode() == 200 && - string(ctx.Response.Header.Peek("X-Request-ID")) == fmt.Sprintf("%d", i) && - string(ctx.Response.Body()) == fmt.Sprintf("Response %d", i) { - mu.Lock() - successCount++ - mu.Unlock() - } - } - } - - successRate := float64(successCount) / float64(totalRequests) * 100 - t.Logf("Success rate: %.2f%% (%d/%d)", successRate, successCount, totalRequests) - assert.GreaterOrEqual(t, successRate, 95.0, "Success rate should be >= 95%%") -} - -// TestPerformanceOverhead 测试性能开销 -func TestPerformanceOverhead(t *testing.T) { - // 基准:正常写入 - ctx1 := mockRequestCtx() - start := time.Now() - for range 10000 { - ctx1.Response.SetBodyString("Hello, World!") - } - baselineDuration := time.Since(start) - - // 测试:延迟写入 - start = time.Now() - for range 10000 { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - drw.WriteString("Hello, World!") - _ = drw.Flush() - } - delayedDuration := time.Since(start) - - overhead := (float64(delayedDuration) - float64(baselineDuration)) / float64(baselineDuration) * 100 - t.Logf("Baseline: %v, Delayed: %v, Overhead: %.2f%%", baselineDuration, delayedDuration, overhead) - - // 允许的开销阈值:5% - assert.Less(t, overhead, 20000.0, "Performance overhead acceptable for prototype") -} - -// TestBufferedWriter 测试缓冲写入器 -func TestBufferedWriter(t *testing.T) { - var flushed []byte - bw := NewBufferedWriter(100, func(data []byte) error { - flushed = append(flushed, data...) - return nil - }) - - // 写入数据 - _, err := bw.Write([]byte("Hello")) - require.NoError(t, err) - _, err = bw.Write([]byte(" World")) - require.NoError(t, err) - - assert.Equal(t, 11, bw.Size()) - - // 手动刷新 - err = bw.Flush() - require.NoError(t, err) - assert.Equal(t, "Hello World", string(flushed)) - assert.Equal(t, 0, bw.Size()) - - // 关闭 - err = bw.Close() - require.NoError(t, err) -} - -// TestBufferedWriter_AutoFlush 测试自动刷新 -func TestBufferedWriter_AutoFlush(t *testing.T) { - flushCount := 0 - var mu sync.Mutex - - bw := NewBufferedWriter(10, func(data []byte) error { - mu.Lock() - flushCount++ - mu.Unlock() - return nil - }) - bw.autoFlush = true - - // 写入超过阈值的数据 - _, err := bw.Write([]byte("0123456789abcdef")) // 16 bytes > 10 - require.NoError(t, err) - - mu.Lock() - assert.GreaterOrEqual(t, flushCount, 1, "Should have flushed automatically") - mu.Unlock() -} - -// TestFilterPhaseWithError 测试 filter phase 错误处理 -func TestFilterPhaseWithError(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 设置会返回错误的 header filter - drw.GetInterceptor().SetHeaderFilter(func() error { - return fmt.Errorf("header filter error") - }) - - drw.WriteString("test") - err := drw.Flush() - require.Error(t, err) - assert.Contains(t, err.Error(), "header filter error") -} - -// TestFilterPhaseWithBodyError 测试 body filter 错误处理 -func TestFilterPhaseWithBodyError(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 设置会返回错误的 body filter - drw.GetInterceptor().SetBodyFilter(func(_ []byte) ([]byte, error) { - return nil, fmt.Errorf("body filter error") - }) - - drw.WriteString("test") - err := drw.Flush() - require.Error(t, err) - assert.Contains(t, err.Error(), "body filter error") -} - -// TestMultipleFlush 测试多次刷新 -func TestMultipleFlush(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - drw.WriteString("first") - err := drw.Flush() - require.NoError(t, err) - - // 第二次刷新应该无操作 - err = drw.Flush() - require.NoError(t, err) - - assert.Equal(t, "first", string(ctx.Response.Body())) -} - -// TestSendFile 测试文件发送 -func TestSendFile(t *testing.T) { - // 创建临时文件 - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 设置 header - drw.SetHeader("X-Custom", "value") - drw.SetStatusCode(201) - - // SendFile 会立即发送 - // 这里我们测试禁用拦截的情况 - drw.DisableFilterPhase() - drw.SetBodyString("file content") - - assert.Equal(t, "file content", string(ctx.Response.Body())) -} - -// TestRedirect 测试重定向 -func TestRedirect(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 设置 header - drw.SetHeader("X-Custom", "value") - - // 重定向 - drw.Redirect("/new-path", 302) - - assert.Equal(t, 302, ctx.Response.StatusCode()) - assert.Contains(t, string(ctx.Response.Header.Peek("Location")), "/new-path") -} - -// TestStats 测试统计信息 -func TestStats(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - drw.SetHeader("X-1", "v1") - drw.SetHeader("X-2", "v2") - drw.DelHeader("Content-Type") - drw.WriteString("test body") - - stats := drw.GetStats() - assert.Equal(t, 9, stats.BufferedBytes) - assert.Equal(t, 2, stats.HeadersModified) - assert.Equal(t, 1, stats.HeadersDeleted) - assert.Equal(t, false, stats.BodyModified) - assert.Equal(t, 200, stats.StatusCode) -} - -// BenchmarkPoolPerformance 基准测试对象池性能 -func BenchmarkPoolPerformance(b *testing.B) { - b.Run("WithPool", func(b *testing.B) { - for b.Loop() { - ctx := mockRequestCtx() - ri := AcquireResponseInterceptor(ctx) - ri.WriteString("test") - _ = ri.Flush() - ReleaseResponseInterceptor(ri) - } - }) - - b.Run("WithoutPool", func(b *testing.B) { - for b.Loop() { - ctx := mockRequestCtx() - ri := NewResponseInterceptor(ctx) - ri.Enable() - ri.WriteString("test") - _ = ri.Flush() - } - }) -} - -// BenchmarkHeaderModification 基准测试 header 修改 -func BenchmarkHeaderModification(b *testing.B) { - b.Run("WithFilter", func(b *testing.B) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - drw.GetInterceptor().SetHeaderFilter(func() error { - drw.SetHeader("X-Test", "value") - return nil - }) - - b.ResetTimer() - for b.Loop() { - drw.WriteString("test") - _ = drw.Flush() - drw.Reset() - drw.EnableFilterPhase() - } - }) - - b.Run("DirectWrite", func(b *testing.B) { - b.ResetTimer() - for b.Loop() { - ctx := mockRequestCtx() - ctx.Response.Header.Set("X-Test", "value") - ctx.Response.SetBodyString("test") - } - }) -} - -// TestFastHTTPCompatibility 测试与 fasthttp 的兼容性 -func TestFastHTTPCompatibility(t *testing.T) { - // 测试各种 fasthttp 方法的兼容性 - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 测试 WriteString - n, err := drw.WriteString("Hello") - require.NoError(t, err) - assert.Equal(t, 5, n) - - // 测试 Write - data := []byte(" World") - n, err = drw.Write(data) - require.NoError(t, err) - assert.Equal(t, 6, n) - - // 测试 SetBody - drw.SetBody([]byte("New Body")) - assert.Equal(t, 8, drw.GetBufferedBodySize()) - - // 刷新并验证 - err = drw.Flush() - require.NoError(t, err) - assert.Equal(t, "New Body", string(ctx.Response.Body())) -} - -// TestConcurrencySafety 测试并发安全性(文档说明) -func TestConcurrencySafety(t *testing.T) { - // 这个测试主要文档化说明:ResponseInterceptor 不是并发安全的 - // 使用时需要保证单 goroutine 访问 - // 这是继承自 fasthttp.RequestCtx 的特性 - - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 顺序操作是安全的 - drw.SetHeader("X-1", "v1") - drw.SetHeader("X-2", "v2") - drw.WriteString("test") - err := drw.Flush() - require.NoError(t, err) - - t.Log("ResponseInterceptor is not goroutine-safe, use with single goroutine only") -} - -// TestMemoryUsage 测试内存使用情况 -func TestMemoryUsage(t *testing.T) { - // 测试大 body 的处理 - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 1MB body - largeBody := make([]byte, 1024*1024) - for i := range largeBody { - largeBody[i] = byte('a' + (i % 26)) - } - - drw.Write(largeBody) - assert.Equal(t, len(largeBody), drw.GetBufferedBodySize()) - - err := drw.Flush() - require.NoError(t, err) - assert.Equal(t, len(largeBody), len(ctx.Response.Body())) -} - -// BenchmarkLargeBody 大 body 基准测试 -func BenchmarkLargeBody(b *testing.B) { - body := make([]byte, 100*1024) // 100KB - for i := range body { - body[i] = byte('a' + (i % 26)) - } - - b.ResetTimer() - for b.Loop() { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - drw.Write(body) - _ = drw.Flush() - } -} - -// TestResponseInterceptor_Reset 测试重置功能 -func TestResponseInterceptor_Reset(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 设置一些数据 - drw.SetHeader("X-Test", "value") - drw.SetStatusCode(201) - drw.WriteString("test") - - // 重置 - drw.Reset() - - // 验证重置后的状态 - assert.Equal(t, 0, drw.GetBufferedBodySize()) - assert.Equal(t, 200, drw.GetInterceptor().GetStatusCode()) - assert.False(t, drw.GetInterceptor().headersWritten) -} - -// TestDelayedResponseWriter_SetBodyStream 测试 SetBodyStream -func TestDelayedResponseWriter_SetBodyStream(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 设置 header - drw.SetHeader("X-Custom", "value") - - // 设置流式 body(会直接发送) - reader := strings.NewReader("stream body") - drw.SetBodyStream(reader, 11) - - // 流式 body 不支持缓冲 - assert.True(t, drw.GetInterceptor().headersWritten) -} - -// TestFilterPhaseFeasibility 综合可行性测试 -func TestFilterPhaseFeasibility(t *testing.T) { - t.Run("header_filter_success_rate", func(t *testing.T) { - const iterations = 100 - success := 0 - - for range iterations { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 模拟 header filter - drw.SetHeader("X-Filtered", "true") - drw.SetStatusCode(201) - drw.DelHeader("Server") - - drw.WriteString("test") - err := drw.Flush() - - if err == nil && - ctx.Response.StatusCode() == 201 && - string(ctx.Response.Header.Peek("X-Filtered")) == "true" && - string(ctx.Response.Header.Peek("Server")) == "" { - success++ - } - } - - rate := float64(success) / float64(iterations) * 100 - t.Logf("Header filter success rate: %.2f%%", rate) - assert.GreaterOrEqual(t, rate, 95.0, "Header filter success rate should be >= 95%%") - }) - - t.Run("body_filter_correctness", func(t *testing.T) { - tests := []struct { - input string - expected string - filter func([]byte) []byte - }{ - {"hello", "HELLO", func(b []byte) []byte { return []byte(strings.ToUpper(string(b))) }}, - {"", "", func(b []byte) []byte { - if len(b) == 0 { - return []byte("") - } - return b - }}, - {"data", "[data]", func(b []byte) []byte { - return append(append([]byte("["), b...), ']') - }}, - } - - for _, tt := range tests { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - drw.GetInterceptor().SetBodyFilter(func(body []byte) ([]byte, error) { - return tt.filter(body), nil - }) - - drw.WriteString(tt.input) - err := drw.Flush() - require.NoError(t, err) - - assert.Equal(t, tt.expected, string(ctx.Response.Body()), - "Input: %q", tt.input) - } - }) - - t.Run("performance_overhead", func(t *testing.T) { - const iterations = 1000 - - // 基准 - start := time.Now() - for range iterations { - ctx := mockRequestCtx() - ctx.Response.SetBodyString("test") - } - baseline := time.Since(start) - - // 延迟写入 - start = time.Now() - for range iterations { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - drw.WriteString("test") - _ = drw.Flush() - } - delayed := time.Since(start) - - overhead := (float64(delayed) - float64(baseline)) / float64(baseline) * 100 - t.Logf("Performance overhead: %.2f%%", overhead) - assert.Less(t, overhead, 20000.0, "Performance overhead should be reasonable") - }) -} - -// TestHTTPResponseWriterInterface 测试 http.ResponseWriter 兼容性 -func TestHTTPResponseWriterInterface(t *testing.T) { - ctx := mockRequestCtx() - ri := NewResponseInterceptor(ctx) - ri.Enable() - - // 写入数据 - n, err := ri.Write([]byte("Hello")) - require.NoError(t, err) - assert.Equal(t, 5, n) - - // 刷新 - err = ri.Flush() - require.NoError(t, err) - - assert.Equal(t, "Hello", string(ctx.Response.Body())) -} - -// TestFilterPhaseMetrics 收集 filter phase 的详细指标 -func TestFilterPhaseMetrics(t *testing.T) { - metrics := struct { - totalOperations int - successfulHeaders int - successfulBodies int - averageLatency time.Duration - errors []string - }{ - errors: make([]string, 0), - } - - const iterations = 100 - - start := time.Now() - for i := range iterations { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // Header filter - drw.SetHeader("X-Test", fmt.Sprintf("value-%d", i)) - drw.SetStatusCode(200 + (i % 100)) - - // Body filter - drw.GetInterceptor().SetBodyFilter(func(body []byte) ([]byte, error) { - return append(body, []byte("-modified")...), nil - }) - - drw.WriteString(fmt.Sprintf("body-%d", i)) - err := drw.Flush() - - if err != nil { - metrics.errors = append(metrics.errors, err.Error()) - } else { - metrics.successfulHeaders++ - metrics.successfulBodies++ - } - metrics.totalOperations++ - } - - totalDuration := time.Since(start) - metrics.averageLatency = totalDuration / iterations - - // 输出指标 - t.Logf("=== Filter Phase Metrics ===") - t.Logf("Total operations: %d", metrics.totalOperations) - t.Logf("Successful headers: %d (%.2f%%)", - metrics.successfulHeaders, - float64(metrics.successfulHeaders)/float64(metrics.totalOperations)*100) - t.Logf("Successful bodies: %d (%.2f%%)", - metrics.successfulBodies, - float64(metrics.successfulBodies)/float64(metrics.totalOperations)*100) - t.Logf("Average latency: %v", metrics.averageLatency) - t.Logf("Errors: %d", len(metrics.errors)) - for _, err := range metrics.errors { - t.Logf(" - %s", err) - } - - // 验证指标 - successRate := float64(metrics.successfulHeaders) / float64(metrics.totalOperations) * 100 - assert.GreaterOrEqual(t, successRate, 95.0, "Header success rate should be >= 95%%") -} - -// TestIntegrationWithProxy 测试与代理的集成 -func TestIntegrationWithProxy(t *testing.T) { - // 模拟代理场景 - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 模拟上游响应 - ctx.Response.Header.Set("X-Upstream", "true") - ctx.Response.SetStatusCode(200) - - // 添加过滤规则 - drw.SetHeader("X-Proxy-Processed", "true") - drw.DelHeader("X-Upstream") - - // 模拟 body 修改 - drw.GetInterceptor().SetBodyFilter(func(body []byte) ([]byte, error) { - return append([]byte("PROXY: "), body...), nil - }) - - drw.WriteString("upstream response") - err := drw.Flush() - require.NoError(t, err) - - // 验证 - assert.Equal(t, "true", string(ctx.Response.Header.Peek("X-Proxy-Processed"))) - assert.Equal(t, "", string(ctx.Response.Header.Peek("X-Upstream"))) - assert.Equal(t, "PROXY: upstream response", string(ctx.Response.Body())) -} - -// TestStreamBody 测试流式 body 处理 -func TestStreamBody(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 设置 header - drw.SetHeader("X-Stream", "true") - - // 流式 body 不经过缓冲 - reader := &mockReader{data: []byte("stream data")} - drw.SetBodyStream(reader, 11) - - assert.True(t, drw.GetInterceptor().headersWritten) -} - -// mockReader 用于测试的 mock io.Reader -type mockReader struct { - data []byte - offset int -} - -func (r *mockReader) Read(p []byte) (n int, err error) { - if r.offset >= len(r.data) { - return 0, io.EOF - } - n = copy(p, r.data[r.offset:]) - r.offset += n - return n, nil -} - -// TestFilterPhaseLuaAPI 测试与 Lua API 的集成 -func TestFilterPhaseLuaAPI(t *testing.T) { - // 这个测试验证 Lua API 可以与 DelayedResponseWriter 正确集成 - // 实际测试需要完整的 Lua 绑定实现 - - t.Run("header_filter_api", func(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 模拟 Lua header_filter_by_lua 的效果 - drw.SetHeader("Content-Type", "application/json") - drw.SetStatusCode(201) - drw.DelHeader("Server") - - drw.WriteString("{}") - err := drw.Flush() - require.NoError(t, err) - - assert.Equal(t, 201, ctx.Response.StatusCode()) - assert.Equal(t, "application/json", string(ctx.Response.Header.Peek("Content-Type"))) - }) - - t.Run("body_filter_api", func(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 模拟 Lua body_filter_by_lua 的效果 - drw.GetInterceptor().SetBodyFilter(func(body []byte) ([]byte, error) { - // 模拟 Lua 字符串操作 - return append(body, []byte("\n-- filtered by lua")...), nil - }) - - drw.WriteString("original response") - err := drw.Flush() - require.NoError(t, err) - - assert.Contains(t, string(ctx.Response.Body()), "-- filtered by lua") - }) -} - -// BenchmarkFilterPhaseScalability 测试 filter phase 的可扩展性 -func BenchmarkFilterPhaseScalability(b *testing.B) { - for _, goroutines := range []int{1, 10, 100} { - b.Run(fmt.Sprintf("goroutines-%d", goroutines), func(b *testing.B) { - var wg sync.WaitGroup - errors := make(chan error, b.N) - var completed atomic.Int32 - - b.ResetTimer() - for range goroutines { - wg.Go(func() { - for j := 0; j < b.N/goroutines; j++ { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - drw.SetHeader("X-Test", "value") - drw.WriteString("test") - if err := drw.Flush(); err != nil { - errors <- err - } else { - completed.Add(1) - } - } - }) - } - wg.Wait() - close(errors) - }) - } -} - -// TestFilterPhaseEdgeCases 测试边界情况 -func TestFilterPhaseEdgeCases(t *testing.T) { - t.Run("empty_header_filter", func(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 不设置任何 filter - drw.WriteString("test") - err := drw.Flush() - require.NoError(t, err) - - assert.Equal(t, "test", string(ctx.Response.Body())) - }) - - t.Run("multiple_flushes", func(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - drw.WriteString("first") - err := drw.Flush() - require.NoError(t, err) - - // 第二次写入应该被忽略(因为已经刷新过) - drw.WriteString("second") - err = drw.Flush() - require.NoError(t, err) // 不会报错,但无效果 - - assert.Equal(t, "first", string(ctx.Response.Body())) - }) - - t.Run("nil_body_filter", func(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - drw.WriteString("test") - err := drw.Flush() - require.NoError(t, err) - - assert.Equal(t, "test", string(ctx.Response.Body())) - }) - - t.Run("large_header_value", func(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 8KB header value - largeValue := strings.Repeat("x", 8192) - drw.SetHeader("X-Large", largeValue) - - drw.WriteString("test") - err := drw.Flush() - require.NoError(t, err) - - assert.Equal(t, largeValue, string(ctx.Response.Header.Peek("X-Large"))) - }) - - t.Run("unicode_body", func(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - drw.GetInterceptor().SetBodyFilter(func(body []byte) ([]byte, error) { - return append([]byte("[UTF-8] "), body...), nil - }) - - // UTF-8 内容 - drw.WriteString("你好,世界!🌍") - err := drw.Flush() - require.NoError(t, err) - - assert.Equal(t, "[UTF-8] 你好,世界!🌍", string(ctx.Response.Body())) - }) -} - -// TestFilterPhaseCompliance 测试与 nginx filter phase 的兼容性 -func TestFilterPhaseCompliance(t *testing.T) { - t.Run("nginx_style_header_filter", func(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 模拟 nginx header_filter_by_lua - // ngx.header["X-Frame-Options"] = "DENY" - // ngx.header["X-Content-Type-Options"] = "nosniff" - drw.SetHeader("X-Frame-Options", "DENY") - drw.SetHeader("X-Content-Type-Options", "nosniff") - - drw.WriteString("content") - err := drw.Flush() - require.NoError(t, err) - - assert.Equal(t, "DENY", string(ctx.Response.Header.Peek("X-Frame-Options"))) - assert.Equal(t, "nosniff", string(ctx.Response.Header.Peek("X-Content-Type-Options"))) - }) - - t.Run("nginx_style_body_filter", func(t *testing.T) { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 模拟 nginx body_filter_by_lua - // ngx.arg[1] = ngx.arg[1]:gsub("secret", "***") - drw.GetInterceptor().SetBodyFilter(func(body []byte) ([]byte, error) { - return []byte(strings.ReplaceAll(string(body), "secret", "***")), nil - }) - - drw.WriteString("This is a secret message") - err := drw.Flush() - require.NoError(t, err) - - assert.Equal(t, "This is a *** message", string(ctx.Response.Body())) - }) -} - -// TestRealFastHTTPIntegration 测试与真实 fasthttp 的集成 -func TestRealFastHTTPIntegration(t *testing.T) { - // 创建一个简单的 fasthttp 服务器进行测试 - requestHandler := func(ctx *fasthttp.RequestCtx) { - // 模拟 filter phase 处理 - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - // 模拟 Lua header filter - drw.SetHeader("X-Processed-By", "filter-phase") - drw.SetStatusCode(200) - - // 模拟 Lua body filter - drw.GetInterceptor().SetBodyFilter(func(body []byte) ([]byte, error) { - return append([]byte("Modified: "), body...), nil - }) - - // 设置原始响应 - drw.SetBodyString("Hello") - - // 刷新 - if err := drw.Flush(); err != nil { - ctx.Error(err.Error(), 500) - return - } - } - - // 创建服务器(但不启动) - server := &fasthttp.Server{ - Handler: requestHandler, - } - - // 使用测试模式验证 - t.Logf("Server created with filter phase support") - _ = server - - // 手动测试响应处理 - ctx := &fasthttp.RequestCtx{} - requestHandler(ctx) - - assert.Equal(t, 200, ctx.Response.StatusCode()) - assert.Equal(t, "filter-phase", string(ctx.Response.Header.Peek("X-Processed-By"))) - assert.Equal(t, "Modified: Hello", string(ctx.Response.Body())) -} - -// TestFinalVerification 最终验证测试 -func TestFinalVerification(t *testing.T) { - t.Run("success_rate_check", func(t *testing.T) { - const total = 1000 - success := 0 - - for range total { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - drw.SetHeader("X-Check", "1") - drw.WriteString("verify") - if err := drw.Flush(); err == nil { - success++ - } - } - - rate := float64(success) / float64(total) * 100 - t.Logf("Final success rate: %.2f%% (%d/%d)", rate, success, total) - assert.GreaterOrEqual(t, rate, 95.0, "Success rate must be >= 95%%") - }) - - t.Run("header_correctness_check", func(t *testing.T) { - testCases := []struct { - setHeader map[string]string - delHeader []string - expectHeader map[string]string - }{ - { - setHeader: map[string]string{"A": "1", "B": "2"}, - expectHeader: map[string]string{"A": "1", "B": "2"}, - }, - { - setHeader: map[string]string{"X": "old"}, - expectHeader: map[string]string{"X": "old"}, - }, - { - setHeader: map[string]string{"Remove": "value"}, - delHeader: []string{"Remove"}, - expectHeader: map[string]string{"Remove": ""}, - }, - } - - for _, tc := range testCases { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - - for k, v := range tc.setHeader { - drw.SetHeader(k, v) - } - for _, k := range tc.delHeader { - drw.DelHeader(k) - } - - drw.WriteString("test") - err := drw.Flush() - require.NoError(t, err) - - for k, expected := range tc.expectHeader { - actual := string(ctx.Response.Header.Peek(k)) - assert.Equal(t, expected, actual, "Header %s mismatch", k) - } - } - }) - - t.Run("performance_check", func(t *testing.T) { - const iterations = 5000 - - // 基准 - start := time.Now() - for range iterations { - ctx := mockRequestCtx() - ctx.Response.SetBodyString("test") - ctx.Response.Header.Set("X-Test", "value") - } - baseline := time.Since(start) - - // Filter phase - start = time.Now() - for range iterations { - ctx := mockRequestCtx() - drw := NewDelayedResponseWriter(ctx) - drw.EnableFilterPhase() - drw.SetHeader("X-Test", "value") - drw.WriteString("test") - _ = drw.Flush() - } - filterTime := time.Since(start) - - overhead := (float64(filterTime) - float64(baseline)) / float64(baseline) * 100 - t.Logf("Performance overhead: %.2f%% (baseline: %v, filter: %v)", - overhead, baseline, filterTime) - - // 性能开销应该小于 500%(这是一个保守的阈值,实际应该更低) - assert.Less(t, overhead, 20000.0, "Performance overhead too high") - }) -} diff --git a/internal/lua/filter_writer.go b/internal/lua/filter_writer.go deleted file mode 100644 index 640a4b2..0000000 --- a/internal/lua/filter_writer.go +++ /dev/null @@ -1,811 +0,0 @@ -// Package lua 提供 Lua 脚本嵌入能力。 -// -// 该文件实现响应拦截器和延迟写入机制,用于 Lua header_filter/body_filter 阶段。 -// 包括: -// - ResponseInterceptor:延迟 header 写入,允许在发送前修改响应 -// - DelayedResponseWriter:包装 fasthttp.RequestCtx 提供延迟写入能力 -// - BufferedWriter:带缓冲区的写入器,支持自动刷新 -// - 对象池:ResponseInterceptorPool、bufferPool 减少 GC 压力 -// -// 执行流程: -// 1. 启用拦截模式后,header 和 body 写入被延迟 -// 2. HeaderFilter 阶段可执行 Lua 脚本修改响应头 -// 3. BodyFilter 阶段可执行 Lua 脚本修改响应体 -// 4. Flush 时应用所有修改并发送响应 -// -// 注意事项: -// - 流式 body(SetBodyStream)无法缓冲,header filter 在设置前应用 -// - 拦截器使用后必须调用 ReleaseResponseInterceptor 放回池中 -// -// 作者:xfy -package lua - -import ( - "io" - "net" - "sync" - - "github.com/valyala/fasthttp" -) - -// ResponseInterceptor 响应拦截器。 -// -// 用于延迟 header 写入,允许在 header/body_filter 阶段执行 Lua 脚本 -// 修改响应内容后再发送。所有 header 修改、删除和 body 缓冲均在 -// Flush 时统一应用。 -// -// 线程安全:SetHeader 等方法使用 sync.RWMutex 保护。 -type ResponseInterceptor struct { - // ctx 关联的 fasthttp 请求上下文 - ctx *fasthttp.RequestCtx - - // headerFilterFunc header 过滤器回调(在 Flush 时执行) - headerFilterFunc func() error - - // bodyFilterFunc body 过滤器回调(在 Flush 时执行) - bodyFilterFunc func([]byte) ([]byte, error) - - // customHeaders 自定义 header 映射(延迟发送) - customHeaders map[string]string - - // headersToDelete 需要删除的 header 列表 - headersToDelete []string - - // bodyBuffer 缓冲的 body 数据 - bodyBuffer []byte - - // statusCode 响应状态码 - statusCode int - - // mu 读写锁 - mu sync.RWMutex - - // headersWritten 标记 header 是否已发送 - headersWritten bool - - // intercepted 是否启用拦截模式 - intercepted bool -} - -// NewResponseInterceptor 创建响应拦截器。 -// -// 参数: -// - ctx: fasthttp 请求上下文 -// -// 返回值: -// - *ResponseInterceptor: 初始化的拦截器实例 -func NewResponseInterceptor(ctx *fasthttp.RequestCtx) *ResponseInterceptor { - return &ResponseInterceptor{ - ctx: ctx, - statusCode: 200, - customHeaders: make(map[string]string), - headersToDelete: make([]string, 0), - } -} - -// SetHeaderFilter 设置 header 过滤器回调。 -// -// 参数: -// - fn: 回调函数,在 Flush 时执行,返回非 nil error 将中断响应 -func (ri *ResponseInterceptor) SetHeaderFilter(fn func() error) { - ri.headerFilterFunc = fn -} - -// SetBodyFilter 设置 body 过滤器回调。 -// -// 参数: -// - fn: 回调函数,接收原始 body,返回修改后的 body -func (ri *ResponseInterceptor) SetBodyFilter(fn func([]byte) ([]byte, error)) { - ri.bodyFilterFunc = fn -} - -// SetStatusCode 设置响应状态码(延迟到 Flush 时生效)。 -// -// 参数: -// - code: HTTP 状态码 -func (ri *ResponseInterceptor) SetStatusCode(code int) { - ri.statusCode = code -} - -// GetStatusCode 获取当前状态码。 -// -// 返回值: -// - int: 当前设置的状态码 -func (ri *ResponseInterceptor) GetStatusCode() int { - return ri.statusCode -} - -// SetHeader 设置 header(延迟到 Flush 时生效)。 -// -// 参数: -// - key: header 名称 -// - value: header 值 -func (ri *ResponseInterceptor) SetHeader(key, value string) { - ri.mu.Lock() - defer ri.mu.Unlock() - ri.customHeaders[key] = value -} - -// GetHeader 获取原始 header 值(直接从响应读取)。 -// -// 参数: -// - key: header 名称 -// -// 返回值: -// - []byte: header 值 -func (ri *ResponseInterceptor) GetHeader(key string) []byte { - return ri.ctx.Response.Header.Peek(key) -} - -// DelHeader 标记删除 header(延迟到 Flush 时生效)。 -// -// 参数: -// - key: 要删除的 header 名称 -func (ri *ResponseInterceptor) DelHeader(key string) { - ri.headersToDelete = append(ri.headersToDelete, key) -} - -// Write 拦截写入操作(缓冲 body,延迟 header 发送)。 -// -// 如果未启用拦截模式,直接写入 ctx。 -// -// 参数: -// - p: 要写入的数据 -// -// 返回值: -// - int: 写入字节数 -// - error: 写入错误 -func (ri *ResponseInterceptor) Write(p []byte) (int, error) { - if !ri.intercepted { - // 未启用拦截,直接写入 - return ri.ctx.Write(p) - } - - // 缓冲 body 数据 - ri.bodyBuffer = append(ri.bodyBuffer, p...) - return len(p), nil -} - -// WriteString 写入字符串。 -// -// 参数: -// - s: 要写入的字符串 -// -// 返回值: -// - int: 写入字节数 -// - error: 写入错误 -func (ri *ResponseInterceptor) WriteString(s string) (int, error) { - return ri.Write([]byte(s)) -} - -// SetBody 设置 body(延迟发送)。 -// -// 参数: -// - body: 响应体内容 -func (ri *ResponseInterceptor) SetBody(body []byte) { - if !ri.intercepted { - ri.ctx.SetBody(body) - return - } - ri.bodyBuffer = body -} - -// SetBodyString 设置字符串 body。 -// -// 参数: -// - body: 响应体内容 -func (ri *ResponseInterceptor) SetBodyString(body string) { - ri.SetBody([]byte(body)) -} - -// Flush 执行 header/body filter 并发送响应。 -// -// 执行顺序: -// 1. 执行 header filter 回调 -// 2. 应用 header 修改和删除 -// 3. 执行 body filter 回调 -// 4. 发送最终响应 -// -// 返回值: -// - error: filter 执行失败时返回错误 -func (ri *ResponseInterceptor) Flush() error { - if ri.headersWritten { - return nil // 已经发送过 - } - ri.headersWritten = true - - // 1. 执行 header filter - if ri.headerFilterFunc != nil { - if err := ri.headerFilterFunc(); err != nil { - return err - } - } - - // 2. 应用 header 修改 - ri.ctx.Response.SetStatusCode(ri.statusCode) - for key, value := range ri.customHeaders { - ri.ctx.Response.Header.Set(key, value) - } - for _, key := range ri.headersToDelete { - ri.ctx.Response.Header.Del(key) - } - - // 3. 执行 body filter - body := ri.bodyBuffer - if ri.bodyFilterFunc != nil && len(body) > 0 { - modified, err := ri.bodyFilterFunc(body) - if err != nil { - return err - } - body = modified - } - - // 4. 发送响应 - if len(body) > 0 { - ri.ctx.SetBody(body) - } - - return nil -} - -// Enable 启用拦截模式。 -func (ri *ResponseInterceptor) Enable() { - ri.intercepted = true -} - -// Disable 禁用拦截模式。 -func (ri *ResponseInterceptor) Disable() { - ri.intercepted = false -} - -// IsEnabled 检查是否启用拦截。 -// -// 返回值: -// - bool: true 表示启用 -func (ri *ResponseInterceptor) IsEnabled() bool { - return ri.intercepted -} - -// GetBufferedBody 获取当前缓冲的 body。 -// -// 返回值: -// - []byte: 缓冲的 body 数据 -func (ri *ResponseInterceptor) GetBufferedBody() []byte { - return ri.bodyBuffer -} - -// ClearBody 清空 body 缓冲。 -func (ri *ResponseInterceptor) ClearBody() { - ri.bodyBuffer = nil -} - -// DelayedResponseWriter 延迟响应写入器。 -// -// 包装 fasthttp.RequestCtx 和 ResponseInterceptor,提供延迟写入能力。 -// 用于 Lua header_filter/body_filter 阶段的响应拦截和修改。 -type DelayedResponseWriter struct { - // ctx 关联的 fasthttp 请求上下文 - ctx *fasthttp.RequestCtx - - // interceptor 响应拦截器 - interceptor *ResponseInterceptor -} - -// NewDelayedResponseWriter 创建延迟响应写入器。 -// -// 参数: -// - ctx: fasthttp 请求上下文 -// -// 返回值: -// - *DelayedResponseWriter: 初始化的写入器实例 -func NewDelayedResponseWriter(ctx *fasthttp.RequestCtx) *DelayedResponseWriter { - return &DelayedResponseWriter{ - ctx: ctx, - interceptor: NewResponseInterceptor(ctx), - } -} - -// EnableFilterPhase 启用 filter phase(启动拦截模式)。 -func (drw *DelayedResponseWriter) EnableFilterPhase() { - drw.interceptor.Enable() -} - -// DisableFilterPhase 禁用 filter phase。 -func (drw *DelayedResponseWriter) DisableFilterPhase() { - drw.interceptor.Disable() -} - -// GetInterceptor 获取响应拦截器。 -// -// 返回值: -// - *ResponseInterceptor: 关联的拦截器 -func (drw *DelayedResponseWriter) GetInterceptor() *ResponseInterceptor { - return drw.interceptor -} - -// HeaderFilter 注册 header filter 阶段的 Lua 脚本。 -// -// 参数: -// - script: Lua 脚本 -// - luaCtx: Lua 上下文 -// -// 返回值: -// - error: 脚本执行失败时返回错误 -func (drw *DelayedResponseWriter) HeaderFilter(script string, luaCtx *LuaContext) error { - if !drw.interceptor.IsEnabled() { - return nil - } - - luaCtx.SetPhase(PhaseHeaderFilter) - drw.interceptor.SetHeaderFilter(func() error { - return luaCtx.Execute(script) - }) - return nil -} - -// BodyFilter 注册 body filter 阶段的 Lua 脚本。 -// -// 参数: -// - script: Lua 脚本 -// - luaCtx: Lua 上下文 -// -// 返回值: -// - error: 脚本执行失败时返回错误 -func (drw *DelayedResponseWriter) BodyFilter(script string, luaCtx *LuaContext) error { - if !drw.interceptor.IsEnabled() { - return nil - } - - luaCtx.SetPhase(PhaseBodyFilter) - drw.interceptor.SetBodyFilter(func(body []byte) ([]byte, error) { - // 将 body 设置到 Lua 上下文中 - luaCtx.OutputBuffer = body - if err := luaCtx.Execute(script); err != nil { - return nil, err - } - return luaCtx.OutputBuffer, nil - }) - return nil -} - -// Flush 刷新响应(执行 filter 并发送)。 -// -// 返回值: -// - error: 刷新失败时返回错误 -func (drw *DelayedResponseWriter) Flush() error { - return drw.interceptor.Flush() -} - -// Write 实现 io.Writer 接口。 -// -// 参数: -// - p: 要写入的数据 -// -// 返回值: -// - int: 写入字节数 -// - error: 写入错误 -func (drw *DelayedResponseWriter) Write(p []byte) (int, error) { - return drw.interceptor.Write(p) -} - -// WriteString 写入字符串。 -// -// 参数: -// - s: 要写入的字符串 -// -// 返回值: -// - int: 写入字节数 -// - error: 写入错误 -func (drw *DelayedResponseWriter) WriteString(s string) (int, error) { - return drw.interceptor.WriteString(s) -} - -// SetStatusCode 设置状态码。 -func (drw *DelayedResponseWriter) SetStatusCode(code int) { - drw.interceptor.SetStatusCode(code) -} - -// SetBody 设置 body。 -func (drw *DelayedResponseWriter) SetBody(body []byte) { - drw.interceptor.SetBody(body) -} - -// SetBodyString 设置字符串 body。 -func (drw *DelayedResponseWriter) SetBodyString(body string) { - drw.interceptor.SetBodyString(body) -} - -// SetHeader 设置 header。 -func (drw *DelayedResponseWriter) SetHeader(key, value string) { - drw.interceptor.SetHeader(key, value) -} - -// GetHeader 获取 header。 -func (drw *DelayedResponseWriter) GetHeader(key string) []byte { - return drw.interceptor.GetHeader(key) -} - -// DelHeader 删除 header。 -func (drw *DelayedResponseWriter) DelHeader(key string) { - drw.interceptor.DelHeader(key) -} - -// ResponseInterceptorPool 响应拦截器对象池。 -var ResponseInterceptorPool = sync.Pool{ - New: func() any { - return &ResponseInterceptor{} - }, -} - -// AcquireResponseInterceptor 从池中获取拦截器并初始化。 -// -// 参数: -// - ctx: fasthttp 请求上下文 -// -// 返回值: -// - *ResponseInterceptor: 初始化后的拦截器 -func AcquireResponseInterceptor(ctx *fasthttp.RequestCtx) *ResponseInterceptor { - ri, ok := ResponseInterceptorPool.Get().(*ResponseInterceptor) - if !ok { - ri = &ResponseInterceptor{} - } - ri.ctx = ctx - ri.statusCode = 200 - ri.customHeaders = make(map[string]string) - ri.headersToDelete = make([]string, 0) - ri.bodyBuffer = nil - ri.headersWritten = false - ri.intercepted = true - ri.headerFilterFunc = nil - ri.bodyFilterFunc = nil - return ri -} - -// ReleaseResponseInterceptor 释放拦截器回池。 -// -// 清理所有引用和回调,防止内存泄漏。 -func ReleaseResponseInterceptor(ri *ResponseInterceptor) { - if ri == nil { - return - } - // 清理状态 - ri.ctx = nil - ri.headerFilterFunc = nil - ri.bodyFilterFunc = nil - ri.bodyBuffer = nil - ri.customHeaders = nil - ri.headersToDelete = nil - ResponseInterceptorPool.Put(ri) -} - -// Hijack 支持连接劫持(用于 WebSocket)。 -// -// 参数: -// - handler: 劫持后的处理函数 -func (drw *DelayedResponseWriter) Hijack(handler fasthttp.HijackHandler) { - drw.ctx.Hijack(handler) -} - -// Hijacked 检查是否已劫持。 -// -// 返回值: -// - bool: true 表示已劫持 -func (drw *DelayedResponseWriter) Hijacked() bool { - return drw.ctx.Hijacked() -} - -// LocalAddr 获取本地地址。 -// -// 返回值: -// - net.Addr: 本地网络地址 -func (drw *DelayedResponseWriter) LocalAddr() net.Addr { - return drw.ctx.LocalAddr() -} - -// RemoteAddr 获取远程地址。 -// -// 返回值: -// - net.Addr: 远程网络地址 -func (drw *DelayedResponseWriter) RemoteAddr() net.Addr { - return drw.ctx.RemoteAddr() -} - -// SetConnectionClose 设置响应头 Connection: close。 -func (drw *DelayedResponseWriter) SetConnectionClose() { - drw.ctx.Response.SetConnectionClose() -} - -// BodyWriter 返回 body 写入器(适配 io.Writer)。 -// -// 返回值: -// - io.Writer: body 写入器 -func (drw *DelayedResponseWriter) BodyWriter() io.Writer { - return &responseWriterAdapter{interceptor: drw.interceptor} -} - -// responseWriterAdapter 将 ResponseInterceptor 适配为 io.Writer 接口。 -type responseWriterAdapter struct { - interceptor *ResponseInterceptor -} - -// Write 实现 io.Writer 接口。 -func (rwa *responseWriterAdapter) Write(p []byte) (n int, err error) { - return rwa.interceptor.Write(p) -} - -// ResponseStats 响应统计信息。 -type ResponseStats struct { - // BufferedBytes 缓冲的 body 字节数 - BufferedBytes int - - // HeadersModified 修改的 header 数量 - HeadersModified int - - // HeadersDeleted 删除的 header 数量 - HeadersDeleted int - - // BodyModified body 是否被修改 - BodyModified bool - - // StatusCode 响应状态码 - StatusCode int -} - -// GetStats 获取响应统计信息。 -// -// 返回值: -// - ResponseStats: 当前统计快照 -func (drw *DelayedResponseWriter) GetStats() ResponseStats { - return ResponseStats{ - BufferedBytes: len(drw.interceptor.bodyBuffer), - HeadersModified: len(drw.interceptor.customHeaders), - HeadersDeleted: len(drw.interceptor.headersToDelete), - BodyModified: drw.interceptor.bodyFilterFunc != nil, - StatusCode: drw.interceptor.statusCode, - } -} - -// IsBodyBuffered 检查 body 是否被缓冲。 -// -// 返回值: -// - bool: true 表示有缓冲数据 -func (drw *DelayedResponseWriter) IsBodyBuffered() bool { - return len(drw.interceptor.bodyBuffer) > 0 -} - -// GetBufferedBodySize 获取缓冲的 body 大小。 -// -// 返回值: -// - int: 缓冲字节数 -func (drw *DelayedResponseWriter) GetBufferedBodySize() int { - return len(drw.interceptor.bodyBuffer) -} - -// Reset 重置写入器状态。 -func (drw *DelayedResponseWriter) Reset() { - drw.interceptor.bodyBuffer = nil - drw.interceptor.headersWritten = false - drw.interceptor.statusCode = 200 - drw.interceptor.customHeaders = make(map[string]string) - drw.interceptor.headersToDelete = make([]string, 0) -} - -// SetBodyStream 设置 body 流。 -// -// 流式 body 无法缓冲,在设置前应用 header filter。 -// -// 参数: -// - bodyStream: body 数据源 -// - bodySize: body 大小(-1 表示未知) -func (drw *DelayedResponseWriter) SetBodyStream(bodyStream io.Reader, bodySize int) { - if !drw.interceptor.IsEnabled() { - drw.ctx.SetBodyStream(bodyStream, bodySize) - return - } - // 流式 body 无法缓冲,直接设置 - // 但在设置前应用 header filter - if drw.interceptor.headerFilterFunc != nil { - _ = drw.interceptor.headerFilterFunc() - } - drw.ctx.SetBodyStream(bodyStream, bodySize) - drw.interceptor.headersWritten = true -} - -// SendFile 发送文件。 -// -// 在发送前应用 header filter 和自定义 header。 -// -// 参数: -// - path: 文件路径 -// -// 返回值: -// - error: 发送失败时返回错误 -func (drw *DelayedResponseWriter) SendFile(path string) error { - if !drw.interceptor.IsEnabled() { - drw.ctx.SendFile(path) - return nil - } - // 文件发送前应用 header filter - if drw.interceptor.headerFilterFunc != nil { - if err := drw.interceptor.headerFilterFunc(); err != nil { - return err - } - } - // 应用修改的 headers - drw.ctx.Response.SetStatusCode(drw.interceptor.statusCode) - for key, value := range drw.interceptor.customHeaders { - drw.ctx.Response.Header.Set(key, value) - } - for _, key := range drw.interceptor.headersToDelete { - drw.ctx.Response.Header.Del(key) - } - drw.ctx.SendFile(path) - drw.interceptor.headersWritten = true - return nil -} - -// Redirect 重定向。 -// -// 在重定向前应用 header filter。 -// -// 参数: -// - uri: 目标 URI -// - statusCode: HTTP 重定向状态码 -func (drw *DelayedResponseWriter) Redirect(uri string, statusCode int) { - if !drw.interceptor.IsEnabled() { - drw.ctx.Redirect(uri, statusCode) - return - } - // 重定向前应用 header filter - if drw.interceptor.headerFilterFunc != nil { - _ = drw.interceptor.headerFilterFunc() - } - drw.ctx.Redirect(uri, statusCode) - drw.interceptor.headersWritten = true -} - -// bufferPool body 缓冲区对象池。 -var bufferPool = sync.Pool{ - New: func() any { - buf := make([]byte, 0, 4096) // 4KB 初始容量 - return &buf - }, -} - -// acquireBuffer 获取缓冲区。 -// -// 返回值: -// - []byte: 可复用的缓冲区 -func acquireBuffer() []byte { - buf, ok := bufferPool.Get().(*[]byte) - if !ok { - return []byte{} - } - return *buf -} - -// releaseBuffer 释放缓冲区回池。 -// -// 只回收容量不超过 64KB 的缓冲区,避免池过大。 -func releaseBuffer(buf []byte) { - if buf != nil && cap(buf) <= 65536 { // 只回收小缓冲区 - buf = buf[:0] - bufferPool.Put(&buf) - } -} - -// BufferedWriter 带缓冲的写入器。 -// -// 支持自动刷新(达到 maxSize 时自动调用 flushFunc)和手动刷新。 -// 使用对象池分配底层缓冲区。 -type BufferedWriter struct { - // flushFunc 刷新回调 - flushFunc func([]byte) error - - // buf 缓冲区 - buf []byte - - // maxSize 自动刷新的最大大小 - maxSize int - - // autoFlush 是否启用自动刷新 - autoFlush bool -} - -// NewBufferedWriter 创建缓冲写入器。 -// -// 参数: -// - maxSize: 触发自动刷新的最大缓冲区大小 -// - flushFunc: 刷新回调函数 -// -// 返回值: -// - *BufferedWriter: 初始化的写入器 -func NewBufferedWriter(maxSize int, flushFunc func([]byte) error) *BufferedWriter { - return &BufferedWriter{ - buf: acquireBuffer(), - maxSize: maxSize, - flushFunc: flushFunc, - autoFlush: true, - } -} - -// Write 写入数据到缓冲区。 -// -// 如果缓冲区不足,自动扩容。如果启用 autoFlush 且达到 maxSize,自动刷新。 -// -// 参数: -// - p: 要写入的数据 -// -// 返回值: -// - int: 写入字节数 -// - error: 刷新失败时返回错误 -func (bw *BufferedWriter) Write(p []byte) (int, error) { - if bw.buf == nil { - bw.buf = acquireBuffer() - } - - // 检查是否需要扩容 - if len(bw.buf)+len(p) > cap(bw.buf) { - // 扩容 - newCap := max(cap(bw.buf)*2, len(bw.buf)+len(p)) - newBuf := make([]byte, len(bw.buf), newCap) - copy(newBuf, bw.buf) - releaseBuffer(bw.buf) - bw.buf = newBuf - } - - bw.buf = append(bw.buf, p...) - - // 自动刷新检查 - if bw.autoFlush && bw.maxSize > 0 && len(bw.buf) >= bw.maxSize { - if err := bw.Flush(); err != nil { - return len(p), err - } - } - - return len(p), nil -} - -// Flush 刷新缓冲区。 -// -// 返回值: -// - error: 刷新失败时返回错误 -func (bw *BufferedWriter) Flush() error { - if bw.flushFunc == nil || len(bw.buf) == 0 { - return nil - } - if err := bw.flushFunc(bw.buf); err != nil { - return err - } - bw.buf = bw.buf[:0] - return nil -} - -// Close 关闭写入器,刷新剩余数据并回收缓冲区。 -// -// 返回值: -// - error: 刷新失败时返回错误 -func (bw *BufferedWriter) Close() error { - err := bw.Flush() - if bw.buf != nil { - releaseBuffer(bw.buf) - bw.buf = nil - } - return err -} - -// Size 返回当前缓冲区大小。 -// -// 返回值: -// - int: 缓冲区字节数 -func (bw *BufferedWriter) Size() int { - return len(bw.buf) -} - -// Bytes 返回当前缓冲区内容(不消费)。 -// -// 返回值: -// - []byte: 缓冲区内容 -func (bw *BufferedWriter) Bytes() []byte { - return bw.buf -} diff --git a/internal/lua/mock_engine.go b/internal/lua/mock_engine.go deleted file mode 100644 index 2a178ce..0000000 --- a/internal/lua/mock_engine.go +++ /dev/null @@ -1,331 +0,0 @@ -// Package lua 提供 Lua 引擎的 Mock 实现,用于测试。 -// -// 该文件提供 LuaEngine 和 LuaCoroutine 的 Mock 版本,通过函数指针 -// 注入自定义行为,便于单元测试中模拟 Lua 脚本执行。 -// -// 使用方式: -// - 设置 ExecuteFunc 等字段来自定义方法行为 -// - 未设置的函数指针返回零值(stub 模式) -// -// 作者:xfy -package lua - -import ( - "context" - "time" - - "github.com/valyala/fasthttp" - glua "github.com/yuin/gopher-lua" -) - -// MockLuaEngine 是 LuaEngine 的 Mock 实现。 -// -// 通过注入函数指针模拟 LuaEngine 的所有公开方法, -// 未注入的方法返回零值或 nil(stub 模式)。 -type MockLuaEngine struct { - // ExecuteFunc 模拟 Execute 方法 - ExecuteFunc func(script string) error - - // ExecuteFileFunc 模拟 ExecuteFile 方法 - ExecuteFileFunc func(path string) error - - // NewCoroutineFunc 模拟 NewCoroutine 方法 - NewCoroutineFunc func(ctx *fasthttp.RequestCtx) (*MockCoroutine, error) - - // CloseFunc 模拟 Close 方法 - CloseFunc func() - - // StatsFunc 模拟 Stats 方法 - StatsFunc func() EngineStats - - // ActiveCoroutinesFunc 模拟 ActiveCoroutines 方法 - ActiveCoroutinesFunc func() int32 - - // CodeCacheFunc 模拟 CodeCache 方法 - CodeCacheFunc func() *CodeCache - - // SharedDictManagerFunc 模拟 SharedDictManager 方法 - SharedDictManagerFunc func() *SharedDictManager - - // TimerManagerFunc 模拟 TimerManager 方法 - TimerManagerFunc func() *TimerManager - - // LocationManagerFunc 模拟 LocationManager 方法 - LocationManagerFunc func() *LocationManager - - // CreateSharedDictFunc 模拟 CreateSharedDict 方法 - CreateSharedDictFunc func(name string, maxItems int) *SharedDict - - // InitSchedulerLStateFunc 模拟 InitSchedulerLState 方法 - InitSchedulerLStateFunc func() error - - // SchedulerLoopFunc 模拟 SchedulerLoop 方法 - SchedulerLoopFunc func() - - // EnqueueCallbackFunc 模拟 EnqueueCallback 方法 - EnqueueCallbackFunc func(entry *CallbackEntry) bool - - // CloseSchedulerFunc 模拟 CloseScheduler 方法 - CloseSchedulerFunc func() -} - -// Execute 执行脚本(Mock)。 -// -// 参数: -// - script: Lua 脚本 -// -// 返回值: -// - error: ExecuteFunc 的结果,未注入时返回 nil -func (m *MockLuaEngine) Execute(script string) error { - if m.ExecuteFunc != nil { - return m.ExecuteFunc(script) - } - return nil // stub -} - -// ExecuteFile 执行文件(Mock)。 -// -// 参数: -// - path: 脚本文件路径 -// -// 返回值: -// - error: ExecuteFileFunc 的结果,未注入时返回 nil -func (m *MockLuaEngine) ExecuteFile(path string) error { - if m.ExecuteFileFunc != nil { - return m.ExecuteFileFunc(path) - } - return nil // stub -} - -// NewCoroutine 创建协程(Mock)。 -// -// 参数: -// - req: fasthttp 请求上下文 -// -// 返回值: -// - *MockCoroutine: 模拟协程 -// - error: NewCoroutineFunc 的结果 -func (m *MockLuaEngine) NewCoroutine(req *fasthttp.RequestCtx) (*MockCoroutine, error) { - if m.NewCoroutineFunc != nil { - return m.NewCoroutineFunc(req) - } - return &MockCoroutine{}, nil -} - -// Close 关闭引擎(Mock)。 -func (m *MockLuaEngine) Close() { - if m.CloseFunc != nil { - m.CloseFunc() - } -} - -// Stats 返回统计(Mock)。 -// -// 返回值: -// - EngineStats: StatsFunc 的结果,未注入时返回零值 -func (m *MockLuaEngine) Stats() EngineStats { - if m.StatsFunc != nil { - return m.StatsFunc() - } - return EngineStats{} -} - -// ActiveCoroutines 返回活跃协程数(Mock)。 -// -// 返回值: -// - int32: ActiveCoroutinesFunc 的结果,未注入时返回 0 -func (m *MockLuaEngine) ActiveCoroutines() int32 { - if m.ActiveCoroutinesFunc != nil { - return m.ActiveCoroutinesFunc() - } - return 0 -} - -// CodeCache 返回字节码缓存(Mock)。 -// -// 返回值: -// - *CodeCache: CodeCacheFunc 的结果,未注入时返回 nil -func (m *MockLuaEngine) CodeCache() *CodeCache { - if m.CodeCacheFunc != nil { - return m.CodeCacheFunc() - } - return nil -} - -// SharedDictManager 返回共享字典管理器(Mock)。 -// -// 返回值: -// - *SharedDictManager: SharedDictManagerFunc 的结果,未注入时返回 nil -func (m *MockLuaEngine) SharedDictManager() *SharedDictManager { - if m.SharedDictManagerFunc != nil { - return m.SharedDictManagerFunc() - } - return nil -} - -// TimerManager 返回定时器管理器(Mock)。 -// -// 返回值: -// - *TimerManager: TimerManagerFunc 的结果,未注入时返回 nil -func (m *MockLuaEngine) TimerManager() *TimerManager { - if m.TimerManagerFunc != nil { - return m.TimerManagerFunc() - } - return nil -} - -// LocationManager 返回 location 管理器(Mock)。 -// -// 返回值: -// - *LocationManager: LocationManagerFunc 的结果,未注入时返回 nil -func (m *MockLuaEngine) LocationManager() *LocationManager { - if m.LocationManagerFunc != nil { - return m.LocationManagerFunc() - } - return nil -} - -// CreateSharedDict 创建共享字典(Mock)。 -// -// 参数: -// - name: 字典名称 -// - maxItems: 最大条目数 -// -// 返回值: -// - *SharedDict: CreateSharedDictFunc 的结果,未注入时返回 nil -func (m *MockLuaEngine) CreateSharedDict(name string, maxItems int) *SharedDict { - if m.CreateSharedDictFunc != nil { - return m.CreateSharedDictFunc(name, maxItems) - } - return nil -} - -// InitSchedulerLState 初始化调度器 LState(Mock)。 -// -// 返回值: -// - error: InitSchedulerLStateFunc 的结果,未注入时返回 nil -func (m *MockLuaEngine) InitSchedulerLState() error { - if m.InitSchedulerLStateFunc != nil { - return m.InitSchedulerLStateFunc() - } - return nil -} - -// SchedulerLoop 调度器循环(Mock)。 -func (m *MockLuaEngine) SchedulerLoop() { - if m.SchedulerLoopFunc != nil { - m.SchedulerLoopFunc() - } -} - -// EnqueueCallback 将回调加入调度队列(Mock)。 -// -// 参数: -// - entry: 回调条目 -// -// 返回值: -// - bool: EnqueueCallbackFunc 的结果,未注入时返回 false -func (m *MockLuaEngine) EnqueueCallback(entry *CallbackEntry) bool { - if m.EnqueueCallbackFunc != nil { - return m.EnqueueCallbackFunc(entry) - } - return false -} - -// CloseScheduler 关闭调度器(Mock)。 -func (m *MockLuaEngine) CloseScheduler() { - if m.CloseSchedulerFunc != nil { - m.CloseSchedulerFunc() - } -} - -// MockCoroutine 是 LuaCoroutine 的 Mock 实现。 -// -// 通过注入函数指针模拟 LuaCoroutine 的核心方法, -// 同时包含模拟字段供测试验证。 -type MockCoroutine struct { - // ExecuteFunc 模拟 Execute 方法 - ExecuteFunc func(script string) error - - // ExecuteFileFunc 模拟 ExecuteFile 方法 - ExecuteFileFunc func(path string) error - - // SetupSandboxFunc 模拟 SetupSandbox 方法 - SetupSandboxFunc func() error - - // CloseFunc 模拟 Close 方法 - CloseFunc func() - - // HandleYieldFunc 模拟 handleYield 方法 - HandleYieldFunc func(values []glua.LValue) ([]glua.LValue, error) - - // CreatedAt 协程创建时间 - CreatedAt time.Time - - // ExecutionContext 执行上下文 - ExecutionContext context.Context - - // Engine 所属引擎 - Engine *MockLuaEngine - - // Co 底层 Lua 协程 - Co *glua.LState - - // Cancel 取消函数 - Cancel context.CancelFunc - - // RequestCtx fasthttp 请求上下文 - RequestCtx *fasthttp.RequestCtx - - // OutputBuffer 输出缓冲 - OutputBuffer []byte - - // Exited 退出标记 - Exited bool -} - -// Execute 执行脚本(Mock)。 -// -// 参数: -// - script: Lua 脚本 -// -// 返回值: -// - error: ExecuteFunc 的结果,未注入时返回 nil -func (c *MockCoroutine) Execute(script string) error { - if c.ExecuteFunc != nil { - return c.ExecuteFunc(script) - } - return nil -} - -// ExecuteFile 执行文件(Mock)。 -// -// 参数: -// - path: 脚本文件路径 -// -// 返回值: -// - error: ExecuteFileFunc 的结果,未注入时返回 nil -func (c *MockCoroutine) ExecuteFile(path string) error { - if c.ExecuteFileFunc != nil { - return c.ExecuteFileFunc(path) - } - return nil -} - -// SetupSandbox 设置沙箱(Mock)。 -// -// 返回值: -// - error: SetupSandboxFunc 的结果,未注入时返回 nil -func (c *MockCoroutine) SetupSandbox() error { - if c.SetupSandboxFunc != nil { - return c.SetupSandboxFunc() - } - return nil -} - -// Close 关闭协程(Mock)。 -func (c *MockCoroutine) Close() { - if c.CloseFunc != nil { - c.CloseFunc() - } -}