lolly/internal/proxy/proxy_coverage_extra_test.go

1481 lines
40 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package proxy 提供额外的覆盖测试,补充低覆盖率函数的测试。
//
// 该文件测试以下功能:
// - HealthChecker.MarkHealthy 和 run 方法
// - selectByLua 和 selectByFallback 方法
// - rewriteCookies 和 rewriteCookieAttr 函数
// - modifyResponseHeaders 边缘情况
// - createHostClient 完整选项
// - TempFileManager 和 TempFileCleaner getter 方法
// - NewRedirectRewriter 正则规则和 RewriteRefreshOnly
// - rewriteCustom 正则模式
// - selectTarget 边缘情况
//
// 作者xfy
package proxy
import (
"net"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/lua"
"rua.plus/lolly/internal/testutil"
)
// TestHealthChecker_MarkHealthy 测试 MarkHealthy 方法。
func TestHealthChecker_MarkHealthy(t *testing.T) {
t.Run("标记健康状态", func(t *testing.T) {
target := &loadbalance.Target{URL: "http://backend:8080"}
target.Healthy.Store(false)
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 10 * time.Second,
Timeout: 5 * time.Second,
})
checker.MarkHealthy(target)
if !target.Healthy.Load() {
t.Error("MarkHealthy() 后 target 应标记为 healthy")
}
})
t.Run("重置失败计数", func(t *testing.T) {
target := &loadbalance.Target{URL: "http://backend:8080"}
target.Healthy.Store(false)
target.RecordFailure()
target.RecordFailure()
target.RecordFailure()
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{})
checker.MarkHealthy(target)
if !target.Healthy.Load() {
t.Error("MarkHealthy() 后 target 应标记为 healthy")
}
})
t.Run("多目标场景", func(t *testing.T) {
target1 := &loadbalance.Target{URL: "http://backend1:8080"}
target1.Healthy.Store(false)
target2 := &loadbalance.Target{URL: "http://backend2:8080"}
target2.Healthy.Store(false)
checker := NewHealthChecker([]*loadbalance.Target{target1, target2}, &config.HealthCheckConfig{})
checker.MarkHealthy(target1)
if !target1.Healthy.Load() {
t.Error("target1 应标记为 healthy")
}
if target2.Healthy.Load() {
t.Error("target2 应保持 unhealthy")
}
})
}
// TestHealthChecker_Run 测试 run 方法。
func TestHealthChecker_Run(t *testing.T) {
t.Run("初始检查执行", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
target := &loadbalance.Target{URL: server.URL}
target.Healthy.Store(false)
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour,
Timeout: 5 * time.Second,
Path: "/health",
})
// 启动检查器
checker.Start()
// 等待初始检查完成
time.Sleep(50 * time.Millisecond)
// 验证初始检查已执行
if !target.Healthy.Load() {
t.Error("初始检查后 target 应标记为 healthy")
}
checker.Stop()
})
t.Run("定时检查执行", func(t *testing.T) {
var requestCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
target := &loadbalance.Target{URL: server.URL}
target.Healthy.Store(true)
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 50 * time.Millisecond,
Timeout: 5 * time.Second,
Path: "/health",
})
checker.Start()
time.Sleep(120 * time.Millisecond)
checker.Stop()
// 应该至少执行初始检查 + 2 次定时检查
if requestCount.Load() < 2 {
t.Errorf("期望至少 2 次检查,实际 %d 次", requestCount.Load())
}
})
t.Run("停止后不再检查", func(t *testing.T) {
var requestCount atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
target := &loadbalance.Target{URL: server.URL}
target.Healthy.Store(true)
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 50 * time.Millisecond,
Timeout: 5 * time.Second,
})
checker.Start()
time.Sleep(60 * time.Millisecond)
checker.Stop()
countAfterStop := requestCount.Load()
// 等待一段时间,确认不再有检查
time.Sleep(100 * time.Millisecond)
if requestCount.Load() != countAfterStop {
t.Error("停止后不应再执行检查")
}
})
}
// TestSelectByFallback 测试 selectByFallback 方法。
func TestSelectByFallback(t *testing.T) {
t.Run("round_robin fallback", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
BalancerByLua: config.BalancerByLuaConfig{
Fallback: "round_robin",
},
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
}
for _, t := range targets {
t.Healthy.Store(true)
}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
selected := p.selectByFallback(ctx, targets)
if selected == nil {
t.Error("selectByFallback() should return a target")
}
})
t.Run("ip_hash fallback", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
BalancerByLua: config.BalancerByLuaConfig{
Fallback: "ip_hash",
},
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
}
for _, t := range targets {
t.Healthy.Store(true)
}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{
"X-Forwarded-For": "192.168.1.1",
})
selected := p.selectByFallback(ctx, targets)
if selected == nil {
t.Error("selectByFallback() should return a target for ip_hash")
}
// 相同 IP 应返回相同目标
selected2 := p.selectByFallback(ctx, targets)
if selected2 == nil || selected.URL != selected2.URL {
t.Error("ip_hash should consistently return same target for same IP")
}
})
}
// TestSelectByLua 测试 selectByLua 方法。
func TestSelectByLua(t *testing.T) {
t.Run("有 Lua 引擎但脚本不存在", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
BalancerByLua: config.BalancerByLuaConfig{
Enabled: true,
Script: "/nonexistent/script.lua",
},
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
}
targets[0].Healthy.Store(true)
luaEngine, err := lua.NewEngine(nil)
if err != nil {
t.Fatalf("NewEngine() error: %v", err)
}
p, err := NewProxy(cfg, targets, nil, luaEngine)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
_, err = p.selectByLua(ctx, targets)
if err == nil {
t.Error("selectByLua() should return error for nonexistent script")
}
})
t.Run("Lua 引擎正常工作但脚本返回错误", func(t *testing.T) {
// 创建临时 Lua 脚本
tmpFile, err := os.CreateTemp("", "test_*.lua")
if err != nil {
t.Fatalf("创建临时文件失败: %v", err)
}
defer os.Remove(tmpFile.Name())
// 写入一个会报错的脚本
_, _ = tmpFile.WriteString("error('test error')")
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
BalancerByLua: config.BalancerByLuaConfig{
Enabled: true,
Script: tmpFile.Name(),
},
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
}
targets[0].Healthy.Store(true)
luaEngine, err := lua.NewEngine(nil)
if err != nil {
t.Fatalf("NewEngine() error: %v", err)
}
p, err := NewProxy(cfg, targets, nil, luaEngine)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
_, err = p.selectByLua(ctx, targets)
// 脚本执行错误应该返回错误
if err == nil {
t.Error("selectByLua() should return error for script error")
}
})
}
// TestRewriteCookies 测试 rewriteCookies 方法。
func TestRewriteCookies(t *testing.T) {
tests := []struct {
name string
cookies []string
cookieDomain string
cookiePath string
wantContains []string
wantNotContains []string
}{
{
name: "改写 Domain",
cookies: []string{"session=abc123; Domain=old.example.com; Path=/"},
cookieDomain: "new.example.com",
wantContains: []string{"Domain=new.example.com"},
},
{
name: "改写 Path",
cookies: []string{"session=abc123; Domain=example.com; Path=/old/"},
cookiePath: "/new/",
wantContains: []string{"Path=/new/"},
},
{
name: "同时改写 Domain 和 Path",
cookies: []string{"session=abc123; Domain=old.example.com; Path=/old/"},
cookieDomain: "new.example.com",
cookiePath: "/new/",
wantContains: []string{"Domain=new.example.com", "Path=/new/"},
},
{
name: "无 Domain 属性时不改写",
cookies: []string{"session=abc123"},
cookiePath: "/new/",
wantContains: []string{"session=abc123"},
},
{
name: "空配置不改写",
cookies: []string{"session=abc123; Domain=example.com"},
cookieDomain: "",
cookiePath: "",
wantContains: []string{"Domain=example.com"},
},
{
name: "大小写不敏感匹配",
cookies: []string{"session=abc123; domain=old.example.com; path=/old/"},
cookieDomain: "new.example.com",
cookiePath: "/new/",
wantContains: []string{"domain=new.example.com", "path=/new/"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Headers: config.ProxyHeaders{
CookieDomain: tt.cookieDomain,
CookiePath: tt.cookiePath,
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
for _, cookie := range tt.cookies {
ctx.Response.Header.Set("Set-Cookie", cookie)
}
p.modifyResponseHeaders(ctx)
cookies := strings.Split(string(ctx.Response.Header.Peek("Set-Cookie")), ";")
cookieStr := string(ctx.Response.Header.Peek("Set-Cookie"))
for _, want := range tt.wantContains {
found := false
for _, c := range cookies {
if strings.Contains(strings.TrimSpace(c), want) || strings.Contains(cookieStr, want) {
found = true
break
}
}
if !found && !strings.Contains(cookieStr, want) {
t.Errorf("cookie 应包含 %q, 实际: %q", want, cookieStr)
}
}
for _, notWant := range tt.wantNotContains {
if strings.Contains(cookieStr, notWant) {
t.Errorf("cookie 不应包含 %q, 实际: %q", notWant, cookieStr)
}
}
})
}
}
// TestRewriteCookieAttr 测试 rewriteCookieAttr 函数。
func TestRewriteCookieAttr(t *testing.T) {
tests := []struct {
name string
cookie string
attr string
newValue string
want string
}{
{
name: "改写 Domain",
cookie: "session=abc; Domain=old.com; Path=/",
attr: "Domain",
newValue: "new.com",
want: "session=abc; Domain=new.com; Path=/",
},
{
name: "改写 Path",
cookie: "session=abc; Domain=example.com; Path=/old",
attr: "Path",
newValue: "/new",
want: "session=abc; Domain=example.com; Path=/new",
},
{
name: "属性不存在则不改写",
cookie: "session=abc",
attr: "Domain",
newValue: "new.com",
want: "session=abc",
},
{
name: "大小写不敏感",
cookie: "session=abc; domain=old.com",
attr: "Domain",
newValue: "new.com",
want: "session=abc; domain=new.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := rewriteCookieAttr(tt.cookie, tt.attr, tt.newValue)
if got != tt.want {
t.Errorf("rewriteCookieAttr() = %q, want %q", got, tt.want)
}
})
}
}
// TestModifyResponseHeaders_PassResponse 测试 PassResponse 白名单模式。
func TestModifyResponseHeaders_PassResponse(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Headers: config.ProxyHeaders{
PassResponse: []string{"Content-Type", "X-Allowed"},
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
ctx.Response.Header.Set("Content-Type", "application/json")
ctx.Response.Header.Set("X-Allowed", "allowed-value")
ctx.Response.Header.Set("X-Blocked", "blocked-value")
p.modifyResponseHeaders(ctx)
// 白名单中的头应保留
if string(ctx.Response.Header.Peek("Content-Type")) != "application/json" {
t.Error("Content-Type 应被保留")
}
if string(ctx.Response.Header.Peek("X-Allowed")) != "allowed-value" {
t.Error("X-Allowed 应被保留")
}
// 不在白名单中的头应被删除
if len(ctx.Response.Header.Peek("X-Blocked")) > 0 {
t.Error("X-Blocked 应被删除")
}
}
// TestModifyResponseHeaders_HideResponse 测试 HideResponse 功能。
func TestModifyResponseHeaders_HideResponse(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Headers: config.ProxyHeaders{
HideResponse: []string{"X-Hidden-Header"},
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
ctx.Response.Header.Set("X-Hidden-Header", "should-be-hidden")
ctx.Response.Header.Set("X-Visible-Header", "should-be-visible")
p.modifyResponseHeaders(ctx)
if len(ctx.Response.Header.Peek("X-Hidden-Header")) > 0 {
t.Error("X-Hidden-Header 应被删除")
}
if string(ctx.Response.Header.Peek("X-Visible-Header")) != "should-be-visible" {
t.Error("X-Visible-Header 应被保留")
}
}
// TestModifyResponseHeaders_IgnoreHeaders 测试 IgnoreHeaders 功能。
func TestModifyResponseHeaders_IgnoreHeaders(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Headers: config.ProxyHeaders{
IgnoreHeaders: []string{"X-Ignored"},
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
ctx.Request.Header.Set("X-Ignored", "ignored-value")
ctx.Response.Header.Set("X-Ignored", "ignored-response-value")
ctx.Response.Header.Set("X-Not-Ignored", "not-ignored")
p.modifyResponseHeaders(ctx)
if len(ctx.Request.Header.Peek("X-Ignored")) > 0 {
t.Error("请求中的 X-Ignored 应被删除")
}
if len(ctx.Response.Header.Peek("X-Ignored")) > 0 {
t.Error("响应中的 X-Ignored 应被删除")
}
if string(ctx.Response.Header.Peek("X-Not-Ignored")) != "not-ignored" {
t.Error("X-Not-Ignored 应被保留")
}
}
// TestCreateHostClient_TransportConfig 测试 Transport 配置。
func TestCreateHostClient_TransportConfig(t *testing.T) {
transportCfg := &config.TransportConfig{
IdleConnTimeout: 60 * time.Second,
MaxConnsPerHost: 50,
}
client := createHostClient("http://localhost:8080", config.ProxyTimeout{
Connect: 5 * time.Second,
Read: 30 * time.Second,
Write: 30 * time.Second,
}, transportCfg, nil, "", nil)
if client == nil {
t.Fatal("createHostClient() returned nil")
}
if client.MaxIdleConnDuration != 60*time.Second {
t.Errorf("MaxIdleConnDuration = %v, want 60s", client.MaxIdleConnDuration)
}
if client.MaxConns != 50 {
t.Errorf("MaxConns = %d, want 50", client.MaxConns)
}
}
// TestCreateHostClient_Buffering 测试 Buffering 配置。
func TestCreateHostClient_Buffering(t *testing.T) {
t.Run("streaming mode", func(t *testing.T) {
buffering := &config.ProxyBufferingConfig{
Mode: "off",
}
client := createHostClient("http://localhost:8080", config.ProxyTimeout{}, nil, nil, "", buffering)
if !client.StreamResponseBody {
t.Error("StreamResponseBody should be true when buffering is off")
}
})
t.Run("custom buffer size", func(t *testing.T) {
buffering := &config.ProxyBufferingConfig{
BufferSize: 64 * 1024,
}
client := createHostClient("http://localhost:8080", config.ProxyTimeout{}, nil, nil, "", buffering)
if client.ReadBufferSize != 64*1024 {
t.Errorf("ReadBufferSize = %d, want 64KB", client.ReadBufferSize)
}
if client.WriteBufferSize != 64*1024 {
t.Errorf("WriteBufferSize = %d, want 64KB", client.WriteBufferSize)
}
})
}
// TestCreateHostClient_ProxyBind 测试 ProxyBind 配置。
func TestCreateHostClient_ProxyBind(t *testing.T) {
// 这个测试只验证 ProxyBind 参数不会导致 panic
client := createHostClient("http://localhost:8080", config.ProxyTimeout{
Connect: 5 * time.Second,
}, nil, nil, "127.0.0.1", nil)
if client == nil {
t.Error("createHostClient() returned nil")
}
if client.Dial == nil {
t.Error("Dial should be set when ProxyBind is specified")
}
}
// TestNewRedirectRewriter_RegexRules 测试正则规则。
func TestNewRedirectRewriter_RegexRules(t *testing.T) {
t.Run("正则模式", func(t *testing.T) {
cfg := &config.RedirectRewriteConfig{
Mode: "custom",
Rules: []config.RedirectRewriteRule{
{Pattern: "~http://backend:\\d+", Replacement: "http://frontend"},
},
}
rw, err := NewRedirectRewriter(cfg, "/")
if err != nil {
t.Fatalf("NewRedirectRewriter() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/")
resp := &fasthttp.Response{}
resp.Header.Set("Location", "http://backend:8080/api")
resp.SetStatusCode(301)
rw.RewriteResponse(resp, ctx, "", "frontend")
got := string(resp.Header.Peek("Location"))
want := "http://frontend/api"
if got != want {
t.Errorf("Location = %q, want %q", got, want)
}
})
t.Run("大小写不敏感正则", func(t *testing.T) {
cfg := &config.RedirectRewriteConfig{
Mode: "custom",
Rules: []config.RedirectRewriteRule{
// 注意大小写不敏感模式下pattern 应该是小写,因为代码会将输入转为小写匹配
{Pattern: "~*backend", Replacement: "frontend"},
},
}
rw, err := NewRedirectRewriter(cfg, "/")
if err != nil {
t.Fatalf("NewRedirectRewriter() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/")
resp := &fasthttp.Response{}
// 使用大写的 URL 来测试大小写不敏感匹配
resp.Header.Set("Location", "http://BACKEND/api")
resp.SetStatusCode(301)
rw.RewriteResponse(resp, ctx, "", "frontend")
got := string(resp.Header.Peek("Location"))
want := "http://frontend/api"
if got != want {
t.Errorf("Location = %q, want %q", got, want)
}
})
t.Run("无效正则返回错误", func(t *testing.T) {
cfg := &config.RedirectRewriteConfig{
Mode: "custom",
Rules: []config.RedirectRewriteRule{
{Pattern: "~[invalid", Replacement: "/"},
},
}
_, err := NewRedirectRewriter(cfg, "/")
if err == nil {
t.Error("NewRedirectRewriter() should return error for invalid regex")
}
})
}
// TestRedirectRewriter_RewriteRefreshOnly 测试 RewriteRefreshOnly 方法。
func TestRedirectRewriter_RewriteRefreshOnly(t *testing.T) {
cfg := &config.RedirectRewriteConfig{
Mode: "custom",
Rules: []config.RedirectRewriteRule{
{Pattern: "http://backend:8080/", Replacement: "http://frontend/"},
},
}
rw, err := NewRedirectRewriter(cfg, "/")
if err != nil {
t.Fatalf("NewRedirectRewriter() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/")
resp := &fasthttp.Response{}
resp.Header.Set("Refresh", "5; url=http://backend:8080/page")
resp.SetStatusCode(200) // 非 3xx
rw.RewriteRefreshOnly(resp, ctx, "", "frontend")
got := string(resp.Header.Peek("Refresh"))
want := "5; url=http://frontend/page"
if got != want {
t.Errorf("Refresh = %q, want %q", got, want)
}
}
// TestRewriteCustom 测试 rewriteCustom 方法。
func TestRewriteCustom(t *testing.T) {
t.Run("正则替换", func(t *testing.T) {
cfg := &config.RedirectRewriteConfig{
Mode: "custom",
Rules: []config.RedirectRewriteRule{
// 注意rewriteCustom 不支持捕获组,只是简单替换匹配的部分
{Pattern: "~http://[a-z]+:\\d+", Replacement: "https://new.example.com"},
},
}
rw, err := NewRedirectRewriter(cfg, "/")
if err != nil {
t.Fatalf("NewRedirectRewriter() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/")
result := rw.rewriteURL("http://backend:8080/api/users", ctx, "", "frontend")
want := "https://new.example.com/api/users"
if result != want {
t.Errorf("rewriteURL() = %q, want %q", result, want)
}
})
t.Run("精确前缀匹配", func(t *testing.T) {
cfg := &config.RedirectRewriteConfig{
Mode: "custom",
Rules: []config.RedirectRewriteRule{
{Pattern: "http://old.example.com/", Replacement: "http://new.example.com/"},
},
}
rw, err := NewRedirectRewriter(cfg, "/")
if err != nil {
t.Fatalf("NewRedirectRewriter() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/")
result := rw.rewriteURL("http://old.example.com/page", ctx, "", "frontend")
want := "http://new.example.com/page"
if result != want {
t.Errorf("rewriteURL() = %q, want %q", result, want)
}
})
t.Run("无匹配则原样返回", func(t *testing.T) {
cfg := &config.RedirectRewriteConfig{
Mode: "custom",
Rules: []config.RedirectRewriteRule{
{Pattern: "http://other.com/", Replacement: "/"},
},
}
rw, err := NewRedirectRewriter(cfg, "/")
if err != nil {
t.Fatalf("NewRedirectRewriter() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/")
result := rw.rewriteURL("http://example.com/page", ctx, "", "frontend")
want := "http://example.com/page"
if result != want {
t.Errorf("rewriteURL() = %q, want %q", result, want)
}
})
}
// TestSelectTarget_EmptyTargets 测试空目标列表。
func TestSelectTarget_EmptyTargets(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 清空目标
p.mu.Lock()
p.targets = nil
p.mu.Unlock()
ctx := testutil.NewRequestCtx("GET", "/api/test")
selected := p.selectTarget(ctx)
if selected != nil {
t.Error("selectTarget() should return nil for empty targets")
}
}
// TestDialTarget 测试 dialTarget 函数。
func TestDialTarget(t *testing.T) {
t.Run("连接超时", func(t *testing.T) {
// 使用不可达地址测试超时
_, err := dialTarget("http://10.255.255.1:9999", 100*time.Millisecond)
if err == nil {
t.Error("dialTarget() should return error for unreachable address")
}
})
t.Run("HTTPS 连接失败", func(t *testing.T) {
_, err := dialTarget("https://10.255.255.1:9999", 100*time.Millisecond)
if err == nil {
t.Error("dialTarget() should return error for unreachable HTTPS address")
}
})
t.Run("成功连接", func(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer ln.Close()
go func() {
conn, _ := ln.Accept()
if conn != nil {
_ = conn.Close()
}
}()
addr := ln.Addr().String()
conn, err := dialTarget("http://"+addr, 1*time.Second)
if err != nil {
t.Errorf("dialTarget() error: %v", err)
}
if conn != nil {
_ = conn.Close()
}
})
}
// TestBackgroundRefresh_Extra 测试 backgroundRefresh 方法的额外场景。
func TestBackgroundRefresh_Extra(t *testing.T) {
t.Run("客户端不存在时直接返回", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
}
targets := []*loadbalance.Target{{URL: "http://nonexistent:9999"}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 删除客户端
p.mu.Lock()
delete(p.clients, targets[0].URL)
p.mu.Unlock()
ctx := testutil.NewRequestCtx("GET", "/api/test")
hashKey := uint64(12345)
// 应该不会 panic
p.backgroundRefresh(&ctx.Request, targets[0], hashKey, "GET:/api/test")
})
t.Run("缓存锁释放", func(t *testing.T) {
ln := fasthttputil.NewInmemoryListener()
defer ln.Close()
go func() {
s := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetBodyString("refreshed")
},
}
_ = s.Serve(ln)
}()
time.Sleep(10 * time.Millisecond)
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
}
targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
hashKey := uint64(12345)
p.backgroundRefresh(&ctx.Request, targets[0], hashKey, "GET:/api/test")
})
}
// TestWebSocket_ErrorCases 测试 WebSocket 错误情况。
func TestWebSocket_ErrorCases(t *testing.T) {
t.Run("连接无效后端", func(t *testing.T) {
ctx := testutil.NewRequestCtxWithHeader("GET", "/ws", map[string]string{
"Upgrade": "websocket",
"Connection": "Upgrade",
})
target := &loadbalance.Target{URL: "http://127.0.0.1:1"}
target.Healthy.Store(true)
// 使用很短的超时
err := WebSocket(ctx, target, 10*time.Millisecond, nil)
if err == nil {
t.Error("WebSocket() should return error for invalid backend")
}
})
}
// TestDialTarget_TLS_Extra 测试 TLS 连接。
func TestDialTarget_TLS_Extra(t *testing.T) {
t.Run("TLS 握手失败", func(t *testing.T) {
// 使用不可达的 HTTPS 地址
_, err := dialTarget("https://10.255.255.1:9999", 100*time.Millisecond)
if err == nil {
t.Error("dialTarget() should return error for unreachable HTTPS address")
}
})
}
// TestCreateHostClient_SSL 测试 SSL 配置。
func TestCreateHostClient_SSL(t *testing.T) {
t.Run("启用 SSL 验证", func(t *testing.T) {
sslCfg := &config.ProxySSLConfig{
Enabled: true,
InsecureSkipVerify: false,
}
client := createHostClient("https://example.com:443", config.ProxyTimeout{
Connect: 5 * time.Second,
}, nil, sslCfg, "", nil)
if client == nil {
t.Error("createHostClient() returned nil")
}
if client.TLSConfig == nil {
t.Error("TLSConfig should be set for HTTPS target")
}
})
}
// TestBackgroundRefresh_Revalidate 测试缓存后台刷新的 Revalidate 功能。
func TestBackgroundRefresh_Revalidate(t *testing.T) {
t.Run("Revalidate 启用但无缓存条目", func(t *testing.T) {
ln := fasthttputil.NewInmemoryListener()
defer ln.Close()
go func() {
s := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetBodyString("refreshed")
},
}
_ = s.Serve(ln)
}()
time.Sleep(10 * time.Millisecond)
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
Revalidate: true,
},
}
targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
hashKey := uint64(12345)
// 无缓存条目时调用 backgroundRefresh
p.backgroundRefresh(&ctx.Request, targets[0], hashKey, "GET:/api/test")
})
}
// TestSelectByBalancer 测试 selectByBalancer 方法。
func TestSelectByBalancer(t *testing.T) {
t.Run("IPHash 负载均衡", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "ip_hash",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
}
for _, t := range targets {
t.Healthy.Store(true)
}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 使用不同 IP 的请求应选择不同目标
ctx1 := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{
"X-Forwarded-For": "192.168.1.1",
})
ctx2 := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{
"X-Forwarded-For": "192.168.1.2",
})
selected1 := p.selectByBalancer(ctx1, targets)
selected2 := p.selectByBalancer(ctx2, targets)
if selected1 == nil || selected2 == nil {
t.Error("selectByBalancer() should return a target")
}
})
t.Run("ConsistentHash 负载均衡", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "consistent_hash",
VirtualNodes: 100,
HashKey: "uri",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
}
for _, t := range targets {
t.Healthy.Store(true)
}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/users/123")
selected := p.selectByBalancer(ctx, targets)
if selected == nil {
t.Error("selectByBalancer() should return a target for consistent_hash")
}
// 相同 URI 应选择相同目标
selected2 := p.selectByBalancer(ctx, targets)
if selected2 == nil || selected.URL != selected2.URL {
t.Error("consistent_hash should return same target for same URI")
}
})
t.Run("ConsistentHash with header key", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "consistent_hash",
VirtualNodes: 100,
HashKey: "header:X-User-Id",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
}
for _, t := range targets {
t.Healthy.Store(true)
}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{
"X-User-Id": "user-123",
})
selected := p.selectByBalancer(ctx, targets)
if selected == nil {
t.Error("selectByBalancer() should return a target for header-based hash")
}
})
t.Run("RoundRobin 负载均衡", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
}
for _, t := range targets {
t.Healthy.Store(true)
}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
selected := p.selectByBalancer(ctx, targets)
if selected == nil {
t.Error("selectByBalancer() should return a target for round_robin")
}
})
}
// TestSelectTargetExcluding_Extra 测试 selectTargetExcluding 方法。
func TestSelectTargetExcluding_Extra(t *testing.T) {
t.Run("排除已失败目标", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
{URL: "http://backend3:8080"},
}
for _, t := range targets {
t.Healthy.Store(true)
}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
// 排除第一个目标
excluded := []*loadbalance.Target{targets[0]}
selected := p.selectTargetExcluding(ctx, excluded)
if selected == nil {
t.Error("selectTargetExcluding() should return a target")
}
if selected != nil && selected.URL == targets[0].URL {
t.Error("selectTargetExcluding() should not return excluded target")
}
})
t.Run("排除所有目标返回 nil", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
// 排除所有目标
excluded := []*loadbalance.Target{targets[0]}
selected := p.selectTargetExcluding(ctx, excluded)
if selected != nil {
t.Error("selectTargetExcluding() should return nil when all targets excluded")
}
})
}
// TestExtractHashKey_Extra 测试 extractHashKey 方法。
func TestExtractHashKey_Extra(t *testing.T) {
t.Run("使用 IP 作为 hash key", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "consistent_hash",
HashKey: "ip",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{
"X-Forwarded-For": "10.0.0.1",
})
key := p.extractHashKey(ctx, "ip")
if key != "10.0.0.1" {
t.Errorf("extractHashKey() = %q, want %q", key, "10.0.0.1")
}
})
t.Run("使用 URI 作为 hash key", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "consistent_hash",
HashKey: "uri",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/users/123")
key := p.extractHashKey(ctx, "uri")
if key != "/api/users/123" {
t.Errorf("extractHashKey() = %q, want %q", key, "/api/users/123")
}
})
t.Run("使用 header 作为 hash key", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "consistent_hash",
HashKey: "header:X-Session-Id",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{
"X-Session-Id": "session-abc-123",
})
key := p.extractHashKey(ctx, "header:X-Session-Id")
if key != "session-abc-123" {
t.Errorf("extractHashKey() = %q, want %q", key, "session-abc-123")
}
})
t.Run("header 不存在时回退到 IP", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "consistent_hash",
HashKey: "header:X-Nonexistent",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{
"X-Forwarded-For": "10.0.0.5",
})
key := p.extractHashKey(ctx, "header:X-Nonexistent")
if key != "10.0.0.5" {
t.Errorf("extractHashKey() should fallback to IP, got %q", key)
}
})
}
// TestSelectTarget_LuaEnabled 测试 selectTarget 在 Lua 启用时的行为。
func TestSelectTarget_LuaEnabled(t *testing.T) {
t.Run("Lua 引擎为 nil 时使用传统算法", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
BalancerByLua: config.BalancerByLuaConfig{
Enabled: true,
Script: "/nonexistent.lua",
},
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
}
for _, t := range targets {
t.Healthy.Store(true)
}
// luaEngine 为 nil
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
selected := p.selectTarget(ctx)
if selected == nil {
t.Error("selectTarget() should return a target using fallback")
}
})
t.Run("Lua 脚本为空时使用传统算法", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
BalancerByLua: config.BalancerByLuaConfig{
Enabled: true,
Script: "",
},
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
}
targets[0].Healthy.Store(true)
luaEngine, err := lua.NewEngine(nil)
if err != nil {
t.Fatalf("NewEngine() error: %v", err)
}
p, err := NewProxy(cfg, targets, nil, luaEngine)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
selected := p.selectTarget(ctx)
if selected == nil {
t.Error("selectTarget() should return a target using traditional balancer")
}
})
}
// TestBackgroundRefresh_304 测试后台刷新收到 304 响应。
func TestBackgroundRefresh_304(t *testing.T) {
ln := fasthttputil.NewInmemoryListener()
defer ln.Close()
go func() {
s := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
// 检查条件请求头
if ctx.Request.Header.Peek("If-Modified-Since") != nil ||
ctx.Request.Header.Peek("If-None-Match") != nil {
ctx.SetStatusCode(304)
ctx.Response.Header.Set("Last-Modified", "Wed, 21 Oct 2015 07:28:00 GMT")
ctx.Response.Header.Set("ETag", "\"abc123\"")
return
}
ctx.SetStatusCode(200)
ctx.SetBodyString("fresh content")
},
}
_ = s.Serve(ln)
}()
time.Sleep(10 * time.Millisecond)
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
Revalidate: true,
},
}
targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 预先设置缓存条目
ctx := testutil.NewRequestCtx("GET", "/api/test")
hashKey, origKey := p.buildCacheKeyHash(ctx)
p.cache.Set(hashKey, origKey, []byte("cached"), map[string]string{
"Last-Modified": "Tue, 20 Oct 2015 07:28:00 GMT",
"ETag": "\"old\"",
}, 200, 10*time.Second)
// 调用后台刷新
p.backgroundRefresh(&ctx.Request, targets[0], hashKey, origKey)
}