feat(stream): 新增 TCP/UDP Stream SSL/TLS 支持
- StreamSSLManager 管理服务端 TLS 终端和客户端 TLS 连接 - 支持证书加载、mTLS 客户端验证 - 并发安全的证书配置管理 Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
b7de258f4e
commit
1a9059b1ff
310
internal/stream/ssl.go
Normal file
310
internal/stream/ssl.go
Normal file
@ -0,0 +1,310 @@
|
||||
// Package stream 提供 TCP/UDP Stream 代理功能。
|
||||
//
|
||||
// 该文件实现 Stream 模块的 SSL/TLS 支持,包括:
|
||||
// - 服务端 TLS 终端
|
||||
// - 客户端 TLS 连接(上游 SSL)
|
||||
// - 证书加载和配置
|
||||
// - mTLS 客户端证书验证
|
||||
//
|
||||
// 作者:xfy
|
||||
package stream
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// StreamSSLManager 管理 Stream SSL/TLS 配置。
|
||||
//
|
||||
// 负责加载证书、配置 TLS 连接,支持服务端和客户端两种模式。
|
||||
type StreamSSLManager struct {
|
||||
// config SSL 配置
|
||||
config config.StreamSSLConfig
|
||||
|
||||
// cert 服务器证书
|
||||
cert tls.Certificate
|
||||
|
||||
// clientCAPool 客户端 CA 证书池(mTLS)
|
||||
clientCAPool *x509.CertPool
|
||||
|
||||
// mu 保护并发访问
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// StreamProxySSLManager 管理上游 SSL 连接。
|
||||
//
|
||||
// 负责创建到上游服务器的 TLS 连接,支持证书验证和客户端证书。
|
||||
type StreamProxySSLManager struct {
|
||||
// config 代理 SSL 配置
|
||||
config config.StreamProxySSLConfig
|
||||
|
||||
// cert 客户端证书
|
||||
cert tls.Certificate
|
||||
|
||||
// rootCAPool 根 CA 证书池
|
||||
rootCAPool *x509.CertPool
|
||||
|
||||
// mu 保护并发访问
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewStreamSSLManager 创建 Stream SSL 管理器。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: SSL 配置
|
||||
//
|
||||
// 返回值:
|
||||
// - *StreamSSLManager: SSL 管理器实例
|
||||
// - error: 证书加载失败时返回错误
|
||||
func NewStreamSSLManager(cfg config.StreamSSLConfig) (*StreamSSLManager, error) {
|
||||
if !cfg.Enabled {
|
||||
return &StreamSSLManager{config: cfg}, nil
|
||||
}
|
||||
|
||||
// 加载服务器证书
|
||||
cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load server certificate: %w", err)
|
||||
}
|
||||
|
||||
mgr := &StreamSSLManager{
|
||||
config: cfg,
|
||||
cert: cert,
|
||||
}
|
||||
|
||||
// 加载客户端 CA 证书(mTLS)
|
||||
if cfg.ClientCA != "" {
|
||||
pool, err := loadCertPool(cfg.ClientCA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load client CA: %w", err)
|
||||
}
|
||||
mgr.clientCAPool = pool
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// NewStreamProxySSLManager 创建上游 SSL 管理器。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 代理 SSL 配置
|
||||
//
|
||||
// 返回值:
|
||||
// - *StreamProxySSLManager: 代理 SSL 管理器实例
|
||||
// - error: 证书加载失败时返回错误
|
||||
func NewStreamProxySSLManager(cfg config.StreamProxySSLConfig) (*StreamProxySSLManager, error) {
|
||||
if !cfg.Enabled {
|
||||
return &StreamProxySSLManager{config: cfg}, nil
|
||||
}
|
||||
|
||||
mgr := &StreamProxySSLManager{config: cfg}
|
||||
|
||||
// 加载客户端证书(mTLS)
|
||||
if cfg.Cert != "" && cfg.Key != "" {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load client certificate: %w", err)
|
||||
}
|
||||
mgr.cert = cert
|
||||
}
|
||||
|
||||
// 加载信任的 CA 证书
|
||||
if cfg.TrustedCA != "" {
|
||||
pool, err := loadCertPool(cfg.TrustedCA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load trusted CA: %w", err)
|
||||
}
|
||||
mgr.rootCAPool = pool
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// GetTLSConfig 获取服务端 TLS 配置。
|
||||
//
|
||||
// 返回值:
|
||||
// - *tls.Config: TLS 配置对象
|
||||
func (m *StreamSSLManager) GetTLSConfig() *tls.Config {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if !m.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{m.cert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
// 设置协议版本
|
||||
if len(m.config.Protocols) > 0 {
|
||||
tlsConfig.MinVersion = parseMinTLSVersion(m.config.Protocols)
|
||||
}
|
||||
|
||||
// 设置加密套件
|
||||
if len(m.config.Ciphers) > 0 {
|
||||
tlsConfig.CipherSuites = parseCipherSuites(m.config.Ciphers)
|
||||
}
|
||||
|
||||
// 配置客户端证书验证(mTLS)
|
||||
if m.clientCAPool != nil {
|
||||
tlsConfig.ClientCAs = m.clientCAPool
|
||||
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
}
|
||||
|
||||
return tlsConfig
|
||||
}
|
||||
|
||||
// GetClientTLSConfig 获取客户端 TLS 配置。
|
||||
//
|
||||
// 用于连接上游服务器。
|
||||
//
|
||||
// 参数:
|
||||
// - serverName: 服务器名称(用于 SNI)
|
||||
//
|
||||
// 返回值:
|
||||
// - *tls.Config: TLS 配置对象
|
||||
func (m *StreamProxySSLManager) GetClientTLSConfig(serverName string) *tls.Config {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if !m.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
// 设置服务器名称(SNI)
|
||||
if m.config.ServerName != "" {
|
||||
tlsConfig.ServerName = m.config.ServerName
|
||||
} else if serverName != "" {
|
||||
tlsConfig.ServerName = serverName
|
||||
}
|
||||
|
||||
// 设置客户端证书
|
||||
if m.cert.Certificate != nil {
|
||||
tlsConfig.Certificates = []tls.Certificate{m.cert}
|
||||
}
|
||||
|
||||
// 设置协议版本
|
||||
if len(m.config.Protocols) > 0 {
|
||||
tlsConfig.MinVersion = parseMinTLSVersion(m.config.Protocols)
|
||||
}
|
||||
|
||||
// 配置服务器证书验证
|
||||
if m.config.Verify && m.rootCAPool != nil {
|
||||
tlsConfig.RootCAs = m.rootCAPool
|
||||
} else if !m.config.Verify {
|
||||
// 跳过证书验证
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
// 会话复用
|
||||
if m.config.SessionReuse {
|
||||
tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(100)
|
||||
}
|
||||
|
||||
return tlsConfig
|
||||
}
|
||||
|
||||
// IsEnabled 检查是否启用 SSL。
|
||||
func (m *StreamSSLManager) IsEnabled() bool {
|
||||
return m.config.Enabled
|
||||
}
|
||||
|
||||
// IsEnabled 检查是否启用代理 SSL。
|
||||
func (m *StreamProxySSLManager) IsEnabled() bool {
|
||||
return m.config.Enabled
|
||||
}
|
||||
|
||||
// loadCertPool 从文件加载证书池。
|
||||
//
|
||||
// 参数:
|
||||
// - certFile: 证书文件路径
|
||||
//
|
||||
// 返回值:
|
||||
// - *x509.CertPool: 证书池
|
||||
// - error: 加载失败时返回错误
|
||||
func loadCertPool(certFile string) (*x509.CertPool, error) {
|
||||
data, err := os.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(data) {
|
||||
return nil, fmt.Errorf("failed to parse certificates from %s", certFile)
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// parseMinTLSVersion 解析最小 TLS 版本。
|
||||
//
|
||||
// 参数:
|
||||
// - protocols: 协议版本列表
|
||||
//
|
||||
// 返回值:
|
||||
// - uint16: TLS 版本常量
|
||||
func parseMinTLSVersion(protocols []string) uint16 {
|
||||
for _, p := range protocols {
|
||||
switch p {
|
||||
case "TLSv1.3":
|
||||
return tls.VersionTLS13
|
||||
case "TLSv1.2":
|
||||
return tls.VersionTLS12
|
||||
}
|
||||
}
|
||||
return tls.VersionTLS12
|
||||
}
|
||||
|
||||
// parseCipherSuites 解析加密套件列表。
|
||||
//
|
||||
// 参数:
|
||||
// - ciphers: 加密套件名称列表
|
||||
//
|
||||
// 返回值:
|
||||
// - []uint16: 加密套件 ID 列表
|
||||
func parseCipherSuites(ciphers []string) []uint16 {
|
||||
var suites []uint16
|
||||
for _, c := range ciphers {
|
||||
if id, ok := cipherNameToID[c]; ok {
|
||||
suites = append(suites, id)
|
||||
}
|
||||
}
|
||||
if len(suites) == 0 {
|
||||
return nil // 使用默认值
|
||||
}
|
||||
return suites
|
||||
}
|
||||
|
||||
// cipherNameToID 加密套件名称到 ID 的映射
|
||||
var cipherNameToID = map[string]uint16{
|
||||
"ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
"ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
"ECDHE-RSA-CHACHA20-POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
"ECDHE-ECDSA-CHACHA20-POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
"AES128-GCM-SHA256": tls.TLS_AES_128_GCM_SHA256,
|
||||
"AES256-GCM-SHA384": tls.TLS_AES_256_GCM_SHA384,
|
||||
"CHACHA20-POLY1305": tls.TLS_CHACHA20_POLY1305_SHA256,
|
||||
"ECDHE-RSA-AES128-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
"ECDHE-RSA-AES256-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||
"ECDHE-ECDSA-AES128-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
"ECDHE-ECDSA-AES256-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
"RSA-AES128-GCM-SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"RSA-AES256-GCM-SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"RSA-AES128-CBC-SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
"RSA-AES256-CBC-SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
"ECDHE-RSA-3DES-EDE-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
"RSA-3DES-EDE-CBC-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
}
|
||||
464
internal/stream/ssl_test.go
Normal file
464
internal/stream/ssl_test.go
Normal file
@ -0,0 +1,464 @@
|
||||
// Package stream 提供 TCP/UDP Stream 代理功能。
|
||||
//
|
||||
// 该文件包含 Stream SSL/TLS 支持的单元测试。
|
||||
//
|
||||
// 作者:xfy
|
||||
package stream
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// generateTestCertificate 生成测试用的自签名证书
|
||||
func generateTestCertificate(t *testing.T, certFile, keyFile string) {
|
||||
t.Helper()
|
||||
|
||||
// 创建证书模板
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
|
||||
// 生成私钥和证书
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
// 写入证书文件
|
||||
certOut, err := os.Create(certFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create cert file: %v", err)
|
||||
}
|
||||
defer func() { _ = certOut.Close() }()
|
||||
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
||||
t.Fatalf("Failed to encode certificate: %v", err)
|
||||
}
|
||||
|
||||
// 写入私钥文件
|
||||
keyOut, err := os.Create(keyFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create key file: %v", err)
|
||||
}
|
||||
defer func() { _ = keyOut.Close() }()
|
||||
if err := pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}); err != nil {
|
||||
t.Fatalf("Failed to encode key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewStreamSSLManager_Disabled(t *testing.T) {
|
||||
cfg := config.StreamSSLConfig{
|
||||
Enabled: false,
|
||||
}
|
||||
|
||||
mgr, err := NewStreamSSLManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStreamSSLManager failed: %v", err)
|
||||
}
|
||||
|
||||
if mgr.IsEnabled() {
|
||||
t.Error("Expected IsEnabled to be false")
|
||||
}
|
||||
|
||||
tlsConfig := mgr.GetTLSConfig()
|
||||
if tlsConfig != nil {
|
||||
t.Error("Expected nil TLS config when disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewStreamSSLManager_Enabled(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
certFile := filepath.Join(tempDir, "server.crt")
|
||||
keyFile := filepath.Join(tempDir, "server.key")
|
||||
|
||||
generateTestCertificate(t, certFile, keyFile)
|
||||
|
||||
cfg := config.StreamSSLConfig{
|
||||
Enabled: true,
|
||||
Cert: certFile,
|
||||
Key: keyFile,
|
||||
Protocols: []string{"TLSv1.2", "TLSv1.3"},
|
||||
}
|
||||
|
||||
mgr, err := NewStreamSSLManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStreamSSLManager failed: %v", err)
|
||||
}
|
||||
|
||||
if !mgr.IsEnabled() {
|
||||
t.Error("Expected IsEnabled to be true")
|
||||
}
|
||||
|
||||
tlsConfig := mgr.GetTLSConfig()
|
||||
if tlsConfig == nil {
|
||||
t.Fatal("Expected non-nil TLS config")
|
||||
}
|
||||
|
||||
if len(tlsConfig.Certificates) != 1 {
|
||||
t.Errorf("Expected 1 certificate, got %d", len(tlsConfig.Certificates))
|
||||
}
|
||||
|
||||
if tlsConfig.MinVersion < tls.VersionTLS12 {
|
||||
t.Errorf("Expected MinVersion >= TLS 1.2, got %v", tlsConfig.MinVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewStreamSSLManager_InvalidCert(t *testing.T) {
|
||||
cfg := config.StreamSSLConfig{
|
||||
Enabled: true,
|
||||
Cert: "/nonexistent/cert.pem",
|
||||
Key: "/nonexistent/key.pem",
|
||||
}
|
||||
|
||||
_, err := NewStreamSSLManager(cfg)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid certificate path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewStreamProxySSLManager_Disabled(t *testing.T) {
|
||||
cfg := config.StreamProxySSLConfig{
|
||||
Enabled: false,
|
||||
}
|
||||
|
||||
mgr, err := NewStreamProxySSLManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStreamProxySSLManager failed: %v", err)
|
||||
}
|
||||
|
||||
if mgr.IsEnabled() {
|
||||
t.Error("Expected IsEnabled to be false")
|
||||
}
|
||||
|
||||
tlsConfig := mgr.GetClientTLSConfig("example.com")
|
||||
if tlsConfig != nil {
|
||||
t.Error("Expected nil TLS config when disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewStreamProxySSLManager_Enabled(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
certFile := filepath.Join(tempDir, "client.crt")
|
||||
keyFile := filepath.Join(tempDir, "client.key")
|
||||
|
||||
generateTestCertificate(t, certFile, keyFile)
|
||||
|
||||
cfg := config.StreamProxySSLConfig{
|
||||
Enabled: true,
|
||||
Cert: certFile,
|
||||
Key: keyFile,
|
||||
ServerName: "backend.example.com",
|
||||
Verify: false,
|
||||
Protocols: []string{"TLSv1.2", "TLSv1.3"},
|
||||
SessionReuse: true,
|
||||
}
|
||||
|
||||
mgr, err := NewStreamProxySSLManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStreamProxySSLManager failed: %v", err)
|
||||
}
|
||||
|
||||
if !mgr.IsEnabled() {
|
||||
t.Error("Expected IsEnabled to be true")
|
||||
}
|
||||
|
||||
tlsConfig := mgr.GetClientTLSConfig("fallback.example.com")
|
||||
if tlsConfig == nil {
|
||||
t.Fatal("Expected non-nil TLS config")
|
||||
}
|
||||
|
||||
// 应该使用配置中的 ServerName
|
||||
if tlsConfig.ServerName != "backend.example.com" {
|
||||
t.Errorf("Expected ServerName 'backend.example.com', got '%s'", tlsConfig.ServerName)
|
||||
}
|
||||
|
||||
// 应该有客户端证书
|
||||
if len(tlsConfig.Certificates) != 1 {
|
||||
t.Errorf("Expected 1 client certificate, got %d", len(tlsConfig.Certificates))
|
||||
}
|
||||
|
||||
// 跳过验证
|
||||
if !tlsConfig.InsecureSkipVerify {
|
||||
t.Error("Expected InsecureSkipVerify to be true")
|
||||
}
|
||||
|
||||
// 会话复用
|
||||
if tlsConfig.ClientSessionCache == nil {
|
||||
t.Error("Expected ClientSessionCache to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewStreamProxySSLManager_WithVerify(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
|
||||
// 创建 CA 证书
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
IsCA: true,
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create CA certificate: %v", err)
|
||||
}
|
||||
|
||||
caOut, err := os.Create(caFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create CA file: %v", err)
|
||||
}
|
||||
defer func() { _ = caOut.Close() }()
|
||||
if err := pem.Encode(caOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
||||
t.Fatalf("Failed to encode CA: %v", err)
|
||||
}
|
||||
|
||||
cfg := config.StreamProxySSLConfig{
|
||||
Enabled: true,
|
||||
Verify: true,
|
||||
TrustedCA: caFile,
|
||||
ServerName: "backend.example.com",
|
||||
}
|
||||
|
||||
mgr, err := NewStreamProxySSLManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStreamProxySSLManager failed: %v", err)
|
||||
}
|
||||
|
||||
tlsConfig := mgr.GetClientTLSConfig("")
|
||||
if tlsConfig == nil {
|
||||
t.Fatal("Expected non-nil TLS config")
|
||||
}
|
||||
|
||||
// 应该验证证书
|
||||
if tlsConfig.InsecureSkipVerify {
|
||||
t.Error("Expected InsecureSkipVerify to be false")
|
||||
}
|
||||
|
||||
// 应该有 RootCAs
|
||||
if tlsConfig.RootCAs == nil {
|
||||
t.Error("Expected RootCAs to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMinTLSVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
protocols []string
|
||||
wantVersion uint16
|
||||
}{
|
||||
{[]string{"TLSv1.3"}, tls.VersionTLS13},
|
||||
{[]string{"TLSv1.2"}, tls.VersionTLS12},
|
||||
{[]string{"TLSv1.2", "TLSv1.3"}, tls.VersionTLS12},
|
||||
{[]string{}, tls.VersionTLS12},
|
||||
{[]string{"Unknown"}, tls.VersionTLS12},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := parseMinTLSVersion(tt.protocols)
|
||||
if got != tt.wantVersion {
|
||||
t.Errorf("parseMinTLSVersion(%v) = %v, want %v", tt.protocols, got, tt.wantVersion)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCipherSuites(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ciphers []string
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
name: "valid ciphers",
|
||||
ciphers: []string{"ECDHE-RSA-AES128-GCM-SHA256", "ECDHE-RSA-AES256-GCM-SHA384"},
|
||||
wantLen: 2,
|
||||
},
|
||||
{
|
||||
name: "empty ciphers",
|
||||
ciphers: []string{},
|
||||
wantLen: 0, // returns nil for empty
|
||||
},
|
||||
{
|
||||
name: "unknown ciphers",
|
||||
ciphers: []string{"UNKNOWN-CIPHER"},
|
||||
wantLen: 0, // returns nil for no valid ciphers
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := parseCipherSuites(tt.ciphers)
|
||||
if tt.wantLen == 0 && got != nil {
|
||||
t.Errorf("Expected nil, got %v", got)
|
||||
} else if tt.wantLen > 0 && len(got) != tt.wantLen {
|
||||
t.Errorf("Expected %d ciphers, got %d", tt.wantLen, len(got))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCertPool(t *testing.T) {
|
||||
t.Run("valid cert", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
certFile := filepath.Join(tempDir, "ca.crt")
|
||||
|
||||
// 创建证书
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
IsCA: true,
|
||||
KeyUsage: x509.KeyUsageCertSign,
|
||||
}
|
||||
|
||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
certDER, _ := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
|
||||
certOut, err := os.Create(certFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create cert file: %v", err)
|
||||
}
|
||||
defer func() { _ = certOut.Close() }()
|
||||
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
||||
t.Fatalf("Failed to encode certificate: %v", err)
|
||||
}
|
||||
|
||||
pool, err := loadCertPool(certFile)
|
||||
if err != nil {
|
||||
t.Fatalf("loadCertPool failed: %v", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Error("Expected non-nil pool")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid path", func(t *testing.T) {
|
||||
_, err := loadCertPool("/nonexistent/cert.pem")
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent file")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid content", func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
certFile := filepath.Join(tempDir, "invalid.crt")
|
||||
if err := os.WriteFile(certFile, []byte("not a certificate"), 0644); err != nil {
|
||||
t.Fatalf("写入无效证书文件失败: %v", err)
|
||||
}
|
||||
|
||||
_, err := loadCertPool(certFile)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid certificate content")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamSSLManager_GetTLSConfig_WithClientCA(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
certFile := filepath.Join(tempDir, "server.crt")
|
||||
keyFile := filepath.Join(tempDir, "server.key")
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
|
||||
generateTestCertificate(t, certFile, keyFile)
|
||||
|
||||
// 创建 CA 证书
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
IsCA: true,
|
||||
KeyUsage: x509.KeyUsageCertSign,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
certDER, _ := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
|
||||
caOut, err := os.Create(caFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create CA file: %v", err)
|
||||
}
|
||||
defer func() { _ = caOut.Close() }()
|
||||
if err := pem.Encode(caOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
||||
t.Fatalf("Failed to encode CA: %v", err)
|
||||
}
|
||||
|
||||
cfg := config.StreamSSLConfig{
|
||||
Enabled: true,
|
||||
Cert: certFile,
|
||||
Key: keyFile,
|
||||
ClientCA: caFile,
|
||||
Protocols: []string{"TLSv1.2"},
|
||||
}
|
||||
|
||||
mgr, err := NewStreamSSLManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStreamSSLManager failed: %v", err)
|
||||
}
|
||||
|
||||
tlsConfig := mgr.GetTLSConfig()
|
||||
if tlsConfig == nil {
|
||||
t.Fatal("Expected non-nil TLS config")
|
||||
}
|
||||
|
||||
// 应该配置客户端 CA
|
||||
if tlsConfig.ClientCAs == nil {
|
||||
t.Error("Expected ClientCAs to be set")
|
||||
}
|
||||
|
||||
// 应该请求客户端证书
|
||||
if tlsConfig.ClientAuth != tls.VerifyClientCertIfGiven {
|
||||
t.Errorf("Expected ClientAuth VerifyClientCertIfGiven, got %v", tlsConfig.ClientAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamProxySSLManager_GetClientTLSConfig_WithServerNameOverride(t *testing.T) {
|
||||
cfg := config.StreamProxySSLConfig{
|
||||
Enabled: true,
|
||||
Verify: false,
|
||||
ServerName: "configured.example.com",
|
||||
}
|
||||
|
||||
mgr, err := NewStreamProxySSLManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewStreamProxySSLManager failed: %v", err)
|
||||
}
|
||||
|
||||
// 即使传入不同的 serverName,也应该使用配置的
|
||||
tlsConfig := mgr.GetClientTLSConfig("fallback.example.com")
|
||||
if tlsConfig == nil {
|
||||
t.Fatal("Expected non-nil TLS config")
|
||||
}
|
||||
|
||||
if tlsConfig.ServerName != "configured.example.com" {
|
||||
t.Errorf("Expected ServerName 'configured.example.com', got '%s'", tlsConfig.ServerName)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user