test(app,proxy,ssl,stream): 完善测试覆盖率

- app: 添加 NewApp/SetPidFile/SetLogFile/sigName 测试
- proxy: 扩展健康检查器测试
- ssl: 添加 TLS 配置和 Close 方法测试
- stream: 添加负载均衡器和 UDP 会话测试

覆盖率从 55.4% 提升至 60.3%

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-03 13:36:43 +08:00
parent 80936ae66b
commit c70ab305b7
4 changed files with 1171 additions and 1 deletions

View File

@ -6,6 +6,7 @@ import (
"os"
"path/filepath"
"strings"
"syscall"
"testing"
)
@ -53,7 +54,109 @@ func captureStderr(t *testing.T) (func() string, func()) {
}
}
// TestRun 测试 Run 函数的各种场景。
// TestNewApp 测试 NewApp 构造器。
func TestNewApp(t *testing.T) {
cfgPath := "/path/to/config.yaml"
app := NewApp(cfgPath)
if app.cfgPath != cfgPath {
t.Errorf("cfgPath = %q, want %q", app.cfgPath, cfgPath)
}
if app.cfg != nil {
t.Error("新创建的 App cfg 应为 nil")
}
if app.srv != nil {
t.Error("新创建的 App srv 应为 nil")
}
if app.pidFile != "" {
t.Errorf("pidFile = %q, want empty", app.pidFile)
}
if app.logFile != "" {
t.Errorf("logFile = %q, want empty", app.logFile)
}
}
// TestSetPidFile 测试 SetPidFile setter 方法。
func TestSetPidFile(t *testing.T) {
app := NewApp("/test/config.yaml")
pidPath := "/var/run/lolly.pid"
app.SetPidFile(pidPath)
if app.pidFile != pidPath {
t.Errorf("pidFile = %q, want %q", app.pidFile, pidPath)
}
}
// TestSetLogFile 测试 SetLogFile setter 方法。
func TestSetLogFile(t *testing.T) {
app := NewApp("/test/config.yaml")
logPath := "/var/log/lolly.log"
app.SetLogFile(logPath)
if app.logFile != logPath {
t.Errorf("logFile = %q, want %q", app.logFile, logPath)
}
}
// TestSigName 测试信号名称辅助函数。
func TestSigName(t *testing.T) {
tests := []struct {
name string
sig syscall.Signal
expected string
}{
{
name: "SIGTERM",
sig: syscall.SIGTERM,
expected: "SIGTERM",
},
{
name: "SIGINT",
sig: syscall.SIGINT,
expected: "SIGINT",
},
{
name: "SIGQUIT",
sig: syscall.SIGQUIT,
expected: "SIGQUIT",
},
{
name: "SIGHUP",
sig: syscall.SIGHUP,
expected: "SIGHUP",
},
{
name: "SIGUSR1",
sig: syscall.SIGUSR1,
expected: "SIGUSR1",
},
{
name: "SIGUSR2",
sig: syscall.SIGUSR2,
expected: "SIGUSR2",
},
{
name: "未知信号",
sig: syscall.Signal(999),
expected: "Signal(999)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sigName(tt.sig)
if result != tt.expected {
t.Errorf("sigName(%d) = %q, want %q", tt.sig, result, tt.expected)
}
})
}
}
func TestRun(t *testing.T) {
tests := []struct {
name string

View File

@ -902,6 +902,206 @@ func TestHandleWebSocket(t *testing.T) {
}
}
// TestSetHealthChecker 测试健康检查器设置
// 注意SetHealthChecker 是公开方法,但 healthChecker 是私有字段
// 此测试验证方法可以正常调用
func TestSetHealthChecker(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://localhost:8081"},
}
p, err := NewProxy(cfg, targets)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 创建健康检查器
hcCfg := &config.HealthCheckConfig{
Interval: 10 * time.Second,
Path: "/health",
Timeout: 5 * time.Second,
}
hc := NewHealthChecker(targets, hcCfg)
// 设置健康检查器 - 验证方法存在且可调用
p.SetHealthChecker(hc)
// 测试被动健康检查:标记目标为不健康
targets[0].Healthy.Store(true)
hc.MarkUnhealthy(targets[0])
if targets[0].Healthy.Load() {
t.Error("MarkUnhealthy() target should be unhealthy after marking")
}
}
// TestGetClient 测试客户端获取
func TestGetClient(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://localhost:8081"},
{URL: "http://localhost:8082"},
}
p, err := NewProxy(cfg, targets)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 测试获取存在的客户端
client1 := p.getClient("http://localhost:8081")
if client1 == nil {
t.Error("getClient() returned nil for existing client")
}
client2 := p.getClient("http://localhost:8082")
if client2 == nil {
t.Error("getClient() returned nil for existing client")
}
// 测试获取不存在的客户端
client3 := p.getClient("http://localhost:9999")
if client3 != nil {
t.Error("getClient() should return nil for non-existent client")
}
}
// TestProxyCache 测试代理缓存功能
func TestProxyCache(t *testing.T) {
// 创建内存监听器作为后端服务器
ln := fasthttputil.NewInmemoryListener()
defer ln.Close()
requestCount := 0
go func() {
s := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
requestCount++
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.SetBodyString("Cached response")
ctx.Response.Header.Set("X-Request-Count", string(rune(requestCount)))
},
}
s.Serve(ln)
}()
// 等待服务器启动
time.Sleep(50 * time.Millisecond)
addr := ln.Addr().String()
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 1 * time.Second,
CacheLock: true,
StaleWhileRevalidate: 500 * time.Millisecond,
},
}
targets := []*loadbalance.Target{
{URL: "http://" + addr},
}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 验证缓存已初始化
if p.cache == nil {
t.Fatal("Proxy cache should be initialized when enabled")
}
// 测试缓存设置和获取
p.cache.Set("/api/test", []byte("test data"), map[string]string{"Content-Type": "text/plain"}, 200, 1*time.Second)
entry, found, stale := p.cache.Get("/api/test")
if !found {
t.Error("Cache should find existing entry")
}
if stale {
t.Error("Cache entry should not be stale immediately after setting")
}
if string(entry.Data) != "test data" {
t.Errorf("Cache entry data = %q, want %q", string(entry.Data), "test data")
}
// 测试缓存统计
stats := p.cache.Stats()
if stats.Entries != 1 {
t.Errorf("Cache stats.Entries = %d, want %d", stats.Entries, 1)
}
// 测试缓存清除
p.cache.Clear()
stats = p.cache.Stats()
if stats.Entries != 0 {
t.Errorf("Cache stats.Entries after Clear = %d, want %d", stats.Entries, 0)
}
}
// TestServeHTTP_WithPassiveHealthCheck 测试带有被动健康检查的请求转发
func TestServeHTTP_WithPassiveHealthCheck(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 100 * time.Millisecond, Read: 100 * time.Millisecond, Write: 100 * time.Millisecond},
}
targets := []*loadbalance.Target{
{URL: "http://127.0.0.1:59999"}, // 不存在的后端
}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 设置健康检查器
hcCfg := &config.HealthCheckConfig{
Interval: 10 * time.Second,
Path: "/health",
Timeout: 5 * time.Second,
}
hc := NewHealthChecker(targets, hcCfg)
p.SetHealthChecker(hc)
// 创建测试请求
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fasthttp.MethodGet)
ctx.Request.SetRequestURI("/api/test")
// 执行请求 - 应该会失败并触发被动健康检查
p.ServeHTTP(ctx)
// 验证返回502错误
if ctx.Response.StatusCode() != fasthttp.StatusBadGateway {
t.Errorf("ServeHTTP() status code = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusBadGateway)
}
// 验证目标已被标记为不健康
if targets[0].Healthy.Load() {
t.Error("Target should be marked unhealthy after failed request")
}
}
// 辅助函数
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsAt(s, substr, 0))

View File

@ -408,3 +408,469 @@ func containsString(s, substr string) bool {
}
return false
}
// TestGetTLSConfig 测试 TLS 配置获取
func TestGetTLSConfig(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
// 生成自签名证书
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
defer manager.Close()
// 验证返回非 nil 配置
tlsCfg := manager.GetTLSConfig()
if tlsCfg == nil {
t.Fatal("Expected non-nil TLS config")
}
// 验证 TLS 版本设置
if tlsCfg.MinVersion != tls.VersionTLS12 {
t.Errorf("Expected MinVersion TLS 1.2, got %v", tlsCfg.MinVersion)
}
}
// TestGetTLSConfig_WithProtocols 测试带协议配置的 TLS 配置获取
func TestGetTLSConfig_WithProtocols(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
Protocols: []string{"TLSv1.3"},
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
defer manager.Close()
tlsCfg := manager.GetTLSConfig()
if tlsCfg == nil {
t.Fatal("Expected non-nil TLS config")
}
// 验证 TLS 1.3 设置
if tlsCfg.MinVersion != tls.VersionTLS13 {
t.Errorf("Expected MinVersion TLS 1.3, got %v", tlsCfg.MinVersion)
}
if tlsCfg.MaxVersion != tls.VersionTLS13 {
t.Errorf("Expected MaxVersion TLS 1.3, got %v", tlsCfg.MaxVersion)
}
}
// TestClose 测试 TLS 管理器关闭
func TestClose(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
// 第一次关闭
manager.Close()
// 验证多次调用 Close 是安全的
manager.Close()
manager.Close()
}
// TestNewTLSManager_Errors 测试错误场景
func TestNewTLSManager_Errors(t *testing.T) {
tests := []struct {
name string
cfg *config.SSLConfig
wantErr bool
errMsg string
}{
{
name: "nil config",
cfg: nil,
wantErr: true,
errMsg: "ssl config is nil",
},
{
name: "缺少证书路径",
cfg: &config.SSLConfig{
Key: "key.pem",
},
wantErr: true,
errMsg: "certificate and key paths are required",
},
{
name: "缺少密钥路径",
cfg: &config.SSLConfig{
Cert: "cert.pem",
},
wantErr: true,
errMsg: "certificate and key paths are required",
},
{
name: "无效证书文件",
cfg: &config.SSLConfig{
Cert: "/nonexistent/cert.pem",
Key: "/nonexistent/key.pem",
},
wantErr: true,
errMsg: "failed to load certificate",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewTLSManager(tt.cfg)
if (err != nil) != tt.wantErr {
t.Errorf("NewTLSManager() error = %v, wantErr %v", err, tt.wantErr)
}
if err != nil && tt.errMsg != "" {
if !containsString(err.Error(), tt.errMsg) {
t.Errorf("NewTLSManager() error = %v, want errMsg containing %v", err, tt.errMsg)
}
}
})
}
}
// TestNewTLSManager_InvalidCipher 测试无效加密套件
func TestNewTLSManager_InvalidCipher(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
Ciphers: []string{"TLS_UNKNOWN_CIPHER"},
}
_, err := NewTLSManager(cfg)
if err == nil {
t.Error("Expected error for invalid cipher suite")
}
}
// TestNewTLSManager_InsecureCipher 测试不安全加密套件
func TestNewTLSManager_InsecureCipher(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
Ciphers: []string{"TLS_ECDHE_RSA_WITH_RC4_128_SHA"},
}
_, err := NewTLSManager(cfg)
if err == nil {
t.Error("Expected error for insecure cipher suite")
}
}
// TestNewMultiTLSManager 测试多证书 TLS 管理器
func TestNewMultiTLSManager(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
configs := map[string]*config.SSLConfig{
"example.com": {
Cert: certPath,
Key: keyPath,
},
}
manager, err := NewMultiTLSManager(configs, nil)
if err != nil {
t.Fatalf("NewMultiTLSManager() failed: %v", err)
}
defer manager.Close()
// 验证配置已加载
cfg := manager.GetTLSConfigForHost("example.com")
if cfg == nil {
t.Fatal("Expected non-nil config for example.com")
}
}
// TestNewMultiTLSManager_EmptyConfigs 测试空配置
func TestNewMultiTLSManager_EmptyConfigs(t *testing.T) {
_, err := NewMultiTLSManager(map[string]*config.SSLConfig{}, nil)
if err == nil {
t.Error("Expected error for empty configs")
}
}
// TestNewMultiTLSManager_NilConfig 测试 nil 配置项
func TestNewMultiTLSManager_NilConfig(t *testing.T) {
configs := map[string]*config.SSLConfig{
"example.com": nil,
}
_, err := NewMultiTLSManager(configs, nil)
if err == nil {
t.Error("Expected error for nil config")
}
}
// TestGetCertificate 测试证书获取回调
func TestGetCertificate(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
configs := map[string]*config.SSLConfig{
"example.com": {
Cert: certPath,
Key: keyPath,
},
}
manager, err := NewMultiTLSManager(configs, nil)
if err != nil {
t.Fatalf("NewMultiTLSManager() failed: %v", err)
}
defer manager.Close()
getCert := manager.GetCertificate()
if getCert == nil {
t.Fatal("Expected non-nil GetCertificate function")
}
// 测试获取存在的证书
testHello := &tls.ClientHelloInfo{
ServerName: "example.com",
}
certResult, err := getCert(testHello)
if err != nil {
t.Errorf("GetCertificate() error = %v", err)
}
if certResult == nil {
t.Error("Expected non-nil certificate")
}
}
// TestAddCertificate 测试添加证书
func TestAddCertificate(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
// 创建带默认配置的管理器
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
defer manager.Close()
// 测试添加新证书
err = manager.AddCertificate("newhost.com", cfg)
if err != nil {
t.Errorf("AddCertificate() error = %v", err)
}
// 验证新证书已添加
hostCfg := manager.GetTLSConfigForHost("newhost.com")
if hostCfg == nil {
t.Error("Expected config for newhost.com")
}
}
// TestAddCertificate_Error 测试添加证书错误
func TestAddCertificate_Error(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
defer manager.Close()
// 测试 nil 配置
err = manager.AddCertificate("test.com", nil)
if err == nil {
t.Error("Expected error for nil config")
}
}
// TestRemoveCertificate 测试移除证书
func TestRemoveCertificate(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
configs := map[string]*config.SSLConfig{
"example.com": {
Cert: certPath,
Key: keyPath,
},
}
// 创建一个默认配置
defaultCfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
}
manager, err := NewMultiTLSManager(configs, defaultCfg)
if err != nil {
t.Fatalf("NewMultiTLSManager() failed: %v", err)
}
defer manager.Close()
// 移除证书
manager.RemoveCertificate("example.com")
// 验证证书已移除(应返回默认配置)
cfg := manager.GetTLSConfigForHost("example.com")
if cfg == nil {
t.Error("Expected default config after removal")
}
}
// TestGetOCSPStatus_NoManager 测试无 OCSP 管理器时的状态
func TestGetOCSPStatus_NoManager(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
OCSPStapling: false, // 禁用 OCSP
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
defer manager.Close()
status := manager.GetOCSPStatus()
if status == nil {
t.Error("Expected non-nil status map")
}
if len(status) != 0 {
t.Errorf("Expected empty status, got %d entries", len(status))
}
}

View File

@ -154,6 +154,30 @@ func TestUpstreamSelect(t *testing.T) {
}
}
func TestTargetHealthy(t *testing.T) {
target := &Target{
addr: "localhost:8001",
weight: 1,
}
// 初始状态应该是不健康(默认为 false
if target.healthy.Load() {
t.Error("新目标应该默认为不健康")
}
// 设置为健康
target.healthy.Store(true)
if !target.healthy.Load() {
t.Error("目标应该被标记为健康")
}
// 设置为不健康
target.healthy.Store(false)
if target.healthy.Load() {
t.Error("目标应该被标记为不健康")
}
}
func TestHealthChecker(t *testing.T) {
u := &Upstream{
targets: []*Target{
@ -177,6 +201,30 @@ func TestHealthChecker(t *testing.T) {
}
}
func TestHealthCheckerStartStop(t *testing.T) {
u := &Upstream{
targets: []*Target{
{addr: "localhost:99998"}, // 不存在的端口
},
}
hc := &HealthChecker{
upstream: u,
interval: 100 * time.Millisecond,
timeout: 50 * time.Millisecond,
stopCh: make(chan struct{}),
}
// 启动健康检查
go hc.Start()
// 等待几次检查执行
time.Sleep(250 * time.Millisecond)
// 停止健康检查
hc.Stop()
}
func TestUDPListener(t *testing.T) {
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
if err != nil {
@ -346,4 +394,357 @@ func TestNewUDPServer(t *testing.T) {
if srv2.timeout != 30*time.Second {
t.Errorf("Expected timeout 30s, got %v", srv2.timeout)
}
}
func TestListenTCP(t *testing.T) {
s := NewServer()
// 使用 :0 让系统分配端口
err := s.ListenTCP("127.0.0.1:0")
if err != nil {
t.Fatalf("ListenTCP failed: %v", err)
}
// 验证监听器已创建
s.mu.RLock()
if len(s.listeners) != 1 {
t.Errorf("Expected 1 listener, got %d", len(s.listeners))
}
s.mu.RUnlock()
// 验证 Stats
stats := s.Stats()
if stats.Listeners != 1 {
t.Errorf("Expected 1 listener in stats, got %d", stats.Listeners)
}
}
func TestServerStartStopWithTCP(t *testing.T) {
s := NewServer()
// 添加上游
targets := []TargetSpec{
{Addr: "127.0.0.1:19003", Weight: 1},
}
s.AddUpstream("tcp_test", targets, "round_robin", HealthCheckSpec{})
// 监听 TCP
err := s.ListenTCP("127.0.0.1:19004")
if err != nil {
t.Fatalf("ListenTCP failed: %v", err)
}
// 启动服务器
err = s.Start()
if err != nil {
t.Fatalf("Start failed: %v", err)
}
// 给服务器一点时间启动
time.Sleep(50 * time.Millisecond)
// 停止服务器
err = s.Stop()
if err != nil {
t.Errorf("Stop failed: %v", err)
}
}
func TestRoundRobinBalancerWithSingleTarget(t *testing.T) {
rb := newRoundRobin()
targets := []*Target{
{addr: "backend1:8080"},
}
targets[0].healthy.Store(true)
// 测试单个健康目标
for i := 0; i < 5; i++ {
target := rb.Select(targets)
if target == nil {
t.Error("Expected non-nil target")
continue
}
if target.addr != "backend1:8080" {
t.Errorf("Expected backend1:8080, got %s", target.addr)
}
}
}
func TestLeastConnBalancerWithTie(t *testing.T) {
lc := newLeastConn()
targets := []*Target{
{addr: "backend1:8080", conns: 5},
{addr: "backend2:8080", conns: 5},
{addr: "backend3:8080", conns: 5},
}
for _, t := range targets {
t.healthy.Store(true)
}
// 当连接数相同时,应该选择第一个
selected := lc.Select(targets)
if selected == nil {
t.Error("Expected non-nil target")
}
}
func TestAddUpstreamWithLeastConn(t *testing.T) {
s := NewServer()
targets := []TargetSpec{
{Addr: "localhost:8001", Weight: 1},
{Addr: "localhost:8002", Weight: 2},
}
err := s.AddUpstream("least_conn_test", targets, "least_conn", HealthCheckSpec{})
if err != nil {
t.Errorf("AddUpstream failed: %v", err)
}
up := s.upstreams["least_conn_test"]
if up == nil {
t.Fatal("Expected non-nil upstream")
}
// 验证使用的是最少连接均衡器
_, ok := up.balancer.(*leastConn)
if !ok {
t.Error("Expected leastConn balancer")
}
}
func TestAddUpstreamWithHealthCheck(t *testing.T) {
s := NewServer()
targets := []TargetSpec{
{Addr: "localhost:8001", Weight: 1},
}
hcSpec := HealthCheckSpec{
Enabled: true,
Interval: 1 * time.Second,
Timeout: 500 * time.Millisecond,
}
err := s.AddUpstream("hc_test", targets, "round_robin", hcSpec)
if err != nil {
t.Errorf("AddUpstream failed: %v", err)
}
up := s.upstreams["hc_test"]
if up == nil {
t.Fatal("Expected non-nil upstream")
}
if up.healthChk == nil {
t.Error("Expected health checker to be initialized")
}
// 停止健康检查
if up.healthChk != nil {
up.healthChk.Stop()
}
}
func TestUpstreamSelectNoHealthy(t *testing.T) {
u := &Upstream{
targets: []*Target{
{addr: "localhost:8001"},
{addr: "localhost:8002"},
},
balancer: newRoundRobin(),
}
// 不设置 healthy默认为 false
selected := u.Select()
if selected != nil {
t.Error("Expected nil for no healthy targets")
}
}
func TestCleanupExpiredSessions(t *testing.T) {
// 创建 UDP 连接
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer conn.Close()
// 创建上游
upstream := &Upstream{
targets: []*Target{{addr: "127.0.0.1:19005"}},
balancer: newRoundRobin(),
}
upstream.targets[0].healthy.Store(true)
// 创建 UDP 服务器,设置很短的超时时间
srv := newUDPServer(conn, upstream, 1*time.Millisecond)
// 创建模拟会话
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12345")
session := &udpSession{
clientAddr: clientAddr,
lastActive: time.Now().Add(-1 * time.Hour), // 很久以前的活动
}
srv.sessions[sessionKey(clientAddr)] = session
// 执行清理
srv.cleanupExpiredSessions()
// 验证会话已被清理
srv.mu.RLock()
if len(srv.sessions) != 0 {
t.Errorf("Expected 0 sessions after cleanup, got %d", len(srv.sessions))
}
srv.mu.RUnlock()
}
func TestUDPServerStop(t *testing.T) {
// 创建 UDP 连接
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
// 创建上游
upstream := &Upstream{
targets: []*Target{{addr: "127.0.0.1:19006"}},
balancer: newRoundRobin(),
}
// 创建 UDP 服务器
srv := newUDPServer(conn, upstream, 1*time.Second)
// 添加一个模拟会话
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12346")
session := &udpSession{
clientAddr: clientAddr,
lastActive: time.Now(),
}
srv.sessions[sessionKey(clientAddr)] = session
// 停止服务器
srv.stop()
// 验证会话已被清理
srv.mu.RLock()
if len(srv.sessions) != 0 {
t.Errorf("Expected 0 sessions after stop, got %d", len(srv.sessions))
}
srv.mu.RUnlock()
}
func TestUDPSessionOperations(t *testing.T) {
// 创建 UDP 连接
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer conn.Close()
// 创建目标服务器(用于模拟连接)
targetAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:19007")
targetConn, _ := net.ListenUDP("udp", targetAddr)
defer targetConn.Close()
// 创建上游
upstream := &Upstream{
targets: []*Target{{addr: "127.0.0.1:19007"}},
balancer: newRoundRobin(),
}
upstream.targets[0].healthy.Store(true)
// 创建 UDP 服务器
srv := newUDPServer(conn, upstream, 1*time.Minute)
// 测试 getSession - 不存在的会话
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12347")
session := srv.getSession(clientAddr)
if session != nil {
t.Error("Expected nil for non-existent session")
}
// 创建模拟会话
testSession := &udpSession{
clientAddr: clientAddr,
lastActive: time.Now(),
srv: srv,
}
srv.sessions[sessionKey(clientAddr)] = testSession
// 测试 getSession - 存在的会话
session = srv.getSession(clientAddr)
if session == nil {
t.Error("Expected non-nil session")
}
// 测试 removeSession
srv.removeSession(clientAddr)
session = srv.getSession(clientAddr)
if session != nil {
t.Error("Expected nil after removeSession")
}
}
func TestUDPSessionClose(t *testing.T) {
// 创建两个 UDP 连接用于测试
udpAddr1, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn1, _ := net.ListenUDP("udp", udpAddr1)
udpAddr2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn2, _ := net.ListenUDP("udp", udpAddr2)
// 创建会话
session := &udpSession{
clientAddr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12348},
targetConn: conn2,
lastActive: time.Now(),
}
// 测试 close - 应该能正常关闭
session.close()
// 第二次调用 close 不应该出错(使用 sync.Once
session.close()
conn1.Close()
}
func TestHealthCheckerCheckWithHealthyTarget(t *testing.T) {
// 启动一个临时的 TCP 服务器
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer listener.Close()
// 在后台运行服务器
go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}
conn.Close()
}
}()
addr := listener.Addr().String()
u := &Upstream{
targets: []*Target{
{addr: addr},
},
}
// 初始设置为不健康
u.targets[0].healthy.Store(false)
hc := &HealthChecker{
upstream: u,
interval: 1 * time.Second,
timeout: 100 * time.Millisecond,
stopCh: make(chan struct{}),
}
// 执行健康检查
hc.check()
// 目标应该被标记为健康(因为服务器在运行)
if !u.targets[0].healthy.Load() {
t.Error("Expected target to be marked healthy")
}
}