refactor(test): 提取 testutil 包统一测试辅助函数

- 新增 NewRequestCtx 和 NewRequestCtxWithHeader 辅助函数
- 简化各测试文件中 RequestCtx 创建代码
- 减少测试代码重复,提高可维护性

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-13 13:15:20 +08:00
parent 95b6119e34
commit e646cc5d05
6 changed files with 69 additions and 69 deletions

View File

@ -23,6 +23,8 @@ import (
"testing" "testing"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"rua.plus/lolly/internal/testutil"
) )
// newTestHandler 创建测试用的静态文件处理器 // newTestHandler 创建测试用的静态文件处理器
@ -34,9 +36,7 @@ func newTestHandler(t *testing.T, root string) *StaticHandler {
// newTestContext 创建测试用的 fasthttp 请求上下文 // newTestContext 创建测试用的 fasthttp 请求上下文
func newTestContext(t *testing.T, path string) *fasthttp.RequestCtx { func newTestContext(t *testing.T, path string) *fasthttp.RequestCtx {
t.Helper() t.Helper()
var ctx fasthttp.RequestCtx return testutil.NewRequestCtx("GET", path)
ctx.Request.SetRequestURI(path)
return &ctx
} }
// TestStaticHandlerHandle 测试静态文件处理器 // TestStaticHandlerHandle 测试静态文件处理器

View File

@ -10,12 +10,13 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
glua "github.com/yuin/gopher-lua" glua "github.com/yuin/gopher-lua"
"rua.plus/lolly/internal/testutil"
) )
// mockRequestCtxForLog 创建模拟的 RequestCtx // mockRequestCtxForLog 创建模拟的 RequestCtx
func mockRequestCtxForLog() *fasthttp.RequestCtx { func mockRequestCtxForLog() *fasthttp.RequestCtx {
ctx := &fasthttp.RequestCtx{} return testutil.NewRequestCtx("GET", "/")
return ctx
} }
// TestNgxLogLevelConstants 测试日志级别常量 // TestNgxLogLevelConstants 测试日志级别常量

View File

@ -8,20 +8,13 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
glua "github.com/yuin/gopher-lua" glua "github.com/yuin/gopher-lua"
"rua.plus/lolly/internal/testutil"
) )
// 创建测试用的 fasthttp.RequestCtx // 创建测试用的 fasthttp.RequestCtx
func createTestRequestCtx(method, uri string, headers map[string]string, body []byte) *fasthttp.RequestCtx { func createTestRequestCtx(method, uri string, headers map[string]string, body []byte) *fasthttp.RequestCtx {
ctx := &fasthttp.RequestCtx{} ctx := testutil.NewRequestCtxWithHeader(method, uri, headers)
// 设置请求
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI(uri)
// 设置请求头
for key, value := range headers {
ctx.Request.Header.Set(key, value)
}
// 设置请求体 // 设置请求体
if len(body) > 0 { if len(body) > 0 {

View File

@ -16,7 +16,9 @@ import (
"time" "time"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config" "rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/testutil"
) )
func TestNewRateLimiter(t *testing.T) { func TestNewRateLimiter(t *testing.T) {
@ -420,7 +422,7 @@ func TestConnLimiterGlobal(t *testing.T) {
t.Fatalf("NewConnLimiter() error: %v", err) t.Fatalf("NewConnLimiter() error: %v", err)
} }
ctx := &fasthttp.RequestCtx{} ctx := testutil.NewRequestCtx("GET", "/")
// First two should succeed // First two should succeed
if !cl.Acquire(ctx) { if !cl.Acquire(ctx) {

View File

@ -29,6 +29,7 @@ import (
"rua.plus/lolly/internal/config" "rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/netutil" "rua.plus/lolly/internal/netutil"
"rua.plus/lolly/internal/testutil"
"rua.plus/lolly/internal/variable" "rua.plus/lolly/internal/variable"
) )
@ -190,9 +191,7 @@ func TestServeHTTP_NoHealthyTargets(t *testing.T) {
} }
// 创建测试请求 // 创建测试请求
ctx := &fasthttp.RequestCtx{} ctx := testutil.NewRequestCtx("GET", "/api/test")
ctx.Request.Header.SetMethod(fasthttp.MethodGet)
ctx.Request.SetRequestURI("/api/test")
// 执行请求 // 执行请求
p.ServeHTTP(ctx) p.ServeHTTP(ctx)
@ -240,9 +239,7 @@ func TestServeHTTP_RequestForwarding(t *testing.T) {
} }
// 创建测试请求 // 创建测试请求
ctx := &fasthttp.RequestCtx{} ctx := testutil.NewRequestCtx("GET", "/api/test")
ctx.Request.Header.SetMethod(fasthttp.MethodGet)
ctx.Request.SetRequestURI("/api/test")
ctx.Request.Header.Set("X-Custom-Header", "client-value") ctx.Request.Header.Set("X-Custom-Header", "client-value")
// 执行请求 // 执行请求
@ -331,12 +328,9 @@ func TestSelectTarget(t *testing.T) {
t.Fatalf("NewProxy() error: %v", err) t.Fatalf("NewProxy() error: %v", err)
} }
ctx := &fasthttp.RequestCtx{} ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{
if tt.clientIP != "" { "X-Forwarded-For": tt.clientIP,
// 设置远程地址模拟客户端IP })
ctx.Request.Header.Set("X-Forwarded-For", tt.clientIP)
}
ctx.Request.SetRequestURI("/api/test")
target := p.selectTarget(ctx) target := p.selectTarget(ctx)
@ -448,26 +442,22 @@ func TestModifyRequestHeaders(t *testing.T) {
t.Fatalf("NewProxy() error: %v", err) t.Fatalf("NewProxy() error: %v", err)
} }
ctx := &fasthttp.RequestCtx{} // 构建 headers map
ctx.Request.SetRequestURI("/api/test") headers := make(map[string]string)
// 设置客户端IP
if tt.clientIP != "" { if tt.clientIP != "" {
ctx.Request.Header.Set("X-Real-IP", tt.clientIP) headers["X-Real-IP"] = tt.clientIP
} }
// 设置已有的X-Forwarded-For
if tt.existingXFF != "" { if tt.existingXFF != "" {
ctx.Request.Header.Set("X-Forwarded-For", tt.existingXFF) headers["X-Forwarded-For"] = tt.existingXFF
} }
// 设置需要被移除的头
if len(tt.removeHeaders) > 0 { if len(tt.removeHeaders) > 0 {
for _, h := range tt.removeHeaders { for _, h := range tt.removeHeaders {
ctx.Request.Header.Set(h, "should-be-removed") headers[h] = "should-be-removed"
} }
} }
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", headers)
target := &loadbalance.Target{URL: "http://localhost:8080"} target := &loadbalance.Target{URL: "http://localhost:8080"}
p.modifyRequestHeaders(ctx, target) p.modifyRequestHeaders(ctx, target)
@ -543,8 +533,7 @@ func TestModifyResponseHeaders(t *testing.T) {
t.Fatalf("NewProxy() error: %v", err) t.Fatalf("NewProxy() error: %v", err)
} }
ctx := &fasthttp.RequestCtx{} ctx := testutil.NewRequestCtx("GET", "/")
ctx.Response.SetStatusCode(fasthttp.StatusOK)
p.modifyResponseHeaders(ctx) p.modifyResponseHeaders(ctx)
@ -597,13 +586,10 @@ func TestGetClientIP(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{} ctx := testutil.NewRequestCtxWithHeader("GET", "/", map[string]string{
if tt.xff != "" { "X-Forwarded-For": tt.xff,
ctx.Request.Header.Set("X-Forwarded-For", tt.xff) "X-Real-IP": tt.xri,
} })
if tt.xri != "" {
ctx.Request.Header.Set("X-Real-IP", tt.xri)
}
ip := netutil.ExtractClientIP(ctx) ip := netutil.ExtractClientIP(ctx)
if ip != tt.expected { if ip != tt.expected {
@ -761,13 +747,10 @@ func TestIsWebSocketRequest(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{} ctx := testutil.NewRequestCtxWithHeader("GET", "/", map[string]string{
if tt.upgrade != "" { "Upgrade": tt.upgrade,
ctx.Request.Header.Set("Upgrade", tt.upgrade) "Connection": tt.connection,
} })
if tt.connection != "" {
ctx.Request.Header.Set("Connection", tt.connection)
}
result := isWebSocketRequest(ctx) result := isWebSocketRequest(ctx)
if result != tt.expected { if result != tt.expected {
@ -1103,9 +1086,7 @@ func TestServeHTTP_WithPassiveHealthCheck(t *testing.T) {
p.SetHealthChecker(hc) p.SetHealthChecker(hc)
// 创建测试请求 // 创建测试请求
ctx := &fasthttp.RequestCtx{} ctx := testutil.NewRequestCtx("GET", "/api/test")
ctx.Request.Header.SetMethod(fasthttp.MethodGet)
ctx.Request.SetRequestURI("/api/test")
// 执行请求 - 应该会失败并触发被动健康检查 // 执行请求 - 应该会失败并触发被动健康检查
p.ServeHTTP(ctx) p.ServeHTTP(ctx)
@ -1181,10 +1162,9 @@ func TestUpstreamVariablesCapture(t *testing.T) {
} }
// 创建请求 // 创建请求
ctx := &fasthttp.RequestCtx{} ctx := testutil.NewRequestCtxWithHeader("GET", "/test", map[string]string{
ctx.Request.Header.SetMethod("GET") "Host": "example.com",
ctx.Request.Header.SetRequestURI("/test") })
ctx.Request.Header.SetHost("example.com")
// 执行代理请求 // 执行代理请求
p.ServeHTTP(ctx) p.ServeHTTP(ctx)
@ -1264,10 +1244,9 @@ func TestUpstreamVariablesErrorPaths(t *testing.T) {
t.Fatalf("failed to create proxy: %v", err) t.Fatalf("failed to create proxy: %v", err)
} }
ctx := &fasthttp.RequestCtx{} ctx := testutil.NewRequestCtxWithHeader("GET", "/test", map[string]string{
ctx.Request.Header.SetMethod("GET") "Host": "example.com",
ctx.Request.Header.SetRequestURI("/test") })
ctx.Request.Header.SetHost("example.com")
p.ServeHTTP(ctx) p.ServeHTTP(ctx)
@ -1283,9 +1262,7 @@ func TestUpstreamVariablesErrorPaths(t *testing.T) {
// TestFinalizeUpstreamVars 测试 FinalizeUpstreamVars 函数 // TestFinalizeUpstreamVars 测试 FinalizeUpstreamVars 函数
func TestFinalizeUpstreamVars(t *testing.T) { func TestFinalizeUpstreamVars(t *testing.T) {
ctx := &fasthttp.RequestCtx{} ctx := testutil.NewRequestCtx("GET", "/test")
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetRequestURI("/test")
vc := variable.NewContext(ctx) vc := variable.NewContext(ctx)
defer variable.ReleaseContext(vc) defer variable.ReleaseContext(vc)

View File

@ -0,0 +1,27 @@
package testutil
import "github.com/valyala/fasthttp"
// NewRequestCtx 创建测试用的请求上下文
func NewRequestCtx(method, path string) *fasthttp.RequestCtx {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI(path)
return ctx
}
// NewRequestCtxWithBody 创建带 body 的测试请求上下文
func NewRequestCtxWithBody(method, path, body string) *fasthttp.RequestCtx {
ctx := NewRequestCtx(method, path)
ctx.Request.SetBodyString(body)
return ctx
}
// NewRequestCtxWithHeader 创建带 header 的测试请求上下文
func NewRequestCtxWithHeader(method, path string, headers map[string]string) *fasthttp.RequestCtx {
ctx := NewRequestCtx(method, path)
for k, v := range headers {
ctx.Request.Header.Set(k, v)
}
return ctx
}