test: 添加各模块覆盖率补充测试
- middleware/security: access 中间件覆盖率测试 - proxy: proxy 核心功能覆盖率测试 - server: server 扩展功能测试 - stream: stream 处理覆盖率测试 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
c82e6dcdb7
commit
5f5717d6a4
340
internal/middleware/security/access_coverage_test.go
Normal file
340
internal/middleware/security/access_coverage_test.go
Normal file
@ -0,0 +1,340 @@
|
||||
// Package security 提供访问控制覆盖测试。
|
||||
//
|
||||
// 该文件补充测试 access.go 中未覆盖的方法:
|
||||
// - Name() 方法
|
||||
// - Process() 完整处理链(允许/拒绝路径)
|
||||
// - getClientIP() 通过 Process 间接测试
|
||||
// - Close() 方法
|
||||
// - actionToString 边缘情况
|
||||
// - trustedProxies 相关逻辑
|
||||
//
|
||||
// 作者:xfy
|
||||
package security
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// TestAccessControlName 测试 Name 方法
|
||||
func TestAccessControlName(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
Default: "allow",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
name := ac.Name()
|
||||
if name != "access_control" {
|
||||
t.Errorf("Name() = %q, want 'access_control'", name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccessControlProcess_AllowPath 测试 Process 允许路径
|
||||
func TestAccessControlProcess_AllowPath(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
Default: "allow",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
called := false
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
_, _ = ctx.WriteString("allowed")
|
||||
}
|
||||
|
||||
handler := ac.Process(nextHandler)
|
||||
if handler == nil {
|
||||
t.Fatal("Process() returned nil")
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
handler(ctx)
|
||||
|
||||
if !called {
|
||||
t.Error("Process() should call next handler when access allowed")
|
||||
}
|
||||
if string(ctx.Response.Body()) != "allowed" {
|
||||
t.Errorf("Process() body = %q, want 'allowed'", string(ctx.Response.Body()))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccessControlProcess_DenyPath 测试 Process 拒绝路径
|
||||
func TestAccessControlProcess_DenyPath(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
Default: "deny",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
called := false
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
}
|
||||
|
||||
handler := ac.Process(nextHandler)
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
handler(ctx)
|
||||
|
||||
if called {
|
||||
t.Error("Process() should NOT call next handler when access denied")
|
||||
}
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusForbidden {
|
||||
t.Errorf("Process() status = %d, want 403", ctx.Response.StatusCode())
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccessControlProcess_ExplicitAllow 测试显式允许列表
|
||||
func TestAccessControlProcess_ExplicitAllow(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
Allow: []string{"127.0.0.1"},
|
||||
Default: "deny",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
called := false
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
}
|
||||
|
||||
handler := ac.Process(nextHandler)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345})
|
||||
handler(ctx)
|
||||
|
||||
if !called {
|
||||
t.Error("Process() should call next handler for allowed IP")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccessControlProcess_ExplicitDeny 测试显式拒绝列表
|
||||
func TestAccessControlProcess_ExplicitDeny(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
Deny: []string{"10.0.0.1"},
|
||||
Default: "allow",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
called := false
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
}
|
||||
|
||||
handler := ac.Process(nextHandler)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 12345})
|
||||
handler(ctx)
|
||||
|
||||
if called {
|
||||
t.Error("Process() should NOT call next handler for denied IP")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccessControlProcess_TrustedProxies_XFF 测试可信代理 XFF 解析
|
||||
func TestAccessControlProcess_TrustedProxies_XFF(t *testing.T) {
|
||||
// 配置可信代理,10.0.0.0/8 为可信代理段
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
TrustedProxies: []string{"10.0.0.0/8"},
|
||||
Default: "deny",
|
||||
Allow: []string{"192.168.1.0/24"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
called := false
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
}
|
||||
|
||||
handler := ac.Process(nextHandler)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
// 请求来自可信代理 10.0.0.1,XFF 中包含真实客户端 192.168.1.100
|
||||
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 12345})
|
||||
ctx.Request.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1")
|
||||
handler(ctx)
|
||||
|
||||
if !called {
|
||||
t.Error("Process() should allow real client IP behind trusted proxy")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccessControlProcess_TrustedProxies_UntrustedSource 测试不可信来源不解析 XFF
|
||||
func TestAccessControlProcess_TrustedProxies_UntrustedSource(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
TrustedProxies: []string{"10.0.0.0/8"},
|
||||
Default: "deny",
|
||||
Allow: []string{"192.168.1.0/24"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
called := false
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
}
|
||||
|
||||
handler := ac.Process(nextHandler)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
// 请求来自不可信地址,即使 XFF 包含允许列表 IP 也不应解析
|
||||
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("203.0.113.1"), Port: 12345})
|
||||
ctx.Request.Header.Set("X-Forwarded-For", "192.168.1.100")
|
||||
handler(ctx)
|
||||
|
||||
if called {
|
||||
t.Error("Process() should not trust XFF from untrusted source")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccessControlProcess_TrustedProxies_XRealIP 测试可信代理 X-Real-IP
|
||||
func TestAccessControlProcess_TrustedProxies_XRealIP(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
TrustedProxies: []string{"10.0.0.0/8"},
|
||||
Default: "deny",
|
||||
Allow: []string{"192.168.1.0/24"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
called := false
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
}
|
||||
|
||||
handler := ac.Process(nextHandler)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("10.0.0.50"), Port: 12345})
|
||||
ctx.Request.Header.Set("X-Real-IP", "192.168.1.50")
|
||||
handler(ctx)
|
||||
|
||||
if !called {
|
||||
t.Error("Process() should use X-Real-IP from trusted proxy")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccessControlClose 测试 Close 方法
|
||||
func TestAccessControlClose(t *testing.T) {
|
||||
// 无 GeoIP 的 Close
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
Default: "allow",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
err = ac.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close() error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestActionToString 测试 actionToString 边缘情况
|
||||
func TestActionToString(t *testing.T) {
|
||||
// 测试 ActionAllow
|
||||
result := actionToString(ActionAllow)
|
||||
if result != "allow" {
|
||||
t.Errorf("actionToString(ActionAllow) = %q, want 'allow'", result)
|
||||
}
|
||||
|
||||
// 测试 ActionDeny
|
||||
result = actionToString(ActionDeny)
|
||||
if result != "deny" {
|
||||
t.Errorf("actionToString(ActionDeny) = %q, want 'deny'", result)
|
||||
}
|
||||
|
||||
// 测试未知值
|
||||
result = actionToString(Action(999))
|
||||
if result != "unknown" {
|
||||
t.Errorf("actionToString(999) = %q, want 'unknown'", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetStatsWithEmpty 测试 GetStats 空列表
|
||||
func TestGetStatsWithEmpty(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
Default: "allow",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
stats := ac.GetStats()
|
||||
if stats.AllowCount != 0 {
|
||||
t.Errorf("GetStats().AllowCount = %d, want 0", stats.AllowCount)
|
||||
}
|
||||
if stats.DenyCount != 0 {
|
||||
t.Errorf("GetStats().DenyCount = %d, want 0", stats.DenyCount)
|
||||
}
|
||||
if stats.Default != "allow" {
|
||||
t.Errorf("GetStats().Default = %q, want 'allow'", stats.Default)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetDefaultValidCases 测试 SetDefault 所有有效值
|
||||
func TestSetDefaultValidCases(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
Default: "allow",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
// 切换为 deny
|
||||
err = ac.SetDefault("deny")
|
||||
if err != nil {
|
||||
t.Errorf("SetDefault('deny') error: %v", err)
|
||||
}
|
||||
stats := ac.GetStats()
|
||||
if stats.Default != "deny" {
|
||||
t.Errorf("After SetDefault('deny'), Default = %q, want 'deny'", stats.Default)
|
||||
}
|
||||
|
||||
// 切换回 allow
|
||||
err = ac.SetDefault("allow")
|
||||
if err != nil {
|
||||
t.Errorf("SetDefault('allow') error: %v", err)
|
||||
}
|
||||
stats = ac.GetStats()
|
||||
if stats.Default != "allow" {
|
||||
t.Errorf("After SetDefault('allow'), Default = %q, want 'allow'", stats.Default)
|
||||
}
|
||||
|
||||
// 大小写不敏感
|
||||
err = ac.SetDefault("DENY")
|
||||
if err != nil {
|
||||
t.Errorf("SetDefault('DENY') error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateDenyListError 测试 UpdateDenyList 错误路径
|
||||
func TestUpdateDenyListError(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
Default: "allow",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
err = ac.UpdateDenyList([]string{"not-an-ip"})
|
||||
if err == nil {
|
||||
t.Error("UpdateDenyList() should return error for invalid CIDR")
|
||||
}
|
||||
}
|
||||
720
internal/proxy/proxy_coverage_test.go
Normal file
720
internal/proxy/proxy_coverage_test.go
Normal file
@ -0,0 +1,720 @@
|
||||
// Package proxy 提供反向代理覆盖测试,补充 proxy.go 中未覆盖的方法。
|
||||
//
|
||||
// 该文件测试代理模块的以下功能:
|
||||
// - selectTargetExcluding 排除已失败目标的选择
|
||||
// - extractHashKey 哈希键提取
|
||||
// - buildCacheKeyHash / buildCacheKeyHashValue 缓存键计算
|
||||
// - writeCachedResponse 缓存响应写入
|
||||
// - GetCache / GetCacheStats 缓存访问
|
||||
// - getCacheDuration 不同状态码的缓存时间
|
||||
// - redirect_rewrite 相关功能
|
||||
//
|
||||
// 作者:xfy
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"github.com/valyala/fasthttp/fasthttputil"
|
||||
"rua.plus/lolly/internal/cache"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/testutil"
|
||||
)
|
||||
|
||||
// TestSelectTargetExcluding 测试排除失败目标的目标选择
|
||||
func TestSelectTargetExcluding(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
}
|
||||
|
||||
targets := []*loadbalance.Target{
|
||||
{URL: "http://backend1:8080"},
|
||||
{URL: "http://backend2:8080"},
|
||||
{URL: "http://backend3:8080"},
|
||||
}
|
||||
for _, target := range targets {
|
||||
target.Healthy.Store(true)
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||
|
||||
// 排除第一个目标,应该选择第二个
|
||||
excluded := []*loadbalance.Target{targets[0]}
|
||||
selected := p.selectTargetExcluding(ctx, excluded)
|
||||
if selected == nil {
|
||||
t.Fatal("selectTargetExcluding() returned nil")
|
||||
}
|
||||
if selected.URL == "http://backend1:8080" {
|
||||
t.Error("selectTargetExcluding() should not select excluded target")
|
||||
}
|
||||
|
||||
// 排除所有目标,应该返回 nil
|
||||
allExcluded := []*loadbalance.Target{targets[0], targets[1], targets[2]}
|
||||
selected = p.selectTargetExcluding(ctx, allExcluded)
|
||||
if selected != nil {
|
||||
t.Error("selectTargetExcluding() should return nil when all excluded")
|
||||
}
|
||||
|
||||
// 空目标列表
|
||||
p2, _ := NewProxy(cfg, []*loadbalance.Target{{URL: "http://a:1"}}, nil, nil)
|
||||
p2.targets = nil
|
||||
selected = p2.selectTargetExcluding(ctx, nil)
|
||||
if selected != nil {
|
||||
t.Error("selectTargetExcluding() should return nil for empty targets")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectTargetExcluding_IPHash 测试 IP Hash 排除选择
|
||||
func TestSelectTargetExcluding_IPHash(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "ip_hash",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
}
|
||||
|
||||
targets := []*loadbalance.Target{
|
||||
{URL: "http://backend1:8080"},
|
||||
{URL: "http://backend2:8080"},
|
||||
{URL: "http://backend3:8080"},
|
||||
}
|
||||
for _, target := range targets {
|
||||
target.Healthy.Store(true)
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{
|
||||
"X-Forwarded-For": "192.168.1.1",
|
||||
})
|
||||
|
||||
// 获取第一次选择
|
||||
first := p.selectTarget(ctx)
|
||||
if first == nil {
|
||||
t.Fatal("selectTarget() returned nil")
|
||||
}
|
||||
|
||||
// 排除第一次选择的目标
|
||||
excluded := []*loadbalance.Target{first}
|
||||
second := p.selectTargetExcluding(ctx, excluded)
|
||||
if second == nil {
|
||||
t.Fatal("selectTargetExcluding() returned nil")
|
||||
}
|
||||
if second.URL == first.URL {
|
||||
t.Errorf("selectTargetExcluding() should not select same target %s", first.URL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectTargetExcluding_ConsistentHash 测试一致性哈希排除选择
|
||||
func TestSelectTargetExcluding_ConsistentHash(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "consistent_hash",
|
||||
HashKey: "uri",
|
||||
VirtualNodes: 150,
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
}
|
||||
|
||||
targets := []*loadbalance.Target{
|
||||
{URL: "http://backend1:8080"},
|
||||
{URL: "http://backend2:8080"},
|
||||
}
|
||||
for _, target := range targets {
|
||||
target.Healthy.Store(true)
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||
|
||||
// 排除一个目标,应该还能选到另一个
|
||||
excluded := []*loadbalance.Target{targets[0]}
|
||||
selected := p.selectTargetExcluding(ctx, excluded)
|
||||
if selected == nil {
|
||||
t.Error("selectTargetExcluding() should return remaining target")
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractHashKey 测试哈希键提取
|
||||
func TestExtractHashKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hashKey string
|
||||
headers map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "ip hash key",
|
||||
hashKey: "ip",
|
||||
headers: map[string]string{"X-Forwarded-For": "10.0.0.1"},
|
||||
expected: "10.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "empty hash key defaults to ip",
|
||||
hashKey: "",
|
||||
headers: map[string]string{"X-Forwarded-For": "10.0.0.2"},
|
||||
expected: "10.0.0.2",
|
||||
},
|
||||
{
|
||||
name: "uri hash key",
|
||||
hashKey: "uri",
|
||||
headers: nil,
|
||||
expected: "/api/test",
|
||||
},
|
||||
{
|
||||
name: "header hash key - found",
|
||||
hashKey: "header:X-Custom-ID",
|
||||
headers: map[string]string{"X-Custom-ID": "abc123"},
|
||||
expected: "abc123",
|
||||
},
|
||||
{
|
||||
name: "header hash key - fallback to ip",
|
||||
hashKey: "header:X-Missing",
|
||||
headers: map[string]string{"X-Forwarded-For": "10.0.0.3"},
|
||||
expected: "10.0.0.3",
|
||||
},
|
||||
{
|
||||
name: "unknown hash key defaults to ip",
|
||||
hashKey: "unknown",
|
||||
headers: map[string]string{"X-Forwarded-For": "10.0.0.4"},
|
||||
expected: "10.0.0.4",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "consistent_hash",
|
||||
HashKey: tt.hashKey,
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
}
|
||||
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", tt.headers)
|
||||
result := p.extractHashKey(ctx, tt.hashKey)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractHashKey() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildCacheKeyHash 测试缓存键哈希计算
|
||||
func TestBuildCacheKeyHash(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||
|
||||
hashKey, origKey := p.buildCacheKeyHash(ctx)
|
||||
if hashKey == 0 {
|
||||
t.Error("buildCacheKeyHash() should return non-zero hash")
|
||||
}
|
||||
if origKey == "" {
|
||||
t.Error("buildCacheKeyHash() should return non-empty origKey")
|
||||
}
|
||||
|
||||
// 相同请求应产生相同哈希
|
||||
ctx2 := testutil.NewRequestCtx("GET", "/api/test")
|
||||
hashKey2, _ := p.buildCacheKeyHash(ctx2)
|
||||
if hashKey != hashKey2 {
|
||||
t.Error("Same request should produce same hash")
|
||||
}
|
||||
|
||||
// 不同请求应产生不同哈希
|
||||
ctx3 := testutil.NewRequestCtx("POST", "/api/other")
|
||||
hashKey3, _ := p.buildCacheKeyHash(ctx3)
|
||||
if hashKey == hashKey3 {
|
||||
t.Error("Different request should produce different hash")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildCacheKeyHashValue 测试零分配缓存键哈希
|
||||
func TestBuildCacheKeyHashValue(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||
|
||||
hashValue := p.buildCacheKeyHashValue(ctx)
|
||||
if hashValue == 0 {
|
||||
t.Error("buildCacheKeyHashValue() should return non-zero hash")
|
||||
}
|
||||
|
||||
// 应该与 buildCacheKeyHash 结果一致
|
||||
hashKey, _ := p.buildCacheKeyHash(ctx)
|
||||
if hashValue != hashKey {
|
||||
t.Error("buildCacheKeyHashValue() should match buildCacheKeyHash()")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteCachedResponse 测试缓存响应写入
|
||||
func TestWriteCachedResponse(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
// 手动创建一个 Response 用于验证 writeCachedResponse 写入正确
|
||||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||
|
||||
entry := &cache.ProxyCacheEntry{
|
||||
Data: []byte("cached body"),
|
||||
Headers: map[string]string{"Content-Type": "text/html", "X-Cached": "true"},
|
||||
Status: 200,
|
||||
}
|
||||
|
||||
p.writeCachedResponse(ctx, entry)
|
||||
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Errorf("writeCachedResponse() status = %d, want 200", ctx.Response.StatusCode())
|
||||
}
|
||||
if string(ctx.Response.Body()) != "cached body" {
|
||||
t.Errorf("writeCachedResponse() body = %q, want %q", string(ctx.Response.Body()), "cached body")
|
||||
}
|
||||
ct := string(ctx.Response.Header.Peek("Content-Type"))
|
||||
if ct != "text/html" {
|
||||
t.Errorf("writeCachedResponse() Content-Type = %q, want %q", ct, "text/html")
|
||||
}
|
||||
xc := string(ctx.Response.Header.Peek("X-Cache"))
|
||||
if xc != "HIT" {
|
||||
t.Errorf("writeCachedResponse() X-Cache = %q, want HIT", xc)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetCache 测试 GetCache 方法
|
||||
func TestGetCache(t *testing.T) {
|
||||
// 启用缓存时
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
Cache: config.ProxyCacheConfig{
|
||||
Enabled: true,
|
||||
MaxAge: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
c := p.GetCache()
|
||||
if c == nil {
|
||||
t.Error("GetCache() should return non-nil when cache enabled")
|
||||
}
|
||||
|
||||
// 禁用缓存时
|
||||
cfg2 := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
}
|
||||
p2, _ := NewProxy(cfg2, targets, nil, nil)
|
||||
c2 := p2.GetCache()
|
||||
if c2 != nil {
|
||||
t.Error("GetCache() should return nil when cache disabled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetCacheStats 测试 GetCacheStats 方法
|
||||
func TestGetCacheStats(t *testing.T) {
|
||||
// 启用缓存时
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
Cache: config.ProxyCacheConfig{
|
||||
Enabled: true,
|
||||
MaxAge: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
stats := p.GetCacheStats()
|
||||
if stats == nil {
|
||||
t.Error("GetCacheStats() should return non-nil when cache enabled")
|
||||
}
|
||||
|
||||
// 禁用缓存时
|
||||
cfg2 := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
}
|
||||
p2, _ := NewProxy(cfg2, targets, nil, nil)
|
||||
stats2 := p2.GetCacheStats()
|
||||
if stats2 != nil {
|
||||
t.Error("GetCacheStats() should return nil when cache disabled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetCacheDuration 测试不同状态码的缓存时间计算
|
||||
func TestGetCacheDuration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cacheValid *config.ProxyCacheValidConfig
|
||||
maxAge time.Duration
|
||||
statusCode int
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
name: "no CacheValid config uses MaxAge",
|
||||
maxAge: 5 * time.Minute,
|
||||
statusCode: 200,
|
||||
expected: 5 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "2xx with CacheValid.OK set",
|
||||
cacheValid: &config.ProxyCacheValidConfig{
|
||||
OK: 10 * time.Minute,
|
||||
},
|
||||
statusCode: 200,
|
||||
expected: 10 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "2xx with CacheValid.OK=0 inherits MaxAge",
|
||||
cacheValid: &config.ProxyCacheValidConfig{
|
||||
OK: 0,
|
||||
},
|
||||
maxAge: 3 * time.Minute,
|
||||
statusCode: 201,
|
||||
expected: 3 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "301 redirect",
|
||||
cacheValid: &config.ProxyCacheValidConfig{
|
||||
Redirect: 1 * time.Hour,
|
||||
},
|
||||
statusCode: 301,
|
||||
expected: 1 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "302 redirect",
|
||||
cacheValid: &config.ProxyCacheValidConfig{
|
||||
Redirect: 30 * time.Minute,
|
||||
},
|
||||
statusCode: 302,
|
||||
expected: 30 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "302 with zero Redirect means no cache",
|
||||
cacheValid: &config.ProxyCacheValidConfig{
|
||||
Redirect: 0,
|
||||
},
|
||||
statusCode: 302,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "404",
|
||||
cacheValid: &config.ProxyCacheValidConfig{
|
||||
NotFound: 1 * time.Minute,
|
||||
},
|
||||
statusCode: 404,
|
||||
expected: 1 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "4xx client error",
|
||||
cacheValid: &config.ProxyCacheValidConfig{
|
||||
ClientError: 30 * time.Second,
|
||||
},
|
||||
statusCode: 400,
|
||||
expected: 30 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "5xx server error",
|
||||
cacheValid: &config.ProxyCacheValidConfig{
|
||||
ServerError: 0,
|
||||
},
|
||||
statusCode: 500,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "other status code",
|
||||
statusCode: 100,
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
Cache: config.ProxyCacheConfig{
|
||||
Enabled: true,
|
||||
MaxAge: tt.maxAge,
|
||||
},
|
||||
CacheValid: tt.cacheValid,
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
duration := p.getCacheDuration(tt.statusCode)
|
||||
if duration != tt.expected {
|
||||
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, duration, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBackgroundRefresh 测试后台缓存刷新(标记为 skip 因为需要真实网络)
|
||||
func TestBackgroundRefresh(t *testing.T) {
|
||||
t.Skip("skipping: requires real network connection and is timing-sensitive")
|
||||
ln := fasthttputil.NewInmemoryListener()
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
go func() {
|
||||
s := &fasthttp.Server{
|
||||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(200)
|
||||
ctx.SetBodyString("refreshed")
|
||||
ctx.Response.Header.Set("Content-Type", "text/plain")
|
||||
},
|
||||
}
|
||||
_ = s.Serve(ln)
|
||||
}()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
addr := ln.Addr().String()
|
||||
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
Cache: config.ProxyCacheConfig{
|
||||
Enabled: true,
|
||||
MaxAge: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
targets := []*loadbalance.Target{
|
||||
{URL: "http://" + addr},
|
||||
}
|
||||
targets[0].Healthy.Store(true)
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||
|
||||
// 设置缓存锁
|
||||
hashKey := uint64(12345)
|
||||
p.cache.AcquireLock(hashKey)
|
||||
|
||||
// 调用后台刷新(它会执行实际请求来刷新缓存)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
p.backgroundRefresh(ctx, targets[0], hashKey, "/api/test")
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// 等待完成
|
||||
select {
|
||||
case <-done:
|
||||
// 完成
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("backgroundRefresh() timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBackgroundRefresh_NoClient 测试后台刷新时客户端不存在的情况
|
||||
func TestBackgroundRefresh_NoClient(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
Cache: config.ProxyCacheConfig{
|
||||
Enabled: true,
|
||||
MaxAge: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
targets := []*loadbalance.Target{
|
||||
{URL: "http://nonexistent:9999"},
|
||||
}
|
||||
targets[0].Healthy.Store(true)
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
// 移除客户端
|
||||
p.mu.Lock()
|
||||
delete(p.clients, targets[0].URL)
|
||||
p.mu.Unlock()
|
||||
|
||||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||
hashKey := uint64(99999)
|
||||
p.cache.AcquireLock(hashKey)
|
||||
|
||||
// 应该不会 panic,直接返回
|
||||
p.backgroundRefresh(ctx, targets[0], hashKey, "/api/test")
|
||||
}
|
||||
|
||||
// TestServeHTTP_CacheHit 测试缓存命中路径
|
||||
func TestServeHTTP_CacheHit(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
Cache: config.ProxyCacheConfig{
|
||||
Enabled: true,
|
||||
MaxAge: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
targets[0].Healthy.Store(true)
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
// 预填充缓存
|
||||
ctx := testutil.NewRequestCtx("GET", "/api/cached")
|
||||
hashKey, origKey := p.buildCacheKeyHash(ctx)
|
||||
p.cache.Set(hashKey, origKey, []byte("cached!"), map[string]string{
|
||||
"Content-Type": "text/plain",
|
||||
}, 200, 10*time.Second)
|
||||
|
||||
// 执行请求
|
||||
p.ServeHTTP(ctx)
|
||||
|
||||
// 应该返回缓存的响应
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Errorf("ServeHTTP() status = %d, want 200", ctx.Response.StatusCode())
|
||||
}
|
||||
if string(ctx.Response.Body()) != "cached!" {
|
||||
t.Errorf("ServeHTTP() body = %q, want %q", string(ctx.Response.Body()), "cached!")
|
||||
}
|
||||
xc := string(ctx.Response.Header.Peek("X-Cache"))
|
||||
if xc != "HIT" {
|
||||
t.Errorf("ServeHTTP() X-Cache = %q, want HIT", xc)
|
||||
}
|
||||
}
|
||||
|
||||
// TestServeHTTP_ClientNil 测试客户端为 nil 时的行为
|
||||
func TestServeHTTP_ClientNil(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
NextUpstream: config.NextUpstreamConfig{
|
||||
Tries: 2,
|
||||
},
|
||||
}
|
||||
targets := []*loadbalance.Target{
|
||||
{URL: "http://backend1:8080"},
|
||||
{URL: "http://backend2:8080"},
|
||||
}
|
||||
for _, target := range targets {
|
||||
target.Healthy.Store(true)
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
// 移除所有客户端
|
||||
p.mu.Lock()
|
||||
p.clients = make(map[string]*fasthttp.HostClient)
|
||||
p.mu.Unlock()
|
||||
|
||||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||
p.ServeHTTP(ctx)
|
||||
|
||||
// 所有客户端都不存在,应该返回 502
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusBadGateway {
|
||||
t.Errorf("ServeHTTP() status = %d, want 502", ctx.Response.StatusCode())
|
||||
}
|
||||
}
|
||||
|
||||
// TestServeHTTP_WithRedirectRewrite 测试带 redirect_rewrite 的缓存命中
|
||||
func TestServeHTTP_WithRedirectRewrite_CacheHit(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
Cache: config.ProxyCacheConfig{
|
||||
Enabled: true,
|
||||
MaxAge: 10 * time.Second,
|
||||
},
|
||||
RedirectRewrite: &config.RedirectRewriteConfig{
|
||||
Mode: "off", // 关闭改写
|
||||
},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
targets[0].Healthy.Store(true)
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||
hashKey, origKey := p.buildCacheKeyHash(ctx)
|
||||
p.cache.Set(hashKey, origKey, []byte("ok"), map[string]string{
|
||||
"Content-Type": "text/plain",
|
||||
}, 200, 10*time.Second)
|
||||
|
||||
p.ServeHTTP(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Errorf("ServeHTTP() status = %d, want 200", ctx.Response.StatusCode())
|
||||
}
|
||||
}
|
||||
465
internal/server/server_extended_test.go
Normal file
465
internal/server/server_extended_test.go
Normal file
@ -0,0 +1,465 @@
|
||||
// Package server 提供 HTTP 服务器的核心实现测试补充。
|
||||
//
|
||||
// 该文件补充以下测试覆盖:
|
||||
// - 多模式启动逻辑测试(single/vhost/multi_server/auto)
|
||||
// - 多服务器模式 shutdownServers 函数测试
|
||||
// - 监听器创建测试(TCP/Unix socket)
|
||||
// - StopWithTimeout 超时行为测试
|
||||
// - GracefulStop 超时行为测试
|
||||
// - 中间件链错误路径测试
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/resolver"
|
||||
)
|
||||
|
||||
// TestServer_GetMode_Single 测试单服务器模式
|
||||
func TestServer_GetMode_Single(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Mode: config.ServerModeSingle,
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":0",
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s.config.GetMode() != config.ServerModeSingle {
|
||||
t.Errorf("Expected mode single, got %s", s.config.GetMode())
|
||||
}
|
||||
}
|
||||
|
||||
// TestServer_GetMode_VHost 测试虚拟主机模式
|
||||
func TestServer_GetMode_VHost(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Mode: config.ServerModeVHost,
|
||||
Servers: []config.ServerConfig{
|
||||
{Listen: ":0", Name: "host1.example.com"},
|
||||
{Listen: ":0", Name: "host2.example.com"},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s.config.GetMode() != config.ServerModeVHost {
|
||||
t.Errorf("Expected mode vhost, got %s", s.config.GetMode())
|
||||
}
|
||||
}
|
||||
|
||||
// TestServer_GetMode_MultiServer 测试多服务器模式
|
||||
func TestServer_GetMode_MultiServer(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Mode: config.ServerModeMultiServer,
|
||||
Servers: []config.ServerConfig{
|
||||
{Listen: ":8080", Name: "server1"},
|
||||
{Listen: ":8081", Name: "server2"},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s.config.GetMode() != config.ServerModeMultiServer {
|
||||
t.Errorf("Expected mode multi_server, got %s", s.config.GetMode())
|
||||
}
|
||||
}
|
||||
|
||||
// TestServer_GetMode_Auto 测试自动模式
|
||||
func TestServer_GetMode_Auto(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Mode: config.ServerModeAuto,
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":0",
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
mode := s.config.GetMode()
|
||||
if mode != config.ServerModeSingle {
|
||||
t.Errorf("Expected auto to resolve to single, got %s", mode)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShutdownServers_Empty 测试空服务器列表关闭
|
||||
func TestShutdownServers_Empty(t *testing.T) {
|
||||
err := shutdownServers(nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error for empty servers, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShutdownServers_NilServer 测试含 nil 的服务器列表关闭
|
||||
func TestShutdownServers_NilServer(t *testing.T) {
|
||||
servers := []*fasthttp.Server{nil, nil}
|
||||
err := shutdownServers(nil, servers)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error with nil servers, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShutdownServers_Timeout 测试关闭超时
|
||||
func TestShutdownServers_Timeout(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
|
||||
fastSrv := &fasthttp.Server{
|
||||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||||
select {}
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = fastSrv.Serve(ln)
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err = shutdownServers(ctx, []*fasthttp.Server{fastSrv})
|
||||
// context 超时后 shutdownServers 会返回 ctx.Err()
|
||||
if err != nil && err != context.DeadlineExceeded {
|
||||
t.Errorf("Expected context.DeadlineExceeded or nil, got: %v", err)
|
||||
}
|
||||
|
||||
_ = fastSrv.Shutdown()
|
||||
_ = ln.Close()
|
||||
}
|
||||
|
||||
// TestStopWithTimeout_DefaultTimeout 测试零超时使用默认值
|
||||
func TestStopWithTimeout_DefaultTimeout(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":0",
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
err := s.StopWithTimeout(0)
|
||||
if err != nil {
|
||||
t.Errorf("StopWithTimeout(0) should succeed, got %v", err)
|
||||
}
|
||||
|
||||
err = s.StopWithTimeout(-1 * time.Second)
|
||||
if err != nil {
|
||||
t.Errorf("StopWithTimeout(-1s) should succeed, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStopWithTimeout_MultiServerMode 测试多服务器模式停止
|
||||
func TestStopWithTimeout_MultiServerMode(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Mode: config.ServerModeMultiServer,
|
||||
Servers: []config.ServerConfig{
|
||||
{Listen: ":0", Name: "server1"},
|
||||
{Listen: ":0", Name: "server2"},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
err := s.StopWithTimeout(1 * time.Second)
|
||||
if err != nil {
|
||||
t.Errorf("StopWithTimeout on non-started multi-server should succeed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGracefulStop_Timeout 测试优雅停止超时
|
||||
func TestGracefulStop_Timeout(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":0",
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
s.fastServer = &fasthttp.Server{
|
||||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetBodyString("ok")
|
||||
},
|
||||
}
|
||||
s.running = true
|
||||
|
||||
err := s.GracefulStop(100 * time.Millisecond)
|
||||
if err != nil {
|
||||
t.Errorf("GracefulStop should succeed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestServer_SetUpgradeManager 测试设置升级管理器
|
||||
func TestServer_SetUpgradeManager(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":0",
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
mgr := NewUpgradeManager(s)
|
||||
|
||||
s.SetUpgradeManager(mgr)
|
||||
if s.upgradeManager != mgr {
|
||||
t.Error("upgradeManager not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// mockResolver 用于测试的 mock DNS 解析器
|
||||
type mockResolver struct{}
|
||||
|
||||
func (m *mockResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
|
||||
return []string{"127.0.0.1"}, nil
|
||||
}
|
||||
|
||||
func (m *mockResolver) LookupHostWithCache(ctx context.Context, host string) ([]string, error) {
|
||||
return []string{"127.0.0.1"}, nil
|
||||
}
|
||||
|
||||
func (m *mockResolver) Refresh(host string) error { return nil }
|
||||
func (m *mockResolver) Start() error { return nil }
|
||||
func (m *mockResolver) Stop() error { return nil }
|
||||
func (m *mockResolver) Stats() resolver.Stats { return resolver.Stats{} }
|
||||
|
||||
// TestServer_Resolver 测试 DNS 解析器设置
|
||||
func TestServer_Resolver(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":0",
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
if s.GetResolver() != nil {
|
||||
t.Error("Expected nil resolver initially")
|
||||
}
|
||||
|
||||
mockRes := &mockResolver{}
|
||||
s.SetResolver(mockRes)
|
||||
|
||||
if s.GetResolver() == nil {
|
||||
t.Error("Resolver not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateListener_TCP 测试 TCP 监听器创建
|
||||
func TestCreateListener_TCP(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":0",
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
serverCfg := &cfg.Servers[0]
|
||||
|
||||
ln, err := s.createListener(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("createListener failed: %v", err)
|
||||
}
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
if ln == nil {
|
||||
t.Fatal("Expected non-nil listener")
|
||||
}
|
||||
|
||||
addr := ln.Addr().(*net.TCPAddr)
|
||||
if addr.Port == 0 {
|
||||
t.Error("Expected non-zero port")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateListener_InvalidTCP 测试无效 TCP 监听地址
|
||||
func TestCreateListener_InvalidTCP(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "invalid:address:format",
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
_, err := s.createListener(&cfg.Servers[0])
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid listen address")
|
||||
}
|
||||
}
|
||||
|
||||
// TestListenerManagement 测试监听器管理
|
||||
func TestListenerManagement(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":0",
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
ln1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener 1: %v", err)
|
||||
}
|
||||
defer func() { _ = ln1.Close() }()
|
||||
|
||||
ln2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener 2: %v", err)
|
||||
}
|
||||
defer func() { _ = ln2.Close() }()
|
||||
|
||||
s.SetListeners([]net.Listener{ln1, ln2})
|
||||
|
||||
listeners := s.GetListeners()
|
||||
if len(listeners) != 2 {
|
||||
t.Errorf("Expected 2 listeners, got %d", len(listeners))
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithGoroutinePoolAndFileCache 测试同时启用 GoroutinePool 和 FileCache
|
||||
func TestStart_WithGoroutinePoolAndFileCache(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":0",
|
||||
}},
|
||||
Performance: config.PerformanceConfig{
|
||||
GoroutinePool: config.GoroutinePoolConfig{
|
||||
Enabled: true,
|
||||
MaxWorkers: 50,
|
||||
MinWorkers: 5,
|
||||
},
|
||||
FileCache: config.FileCacheConfig{
|
||||
MaxEntries: 500,
|
||||
MaxSize: 50 * 1024 * 1024,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
|
||||
if s.config.Performance.GoroutinePool.Enabled != true {
|
||||
t.Error("GoroutinePool should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestServer_GetHandler_NilThenSet 测试 handler 的 nil 到设置
|
||||
func TestServer_GetHandler_NilThenSet(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":0",
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
if s.GetHandler() != nil {
|
||||
t.Error("Expected nil handler initially")
|
||||
}
|
||||
|
||||
testHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetBodyString("test handler response")
|
||||
}
|
||||
s.handler = testHandler
|
||||
|
||||
got := s.GetHandler()
|
||||
if got == nil {
|
||||
t.Error("Expected non-nil handler after setting")
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
got(ctx)
|
||||
if string(ctx.Response.Body()) != "test handler response" {
|
||||
t.Errorf("Handler response = %q, want %q", string(ctx.Response.Body()), "test handler response")
|
||||
}
|
||||
}
|
||||
|
||||
// TestServer_TrackStats_Concurrent 测试并发统计追踪
|
||||
func TestServer_TrackStats_Concurrent(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":0",
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetBodyString("ok")
|
||||
}
|
||||
|
||||
wrappedHandler := s.trackStats(handler)
|
||||
|
||||
const numGoroutines = 100
|
||||
done := make(chan bool, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
wrappedHandler(ctx)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
if s.requests.Load() != int64(numGoroutines) {
|
||||
t.Errorf("Expected %d requests, got %d", numGoroutines, s.requests.Load())
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildMiddlewareChain_BodyLimit 测试请求体限制中间件
|
||||
func TestBuildMiddlewareChain_BodyLimit(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Logging: config.LoggingConfig{},
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":0",
|
||||
Proxy: []config.ProxyConfig{{
|
||||
Path: "/api/",
|
||||
ClientMaxBodySize: "1MB",
|
||||
}},
|
||||
ClientMaxBodySize: "10MB",
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
|
||||
if err != nil {
|
||||
t.Errorf("buildMiddlewareChain failed: %v", err)
|
||||
}
|
||||
if chain == nil {
|
||||
t.Error("Expected non-nil chain")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildMiddlewareChain_BodyLimit_Invalid 测试无效的请求体限制
|
||||
func TestBuildMiddlewareChain_BodyLimit_Invalid(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Logging: config.LoggingConfig{},
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":0",
|
||||
ClientMaxBodySize: "invalid_size",
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
_, err := s.buildMiddlewareChain(&cfg.Servers[0])
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid body limit size")
|
||||
}
|
||||
}
|
||||
554
internal/stream/stream_coverage_test.go
Normal file
554
internal/stream/stream_coverage_test.go
Normal file
@ -0,0 +1,554 @@
|
||||
// Package stream 提供流代理覆盖测试。
|
||||
//
|
||||
// 该文件补充测试 stream.go 中未覆盖的方法:
|
||||
// - ipHash.Select() (空 IP)
|
||||
// - handleConnection() 连接处理
|
||||
// - getOrCreateSession() 会话创建
|
||||
// - handleBackendResponse() 后端响应处理
|
||||
// - Stats 完整统计
|
||||
//
|
||||
// 作者:xfy
|
||||
package stream
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestIPHashSelect 测试 ipHash 的 Select 方法(空字符串 IP)
|
||||
func TestIPHashSelect(t *testing.T) {
|
||||
targets := []*Target{
|
||||
{addr: "localhost:8001"},
|
||||
{addr: "localhost:8002"},
|
||||
}
|
||||
for _, target := range targets {
|
||||
target.healthy.Store(true)
|
||||
}
|
||||
|
||||
ih := newIPHash()
|
||||
|
||||
// Select() 使用空字符串作为 IP
|
||||
selected := ih.Select(targets)
|
||||
if selected == nil {
|
||||
t.Error("Select() with empty IP should return a target")
|
||||
}
|
||||
|
||||
// 多次调用应返回相同目标(确定性哈希)
|
||||
selected2 := ih.Select(targets)
|
||||
if selected != selected2 {
|
||||
t.Error("Select() with same empty IP should be consistent")
|
||||
}
|
||||
|
||||
// 无健康目标时应返回 nil
|
||||
for _, target := range targets {
|
||||
target.healthy.Store(false)
|
||||
}
|
||||
selected = ih.Select(targets)
|
||||
if selected != nil {
|
||||
t.Error("Select() with no healthy targets should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSelectByIPNoHealthy 测试 SelectByIP 无健康目标
|
||||
func TestSelectByIPNoHealthy(t *testing.T) {
|
||||
targets := []*Target{
|
||||
{addr: "localhost:8001"},
|
||||
{addr: "localhost:8002"},
|
||||
}
|
||||
|
||||
ih := newIPHash()
|
||||
selected := ih.(*ipHash).SelectByIP(targets, "192.168.1.1")
|
||||
if selected != nil {
|
||||
t.Error("SelectByIP() with no healthy targets should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWeightedRoundRobinZeroWeight 测试零权重处理
|
||||
func TestWeightedRoundRobinZeroWeight(t *testing.T) {
|
||||
targets := []*Target{
|
||||
{addr: "localhost:8001", weight: 0},
|
||||
{addr: "localhost:8002", weight: -1},
|
||||
}
|
||||
for _, target := range targets {
|
||||
target.healthy.Store(true)
|
||||
}
|
||||
|
||||
wrr := newWeightedRoundRobin().(*weightedRoundRobin)
|
||||
|
||||
// 权重为 0 或负数应视为权重 1
|
||||
for i := 0; i < 4; i++ {
|
||||
selected := wrr.Select(targets)
|
||||
if selected == nil {
|
||||
t.Error("Select() should return target with zero/negative weight")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleConnection 测试 handleConnection 方法
|
||||
func TestHandleConnection(t *testing.T) {
|
||||
s := NewServer()
|
||||
|
||||
// 添加上游配置
|
||||
targets := []TargetSpec{
|
||||
{Addr: "127.0.0.1:29001", Weight: 1},
|
||||
}
|
||||
_ = s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{})
|
||||
s.upstreams["test"].targets[0].healthy.Store(true)
|
||||
|
||||
// 创建模拟客户端连接(不会实际建立连接,测试无上游路径)
|
||||
s.mu.Lock()
|
||||
// 设置上游为空,测试无上游配置路径
|
||||
s.upstreams = make(map[string]*Upstream)
|
||||
s.mu.Unlock()
|
||||
|
||||
// 创建一对连接
|
||||
serverLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() { _ = serverLn.Close() }()
|
||||
|
||||
clientConn, err := net.Dial("tcp", serverLn.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial: %v", err)
|
||||
}
|
||||
defer func() { _ = clientConn.Close() }()
|
||||
|
||||
serverConn, err := serverLn.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to accept: %v", err)
|
||||
}
|
||||
defer func() { _ = serverConn.Close() }()
|
||||
|
||||
// 测试无上游配置的 handleConnection
|
||||
s.handleConnection(clientConn, "127.0.0.1:0")
|
||||
}
|
||||
|
||||
// TestHandleConnection_NoHealthyTarget 测试无健康目标路径
|
||||
func TestHandleConnection_NoHealthyTarget(t *testing.T) {
|
||||
s := NewServer()
|
||||
|
||||
// 添加不健康的上游
|
||||
targets := []TargetSpec{
|
||||
{Addr: "127.0.0.1:29002", Weight: 1},
|
||||
}
|
||||
_ = s.AddUpstream("test2", targets, "round_robin", HealthCheckSpec{})
|
||||
// 目标不健康(默认 false)
|
||||
|
||||
serverLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() { _ = serverLn.Close() }()
|
||||
|
||||
clientConn, err := net.Dial("tcp", serverLn.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial: %v", err)
|
||||
}
|
||||
defer func() { _ = clientConn.Close() }()
|
||||
|
||||
serverConn, err := serverLn.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to accept: %v", err)
|
||||
}
|
||||
defer func() { _ = serverConn.Close() }()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.handleConnection(clientConn, "127.0.0.1:0")
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// 完成
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("handleConnection() timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleConnection_DialFail 测试连接目标失败路径
|
||||
func TestHandleConnection_DialFail(t *testing.T) {
|
||||
s := NewServer()
|
||||
|
||||
// 添加上游,目标不可达
|
||||
targets := []TargetSpec{
|
||||
{Addr: "127.0.0.1:29999", Weight: 1},
|
||||
}
|
||||
_ = s.AddUpstream("test3", targets, "round_robin", HealthCheckSpec{})
|
||||
s.upstreams["test3"].targets[0].healthy.Store(true)
|
||||
|
||||
serverLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() { _ = serverLn.Close() }()
|
||||
|
||||
clientConn, err := net.Dial("tcp", serverLn.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial: %v", err)
|
||||
}
|
||||
defer func() { _ = clientConn.Close() }()
|
||||
|
||||
serverConn, err := serverLn.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to accept: %v", err)
|
||||
}
|
||||
defer func() { _ = serverConn.Close() }()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.handleConnection(clientConn, "127.0.0.1:0")
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// 完成 - 连接目标失败后应标记为不健康
|
||||
if s.upstreams["test3"].targets[0].healthy.Load() {
|
||||
t.Error("Target should be marked unhealthy after dial failure")
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("handleConnection() timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetOrCreateSession 测试 getOrCreateSession 方法
|
||||
func TestGetOrCreateSession(t *testing.T) {
|
||||
// 创建 UDP 连接
|
||||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
// 创建上游
|
||||
upstream := &Upstream{
|
||||
targets: []*Target{{addr: "127.0.0.1:29003"}},
|
||||
balancer: newRoundRobin(),
|
||||
}
|
||||
upstream.targets[0].healthy.Store(true)
|
||||
|
||||
srv := newUDPServer(conn, upstream, 1*time.Minute)
|
||||
|
||||
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29010")
|
||||
|
||||
// 第一次调用 - 应该创建新会话(但由于后端不可达,应该失败)
|
||||
session, err := srv.getOrCreateSession(clientAddr)
|
||||
if err != nil {
|
||||
// 预期失败,因为后端不可达
|
||||
return
|
||||
}
|
||||
if session == nil {
|
||||
t.Error("getOrCreateSession() should return a session")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetOrCreateSession_DoubleCheck 测试双重检查锁定
|
||||
func TestGetOrCreateSession_DoubleCheck(t *testing.T) {
|
||||
// 创建 UDP 连接
|
||||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
// 创建上游
|
||||
upstream := &Upstream{
|
||||
targets: []*Target{{addr: "127.0.0.1:29004"}},
|
||||
balancer: newRoundRobin(),
|
||||
}
|
||||
upstream.targets[0].healthy.Store(true)
|
||||
|
||||
srv := newUDPServer(conn, upstream, 1*time.Minute)
|
||||
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29011")
|
||||
|
||||
// 手动创建一个会话来测试双重检查
|
||||
srv.mu.Lock()
|
||||
testSession := &udpSession{
|
||||
clientAddr: clientAddr,
|
||||
lastActive: time.Now(),
|
||||
srv: srv,
|
||||
}
|
||||
srv.sessions[sessionKey(clientAddr)] = testSession
|
||||
srv.mu.Unlock()
|
||||
|
||||
// 再次获取应该返回现有会话
|
||||
session, err := srv.getOrCreateSession(clientAddr)
|
||||
if err != nil {
|
||||
t.Errorf("getOrCreateSession() should not error for existing session: %v", err)
|
||||
}
|
||||
if session != testSession {
|
||||
t.Error("getOrCreateSession() should return existing session")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetOrCreateSession_NoHealthyTarget 测试无健康目标
|
||||
func TestGetOrCreateSession_NoHealthyTarget(t *testing.T) {
|
||||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
upstream := &Upstream{
|
||||
targets: []*Target{{addr: "127.0.0.1:29005"}},
|
||||
balancer: newRoundRobin(),
|
||||
}
|
||||
// 不设置 healthy,默认为 false
|
||||
|
||||
srv := newUDPServer(conn, upstream, 1*time.Minute)
|
||||
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29012")
|
||||
|
||||
session, err := srv.getOrCreateSession(clientAddr)
|
||||
if err == nil {
|
||||
t.Error("getOrCreateSession() should return error when no healthy target")
|
||||
}
|
||||
if session != nil {
|
||||
t.Error("getOrCreateSession() should return nil session on error")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetOrCreateSession_InvalidTargetAddr 测试无效目标地址
|
||||
func TestGetOrCreateSession_InvalidTargetAddr(t *testing.T) {
|
||||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
upstream := &Upstream{
|
||||
targets: []*Target{{addr: "invalid-address"}},
|
||||
balancer: newRoundRobin(),
|
||||
}
|
||||
upstream.targets[0].healthy.Store(true)
|
||||
|
||||
srv := newUDPServer(conn, upstream, 1*time.Minute)
|
||||
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29013")
|
||||
|
||||
session, err := srv.getOrCreateSession(clientAddr)
|
||||
if err == nil {
|
||||
t.Error("getOrCreateSession() should return error for invalid target address")
|
||||
}
|
||||
if session != nil {
|
||||
t.Error("getOrCreateSession() should return nil session on error")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleBackendResponse 测试 handleBackendResponse 超时清理路径
|
||||
func TestHandleBackendResponse(t *testing.T) {
|
||||
// 创建 UDP 连接(服务端)
|
||||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
// 创建上游
|
||||
upstream := &Upstream{
|
||||
targets: []*Target{{addr: "127.0.0.1:29006"}},
|
||||
balancer: newRoundRobin(),
|
||||
}
|
||||
upstream.targets[0].healthy.Store(true)
|
||||
|
||||
srv := newUDPServer(conn, upstream, 50*time.Millisecond) // 短超时
|
||||
|
||||
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29014")
|
||||
|
||||
// 创建目标连接(监听器用于接收目标连接)
|
||||
targetAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29006")
|
||||
targetConn, _ := net.ListenUDP("udp", targetAddr)
|
||||
defer func() { _ = targetConn.Close() }()
|
||||
|
||||
// 创建会话时需要先添加到 WaitGroup(handleBackendResponse 会调用 Done)
|
||||
srv.wg.Add(1)
|
||||
session := &udpSession{
|
||||
clientAddr: clientAddr,
|
||||
targetConn: targetConn,
|
||||
lastActive: time.Now().Add(-2 * time.Hour), // 很久以前
|
||||
srv: srv,
|
||||
}
|
||||
|
||||
// 添加会话到服务器
|
||||
srv.sessions[sessionKey(clientAddr)] = session
|
||||
|
||||
// 启动后端响应处理
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
session.handleBackendResponse()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// 应该因为超时而清理会话
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("handleBackendResponse() timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleBackendResponse_ErrorPath 测试 handleBackendResponse 错误路径
|
||||
func TestHandleBackendResponse_ErrorPath(t *testing.T) {
|
||||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
upstream := &Upstream{
|
||||
targets: []*Target{{addr: "127.0.0.1:29007"}},
|
||||
balancer: newRoundRobin(),
|
||||
}
|
||||
|
||||
srv := newUDPServer(conn, upstream, 10*time.Millisecond)
|
||||
|
||||
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29015")
|
||||
|
||||
// 创建一个已关闭的连接作为 targetConn
|
||||
targetUDPAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29007")
|
||||
targetConn, err := net.DialUDP("udp", nil, targetUDPAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create target connection: %v", err)
|
||||
}
|
||||
|
||||
session := &udpSession{
|
||||
clientAddr: clientAddr,
|
||||
targetConn: targetConn,
|
||||
lastActive: time.Now(),
|
||||
srv: srv,
|
||||
}
|
||||
|
||||
// 创建会话时需要先添加到 WaitGroup(handleBackendResponse 会调用 Done)
|
||||
srv.wg.Add(1)
|
||||
srv.sessions[sessionKey(clientAddr)] = session
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
session.handleBackendResponse()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// 完成
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("handleBackendResponse() timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// TestServe_InvalidUpstream 测试 serve 方法无效上游路径
|
||||
func TestServe_InvalidUpstream(t *testing.T) {
|
||||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
upstream := &Upstream{
|
||||
targets: []*Target{},
|
||||
balancer: newRoundRobin(),
|
||||
}
|
||||
|
||||
srv := newUDPServer(conn, upstream, 50*time.Millisecond)
|
||||
|
||||
// 启动 serve
|
||||
go srv.serve()
|
||||
|
||||
// 立即停止
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
srv.stop()
|
||||
}
|
||||
|
||||
// TestServerStop 测试 Server.Stop 方法
|
||||
func TestServerStop(t *testing.T) {
|
||||
s := NewServer()
|
||||
|
||||
// 添加上游
|
||||
targets := []TargetSpec{
|
||||
{Addr: "127.0.0.1:29008", Weight: 1},
|
||||
}
|
||||
hcSpec := HealthCheckSpec{
|
||||
Enabled: true,
|
||||
Interval: 1 * time.Second,
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}
|
||||
_ = s.AddUpstream("stop_test", targets, "round_robin", hcSpec)
|
||||
|
||||
// 监听 TCP
|
||||
err := s.ListenTCP("127.0.0.1:29009")
|
||||
if err != nil {
|
||||
t.Fatalf("ListenTCP failed: %v", err)
|
||||
}
|
||||
|
||||
// 监听 UDP
|
||||
err = s.ListenUDP("127.0.0.1:29010", "stop_test", 1*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("ListenUDP failed: %v", err)
|
||||
}
|
||||
|
||||
// 启动
|
||||
err = s.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 停止
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Stop() error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStatsComplete 测试 Stats 完整统计
|
||||
func TestStatsComplete(t *testing.T) {
|
||||
s := NewServer()
|
||||
|
||||
// 添加 TCP 监听
|
||||
err := s.ListenTCP("127.0.0.1:29020")
|
||||
if err != nil {
|
||||
t.Fatalf("ListenTCP failed: %v", err)
|
||||
}
|
||||
|
||||
// 添加上游
|
||||
targets := []TargetSpec{{Addr: "127.0.0.1:29021", Weight: 1}}
|
||||
_ = s.AddUpstream("stats_test", targets, "round_robin", HealthCheckSpec{})
|
||||
|
||||
// 添加 UDP 监听
|
||||
err = s.ListenUDP("127.0.0.1:29022", "stats_test", 1*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("ListenUDP failed: %v", err)
|
||||
}
|
||||
|
||||
stats := s.Stats()
|
||||
if stats.Listeners != 2 {
|
||||
t.Errorf("Stats().Listeners = %d, want 2 (1 TCP + 1 UDP)", stats.Listeners)
|
||||
}
|
||||
if stats.Upstreams != 1 {
|
||||
t.Errorf("Stats().Upstreams = %d, want 1", stats.Upstreams)
|
||||
}
|
||||
if stats.Connections != 0 {
|
||||
t.Errorf("Stats().Connections = %d, want 0", stats.Connections)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAcceptLoop_Error 测试 acceptLoop 错误处理路径
|
||||
func TestAcceptLoop_Error(t *testing.T) {
|
||||
s := NewServer()
|
||||
s.running.Store(true)
|
||||
|
||||
// 创建一个立即关闭的监听器
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
|
||||
// 在另一个 goroutine 中关闭监听器
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = ln.Close()
|
||||
}()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.acceptLoop("test", ln)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// 完成
|
||||
case <-time.After(2 * time.Second):
|
||||
s.running.Store(false)
|
||||
<-done
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user