From eb379d9121d713d4385390bea7f0bdbe59fb3de3 Mon Sep 17 00:00:00 2001 From: xfy Date: Fri, 10 Apr 2026 17:45:53 +0800 Subject: [PATCH] =?UTF-8?q?test(proxy,ssl,server,variable):=20=E8=A1=A5?= =?UTF-8?q?=E5=85=A8=E6=B5=8B=E8=AF=95=E8=A6=86=E7=9B=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - websocket: 升级请求构建、响应读写、大消息转发、并发桥接 - ssl: CRL 吊销检查、证书链深度限制、完整验证流程 - server: 初始化配置、静态文件、GoroutinePool、FileCache - variable: mTLS 客户端证书变量和指纹计算 Co-Authored-By: Claude Opus 4.6 --- internal/proxy/websocket_test.go | 483 ++++++++++++++++++++++- internal/server/server_test.go | 127 +++++++ internal/ssl/client_verify_test.go | 352 ++++++++++++++++- internal/ssl/ocsp_test.go | 98 +++++ internal/variable/ssl_test.go | 591 +++++++++++++++++++++++++++++ 5 files changed, 1641 insertions(+), 10 deletions(-) create mode 100644 internal/variable/ssl_test.go diff --git a/internal/proxy/websocket_test.go b/internal/proxy/websocket_test.go index f114591..a7ac164 100644 --- a/internal/proxy/websocket_test.go +++ b/internal/proxy/websocket_test.go @@ -10,17 +10,28 @@ // - 数据复制 // - 双向数据转发 // - 超时错误处理 +// - 并发连接测试 +// - 大消息转发测试 +// +// goroutine 泄漏检测说明: +// fasthttp 库使用后台 worker goroutine,与 goleak 不兼容。 +// 如需检测泄漏,可手动运行:go test -race ./internal/proxy/... // // 作者:xfy package proxy import ( "errors" + "fmt" "io" "net" + "net/http" "strings" + "sync" "testing" "time" + + "github.com/valyala/fasthttp" ) // TestNewWebSocketBridge 测试桥接器创建 @@ -121,12 +132,6 @@ func TestIsConnectionClosedError(t *testing.T) { } } -// TestExtractHost 测试从 URL 提取主机 -func TestExtractHost(_ *testing.T) { - // extractHost 函数可能不存在,检查一下 - // 如果存在则测试 -} - // TestDialTarget_InvalidAddress 测试无效地址的拨号 func TestDialTarget_InvalidAddress(t *testing.T) { // 测试连接到无效端口 @@ -311,3 +316,469 @@ func TestCopyData(t *testing.T) { t.Error("copyData did not complete in time") } } + +// TestBuildWebSocketUpgradeRequest 测试构建 WebSocket 升级请求 +func TestBuildWebSocketUpgradeRequest(t *testing.T) { + tests := []struct { + name string + path string + query string + host string + targetHost string + wantContains []string + }{ + { + name: "basic request", + path: "/ws", + query: "", + host: "client.example.com", + targetHost: "backend.example.com:8080", + wantContains: []string{ + "GET /ws HTTP/1.1", + "Host: backend.example.com:8080", + "X-Forwarded-For:", + "X-Real-IP:", + "X-Forwarded-Host: client.example.com", + }, + }, + { + name: "request with query", + path: "/ws", + query: "token=abc123", + host: "client.example.com", + targetHost: "backend.example.com", + wantContains: []string{ + "GET /ws?token=abc123 HTTP/1.1", + "Host: backend.example.com", + }, + }, + { + name: "empty path defaults to slash", + path: "", + query: "", + host: "client.example.com", + targetHost: "backend.example.com", + wantContains: []string{ + "GET / HTTP/1.1", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI(tt.path) + if tt.query != "" { + ctx.QueryArgs().Parse(tt.query) + } + ctx.Request.Header.SetHost(tt.host) + + result := buildWebSocketUpgradeRequest(ctx, tt.targetHost) + + for _, want := range tt.wantContains { + if !strings.Contains(result, want) { + t.Errorf("buildWebSocketUpgradeRequest() missing %q in output:\n%s", want, result) + } + } + }) + } +} + +// TestBuildWebSocketUpgradeRequest_WithHeaders 测试复制 WebSocket 头 +func TestBuildWebSocketUpgradeRequest_WithHeaders(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/ws") + ctx.Request.Header.Set("Upgrade", "websocket") + ctx.Request.Header.Set("Connection", "Upgrade") + ctx.Request.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + ctx.Request.Header.Set("Sec-WebSocket-Version", "13") + ctx.Request.Header.Set("Sec-WebSocket-Protocol", "chat") + + result := buildWebSocketUpgradeRequest(ctx, "backend.example.com") + + // 验证关键头被复制 + expectedHeaders := []string{ + "Upgrade: websocket", + "Connection: Upgrade", + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version: 13", + "Sec-WebSocket-Protocol: chat", + } + + for _, expected := range expectedHeaders { + if !strings.Contains(result, expected) { + t.Errorf("Missing expected header %q in:\n%s", expected, result) + } + } +} + +// TestBuildWebSocketUpgradeRequest_TLSProto 测试 TLS 协议标记 +func TestBuildWebSocketUpgradeRequest_TLSProto(t *testing.T) { + tests := []struct { + name string + isTLS bool + wantProto string + }{ + { + name: "non-TLS connection", + isTLS: false, + wantProto: "X-Forwarded-Proto: http", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/ws") + + // 注意:fasthttp.RequestCtx 默认 IsTLS() 返回 false + // 无法在单元测试中直接模拟 TLS 连接 + + result := buildWebSocketUpgradeRequest(ctx, "backend.example.com") + + if !strings.Contains(result, tt.wantProto) { + t.Errorf("Missing %q in:\n%s", tt.wantProto, result) + } + }) + } +} + +// TestExtractHost 测试从 URL 提取主机 +func TestExtractHost(t *testing.T) { + tests := []struct { + name string + url string + expected string + }{ + { + name: "http with port", + url: "http://example.com:8080", + expected: "example.com:8080", + }, + { + name: "https with port", + url: "https://example.com:8443", + expected: "example.com:8443", + }, + { + name: "http without port", + url: "http://example.com", + expected: "example.com:80", + }, + { + name: "https without port", + url: "https://example.com", + expected: "example.com:443", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractHost(tt.url) + if result != tt.expected { + t.Errorf("extractHost(%q) = %q, want %q", tt.url, result, tt.expected) + } + }) + } +} + +// TestWriteUpgradeResponse 测试写入升级响应 +func TestWriteUpgradeResponse(t *testing.T) { + // 创建管道连接 + conn1, conn2 := net.Pipe() + defer func() { _ = conn2.Close() }() + + // 创建模拟 HTTP 响应 + resp := &http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + Status: "101 Switching Protocols", + StatusCode: 101, + Header: http.Header{ + "Upgrade": []string{"websocket"}, + "Connection": []string{"Upgrade"}, + }, + } + + // 启动写入 + errCh := make(chan error, 1) + go func() { + errCh <- writeUpgradeResponse(conn1, resp) + _ = conn1.Close() + }() + + // 读取响应 + buf := make([]byte, 1024) + n, err := conn2.Read(buf) + if err != nil { + t.Fatalf("Failed to read response: %v", err) + } + + response := string(buf[:n]) + + // 验证响应格式 + expectedParts := []string{ + "HTTP/1.1 101 Switching Protocols", + "Upgrade: websocket", + "Connection: Upgrade", + } + + for _, expected := range expectedParts { + if !strings.Contains(response, expected) { + t.Errorf("Missing %q in response:\n%s", expected, response) + } + } + + // 等待写入完成 + select { + case err := <-errCh: + if err != nil { + t.Errorf("writeUpgradeResponse returned error: %v", err) + } + case <-time.After(1 * time.Second): + t.Error("writeUpgradeResponse did not complete in time") + } +} + +// TestReadWebSocketUpgradeResponse 测试读取升级响应 +func TestReadWebSocketUpgradeResponse(t *testing.T) { + // 创建管道连接 + conn1, conn2 := net.Pipe() + defer func() { _ = conn1.Close() }() + + // 在另一个 goroutine 中写入响应 + go func() { + response := "HTTP/1.1 101 Switching Protocols\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + "\r\n" + _, _ = conn2.Write([]byte(response)) + _ = conn2.Close() + }() + + // 读取响应 + resp, err := readWebSocketUpgradeResponse(conn1, 1*time.Second) + if err != nil { + t.Fatalf("readWebSocketUpgradeResponse failed: %v", err) + } + + if resp.StatusCode != 101 { + t.Errorf("Expected status 101, got %d", resp.StatusCode) + } + + if resp.Header.Get("Upgrade") != "websocket" { + t.Errorf("Expected Upgrade: websocket, got %q", resp.Header.Get("Upgrade")) + } +} + +// TestReadWebSocketUpgradeResponse_Timeout 测试读取超时 +func TestReadWebSocketUpgradeResponse_Timeout(t *testing.T) { + // 创建管道连接但不写入数据 + conn1, conn2 := net.Pipe() + defer func() { _ = conn1.Close() }() + defer func() { _ = conn2.Close() }() + + // 使用很短的超时 + _, err := readWebSocketUpgradeResponse(conn1, 10*time.Millisecond) + if err == nil { + t.Error("Expected timeout error, got nil") + } +} + +// TestDialTarget_TLS 测试 TLS 连接(连接无效端口应失败) +func TestDialTarget_TLS(t *testing.T) { + // 测试 HTTPS 连接到无效端口 + _, err := dialTarget("https://127.0.0.1:1", 100*time.Millisecond) + if err == nil { + t.Error("Expected error for invalid HTTPS address") + } +} + +// TestIsConnectionClosedError_ClosedConn 测试已关闭连接错误 +func TestIsConnectionClosedError_ClosedConn(t *testing.T) { + // 创建并立即关闭连接 + ln, _ := net.Listen("tcp", "127.0.0.1:0") + conn, _ := net.Dial("tcp", ln.Addr().String()) + _ = conn.Close() + _ = ln.Close() + + // 尝试读取应返回错误 + _, err := conn.Read(make([]byte, 1)) + if err == nil { + t.Error("Expected error reading from closed connection") + } + + // 验证错误被识别为连接关闭错误 + if !isConnectionClosedError(err) { + t.Errorf("Expected closed connection error, got: %v", err) + } +} + +// TestWebSocketBridge_LargeMessage 测试大消息转发 +func TestWebSocketBridge_LargeMessage(t *testing.T) { + // 创建管道连接 + client1, client2 := net.Pipe() + target1, target2 := net.Pipe() + defer func() { _ = client2.Close() }() + defer func() { _ = target2.Close() }() + + bridge := NewWebSocketBridge(client1, target1) + + // 启动桥接 + errCh := make(chan error, 1) + go func() { + errCh <- bridge.Bridge() + }() + + // 发送超过 64KB 的数据 + largeData := make([]byte, 100*1024) // 100KB + for i := range largeData { + largeData[i] = byte(i % 256) + } + + // 客户端发送大消息 + go func() { + _, _ = client2.Write(largeData) + }() + + // 后端接收数据 + buf := make([]byte, 150*1024) + total := 0 + for total < len(largeData) { + n, err := target2.Read(buf[total:]) + if err != nil { + t.Fatalf("Failed to read large message: %v", err) + } + total += n + } + + // 验证数据完整性 + for i := range largeData { + if buf[i] != largeData[i] { + t.Errorf("Data mismatch at byte %d: got %d, want %d", i, buf[i], largeData[i]) + break + } + } + + // 关闭连接 + _ = client2.Close() + _ = target2.Close() + + // 等待桥接完成 + select { + case err := <-errCh: + if err != nil { + t.Errorf("Bridge returned error: %v", err) + } + case <-time.After(2 * time.Second): + t.Error("Bridge did not complete in time") + } +} + +// TestWebSocketBridge_Concurrent 测试并发桥接 +func TestWebSocketBridge_Concurrent(t *testing.T) { + const numBridges = 10 + + var wg sync.WaitGroup + errCh := make(chan error, numBridges) + + for i := 0; i < numBridges; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + // 创建管道连接 + client1, client2 := net.Pipe() + target1, target2 := net.Pipe() + defer func() { _ = client2.Close() }() + defer func() { _ = target2.Close() }() + + bridge := NewWebSocketBridge(client1, target1) + + // 启动桥接 + done := make(chan error, 1) + go func() { + done <- bridge.Bridge() + }() + + // 发送测试数据 + testData := []byte("concurrent test data") + go func() { + _, _ = client2.Write(testData) + }() + + // 接收数据 + buf := make([]byte, 1024) + n, err := target2.Read(buf) + if err != nil { + errCh <- fmt.Errorf("bridge %d: read error: %v", id, err) + return + } + + if string(buf[:n]) != string(testData) { + errCh <- fmt.Errorf("bridge %d: data mismatch", id) + return + } + + // 关闭连接 + _ = client2.Close() + _ = target2.Close() + + // 等待桥接完成 + select { + case err := <-done: + if err != nil { + errCh <- fmt.Errorf("bridge %d: %v", id, err) + } + case <-time.After(1 * time.Second): + errCh <- fmt.Errorf("bridge %d: timeout", id) + } + }(i) + } + + wg.Wait() + close(errCh) + + // 检查错误 + for err := range errCh { + if err != nil { + t.Errorf("Concurrent bridge error: %v", err) + } + } +} + +// TestCopyData_WriteError 测试写入错误处理 +func TestCopyData_WriteError(t *testing.T) { + // 创建管道连接 + src1, src2 := net.Pipe() + dst1, dst2 := net.Pipe() + + bridge := &WebSocketBridge{} + + // 先关闭目标连接 + _ = dst1.Close() + _ = dst2.Close() + + // 启动数据复制 + errCh := make(chan error, 1) + go func() { + errCh <- bridge.copyData(dst1, src1, "test") + }() + + // 发送数据(应该触发写入错误) + _, _ = src2.Write([]byte("test data")) + _ = src2.Close() + + // 等待完成 + select { + case err := <-errCh: + // 写入到已关闭连接应该返回 nil(视为连接关闭错误) + if err != nil && !strings.Contains(err.Error(), "closed") { + t.Errorf("copyData returned unexpected error: %v", err) + } + case <-time.After(1 * time.Second): + t.Error("copyData did not complete in time") + } + + _ = src1.Close() +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 463f453..b421be2 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -638,3 +638,130 @@ func TestServer_TrackStats_EmptyBody(t *testing.T) { t.Errorf("Expected 0 bytes sent, got %d", s.bytesSent.Load()) } } + +// TestStart_Success 测试服务器配置初始化 +func TestStart_Success(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":8080", + }, + } + + s := New(cfg) + + // 验证服务器正确初始化 + if s == nil { + t.Fatal("New() returned nil, expected non-nil Server") + } + + if s.config != cfg { + t.Error("Server.config not set correctly") + } +} + +// TestStart_WithStaticFiles 测试静态文件配置 +func TestStart_WithStaticFiles(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":8080", + Static: []config.StaticConfig{{ + Path: "/static", + Root: tempDir, + Index: []string{"index.html"}, + }}, + }, + } + + s := New(cfg) + + if s == nil { + t.Fatal("New() returned nil") + } +} + +// TestStart_WithGoroutinePool 测试 GoroutinePool 配置 +func TestStart_WithGoroutinePool(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":8080", + }, + Performance: config.PerformanceConfig{ + GoroutinePool: config.GoroutinePoolConfig{ + Enabled: true, + MaxWorkers: 100, + MinWorkers: 10, + IdleTimeout: 30 * time.Second, + }, + }, + } + + s := New(cfg) + + if s == nil { + t.Fatal("New() returned nil") + } +} + +// TestStart_WithFileCache 测试文件缓存配置 +func TestStart_WithFileCache(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":8080", + }, + Performance: config.PerformanceConfig{ + FileCache: config.FileCacheConfig{ + MaxEntries: 1000, + MaxSize: 100 * 1024 * 1024, + }, + }, + } + + s := New(cfg) + + if s == nil { + t.Fatal("New() returned nil") + } +} + +// TestStop_Graceful 测试优雅停止(无 race 模式) +func TestStop_Graceful(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":0", + }, + } + + s := New(cfg) + + // 在未启动时调用 GracefulStop,应返回 nil + err := s.GracefulStop(1 * time.Second) + if err != nil { + t.Errorf("GracefulStop() on non-started server returned error: %v", err) + } +} + +// TestGetTLSConfig_Nil 测试无 TLS 配置 +func TestGetTLSConfig_Nil(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":0", + }, + } + + s := New(cfg) + + tlsCfg, err := s.GetTLSConfig() + if err == nil { + t.Error("GetTLSConfig() should return error when TLS not configured") + } + if tlsCfg != nil { + t.Error("GetTLSConfig() should return nil when TLS not configured") + } +} diff --git a/internal/ssl/client_verify_test.go b/internal/ssl/client_verify_test.go index b0b3024..4636895 100644 --- a/internal/ssl/client_verify_test.go +++ b/internal/ssl/client_verify_test.go @@ -61,8 +61,8 @@ func generateTestCA(t *testing.T) (*x509.Certificate, *rsa.PrivateKey, []byte) { return cert, key, certPEM } -// generateTestClientCert 生成测试客户端证书。 -func generateTestClientCert(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, []byte) { +// generateTestClientCert 生成测试客户端证书,serial 参数指定序列号。 +func generateTestClientCert(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey, serial int64) (*x509.Certificate, *rsa.PrivateKey, []byte) { t.Helper() key, err := rsa.GenerateKey(rand.Reader, 2048) @@ -71,7 +71,7 @@ func generateTestClientCert(t *testing.T, caCert *x509.Certificate, caKey *rsa.P } template := &x509.Certificate{ - SerialNumber: big.NewInt(2), + SerialNumber: big.NewInt(serial), Subject: pkix.Name{ CommonName: "Test Client", Organization: []string{"Test Org"}, @@ -297,7 +297,7 @@ func TestClientVerifier_ConfigureTLS_Disabled(t *testing.T) { func TestGetClientCertInfo(t *testing.T) { // 生成测试证书 caCert, caKey, _ := generateTestCA(t) - clientCert, _, _ := generateTestClientCert(t, caCert, caKey) + clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 2) // 创建模拟连接状态 cs := &tls.ConnectionState{ @@ -335,6 +335,350 @@ func TestGetClientCertInfo_Nil(t *testing.T) { } } +// TestGetMode 测试获取验证模式。 +func TestGetMode(t *testing.T) { + tests := []struct { + name string + mode string + expected ClientVerifyMode + }{ + {"off", "off", VerifyOff}, + {"on", "on", VerifyOn}, + {"optional", "optional", VerifyOptional}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + caFile := filepath.Join(tempDir, "ca.crt") + _, _, caPEM := generateTestCA(t) + if err := os.WriteFile(caFile, caPEM, 0644); err != nil { + t.Fatalf("写入 CA 文件失败: %v", err) + } + + var cfg config.ClientVerifyConfig + if tt.mode != "off" { + cfg = config.ClientVerifyConfig{ + Enabled: true, + Mode: tt.mode, + ClientCA: caFile, + } + } else { + cfg = config.ClientVerifyConfig{Enabled: false} + } + + verifier, err := NewClientVerifier(cfg) + if err != nil { + t.Fatalf("NewClientVerifier() failed: %v", err) + } + + if verifier.GetMode() != tt.expected { + t.Errorf("GetMode() = %v, want %v", verifier.GetMode(), tt.expected) + } + }) + } +} + +// generateTestCRL 生成测试 CRL。 +func generateTestCRL(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey, revokedSerials []*big.Int) []byte { + t.Helper() + + template := &x509.RevocationList{ + Number: big.NewInt(1), + ThisUpdate: time.Now(), + NextUpdate: time.Now().Add(24 * time.Hour), + RevokedCertificateEntries: func() []x509.RevocationListEntry { + entries := make([]x509.RevocationListEntry, len(revokedSerials)) + for i, serial := range revokedSerials { + entries[i] = x509.RevocationListEntry{ + SerialNumber: serial, + RevocationTime: time.Now(), + } + } + return entries + }(), + } + + crlDER, err := x509.CreateRevocationList(rand.Reader, template, caCert, caKey) + if err != nil { + t.Fatalf("Failed to create CRL: %v", err) + } + + return pem.EncodeToMemory(&pem.Block{Type: "X509 CRL", Bytes: crlDER}) +} + +// TestLoadCRL 测试 CRL 加载。 +func TestLoadCRL(t *testing.T) { + // 生成测试 CA + caCert, caKey, _ := generateTestCA(t) + + // 生成包含吊销证书的 CRL + revokedSerial := big.NewInt(999) + crlPEM := generateTestCRL(t, caCert, caKey, []*big.Int{revokedSerial}) + + // 写入临时文件 + tempDir := t.TempDir() + crlFile := filepath.Join(tempDir, "crl.pem") + if err := os.WriteFile(crlFile, crlPEM, 0644); err != nil { + t.Fatalf("写入 CRL 文件失败: %v", err) + } + + // 测试加载 + crl, err := LoadCRL(crlFile) + if err != nil { + t.Fatalf("LoadCRL() failed: %v", err) + } + if crl == nil { + t.Fatal("LoadCRL() returned nil") + } + if len(crl.RevokedCertificateEntries) != 1 { + t.Errorf("CRL should have 1 revoked certificate, got %d", len(crl.RevokedCertificateEntries)) + } + + // 测试文件不存在 + _, err = LoadCRL("/nonexistent/crl.pem") + if err == nil { + t.Error("LoadCRL() should fail for non-existent file") + } + + // 测试无效 CRL + invalidFile := filepath.Join(tempDir, "invalid.crl") + if err := os.WriteFile(invalidFile, []byte("invalid data"), 0644); err != nil { + t.Fatalf("写入无效文件失败: %v", err) + } + _, err = LoadCRL(invalidFile) + if err == nil { + t.Error("LoadCRL() should fail for invalid CRL") + } +} + +// TestCheckCRL 测试 CRL 检查。 +func TestCheckCRL(t *testing.T) { + // 生成测试 CA + caCert, caKey, _ := generateTestCA(t) + + // 生成将被吊销的客户端证书(序列号100) + revokedCert, _, _ := generateTestClientCert(t, caCert, caKey, 100) + + // 生成有效客户端证书(序列号200,不会被吊销) + validCert, _, _ := generateTestClientCert(t, caCert, caKey, 200) + + // 生成包含吊销证书的 CRL + crlPEM := generateTestCRL(t, caCert, caKey, []*big.Int{revokedCert.SerialNumber}) + + // 写入临时文件 + tempDir := t.TempDir() + crlFile := filepath.Join(tempDir, "crl.pem") + caFile := filepath.Join(tempDir, "ca.crt") + _, _, caPEM := generateTestCA(t) + if err := os.WriteFile(crlFile, crlPEM, 0644); err != nil { + t.Fatalf("写入 CRL 文件失败: %v", err) + } + if err := os.WriteFile(caFile, caPEM, 0644); err != nil { + t.Fatalf("写入 CA 文件失败: %v", err) + } + + // 创建带 CRL 的验证器 + verifier, err := NewClientVerifier(config.ClientVerifyConfig{ + Enabled: true, + Mode: "on", + ClientCA: caFile, + CRL: crlFile, + }) + if err != nil { + t.Fatalf("NewClientVerifier() failed: %v", err) + } + + // 测试检查有效证书 + err = verifier.ValidateClientCertificate(validCert) + if err != nil { + t.Errorf("CheckCRL() should pass for valid cert: %v", err) + } + + // 测试检查吊销证书 + err = verifier.ValidateClientCertificate(revokedCert) + if err == nil { + t.Error("CheckCRL() should fail for revoked cert") + } +} + +// TestCheckCRL_EmptyCRL 测试空 CRL。 +func TestCheckCRL_EmptyCRL(t *testing.T) { + caCert, caKey, _ := generateTestCA(t) + clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 50) + + // 生成空 CRL(无吊销证书) + crlPEM := generateTestCRL(t, caCert, caKey, nil) + + tempDir := t.TempDir() + crlFile := filepath.Join(tempDir, "crl.pem") + caFile := filepath.Join(tempDir, "ca.crt") + _, _, caPEM := generateTestCA(t) + if err := os.WriteFile(crlFile, crlPEM, 0644); err != nil { + t.Fatalf("写入 CRL 文件失败: %v", err) + } + if err := os.WriteFile(caFile, caPEM, 0644); err != nil { + t.Fatalf("写入 CA 文件失败: %v", err) + } + + verifier, err := NewClientVerifier(config.ClientVerifyConfig{ + Enabled: true, + Mode: "on", + ClientCA: caFile, + CRL: crlFile, + }) + if err != nil { + t.Fatalf("NewClientVerifier() failed: %v", err) + } + + // 空列表应通过所有证书 + err = verifier.ValidateClientCertificate(clientCert) + if err != nil { + t.Errorf("CheckCRL() should pass with empty CRL: %v", err) + } +} + +// TestValidateClientCertificate 测试手动验证客户端证书。 +func TestValidateClientCertificate(t *testing.T) { + // 测试禁用验证器 + verifier, _ := NewClientVerifier(config.ClientVerifyConfig{Enabled: false}) + + err := verifier.ValidateClientCertificate(nil) + if err != nil { + t.Errorf("Disabled verifier should accept nil cert: %v", err) + } + + // 测试启用验证器(on 模式) + tempDir := t.TempDir() + caFile := filepath.Join(tempDir, "ca.crt") + _, _, caPEM := generateTestCA(t) + if err := os.WriteFile(caFile, caPEM, 0644); err != nil { + t.Fatalf("写入 CA 文件失败: %v", err) + } + + verifier, _ = NewClientVerifier(config.ClientVerifyConfig{ + Enabled: true, + Mode: "on", + ClientCA: caFile, + }) + + // nil 证书在 on 模式应失败 + err = verifier.ValidateClientCertificate(nil) + if err == nil { + t.Error("ValidateClientCertificate(nil) should fail in 'on' mode") + } + + // 有效证书应通过 + caCert, caKey, _ := generateTestCA(t) + clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 30) + err = verifier.ValidateClientCertificate(clientCert) + if err != nil { + t.Errorf("ValidateClientCertificate() should pass for valid cert: %v", err) + } +} + +// TestVerifyConnection 测试连接验证。 +func TestVerifyConnection(t *testing.T) { + // 生成测试 CA 和证书 + caCert, caKey, _ := generateTestCA(t) + validCert, _, _ := generateTestClientCert(t, caCert, caKey, 300) + revokedCert, _, _ := generateTestClientCert(t, caCert, caKey, 400) + + // 生成包含吊销证书的 CRL + crlPEM := generateTestCRL(t, caCert, caKey, []*big.Int{revokedCert.SerialNumber}) + + tempDir := t.TempDir() + crlFile := filepath.Join(tempDir, "crl.pem") + caFile := filepath.Join(tempDir, "ca.crt") + _, _, caPEM := generateTestCA(t) + if err := os.WriteFile(crlFile, crlPEM, 0644); err != nil { + t.Fatalf("写入 CRL 文件失败: %v", err) + } + if err := os.WriteFile(caFile, caPEM, 0644); err != nil { + t.Fatalf("写入 CA 文件失败: %v", err) + } + + // 测试带 CRL 和深度限制的验证器 + verifier, err := NewClientVerifier(config.ClientVerifyConfig{ + Enabled: true, + Mode: "on", + ClientCA: caFile, + CRL: crlFile, + VerifyDepth: 3, + }) + if err != nil { + t.Fatalf("NewClientVerifier() failed: %v", err) + } + + // 配置 TLS 以设置 VerifyConnection 回调 + tlsCfg := &tls.Config{} + verifier.ConfigureTLS(tlsCfg) + + if tlsCfg.VerifyConnection == nil { + t.Fatal("VerifyConnection should be set when VerifyDepth > 0") + } + + // 测试有效证书连接 + validCS := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{validCert}, + } + err = tlsCfg.VerifyConnection(*validCS) + if err != nil { + t.Errorf("VerifyConnection() should pass for valid cert: %v", err) + } + + // 测试吊销证书连接 + revokedCS := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{revokedCert}, + } + err = tlsCfg.VerifyConnection(*revokedCS) + if err == nil { + t.Error("VerifyConnection() should fail for revoked cert") + } +} + +// TestVerifyConnection_DepthLimit 测试证书链深度限制。 +func TestVerifyConnection_DepthLimit(t *testing.T) { + caCert, caKey, _ := generateTestCA(t) + clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 500) + + tempDir := t.TempDir() + caFile := filepath.Join(tempDir, "ca.crt") + _, _, caPEM := generateTestCA(t) + if err := os.WriteFile(caFile, caPEM, 0644); err != nil { + t.Fatalf("写入 CA 文件失败: %v", err) + } + + // 测试深度限制为 1 + verifier, _ := NewClientVerifier(config.ClientVerifyConfig{ + Enabled: true, + Mode: "on", + ClientCA: caFile, + VerifyDepth: 1, + }) + + tlsCfg := &tls.Config{} + verifier.ConfigureTLS(tlsCfg) + + // 单个证书应通过 + cs := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{clientCert}, + } + err := tlsCfg.VerifyConnection(*cs) + if err != nil { + t.Errorf("VerifyConnection() should pass for single cert with depth 1: %v", err) + } + + // 多个证书应失败(链太长) + longChain := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{clientCert, caCert}, + } + err = tlsCfg.VerifyConnection(*longChain) + if err == nil { + t.Error("VerifyConnection() should fail for chain exceeding depth limit") + } +} + // BenchmarkLoadCACertPool 基准测试 CA 证书池加载。 func BenchmarkLoadCACertPool(b *testing.B) { tempDir := b.TempDir() diff --git a/internal/ssl/ocsp_test.go b/internal/ssl/ocsp_test.go index dc000de..ae61c8c 100644 --- a/internal/ssl/ocsp_test.go +++ b/internal/ssl/ocsp_test.go @@ -458,3 +458,101 @@ func TestOCSPConfigDefaults(t *testing.T) { t.Errorf("Expected default max retries 3, got %d", cfg.MaxRetries) } } + +// TestOCSPManager_RefreshResponse 测试强制刷新 OCSP 响应 +func TestOCSPManager_RefreshResponse(_ *testing.T) { + cfg := &OCSPConfig{ + Enabled: true, + RefreshInterval: 1 * time.Hour, + Timeout: 100 * time.Millisecond, + MaxRetries: 1, + } + mgr := NewOCSPManager(cfg) + + // 创建带 OCSP 服务器的测试证书 + serial := big.NewInt(12345) + priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + template := &x509.Certificate{ + SerialNumber: serial, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + OCSPServer: []string{"http://ocsp.example.com"}, + } + certDER, _ := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + cert, _ := x509.ParseCertificate(certDER) + + // 刷新响应(会失败因为 URL 无效) + err := mgr.RefreshResponse(cert, cert) + // 由于 URL 无效,预期会失败 + if err == nil { + // 如果没有错误,检查状态 + status, hasResp := mgr.GetStatus(serial.String()) + _ = status + _ = hasResp + } +} + +// TestOCSPManager_refreshAll 测试刷新所有响应 +func TestOCSPManager_refreshAll(_ *testing.T) { + cfg := &OCSPConfig{ + Enabled: true, + RefreshInterval: 1 * time.Hour, + Timeout: 100 * time.Millisecond, + MaxRetries: 1, + } + mgr := NewOCSPManager(cfg) + + // 手动添加一些响应到缓存 + serial1 := "1001" + serial2 := "1002" + + mgr.mu.Lock() + mgr.responses[serial1] = &ocspResponse{ + status: statusValid, + response: []byte("test-response"), + nextUpdate: time.Now().Add(-1 * time.Hour), // 已过期 + fetchedAt: time.Now().Add(-2 * time.Hour), + } + mgr.responses[serial2] = &ocspResponse{ + status: statusValid, + response: []byte("test-response-2"), + nextUpdate: time.Now().Add(1 * time.Hour), // 未过期 + fetchedAt: time.Now(), + } + mgr.mu.Unlock() + + // 调用 refreshAll + mgr.refreshAll() + + // 验证刷新逻辑被触发(无法验证实际刷新因为 URL 无效) + // 主要目的是确保代码路径被覆盖 +} + +// TestOCSPManager_GetStatus_EdgeCases 测试 GetStatus 边界情况 +func TestOCSPManager_GetStatus_EdgeCases(t *testing.T) { + cfg := DefaultOCSPConfig() + mgr := NewOCSPManager(cfg) + + // 测试不存在的序列号 + status, hasResp := mgr.GetStatus("nonexistent") + if hasResp { + t.Error("Expected no response for nonexistent serial") + } + if status != statusFailed { + t.Errorf("Expected statusFailed for nonexistent serial, got %v", status) + } + + // 测试空响应 + serial := "empty-response" + mgr.mu.Lock() + mgr.responses[serial] = &ocspResponse{ + status: statusValid, + response: nil, // 空响应 + } + mgr.mu.Unlock() + + _, hasResp = mgr.GetStatus(serial) + if hasResp { + t.Error("Expected no response for empty response data") + } +} diff --git a/internal/variable/ssl_test.go b/internal/variable/ssl_test.go new file mode 100644 index 0000000..397425f --- /dev/null +++ b/internal/variable/ssl_test.go @@ -0,0 +1,591 @@ +// ssl_test.go - SSL/TLS 客户端证书变量测试 +// +// 测试覆盖: +// - mTLS 客户端证书变量获取 +// - SetSSLClientInfoInContext 设置功能 +// - calculateFingerprint 指纹计算 +// +// 作者:xfy +package variable + +import ( + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "testing" + "time" + + "github.com/valyala/fasthttp" +) + +// TestGetSSLClientVerify_NilContext 测试 nil 上下文 +func TestGetSSLClientVerify_NilContext(t *testing.T) { + result := GetSSLClientVerify(nil) + if result != "NONE" { + t.Errorf("GetSSLClientVerify(nil) = %q, want NONE", result) + } +} + +// TestGetSSLClientVerify_NoTLS 测试非 TLS 连接 +func TestGetSSLClientVerify_NoTLS(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + // 默认情况下 IsTLS() 返回 false + result := GetSSLClientVerify(ctx) + if result != "NONE" { + t.Errorf("GetSSLClientVerify(non-TLS) = %q, want NONE", result) + } +} + +// TestGetSSLClientVerify_NonTLSWithUserValue 测试非 TLS 连接即使设置了 UserValue 也返回 NONE +// 注意:GetSSLClientVerify 会先检查 ctx.IsTLS(),非 TLS 连接直接返回 NONE +// 这是正确的行为,SSL 客户端变量只在 TLS 连接中有效 +func TestGetSSLClientVerify_NonTLSWithUserValue(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.SetUserValue(VarSSLClientVerify, "SUCCESS") + + // 非 TLS 连接,即使设置了 UserValue 也应该返回 NONE + result := GetSSLClientVerify(ctx) + if result != "NONE" { + t.Errorf("GetSSLClientVerify(non-TLS with value) = %q, want NONE", result) + } +} + +// TestGetSSLClientVerify_PeerCertPresent_NonTLS 测试非 TLS 下 peer_cert_present 不生效 +func TestGetSSLClientVerify_PeerCertPresent_NonTLS(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.SetUserValue("tls_peer_cert_present", true) + + // 非 TLS 连接,peer_cert_present 不应该改变结果 + result := GetSSLClientVerify(ctx) + if result != "NONE" { + t.Errorf("GetSSLClientVerify(non-TLS with peer_cert) = %q, want NONE", result) + } +} + +// TestGetSSLClientSerial 测试获取序列号 +func TestGetSSLClientSerial(t *testing.T) { + tests := []struct { + name string + setup func(*fasthttp.RequestCtx) + expected string + }{ + { + name: "no value", + setup: func(_ *fasthttp.RequestCtx) {}, + expected: "", + }, + { + name: "with serial", + setup: func(ctx *fasthttp.RequestCtx) { + ctx.SetUserValue(VarSSLClientSerial, "1234567890ABCDEF") + }, + expected: "1234567890ABCDEF", + }, + { + name: "invalid type", + setup: func(ctx *fasthttp.RequestCtx) { + ctx.SetUserValue(VarSSLClientSerial, 12345) + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + tt.setup(ctx) + result := GetSSLClientSerial(ctx) + if result != tt.expected { + t.Errorf("GetSSLClientSerial() = %q, want %q", result, tt.expected) + } + }) + } +} + +// TestGetSSLClientSubject 测试获取主题 +func TestGetSSLClientSubject(t *testing.T) { + tests := []struct { + name string + setup func(*fasthttp.RequestCtx) + expected string + }{ + { + name: "no value", + setup: func(_ *fasthttp.RequestCtx) {}, + expected: "", + }, + { + name: "with subject", + setup: func(ctx *fasthttp.RequestCtx) { + ctx.SetUserValue(VarSSLClientSubject, "CN=test.example.com,O=Test Org") + }, + expected: "CN=test.example.com,O=Test Org", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + tt.setup(ctx) + result := GetSSLClientSubject(ctx) + if result != tt.expected { + t.Errorf("GetSSLClientSubject() = %q, want %q", result, tt.expected) + } + }) + } +} + +// TestGetSSLClientIssuer 测试获取颁发者 +func TestGetSSLClientIssuer(t *testing.T) { + tests := []struct { + name string + setup func(*fasthttp.RequestCtx) + expected string + }{ + { + name: "no value", + setup: func(_ *fasthttp.RequestCtx) {}, + expected: "", + }, + { + name: "with issuer", + setup: func(ctx *fasthttp.RequestCtx) { + ctx.SetUserValue(VarSSLClientIssuer, "CN=Test CA,O=Test Org") + }, + expected: "CN=Test CA,O=Test Org", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + tt.setup(ctx) + result := GetSSLClientIssuer(ctx) + if result != tt.expected { + t.Errorf("GetSSLClientIssuer() = %q, want %q", result, tt.expected) + } + }) + } +} + +// TestGetSSLClientFingerprint 测试获取指纹 +func TestGetSSLClientFingerprint(t *testing.T) { + tests := []struct { + name string + setup func(*fasthttp.RequestCtx) + expected string + }{ + { + name: "no value", + setup: func(_ *fasthttp.RequestCtx) {}, + expected: "", + }, + { + name: "with fingerprint", + setup: func(ctx *fasthttp.RequestCtx) { + ctx.SetUserValue(VarSSLClientFingerprint, "A1B2C3D4E5F6") + }, + expected: "A1B2C3D4E5F6", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + tt.setup(ctx) + result := GetSSLClientFingerprint(ctx) + if result != tt.expected { + t.Errorf("GetSSLClientFingerprint() = %q, want %q", result, tt.expected) + } + }) + } +} + +// TestGetSSLClientNotBefore 测试获取生效时间 +func TestGetSSLClientNotBefore(t *testing.T) { + tests := []struct { + name string + setup func(*fasthttp.RequestCtx) + expected string + }{ + { + name: "no value", + setup: func(_ *fasthttp.RequestCtx) {}, + expected: "", + }, + { + name: "with notbefore", + setup: func(ctx *fasthttp.RequestCtx) { + ctx.SetUserValue(VarSSLClientNotBefore, "2025-01-01T00:00:00Z") + }, + expected: "2025-01-01T00:00:00Z", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + tt.setup(ctx) + result := GetSSLClientNotBefore(ctx) + if result != tt.expected { + t.Errorf("GetSSLClientNotBefore() = %q, want %q", result, tt.expected) + } + }) + } +} + +// TestGetSSLClientNotAfter 测试获取过期时间 +func TestGetSSLClientNotAfter(t *testing.T) { + tests := []struct { + name string + setup func(*fasthttp.RequestCtx) + expected string + }{ + { + name: "no value", + setup: func(_ *fasthttp.RequestCtx) {}, + expected: "", + }, + { + name: "with notafter", + setup: func(ctx *fasthttp.RequestCtx) { + ctx.SetUserValue(VarSSLClientNotAfter, "2026-01-01T00:00:00Z") + }, + expected: "2026-01-01T00:00:00Z", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + tt.setup(ctx) + result := GetSSLClientNotAfter(ctx) + if result != tt.expected { + t.Errorf("GetSSLClientNotAfter() = %q, want %q", result, tt.expected) + } + }) + } +} + +// TestGetSSLClientEmail 测试获取邮箱 +func TestGetSSLClientEmail(t *testing.T) { + tests := []struct { + name string + setup func(*fasthttp.RequestCtx) + expected string + }{ + { + name: "no value", + setup: func(_ *fasthttp.RequestCtx) {}, + expected: "", + }, + { + name: "with email", + setup: func(ctx *fasthttp.RequestCtx) { + ctx.SetUserValue(VarSSLClientEmail, "test@example.com") + }, + expected: "test@example.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + tt.setup(ctx) + result := GetSSLClientEmail(ctx) + if result != tt.expected { + t.Errorf("GetSSLClientEmail() = %q, want %q", result, tt.expected) + } + }) + } +} + +// TestSetSSLClientInfoInContext_NilCtx 测试 nil 上下文 +func TestSetSSLClientInfoInContext_NilCtx(_ *testing.T) { + // 不应该 panic + SetSSLClientInfoInContext(nil, &tls.ConnectionState{}, "SUCCESS") +} + +// TestSetSSLClientInfoInContext_NilConnState 测试 nil 连接状态 +func TestSetSSLClientInfoInContext_NilConnState(_ *testing.T) { + ctx := &fasthttp.RequestCtx{} + // 不应该 panic + SetSSLClientInfoInContext(ctx, nil, "SUCCESS") +} + +// TestSetSSLClientInfoInContext_NoPeerCerts 测试无客户端证书 +func TestSetSSLClientInfoInContext_NoPeerCerts(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + cs := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{}, + } + + SetSSLClientInfoInContext(ctx, cs, "NONE") + + // 验证只设置了 verify 状态 + if v := ctx.UserValue(VarSSLClientVerify); v != "NONE" { + t.Errorf("expected verify=NONE, got %v", v) + } + if v := ctx.UserValue("tls_peer_cert_present"); v != nil { + t.Errorf("expected no peer_cert_present, got %v", v) + } +} + +// TestSetSSLClientInfoInContext_WithPeerCert 测试有客户端证书 +func TestSetSSLClientInfoInContext_WithPeerCert(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + + // 创建模拟证书 + now := time.Now() + cert := &x509.Certificate{ + SerialNumber: big.NewInt(12345), + Subject: pkix.Name{ + CommonName: "test.example.com", + Organization: []string{"Test Org"}, + }, + Issuer: pkix.Name{ + CommonName: "Test CA", + }, + NotBefore: now.Add(-24 * time.Hour), + NotAfter: now.Add(365 * 24 * time.Hour), + EmailAddresses: []string{"test@example.com"}, + Raw: make([]byte, 25), // 模拟原始数据(25字节) + } + // 填充可预测的原始数据 + for i := 0; i < 25; i++ { + cert.Raw[i] = byte(i + 1) + } + + cs := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{cert}, + } + + SetSSLClientInfoInContext(ctx, cs, "SUCCESS") + + // 验证所有字段 + tests := []struct { + name string + key string + expected interface{} + }{ + {"verify", VarSSLClientVerify, "SUCCESS"}, + {"peer_cert_present", "tls_peer_cert_present", true}, + {"serial", VarSSLClientSerial, "12345"}, + {"subject", VarSSLClientSubject, cert.Subject.String()}, + {"issuer", VarSSLClientIssuer, cert.Issuer.String()}, + {"email", VarSSLClientEmail, "test@example.com"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := ctx.UserValue(tt.key) + if v != tt.expected { + t.Errorf("%s = %v, want %v", tt.name, v, tt.expected) + } + }) + } + + // 验证时间格式 + notBefore := ctx.UserValue(VarSSLClientNotBefore) + if notBefore == nil || notBefore == "" { + t.Error("notbefore should be set") + } + notAfter := ctx.UserValue(VarSSLClientNotAfter) + if notAfter == nil || notAfter == "" { + t.Error("notafter should be set") + } + + // 验证指纹 + fingerprint := ctx.UserValue(VarSSLClientFingerprint) + if fingerprint == nil || fingerprint == "" { + t.Error("fingerprint should be set") + } +} + +// TestSetSSLClientInfoInContext_NoEmail 测试证书无邮箱 +func TestSetSSLClientInfoInContext_NoEmail(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + + cert := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test"}, + Issuer: pkix.Name{CommonName: "CA"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + EmailAddresses: []string{}, // 无邮箱 + Raw: []byte{1, 2, 3}, + } + + cs := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{cert}, + } + + SetSSLClientInfoInContext(ctx, cs, "SUCCESS") + + // 验证邮箱未设置 + if v := ctx.UserValue(VarSSLClientEmail); v != nil { + t.Errorf("expected no email, got %v", v) + } +} + +// TestCalculateFingerprint 测试指纹计算 +func TestCalculateFingerprint(t *testing.T) { + tests := []struct { + name string + raw []byte + expected string + }{ + { + name: "empty data", + raw: []byte{}, + expected: "", + }, + { + name: "short data (less than 20 bytes)", + raw: []byte{1, 2, 3, 4, 5}, + expected: "0102030405000000000000000000000000000000", // 5字节+15个零 + }, + { + name: "exactly 20 bytes", + raw: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14}, + expected: "0102030405060708090A0B0C0D0E0F1011121314", + }, + { + name: "more than 20 bytes", + raw: []byte{0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8, 0xF7, 0xF6, 0xF5, 0xF4, 0xF3, 0xF2, 0xF1, 0xF0, 0xEF, 0xEE, 0xED, 0xEC, 0xEB, 0xEA}, + expected: "FFFEFDFCFBFAF9F8F7F6F5F4F3F2F1F0EFEEEDEC", // 只取前20字节 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := calculateFingerprint(tt.raw) + if result != tt.expected { + t.Errorf("calculateFingerprint() = %q, want %q", result, tt.expected) + } + }) + } +} + +// TestCalculateFingerprint_Uppercase 测试十六进制输出为大写 +func TestCalculateFingerprint_Uppercase(t *testing.T) { + raw := []byte{0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + result := calculateFingerprint(raw) + + // 验证输出为大写 + for _, c := range result { + if c >= 'a' && c <= 'f' { + t.Errorf("fingerprint should be uppercase, got %q", result) + break + } + } +} + +// TestSSLVariablesInContext 测试通过 VariableContext 访问 SSL 变量 +// 注意:ssl_client_verify 在非 TLS 连接下会返回 NONE(因为 GetSSLClientVerify 检查 ctx.IsTLS()) +func TestSSLVariablesInContext(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + + // 设置 SSL 客户端信息 + ctx.SetUserValue(VarSSLClientSerial, "ABC123") + ctx.SetUserValue(VarSSLClientSubject, "CN=test") + ctx.SetUserValue(VarSSLClientIssuer, "CN=CA") + ctx.SetUserValue(VarSSLClientFingerprint, "FINGERPRINT") + ctx.SetUserValue(VarSSLClientNotBefore, "2025-01-01T00:00:00Z") + ctx.SetUserValue(VarSSLClientNotAfter, "2026-01-01T00:00:00Z") + ctx.SetUserValue(VarSSLClientEmail, "test@example.com") + + vc := NewContext(ctx) + defer ReleaseContext(vc) + + tests := []struct { + varName string + expected string + }{ + {VarSSLClientSerial, "ABC123"}, + {VarSSLClientSubject, "CN=test"}, + {VarSSLClientIssuer, "CN=CA"}, + {VarSSLClientFingerprint, "FINGERPRINT"}, + {VarSSLClientNotBefore, "2025-01-01T00:00:00Z"}, + {VarSSLClientNotAfter, "2026-01-01T00:00:00Z"}, + {VarSSLClientEmail, "test@example.com"}, + } + + for _, tt := range tests { + t.Run(tt.varName, func(t *testing.T) { + value, ok := vc.Get(tt.varName) + if !ok { + t.Errorf("variable %s not found", tt.varName) + return + } + if value != tt.expected { + t.Errorf("%s = %q, want %q", tt.varName, value, tt.expected) + } + }) + } +} + +// TestSSLVariablesInContext_VerifyNonTLS 测试 ssl_client_verify 在非 TLS 下返回 NONE +func TestSSLVariablesInContext_VerifyNonTLS(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.SetUserValue(VarSSLClientVerify, "SUCCESS") + + vc := NewContext(ctx) + defer ReleaseContext(vc) + + // 非 TLS 连接,ssl_client_verify 应该返回 NONE + value, ok := vc.Get(VarSSLClientVerify) + if !ok { + t.Error("ssl_client_verify not found") + return + } + if value != "NONE" { + t.Errorf("ssl_client_verify = %q, want NONE (non-TLS context)", value) + } +} + +// TestSSLVariablesExpand 测试在模板中展开 SSL 变量 +// 注意:ssl_client_verify 在非 TLS 连接下会返回 NONE +func TestSSLVariablesExpand(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + + ctx.SetUserValue(VarSSLClientSerial, "12345") + ctx.SetUserValue(VarSSLClientSubject, "CN=test") + + vc := NewContext(ctx) + defer ReleaseContext(vc) + + tests := []struct { + template string + expected string + }{ + {"$ssl_client_serial", "12345"}, + {"$ssl_client_subject", "CN=test"}, + {"serial=$ssl_client_serial subject=$ssl_client_subject", "serial=12345 subject=CN=test"}, + } + + for _, tt := range tests { + t.Run(tt.template, func(t *testing.T) { + result := vc.Expand(tt.template) + if result != tt.expected { + t.Errorf("Expand(%q) = %q, want %q", tt.template, result, tt.expected) + } + }) + } +} + +// TestSSLVariablesExpand_VerifyNonTLS 测试 ssl_client_verify 在非 TLS 下展开为 NONE +func TestSSLVariablesExpand_VerifyNonTLS(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.SetUserValue(VarSSLClientVerify, "SUCCESS") + + vc := NewContext(ctx) + defer ReleaseContext(vc) + + // 非 TLS 连接,ssl_client_verify 应该展开为 NONE + result := vc.Expand("$ssl_client_verify") + if result != "NONE" { + t.Errorf("Expand($ssl_client_verify) = %q, want NONE (non-TLS context)", result) + } +}