test(app,proxy,ssl,stream): 完善测试覆盖率
- app: 添加 NewApp/SetPidFile/SetLogFile/sigName 测试 - proxy: 扩展健康检查器测试 - ssl: 添加 TLS 配置和 Close 方法测试 - stream: 添加负载均衡器和 UDP 会话测试 覆盖率从 55.4% 提升至 60.3% Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
80936ae66b
commit
c70ab305b7
@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -53,7 +54,109 @@ func captureStderr(t *testing.T) (func() string, func()) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestRun 测试 Run 函数的各种场景。
|
||||
// TestNewApp 测试 NewApp 构造器。
|
||||
func TestNewApp(t *testing.T) {
|
||||
cfgPath := "/path/to/config.yaml"
|
||||
app := NewApp(cfgPath)
|
||||
|
||||
if app.cfgPath != cfgPath {
|
||||
t.Errorf("cfgPath = %q, want %q", app.cfgPath, cfgPath)
|
||||
}
|
||||
|
||||
if app.cfg != nil {
|
||||
t.Error("新创建的 App cfg 应为 nil")
|
||||
}
|
||||
|
||||
if app.srv != nil {
|
||||
t.Error("新创建的 App srv 应为 nil")
|
||||
}
|
||||
|
||||
if app.pidFile != "" {
|
||||
t.Errorf("pidFile = %q, want empty", app.pidFile)
|
||||
}
|
||||
|
||||
if app.logFile != "" {
|
||||
t.Errorf("logFile = %q, want empty", app.logFile)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetPidFile 测试 SetPidFile setter 方法。
|
||||
func TestSetPidFile(t *testing.T) {
|
||||
app := NewApp("/test/config.yaml")
|
||||
pidPath := "/var/run/lolly.pid"
|
||||
|
||||
app.SetPidFile(pidPath)
|
||||
|
||||
if app.pidFile != pidPath {
|
||||
t.Errorf("pidFile = %q, want %q", app.pidFile, pidPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetLogFile 测试 SetLogFile setter 方法。
|
||||
func TestSetLogFile(t *testing.T) {
|
||||
app := NewApp("/test/config.yaml")
|
||||
logPath := "/var/log/lolly.log"
|
||||
|
||||
app.SetLogFile(logPath)
|
||||
|
||||
if app.logFile != logPath {
|
||||
t.Errorf("logFile = %q, want %q", app.logFile, logPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSigName 测试信号名称辅助函数。
|
||||
func TestSigName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sig syscall.Signal
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "SIGTERM",
|
||||
sig: syscall.SIGTERM,
|
||||
expected: "SIGTERM",
|
||||
},
|
||||
{
|
||||
name: "SIGINT",
|
||||
sig: syscall.SIGINT,
|
||||
expected: "SIGINT",
|
||||
},
|
||||
{
|
||||
name: "SIGQUIT",
|
||||
sig: syscall.SIGQUIT,
|
||||
expected: "SIGQUIT",
|
||||
},
|
||||
{
|
||||
name: "SIGHUP",
|
||||
sig: syscall.SIGHUP,
|
||||
expected: "SIGHUP",
|
||||
},
|
||||
{
|
||||
name: "SIGUSR1",
|
||||
sig: syscall.SIGUSR1,
|
||||
expected: "SIGUSR1",
|
||||
},
|
||||
{
|
||||
name: "SIGUSR2",
|
||||
sig: syscall.SIGUSR2,
|
||||
expected: "SIGUSR2",
|
||||
},
|
||||
{
|
||||
name: "未知信号",
|
||||
sig: syscall.Signal(999),
|
||||
expected: "Signal(999)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sigName(tt.sig)
|
||||
if result != tt.expected {
|
||||
t.Errorf("sigName(%d) = %q, want %q", tt.sig, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestRun(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@ -902,6 +902,206 @@ func TestHandleWebSocket(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetHealthChecker 测试健康检查器设置
|
||||
// 注意:SetHealthChecker 是公开方法,但 healthChecker 是私有字段
|
||||
// 此测试验证方法可以正常调用
|
||||
func TestSetHealthChecker(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second},
|
||||
}
|
||||
|
||||
targets := []*loadbalance.Target{
|
||||
{URL: "http://localhost:8081"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
// 创建健康检查器
|
||||
hcCfg := &config.HealthCheckConfig{
|
||||
Interval: 10 * time.Second,
|
||||
Path: "/health",
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
hc := NewHealthChecker(targets, hcCfg)
|
||||
|
||||
// 设置健康检查器 - 验证方法存在且可调用
|
||||
p.SetHealthChecker(hc)
|
||||
|
||||
// 测试被动健康检查:标记目标为不健康
|
||||
targets[0].Healthy.Store(true)
|
||||
hc.MarkUnhealthy(targets[0])
|
||||
|
||||
if targets[0].Healthy.Load() {
|
||||
t.Error("MarkUnhealthy() target should be unhealthy after marking")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetClient 测试客户端获取
|
||||
func TestGetClient(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second},
|
||||
}
|
||||
|
||||
targets := []*loadbalance.Target{
|
||||
{URL: "http://localhost:8081"},
|
||||
{URL: "http://localhost:8082"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
// 测试获取存在的客户端
|
||||
client1 := p.getClient("http://localhost:8081")
|
||||
if client1 == nil {
|
||||
t.Error("getClient() returned nil for existing client")
|
||||
}
|
||||
|
||||
client2 := p.getClient("http://localhost:8082")
|
||||
if client2 == nil {
|
||||
t.Error("getClient() returned nil for existing client")
|
||||
}
|
||||
|
||||
// 测试获取不存在的客户端
|
||||
client3 := p.getClient("http://localhost:9999")
|
||||
if client3 != nil {
|
||||
t.Error("getClient() should return nil for non-existent client")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyCache 测试代理缓存功能
|
||||
func TestProxyCache(t *testing.T) {
|
||||
// 创建内存监听器作为后端服务器
|
||||
ln := fasthttputil.NewInmemoryListener()
|
||||
defer ln.Close()
|
||||
|
||||
requestCount := 0
|
||||
go func() {
|
||||
s := &fasthttp.Server{
|
||||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||||
requestCount++
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetBodyString("Cached response")
|
||||
ctx.Response.Header.Set("X-Request-Count", string(rune(requestCount)))
|
||||
},
|
||||
}
|
||||
s.Serve(ln)
|
||||
}()
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
addr := ln.Addr().String()
|
||||
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second},
|
||||
Cache: config.ProxyCacheConfig{
|
||||
Enabled: true,
|
||||
MaxAge: 1 * time.Second,
|
||||
CacheLock: true,
|
||||
StaleWhileRevalidate: 500 * time.Millisecond,
|
||||
},
|
||||
}
|
||||
|
||||
targets := []*loadbalance.Target{
|
||||
{URL: "http://" + addr},
|
||||
}
|
||||
targets[0].Healthy.Store(true)
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
// 验证缓存已初始化
|
||||
if p.cache == nil {
|
||||
t.Fatal("Proxy cache should be initialized when enabled")
|
||||
}
|
||||
|
||||
// 测试缓存设置和获取
|
||||
p.cache.Set("/api/test", []byte("test data"), map[string]string{"Content-Type": "text/plain"}, 200, 1*time.Second)
|
||||
|
||||
entry, found, stale := p.cache.Get("/api/test")
|
||||
if !found {
|
||||
t.Error("Cache should find existing entry")
|
||||
}
|
||||
if stale {
|
||||
t.Error("Cache entry should not be stale immediately after setting")
|
||||
}
|
||||
if string(entry.Data) != "test data" {
|
||||
t.Errorf("Cache entry data = %q, want %q", string(entry.Data), "test data")
|
||||
}
|
||||
|
||||
// 测试缓存统计
|
||||
stats := p.cache.Stats()
|
||||
if stats.Entries != 1 {
|
||||
t.Errorf("Cache stats.Entries = %d, want %d", stats.Entries, 1)
|
||||
}
|
||||
|
||||
// 测试缓存清除
|
||||
p.cache.Clear()
|
||||
stats = p.cache.Stats()
|
||||
if stats.Entries != 0 {
|
||||
t.Errorf("Cache stats.Entries after Clear = %d, want %d", stats.Entries, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// TestServeHTTP_WithPassiveHealthCheck 测试带有被动健康检查的请求转发
|
||||
func TestServeHTTP_WithPassiveHealthCheck(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 100 * time.Millisecond, Read: 100 * time.Millisecond, Write: 100 * time.Millisecond},
|
||||
}
|
||||
|
||||
targets := []*loadbalance.Target{
|
||||
{URL: "http://127.0.0.1:59999"}, // 不存在的后端
|
||||
}
|
||||
targets[0].Healthy.Store(true)
|
||||
|
||||
p, err := NewProxy(cfg, targets)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
// 设置健康检查器
|
||||
hcCfg := &config.HealthCheckConfig{
|
||||
Interval: 10 * time.Second,
|
||||
Path: "/health",
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
hc := NewHealthChecker(targets, hcCfg)
|
||||
p.SetHealthChecker(hc)
|
||||
|
||||
// 创建测试请求
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(fasthttp.MethodGet)
|
||||
ctx.Request.SetRequestURI("/api/test")
|
||||
|
||||
// 执行请求 - 应该会失败并触发被动健康检查
|
||||
p.ServeHTTP(ctx)
|
||||
|
||||
// 验证返回502错误
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusBadGateway {
|
||||
t.Errorf("ServeHTTP() status code = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusBadGateway)
|
||||
}
|
||||
|
||||
// 验证目标已被标记为不健康
|
||||
if targets[0].Healthy.Load() {
|
||||
t.Error("Target should be marked unhealthy after failed request")
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsAt(s, substr, 0))
|
||||
|
||||
@ -408,3 +408,469 @@ func containsString(s, substr string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestGetTLSConfig 测试 TLS 配置获取
|
||||
func TestGetTLSConfig(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
// 生成自签名证书
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
}
|
||||
|
||||
manager, err := NewTLSManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTLSManager() failed: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
// 验证返回非 nil 配置
|
||||
tlsCfg := manager.GetTLSConfig()
|
||||
if tlsCfg == nil {
|
||||
t.Fatal("Expected non-nil TLS config")
|
||||
}
|
||||
|
||||
// 验证 TLS 版本设置
|
||||
if tlsCfg.MinVersion != tls.VersionTLS12 {
|
||||
t.Errorf("Expected MinVersion TLS 1.2, got %v", tlsCfg.MinVersion)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetTLSConfig_WithProtocols 测试带协议配置的 TLS 配置获取
|
||||
func TestGetTLSConfig_WithProtocols(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
Protocols: []string{"TLSv1.3"},
|
||||
}
|
||||
|
||||
manager, err := NewTLSManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTLSManager() failed: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
tlsCfg := manager.GetTLSConfig()
|
||||
if tlsCfg == nil {
|
||||
t.Fatal("Expected non-nil TLS config")
|
||||
}
|
||||
|
||||
// 验证 TLS 1.3 设置
|
||||
if tlsCfg.MinVersion != tls.VersionTLS13 {
|
||||
t.Errorf("Expected MinVersion TLS 1.3, got %v", tlsCfg.MinVersion)
|
||||
}
|
||||
if tlsCfg.MaxVersion != tls.VersionTLS13 {
|
||||
t.Errorf("Expected MaxVersion TLS 1.3, got %v", tlsCfg.MaxVersion)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClose 测试 TLS 管理器关闭
|
||||
func TestClose(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
}
|
||||
|
||||
manager, err := NewTLSManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTLSManager() failed: %v", err)
|
||||
}
|
||||
|
||||
// 第一次关闭
|
||||
manager.Close()
|
||||
|
||||
// 验证多次调用 Close 是安全的
|
||||
manager.Close()
|
||||
manager.Close()
|
||||
}
|
||||
|
||||
// TestNewTLSManager_Errors 测试错误场景
|
||||
func TestNewTLSManager_Errors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *config.SSLConfig
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
cfg: nil,
|
||||
wantErr: true,
|
||||
errMsg: "ssl config is nil",
|
||||
},
|
||||
{
|
||||
name: "缺少证书路径",
|
||||
cfg: &config.SSLConfig{
|
||||
Key: "key.pem",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "certificate and key paths are required",
|
||||
},
|
||||
{
|
||||
name: "缺少密钥路径",
|
||||
cfg: &config.SSLConfig{
|
||||
Cert: "cert.pem",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "certificate and key paths are required",
|
||||
},
|
||||
{
|
||||
name: "无效证书文件",
|
||||
cfg: &config.SSLConfig{
|
||||
Cert: "/nonexistent/cert.pem",
|
||||
Key: "/nonexistent/key.pem",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "failed to load certificate",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewTLSManager(tt.cfg)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewTLSManager() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if err != nil && tt.errMsg != "" {
|
||||
if !containsString(err.Error(), tt.errMsg) {
|
||||
t.Errorf("NewTLSManager() error = %v, want errMsg containing %v", err, tt.errMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewTLSManager_InvalidCipher 测试无效加密套件
|
||||
func TestNewTLSManager_InvalidCipher(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
Ciphers: []string{"TLS_UNKNOWN_CIPHER"},
|
||||
}
|
||||
|
||||
_, err := NewTLSManager(cfg)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid cipher suite")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewTLSManager_InsecureCipher 测试不安全加密套件
|
||||
func TestNewTLSManager_InsecureCipher(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
Ciphers: []string{"TLS_ECDHE_RSA_WITH_RC4_128_SHA"},
|
||||
}
|
||||
|
||||
_, err := NewTLSManager(cfg)
|
||||
if err == nil {
|
||||
t.Error("Expected error for insecure cipher suite")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewMultiTLSManager 测试多证书 TLS 管理器
|
||||
func TestNewMultiTLSManager(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
configs := map[string]*config.SSLConfig{
|
||||
"example.com": {
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := NewMultiTLSManager(configs, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewMultiTLSManager() failed: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
// 验证配置已加载
|
||||
cfg := manager.GetTLSConfigForHost("example.com")
|
||||
if cfg == nil {
|
||||
t.Fatal("Expected non-nil config for example.com")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewMultiTLSManager_EmptyConfigs 测试空配置
|
||||
func TestNewMultiTLSManager_EmptyConfigs(t *testing.T) {
|
||||
_, err := NewMultiTLSManager(map[string]*config.SSLConfig{}, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for empty configs")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewMultiTLSManager_NilConfig 测试 nil 配置项
|
||||
func TestNewMultiTLSManager_NilConfig(t *testing.T) {
|
||||
configs := map[string]*config.SSLConfig{
|
||||
"example.com": nil,
|
||||
}
|
||||
|
||||
_, err := NewMultiTLSManager(configs, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil config")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetCertificate 测试证书获取回调
|
||||
func TestGetCertificate(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
configs := map[string]*config.SSLConfig{
|
||||
"example.com": {
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := NewMultiTLSManager(configs, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewMultiTLSManager() failed: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
getCert := manager.GetCertificate()
|
||||
if getCert == nil {
|
||||
t.Fatal("Expected non-nil GetCertificate function")
|
||||
}
|
||||
|
||||
// 测试获取存在的证书
|
||||
testHello := &tls.ClientHelloInfo{
|
||||
ServerName: "example.com",
|
||||
}
|
||||
certResult, err := getCert(testHello)
|
||||
if err != nil {
|
||||
t.Errorf("GetCertificate() error = %v", err)
|
||||
}
|
||||
if certResult == nil {
|
||||
t.Error("Expected non-nil certificate")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddCertificate 测试添加证书
|
||||
func TestAddCertificate(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
// 创建带默认配置的管理器
|
||||
cfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
}
|
||||
|
||||
manager, err := NewTLSManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTLSManager() failed: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
// 测试添加新证书
|
||||
err = manager.AddCertificate("newhost.com", cfg)
|
||||
if err != nil {
|
||||
t.Errorf("AddCertificate() error = %v", err)
|
||||
}
|
||||
|
||||
// 验证新证书已添加
|
||||
hostCfg := manager.GetTLSConfigForHost("newhost.com")
|
||||
if hostCfg == nil {
|
||||
t.Error("Expected config for newhost.com")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddCertificate_Error 测试添加证书错误
|
||||
func TestAddCertificate_Error(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
}
|
||||
|
||||
manager, err := NewTLSManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTLSManager() failed: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
// 测试 nil 配置
|
||||
err = manager.AddCertificate("test.com", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil config")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemoveCertificate 测试移除证书
|
||||
func TestRemoveCertificate(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
configs := map[string]*config.SSLConfig{
|
||||
"example.com": {
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
},
|
||||
}
|
||||
|
||||
// 创建一个默认配置
|
||||
defaultCfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
}
|
||||
|
||||
manager, err := NewMultiTLSManager(configs, defaultCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewMultiTLSManager() failed: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
// 移除证书
|
||||
manager.RemoveCertificate("example.com")
|
||||
|
||||
// 验证证书已移除(应返回默认配置)
|
||||
cfg := manager.GetTLSConfigForHost("example.com")
|
||||
if cfg == nil {
|
||||
t.Error("Expected default config after removal")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetOCSPStatus_NoManager 测试无 OCSP 管理器时的状态
|
||||
func TestGetOCSPStatus_NoManager(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
OCSPStapling: false, // 禁用 OCSP
|
||||
}
|
||||
|
||||
manager, err := NewTLSManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTLSManager() failed: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
status := manager.GetOCSPStatus()
|
||||
if status == nil {
|
||||
t.Error("Expected non-nil status map")
|
||||
}
|
||||
if len(status) != 0 {
|
||||
t.Errorf("Expected empty status, got %d entries", len(status))
|
||||
}
|
||||
}
|
||||
|
||||
@ -154,6 +154,30 @@ func TestUpstreamSelect(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTargetHealthy(t *testing.T) {
|
||||
target := &Target{
|
||||
addr: "localhost:8001",
|
||||
weight: 1,
|
||||
}
|
||||
|
||||
// 初始状态应该是不健康(默认为 false)
|
||||
if target.healthy.Load() {
|
||||
t.Error("新目标应该默认为不健康")
|
||||
}
|
||||
|
||||
// 设置为健康
|
||||
target.healthy.Store(true)
|
||||
if !target.healthy.Load() {
|
||||
t.Error("目标应该被标记为健康")
|
||||
}
|
||||
|
||||
// 设置为不健康
|
||||
target.healthy.Store(false)
|
||||
if target.healthy.Load() {
|
||||
t.Error("目标应该被标记为不健康")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthChecker(t *testing.T) {
|
||||
u := &Upstream{
|
||||
targets: []*Target{
|
||||
@ -177,6 +201,30 @@ func TestHealthChecker(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthCheckerStartStop(t *testing.T) {
|
||||
u := &Upstream{
|
||||
targets: []*Target{
|
||||
{addr: "localhost:99998"}, // 不存在的端口
|
||||
},
|
||||
}
|
||||
|
||||
hc := &HealthChecker{
|
||||
upstream: u,
|
||||
interval: 100 * time.Millisecond,
|
||||
timeout: 50 * time.Millisecond,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 启动健康检查
|
||||
go hc.Start()
|
||||
|
||||
// 等待几次检查执行
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
// 停止健康检查
|
||||
hc.Stop()
|
||||
}
|
||||
|
||||
func TestUDPListener(t *testing.T) {
|
||||
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
@ -346,4 +394,357 @@ func TestNewUDPServer(t *testing.T) {
|
||||
if srv2.timeout != 30*time.Second {
|
||||
t.Errorf("Expected timeout 30s, got %v", srv2.timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenTCP(t *testing.T) {
|
||||
s := NewServer()
|
||||
|
||||
// 使用 :0 让系统分配端口
|
||||
err := s.ListenTCP("127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("ListenTCP failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证监听器已创建
|
||||
s.mu.RLock()
|
||||
if len(s.listeners) != 1 {
|
||||
t.Errorf("Expected 1 listener, got %d", len(s.listeners))
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
// 验证 Stats
|
||||
stats := s.Stats()
|
||||
if stats.Listeners != 1 {
|
||||
t.Errorf("Expected 1 listener in stats, got %d", stats.Listeners)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerStartStopWithTCP(t *testing.T) {
|
||||
s := NewServer()
|
||||
|
||||
// 添加上游
|
||||
targets := []TargetSpec{
|
||||
{Addr: "127.0.0.1:19003", Weight: 1},
|
||||
}
|
||||
s.AddUpstream("tcp_test", targets, "round_robin", HealthCheckSpec{})
|
||||
|
||||
// 监听 TCP
|
||||
err := s.ListenTCP("127.0.0.1:19004")
|
||||
if err != nil {
|
||||
t.Fatalf("ListenTCP failed: %v", err)
|
||||
}
|
||||
|
||||
// 启动服务器
|
||||
err = s.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
|
||||
// 给服务器一点时间启动
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 停止服务器
|
||||
err = s.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Stop failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinBalancerWithSingleTarget(t *testing.T) {
|
||||
rb := newRoundRobin()
|
||||
targets := []*Target{
|
||||
{addr: "backend1:8080"},
|
||||
}
|
||||
targets[0].healthy.Store(true)
|
||||
|
||||
// 测试单个健康目标
|
||||
for i := 0; i < 5; i++ {
|
||||
target := rb.Select(targets)
|
||||
if target == nil {
|
||||
t.Error("Expected non-nil target")
|
||||
continue
|
||||
}
|
||||
if target.addr != "backend1:8080" {
|
||||
t.Errorf("Expected backend1:8080, got %s", target.addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeastConnBalancerWithTie(t *testing.T) {
|
||||
lc := newLeastConn()
|
||||
targets := []*Target{
|
||||
{addr: "backend1:8080", conns: 5},
|
||||
{addr: "backend2:8080", conns: 5},
|
||||
{addr: "backend3:8080", conns: 5},
|
||||
}
|
||||
for _, t := range targets {
|
||||
t.healthy.Store(true)
|
||||
}
|
||||
|
||||
// 当连接数相同时,应该选择第一个
|
||||
selected := lc.Select(targets)
|
||||
if selected == nil {
|
||||
t.Error("Expected non-nil target")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddUpstreamWithLeastConn(t *testing.T) {
|
||||
s := NewServer()
|
||||
|
||||
targets := []TargetSpec{
|
||||
{Addr: "localhost:8001", Weight: 1},
|
||||
{Addr: "localhost:8002", Weight: 2},
|
||||
}
|
||||
|
||||
err := s.AddUpstream("least_conn_test", targets, "least_conn", HealthCheckSpec{})
|
||||
if err != nil {
|
||||
t.Errorf("AddUpstream failed: %v", err)
|
||||
}
|
||||
|
||||
up := s.upstreams["least_conn_test"]
|
||||
if up == nil {
|
||||
t.Fatal("Expected non-nil upstream")
|
||||
}
|
||||
|
||||
// 验证使用的是最少连接均衡器
|
||||
_, ok := up.balancer.(*leastConn)
|
||||
if !ok {
|
||||
t.Error("Expected leastConn balancer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddUpstreamWithHealthCheck(t *testing.T) {
|
||||
s := NewServer()
|
||||
|
||||
targets := []TargetSpec{
|
||||
{Addr: "localhost:8001", Weight: 1},
|
||||
}
|
||||
|
||||
hcSpec := HealthCheckSpec{
|
||||
Enabled: true,
|
||||
Interval: 1 * time.Second,
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}
|
||||
|
||||
err := s.AddUpstream("hc_test", targets, "round_robin", hcSpec)
|
||||
if err != nil {
|
||||
t.Errorf("AddUpstream failed: %v", err)
|
||||
}
|
||||
|
||||
up := s.upstreams["hc_test"]
|
||||
if up == nil {
|
||||
t.Fatal("Expected non-nil upstream")
|
||||
}
|
||||
|
||||
if up.healthChk == nil {
|
||||
t.Error("Expected health checker to be initialized")
|
||||
}
|
||||
|
||||
// 停止健康检查
|
||||
if up.healthChk != nil {
|
||||
up.healthChk.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamSelectNoHealthy(t *testing.T) {
|
||||
u := &Upstream{
|
||||
targets: []*Target{
|
||||
{addr: "localhost:8001"},
|
||||
{addr: "localhost:8002"},
|
||||
},
|
||||
balancer: newRoundRobin(),
|
||||
}
|
||||
// 不设置 healthy,默认为 false
|
||||
|
||||
selected := u.Select()
|
||||
if selected != nil {
|
||||
t.Error("Expected nil for no healthy targets")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupExpiredSessions(t *testing.T) {
|
||||
// 创建 UDP 连接
|
||||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||
defer conn.Close()
|
||||
|
||||
// 创建上游
|
||||
upstream := &Upstream{
|
||||
targets: []*Target{{addr: "127.0.0.1:19005"}},
|
||||
balancer: newRoundRobin(),
|
||||
}
|
||||
upstream.targets[0].healthy.Store(true)
|
||||
|
||||
// 创建 UDP 服务器,设置很短的超时时间
|
||||
srv := newUDPServer(conn, upstream, 1*time.Millisecond)
|
||||
|
||||
// 创建模拟会话
|
||||
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12345")
|
||||
session := &udpSession{
|
||||
clientAddr: clientAddr,
|
||||
lastActive: time.Now().Add(-1 * time.Hour), // 很久以前的活动
|
||||
}
|
||||
srv.sessions[sessionKey(clientAddr)] = session
|
||||
|
||||
// 执行清理
|
||||
srv.cleanupExpiredSessions()
|
||||
|
||||
// 验证会话已被清理
|
||||
srv.mu.RLock()
|
||||
if len(srv.sessions) != 0 {
|
||||
t.Errorf("Expected 0 sessions after cleanup, got %d", len(srv.sessions))
|
||||
}
|
||||
srv.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestUDPServerStop(t *testing.T) {
|
||||
// 创建 UDP 连接
|
||||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||
|
||||
// 创建上游
|
||||
upstream := &Upstream{
|
||||
targets: []*Target{{addr: "127.0.0.1:19006"}},
|
||||
balancer: newRoundRobin(),
|
||||
}
|
||||
|
||||
// 创建 UDP 服务器
|
||||
srv := newUDPServer(conn, upstream, 1*time.Second)
|
||||
|
||||
// 添加一个模拟会话
|
||||
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12346")
|
||||
session := &udpSession{
|
||||
clientAddr: clientAddr,
|
||||
lastActive: time.Now(),
|
||||
}
|
||||
srv.sessions[sessionKey(clientAddr)] = session
|
||||
|
||||
// 停止服务器
|
||||
srv.stop()
|
||||
|
||||
// 验证会话已被清理
|
||||
srv.mu.RLock()
|
||||
if len(srv.sessions) != 0 {
|
||||
t.Errorf("Expected 0 sessions after stop, got %d", len(srv.sessions))
|
||||
}
|
||||
srv.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestUDPSessionOperations(t *testing.T) {
|
||||
// 创建 UDP 连接
|
||||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||
defer conn.Close()
|
||||
|
||||
// 创建目标服务器(用于模拟连接)
|
||||
targetAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:19007")
|
||||
targetConn, _ := net.ListenUDP("udp", targetAddr)
|
||||
defer targetConn.Close()
|
||||
|
||||
// 创建上游
|
||||
upstream := &Upstream{
|
||||
targets: []*Target{{addr: "127.0.0.1:19007"}},
|
||||
balancer: newRoundRobin(),
|
||||
}
|
||||
upstream.targets[0].healthy.Store(true)
|
||||
|
||||
// 创建 UDP 服务器
|
||||
srv := newUDPServer(conn, upstream, 1*time.Minute)
|
||||
|
||||
// 测试 getSession - 不存在的会话
|
||||
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12347")
|
||||
session := srv.getSession(clientAddr)
|
||||
if session != nil {
|
||||
t.Error("Expected nil for non-existent session")
|
||||
}
|
||||
|
||||
// 创建模拟会话
|
||||
testSession := &udpSession{
|
||||
clientAddr: clientAddr,
|
||||
lastActive: time.Now(),
|
||||
srv: srv,
|
||||
}
|
||||
srv.sessions[sessionKey(clientAddr)] = testSession
|
||||
|
||||
// 测试 getSession - 存在的会话
|
||||
session = srv.getSession(clientAddr)
|
||||
if session == nil {
|
||||
t.Error("Expected non-nil session")
|
||||
}
|
||||
|
||||
// 测试 removeSession
|
||||
srv.removeSession(clientAddr)
|
||||
session = srv.getSession(clientAddr)
|
||||
if session != nil {
|
||||
t.Error("Expected nil after removeSession")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPSessionClose(t *testing.T) {
|
||||
// 创建两个 UDP 连接用于测试
|
||||
udpAddr1, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
conn1, _ := net.ListenUDP("udp", udpAddr1)
|
||||
|
||||
udpAddr2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
conn2, _ := net.ListenUDP("udp", udpAddr2)
|
||||
|
||||
// 创建会话
|
||||
session := &udpSession{
|
||||
clientAddr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12348},
|
||||
targetConn: conn2,
|
||||
lastActive: time.Now(),
|
||||
}
|
||||
|
||||
// 测试 close - 应该能正常关闭
|
||||
session.close()
|
||||
|
||||
// 第二次调用 close 不应该出错(使用 sync.Once)
|
||||
session.close()
|
||||
|
||||
conn1.Close()
|
||||
}
|
||||
|
||||
func TestHealthCheckerCheckWithHealthyTarget(t *testing.T) {
|
||||
// 启动一个临时的 TCP 服务器
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
// 在后台运行服务器
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
addr := listener.Addr().String()
|
||||
|
||||
u := &Upstream{
|
||||
targets: []*Target{
|
||||
{addr: addr},
|
||||
},
|
||||
}
|
||||
// 初始设置为不健康
|
||||
u.targets[0].healthy.Store(false)
|
||||
|
||||
hc := &HealthChecker{
|
||||
upstream: u,
|
||||
interval: 1 * time.Second,
|
||||
timeout: 100 * time.Millisecond,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 执行健康检查
|
||||
hc.check()
|
||||
|
||||
// 目标应该被标记为健康(因为服务器在运行)
|
||||
if !u.targets[0].healthy.Load() {
|
||||
t.Error("Expected target to be marked healthy")
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user