diff --git a/internal/middleware/security/access_coverage_test.go b/internal/middleware/security/access_coverage_test.go new file mode 100644 index 0000000..1595bdb --- /dev/null +++ b/internal/middleware/security/access_coverage_test.go @@ -0,0 +1,340 @@ +// Package security 提供访问控制覆盖测试。 +// +// 该文件补充测试 access.go 中未覆盖的方法: +// - Name() 方法 +// - Process() 完整处理链(允许/拒绝路径) +// - getClientIP() 通过 Process 间接测试 +// - Close() 方法 +// - actionToString 边缘情况 +// - trustedProxies 相关逻辑 +// +// 作者:xfy +package security + +import ( + "net" + "testing" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/config" +) + +// TestAccessControlName 测试 Name 方法 +func TestAccessControlName(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + Default: "allow", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + name := ac.Name() + if name != "access_control" { + t.Errorf("Name() = %q, want 'access_control'", name) + } +} + +// TestAccessControlProcess_AllowPath 测试 Process 允许路径 +func TestAccessControlProcess_AllowPath(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + Default: "allow", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + called := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + called = true + _, _ = ctx.WriteString("allowed") + } + + handler := ac.Process(nextHandler) + if handler == nil { + t.Fatal("Process() returned nil") + } + + ctx := &fasthttp.RequestCtx{} + handler(ctx) + + if !called { + t.Error("Process() should call next handler when access allowed") + } + if string(ctx.Response.Body()) != "allowed" { + t.Errorf("Process() body = %q, want 'allowed'", string(ctx.Response.Body())) + } +} + +// TestAccessControlProcess_DenyPath 测试 Process 拒绝路径 +func TestAccessControlProcess_DenyPath(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + Default: "deny", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + called := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + called = true + } + + handler := ac.Process(nextHandler) + ctx := &fasthttp.RequestCtx{} + handler(ctx) + + if called { + t.Error("Process() should NOT call next handler when access denied") + } + if ctx.Response.StatusCode() != fasthttp.StatusForbidden { + t.Errorf("Process() status = %d, want 403", ctx.Response.StatusCode()) + } +} + +// TestAccessControlProcess_ExplicitAllow 测试显式允许列表 +func TestAccessControlProcess_ExplicitAllow(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + Allow: []string{"127.0.0.1"}, + Default: "deny", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + called := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + called = true + } + + handler := ac.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}) + handler(ctx) + + if !called { + t.Error("Process() should call next handler for allowed IP") + } +} + +// TestAccessControlProcess_ExplicitDeny 测试显式拒绝列表 +func TestAccessControlProcess_ExplicitDeny(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + Deny: []string{"10.0.0.1"}, + Default: "allow", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + called := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + called = true + } + + handler := ac.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 12345}) + handler(ctx) + + if called { + t.Error("Process() should NOT call next handler for denied IP") + } +} + +// TestAccessControlProcess_TrustedProxies_XFF 测试可信代理 XFF 解析 +func TestAccessControlProcess_TrustedProxies_XFF(t *testing.T) { + // 配置可信代理,10.0.0.0/8 为可信代理段 + ac, err := NewAccessControl(&config.AccessConfig{ + TrustedProxies: []string{"10.0.0.0/8"}, + Default: "deny", + Allow: []string{"192.168.1.0/24"}, + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + called := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + called = true + } + + handler := ac.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + // 请求来自可信代理 10.0.0.1,XFF 中包含真实客户端 192.168.1.100 + ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 12345}) + ctx.Request.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1") + handler(ctx) + + if !called { + t.Error("Process() should allow real client IP behind trusted proxy") + } +} + +// TestAccessControlProcess_TrustedProxies_UntrustedSource 测试不可信来源不解析 XFF +func TestAccessControlProcess_TrustedProxies_UntrustedSource(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + TrustedProxies: []string{"10.0.0.0/8"}, + Default: "deny", + Allow: []string{"192.168.1.0/24"}, + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + called := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + called = true + } + + handler := ac.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + // 请求来自不可信地址,即使 XFF 包含允许列表 IP 也不应解析 + ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("203.0.113.1"), Port: 12345}) + ctx.Request.Header.Set("X-Forwarded-For", "192.168.1.100") + handler(ctx) + + if called { + t.Error("Process() should not trust XFF from untrusted source") + } +} + +// TestAccessControlProcess_TrustedProxies_XRealIP 测试可信代理 X-Real-IP +func TestAccessControlProcess_TrustedProxies_XRealIP(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + TrustedProxies: []string{"10.0.0.0/8"}, + Default: "deny", + Allow: []string{"192.168.1.0/24"}, + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + called := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + called = true + } + + handler := ac.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("10.0.0.50"), Port: 12345}) + ctx.Request.Header.Set("X-Real-IP", "192.168.1.50") + handler(ctx) + + if !called { + t.Error("Process() should use X-Real-IP from trusted proxy") + } +} + +// TestAccessControlClose 测试 Close 方法 +func TestAccessControlClose(t *testing.T) { + // 无 GeoIP 的 Close + ac, err := NewAccessControl(&config.AccessConfig{ + Default: "allow", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + err = ac.Close() + if err != nil { + t.Errorf("Close() error: %v", err) + } +} + +// TestActionToString 测试 actionToString 边缘情况 +func TestActionToString(t *testing.T) { + // 测试 ActionAllow + result := actionToString(ActionAllow) + if result != "allow" { + t.Errorf("actionToString(ActionAllow) = %q, want 'allow'", result) + } + + // 测试 ActionDeny + result = actionToString(ActionDeny) + if result != "deny" { + t.Errorf("actionToString(ActionDeny) = %q, want 'deny'", result) + } + + // 测试未知值 + result = actionToString(Action(999)) + if result != "unknown" { + t.Errorf("actionToString(999) = %q, want 'unknown'", result) + } +} + +// TestGetStatsWithEmpty 测试 GetStats 空列表 +func TestGetStatsWithEmpty(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + Default: "allow", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + stats := ac.GetStats() + if stats.AllowCount != 0 { + t.Errorf("GetStats().AllowCount = %d, want 0", stats.AllowCount) + } + if stats.DenyCount != 0 { + t.Errorf("GetStats().DenyCount = %d, want 0", stats.DenyCount) + } + if stats.Default != "allow" { + t.Errorf("GetStats().Default = %q, want 'allow'", stats.Default) + } +} + +// TestSetDefaultValidCases 测试 SetDefault 所有有效值 +func TestSetDefaultValidCases(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + Default: "allow", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + // 切换为 deny + err = ac.SetDefault("deny") + if err != nil { + t.Errorf("SetDefault('deny') error: %v", err) + } + stats := ac.GetStats() + if stats.Default != "deny" { + t.Errorf("After SetDefault('deny'), Default = %q, want 'deny'", stats.Default) + } + + // 切换回 allow + err = ac.SetDefault("allow") + if err != nil { + t.Errorf("SetDefault('allow') error: %v", err) + } + stats = ac.GetStats() + if stats.Default != "allow" { + t.Errorf("After SetDefault('allow'), Default = %q, want 'allow'", stats.Default) + } + + // 大小写不敏感 + err = ac.SetDefault("DENY") + if err != nil { + t.Errorf("SetDefault('DENY') error: %v", err) + } +} + +// TestUpdateDenyListError 测试 UpdateDenyList 错误路径 +func TestUpdateDenyListError(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + Default: "allow", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + err = ac.UpdateDenyList([]string{"not-an-ip"}) + if err == nil { + t.Error("UpdateDenyList() should return error for invalid CIDR") + } +} diff --git a/internal/proxy/proxy_coverage_test.go b/internal/proxy/proxy_coverage_test.go new file mode 100644 index 0000000..e59fd0b --- /dev/null +++ b/internal/proxy/proxy_coverage_test.go @@ -0,0 +1,720 @@ +// Package proxy 提供反向代理覆盖测试,补充 proxy.go 中未覆盖的方法。 +// +// 该文件测试代理模块的以下功能: +// - selectTargetExcluding 排除已失败目标的选择 +// - extractHashKey 哈希键提取 +// - buildCacheKeyHash / buildCacheKeyHashValue 缓存键计算 +// - writeCachedResponse 缓存响应写入 +// - GetCache / GetCacheStats 缓存访问 +// - getCacheDuration 不同状态码的缓存时间 +// - redirect_rewrite 相关功能 +// +// 作者:xfy +package proxy + +import ( + "testing" + "time" + + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttputil" + "rua.plus/lolly/internal/cache" + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/loadbalance" + "rua.plus/lolly/internal/testutil" +) + +// TestSelectTargetExcluding 测试排除失败目标的目标选择 +func TestSelectTargetExcluding(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + {URL: "http://backend3:8080"}, + } + for _, target := range targets { + target.Healthy.Store(true) + } + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + + // 排除第一个目标,应该选择第二个 + excluded := []*loadbalance.Target{targets[0]} + selected := p.selectTargetExcluding(ctx, excluded) + if selected == nil { + t.Fatal("selectTargetExcluding() returned nil") + } + if selected.URL == "http://backend1:8080" { + t.Error("selectTargetExcluding() should not select excluded target") + } + + // 排除所有目标,应该返回 nil + allExcluded := []*loadbalance.Target{targets[0], targets[1], targets[2]} + selected = p.selectTargetExcluding(ctx, allExcluded) + if selected != nil { + t.Error("selectTargetExcluding() should return nil when all excluded") + } + + // 空目标列表 + p2, _ := NewProxy(cfg, []*loadbalance.Target{{URL: "http://a:1"}}, nil, nil) + p2.targets = nil + selected = p2.selectTargetExcluding(ctx, nil) + if selected != nil { + t.Error("selectTargetExcluding() should return nil for empty targets") + } +} + +// TestSelectTargetExcluding_IPHash 测试 IP Hash 排除选择 +func TestSelectTargetExcluding_IPHash(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "ip_hash", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + {URL: "http://backend3:8080"}, + } + for _, target := range targets { + target.Healthy.Store(true) + } + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{ + "X-Forwarded-For": "192.168.1.1", + }) + + // 获取第一次选择 + first := p.selectTarget(ctx) + if first == nil { + t.Fatal("selectTarget() returned nil") + } + + // 排除第一次选择的目标 + excluded := []*loadbalance.Target{first} + second := p.selectTargetExcluding(ctx, excluded) + if second == nil { + t.Fatal("selectTargetExcluding() returned nil") + } + if second.URL == first.URL { + t.Errorf("selectTargetExcluding() should not select same target %s", first.URL) + } +} + +// TestSelectTargetExcluding_ConsistentHash 测试一致性哈希排除选择 +func TestSelectTargetExcluding_ConsistentHash(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "consistent_hash", + HashKey: "uri", + VirtualNodes: 150, + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + } + for _, target := range targets { + target.Healthy.Store(true) + } + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + + // 排除一个目标,应该还能选到另一个 + excluded := []*loadbalance.Target{targets[0]} + selected := p.selectTargetExcluding(ctx, excluded) + if selected == nil { + t.Error("selectTargetExcluding() should return remaining target") + } +} + +// TestExtractHashKey 测试哈希键提取 +func TestExtractHashKey(t *testing.T) { + tests := []struct { + name string + hashKey string + headers map[string]string + expected string + }{ + { + name: "ip hash key", + hashKey: "ip", + headers: map[string]string{"X-Forwarded-For": "10.0.0.1"}, + expected: "10.0.0.1", + }, + { + name: "empty hash key defaults to ip", + hashKey: "", + headers: map[string]string{"X-Forwarded-For": "10.0.0.2"}, + expected: "10.0.0.2", + }, + { + name: "uri hash key", + hashKey: "uri", + headers: nil, + expected: "/api/test", + }, + { + name: "header hash key - found", + hashKey: "header:X-Custom-ID", + headers: map[string]string{"X-Custom-ID": "abc123"}, + expected: "abc123", + }, + { + name: "header hash key - fallback to ip", + hashKey: "header:X-Missing", + headers: map[string]string{"X-Forwarded-For": "10.0.0.3"}, + expected: "10.0.0.3", + }, + { + name: "unknown hash key defaults to ip", + hashKey: "unknown", + headers: map[string]string{"X-Forwarded-For": "10.0.0.4"}, + expected: "10.0.0.4", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "consistent_hash", + HashKey: tt.hashKey, + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", tt.headers) + result := p.extractHashKey(ctx, tt.hashKey) + if result != tt.expected { + t.Errorf("extractHashKey() = %q, want %q", result, tt.expected) + } + }) + } +} + +// TestBuildCacheKeyHash 测试缓存键哈希计算 +func TestBuildCacheKeyHash(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + + hashKey, origKey := p.buildCacheKeyHash(ctx) + if hashKey == 0 { + t.Error("buildCacheKeyHash() should return non-zero hash") + } + if origKey == "" { + t.Error("buildCacheKeyHash() should return non-empty origKey") + } + + // 相同请求应产生相同哈希 + ctx2 := testutil.NewRequestCtx("GET", "/api/test") + hashKey2, _ := p.buildCacheKeyHash(ctx2) + if hashKey != hashKey2 { + t.Error("Same request should produce same hash") + } + + // 不同请求应产生不同哈希 + ctx3 := testutil.NewRequestCtx("POST", "/api/other") + hashKey3, _ := p.buildCacheKeyHash(ctx3) + if hashKey == hashKey3 { + t.Error("Different request should produce different hash") + } +} + +// TestBuildCacheKeyHashValue 测试零分配缓存键哈希 +func TestBuildCacheKeyHashValue(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + + hashValue := p.buildCacheKeyHashValue(ctx) + if hashValue == 0 { + t.Error("buildCacheKeyHashValue() should return non-zero hash") + } + + // 应该与 buildCacheKeyHash 结果一致 + hashKey, _ := p.buildCacheKeyHash(ctx) + if hashValue != hashKey { + t.Error("buildCacheKeyHashValue() should match buildCacheKeyHash()") + } +} + +// TestWriteCachedResponse 测试缓存响应写入 +func TestWriteCachedResponse(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 手动创建一个 Response 用于验证 writeCachedResponse 写入正确 + ctx := testutil.NewRequestCtx("GET", "/api/test") + + entry := &cache.ProxyCacheEntry{ + Data: []byte("cached body"), + Headers: map[string]string{"Content-Type": "text/html", "X-Cached": "true"}, + Status: 200, + } + + p.writeCachedResponse(ctx, entry) + + if ctx.Response.StatusCode() != 200 { + t.Errorf("writeCachedResponse() status = %d, want 200", ctx.Response.StatusCode()) + } + if string(ctx.Response.Body()) != "cached body" { + t.Errorf("writeCachedResponse() body = %q, want %q", string(ctx.Response.Body()), "cached body") + } + ct := string(ctx.Response.Header.Peek("Content-Type")) + if ct != "text/html" { + t.Errorf("writeCachedResponse() Content-Type = %q, want %q", ct, "text/html") + } + xc := string(ctx.Response.Header.Peek("X-Cache")) + if xc != "HIT" { + t.Errorf("writeCachedResponse() X-Cache = %q, want HIT", xc) + } +} + +// TestGetCache 测试 GetCache 方法 +func TestGetCache(t *testing.T) { + // 启用缓存时 + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + }, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + c := p.GetCache() + if c == nil { + t.Error("GetCache() should return non-nil when cache enabled") + } + + // 禁用缓存时 + cfg2 := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + p2, _ := NewProxy(cfg2, targets, nil, nil) + c2 := p2.GetCache() + if c2 != nil { + t.Error("GetCache() should return nil when cache disabled") + } +} + +// TestGetCacheStats 测试 GetCacheStats 方法 +func TestGetCacheStats(t *testing.T) { + // 启用缓存时 + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + }, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + stats := p.GetCacheStats() + if stats == nil { + t.Error("GetCacheStats() should return non-nil when cache enabled") + } + + // 禁用缓存时 + cfg2 := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + p2, _ := NewProxy(cfg2, targets, nil, nil) + stats2 := p2.GetCacheStats() + if stats2 != nil { + t.Error("GetCacheStats() should return nil when cache disabled") + } +} + +// TestGetCacheDuration 测试不同状态码的缓存时间计算 +func TestGetCacheDuration(t *testing.T) { + tests := []struct { + name string + cacheValid *config.ProxyCacheValidConfig + maxAge time.Duration + statusCode int + expected time.Duration + }{ + { + name: "no CacheValid config uses MaxAge", + maxAge: 5 * time.Minute, + statusCode: 200, + expected: 5 * time.Minute, + }, + { + name: "2xx with CacheValid.OK set", + cacheValid: &config.ProxyCacheValidConfig{ + OK: 10 * time.Minute, + }, + statusCode: 200, + expected: 10 * time.Minute, + }, + { + name: "2xx with CacheValid.OK=0 inherits MaxAge", + cacheValid: &config.ProxyCacheValidConfig{ + OK: 0, + }, + maxAge: 3 * time.Minute, + statusCode: 201, + expected: 3 * time.Minute, + }, + { + name: "301 redirect", + cacheValid: &config.ProxyCacheValidConfig{ + Redirect: 1 * time.Hour, + }, + statusCode: 301, + expected: 1 * time.Hour, + }, + { + name: "302 redirect", + cacheValid: &config.ProxyCacheValidConfig{ + Redirect: 30 * time.Minute, + }, + statusCode: 302, + expected: 30 * time.Minute, + }, + { + name: "302 with zero Redirect means no cache", + cacheValid: &config.ProxyCacheValidConfig{ + Redirect: 0, + }, + statusCode: 302, + expected: 0, + }, + { + name: "404", + cacheValid: &config.ProxyCacheValidConfig{ + NotFound: 1 * time.Minute, + }, + statusCode: 404, + expected: 1 * time.Minute, + }, + { + name: "4xx client error", + cacheValid: &config.ProxyCacheValidConfig{ + ClientError: 30 * time.Second, + }, + statusCode: 400, + expected: 30 * time.Second, + }, + { + name: "5xx server error", + cacheValid: &config.ProxyCacheValidConfig{ + ServerError: 0, + }, + statusCode: 500, + expected: 0, + }, + { + name: "other status code", + statusCode: 100, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: tt.maxAge, + }, + CacheValid: tt.cacheValid, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + duration := p.getCacheDuration(tt.statusCode) + if duration != tt.expected { + t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, duration, tt.expected) + } + }) + } +} + +// TestBackgroundRefresh 测试后台缓存刷新(标记为 skip 因为需要真实网络) +func TestBackgroundRefresh(t *testing.T) { + t.Skip("skipping: requires real network connection and is timing-sensitive") + ln := fasthttputil.NewInmemoryListener() + defer func() { _ = ln.Close() }() + + go func() { + s := &fasthttp.Server{ + Handler: func(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(200) + ctx.SetBodyString("refreshed") + ctx.Response.Header.Set("Content-Type", "text/plain") + }, + } + _ = s.Serve(ln) + }() + time.Sleep(10 * time.Millisecond) + + addr := ln.Addr().String() + + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + }, + } + targets := []*loadbalance.Target{ + {URL: "http://" + addr}, + } + targets[0].Healthy.Store(true) + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + + // 设置缓存锁 + hashKey := uint64(12345) + p.cache.AcquireLock(hashKey) + + // 调用后台刷新(它会执行实际请求来刷新缓存) + done := make(chan struct{}) + go func() { + p.backgroundRefresh(ctx, targets[0], hashKey, "/api/test") + close(done) + }() + + // 等待完成 + select { + case <-done: + // 完成 + case <-time.After(2 * time.Second): + t.Fatal("backgroundRefresh() timed out") + } +} + +// TestBackgroundRefresh_NoClient 测试后台刷新时客户端不存在的情况 +func TestBackgroundRefresh_NoClient(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + }, + } + targets := []*loadbalance.Target{ + {URL: "http://nonexistent:9999"}, + } + targets[0].Healthy.Store(true) + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 移除客户端 + p.mu.Lock() + delete(p.clients, targets[0].URL) + p.mu.Unlock() + + ctx := testutil.NewRequestCtx("GET", "/api/test") + hashKey := uint64(99999) + p.cache.AcquireLock(hashKey) + + // 应该不会 panic,直接返回 + p.backgroundRefresh(ctx, targets[0], hashKey, "/api/test") +} + +// TestServeHTTP_CacheHit 测试缓存命中路径 +func TestServeHTTP_CacheHit(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + }, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + targets[0].Healthy.Store(true) + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 预填充缓存 + ctx := testutil.NewRequestCtx("GET", "/api/cached") + hashKey, origKey := p.buildCacheKeyHash(ctx) + p.cache.Set(hashKey, origKey, []byte("cached!"), map[string]string{ + "Content-Type": "text/plain", + }, 200, 10*time.Second) + + // 执行请求 + p.ServeHTTP(ctx) + + // 应该返回缓存的响应 + if ctx.Response.StatusCode() != 200 { + t.Errorf("ServeHTTP() status = %d, want 200", ctx.Response.StatusCode()) + } + if string(ctx.Response.Body()) != "cached!" { + t.Errorf("ServeHTTP() body = %q, want %q", string(ctx.Response.Body()), "cached!") + } + xc := string(ctx.Response.Header.Peek("X-Cache")) + if xc != "HIT" { + t.Errorf("ServeHTTP() X-Cache = %q, want HIT", xc) + } +} + +// TestServeHTTP_ClientNil 测试客户端为 nil 时的行为 +func TestServeHTTP_ClientNil(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + NextUpstream: config.NextUpstreamConfig{ + Tries: 2, + }, + } + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + } + for _, target := range targets { + target.Healthy.Store(true) + } + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 移除所有客户端 + p.mu.Lock() + p.clients = make(map[string]*fasthttp.HostClient) + p.mu.Unlock() + + ctx := testutil.NewRequestCtx("GET", "/api/test") + p.ServeHTTP(ctx) + + // 所有客户端都不存在,应该返回 502 + if ctx.Response.StatusCode() != fasthttp.StatusBadGateway { + t.Errorf("ServeHTTP() status = %d, want 502", ctx.Response.StatusCode()) + } +} + +// TestServeHTTP_WithRedirectRewrite 测试带 redirect_rewrite 的缓存命中 +func TestServeHTTP_WithRedirectRewrite_CacheHit(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + }, + RedirectRewrite: &config.RedirectRewriteConfig{ + Mode: "off", // 关闭改写 + }, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + targets[0].Healthy.Store(true) + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + hashKey, origKey := p.buildCacheKeyHash(ctx) + p.cache.Set(hashKey, origKey, []byte("ok"), map[string]string{ + "Content-Type": "text/plain", + }, 200, 10*time.Second) + + p.ServeHTTP(ctx) + + if ctx.Response.StatusCode() != 200 { + t.Errorf("ServeHTTP() status = %d, want 200", ctx.Response.StatusCode()) + } +} diff --git a/internal/server/server_extended_test.go b/internal/server/server_extended_test.go new file mode 100644 index 0000000..5116ec5 --- /dev/null +++ b/internal/server/server_extended_test.go @@ -0,0 +1,465 @@ +// Package server 提供 HTTP 服务器的核心实现测试补充。 +// +// 该文件补充以下测试覆盖: +// - 多模式启动逻辑测试(single/vhost/multi_server/auto) +// - 多服务器模式 shutdownServers 函数测试 +// - 监听器创建测试(TCP/Unix socket) +// - StopWithTimeout 超时行为测试 +// - GracefulStop 超时行为测试 +// - 中间件链错误路径测试 + +package server + +import ( + "context" + "net" + "testing" + "time" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/resolver" +) + +// TestServer_GetMode_Single 测试单服务器模式 +func TestServer_GetMode_Single(t *testing.T) { + cfg := &config.Config{ + Mode: config.ServerModeSingle, + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + if s.config.GetMode() != config.ServerModeSingle { + t.Errorf("Expected mode single, got %s", s.config.GetMode()) + } +} + +// TestServer_GetMode_VHost 测试虚拟主机模式 +func TestServer_GetMode_VHost(t *testing.T) { + cfg := &config.Config{ + Mode: config.ServerModeVHost, + Servers: []config.ServerConfig{ + {Listen: ":0", Name: "host1.example.com"}, + {Listen: ":0", Name: "host2.example.com"}, + }, + } + + s := New(cfg) + if s.config.GetMode() != config.ServerModeVHost { + t.Errorf("Expected mode vhost, got %s", s.config.GetMode()) + } +} + +// TestServer_GetMode_MultiServer 测试多服务器模式 +func TestServer_GetMode_MultiServer(t *testing.T) { + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + {Listen: ":8080", Name: "server1"}, + {Listen: ":8081", Name: "server2"}, + }, + } + + s := New(cfg) + if s.config.GetMode() != config.ServerModeMultiServer { + t.Errorf("Expected mode multi_server, got %s", s.config.GetMode()) + } +} + +// TestServer_GetMode_Auto 测试自动模式 +func TestServer_GetMode_Auto(t *testing.T) { + cfg := &config.Config{ + Mode: config.ServerModeAuto, + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + mode := s.config.GetMode() + if mode != config.ServerModeSingle { + t.Errorf("Expected auto to resolve to single, got %s", mode) + } +} + +// TestShutdownServers_Empty 测试空服务器列表关闭 +func TestShutdownServers_Empty(t *testing.T) { + err := shutdownServers(nil, nil) + if err != nil { + t.Errorf("Expected nil error for empty servers, got %v", err) + } +} + +// TestShutdownServers_NilServer 测试含 nil 的服务器列表关闭 +func TestShutdownServers_NilServer(t *testing.T) { + servers := []*fasthttp.Server{nil, nil} + err := shutdownServers(nil, servers) + if err != nil { + t.Errorf("Expected nil error with nil servers, got %v", err) + } +} + +// TestShutdownServers_Timeout 测试关闭超时 +func TestShutdownServers_Timeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + + fastSrv := &fasthttp.Server{ + Handler: func(ctx *fasthttp.RequestCtx) { + select {} + }, + } + + go func() { + _ = fastSrv.Serve(ln) + }() + + time.Sleep(50 * time.Millisecond) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + err = shutdownServers(ctx, []*fasthttp.Server{fastSrv}) + // context 超时后 shutdownServers 会返回 ctx.Err() + if err != nil && err != context.DeadlineExceeded { + t.Errorf("Expected context.DeadlineExceeded or nil, got: %v", err) + } + + _ = fastSrv.Shutdown() + _ = ln.Close() +} + +// TestStopWithTimeout_DefaultTimeout 测试零超时使用默认值 +func TestStopWithTimeout_DefaultTimeout(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + + err := s.StopWithTimeout(0) + if err != nil { + t.Errorf("StopWithTimeout(0) should succeed, got %v", err) + } + + err = s.StopWithTimeout(-1 * time.Second) + if err != nil { + t.Errorf("StopWithTimeout(-1s) should succeed, got %v", err) + } +} + +// TestStopWithTimeout_MultiServerMode 测试多服务器模式停止 +func TestStopWithTimeout_MultiServerMode(t *testing.T) { + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + {Listen: ":0", Name: "server1"}, + {Listen: ":0", Name: "server2"}, + }, + } + + s := New(cfg) + + err := s.StopWithTimeout(1 * time.Second) + if err != nil { + t.Errorf("StopWithTimeout on non-started multi-server should succeed: %v", err) + } +} + +// TestGracefulStop_Timeout 测试优雅停止超时 +func TestGracefulStop_Timeout(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + + s.fastServer = &fasthttp.Server{ + Handler: func(ctx *fasthttp.RequestCtx) { + ctx.SetBodyString("ok") + }, + } + s.running = true + + err := s.GracefulStop(100 * time.Millisecond) + if err != nil { + t.Errorf("GracefulStop should succeed: %v", err) + } +} + +// TestServer_SetUpgradeManager 测试设置升级管理器 +func TestServer_SetUpgradeManager(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + mgr := NewUpgradeManager(s) + + s.SetUpgradeManager(mgr) + if s.upgradeManager != mgr { + t.Error("upgradeManager not set correctly") + } +} + +// mockResolver 用于测试的 mock DNS 解析器 +type mockResolver struct{} + +func (m *mockResolver) LookupHost(ctx context.Context, host string) ([]string, error) { + return []string{"127.0.0.1"}, nil +} + +func (m *mockResolver) LookupHostWithCache(ctx context.Context, host string) ([]string, error) { + return []string{"127.0.0.1"}, nil +} + +func (m *mockResolver) Refresh(host string) error { return nil } +func (m *mockResolver) Start() error { return nil } +func (m *mockResolver) Stop() error { return nil } +func (m *mockResolver) Stats() resolver.Stats { return resolver.Stats{} } + +// TestServer_Resolver 测试 DNS 解析器设置 +func TestServer_Resolver(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + + if s.GetResolver() != nil { + t.Error("Expected nil resolver initially") + } + + mockRes := &mockResolver{} + s.SetResolver(mockRes) + + if s.GetResolver() == nil { + t.Error("Resolver not set correctly") + } +} + +// TestCreateListener_TCP 测试 TCP 监听器创建 +func TestCreateListener_TCP(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + serverCfg := &cfg.Servers[0] + + ln, err := s.createListener(serverCfg) + if err != nil { + t.Fatalf("createListener failed: %v", err) + } + defer func() { _ = ln.Close() }() + + if ln == nil { + t.Fatal("Expected non-nil listener") + } + + addr := ln.Addr().(*net.TCPAddr) + if addr.Port == 0 { + t.Error("Expected non-zero port") + } +} + +// TestCreateListener_InvalidTCP 测试无效 TCP 监听地址 +func TestCreateListener_InvalidTCP(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "invalid:address:format", + }}, + } + + s := New(cfg) + _, err := s.createListener(&cfg.Servers[0]) + if err == nil { + t.Error("Expected error for invalid listen address") + } +} + +// TestListenerManagement 测试监听器管理 +func TestListenerManagement(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + + ln1, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener 1: %v", err) + } + defer func() { _ = ln1.Close() }() + + ln2, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener 2: %v", err) + } + defer func() { _ = ln2.Close() }() + + s.SetListeners([]net.Listener{ln1, ln2}) + + listeners := s.GetListeners() + if len(listeners) != 2 { + t.Errorf("Expected 2 listeners, got %d", len(listeners)) + } +} + +// TestStart_WithGoroutinePoolAndFileCache 测试同时启用 GoroutinePool 和 FileCache +func TestStart_WithGoroutinePoolAndFileCache(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + Performance: config.PerformanceConfig{ + GoroutinePool: config.GoroutinePoolConfig{ + Enabled: true, + MaxWorkers: 50, + MinWorkers: 5, + }, + FileCache: config.FileCacheConfig{ + MaxEntries: 500, + MaxSize: 50 * 1024 * 1024, + }, + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("New() returned nil") + } + + if s.config.Performance.GoroutinePool.Enabled != true { + t.Error("GoroutinePool should be enabled") + } +} + +// TestServer_GetHandler_NilThenSet 测试 handler 的 nil 到设置 +func TestServer_GetHandler_NilThenSet(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + + if s.GetHandler() != nil { + t.Error("Expected nil handler initially") + } + + testHandler := func(ctx *fasthttp.RequestCtx) { + ctx.SetBodyString("test handler response") + } + s.handler = testHandler + + got := s.GetHandler() + if got == nil { + t.Error("Expected non-nil handler after setting") + } + + ctx := &fasthttp.RequestCtx{} + got(ctx) + if string(ctx.Response.Body()) != "test handler response" { + t.Errorf("Handler response = %q, want %q", string(ctx.Response.Body()), "test handler response") + } +} + +// TestServer_TrackStats_Concurrent 测试并发统计追踪 +func TestServer_TrackStats_Concurrent(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + + handler := func(ctx *fasthttp.RequestCtx) { + ctx.SetBodyString("ok") + } + + wrappedHandler := s.trackStats(handler) + + const numGoroutines = 100 + done := make(chan bool, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + wrappedHandler(ctx) + done <- true + }() + } + + for i := 0; i < numGoroutines; i++ { + <-done + } + + if s.requests.Load() != int64(numGoroutines) { + t.Errorf("Expected %d requests, got %d", numGoroutines, s.requests.Load()) + } +} + +// TestBuildMiddlewareChain_BodyLimit 测试请求体限制中间件 +func TestBuildMiddlewareChain_BodyLimit(t *testing.T) { + cfg := &config.Config{ + Logging: config.LoggingConfig{}, + Servers: []config.ServerConfig{{ + Listen: ":0", + Proxy: []config.ProxyConfig{{ + Path: "/api/", + ClientMaxBodySize: "1MB", + }}, + ClientMaxBodySize: "10MB", + }}, + } + + s := New(cfg) + chain, err := s.buildMiddlewareChain(&cfg.Servers[0]) + if err != nil { + t.Errorf("buildMiddlewareChain failed: %v", err) + } + if chain == nil { + t.Error("Expected non-nil chain") + } +} + +// TestBuildMiddlewareChain_BodyLimit_Invalid 测试无效的请求体限制 +func TestBuildMiddlewareChain_BodyLimit_Invalid(t *testing.T) { + cfg := &config.Config{ + Logging: config.LoggingConfig{}, + Servers: []config.ServerConfig{{ + Listen: ":0", + ClientMaxBodySize: "invalid_size", + }}, + } + + s := New(cfg) + _, err := s.buildMiddlewareChain(&cfg.Servers[0]) + if err == nil { + t.Error("Expected error for invalid body limit size") + } +} diff --git a/internal/stream/stream_coverage_test.go b/internal/stream/stream_coverage_test.go new file mode 100644 index 0000000..fc86ccf --- /dev/null +++ b/internal/stream/stream_coverage_test.go @@ -0,0 +1,554 @@ +// Package stream 提供流代理覆盖测试。 +// +// 该文件补充测试 stream.go 中未覆盖的方法: +// - ipHash.Select() (空 IP) +// - handleConnection() 连接处理 +// - getOrCreateSession() 会话创建 +// - handleBackendResponse() 后端响应处理 +// - Stats 完整统计 +// +// 作者:xfy +package stream + +import ( + "net" + "testing" + "time" +) + +// TestIPHashSelect 测试 ipHash 的 Select 方法(空字符串 IP) +func TestIPHashSelect(t *testing.T) { + targets := []*Target{ + {addr: "localhost:8001"}, + {addr: "localhost:8002"}, + } + for _, target := range targets { + target.healthy.Store(true) + } + + ih := newIPHash() + + // Select() 使用空字符串作为 IP + selected := ih.Select(targets) + if selected == nil { + t.Error("Select() with empty IP should return a target") + } + + // 多次调用应返回相同目标(确定性哈希) + selected2 := ih.Select(targets) + if selected != selected2 { + t.Error("Select() with same empty IP should be consistent") + } + + // 无健康目标时应返回 nil + for _, target := range targets { + target.healthy.Store(false) + } + selected = ih.Select(targets) + if selected != nil { + t.Error("Select() with no healthy targets should return nil") + } +} + +// TestSelectByIPNoHealthy 测试 SelectByIP 无健康目标 +func TestSelectByIPNoHealthy(t *testing.T) { + targets := []*Target{ + {addr: "localhost:8001"}, + {addr: "localhost:8002"}, + } + + ih := newIPHash() + selected := ih.(*ipHash).SelectByIP(targets, "192.168.1.1") + if selected != nil { + t.Error("SelectByIP() with no healthy targets should return nil") + } +} + +// TestWeightedRoundRobinZeroWeight 测试零权重处理 +func TestWeightedRoundRobinZeroWeight(t *testing.T) { + targets := []*Target{ + {addr: "localhost:8001", weight: 0}, + {addr: "localhost:8002", weight: -1}, + } + for _, target := range targets { + target.healthy.Store(true) + } + + wrr := newWeightedRoundRobin().(*weightedRoundRobin) + + // 权重为 0 或负数应视为权重 1 + for i := 0; i < 4; i++ { + selected := wrr.Select(targets) + if selected == nil { + t.Error("Select() should return target with zero/negative weight") + return + } + } +} + +// TestHandleConnection 测试 handleConnection 方法 +func TestHandleConnection(t *testing.T) { + s := NewServer() + + // 添加上游配置 + targets := []TargetSpec{ + {Addr: "127.0.0.1:29001", Weight: 1}, + } + _ = s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{}) + s.upstreams["test"].targets[0].healthy.Store(true) + + // 创建模拟客户端连接(不会实际建立连接,测试无上游路径) + s.mu.Lock() + // 设置上游为空,测试无上游配置路径 + s.upstreams = make(map[string]*Upstream) + s.mu.Unlock() + + // 创建一对连接 + serverLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer func() { _ = serverLn.Close() }() + + clientConn, err := net.Dial("tcp", serverLn.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer func() { _ = clientConn.Close() }() + + serverConn, err := serverLn.Accept() + if err != nil { + t.Fatalf("Failed to accept: %v", err) + } + defer func() { _ = serverConn.Close() }() + + // 测试无上游配置的 handleConnection + s.handleConnection(clientConn, "127.0.0.1:0") +} + +// TestHandleConnection_NoHealthyTarget 测试无健康目标路径 +func TestHandleConnection_NoHealthyTarget(t *testing.T) { + s := NewServer() + + // 添加不健康的上游 + targets := []TargetSpec{ + {Addr: "127.0.0.1:29002", Weight: 1}, + } + _ = s.AddUpstream("test2", targets, "round_robin", HealthCheckSpec{}) + // 目标不健康(默认 false) + + serverLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer func() { _ = serverLn.Close() }() + + clientConn, err := net.Dial("tcp", serverLn.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer func() { _ = clientConn.Close() }() + + serverConn, err := serverLn.Accept() + if err != nil { + t.Fatalf("Failed to accept: %v", err) + } + defer func() { _ = serverConn.Close() }() + + done := make(chan struct{}) + go func() { + s.handleConnection(clientConn, "127.0.0.1:0") + close(done) + }() + + select { + case <-done: + // 完成 + case <-time.After(2 * time.Second): + t.Fatal("handleConnection() timed out") + } +} + +// TestHandleConnection_DialFail 测试连接目标失败路径 +func TestHandleConnection_DialFail(t *testing.T) { + s := NewServer() + + // 添加上游,目标不可达 + targets := []TargetSpec{ + {Addr: "127.0.0.1:29999", Weight: 1}, + } + _ = s.AddUpstream("test3", targets, "round_robin", HealthCheckSpec{}) + s.upstreams["test3"].targets[0].healthy.Store(true) + + serverLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer func() { _ = serverLn.Close() }() + + clientConn, err := net.Dial("tcp", serverLn.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer func() { _ = clientConn.Close() }() + + serverConn, err := serverLn.Accept() + if err != nil { + t.Fatalf("Failed to accept: %v", err) + } + defer func() { _ = serverConn.Close() }() + + done := make(chan struct{}) + go func() { + s.handleConnection(clientConn, "127.0.0.1:0") + close(done) + }() + + select { + case <-done: + // 完成 - 连接目标失败后应标记为不健康 + if s.upstreams["test3"].targets[0].healthy.Load() { + t.Error("Target should be marked unhealthy after dial failure") + } + case <-time.After(15 * time.Second): + t.Fatal("handleConnection() timed out") + } +} + +// TestGetOrCreateSession 测试 getOrCreateSession 方法 +func TestGetOrCreateSession(t *testing.T) { + // 创建 UDP 连接 + udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + conn, _ := net.ListenUDP("udp", udpAddr) + defer func() { _ = conn.Close() }() + + // 创建上游 + upstream := &Upstream{ + targets: []*Target{{addr: "127.0.0.1:29003"}}, + balancer: newRoundRobin(), + } + upstream.targets[0].healthy.Store(true) + + srv := newUDPServer(conn, upstream, 1*time.Minute) + + clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29010") + + // 第一次调用 - 应该创建新会话(但由于后端不可达,应该失败) + session, err := srv.getOrCreateSession(clientAddr) + if err != nil { + // 预期失败,因为后端不可达 + return + } + if session == nil { + t.Error("getOrCreateSession() should return a session") + } +} + +// TestGetOrCreateSession_DoubleCheck 测试双重检查锁定 +func TestGetOrCreateSession_DoubleCheck(t *testing.T) { + // 创建 UDP 连接 + udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + conn, _ := net.ListenUDP("udp", udpAddr) + defer func() { _ = conn.Close() }() + + // 创建上游 + upstream := &Upstream{ + targets: []*Target{{addr: "127.0.0.1:29004"}}, + balancer: newRoundRobin(), + } + upstream.targets[0].healthy.Store(true) + + srv := newUDPServer(conn, upstream, 1*time.Minute) + clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29011") + + // 手动创建一个会话来测试双重检查 + srv.mu.Lock() + testSession := &udpSession{ + clientAddr: clientAddr, + lastActive: time.Now(), + srv: srv, + } + srv.sessions[sessionKey(clientAddr)] = testSession + srv.mu.Unlock() + + // 再次获取应该返回现有会话 + session, err := srv.getOrCreateSession(clientAddr) + if err != nil { + t.Errorf("getOrCreateSession() should not error for existing session: %v", err) + } + if session != testSession { + t.Error("getOrCreateSession() should return existing session") + } +} + +// TestGetOrCreateSession_NoHealthyTarget 测试无健康目标 +func TestGetOrCreateSession_NoHealthyTarget(t *testing.T) { + udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + conn, _ := net.ListenUDP("udp", udpAddr) + defer func() { _ = conn.Close() }() + + upstream := &Upstream{ + targets: []*Target{{addr: "127.0.0.1:29005"}}, + balancer: newRoundRobin(), + } + // 不设置 healthy,默认为 false + + srv := newUDPServer(conn, upstream, 1*time.Minute) + clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29012") + + session, err := srv.getOrCreateSession(clientAddr) + if err == nil { + t.Error("getOrCreateSession() should return error when no healthy target") + } + if session != nil { + t.Error("getOrCreateSession() should return nil session on error") + } +} + +// TestGetOrCreateSession_InvalidTargetAddr 测试无效目标地址 +func TestGetOrCreateSession_InvalidTargetAddr(t *testing.T) { + udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + conn, _ := net.ListenUDP("udp", udpAddr) + defer func() { _ = conn.Close() }() + + upstream := &Upstream{ + targets: []*Target{{addr: "invalid-address"}}, + balancer: newRoundRobin(), + } + upstream.targets[0].healthy.Store(true) + + srv := newUDPServer(conn, upstream, 1*time.Minute) + clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29013") + + session, err := srv.getOrCreateSession(clientAddr) + if err == nil { + t.Error("getOrCreateSession() should return error for invalid target address") + } + if session != nil { + t.Error("getOrCreateSession() should return nil session on error") + } +} + +// TestHandleBackendResponse 测试 handleBackendResponse 超时清理路径 +func TestHandleBackendResponse(t *testing.T) { + // 创建 UDP 连接(服务端) + udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + conn, _ := net.ListenUDP("udp", udpAddr) + defer func() { _ = conn.Close() }() + + // 创建上游 + upstream := &Upstream{ + targets: []*Target{{addr: "127.0.0.1:29006"}}, + balancer: newRoundRobin(), + } + upstream.targets[0].healthy.Store(true) + + srv := newUDPServer(conn, upstream, 50*time.Millisecond) // 短超时 + + clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29014") + + // 创建目标连接(监听器用于接收目标连接) + targetAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29006") + targetConn, _ := net.ListenUDP("udp", targetAddr) + defer func() { _ = targetConn.Close() }() + + // 创建会话时需要先添加到 WaitGroup(handleBackendResponse 会调用 Done) + srv.wg.Add(1) + session := &udpSession{ + clientAddr: clientAddr, + targetConn: targetConn, + lastActive: time.Now().Add(-2 * time.Hour), // 很久以前 + srv: srv, + } + + // 添加会话到服务器 + srv.sessions[sessionKey(clientAddr)] = session + + // 启动后端响应处理 + done := make(chan struct{}) + go func() { + session.handleBackendResponse() + close(done) + }() + + select { + case <-done: + // 应该因为超时而清理会话 + case <-time.After(2 * time.Second): + t.Fatal("handleBackendResponse() timed out") + } +} + +// TestHandleBackendResponse_ErrorPath 测试 handleBackendResponse 错误路径 +func TestHandleBackendResponse_ErrorPath(t *testing.T) { + udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + conn, _ := net.ListenUDP("udp", udpAddr) + defer func() { _ = conn.Close() }() + + upstream := &Upstream{ + targets: []*Target{{addr: "127.0.0.1:29007"}}, + balancer: newRoundRobin(), + } + + srv := newUDPServer(conn, upstream, 10*time.Millisecond) + + clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29015") + + // 创建一个已关闭的连接作为 targetConn + targetUDPAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29007") + targetConn, err := net.DialUDP("udp", nil, targetUDPAddr) + if err != nil { + t.Fatalf("Failed to create target connection: %v", err) + } + + session := &udpSession{ + clientAddr: clientAddr, + targetConn: targetConn, + lastActive: time.Now(), + srv: srv, + } + + // 创建会话时需要先添加到 WaitGroup(handleBackendResponse 会调用 Done) + srv.wg.Add(1) + srv.sessions[sessionKey(clientAddr)] = session + + done := make(chan struct{}) + go func() { + session.handleBackendResponse() + close(done) + }() + + select { + case <-done: + // 完成 + case <-time.After(3 * time.Second): + t.Fatal("handleBackendResponse() timed out") + } +} + +// TestServe_InvalidUpstream 测试 serve 方法无效上游路径 +func TestServe_InvalidUpstream(t *testing.T) { + udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + conn, _ := net.ListenUDP("udp", udpAddr) + defer func() { _ = conn.Close() }() + + upstream := &Upstream{ + targets: []*Target{}, + balancer: newRoundRobin(), + } + + srv := newUDPServer(conn, upstream, 50*time.Millisecond) + + // 启动 serve + go srv.serve() + + // 立即停止 + time.Sleep(20 * time.Millisecond) + srv.stop() +} + +// TestServerStop 测试 Server.Stop 方法 +func TestServerStop(t *testing.T) { + s := NewServer() + + // 添加上游 + targets := []TargetSpec{ + {Addr: "127.0.0.1:29008", Weight: 1}, + } + hcSpec := HealthCheckSpec{ + Enabled: true, + Interval: 1 * time.Second, + Timeout: 500 * time.Millisecond, + } + _ = s.AddUpstream("stop_test", targets, "round_robin", hcSpec) + + // 监听 TCP + err := s.ListenTCP("127.0.0.1:29009") + if err != nil { + t.Fatalf("ListenTCP failed: %v", err) + } + + // 监听 UDP + err = s.ListenUDP("127.0.0.1:29010", "stop_test", 1*time.Second) + if err != nil { + t.Fatalf("ListenUDP failed: %v", err) + } + + // 启动 + err = s.Start() + if err != nil { + t.Fatalf("Start failed: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + // 停止 + err = s.Stop() + if err != nil { + t.Errorf("Stop() error: %v", err) + } +} + +// TestStatsComplete 测试 Stats 完整统计 +func TestStatsComplete(t *testing.T) { + s := NewServer() + + // 添加 TCP 监听 + err := s.ListenTCP("127.0.0.1:29020") + if err != nil { + t.Fatalf("ListenTCP failed: %v", err) + } + + // 添加上游 + targets := []TargetSpec{{Addr: "127.0.0.1:29021", Weight: 1}} + _ = s.AddUpstream("stats_test", targets, "round_robin", HealthCheckSpec{}) + + // 添加 UDP 监听 + err = s.ListenUDP("127.0.0.1:29022", "stats_test", 1*time.Second) + if err != nil { + t.Fatalf("ListenUDP failed: %v", err) + } + + stats := s.Stats() + if stats.Listeners != 2 { + t.Errorf("Stats().Listeners = %d, want 2 (1 TCP + 1 UDP)", stats.Listeners) + } + if stats.Upstreams != 1 { + t.Errorf("Stats().Upstreams = %d, want 1", stats.Upstreams) + } + if stats.Connections != 0 { + t.Errorf("Stats().Connections = %d, want 0", stats.Connections) + } +} + +// TestAcceptLoop_Error 测试 acceptLoop 错误处理路径 +func TestAcceptLoop_Error(t *testing.T) { + s := NewServer() + s.running.Store(true) + + // 创建一个立即关闭的监听器 + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + + // 在另一个 goroutine 中关闭监听器 + go func() { + time.Sleep(50 * time.Millisecond) + _ = ln.Close() + }() + + done := make(chan struct{}) + go func() { + s.acceptLoop("test", ln) + close(done) + }() + + select { + case <-done: + // 完成 + case <-time.After(2 * time.Second): + s.running.Store(false) + <-done + } +}