From 1a9059b1ffb53235a335e100881489153d07e9c0 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 8 Apr 2026 14:37:02 +0800 Subject: [PATCH] =?UTF-8?q?feat(stream):=20=E6=96=B0=E5=A2=9E=20TCP/UDP=20?= =?UTF-8?q?Stream=20SSL/TLS=20=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - StreamSSLManager 管理服务端 TLS 终端和客户端 TLS 连接 - 支持证书加载、mTLS 客户端验证 - 并发安全的证书配置管理 Co-Authored-By: Claude --- internal/stream/ssl.go | 310 ++++++++++++++++++++++++ internal/stream/ssl_test.go | 464 ++++++++++++++++++++++++++++++++++++ 2 files changed, 774 insertions(+) create mode 100644 internal/stream/ssl.go create mode 100644 internal/stream/ssl_test.go diff --git a/internal/stream/ssl.go b/internal/stream/ssl.go new file mode 100644 index 0000000..a094197 --- /dev/null +++ b/internal/stream/ssl.go @@ -0,0 +1,310 @@ +// Package stream 提供 TCP/UDP Stream 代理功能。 +// +// 该文件实现 Stream 模块的 SSL/TLS 支持,包括: +// - 服务端 TLS 终端 +// - 客户端 TLS 连接(上游 SSL) +// - 证书加载和配置 +// - mTLS 客户端证书验证 +// +// 作者:xfy +package stream + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "sync" + + "rua.plus/lolly/internal/config" +) + +// StreamSSLManager 管理 Stream SSL/TLS 配置。 +// +// 负责加载证书、配置 TLS 连接,支持服务端和客户端两种模式。 +type StreamSSLManager struct { + // config SSL 配置 + config config.StreamSSLConfig + + // cert 服务器证书 + cert tls.Certificate + + // clientCAPool 客户端 CA 证书池(mTLS) + clientCAPool *x509.CertPool + + // mu 保护并发访问 + mu sync.RWMutex +} + +// StreamProxySSLManager 管理上游 SSL 连接。 +// +// 负责创建到上游服务器的 TLS 连接,支持证书验证和客户端证书。 +type StreamProxySSLManager struct { + // config 代理 SSL 配置 + config config.StreamProxySSLConfig + + // cert 客户端证书 + cert tls.Certificate + + // rootCAPool 根 CA 证书池 + rootCAPool *x509.CertPool + + // mu 保护并发访问 + mu sync.RWMutex +} + +// NewStreamSSLManager 创建 Stream SSL 管理器。 +// +// 参数: +// - cfg: SSL 配置 +// +// 返回值: +// - *StreamSSLManager: SSL 管理器实例 +// - error: 证书加载失败时返回错误 +func NewStreamSSLManager(cfg config.StreamSSLConfig) (*StreamSSLManager, error) { + if !cfg.Enabled { + return &StreamSSLManager{config: cfg}, nil + } + + // 加载服务器证书 + cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key) + if err != nil { + return nil, fmt.Errorf("failed to load server certificate: %w", err) + } + + mgr := &StreamSSLManager{ + config: cfg, + cert: cert, + } + + // 加载客户端 CA 证书(mTLS) + if cfg.ClientCA != "" { + pool, err := loadCertPool(cfg.ClientCA) + if err != nil { + return nil, fmt.Errorf("failed to load client CA: %w", err) + } + mgr.clientCAPool = pool + } + + return mgr, nil +} + +// NewStreamProxySSLManager 创建上游 SSL 管理器。 +// +// 参数: +// - cfg: 代理 SSL 配置 +// +// 返回值: +// - *StreamProxySSLManager: 代理 SSL 管理器实例 +// - error: 证书加载失败时返回错误 +func NewStreamProxySSLManager(cfg config.StreamProxySSLConfig) (*StreamProxySSLManager, error) { + if !cfg.Enabled { + return &StreamProxySSLManager{config: cfg}, nil + } + + mgr := &StreamProxySSLManager{config: cfg} + + // 加载客户端证书(mTLS) + if cfg.Cert != "" && cfg.Key != "" { + cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key) + if err != nil { + return nil, fmt.Errorf("failed to load client certificate: %w", err) + } + mgr.cert = cert + } + + // 加载信任的 CA 证书 + if cfg.TrustedCA != "" { + pool, err := loadCertPool(cfg.TrustedCA) + if err != nil { + return nil, fmt.Errorf("failed to load trusted CA: %w", err) + } + mgr.rootCAPool = pool + } + + return mgr, nil +} + +// GetTLSConfig 获取服务端 TLS 配置。 +// +// 返回值: +// - *tls.Config: TLS 配置对象 +func (m *StreamSSLManager) GetTLSConfig() *tls.Config { + m.mu.RLock() + defer m.mu.RUnlock() + + if !m.config.Enabled { + return nil + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{m.cert}, + MinVersion: tls.VersionTLS12, + } + + // 设置协议版本 + if len(m.config.Protocols) > 0 { + tlsConfig.MinVersion = parseMinTLSVersion(m.config.Protocols) + } + + // 设置加密套件 + if len(m.config.Ciphers) > 0 { + tlsConfig.CipherSuites = parseCipherSuites(m.config.Ciphers) + } + + // 配置客户端证书验证(mTLS) + if m.clientCAPool != nil { + tlsConfig.ClientCAs = m.clientCAPool + tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven + } + + return tlsConfig +} + +// GetClientTLSConfig 获取客户端 TLS 配置。 +// +// 用于连接上游服务器。 +// +// 参数: +// - serverName: 服务器名称(用于 SNI) +// +// 返回值: +// - *tls.Config: TLS 配置对象 +func (m *StreamProxySSLManager) GetClientTLSConfig(serverName string) *tls.Config { + m.mu.RLock() + defer m.mu.RUnlock() + + if !m.config.Enabled { + return nil + } + + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + // 设置服务器名称(SNI) + if m.config.ServerName != "" { + tlsConfig.ServerName = m.config.ServerName + } else if serverName != "" { + tlsConfig.ServerName = serverName + } + + // 设置客户端证书 + if m.cert.Certificate != nil { + tlsConfig.Certificates = []tls.Certificate{m.cert} + } + + // 设置协议版本 + if len(m.config.Protocols) > 0 { + tlsConfig.MinVersion = parseMinTLSVersion(m.config.Protocols) + } + + // 配置服务器证书验证 + if m.config.Verify && m.rootCAPool != nil { + tlsConfig.RootCAs = m.rootCAPool + } else if !m.config.Verify { + // 跳过证书验证 + tlsConfig.InsecureSkipVerify = true + } + + // 会话复用 + if m.config.SessionReuse { + tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(100) + } + + return tlsConfig +} + +// IsEnabled 检查是否启用 SSL。 +func (m *StreamSSLManager) IsEnabled() bool { + return m.config.Enabled +} + +// IsEnabled 检查是否启用代理 SSL。 +func (m *StreamProxySSLManager) IsEnabled() bool { + return m.config.Enabled +} + +// loadCertPool 从文件加载证书池。 +// +// 参数: +// - certFile: 证书文件路径 +// +// 返回值: +// - *x509.CertPool: 证书池 +// - error: 加载失败时返回错误 +func loadCertPool(certFile string) (*x509.CertPool, error) { + data, err := os.ReadFile(certFile) + if err != nil { + return nil, err + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(data) { + return nil, fmt.Errorf("failed to parse certificates from %s", certFile) + } + + return pool, nil +} + +// parseMinTLSVersion 解析最小 TLS 版本。 +// +// 参数: +// - protocols: 协议版本列表 +// +// 返回值: +// - uint16: TLS 版本常量 +func parseMinTLSVersion(protocols []string) uint16 { + for _, p := range protocols { + switch p { + case "TLSv1.3": + return tls.VersionTLS13 + case "TLSv1.2": + return tls.VersionTLS12 + } + } + return tls.VersionTLS12 +} + +// parseCipherSuites 解析加密套件列表。 +// +// 参数: +// - ciphers: 加密套件名称列表 +// +// 返回值: +// - []uint16: 加密套件 ID 列表 +func parseCipherSuites(ciphers []string) []uint16 { + var suites []uint16 + for _, c := range ciphers { + if id, ok := cipherNameToID[c]; ok { + suites = append(suites, id) + } + } + if len(suites) == 0 { + return nil // 使用默认值 + } + return suites +} + +// cipherNameToID 加密套件名称到 ID 的映射 +var cipherNameToID = map[string]uint16{ + "ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + "ECDHE-RSA-CHACHA20-POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + "ECDHE-ECDSA-CHACHA20-POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + "AES128-GCM-SHA256": tls.TLS_AES_128_GCM_SHA256, + "AES256-GCM-SHA384": tls.TLS_AES_256_GCM_SHA384, + "CHACHA20-POLY1305": tls.TLS_CHACHA20_POLY1305_SHA256, + "ECDHE-RSA-AES128-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + "ECDHE-RSA-AES256-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + "ECDHE-ECDSA-AES128-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + "ECDHE-ECDSA-AES256-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + "RSA-AES128-GCM-SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + "RSA-AES256-GCM-SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + "RSA-AES128-CBC-SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, + "RSA-AES256-CBC-SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, + "ECDHE-RSA-3DES-EDE-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + "RSA-3DES-EDE-CBC-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, +} diff --git a/internal/stream/ssl_test.go b/internal/stream/ssl_test.go new file mode 100644 index 0000000..f3dc277 --- /dev/null +++ b/internal/stream/ssl_test.go @@ -0,0 +1,464 @@ +// Package stream 提供 TCP/UDP Stream 代理功能。 +// +// 该文件包含 Stream SSL/TLS 支持的单元测试。 +// +// 作者:xfy +package stream + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + "rua.plus/lolly/internal/config" +) + +// generateTestCertificate 生成测试用的自签名证书 +func generateTestCertificate(t *testing.T, certFile, keyFile string) { + t.Helper() + + // 创建证书模板 + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + } + + // 生成私钥和证书 + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("Failed to create certificate: %v", err) + } + + // 写入证书文件 + certOut, err := os.Create(certFile) + if err != nil { + t.Fatalf("Failed to create cert file: %v", err) + } + defer func() { _ = certOut.Close() }() + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { + t.Fatalf("Failed to encode certificate: %v", err) + } + + // 写入私钥文件 + keyOut, err := os.Create(keyFile) + if err != nil { + t.Fatalf("Failed to create key file: %v", err) + } + defer func() { _ = keyOut.Close() }() + if err := pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}); err != nil { + t.Fatalf("Failed to encode key: %v", err) + } +} + +func TestNewStreamSSLManager_Disabled(t *testing.T) { + cfg := config.StreamSSLConfig{ + Enabled: false, + } + + mgr, err := NewStreamSSLManager(cfg) + if err != nil { + t.Fatalf("NewStreamSSLManager failed: %v", err) + } + + if mgr.IsEnabled() { + t.Error("Expected IsEnabled to be false") + } + + tlsConfig := mgr.GetTLSConfig() + if tlsConfig != nil { + t.Error("Expected nil TLS config when disabled") + } +} + +func TestNewStreamSSLManager_Enabled(t *testing.T) { + tempDir := t.TempDir() + certFile := filepath.Join(tempDir, "server.crt") + keyFile := filepath.Join(tempDir, "server.key") + + generateTestCertificate(t, certFile, keyFile) + + cfg := config.StreamSSLConfig{ + Enabled: true, + Cert: certFile, + Key: keyFile, + Protocols: []string{"TLSv1.2", "TLSv1.3"}, + } + + mgr, err := NewStreamSSLManager(cfg) + if err != nil { + t.Fatalf("NewStreamSSLManager failed: %v", err) + } + + if !mgr.IsEnabled() { + t.Error("Expected IsEnabled to be true") + } + + tlsConfig := mgr.GetTLSConfig() + if tlsConfig == nil { + t.Fatal("Expected non-nil TLS config") + } + + if len(tlsConfig.Certificates) != 1 { + t.Errorf("Expected 1 certificate, got %d", len(tlsConfig.Certificates)) + } + + if tlsConfig.MinVersion < tls.VersionTLS12 { + t.Errorf("Expected MinVersion >= TLS 1.2, got %v", tlsConfig.MinVersion) + } +} + +func TestNewStreamSSLManager_InvalidCert(t *testing.T) { + cfg := config.StreamSSLConfig{ + Enabled: true, + Cert: "/nonexistent/cert.pem", + Key: "/nonexistent/key.pem", + } + + _, err := NewStreamSSLManager(cfg) + if err == nil { + t.Error("Expected error for invalid certificate path") + } +} + +func TestNewStreamProxySSLManager_Disabled(t *testing.T) { + cfg := config.StreamProxySSLConfig{ + Enabled: false, + } + + mgr, err := NewStreamProxySSLManager(cfg) + if err != nil { + t.Fatalf("NewStreamProxySSLManager failed: %v", err) + } + + if mgr.IsEnabled() { + t.Error("Expected IsEnabled to be false") + } + + tlsConfig := mgr.GetClientTLSConfig("example.com") + if tlsConfig != nil { + t.Error("Expected nil TLS config when disabled") + } +} + +func TestNewStreamProxySSLManager_Enabled(t *testing.T) { + tempDir := t.TempDir() + certFile := filepath.Join(tempDir, "client.crt") + keyFile := filepath.Join(tempDir, "client.key") + + generateTestCertificate(t, certFile, keyFile) + + cfg := config.StreamProxySSLConfig{ + Enabled: true, + Cert: certFile, + Key: keyFile, + ServerName: "backend.example.com", + Verify: false, + Protocols: []string{"TLSv1.2", "TLSv1.3"}, + SessionReuse: true, + } + + mgr, err := NewStreamProxySSLManager(cfg) + if err != nil { + t.Fatalf("NewStreamProxySSLManager failed: %v", err) + } + + if !mgr.IsEnabled() { + t.Error("Expected IsEnabled to be true") + } + + tlsConfig := mgr.GetClientTLSConfig("fallback.example.com") + if tlsConfig == nil { + t.Fatal("Expected non-nil TLS config") + } + + // 应该使用配置中的 ServerName + if tlsConfig.ServerName != "backend.example.com" { + t.Errorf("Expected ServerName 'backend.example.com', got '%s'", tlsConfig.ServerName) + } + + // 应该有客户端证书 + if len(tlsConfig.Certificates) != 1 { + t.Errorf("Expected 1 client certificate, got %d", len(tlsConfig.Certificates)) + } + + // 跳过验证 + if !tlsConfig.InsecureSkipVerify { + t.Error("Expected InsecureSkipVerify to be true") + } + + // 会话复用 + if tlsConfig.ClientSessionCache == nil { + t.Error("Expected ClientSessionCache to be set") + } +} + +func TestNewStreamProxySSLManager_WithVerify(t *testing.T) { + tempDir := t.TempDir() + caFile := filepath.Join(tempDir, "ca.crt") + + // 创建 CA 证书 + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + IsCA: true, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + } + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("Failed to create CA certificate: %v", err) + } + + caOut, err := os.Create(caFile) + if err != nil { + t.Fatalf("Failed to create CA file: %v", err) + } + defer func() { _ = caOut.Close() }() + if err := pem.Encode(caOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { + t.Fatalf("Failed to encode CA: %v", err) + } + + cfg := config.StreamProxySSLConfig{ + Enabled: true, + Verify: true, + TrustedCA: caFile, + ServerName: "backend.example.com", + } + + mgr, err := NewStreamProxySSLManager(cfg) + if err != nil { + t.Fatalf("NewStreamProxySSLManager failed: %v", err) + } + + tlsConfig := mgr.GetClientTLSConfig("") + if tlsConfig == nil { + t.Fatal("Expected non-nil TLS config") + } + + // 应该验证证书 + if tlsConfig.InsecureSkipVerify { + t.Error("Expected InsecureSkipVerify to be false") + } + + // 应该有 RootCAs + if tlsConfig.RootCAs == nil { + t.Error("Expected RootCAs to be set") + } +} + +func TestParseMinTLSVersion(t *testing.T) { + tests := []struct { + protocols []string + wantVersion uint16 + }{ + {[]string{"TLSv1.3"}, tls.VersionTLS13}, + {[]string{"TLSv1.2"}, tls.VersionTLS12}, + {[]string{"TLSv1.2", "TLSv1.3"}, tls.VersionTLS12}, + {[]string{}, tls.VersionTLS12}, + {[]string{"Unknown"}, tls.VersionTLS12}, + } + + for _, tt := range tests { + got := parseMinTLSVersion(tt.protocols) + if got != tt.wantVersion { + t.Errorf("parseMinTLSVersion(%v) = %v, want %v", tt.protocols, got, tt.wantVersion) + } + } +} + +func TestParseCipherSuites(t *testing.T) { + tests := []struct { + name string + ciphers []string + wantLen int + }{ + { + name: "valid ciphers", + ciphers: []string{"ECDHE-RSA-AES128-GCM-SHA256", "ECDHE-RSA-AES256-GCM-SHA384"}, + wantLen: 2, + }, + { + name: "empty ciphers", + ciphers: []string{}, + wantLen: 0, // returns nil for empty + }, + { + name: "unknown ciphers", + ciphers: []string{"UNKNOWN-CIPHER"}, + wantLen: 0, // returns nil for no valid ciphers + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseCipherSuites(tt.ciphers) + if tt.wantLen == 0 && got != nil { + t.Errorf("Expected nil, got %v", got) + } else if tt.wantLen > 0 && len(got) != tt.wantLen { + t.Errorf("Expected %d ciphers, got %d", tt.wantLen, len(got)) + } + }) + } +} + +func TestLoadCertPool(t *testing.T) { + t.Run("valid cert", func(t *testing.T) { + tempDir := t.TempDir() + certFile := filepath.Join(tempDir, "ca.crt") + + // 创建证书 + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + IsCA: true, + KeyUsage: x509.KeyUsageCertSign, + } + + key, _ := rsa.GenerateKey(rand.Reader, 2048) + certDER, _ := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + + certOut, err := os.Create(certFile) + if err != nil { + t.Fatalf("Failed to create cert file: %v", err) + } + defer func() { _ = certOut.Close() }() + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { + t.Fatalf("Failed to encode certificate: %v", err) + } + + pool, err := loadCertPool(certFile) + if err != nil { + t.Fatalf("loadCertPool failed: %v", err) + } + if pool == nil { + t.Error("Expected non-nil pool") + } + }) + + t.Run("invalid path", func(t *testing.T) { + _, err := loadCertPool("/nonexistent/cert.pem") + if err == nil { + t.Error("Expected error for nonexistent file") + } + }) + + t.Run("invalid content", func(t *testing.T) { + tempDir := t.TempDir() + certFile := filepath.Join(tempDir, "invalid.crt") + if err := os.WriteFile(certFile, []byte("not a certificate"), 0644); err != nil { + t.Fatalf("写入无效证书文件失败: %v", err) + } + + _, err := loadCertPool(certFile) + if err == nil { + t.Error("Expected error for invalid certificate content") + } + }) +} + +func TestStreamSSLManager_GetTLSConfig_WithClientCA(t *testing.T) { + tempDir := t.TempDir() + certFile := filepath.Join(tempDir, "server.crt") + keyFile := filepath.Join(tempDir, "server.key") + caFile := filepath.Join(tempDir, "ca.crt") + + generateTestCertificate(t, certFile, keyFile) + + // 创建 CA 证书 + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + IsCA: true, + KeyUsage: x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + key, _ := rsa.GenerateKey(rand.Reader, 2048) + certDER, _ := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + + caOut, err := os.Create(caFile) + if err != nil { + t.Fatalf("Failed to create CA file: %v", err) + } + defer func() { _ = caOut.Close() }() + if err := pem.Encode(caOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { + t.Fatalf("Failed to encode CA: %v", err) + } + + cfg := config.StreamSSLConfig{ + Enabled: true, + Cert: certFile, + Key: keyFile, + ClientCA: caFile, + Protocols: []string{"TLSv1.2"}, + } + + mgr, err := NewStreamSSLManager(cfg) + if err != nil { + t.Fatalf("NewStreamSSLManager failed: %v", err) + } + + tlsConfig := mgr.GetTLSConfig() + if tlsConfig == nil { + t.Fatal("Expected non-nil TLS config") + } + + // 应该配置客户端 CA + if tlsConfig.ClientCAs == nil { + t.Error("Expected ClientCAs to be set") + } + + // 应该请求客户端证书 + if tlsConfig.ClientAuth != tls.VerifyClientCertIfGiven { + t.Errorf("Expected ClientAuth VerifyClientCertIfGiven, got %v", tlsConfig.ClientAuth) + } +} + +func TestStreamProxySSLManager_GetClientTLSConfig_WithServerNameOverride(t *testing.T) { + cfg := config.StreamProxySSLConfig{ + Enabled: true, + Verify: false, + ServerName: "configured.example.com", + } + + mgr, err := NewStreamProxySSLManager(cfg) + if err != nil { + t.Fatalf("NewStreamProxySSLManager failed: %v", err) + } + + // 即使传入不同的 serverName,也应该使用配置的 + tlsConfig := mgr.GetClientTLSConfig("fallback.example.com") + if tlsConfig == nil { + t.Fatal("Expected non-nil TLS config") + } + + if tlsConfig.ServerName != "configured.example.com" { + t.Errorf("Expected ServerName 'configured.example.com', got '%s'", tlsConfig.ServerName) + } +}