diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 687a837..5f384a5 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "strings" + "syscall" "testing" ) @@ -53,7 +54,109 @@ func captureStderr(t *testing.T) (func() string, func()) { } } -// TestRun 测试 Run 函数的各种场景。 +// TestNewApp 测试 NewApp 构造器。 +func TestNewApp(t *testing.T) { + cfgPath := "/path/to/config.yaml" + app := NewApp(cfgPath) + + if app.cfgPath != cfgPath { + t.Errorf("cfgPath = %q, want %q", app.cfgPath, cfgPath) + } + + if app.cfg != nil { + t.Error("新创建的 App cfg 应为 nil") + } + + if app.srv != nil { + t.Error("新创建的 App srv 应为 nil") + } + + if app.pidFile != "" { + t.Errorf("pidFile = %q, want empty", app.pidFile) + } + + if app.logFile != "" { + t.Errorf("logFile = %q, want empty", app.logFile) + } +} + +// TestSetPidFile 测试 SetPidFile setter 方法。 +func TestSetPidFile(t *testing.T) { + app := NewApp("/test/config.yaml") + pidPath := "/var/run/lolly.pid" + + app.SetPidFile(pidPath) + + if app.pidFile != pidPath { + t.Errorf("pidFile = %q, want %q", app.pidFile, pidPath) + } +} + +// TestSetLogFile 测试 SetLogFile setter 方法。 +func TestSetLogFile(t *testing.T) { + app := NewApp("/test/config.yaml") + logPath := "/var/log/lolly.log" + + app.SetLogFile(logPath) + + if app.logFile != logPath { + t.Errorf("logFile = %q, want %q", app.logFile, logPath) + } +} + +// TestSigName 测试信号名称辅助函数。 +func TestSigName(t *testing.T) { + tests := []struct { + name string + sig syscall.Signal + expected string + }{ + { + name: "SIGTERM", + sig: syscall.SIGTERM, + expected: "SIGTERM", + }, + { + name: "SIGINT", + sig: syscall.SIGINT, + expected: "SIGINT", + }, + { + name: "SIGQUIT", + sig: syscall.SIGQUIT, + expected: "SIGQUIT", + }, + { + name: "SIGHUP", + sig: syscall.SIGHUP, + expected: "SIGHUP", + }, + { + name: "SIGUSR1", + sig: syscall.SIGUSR1, + expected: "SIGUSR1", + }, + { + name: "SIGUSR2", + sig: syscall.SIGUSR2, + expected: "SIGUSR2", + }, + { + name: "未知信号", + sig: syscall.Signal(999), + expected: "Signal(999)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := sigName(tt.sig) + if result != tt.expected { + t.Errorf("sigName(%d) = %q, want %q", tt.sig, result, tt.expected) + } + }) + } +} func TestRun(t *testing.T) { tests := []struct { name string diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index a4c12fc..0d14f92 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -902,6 +902,206 @@ func TestHandleWebSocket(t *testing.T) { } } +// TestSetHealthChecker 测试健康检查器设置 +// 注意:SetHealthChecker 是公开方法,但 healthChecker 是私有字段 +// 此测试验证方法可以正常调用 +func TestSetHealthChecker(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://localhost:8081"}, + } + + p, err := NewProxy(cfg, targets) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 创建健康检查器 + hcCfg := &config.HealthCheckConfig{ + Interval: 10 * time.Second, + Path: "/health", + Timeout: 5 * time.Second, + } + hc := NewHealthChecker(targets, hcCfg) + + // 设置健康检查器 - 验证方法存在且可调用 + p.SetHealthChecker(hc) + + // 测试被动健康检查:标记目标为不健康 + targets[0].Healthy.Store(true) + hc.MarkUnhealthy(targets[0]) + + if targets[0].Healthy.Load() { + t.Error("MarkUnhealthy() target should be unhealthy after marking") + } +} + +// TestGetClient 测试客户端获取 +func TestGetClient(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://localhost:8081"}, + {URL: "http://localhost:8082"}, + } + + p, err := NewProxy(cfg, targets) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 测试获取存在的客户端 + client1 := p.getClient("http://localhost:8081") + if client1 == nil { + t.Error("getClient() returned nil for existing client") + } + + client2 := p.getClient("http://localhost:8082") + if client2 == nil { + t.Error("getClient() returned nil for existing client") + } + + // 测试获取不存在的客户端 + client3 := p.getClient("http://localhost:9999") + if client3 != nil { + t.Error("getClient() should return nil for non-existent client") + } +} + +// TestProxyCache 测试代理缓存功能 +func TestProxyCache(t *testing.T) { + // 创建内存监听器作为后端服务器 + ln := fasthttputil.NewInmemoryListener() + defer ln.Close() + + requestCount := 0 + go func() { + s := &fasthttp.Server{ + Handler: func(ctx *fasthttp.RequestCtx) { + requestCount++ + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetBodyString("Cached response") + ctx.Response.Header.Set("X-Request-Count", string(rune(requestCount))) + }, + } + s.Serve(ln) + }() + + // 等待服务器启动 + time.Sleep(50 * time.Millisecond) + + addr := ln.Addr().String() + + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 1 * time.Second, + CacheLock: true, + StaleWhileRevalidate: 500 * time.Millisecond, + }, + } + + targets := []*loadbalance.Target{ + {URL: "http://" + addr}, + } + targets[0].Healthy.Store(true) + + p, err := NewProxy(cfg, targets) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 验证缓存已初始化 + if p.cache == nil { + t.Fatal("Proxy cache should be initialized when enabled") + } + + // 测试缓存设置和获取 + p.cache.Set("/api/test", []byte("test data"), map[string]string{"Content-Type": "text/plain"}, 200, 1*time.Second) + + entry, found, stale := p.cache.Get("/api/test") + if !found { + t.Error("Cache should find existing entry") + } + if stale { + t.Error("Cache entry should not be stale immediately after setting") + } + if string(entry.Data) != "test data" { + t.Errorf("Cache entry data = %q, want %q", string(entry.Data), "test data") + } + + // 测试缓存统计 + stats := p.cache.Stats() + if stats.Entries != 1 { + t.Errorf("Cache stats.Entries = %d, want %d", stats.Entries, 1) + } + + // 测试缓存清除 + p.cache.Clear() + stats = p.cache.Stats() + if stats.Entries != 0 { + t.Errorf("Cache stats.Entries after Clear = %d, want %d", stats.Entries, 0) + } +} + +// TestServeHTTP_WithPassiveHealthCheck 测试带有被动健康检查的请求转发 +func TestServeHTTP_WithPassiveHealthCheck(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 100 * time.Millisecond, Read: 100 * time.Millisecond, Write: 100 * time.Millisecond}, + } + + targets := []*loadbalance.Target{ + {URL: "http://127.0.0.1:59999"}, // 不存在的后端 + } + targets[0].Healthy.Store(true) + + p, err := NewProxy(cfg, targets) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 设置健康检查器 + hcCfg := &config.HealthCheckConfig{ + Interval: 10 * time.Second, + Path: "/health", + Timeout: 5 * time.Second, + } + hc := NewHealthChecker(targets, hcCfg) + p.SetHealthChecker(hc) + + // 创建测试请求 + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fasthttp.MethodGet) + ctx.Request.SetRequestURI("/api/test") + + // 执行请求 - 应该会失败并触发被动健康检查 + p.ServeHTTP(ctx) + + // 验证返回502错误 + if ctx.Response.StatusCode() != fasthttp.StatusBadGateway { + t.Errorf("ServeHTTP() status code = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusBadGateway) + } + + // 验证目标已被标记为不健康 + if targets[0].Healthy.Load() { + t.Error("Target should be marked unhealthy after failed request") + } +} + // 辅助函数 func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsAt(s, substr, 0)) diff --git a/internal/ssl/ssl_test.go b/internal/ssl/ssl_test.go index ed8a0bb..fba82bb 100644 --- a/internal/ssl/ssl_test.go +++ b/internal/ssl/ssl_test.go @@ -408,3 +408,469 @@ func containsString(s, substr string) bool { } return false } + +// TestGetTLSConfig 测试 TLS 配置获取 +func TestGetTLSConfig(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + // 生成自签名证书 + cert, key := generateTestCert(t) + if err := os.WriteFile(certPath, cert, 0644); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("Failed to write key: %v", err) + } + + cfg := &config.SSLConfig{ + Cert: certPath, + Key: keyPath, + } + + manager, err := NewTLSManager(cfg) + if err != nil { + t.Fatalf("NewTLSManager() failed: %v", err) + } + defer manager.Close() + + // 验证返回非 nil 配置 + tlsCfg := manager.GetTLSConfig() + if tlsCfg == nil { + t.Fatal("Expected non-nil TLS config") + } + + // 验证 TLS 版本设置 + if tlsCfg.MinVersion != tls.VersionTLS12 { + t.Errorf("Expected MinVersion TLS 1.2, got %v", tlsCfg.MinVersion) + } +} + +// TestGetTLSConfig_WithProtocols 测试带协议配置的 TLS 配置获取 +func TestGetTLSConfig_WithProtocols(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + cert, key := generateTestCert(t) + if err := os.WriteFile(certPath, cert, 0644); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("Failed to write key: %v", err) + } + + cfg := &config.SSLConfig{ + Cert: certPath, + Key: keyPath, + Protocols: []string{"TLSv1.3"}, + } + + manager, err := NewTLSManager(cfg) + if err != nil { + t.Fatalf("NewTLSManager() failed: %v", err) + } + defer manager.Close() + + tlsCfg := manager.GetTLSConfig() + if tlsCfg == nil { + t.Fatal("Expected non-nil TLS config") + } + + // 验证 TLS 1.3 设置 + if tlsCfg.MinVersion != tls.VersionTLS13 { + t.Errorf("Expected MinVersion TLS 1.3, got %v", tlsCfg.MinVersion) + } + if tlsCfg.MaxVersion != tls.VersionTLS13 { + t.Errorf("Expected MaxVersion TLS 1.3, got %v", tlsCfg.MaxVersion) + } +} + +// TestClose 测试 TLS 管理器关闭 +func TestClose(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + cert, key := generateTestCert(t) + if err := os.WriteFile(certPath, cert, 0644); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("Failed to write key: %v", err) + } + + cfg := &config.SSLConfig{ + Cert: certPath, + Key: keyPath, + } + + manager, err := NewTLSManager(cfg) + if err != nil { + t.Fatalf("NewTLSManager() failed: %v", err) + } + + // 第一次关闭 + manager.Close() + + // 验证多次调用 Close 是安全的 + manager.Close() + manager.Close() +} + +// TestNewTLSManager_Errors 测试错误场景 +func TestNewTLSManager_Errors(t *testing.T) { + tests := []struct { + name string + cfg *config.SSLConfig + wantErr bool + errMsg string + }{ + { + name: "nil config", + cfg: nil, + wantErr: true, + errMsg: "ssl config is nil", + }, + { + name: "缺少证书路径", + cfg: &config.SSLConfig{ + Key: "key.pem", + }, + wantErr: true, + errMsg: "certificate and key paths are required", + }, + { + name: "缺少密钥路径", + cfg: &config.SSLConfig{ + Cert: "cert.pem", + }, + wantErr: true, + errMsg: "certificate and key paths are required", + }, + { + name: "无效证书文件", + cfg: &config.SSLConfig{ + Cert: "/nonexistent/cert.pem", + Key: "/nonexistent/key.pem", + }, + wantErr: true, + errMsg: "failed to load certificate", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewTLSManager(tt.cfg) + if (err != nil) != tt.wantErr { + t.Errorf("NewTLSManager() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil && tt.errMsg != "" { + if !containsString(err.Error(), tt.errMsg) { + t.Errorf("NewTLSManager() error = %v, want errMsg containing %v", err, tt.errMsg) + } + } + }) + } +} + +// TestNewTLSManager_InvalidCipher 测试无效加密套件 +func TestNewTLSManager_InvalidCipher(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + cert, key := generateTestCert(t) + if err := os.WriteFile(certPath, cert, 0644); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("Failed to write key: %v", err) + } + + cfg := &config.SSLConfig{ + Cert: certPath, + Key: keyPath, + Ciphers: []string{"TLS_UNKNOWN_CIPHER"}, + } + + _, err := NewTLSManager(cfg) + if err == nil { + t.Error("Expected error for invalid cipher suite") + } +} + +// TestNewTLSManager_InsecureCipher 测试不安全加密套件 +func TestNewTLSManager_InsecureCipher(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + cert, key := generateTestCert(t) + if err := os.WriteFile(certPath, cert, 0644); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("Failed to write key: %v", err) + } + + cfg := &config.SSLConfig{ + Cert: certPath, + Key: keyPath, + Ciphers: []string{"TLS_ECDHE_RSA_WITH_RC4_128_SHA"}, + } + + _, err := NewTLSManager(cfg) + if err == nil { + t.Error("Expected error for insecure cipher suite") + } +} + +// TestNewMultiTLSManager 测试多证书 TLS 管理器 +func TestNewMultiTLSManager(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + cert, key := generateTestCert(t) + if err := os.WriteFile(certPath, cert, 0644); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("Failed to write key: %v", err) + } + + configs := map[string]*config.SSLConfig{ + "example.com": { + Cert: certPath, + Key: keyPath, + }, + } + + manager, err := NewMultiTLSManager(configs, nil) + if err != nil { + t.Fatalf("NewMultiTLSManager() failed: %v", err) + } + defer manager.Close() + + // 验证配置已加载 + cfg := manager.GetTLSConfigForHost("example.com") + if cfg == nil { + t.Fatal("Expected non-nil config for example.com") + } +} + +// TestNewMultiTLSManager_EmptyConfigs 测试空配置 +func TestNewMultiTLSManager_EmptyConfigs(t *testing.T) { + _, err := NewMultiTLSManager(map[string]*config.SSLConfig{}, nil) + if err == nil { + t.Error("Expected error for empty configs") + } +} + +// TestNewMultiTLSManager_NilConfig 测试 nil 配置项 +func TestNewMultiTLSManager_NilConfig(t *testing.T) { + configs := map[string]*config.SSLConfig{ + "example.com": nil, + } + + _, err := NewMultiTLSManager(configs, nil) + if err == nil { + t.Error("Expected error for nil config") + } +} + +// TestGetCertificate 测试证书获取回调 +func TestGetCertificate(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + cert, key := generateTestCert(t) + if err := os.WriteFile(certPath, cert, 0644); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("Failed to write key: %v", err) + } + + configs := map[string]*config.SSLConfig{ + "example.com": { + Cert: certPath, + Key: keyPath, + }, + } + + manager, err := NewMultiTLSManager(configs, nil) + if err != nil { + t.Fatalf("NewMultiTLSManager() failed: %v", err) + } + defer manager.Close() + + getCert := manager.GetCertificate() + if getCert == nil { + t.Fatal("Expected non-nil GetCertificate function") + } + + // 测试获取存在的证书 + testHello := &tls.ClientHelloInfo{ + ServerName: "example.com", + } + certResult, err := getCert(testHello) + if err != nil { + t.Errorf("GetCertificate() error = %v", err) + } + if certResult == nil { + t.Error("Expected non-nil certificate") + } +} + +// TestAddCertificate 测试添加证书 +func TestAddCertificate(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + cert, key := generateTestCert(t) + if err := os.WriteFile(certPath, cert, 0644); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("Failed to write key: %v", err) + } + + // 创建带默认配置的管理器 + cfg := &config.SSLConfig{ + Cert: certPath, + Key: keyPath, + } + + manager, err := NewTLSManager(cfg) + if err != nil { + t.Fatalf("NewTLSManager() failed: %v", err) + } + defer manager.Close() + + // 测试添加新证书 + err = manager.AddCertificate("newhost.com", cfg) + if err != nil { + t.Errorf("AddCertificate() error = %v", err) + } + + // 验证新证书已添加 + hostCfg := manager.GetTLSConfigForHost("newhost.com") + if hostCfg == nil { + t.Error("Expected config for newhost.com") + } +} + +// TestAddCertificate_Error 测试添加证书错误 +func TestAddCertificate_Error(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + cert, key := generateTestCert(t) + if err := os.WriteFile(certPath, cert, 0644); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("Failed to write key: %v", err) + } + + cfg := &config.SSLConfig{ + Cert: certPath, + Key: keyPath, + } + + manager, err := NewTLSManager(cfg) + if err != nil { + t.Fatalf("NewTLSManager() failed: %v", err) + } + defer manager.Close() + + // 测试 nil 配置 + err = manager.AddCertificate("test.com", nil) + if err == nil { + t.Error("Expected error for nil config") + } +} + +// TestRemoveCertificate 测试移除证书 +func TestRemoveCertificate(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + cert, key := generateTestCert(t) + if err := os.WriteFile(certPath, cert, 0644); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("Failed to write key: %v", err) + } + + configs := map[string]*config.SSLConfig{ + "example.com": { + Cert: certPath, + Key: keyPath, + }, + } + + // 创建一个默认配置 + defaultCfg := &config.SSLConfig{ + Cert: certPath, + Key: keyPath, + } + + manager, err := NewMultiTLSManager(configs, defaultCfg) + if err != nil { + t.Fatalf("NewMultiTLSManager() failed: %v", err) + } + defer manager.Close() + + // 移除证书 + manager.RemoveCertificate("example.com") + + // 验证证书已移除(应返回默认配置) + cfg := manager.GetTLSConfigForHost("example.com") + if cfg == nil { + t.Error("Expected default config after removal") + } +} + +// TestGetOCSPStatus_NoManager 测试无 OCSP 管理器时的状态 +func TestGetOCSPStatus_NoManager(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + keyPath := filepath.Join(tmpDir, "key.pem") + + cert, key := generateTestCert(t) + if err := os.WriteFile(certPath, cert, 0644); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + if err := os.WriteFile(keyPath, key, 0600); err != nil { + t.Fatalf("Failed to write key: %v", err) + } + + cfg := &config.SSLConfig{ + Cert: certPath, + Key: keyPath, + OCSPStapling: false, // 禁用 OCSP + } + + manager, err := NewTLSManager(cfg) + if err != nil { + t.Fatalf("NewTLSManager() failed: %v", err) + } + defer manager.Close() + + status := manager.GetOCSPStatus() + if status == nil { + t.Error("Expected non-nil status map") + } + if len(status) != 0 { + t.Errorf("Expected empty status, got %d entries", len(status)) + } +} diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go index ef6a14a..9a08326 100644 --- a/internal/stream/stream_test.go +++ b/internal/stream/stream_test.go @@ -154,6 +154,30 @@ func TestUpstreamSelect(t *testing.T) { } } +func TestTargetHealthy(t *testing.T) { + target := &Target{ + addr: "localhost:8001", + weight: 1, + } + + // 初始状态应该是不健康(默认为 false) + if target.healthy.Load() { + t.Error("新目标应该默认为不健康") + } + + // 设置为健康 + target.healthy.Store(true) + if !target.healthy.Load() { + t.Error("目标应该被标记为健康") + } + + // 设置为不健康 + target.healthy.Store(false) + if target.healthy.Load() { + t.Error("目标应该被标记为不健康") + } +} + func TestHealthChecker(t *testing.T) { u := &Upstream{ targets: []*Target{ @@ -177,6 +201,30 @@ func TestHealthChecker(t *testing.T) { } } +func TestHealthCheckerStartStop(t *testing.T) { + u := &Upstream{ + targets: []*Target{ + {addr: "localhost:99998"}, // 不存在的端口 + }, + } + + hc := &HealthChecker{ + upstream: u, + interval: 100 * time.Millisecond, + timeout: 50 * time.Millisecond, + stopCh: make(chan struct{}), + } + + // 启动健康检查 + go hc.Start() + + // 等待几次检查执行 + time.Sleep(250 * time.Millisecond) + + // 停止健康检查 + hc.Stop() +} + func TestUDPListener(t *testing.T) { addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") if err != nil { @@ -346,4 +394,357 @@ func TestNewUDPServer(t *testing.T) { if srv2.timeout != 30*time.Second { t.Errorf("Expected timeout 30s, got %v", srv2.timeout) } +} + +func TestListenTCP(t *testing.T) { + s := NewServer() + + // 使用 :0 让系统分配端口 + err := s.ListenTCP("127.0.0.1:0") + if err != nil { + t.Fatalf("ListenTCP failed: %v", err) + } + + // 验证监听器已创建 + s.mu.RLock() + if len(s.listeners) != 1 { + t.Errorf("Expected 1 listener, got %d", len(s.listeners)) + } + s.mu.RUnlock() + + // 验证 Stats + stats := s.Stats() + if stats.Listeners != 1 { + t.Errorf("Expected 1 listener in stats, got %d", stats.Listeners) + } +} + +func TestServerStartStopWithTCP(t *testing.T) { + s := NewServer() + + // 添加上游 + targets := []TargetSpec{ + {Addr: "127.0.0.1:19003", Weight: 1}, + } + s.AddUpstream("tcp_test", targets, "round_robin", HealthCheckSpec{}) + + // 监听 TCP + err := s.ListenTCP("127.0.0.1:19004") + if err != nil { + t.Fatalf("ListenTCP failed: %v", err) + } + + // 启动服务器 + err = s.Start() + if err != nil { + t.Fatalf("Start failed: %v", err) + } + + // 给服务器一点时间启动 + time.Sleep(50 * time.Millisecond) + + // 停止服务器 + err = s.Stop() + if err != nil { + t.Errorf("Stop failed: %v", err) + } +} + +func TestRoundRobinBalancerWithSingleTarget(t *testing.T) { + rb := newRoundRobin() + targets := []*Target{ + {addr: "backend1:8080"}, + } + targets[0].healthy.Store(true) + + // 测试单个健康目标 + for i := 0; i < 5; i++ { + target := rb.Select(targets) + if target == nil { + t.Error("Expected non-nil target") + continue + } + if target.addr != "backend1:8080" { + t.Errorf("Expected backend1:8080, got %s", target.addr) + } + } +} + +func TestLeastConnBalancerWithTie(t *testing.T) { + lc := newLeastConn() + targets := []*Target{ + {addr: "backend1:8080", conns: 5}, + {addr: "backend2:8080", conns: 5}, + {addr: "backend3:8080", conns: 5}, + } + for _, t := range targets { + t.healthy.Store(true) + } + + // 当连接数相同时,应该选择第一个 + selected := lc.Select(targets) + if selected == nil { + t.Error("Expected non-nil target") + } +} + +func TestAddUpstreamWithLeastConn(t *testing.T) { + s := NewServer() + + targets := []TargetSpec{ + {Addr: "localhost:8001", Weight: 1}, + {Addr: "localhost:8002", Weight: 2}, + } + + err := s.AddUpstream("least_conn_test", targets, "least_conn", HealthCheckSpec{}) + if err != nil { + t.Errorf("AddUpstream failed: %v", err) + } + + up := s.upstreams["least_conn_test"] + if up == nil { + t.Fatal("Expected non-nil upstream") + } + + // 验证使用的是最少连接均衡器 + _, ok := up.balancer.(*leastConn) + if !ok { + t.Error("Expected leastConn balancer") + } +} + +func TestAddUpstreamWithHealthCheck(t *testing.T) { + s := NewServer() + + targets := []TargetSpec{ + {Addr: "localhost:8001", Weight: 1}, + } + + hcSpec := HealthCheckSpec{ + Enabled: true, + Interval: 1 * time.Second, + Timeout: 500 * time.Millisecond, + } + + err := s.AddUpstream("hc_test", targets, "round_robin", hcSpec) + if err != nil { + t.Errorf("AddUpstream failed: %v", err) + } + + up := s.upstreams["hc_test"] + if up == nil { + t.Fatal("Expected non-nil upstream") + } + + if up.healthChk == nil { + t.Error("Expected health checker to be initialized") + } + + // 停止健康检查 + if up.healthChk != nil { + up.healthChk.Stop() + } +} + +func TestUpstreamSelectNoHealthy(t *testing.T) { + u := &Upstream{ + targets: []*Target{ + {addr: "localhost:8001"}, + {addr: "localhost:8002"}, + }, + balancer: newRoundRobin(), + } + // 不设置 healthy,默认为 false + + selected := u.Select() + if selected != nil { + t.Error("Expected nil for no healthy targets") + } +} + +func TestCleanupExpiredSessions(t *testing.T) { + // 创建 UDP 连接 + udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + conn, _ := net.ListenUDP("udp", udpAddr) + defer conn.Close() + + // 创建上游 + upstream := &Upstream{ + targets: []*Target{{addr: "127.0.0.1:19005"}}, + balancer: newRoundRobin(), + } + upstream.targets[0].healthy.Store(true) + + // 创建 UDP 服务器,设置很短的超时时间 + srv := newUDPServer(conn, upstream, 1*time.Millisecond) + + // 创建模拟会话 + clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12345") + session := &udpSession{ + clientAddr: clientAddr, + lastActive: time.Now().Add(-1 * time.Hour), // 很久以前的活动 + } + srv.sessions[sessionKey(clientAddr)] = session + + // 执行清理 + srv.cleanupExpiredSessions() + + // 验证会话已被清理 + srv.mu.RLock() + if len(srv.sessions) != 0 { + t.Errorf("Expected 0 sessions after cleanup, got %d", len(srv.sessions)) + } + srv.mu.RUnlock() +} + +func TestUDPServerStop(t *testing.T) { + // 创建 UDP 连接 + udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + conn, _ := net.ListenUDP("udp", udpAddr) + + // 创建上游 + upstream := &Upstream{ + targets: []*Target{{addr: "127.0.0.1:19006"}}, + balancer: newRoundRobin(), + } + + // 创建 UDP 服务器 + srv := newUDPServer(conn, upstream, 1*time.Second) + + // 添加一个模拟会话 + clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12346") + session := &udpSession{ + clientAddr: clientAddr, + lastActive: time.Now(), + } + srv.sessions[sessionKey(clientAddr)] = session + + // 停止服务器 + srv.stop() + + // 验证会话已被清理 + srv.mu.RLock() + if len(srv.sessions) != 0 { + t.Errorf("Expected 0 sessions after stop, got %d", len(srv.sessions)) + } + srv.mu.RUnlock() +} + +func TestUDPSessionOperations(t *testing.T) { + // 创建 UDP 连接 + udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + conn, _ := net.ListenUDP("udp", udpAddr) + defer conn.Close() + + // 创建目标服务器(用于模拟连接) + targetAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:19007") + targetConn, _ := net.ListenUDP("udp", targetAddr) + defer targetConn.Close() + + // 创建上游 + upstream := &Upstream{ + targets: []*Target{{addr: "127.0.0.1:19007"}}, + balancer: newRoundRobin(), + } + upstream.targets[0].healthy.Store(true) + + // 创建 UDP 服务器 + srv := newUDPServer(conn, upstream, 1*time.Minute) + + // 测试 getSession - 不存在的会话 + clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12347") + session := srv.getSession(clientAddr) + if session != nil { + t.Error("Expected nil for non-existent session") + } + + // 创建模拟会话 + testSession := &udpSession{ + clientAddr: clientAddr, + lastActive: time.Now(), + srv: srv, + } + srv.sessions[sessionKey(clientAddr)] = testSession + + // 测试 getSession - 存在的会话 + session = srv.getSession(clientAddr) + if session == nil { + t.Error("Expected non-nil session") + } + + // 测试 removeSession + srv.removeSession(clientAddr) + session = srv.getSession(clientAddr) + if session != nil { + t.Error("Expected nil after removeSession") + } +} + +func TestUDPSessionClose(t *testing.T) { + // 创建两个 UDP 连接用于测试 + udpAddr1, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + conn1, _ := net.ListenUDP("udp", udpAddr1) + + udpAddr2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + conn2, _ := net.ListenUDP("udp", udpAddr2) + + // 创建会话 + session := &udpSession{ + clientAddr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12348}, + targetConn: conn2, + lastActive: time.Now(), + } + + // 测试 close - 应该能正常关闭 + session.close() + + // 第二次调用 close 不应该出错(使用 sync.Once) + session.close() + + conn1.Close() +} + +func TestHealthCheckerCheckWithHealthyTarget(t *testing.T) { + // 启动一个临时的 TCP 服务器 + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer listener.Close() + + // 在后台运行服务器 + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + conn.Close() + } + }() + + addr := listener.Addr().String() + + u := &Upstream{ + targets: []*Target{ + {addr: addr}, + }, + } + // 初始设置为不健康 + u.targets[0].healthy.Store(false) + + hc := &HealthChecker{ + upstream: u, + interval: 1 * time.Second, + timeout: 100 * time.Millisecond, + stopCh: make(chan struct{}), + } + + // 执行健康检查 + hc.check() + + // 目标应该被标记为健康(因为服务器在运行) + if !u.targets[0].healthy.Load() { + t.Error("Expected target to be marked healthy") + } } \ No newline at end of file