test: 添加各模块覆盖率补充测试

- middleware/security: access 中间件覆盖率测试
- proxy: proxy 核心功能覆盖率测试
- server: server 扩展功能测试
- stream: stream 处理覆盖率测试

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-20 08:27:25 +08:00
parent c82e6dcdb7
commit 5f5717d6a4
4 changed files with 2079 additions and 0 deletions

View File

@ -0,0 +1,340 @@
// Package security 提供访问控制覆盖测试。
//
// 该文件补充测试 access.go 中未覆盖的方法:
// - Name() 方法
// - Process() 完整处理链(允许/拒绝路径)
// - getClientIP() 通过 Process 间接测试
// - Close() 方法
// - actionToString 边缘情况
// - trustedProxies 相关逻辑
//
// 作者xfy
package security
import (
"net"
"testing"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
)
// TestAccessControlName 测试 Name 方法
func TestAccessControlName(t *testing.T) {
ac, err := NewAccessControl(&config.AccessConfig{
Default: "allow",
})
if err != nil {
t.Fatalf("NewAccessControl() error: %v", err)
}
name := ac.Name()
if name != "access_control" {
t.Errorf("Name() = %q, want 'access_control'", name)
}
}
// TestAccessControlProcess_AllowPath 测试 Process 允许路径
func TestAccessControlProcess_AllowPath(t *testing.T) {
ac, err := NewAccessControl(&config.AccessConfig{
Default: "allow",
})
if err != nil {
t.Fatalf("NewAccessControl() error: %v", err)
}
called := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
called = true
_, _ = ctx.WriteString("allowed")
}
handler := ac.Process(nextHandler)
if handler == nil {
t.Fatal("Process() returned nil")
}
ctx := &fasthttp.RequestCtx{}
handler(ctx)
if !called {
t.Error("Process() should call next handler when access allowed")
}
if string(ctx.Response.Body()) != "allowed" {
t.Errorf("Process() body = %q, want 'allowed'", string(ctx.Response.Body()))
}
}
// TestAccessControlProcess_DenyPath 测试 Process 拒绝路径
func TestAccessControlProcess_DenyPath(t *testing.T) {
ac, err := NewAccessControl(&config.AccessConfig{
Default: "deny",
})
if err != nil {
t.Fatalf("NewAccessControl() error: %v", err)
}
called := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
called = true
}
handler := ac.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
handler(ctx)
if called {
t.Error("Process() should NOT call next handler when access denied")
}
if ctx.Response.StatusCode() != fasthttp.StatusForbidden {
t.Errorf("Process() status = %d, want 403", ctx.Response.StatusCode())
}
}
// TestAccessControlProcess_ExplicitAllow 测试显式允许列表
func TestAccessControlProcess_ExplicitAllow(t *testing.T) {
ac, err := NewAccessControl(&config.AccessConfig{
Allow: []string{"127.0.0.1"},
Default: "deny",
})
if err != nil {
t.Fatalf("NewAccessControl() error: %v", err)
}
called := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
called = true
}
handler := ac.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345})
handler(ctx)
if !called {
t.Error("Process() should call next handler for allowed IP")
}
}
// TestAccessControlProcess_ExplicitDeny 测试显式拒绝列表
func TestAccessControlProcess_ExplicitDeny(t *testing.T) {
ac, err := NewAccessControl(&config.AccessConfig{
Deny: []string{"10.0.0.1"},
Default: "allow",
})
if err != nil {
t.Fatalf("NewAccessControl() error: %v", err)
}
called := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
called = true
}
handler := ac.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 12345})
handler(ctx)
if called {
t.Error("Process() should NOT call next handler for denied IP")
}
}
// TestAccessControlProcess_TrustedProxies_XFF 测试可信代理 XFF 解析
func TestAccessControlProcess_TrustedProxies_XFF(t *testing.T) {
// 配置可信代理10.0.0.0/8 为可信代理段
ac, err := NewAccessControl(&config.AccessConfig{
TrustedProxies: []string{"10.0.0.0/8"},
Default: "deny",
Allow: []string{"192.168.1.0/24"},
})
if err != nil {
t.Fatalf("NewAccessControl() error: %v", err)
}
called := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
called = true
}
handler := ac.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
// 请求来自可信代理 10.0.0.1XFF 中包含真实客户端 192.168.1.100
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 12345})
ctx.Request.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1")
handler(ctx)
if !called {
t.Error("Process() should allow real client IP behind trusted proxy")
}
}
// TestAccessControlProcess_TrustedProxies_UntrustedSource 测试不可信来源不解析 XFF
func TestAccessControlProcess_TrustedProxies_UntrustedSource(t *testing.T) {
ac, err := NewAccessControl(&config.AccessConfig{
TrustedProxies: []string{"10.0.0.0/8"},
Default: "deny",
Allow: []string{"192.168.1.0/24"},
})
if err != nil {
t.Fatalf("NewAccessControl() error: %v", err)
}
called := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
called = true
}
handler := ac.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
// 请求来自不可信地址,即使 XFF 包含允许列表 IP 也不应解析
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("203.0.113.1"), Port: 12345})
ctx.Request.Header.Set("X-Forwarded-For", "192.168.1.100")
handler(ctx)
if called {
t.Error("Process() should not trust XFF from untrusted source")
}
}
// TestAccessControlProcess_TrustedProxies_XRealIP 测试可信代理 X-Real-IP
func TestAccessControlProcess_TrustedProxies_XRealIP(t *testing.T) {
ac, err := NewAccessControl(&config.AccessConfig{
TrustedProxies: []string{"10.0.0.0/8"},
Default: "deny",
Allow: []string{"192.168.1.0/24"},
})
if err != nil {
t.Fatalf("NewAccessControl() error: %v", err)
}
called := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
called = true
}
handler := ac.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("10.0.0.50"), Port: 12345})
ctx.Request.Header.Set("X-Real-IP", "192.168.1.50")
handler(ctx)
if !called {
t.Error("Process() should use X-Real-IP from trusted proxy")
}
}
// TestAccessControlClose 测试 Close 方法
func TestAccessControlClose(t *testing.T) {
// 无 GeoIP 的 Close
ac, err := NewAccessControl(&config.AccessConfig{
Default: "allow",
})
if err != nil {
t.Fatalf("NewAccessControl() error: %v", err)
}
err = ac.Close()
if err != nil {
t.Errorf("Close() error: %v", err)
}
}
// TestActionToString 测试 actionToString 边缘情况
func TestActionToString(t *testing.T) {
// 测试 ActionAllow
result := actionToString(ActionAllow)
if result != "allow" {
t.Errorf("actionToString(ActionAllow) = %q, want 'allow'", result)
}
// 测试 ActionDeny
result = actionToString(ActionDeny)
if result != "deny" {
t.Errorf("actionToString(ActionDeny) = %q, want 'deny'", result)
}
// 测试未知值
result = actionToString(Action(999))
if result != "unknown" {
t.Errorf("actionToString(999) = %q, want 'unknown'", result)
}
}
// TestGetStatsWithEmpty 测试 GetStats 空列表
func TestGetStatsWithEmpty(t *testing.T) {
ac, err := NewAccessControl(&config.AccessConfig{
Default: "allow",
})
if err != nil {
t.Fatalf("NewAccessControl() error: %v", err)
}
stats := ac.GetStats()
if stats.AllowCount != 0 {
t.Errorf("GetStats().AllowCount = %d, want 0", stats.AllowCount)
}
if stats.DenyCount != 0 {
t.Errorf("GetStats().DenyCount = %d, want 0", stats.DenyCount)
}
if stats.Default != "allow" {
t.Errorf("GetStats().Default = %q, want 'allow'", stats.Default)
}
}
// TestSetDefaultValidCases 测试 SetDefault 所有有效值
func TestSetDefaultValidCases(t *testing.T) {
ac, err := NewAccessControl(&config.AccessConfig{
Default: "allow",
})
if err != nil {
t.Fatalf("NewAccessControl() error: %v", err)
}
// 切换为 deny
err = ac.SetDefault("deny")
if err != nil {
t.Errorf("SetDefault('deny') error: %v", err)
}
stats := ac.GetStats()
if stats.Default != "deny" {
t.Errorf("After SetDefault('deny'), Default = %q, want 'deny'", stats.Default)
}
// 切换回 allow
err = ac.SetDefault("allow")
if err != nil {
t.Errorf("SetDefault('allow') error: %v", err)
}
stats = ac.GetStats()
if stats.Default != "allow" {
t.Errorf("After SetDefault('allow'), Default = %q, want 'allow'", stats.Default)
}
// 大小写不敏感
err = ac.SetDefault("DENY")
if err != nil {
t.Errorf("SetDefault('DENY') error: %v", err)
}
}
// TestUpdateDenyListError 测试 UpdateDenyList 错误路径
func TestUpdateDenyListError(t *testing.T) {
ac, err := NewAccessControl(&config.AccessConfig{
Default: "allow",
})
if err != nil {
t.Fatalf("NewAccessControl() error: %v", err)
}
err = ac.UpdateDenyList([]string{"not-an-ip"})
if err == nil {
t.Error("UpdateDenyList() should return error for invalid CIDR")
}
}

View File

@ -0,0 +1,720 @@
// Package proxy 提供反向代理覆盖测试,补充 proxy.go 中未覆盖的方法。
//
// 该文件测试代理模块的以下功能:
// - selectTargetExcluding 排除已失败目标的选择
// - extractHashKey 哈希键提取
// - buildCacheKeyHash / buildCacheKeyHashValue 缓存键计算
// - writeCachedResponse 缓存响应写入
// - GetCache / GetCacheStats 缓存访问
// - getCacheDuration 不同状态码的缓存时间
// - redirect_rewrite 相关功能
//
// 作者xfy
package proxy
import (
"testing"
"time"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
"rua.plus/lolly/internal/cache"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/testutil"
)
// TestSelectTargetExcluding 测试排除失败目标的目标选择
func TestSelectTargetExcluding(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
{URL: "http://backend3:8080"},
}
for _, target := range targets {
target.Healthy.Store(true)
}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
// 排除第一个目标,应该选择第二个
excluded := []*loadbalance.Target{targets[0]}
selected := p.selectTargetExcluding(ctx, excluded)
if selected == nil {
t.Fatal("selectTargetExcluding() returned nil")
}
if selected.URL == "http://backend1:8080" {
t.Error("selectTargetExcluding() should not select excluded target")
}
// 排除所有目标,应该返回 nil
allExcluded := []*loadbalance.Target{targets[0], targets[1], targets[2]}
selected = p.selectTargetExcluding(ctx, allExcluded)
if selected != nil {
t.Error("selectTargetExcluding() should return nil when all excluded")
}
// 空目标列表
p2, _ := NewProxy(cfg, []*loadbalance.Target{{URL: "http://a:1"}}, nil, nil)
p2.targets = nil
selected = p2.selectTargetExcluding(ctx, nil)
if selected != nil {
t.Error("selectTargetExcluding() should return nil for empty targets")
}
}
// TestSelectTargetExcluding_IPHash 测试 IP Hash 排除选择
func TestSelectTargetExcluding_IPHash(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "ip_hash",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
{URL: "http://backend3:8080"},
}
for _, target := range targets {
target.Healthy.Store(true)
}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{
"X-Forwarded-For": "192.168.1.1",
})
// 获取第一次选择
first := p.selectTarget(ctx)
if first == nil {
t.Fatal("selectTarget() returned nil")
}
// 排除第一次选择的目标
excluded := []*loadbalance.Target{first}
second := p.selectTargetExcluding(ctx, excluded)
if second == nil {
t.Fatal("selectTargetExcluding() returned nil")
}
if second.URL == first.URL {
t.Errorf("selectTargetExcluding() should not select same target %s", first.URL)
}
}
// TestSelectTargetExcluding_ConsistentHash 测试一致性哈希排除选择
func TestSelectTargetExcluding_ConsistentHash(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "consistent_hash",
HashKey: "uri",
VirtualNodes: 150,
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
}
for _, target := range targets {
target.Healthy.Store(true)
}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
// 排除一个目标,应该还能选到另一个
excluded := []*loadbalance.Target{targets[0]}
selected := p.selectTargetExcluding(ctx, excluded)
if selected == nil {
t.Error("selectTargetExcluding() should return remaining target")
}
}
// TestExtractHashKey 测试哈希键提取
func TestExtractHashKey(t *testing.T) {
tests := []struct {
name string
hashKey string
headers map[string]string
expected string
}{
{
name: "ip hash key",
hashKey: "ip",
headers: map[string]string{"X-Forwarded-For": "10.0.0.1"},
expected: "10.0.0.1",
},
{
name: "empty hash key defaults to ip",
hashKey: "",
headers: map[string]string{"X-Forwarded-For": "10.0.0.2"},
expected: "10.0.0.2",
},
{
name: "uri hash key",
hashKey: "uri",
headers: nil,
expected: "/api/test",
},
{
name: "header hash key - found",
hashKey: "header:X-Custom-ID",
headers: map[string]string{"X-Custom-ID": "abc123"},
expected: "abc123",
},
{
name: "header hash key - fallback to ip",
hashKey: "header:X-Missing",
headers: map[string]string{"X-Forwarded-For": "10.0.0.3"},
expected: "10.0.0.3",
},
{
name: "unknown hash key defaults to ip",
hashKey: "unknown",
headers: map[string]string{"X-Forwarded-For": "10.0.0.4"},
expected: "10.0.0.4",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "consistent_hash",
HashKey: tt.hashKey,
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", tt.headers)
result := p.extractHashKey(ctx, tt.hashKey)
if result != tt.expected {
t.Errorf("extractHashKey() = %q, want %q", result, tt.expected)
}
})
}
}
// TestBuildCacheKeyHash 测试缓存键哈希计算
func TestBuildCacheKeyHash(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
hashKey, origKey := p.buildCacheKeyHash(ctx)
if hashKey == 0 {
t.Error("buildCacheKeyHash() should return non-zero hash")
}
if origKey == "" {
t.Error("buildCacheKeyHash() should return non-empty origKey")
}
// 相同请求应产生相同哈希
ctx2 := testutil.NewRequestCtx("GET", "/api/test")
hashKey2, _ := p.buildCacheKeyHash(ctx2)
if hashKey != hashKey2 {
t.Error("Same request should produce same hash")
}
// 不同请求应产生不同哈希
ctx3 := testutil.NewRequestCtx("POST", "/api/other")
hashKey3, _ := p.buildCacheKeyHash(ctx3)
if hashKey == hashKey3 {
t.Error("Different request should produce different hash")
}
}
// TestBuildCacheKeyHashValue 测试零分配缓存键哈希
func TestBuildCacheKeyHashValue(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
hashValue := p.buildCacheKeyHashValue(ctx)
if hashValue == 0 {
t.Error("buildCacheKeyHashValue() should return non-zero hash")
}
// 应该与 buildCacheKeyHash 结果一致
hashKey, _ := p.buildCacheKeyHash(ctx)
if hashValue != hashKey {
t.Error("buildCacheKeyHashValue() should match buildCacheKeyHash()")
}
}
// TestWriteCachedResponse 测试缓存响应写入
func TestWriteCachedResponse(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 手动创建一个 Response 用于验证 writeCachedResponse 写入正确
ctx := testutil.NewRequestCtx("GET", "/api/test")
entry := &cache.ProxyCacheEntry{
Data: []byte("cached body"),
Headers: map[string]string{"Content-Type": "text/html", "X-Cached": "true"},
Status: 200,
}
p.writeCachedResponse(ctx, entry)
if ctx.Response.StatusCode() != 200 {
t.Errorf("writeCachedResponse() status = %d, want 200", ctx.Response.StatusCode())
}
if string(ctx.Response.Body()) != "cached body" {
t.Errorf("writeCachedResponse() body = %q, want %q", string(ctx.Response.Body()), "cached body")
}
ct := string(ctx.Response.Header.Peek("Content-Type"))
if ct != "text/html" {
t.Errorf("writeCachedResponse() Content-Type = %q, want %q", ct, "text/html")
}
xc := string(ctx.Response.Header.Peek("X-Cache"))
if xc != "HIT" {
t.Errorf("writeCachedResponse() X-Cache = %q, want HIT", xc)
}
}
// TestGetCache 测试 GetCache 方法
func TestGetCache(t *testing.T) {
// 启用缓存时
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
c := p.GetCache()
if c == nil {
t.Error("GetCache() should return non-nil when cache enabled")
}
// 禁用缓存时
cfg2 := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
p2, _ := NewProxy(cfg2, targets, nil, nil)
c2 := p2.GetCache()
if c2 != nil {
t.Error("GetCache() should return nil when cache disabled")
}
}
// TestGetCacheStats 测试 GetCacheStats 方法
func TestGetCacheStats(t *testing.T) {
// 启用缓存时
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
stats := p.GetCacheStats()
if stats == nil {
t.Error("GetCacheStats() should return non-nil when cache enabled")
}
// 禁用缓存时
cfg2 := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
p2, _ := NewProxy(cfg2, targets, nil, nil)
stats2 := p2.GetCacheStats()
if stats2 != nil {
t.Error("GetCacheStats() should return nil when cache disabled")
}
}
// TestGetCacheDuration 测试不同状态码的缓存时间计算
func TestGetCacheDuration(t *testing.T) {
tests := []struct {
name string
cacheValid *config.ProxyCacheValidConfig
maxAge time.Duration
statusCode int
expected time.Duration
}{
{
name: "no CacheValid config uses MaxAge",
maxAge: 5 * time.Minute,
statusCode: 200,
expected: 5 * time.Minute,
},
{
name: "2xx with CacheValid.OK set",
cacheValid: &config.ProxyCacheValidConfig{
OK: 10 * time.Minute,
},
statusCode: 200,
expected: 10 * time.Minute,
},
{
name: "2xx with CacheValid.OK=0 inherits MaxAge",
cacheValid: &config.ProxyCacheValidConfig{
OK: 0,
},
maxAge: 3 * time.Minute,
statusCode: 201,
expected: 3 * time.Minute,
},
{
name: "301 redirect",
cacheValid: &config.ProxyCacheValidConfig{
Redirect: 1 * time.Hour,
},
statusCode: 301,
expected: 1 * time.Hour,
},
{
name: "302 redirect",
cacheValid: &config.ProxyCacheValidConfig{
Redirect: 30 * time.Minute,
},
statusCode: 302,
expected: 30 * time.Minute,
},
{
name: "302 with zero Redirect means no cache",
cacheValid: &config.ProxyCacheValidConfig{
Redirect: 0,
},
statusCode: 302,
expected: 0,
},
{
name: "404",
cacheValid: &config.ProxyCacheValidConfig{
NotFound: 1 * time.Minute,
},
statusCode: 404,
expected: 1 * time.Minute,
},
{
name: "4xx client error",
cacheValid: &config.ProxyCacheValidConfig{
ClientError: 30 * time.Second,
},
statusCode: 400,
expected: 30 * time.Second,
},
{
name: "5xx server error",
cacheValid: &config.ProxyCacheValidConfig{
ServerError: 0,
},
statusCode: 500,
expected: 0,
},
{
name: "other status code",
statusCode: 100,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: tt.maxAge,
},
CacheValid: tt.cacheValid,
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
duration := p.getCacheDuration(tt.statusCode)
if duration != tt.expected {
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, duration, tt.expected)
}
})
}
}
// TestBackgroundRefresh 测试后台缓存刷新(标记为 skip 因为需要真实网络)
func TestBackgroundRefresh(t *testing.T) {
t.Skip("skipping: requires real network connection and is timing-sensitive")
ln := fasthttputil.NewInmemoryListener()
defer func() { _ = ln.Close() }()
go func() {
s := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetBodyString("refreshed")
ctx.Response.Header.Set("Content-Type", "text/plain")
},
}
_ = s.Serve(ln)
}()
time.Sleep(10 * time.Millisecond)
addr := ln.Addr().String()
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
}
targets := []*loadbalance.Target{
{URL: "http://" + addr},
}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
// 设置缓存锁
hashKey := uint64(12345)
p.cache.AcquireLock(hashKey)
// 调用后台刷新(它会执行实际请求来刷新缓存)
done := make(chan struct{})
go func() {
p.backgroundRefresh(ctx, targets[0], hashKey, "/api/test")
close(done)
}()
// 等待完成
select {
case <-done:
// 完成
case <-time.After(2 * time.Second):
t.Fatal("backgroundRefresh() timed out")
}
}
// TestBackgroundRefresh_NoClient 测试后台刷新时客户端不存在的情况
func TestBackgroundRefresh_NoClient(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
}
targets := []*loadbalance.Target{
{URL: "http://nonexistent:9999"},
}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 移除客户端
p.mu.Lock()
delete(p.clients, targets[0].URL)
p.mu.Unlock()
ctx := testutil.NewRequestCtx("GET", "/api/test")
hashKey := uint64(99999)
p.cache.AcquireLock(hashKey)
// 应该不会 panic直接返回
p.backgroundRefresh(ctx, targets[0], hashKey, "/api/test")
}
// TestServeHTTP_CacheHit 测试缓存命中路径
func TestServeHTTP_CacheHit(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 预填充缓存
ctx := testutil.NewRequestCtx("GET", "/api/cached")
hashKey, origKey := p.buildCacheKeyHash(ctx)
p.cache.Set(hashKey, origKey, []byte("cached!"), map[string]string{
"Content-Type": "text/plain",
}, 200, 10*time.Second)
// 执行请求
p.ServeHTTP(ctx)
// 应该返回缓存的响应
if ctx.Response.StatusCode() != 200 {
t.Errorf("ServeHTTP() status = %d, want 200", ctx.Response.StatusCode())
}
if string(ctx.Response.Body()) != "cached!" {
t.Errorf("ServeHTTP() body = %q, want %q", string(ctx.Response.Body()), "cached!")
}
xc := string(ctx.Response.Header.Peek("X-Cache"))
if xc != "HIT" {
t.Errorf("ServeHTTP() X-Cache = %q, want HIT", xc)
}
}
// TestServeHTTP_ClientNil 测试客户端为 nil 时的行为
func TestServeHTTP_ClientNil(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
NextUpstream: config.NextUpstreamConfig{
Tries: 2,
},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
}
for _, target := range targets {
target.Healthy.Store(true)
}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 移除所有客户端
p.mu.Lock()
p.clients = make(map[string]*fasthttp.HostClient)
p.mu.Unlock()
ctx := testutil.NewRequestCtx("GET", "/api/test")
p.ServeHTTP(ctx)
// 所有客户端都不存在,应该返回 502
if ctx.Response.StatusCode() != fasthttp.StatusBadGateway {
t.Errorf("ServeHTTP() status = %d, want 502", ctx.Response.StatusCode())
}
}
// TestServeHTTP_WithRedirectRewrite 测试带 redirect_rewrite 的缓存命中
func TestServeHTTP_WithRedirectRewrite_CacheHit(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
RedirectRewrite: &config.RedirectRewriteConfig{
Mode: "off", // 关闭改写
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/api/test")
hashKey, origKey := p.buildCacheKeyHash(ctx)
p.cache.Set(hashKey, origKey, []byte("ok"), map[string]string{
"Content-Type": "text/plain",
}, 200, 10*time.Second)
p.ServeHTTP(ctx)
if ctx.Response.StatusCode() != 200 {
t.Errorf("ServeHTTP() status = %d, want 200", ctx.Response.StatusCode())
}
}

View File

@ -0,0 +1,465 @@
// Package server 提供 HTTP 服务器的核心实现测试补充。
//
// 该文件补充以下测试覆盖:
// - 多模式启动逻辑测试single/vhost/multi_server/auto
// - 多服务器模式 shutdownServers 函数测试
// - 监听器创建测试TCP/Unix socket
// - StopWithTimeout 超时行为测试
// - GracefulStop 超时行为测试
// - 中间件链错误路径测试
package server
import (
"context"
"net"
"testing"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/resolver"
)
// TestServer_GetMode_Single 测试单服务器模式
func TestServer_GetMode_Single(t *testing.T) {
cfg := &config.Config{
Mode: config.ServerModeSingle,
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
if s.config.GetMode() != config.ServerModeSingle {
t.Errorf("Expected mode single, got %s", s.config.GetMode())
}
}
// TestServer_GetMode_VHost 测试虚拟主机模式
func TestServer_GetMode_VHost(t *testing.T) {
cfg := &config.Config{
Mode: config.ServerModeVHost,
Servers: []config.ServerConfig{
{Listen: ":0", Name: "host1.example.com"},
{Listen: ":0", Name: "host2.example.com"},
},
}
s := New(cfg)
if s.config.GetMode() != config.ServerModeVHost {
t.Errorf("Expected mode vhost, got %s", s.config.GetMode())
}
}
// TestServer_GetMode_MultiServer 测试多服务器模式
func TestServer_GetMode_MultiServer(t *testing.T) {
cfg := &config.Config{
Mode: config.ServerModeMultiServer,
Servers: []config.ServerConfig{
{Listen: ":8080", Name: "server1"},
{Listen: ":8081", Name: "server2"},
},
}
s := New(cfg)
if s.config.GetMode() != config.ServerModeMultiServer {
t.Errorf("Expected mode multi_server, got %s", s.config.GetMode())
}
}
// TestServer_GetMode_Auto 测试自动模式
func TestServer_GetMode_Auto(t *testing.T) {
cfg := &config.Config{
Mode: config.ServerModeAuto,
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
mode := s.config.GetMode()
if mode != config.ServerModeSingle {
t.Errorf("Expected auto to resolve to single, got %s", mode)
}
}
// TestShutdownServers_Empty 测试空服务器列表关闭
func TestShutdownServers_Empty(t *testing.T) {
err := shutdownServers(nil, nil)
if err != nil {
t.Errorf("Expected nil error for empty servers, got %v", err)
}
}
// TestShutdownServers_NilServer 测试含 nil 的服务器列表关闭
func TestShutdownServers_NilServer(t *testing.T) {
servers := []*fasthttp.Server{nil, nil}
err := shutdownServers(nil, servers)
if err != nil {
t.Errorf("Expected nil error with nil servers, got %v", err)
}
}
// TestShutdownServers_Timeout 测试关闭超时
func TestShutdownServers_Timeout(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
fastSrv := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
select {}
},
}
go func() {
_ = fastSrv.Serve(ln)
}()
time.Sleep(50 * time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
err = shutdownServers(ctx, []*fasthttp.Server{fastSrv})
// context 超时后 shutdownServers 会返回 ctx.Err()
if err != nil && err != context.DeadlineExceeded {
t.Errorf("Expected context.DeadlineExceeded or nil, got: %v", err)
}
_ = fastSrv.Shutdown()
_ = ln.Close()
}
// TestStopWithTimeout_DefaultTimeout 测试零超时使用默认值
func TestStopWithTimeout_DefaultTimeout(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
err := s.StopWithTimeout(0)
if err != nil {
t.Errorf("StopWithTimeout(0) should succeed, got %v", err)
}
err = s.StopWithTimeout(-1 * time.Second)
if err != nil {
t.Errorf("StopWithTimeout(-1s) should succeed, got %v", err)
}
}
// TestStopWithTimeout_MultiServerMode 测试多服务器模式停止
func TestStopWithTimeout_MultiServerMode(t *testing.T) {
cfg := &config.Config{
Mode: config.ServerModeMultiServer,
Servers: []config.ServerConfig{
{Listen: ":0", Name: "server1"},
{Listen: ":0", Name: "server2"},
},
}
s := New(cfg)
err := s.StopWithTimeout(1 * time.Second)
if err != nil {
t.Errorf("StopWithTimeout on non-started multi-server should succeed: %v", err)
}
}
// TestGracefulStop_Timeout 测试优雅停止超时
func TestGracefulStop_Timeout(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.fastServer = &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetBodyString("ok")
},
}
s.running = true
err := s.GracefulStop(100 * time.Millisecond)
if err != nil {
t.Errorf("GracefulStop should succeed: %v", err)
}
}
// TestServer_SetUpgradeManager 测试设置升级管理器
func TestServer_SetUpgradeManager(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
mgr := NewUpgradeManager(s)
s.SetUpgradeManager(mgr)
if s.upgradeManager != mgr {
t.Error("upgradeManager not set correctly")
}
}
// mockResolver 用于测试的 mock DNS 解析器
type mockResolver struct{}
func (m *mockResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
return []string{"127.0.0.1"}, nil
}
func (m *mockResolver) LookupHostWithCache(ctx context.Context, host string) ([]string, error) {
return []string{"127.0.0.1"}, nil
}
func (m *mockResolver) Refresh(host string) error { return nil }
func (m *mockResolver) Start() error { return nil }
func (m *mockResolver) Stop() error { return nil }
func (m *mockResolver) Stats() resolver.Stats { return resolver.Stats{} }
// TestServer_Resolver 测试 DNS 解析器设置
func TestServer_Resolver(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
if s.GetResolver() != nil {
t.Error("Expected nil resolver initially")
}
mockRes := &mockResolver{}
s.SetResolver(mockRes)
if s.GetResolver() == nil {
t.Error("Resolver not set correctly")
}
}
// TestCreateListener_TCP 测试 TCP 监听器创建
func TestCreateListener_TCP(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
serverCfg := &cfg.Servers[0]
ln, err := s.createListener(serverCfg)
if err != nil {
t.Fatalf("createListener failed: %v", err)
}
defer func() { _ = ln.Close() }()
if ln == nil {
t.Fatal("Expected non-nil listener")
}
addr := ln.Addr().(*net.TCPAddr)
if addr.Port == 0 {
t.Error("Expected non-zero port")
}
}
// TestCreateListener_InvalidTCP 测试无效 TCP 监听地址
func TestCreateListener_InvalidTCP(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "invalid:address:format",
}},
}
s := New(cfg)
_, err := s.createListener(&cfg.Servers[0])
if err == nil {
t.Error("Expected error for invalid listen address")
}
}
// TestListenerManagement 测试监听器管理
func TestListenerManagement(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
ln1, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener 1: %v", err)
}
defer func() { _ = ln1.Close() }()
ln2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener 2: %v", err)
}
defer func() { _ = ln2.Close() }()
s.SetListeners([]net.Listener{ln1, ln2})
listeners := s.GetListeners()
if len(listeners) != 2 {
t.Errorf("Expected 2 listeners, got %d", len(listeners))
}
}
// TestStart_WithGoroutinePoolAndFileCache 测试同时启用 GoroutinePool 和 FileCache
func TestStart_WithGoroutinePoolAndFileCache(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
Performance: config.PerformanceConfig{
GoroutinePool: config.GoroutinePoolConfig{
Enabled: true,
MaxWorkers: 50,
MinWorkers: 5,
},
FileCache: config.FileCacheConfig{
MaxEntries: 500,
MaxSize: 50 * 1024 * 1024,
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("New() returned nil")
}
if s.config.Performance.GoroutinePool.Enabled != true {
t.Error("GoroutinePool should be enabled")
}
}
// TestServer_GetHandler_NilThenSet 测试 handler 的 nil 到设置
func TestServer_GetHandler_NilThenSet(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
if s.GetHandler() != nil {
t.Error("Expected nil handler initially")
}
testHandler := func(ctx *fasthttp.RequestCtx) {
ctx.SetBodyString("test handler response")
}
s.handler = testHandler
got := s.GetHandler()
if got == nil {
t.Error("Expected non-nil handler after setting")
}
ctx := &fasthttp.RequestCtx{}
got(ctx)
if string(ctx.Response.Body()) != "test handler response" {
t.Errorf("Handler response = %q, want %q", string(ctx.Response.Body()), "test handler response")
}
}
// TestServer_TrackStats_Concurrent 测试并发统计追踪
func TestServer_TrackStats_Concurrent(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
handler := func(ctx *fasthttp.RequestCtx) {
ctx.SetBodyString("ok")
}
wrappedHandler := s.trackStats(handler)
const numGoroutines = 100
done := make(chan bool, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
wrappedHandler(ctx)
done <- true
}()
}
for i := 0; i < numGoroutines; i++ {
<-done
}
if s.requests.Load() != int64(numGoroutines) {
t.Errorf("Expected %d requests, got %d", numGoroutines, s.requests.Load())
}
}
// TestBuildMiddlewareChain_BodyLimit 测试请求体限制中间件
func TestBuildMiddlewareChain_BodyLimit(t *testing.T) {
cfg := &config.Config{
Logging: config.LoggingConfig{},
Servers: []config.ServerConfig{{
Listen: ":0",
Proxy: []config.ProxyConfig{{
Path: "/api/",
ClientMaxBodySize: "1MB",
}},
ClientMaxBodySize: "10MB",
}},
}
s := New(cfg)
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
if err != nil {
t.Errorf("buildMiddlewareChain failed: %v", err)
}
if chain == nil {
t.Error("Expected non-nil chain")
}
}
// TestBuildMiddlewareChain_BodyLimit_Invalid 测试无效的请求体限制
func TestBuildMiddlewareChain_BodyLimit_Invalid(t *testing.T) {
cfg := &config.Config{
Logging: config.LoggingConfig{},
Servers: []config.ServerConfig{{
Listen: ":0",
ClientMaxBodySize: "invalid_size",
}},
}
s := New(cfg)
_, err := s.buildMiddlewareChain(&cfg.Servers[0])
if err == nil {
t.Error("Expected error for invalid body limit size")
}
}

View File

@ -0,0 +1,554 @@
// Package stream 提供流代理覆盖测试。
//
// 该文件补充测试 stream.go 中未覆盖的方法:
// - ipHash.Select() (空 IP)
// - handleConnection() 连接处理
// - getOrCreateSession() 会话创建
// - handleBackendResponse() 后端响应处理
// - Stats 完整统计
//
// 作者xfy
package stream
import (
"net"
"testing"
"time"
)
// TestIPHashSelect 测试 ipHash 的 Select 方法(空字符串 IP
func TestIPHashSelect(t *testing.T) {
targets := []*Target{
{addr: "localhost:8001"},
{addr: "localhost:8002"},
}
for _, target := range targets {
target.healthy.Store(true)
}
ih := newIPHash()
// Select() 使用空字符串作为 IP
selected := ih.Select(targets)
if selected == nil {
t.Error("Select() with empty IP should return a target")
}
// 多次调用应返回相同目标(确定性哈希)
selected2 := ih.Select(targets)
if selected != selected2 {
t.Error("Select() with same empty IP should be consistent")
}
// 无健康目标时应返回 nil
for _, target := range targets {
target.healthy.Store(false)
}
selected = ih.Select(targets)
if selected != nil {
t.Error("Select() with no healthy targets should return nil")
}
}
// TestSelectByIPNoHealthy 测试 SelectByIP 无健康目标
func TestSelectByIPNoHealthy(t *testing.T) {
targets := []*Target{
{addr: "localhost:8001"},
{addr: "localhost:8002"},
}
ih := newIPHash()
selected := ih.(*ipHash).SelectByIP(targets, "192.168.1.1")
if selected != nil {
t.Error("SelectByIP() with no healthy targets should return nil")
}
}
// TestWeightedRoundRobinZeroWeight 测试零权重处理
func TestWeightedRoundRobinZeroWeight(t *testing.T) {
targets := []*Target{
{addr: "localhost:8001", weight: 0},
{addr: "localhost:8002", weight: -1},
}
for _, target := range targets {
target.healthy.Store(true)
}
wrr := newWeightedRoundRobin().(*weightedRoundRobin)
// 权重为 0 或负数应视为权重 1
for i := 0; i < 4; i++ {
selected := wrr.Select(targets)
if selected == nil {
t.Error("Select() should return target with zero/negative weight")
return
}
}
}
// TestHandleConnection 测试 handleConnection 方法
func TestHandleConnection(t *testing.T) {
s := NewServer()
// 添加上游配置
targets := []TargetSpec{
{Addr: "127.0.0.1:29001", Weight: 1},
}
_ = s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{})
s.upstreams["test"].targets[0].healthy.Store(true)
// 创建模拟客户端连接(不会实际建立连接,测试无上游路径)
s.mu.Lock()
// 设置上游为空,测试无上游配置路径
s.upstreams = make(map[string]*Upstream)
s.mu.Unlock()
// 创建一对连接
serverLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() { _ = serverLn.Close() }()
clientConn, err := net.Dial("tcp", serverLn.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer func() { _ = clientConn.Close() }()
serverConn, err := serverLn.Accept()
if err != nil {
t.Fatalf("Failed to accept: %v", err)
}
defer func() { _ = serverConn.Close() }()
// 测试无上游配置的 handleConnection
s.handleConnection(clientConn, "127.0.0.1:0")
}
// TestHandleConnection_NoHealthyTarget 测试无健康目标路径
func TestHandleConnection_NoHealthyTarget(t *testing.T) {
s := NewServer()
// 添加不健康的上游
targets := []TargetSpec{
{Addr: "127.0.0.1:29002", Weight: 1},
}
_ = s.AddUpstream("test2", targets, "round_robin", HealthCheckSpec{})
// 目标不健康(默认 false
serverLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() { _ = serverLn.Close() }()
clientConn, err := net.Dial("tcp", serverLn.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer func() { _ = clientConn.Close() }()
serverConn, err := serverLn.Accept()
if err != nil {
t.Fatalf("Failed to accept: %v", err)
}
defer func() { _ = serverConn.Close() }()
done := make(chan struct{})
go func() {
s.handleConnection(clientConn, "127.0.0.1:0")
close(done)
}()
select {
case <-done:
// 完成
case <-time.After(2 * time.Second):
t.Fatal("handleConnection() timed out")
}
}
// TestHandleConnection_DialFail 测试连接目标失败路径
func TestHandleConnection_DialFail(t *testing.T) {
s := NewServer()
// 添加上游,目标不可达
targets := []TargetSpec{
{Addr: "127.0.0.1:29999", Weight: 1},
}
_ = s.AddUpstream("test3", targets, "round_robin", HealthCheckSpec{})
s.upstreams["test3"].targets[0].healthy.Store(true)
serverLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() { _ = serverLn.Close() }()
clientConn, err := net.Dial("tcp", serverLn.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer func() { _ = clientConn.Close() }()
serverConn, err := serverLn.Accept()
if err != nil {
t.Fatalf("Failed to accept: %v", err)
}
defer func() { _ = serverConn.Close() }()
done := make(chan struct{})
go func() {
s.handleConnection(clientConn, "127.0.0.1:0")
close(done)
}()
select {
case <-done:
// 完成 - 连接目标失败后应标记为不健康
if s.upstreams["test3"].targets[0].healthy.Load() {
t.Error("Target should be marked unhealthy after dial failure")
}
case <-time.After(15 * time.Second):
t.Fatal("handleConnection() timed out")
}
}
// TestGetOrCreateSession 测试 getOrCreateSession 方法
func TestGetOrCreateSession(t *testing.T) {
// 创建 UDP 连接
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer func() { _ = conn.Close() }()
// 创建上游
upstream := &Upstream{
targets: []*Target{{addr: "127.0.0.1:29003"}},
balancer: newRoundRobin(),
}
upstream.targets[0].healthy.Store(true)
srv := newUDPServer(conn, upstream, 1*time.Minute)
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29010")
// 第一次调用 - 应该创建新会话(但由于后端不可达,应该失败)
session, err := srv.getOrCreateSession(clientAddr)
if err != nil {
// 预期失败,因为后端不可达
return
}
if session == nil {
t.Error("getOrCreateSession() should return a session")
}
}
// TestGetOrCreateSession_DoubleCheck 测试双重检查锁定
func TestGetOrCreateSession_DoubleCheck(t *testing.T) {
// 创建 UDP 连接
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer func() { _ = conn.Close() }()
// 创建上游
upstream := &Upstream{
targets: []*Target{{addr: "127.0.0.1:29004"}},
balancer: newRoundRobin(),
}
upstream.targets[0].healthy.Store(true)
srv := newUDPServer(conn, upstream, 1*time.Minute)
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29011")
// 手动创建一个会话来测试双重检查
srv.mu.Lock()
testSession := &udpSession{
clientAddr: clientAddr,
lastActive: time.Now(),
srv: srv,
}
srv.sessions[sessionKey(clientAddr)] = testSession
srv.mu.Unlock()
// 再次获取应该返回现有会话
session, err := srv.getOrCreateSession(clientAddr)
if err != nil {
t.Errorf("getOrCreateSession() should not error for existing session: %v", err)
}
if session != testSession {
t.Error("getOrCreateSession() should return existing session")
}
}
// TestGetOrCreateSession_NoHealthyTarget 测试无健康目标
func TestGetOrCreateSession_NoHealthyTarget(t *testing.T) {
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer func() { _ = conn.Close() }()
upstream := &Upstream{
targets: []*Target{{addr: "127.0.0.1:29005"}},
balancer: newRoundRobin(),
}
// 不设置 healthy默认为 false
srv := newUDPServer(conn, upstream, 1*time.Minute)
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29012")
session, err := srv.getOrCreateSession(clientAddr)
if err == nil {
t.Error("getOrCreateSession() should return error when no healthy target")
}
if session != nil {
t.Error("getOrCreateSession() should return nil session on error")
}
}
// TestGetOrCreateSession_InvalidTargetAddr 测试无效目标地址
func TestGetOrCreateSession_InvalidTargetAddr(t *testing.T) {
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer func() { _ = conn.Close() }()
upstream := &Upstream{
targets: []*Target{{addr: "invalid-address"}},
balancer: newRoundRobin(),
}
upstream.targets[0].healthy.Store(true)
srv := newUDPServer(conn, upstream, 1*time.Minute)
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29013")
session, err := srv.getOrCreateSession(clientAddr)
if err == nil {
t.Error("getOrCreateSession() should return error for invalid target address")
}
if session != nil {
t.Error("getOrCreateSession() should return nil session on error")
}
}
// TestHandleBackendResponse 测试 handleBackendResponse 超时清理路径
func TestHandleBackendResponse(t *testing.T) {
// 创建 UDP 连接(服务端)
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer func() { _ = conn.Close() }()
// 创建上游
upstream := &Upstream{
targets: []*Target{{addr: "127.0.0.1:29006"}},
balancer: newRoundRobin(),
}
upstream.targets[0].healthy.Store(true)
srv := newUDPServer(conn, upstream, 50*time.Millisecond) // 短超时
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29014")
// 创建目标连接(监听器用于接收目标连接)
targetAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29006")
targetConn, _ := net.ListenUDP("udp", targetAddr)
defer func() { _ = targetConn.Close() }()
// 创建会话时需要先添加到 WaitGrouphandleBackendResponse 会调用 Done
srv.wg.Add(1)
session := &udpSession{
clientAddr: clientAddr,
targetConn: targetConn,
lastActive: time.Now().Add(-2 * time.Hour), // 很久以前
srv: srv,
}
// 添加会话到服务器
srv.sessions[sessionKey(clientAddr)] = session
// 启动后端响应处理
done := make(chan struct{})
go func() {
session.handleBackendResponse()
close(done)
}()
select {
case <-done:
// 应该因为超时而清理会话
case <-time.After(2 * time.Second):
t.Fatal("handleBackendResponse() timed out")
}
}
// TestHandleBackendResponse_ErrorPath 测试 handleBackendResponse 错误路径
func TestHandleBackendResponse_ErrorPath(t *testing.T) {
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer func() { _ = conn.Close() }()
upstream := &Upstream{
targets: []*Target{{addr: "127.0.0.1:29007"}},
balancer: newRoundRobin(),
}
srv := newUDPServer(conn, upstream, 10*time.Millisecond)
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29015")
// 创建一个已关闭的连接作为 targetConn
targetUDPAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29007")
targetConn, err := net.DialUDP("udp", nil, targetUDPAddr)
if err != nil {
t.Fatalf("Failed to create target connection: %v", err)
}
session := &udpSession{
clientAddr: clientAddr,
targetConn: targetConn,
lastActive: time.Now(),
srv: srv,
}
// 创建会话时需要先添加到 WaitGrouphandleBackendResponse 会调用 Done
srv.wg.Add(1)
srv.sessions[sessionKey(clientAddr)] = session
done := make(chan struct{})
go func() {
session.handleBackendResponse()
close(done)
}()
select {
case <-done:
// 完成
case <-time.After(3 * time.Second):
t.Fatal("handleBackendResponse() timed out")
}
}
// TestServe_InvalidUpstream 测试 serve 方法无效上游路径
func TestServe_InvalidUpstream(t *testing.T) {
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer func() { _ = conn.Close() }()
upstream := &Upstream{
targets: []*Target{},
balancer: newRoundRobin(),
}
srv := newUDPServer(conn, upstream, 50*time.Millisecond)
// 启动 serve
go srv.serve()
// 立即停止
time.Sleep(20 * time.Millisecond)
srv.stop()
}
// TestServerStop 测试 Server.Stop 方法
func TestServerStop(t *testing.T) {
s := NewServer()
// 添加上游
targets := []TargetSpec{
{Addr: "127.0.0.1:29008", Weight: 1},
}
hcSpec := HealthCheckSpec{
Enabled: true,
Interval: 1 * time.Second,
Timeout: 500 * time.Millisecond,
}
_ = s.AddUpstream("stop_test", targets, "round_robin", hcSpec)
// 监听 TCP
err := s.ListenTCP("127.0.0.1:29009")
if err != nil {
t.Fatalf("ListenTCP failed: %v", err)
}
// 监听 UDP
err = s.ListenUDP("127.0.0.1:29010", "stop_test", 1*time.Second)
if err != nil {
t.Fatalf("ListenUDP failed: %v", err)
}
// 启动
err = s.Start()
if err != nil {
t.Fatalf("Start failed: %v", err)
}
time.Sleep(50 * time.Millisecond)
// 停止
err = s.Stop()
if err != nil {
t.Errorf("Stop() error: %v", err)
}
}
// TestStatsComplete 测试 Stats 完整统计
func TestStatsComplete(t *testing.T) {
s := NewServer()
// 添加 TCP 监听
err := s.ListenTCP("127.0.0.1:29020")
if err != nil {
t.Fatalf("ListenTCP failed: %v", err)
}
// 添加上游
targets := []TargetSpec{{Addr: "127.0.0.1:29021", Weight: 1}}
_ = s.AddUpstream("stats_test", targets, "round_robin", HealthCheckSpec{})
// 添加 UDP 监听
err = s.ListenUDP("127.0.0.1:29022", "stats_test", 1*time.Second)
if err != nil {
t.Fatalf("ListenUDP failed: %v", err)
}
stats := s.Stats()
if stats.Listeners != 2 {
t.Errorf("Stats().Listeners = %d, want 2 (1 TCP + 1 UDP)", stats.Listeners)
}
if stats.Upstreams != 1 {
t.Errorf("Stats().Upstreams = %d, want 1", stats.Upstreams)
}
if stats.Connections != 0 {
t.Errorf("Stats().Connections = %d, want 0", stats.Connections)
}
}
// TestAcceptLoop_Error 测试 acceptLoop 错误处理路径
func TestAcceptLoop_Error(t *testing.T) {
s := NewServer()
s.running.Store(true)
// 创建一个立即关闭的监听器
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
// 在另一个 goroutine 中关闭监听器
go func() {
time.Sleep(50 * time.Millisecond)
_ = ln.Close()
}()
done := make(chan struct{})
go func() {
s.acceptLoop("test", ln)
close(done)
}()
select {
case <-done:
// 完成
case <-time.After(2 * time.Second):
s.running.Store(false)
<-done
}
}