From 8f79fb679768621a707ae0a9a4fd1f9e82b9e962 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 22 Apr 2026 18:28:28 +0800 Subject: [PATCH] =?UTF-8?q?test(config,handler,loadbalance,proxy):=20?= =?UTF-8?q?=E6=89=A9=E5=B1=95=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95=E8=A6=86?= =?UTF-8?q?=E7=9B=96=E7=8E=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加以下测试: - validate_test.go: Rewrite、NextUpstream、DefaultServer、Mode、 ListenConflicts、HTTP2、RedirectRewrite 验证测试 - sendfile_test.go: 无效文件描述符、零长度传输、部分传输、 带偏移量传输测试 - balancer_test.go: ConsistentHash Select/SelectExcluding、 RandomBalancer 边界条件和 Power of Two Choices 测试 - health_test.go: MarkHealthy/MarkUnhealthy 与 SlowStartManager 集成测试 Co-Authored-By: Claude Opus 4.7 --- internal/config/validate_test.go | 520 ++++++++++++++++++++++++++ internal/handler/sendfile_test.go | 193 ++++++++++ internal/loadbalance/balancer_test.go | 198 ++++++++-- internal/proxy/health_test.go | 57 +++ 4 files changed, 947 insertions(+), 21 deletions(-) diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go index 8313a7d..2591739 100644 --- a/internal/config/validate_test.go +++ b/internal/config/validate_test.go @@ -1371,3 +1371,523 @@ func TestValidateStaticsWithTryFiles(t *testing.T) { }) } } + +func TestValidateRewrite(t *testing.T) { + tests := []struct { + name string + errMsg string + config RewriteRule + wantErr bool + }{ + { + name: "有效重写规则", + config: RewriteRule{ + Pattern: "^/old/(.*)$", + Replacement: "/new/$1", + Flag: "last", + }, + wantErr: false, + }, + { + name: "有效redirect标志", + config: RewriteRule{ + Pattern: "^/old$", + Replacement: "/new", + Flag: "redirect", + }, + wantErr: false, + }, + { + name: "有效permanent标志", + config: RewriteRule{ + Pattern: "^/old$", + Replacement: "/new", + Flag: "permanent", + }, + wantErr: false, + }, + { + name: "有效break标志", + config: RewriteRule{ + Pattern: "^/api/(.*)$", + Replacement: "/backend/$1", + Flag: "break", + }, + wantErr: false, + }, + { + name: "空标志有效", + config: RewriteRule{ + Pattern: "^/old$", + Replacement: "/new", + }, + wantErr: false, + }, + { + name: "Pattern缺失", + config: RewriteRule{ + Replacement: "/new", + }, + wantErr: true, + errMsg: "pattern 必填", + }, + { + name: "无效flag", + config: RewriteRule{ + Pattern: "^/old$", + Replacement: "/new", + Flag: "invalid", + }, + wantErr: true, + errMsg: "无效的 flag", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateRewrite(&tt.config) + if tt.wantErr { + if err == nil { + t.Errorf("validateRewrite() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateRewrite() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateRewrite() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateNextUpstream(t *testing.T) { + tests := []struct { + name string + errMsg string + config NextUpstreamConfig + wantErr bool + }{ + { + name: "空配置有效", + config: NextUpstreamConfig{}, + wantErr: false, + }, + { + name: "有效重试配置", + config: NextUpstreamConfig{ + Tries: 3, + HTTPCodes: []int{500, 502, 503, 504}, + }, + wantErr: false, + }, + { + name: "仅Tries配置", + config: NextUpstreamConfig{ + Tries: 2, + }, + wantErr: false, + }, + { + name: "仅HTTPCodes配置", + config: NextUpstreamConfig{ + HTTPCodes: []int{500, 502}, + }, + wantErr: false, + }, + { + name: "负数Tries", + config: NextUpstreamConfig{ + Tries: -1, + }, + wantErr: true, + errMsg: "tries 不能为负数", + }, + { + name: "无效HTTP状态码-过低", + config: NextUpstreamConfig{ + HTTPCodes: []int{99}, + }, + wantErr: true, + errMsg: "无效的 HTTP 状态码", + }, + { + name: "无效HTTP状态码-过高", + config: NextUpstreamConfig{ + HTTPCodes: []int{600}, + }, + wantErr: true, + errMsg: "无效的 HTTP 状态码", + }, + { + name: "有效边界状态码100", + config: NextUpstreamConfig{ + HTTPCodes: []int{100}, + }, + wantErr: false, + }, + { + name: "有效边界状态码599", + config: NextUpstreamConfig{ + HTTPCodes: []int{599}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateNextUpstream(&tt.config) + if tt.wantErr { + if err == nil { + t.Errorf("validateNextUpstream() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateNextUpstream() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateNextUpstream() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateDefaultServer(t *testing.T) { + tests := []struct { + name string + errMsg string + servers []ServerConfig + wantErr bool + }{ + { + name: "空服务器列表", + servers: []ServerConfig{}, + wantErr: false, + }, + { + name: "无默认服务器", + servers: []ServerConfig{ + {Listen: ":8080"}, + {Listen: ":8081"}, + }, + wantErr: false, + }, + { + name: "单个默认服务器", + servers: []ServerConfig{ + {Listen: ":8080", Default: true}, + {Listen: ":8081"}, + }, + wantErr: false, + }, + { + name: "多个默认服务器", + servers: []ServerConfig{ + {Listen: ":8080", Default: true}, + {Listen: ":8081", Default: true}, + }, + wantErr: true, + errMsg: "只能有一个 default: true 服务器", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateDefaultServer(tt.servers) + if tt.wantErr { + if err == nil { + t.Errorf("validateDefaultServer() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateDefaultServer() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateDefaultServer() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateMode(t *testing.T) { + tests := []struct { + name string + errMsg string + mode ServerMode + wantErr bool + }{ + {name: "空模式有效", mode: "", wantErr: false}, + {name: "auto模式有效", mode: ServerModeAuto, wantErr: false}, + {name: "single模式有效", mode: ServerModeSingle, wantErr: false}, + {name: "vhost模式有效", mode: ServerModeVHost, wantErr: false}, + {name: "multi_server模式有效", mode: ServerModeMultiServer, wantErr: false}, + {name: "无效模式", mode: "invalid", wantErr: true, errMsg: "无效的 mode"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateMode(tt.mode) + if tt.wantErr { + if err == nil { + t.Errorf("validateMode() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateMode() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateMode() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateListenConflicts(t *testing.T) { + tests := []struct { + name string + errMsg string + servers []ServerConfig + mode ServerMode + wantErr bool + }{ + { + name: "非multi_server模式跳过验证", + servers: []ServerConfig{{Listen: ":8080"}}, + mode: ServerModeSingle, + wantErr: false, + }, + { + name: "multi_server模式有效配置", + servers: []ServerConfig{ + {Listen: ":8080"}, + {Listen: ":8081"}, + }, + mode: ServerModeMultiServer, + wantErr: false, + }, + { + name: "multi_server模式缺少listen", + servers: []ServerConfig{{Listen: ""}}, + mode: ServerModeMultiServer, + wantErr: true, + errMsg: "multi_server 模式下每个 server 必须配置 listen 地址", + }, + { + name: "multi_server模式监听地址冲突", + servers: []ServerConfig{ + {Listen: ":8080"}, + {Listen: ":8080"}, + }, + mode: ServerModeMultiServer, + wantErr: true, + errMsg: "监听地址冲突", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateListenConflicts(tt.servers, tt.mode) + if tt.wantErr { + if err == nil { + t.Errorf("validateListenConflicts() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateListenConflicts() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateListenConflicts() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateHTTP2(t *testing.T) { + tests := []struct { + name string + errMsg string + config HTTP2Config + hasSSL bool + wantErr bool + }{ + { + name: "未启用HTTP2", + config: HTTP2Config{Enabled: false}, + hasSSL: false, + wantErr: false, + }, + { + name: "启用HTTP2且有SSL", + config: HTTP2Config{Enabled: true}, + hasSSL: true, + wantErr: false, + }, + { + name: "启用HTTP2但无SSL", + config: HTTP2Config{Enabled: true}, + hasSSL: false, + wantErr: true, + errMsg: "HTTP/2 需要配置 SSL/TLS 证书", + }, + { + name: "启用H2C但无SSL", + config: HTTP2Config{Enabled: true, H2CEnabled: true}, + hasSSL: false, + wantErr: false, + }, + { + name: "负数MaxConcurrentStreams", + config: HTTP2Config{MaxConcurrentStreams: -1}, + hasSSL: false, + wantErr: true, + errMsg: "max_concurrent_streams 不能为负数", + }, + { + name: "负数MaxHeaderListSize", + config: HTTP2Config{MaxHeaderListSize: -1}, + hasSSL: false, + wantErr: true, + errMsg: "max_header_list_size 不能为负数", + }, + { + name: "负数IdleTimeout", + config: HTTP2Config{IdleTimeout: -1}, + hasSSL: false, + wantErr: true, + errMsg: "idle_timeout 不能为负数", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateHTTP2(&tt.config, tt.hasSSL) + if tt.wantErr { + if err == nil { + t.Errorf("validateHTTP2() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateHTTP2() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateHTTP2() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateRedirectRewrite(t *testing.T) { + tests := []struct { + name string + errMsg string + config *RedirectRewriteConfig + wantErr bool + }{ + { + name: "nil配置有效", + config: nil, + wantErr: false, + }, + { + name: "空配置有效", + config: &RedirectRewriteConfig{}, + wantErr: false, + }, + { + name: "default模式有效", + config: &RedirectRewriteConfig{Mode: "default"}, + wantErr: false, + }, + { + name: "off模式有效", + config: &RedirectRewriteConfig{Mode: "off"}, + wantErr: false, + }, + { + name: "custom模式有规则", + config: &RedirectRewriteConfig{ + Mode: "custom", + Rules: []RedirectRewriteRule{ + {Pattern: "^/old$", Replacement: "/new"}, + }, + }, + wantErr: false, + }, + { + name: "custom模式无规则", + config: &RedirectRewriteConfig{Mode: "custom"}, + wantErr: true, + errMsg: "rules required when mode is custom", + }, + { + name: "无效模式", + config: &RedirectRewriteConfig{Mode: "invalid"}, + wantErr: true, + errMsg: "must be one of", + }, + { + name: "规则pattern为空", + config: &RedirectRewriteConfig{ + Mode: "custom", + Rules: []RedirectRewriteRule{ + {Pattern: "", Replacement: "/new"}, + }, + }, + wantErr: true, + errMsg: "pattern cannot be empty", + }, + { + name: "正则模式有效", + config: &RedirectRewriteConfig{ + Mode: "custom", + Rules: []RedirectRewriteRule{ + {Pattern: "~^/old/(.*)$", Replacement: "/new/$1"}, + }, + }, + wantErr: false, + }, + { + name: "正则模式无效", + config: &RedirectRewriteConfig{ + Mode: "custom", + Rules: []RedirectRewriteRule{ + {Pattern: "~[invalid(regex", Replacement: "/new"}, + }, + }, + wantErr: true, + errMsg: "invalid regex", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateRedirectRewrite(tt.config) + if tt.wantErr { + if err == nil { + t.Errorf("validateRedirectRewrite() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateRedirectRewrite() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateRedirectRewrite() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} diff --git a/internal/handler/sendfile_test.go b/internal/handler/sendfile_test.go index d7ead91..d951e67 100644 --- a/internal/handler/sendfile_test.go +++ b/internal/handler/sendfile_test.go @@ -723,6 +723,199 @@ func TestLinuxSendfile_SendfileError(t *testing.T) { } } +// TestLinuxSendfile_InvalidFileFd 测试无效文件描述符 +func TestLinuxSendfile_InvalidFileFd(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + defer ln.Close() + + var serverConn net.Conn + go func() { + serverConn, _ = ln.Accept() + }() + + clientConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer clientConn.Close() + + time.Sleep(100 * time.Millisecond) + + // 使用无效的文件描述符 + err = linuxSendfile(clientConn, uintptr(99999), 0, 1024) + if err == nil { + t.Error("Expected error for invalid file descriptor") + } + + if serverConn != nil { + serverConn.Close() + } +} + +// TestLinuxSendfile_ZeroLength 测试零长度传输 +func TestLinuxSendfile_ZeroLength(t *testing.T) { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "test.txt") + content := []byte("test") + _ = os.WriteFile(tmpFile, content, 0o644) + + file, err := os.Open(tmpFile) + if err != nil { + t.Fatalf("Failed to open file: %v", err) + } + defer func() { _ = file.Close() }() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + defer ln.Close() + + var serverConn net.Conn + go func() { + serverConn, _ = ln.Accept() + }() + + clientConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer clientConn.Close() + + time.Sleep(100 * time.Millisecond) + + // 零长度应该立即返回 + err = linuxSendfile(clientConn, file.Fd(), 0, 0) + if err != nil { + t.Errorf("Expected nil for zero length, got: %v", err) + } + + if serverConn != nil { + serverConn.Close() + } +} + +// TestLinuxSendfile_PartialTransfer 测试部分传输后继续 +func TestLinuxSendfile_PartialTransfer(t *testing.T) { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "partial.bin") + + // 创建大文件 + content := make([]byte, 32*1024) + for i := range content { + content[i] = byte(i % 256) + } + if err := os.WriteFile(tmpFile, content, 0o644); err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + file, err := os.Open(tmpFile) + if err != nil { + t.Fatalf("Failed to open file: %v", err) + } + defer func() { _ = file.Close() }() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + defer ln.Close() + + var serverConn net.Conn + var received []byte + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + serverConn, _ = ln.Accept() + buf := make([]byte, len(content)) + n, _ := serverConn.Read(buf) + received = buf[:n] + serverConn.Close() + }() + + clientConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer clientConn.Close() + + clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + + // 调用 linuxSendfile 传输整个文件 + err = linuxSendfile(clientConn, file.Fd(), 0, int64(len(content))) + // EPIPE/ECONNRESET 是可接受的,因为服务器可能提前关闭 + if err != nil && err != syscall.EPIPE && err != syscall.ECONNRESET { + t.Logf("linuxSendfile returned: %v", err) + } + + clientConn.Close() + wg.Wait() + + // 验证至少传输了部分数据 + if len(received) > 0 { + t.Logf("Received %d bytes", len(received)) + } +} + +// TestLinuxSendfile_WithOffset 测试带偏移量的传输 +func TestLinuxSendfile_WithOffset(t *testing.T) { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "offset.bin") + + content := make([]byte, 16*1024) + for i := range content { + content[i] = byte(i % 256) + } + if err := os.WriteFile(tmpFile, content, 0o644); err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + file, err := os.Open(tmpFile) + if err != nil { + t.Fatalf("Failed to open file: %v", err) + } + defer func() { _ = file.Close() }() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + defer ln.Close() + + var serverConn net.Conn + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + serverConn, _ = ln.Accept() + buf := make([]byte, 8*1024) + _, _ = serverConn.Read(buf) + serverConn.Close() + }() + + clientConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer clientConn.Close() + + clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + + // 注意:linuxSendfile 的 offset 参数未使用(由内核处理) + // 这里测试 length 参数 + err = linuxSendfile(clientConn, file.Fd(), 0, 8*1024) + if err != nil && err != syscall.EPIPE && err != syscall.ECONNRESET { + t.Logf("linuxSendfile returned: %v", err) + } + + clientConn.Close() + wg.Wait() +} + // TestSendFile_NegativeLength 测试负长度参数 func TestSendFile_NegativeLength(t *testing.T) { tmpDir := t.TempDir() diff --git a/internal/loadbalance/balancer_test.go b/internal/loadbalance/balancer_test.go index 807e0d7..94f0567 100644 --- a/internal/loadbalance/balancer_test.go +++ b/internal/loadbalance/balancer_test.go @@ -1508,27 +1508,6 @@ func TestConsistentHash_PrecomputeHashes(t *testing.T) { }) } -// TestConsistentHash_SelectExcluding 测试一致性哈希SelectExcluding方法。 -func TestConsistentHash_SelectExcluding(t *testing.T) { - t.Run("委托给SelectExcludingByKey", func(_ *testing.T) { - ch := NewConsistentHash(100, "ip") - targets := []*Target{ - createHealthyTarget("http://backend1:8080", true), - createHealthyTarget("http://backend2:8080", true), - } - ch.Rebuild(targets) - - excluded := []*Target{targets[0]} - got := ch.SelectExcluding(targets, excluded) - if got == nil { - t.Fatal("SelectExcluding() = nil, want non-nil") - } - if got.URL == targets[0].URL { - t.Errorf("选中了被排除的目标: %q", got.URL) - } - }) -} - // TestLeastConnections_ConcurrentSelection 测试最少连接并发选择。 func TestLeastConnections_ConcurrentSelection(t *testing.T) { targets := []*Target{ @@ -1857,6 +1836,102 @@ func TestFilterHealthyBackup(t *testing.T) { }) } +// TestConsistentHash_Select 测试一致性哈希 Select 方法(委托给 SelectByKey)。 +func TestConsistentHash_Select(t *testing.T) { + t.Run("委托给SelectByKey", func(_ *testing.T) { + ch := NewConsistentHash(100, "ip") + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + } + + // Select 内部调用 SelectByKey(targets, "") + got := ch.Select(targets) + if got == nil { + t.Fatal("Select() = nil, want non-nil") + } + }) + + t.Run("空目标返回nil", func(_ *testing.T) { + ch := NewConsistentHash(100, "ip") + got := ch.Select([]*Target{}) + if got != nil { + t.Errorf("Select() = %v, want nil", got) + } + }) + + t.Run("单目标直接返回", func(_ *testing.T) { + ch := NewConsistentHash(100, "ip") + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + } + + got := ch.Select(targets) + if got == nil { + t.Fatal("Select() = nil, want non-nil") + } + if got.URL != "http://backend1:8080" { + t.Errorf("Select() = %q, want %q", got.URL, "http://backend1:8080") + } + }) + + t.Run("多目标一致性", func(_ *testing.T) { + ch := NewConsistentHash(100, "ip") + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + createHealthyTarget("http://backend3:8080", true), + } + + // 多次调用应该返回相同结果(空键的一致性) + first := ch.Select(targets) + for i := 0; i < 10; i++ { + got := ch.Select(targets) + if got == nil { + t.Fatal("Select() = nil, want non-nil") + } + if got.URL != first.URL { + t.Errorf("Select() 不一致: first=%q, got=%q", first.URL, got.URL) + } + } + }) +} + +// TestConsistentHash_SelectExcluding 测试一致性哈希 SelectExcluding 方法。 +func TestConsistentHash_SelectExcluding(t *testing.T) { + t.Run("委托给SelectExcludingByKey", func(_ *testing.T) { + ch := NewConsistentHash(100, "ip") + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + } + ch.Rebuild(targets) + + excluded := []*Target{targets[0]} + got := ch.SelectExcluding(targets, excluded) + if got == nil { + t.Fatal("SelectExcluding() = nil, want non-nil") + } + if got.URL == targets[0].URL { + t.Errorf("选中了被排除的目标: %q", got.URL) + } + }) + + t.Run("空排除列表", func(_ *testing.T) { + ch := NewConsistentHash(100, "ip") + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + } + ch.Rebuild(targets) + + got := ch.SelectExcluding(targets, nil) + if got == nil { + t.Fatal("SelectExcluding() = nil, want non-nil") + } + }) +} + +// TestRandomBalancer 测试随机负载均衡器。 func TestRandomBalancer(t *testing.T) { t.Run("selects from available targets", func(t *testing.T) { targets := []*Target{ @@ -1892,4 +1967,85 @@ func TestRandomBalancer(t *testing.T) { t.Error("should exclude first target") } }) + + t.Run("select excluding all targets returns nil", func(t *testing.T) { + targets := []*Target{ + NewTargetFromConfig("http://a:8080", 1, 0, 0, 0, false, false, ""), + NewTargetFromConfig("http://b:8080", 1, 0, 0, 0, false, false, ""), + } + b := NewRandom() + selected := b.SelectExcluding(targets, targets) + if selected != nil { + t.Error("should return nil when all targets excluded") + } + }) + + t.Run("select excluding empty list", func(t *testing.T) { + targets := []*Target{ + NewTargetFromConfig("http://a:8080", 1, 0, 0, 0, false, false, ""), + } + b := NewRandom() + selected := b.SelectExcluding(targets, nil) + if selected == nil { + t.Error("should select a target with empty exclusion") + } + }) + + t.Run("select excluding with nil in excluded list", func(t *testing.T) { + targets := []*Target{ + NewTargetFromConfig("http://a:8080", 1, 0, 0, 0, false, false, ""), + NewTargetFromConfig("http://b:8080", 1, 0, 0, 0, false, false, ""), + } + b := NewRandom() + selected := b.SelectExcluding(targets, []*Target{nil}) + if selected == nil { + t.Error("should select a target even with nil in excluded list") + } + }) + + t.Run("power of two choices prefers fewer connections", func(t *testing.T) { + targets := []*Target{ + NewTargetFromConfig("http://a:8080", 1, 0, 0, 0, false, false, ""), + NewTargetFromConfig("http://b:8080", 1, 0, 0, 0, false, false, ""), + } + targets[0].Connections = 100 + targets[1].Connections = 1 + b := NewRandom() + + // 多次选择,验证总是选择连接数少的目标 + for i := 0; i < 100; i++ { + selected := b.Select(targets) + if selected == nil { + t.Error("should select a target") + continue + } + // Power of Two Choices 总是选择连接数少的 + if selected.URL != "http://b:8080" { + t.Errorf("should prefer target with fewer connections, got %q", selected.URL) + } + } + }) + + t.Run("power of two choices with equal connections", func(t *testing.T) { + targets := []*Target{ + NewTargetFromConfig("http://a:8080", 1, 0, 0, 0, false, false, ""), + NewTargetFromConfig("http://b:8080", 1, 0, 0, 0, false, false, ""), + } + targets[0].Connections = 10 + targets[1].Connections = 10 + b := NewRandom() + + // 连接数相等时,两个目标都应该被选中 + counts := make(map[string]int) + for i := 0; i < 100; i++ { + selected := b.Select(targets) + if selected != nil { + counts[selected.URL]++ + } + } + + if counts["http://a:8080"] == 0 || counts["http://b:8080"] == 0 { + t.Error("both targets should be selected when connections are equal") + } + }) } diff --git a/internal/proxy/health_test.go b/internal/proxy/health_test.go index 98e77ef..c564d76 100644 --- a/internal/proxy/health_test.go +++ b/internal/proxy/health_test.go @@ -444,3 +444,60 @@ func TestMarkUnhealthy(t *testing.T) { } }) } + +// TestMarkUnhealthy_WithSlowStartManager 测试 MarkUnhealthy 与 SlowStartManager 集成。 +func TestMarkUnhealthy_WithSlowStartManager(t *testing.T) { + target := &loadbalance.Target{ + URL: "http://127.0.0.1:8080", + Weight: 100, + } + target.Healthy.Store(true) + target.SlowStart = 30 * time.Second + + checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ + Interval: 1 * time.Hour, + Timeout: 5 * time.Second, + Path: "/health", + SlowStart: 30 * time.Second, + }) + + // 先标记为健康以初始化慢启动 + checker.MarkHealthy(target) + + // 标记目标为不健康 + checker.MarkUnhealthy(target) + + if target.Healthy.Load() { + t.Error("target 应标记为 unhealthy") + } +} + +// TestMarkHealthy_WithSlowStartManager 测试 MarkHealthy 与 SlowStartManager 集成。 +func TestMarkHealthy_WithSlowStartManager(t *testing.T) { + target := &loadbalance.Target{ + URL: "http://127.0.0.1:8080", + Weight: 100, + } + target.Healthy.Store(false) + target.SlowStart = 30 * time.Second + + checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ + Interval: 1 * time.Hour, + Timeout: 5 * time.Second, + Path: "/health", + SlowStart: 30 * time.Second, + }) + + // 标记目标为健康 + checker.MarkHealthy(target) + + if !target.Healthy.Load() { + t.Error("target 应标记为 healthy") + } + + // 验证慢启动已开始(EffectiveWeight 应被设置为 1) + ew := target.EffectiveWeight.Load() + if ew <= 0 { + t.Errorf("慢启动 EffectiveWeight 应大于 0,got: %d", ew) + } +}