From e646cc5d050be4add2a6c1d61780536ef99a888a Mon Sep 17 00:00:00 2001 From: xfy Date: Mon, 13 Apr 2026 13:15:20 +0800 Subject: [PATCH] =?UTF-8?q?refactor(test):=20=E6=8F=90=E5=8F=96=20testutil?= =?UTF-8?q?=20=E5=8C=85=E7=BB=9F=E4=B8=80=E6=B5=8B=E8=AF=95=E8=BE=85?= =?UTF-8?q?=E5=8A=A9=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 NewRequestCtx 和 NewRequestCtxWithHeader 辅助函数 - 简化各测试文件中 RequestCtx 创建代码 - 减少测试代码重复,提高可维护性 Co-Authored-By: Claude Opus 4.6 --- internal/handler/static_test.go | 6 +- internal/lua/api_log_test.go | 5 +- internal/lua/api_req_test.go | 13 +-- .../middleware/security/ratelimit_test.go | 4 +- internal/proxy/proxy_test.go | 83 +++++++------------ internal/testutil/request.go | 27 ++++++ 6 files changed, 69 insertions(+), 69 deletions(-) create mode 100644 internal/testutil/request.go diff --git a/internal/handler/static_test.go b/internal/handler/static_test.go index 0da5b49..f13e429 100644 --- a/internal/handler/static_test.go +++ b/internal/handler/static_test.go @@ -23,6 +23,8 @@ import ( "testing" "github.com/valyala/fasthttp" + + "rua.plus/lolly/internal/testutil" ) // newTestHandler 创建测试用的静态文件处理器 @@ -34,9 +36,7 @@ func newTestHandler(t *testing.T, root string) *StaticHandler { // newTestContext 创建测试用的 fasthttp 请求上下文 func newTestContext(t *testing.T, path string) *fasthttp.RequestCtx { t.Helper() - var ctx fasthttp.RequestCtx - ctx.Request.SetRequestURI(path) - return &ctx + return testutil.NewRequestCtx("GET", path) } // TestStaticHandlerHandle 测试静态文件处理器 diff --git a/internal/lua/api_log_test.go b/internal/lua/api_log_test.go index 7daa07b..3044a1d 100644 --- a/internal/lua/api_log_test.go +++ b/internal/lua/api_log_test.go @@ -10,12 +10,13 @@ import ( "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" glua "github.com/yuin/gopher-lua" + + "rua.plus/lolly/internal/testutil" ) // mockRequestCtxForLog 创建模拟的 RequestCtx func mockRequestCtxForLog() *fasthttp.RequestCtx { - ctx := &fasthttp.RequestCtx{} - return ctx + return testutil.NewRequestCtx("GET", "/") } // TestNgxLogLevelConstants 测试日志级别常量 diff --git a/internal/lua/api_req_test.go b/internal/lua/api_req_test.go index 4055072..c78c729 100644 --- a/internal/lua/api_req_test.go +++ b/internal/lua/api_req_test.go @@ -8,20 +8,13 @@ import ( "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" glua "github.com/yuin/gopher-lua" + + "rua.plus/lolly/internal/testutil" ) // 创建测试用的 fasthttp.RequestCtx func createTestRequestCtx(method, uri string, headers map[string]string, body []byte) *fasthttp.RequestCtx { - ctx := &fasthttp.RequestCtx{} - - // 设置请求 - ctx.Request.Header.SetMethod(method) - ctx.Request.SetRequestURI(uri) - - // 设置请求头 - for key, value := range headers { - ctx.Request.Header.Set(key, value) - } + ctx := testutil.NewRequestCtxWithHeader(method, uri, headers) // 设置请求体 if len(body) > 0 { diff --git a/internal/middleware/security/ratelimit_test.go b/internal/middleware/security/ratelimit_test.go index 2df592b..d7150d6 100644 --- a/internal/middleware/security/ratelimit_test.go +++ b/internal/middleware/security/ratelimit_test.go @@ -16,7 +16,9 @@ import ( "time" "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/testutil" ) func TestNewRateLimiter(t *testing.T) { @@ -420,7 +422,7 @@ func TestConnLimiterGlobal(t *testing.T) { t.Fatalf("NewConnLimiter() error: %v", err) } - ctx := &fasthttp.RequestCtx{} + ctx := testutil.NewRequestCtx("GET", "/") // First two should succeed if !cl.Acquire(ctx) { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index a78c88b..9d7f514 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -29,6 +29,7 @@ import ( "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/netutil" + "rua.plus/lolly/internal/testutil" "rua.plus/lolly/internal/variable" ) @@ -190,9 +191,7 @@ func TestServeHTTP_NoHealthyTargets(t *testing.T) { } // 创建测试请求 - ctx := &fasthttp.RequestCtx{} - ctx.Request.Header.SetMethod(fasthttp.MethodGet) - ctx.Request.SetRequestURI("/api/test") + ctx := testutil.NewRequestCtx("GET", "/api/test") // 执行请求 p.ServeHTTP(ctx) @@ -240,9 +239,7 @@ func TestServeHTTP_RequestForwarding(t *testing.T) { } // 创建测试请求 - ctx := &fasthttp.RequestCtx{} - ctx.Request.Header.SetMethod(fasthttp.MethodGet) - ctx.Request.SetRequestURI("/api/test") + ctx := testutil.NewRequestCtx("GET", "/api/test") ctx.Request.Header.Set("X-Custom-Header", "client-value") // 执行请求 @@ -331,12 +328,9 @@ func TestSelectTarget(t *testing.T) { t.Fatalf("NewProxy() error: %v", err) } - ctx := &fasthttp.RequestCtx{} - if tt.clientIP != "" { - // 设置远程地址模拟客户端IP - ctx.Request.Header.Set("X-Forwarded-For", tt.clientIP) - } - ctx.Request.SetRequestURI("/api/test") + ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{ + "X-Forwarded-For": tt.clientIP, + }) target := p.selectTarget(ctx) @@ -448,26 +442,22 @@ func TestModifyRequestHeaders(t *testing.T) { t.Fatalf("NewProxy() error: %v", err) } - ctx := &fasthttp.RequestCtx{} - ctx.Request.SetRequestURI("/api/test") - - // 设置客户端IP + // 构建 headers map + headers := make(map[string]string) if tt.clientIP != "" { - ctx.Request.Header.Set("X-Real-IP", tt.clientIP) + headers["X-Real-IP"] = tt.clientIP } - - // 设置已有的X-Forwarded-For if tt.existingXFF != "" { - ctx.Request.Header.Set("X-Forwarded-For", tt.existingXFF) + headers["X-Forwarded-For"] = tt.existingXFF } - - // 设置需要被移除的头 if len(tt.removeHeaders) > 0 { for _, h := range tt.removeHeaders { - ctx.Request.Header.Set(h, "should-be-removed") + headers[h] = "should-be-removed" } } + ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", headers) + target := &loadbalance.Target{URL: "http://localhost:8080"} p.modifyRequestHeaders(ctx, target) @@ -543,8 +533,7 @@ func TestModifyResponseHeaders(t *testing.T) { t.Fatalf("NewProxy() error: %v", err) } - ctx := &fasthttp.RequestCtx{} - ctx.Response.SetStatusCode(fasthttp.StatusOK) + ctx := testutil.NewRequestCtx("GET", "/") p.modifyResponseHeaders(ctx) @@ -597,13 +586,10 @@ func TestGetClientIP(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := &fasthttp.RequestCtx{} - if tt.xff != "" { - ctx.Request.Header.Set("X-Forwarded-For", tt.xff) - } - if tt.xri != "" { - ctx.Request.Header.Set("X-Real-IP", tt.xri) - } + ctx := testutil.NewRequestCtxWithHeader("GET", "/", map[string]string{ + "X-Forwarded-For": tt.xff, + "X-Real-IP": tt.xri, + }) ip := netutil.ExtractClientIP(ctx) if ip != tt.expected { @@ -761,13 +747,10 @@ func TestIsWebSocketRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := &fasthttp.RequestCtx{} - if tt.upgrade != "" { - ctx.Request.Header.Set("Upgrade", tt.upgrade) - } - if tt.connection != "" { - ctx.Request.Header.Set("Connection", tt.connection) - } + ctx := testutil.NewRequestCtxWithHeader("GET", "/", map[string]string{ + "Upgrade": tt.upgrade, + "Connection": tt.connection, + }) result := isWebSocketRequest(ctx) if result != tt.expected { @@ -1103,9 +1086,7 @@ func TestServeHTTP_WithPassiveHealthCheck(t *testing.T) { p.SetHealthChecker(hc) // 创建测试请求 - ctx := &fasthttp.RequestCtx{} - ctx.Request.Header.SetMethod(fasthttp.MethodGet) - ctx.Request.SetRequestURI("/api/test") + ctx := testutil.NewRequestCtx("GET", "/api/test") // 执行请求 - 应该会失败并触发被动健康检查 p.ServeHTTP(ctx) @@ -1181,10 +1162,9 @@ func TestUpstreamVariablesCapture(t *testing.T) { } // 创建请求 - ctx := &fasthttp.RequestCtx{} - ctx.Request.Header.SetMethod("GET") - ctx.Request.Header.SetRequestURI("/test") - ctx.Request.Header.SetHost("example.com") + ctx := testutil.NewRequestCtxWithHeader("GET", "/test", map[string]string{ + "Host": "example.com", + }) // 执行代理请求 p.ServeHTTP(ctx) @@ -1264,10 +1244,9 @@ func TestUpstreamVariablesErrorPaths(t *testing.T) { t.Fatalf("failed to create proxy: %v", err) } - ctx := &fasthttp.RequestCtx{} - ctx.Request.Header.SetMethod("GET") - ctx.Request.Header.SetRequestURI("/test") - ctx.Request.Header.SetHost("example.com") + ctx := testutil.NewRequestCtxWithHeader("GET", "/test", map[string]string{ + "Host": "example.com", + }) p.ServeHTTP(ctx) @@ -1283,9 +1262,7 @@ func TestUpstreamVariablesErrorPaths(t *testing.T) { // TestFinalizeUpstreamVars 测试 FinalizeUpstreamVars 函数 func TestFinalizeUpstreamVars(t *testing.T) { - ctx := &fasthttp.RequestCtx{} - ctx.Request.Header.SetMethod("GET") - ctx.Request.Header.SetRequestURI("/test") + ctx := testutil.NewRequestCtx("GET", "/test") vc := variable.NewContext(ctx) defer variable.ReleaseContext(vc) diff --git a/internal/testutil/request.go b/internal/testutil/request.go new file mode 100644 index 0000000..28d21b5 --- /dev/null +++ b/internal/testutil/request.go @@ -0,0 +1,27 @@ +package testutil + +import "github.com/valyala/fasthttp" + +// NewRequestCtx 创建测试用的请求上下文 +func NewRequestCtx(method, path string) *fasthttp.RequestCtx { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(method) + ctx.Request.SetRequestURI(path) + return ctx +} + +// NewRequestCtxWithBody 创建带 body 的测试请求上下文 +func NewRequestCtxWithBody(method, path, body string) *fasthttp.RequestCtx { + ctx := NewRequestCtx(method, path) + ctx.Request.SetBodyString(body) + return ctx +} + +// NewRequestCtxWithHeader 创建带 header 的测试请求上下文 +func NewRequestCtxWithHeader(method, path string, headers map[string]string) *fasthttp.RequestCtx { + ctx := NewRequestCtx(method, path) + for k, v := range headers { + ctx.Request.Header.Set(k, v) + } + return ctx +}