feat(stream): 新增 TCP/UDP Stream SSL/TLS 支持
- StreamSSLManager 管理服务端 TLS 终端和客户端 TLS 连接 - 支持证书加载、mTLS 客户端验证 - 并发安全的证书配置管理 Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
b7de258f4e
commit
1a9059b1ff
310
internal/stream/ssl.go
Normal file
310
internal/stream/ssl.go
Normal file
@ -0,0 +1,310 @@
|
|||||||
|
// Package stream 提供 TCP/UDP Stream 代理功能。
|
||||||
|
//
|
||||||
|
// 该文件实现 Stream 模块的 SSL/TLS 支持,包括:
|
||||||
|
// - 服务端 TLS 终端
|
||||||
|
// - 客户端 TLS 连接(上游 SSL)
|
||||||
|
// - 证书加载和配置
|
||||||
|
// - mTLS 客户端证书验证
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StreamSSLManager 管理 Stream SSL/TLS 配置。
|
||||||
|
//
|
||||||
|
// 负责加载证书、配置 TLS 连接,支持服务端和客户端两种模式。
|
||||||
|
type StreamSSLManager struct {
|
||||||
|
// config SSL 配置
|
||||||
|
config config.StreamSSLConfig
|
||||||
|
|
||||||
|
// cert 服务器证书
|
||||||
|
cert tls.Certificate
|
||||||
|
|
||||||
|
// clientCAPool 客户端 CA 证书池(mTLS)
|
||||||
|
clientCAPool *x509.CertPool
|
||||||
|
|
||||||
|
// mu 保护并发访问
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamProxySSLManager 管理上游 SSL 连接。
|
||||||
|
//
|
||||||
|
// 负责创建到上游服务器的 TLS 连接,支持证书验证和客户端证书。
|
||||||
|
type StreamProxySSLManager struct {
|
||||||
|
// config 代理 SSL 配置
|
||||||
|
config config.StreamProxySSLConfig
|
||||||
|
|
||||||
|
// cert 客户端证书
|
||||||
|
cert tls.Certificate
|
||||||
|
|
||||||
|
// rootCAPool 根 CA 证书池
|
||||||
|
rootCAPool *x509.CertPool
|
||||||
|
|
||||||
|
// mu 保护并发访问
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStreamSSLManager 创建 Stream SSL 管理器。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - cfg: SSL 配置
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *StreamSSLManager: SSL 管理器实例
|
||||||
|
// - error: 证书加载失败时返回错误
|
||||||
|
func NewStreamSSLManager(cfg config.StreamSSLConfig) (*StreamSSLManager, error) {
|
||||||
|
if !cfg.Enabled {
|
||||||
|
return &StreamSSLManager{config: cfg}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载服务器证书
|
||||||
|
cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load server certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := &StreamSSLManager{
|
||||||
|
config: cfg,
|
||||||
|
cert: cert,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载客户端 CA 证书(mTLS)
|
||||||
|
if cfg.ClientCA != "" {
|
||||||
|
pool, err := loadCertPool(cfg.ClientCA)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load client CA: %w", err)
|
||||||
|
}
|
||||||
|
mgr.clientCAPool = pool
|
||||||
|
}
|
||||||
|
|
||||||
|
return mgr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStreamProxySSLManager 创建上游 SSL 管理器。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - cfg: 代理 SSL 配置
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *StreamProxySSLManager: 代理 SSL 管理器实例
|
||||||
|
// - error: 证书加载失败时返回错误
|
||||||
|
func NewStreamProxySSLManager(cfg config.StreamProxySSLConfig) (*StreamProxySSLManager, error) {
|
||||||
|
if !cfg.Enabled {
|
||||||
|
return &StreamProxySSLManager{config: cfg}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := &StreamProxySSLManager{config: cfg}
|
||||||
|
|
||||||
|
// 加载客户端证书(mTLS)
|
||||||
|
if cfg.Cert != "" && cfg.Key != "" {
|
||||||
|
cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load client certificate: %w", err)
|
||||||
|
}
|
||||||
|
mgr.cert = cert
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载信任的 CA 证书
|
||||||
|
if cfg.TrustedCA != "" {
|
||||||
|
pool, err := loadCertPool(cfg.TrustedCA)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load trusted CA: %w", err)
|
||||||
|
}
|
||||||
|
mgr.rootCAPool = pool
|
||||||
|
}
|
||||||
|
|
||||||
|
return mgr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTLSConfig 获取服务端 TLS 配置。
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *tls.Config: TLS 配置对象
|
||||||
|
func (m *StreamSSLManager) GetTLSConfig() *tls.Config {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if !m.config.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{m.cert},
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置协议版本
|
||||||
|
if len(m.config.Protocols) > 0 {
|
||||||
|
tlsConfig.MinVersion = parseMinTLSVersion(m.config.Protocols)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置加密套件
|
||||||
|
if len(m.config.Ciphers) > 0 {
|
||||||
|
tlsConfig.CipherSuites = parseCipherSuites(m.config.Ciphers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 配置客户端证书验证(mTLS)
|
||||||
|
if m.clientCAPool != nil {
|
||||||
|
tlsConfig.ClientCAs = m.clientCAPool
|
||||||
|
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientTLSConfig 获取客户端 TLS 配置。
|
||||||
|
//
|
||||||
|
// 用于连接上游服务器。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - serverName: 服务器名称(用于 SNI)
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *tls.Config: TLS 配置对象
|
||||||
|
func (m *StreamProxySSLManager) GetClientTLSConfig(serverName string) *tls.Config {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if !m.config.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置服务器名称(SNI)
|
||||||
|
if m.config.ServerName != "" {
|
||||||
|
tlsConfig.ServerName = m.config.ServerName
|
||||||
|
} else if serverName != "" {
|
||||||
|
tlsConfig.ServerName = serverName
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置客户端证书
|
||||||
|
if m.cert.Certificate != nil {
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{m.cert}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置协议版本
|
||||||
|
if len(m.config.Protocols) > 0 {
|
||||||
|
tlsConfig.MinVersion = parseMinTLSVersion(m.config.Protocols)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 配置服务器证书验证
|
||||||
|
if m.config.Verify && m.rootCAPool != nil {
|
||||||
|
tlsConfig.RootCAs = m.rootCAPool
|
||||||
|
} else if !m.config.Verify {
|
||||||
|
// 跳过证书验证
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 会话复用
|
||||||
|
if m.config.SessionReuse {
|
||||||
|
tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(100)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEnabled 检查是否启用 SSL。
|
||||||
|
func (m *StreamSSLManager) IsEnabled() bool {
|
||||||
|
return m.config.Enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEnabled 检查是否启用代理 SSL。
|
||||||
|
func (m *StreamProxySSLManager) IsEnabled() bool {
|
||||||
|
return m.config.Enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadCertPool 从文件加载证书池。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - certFile: 证书文件路径
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *x509.CertPool: 证书池
|
||||||
|
// - error: 加载失败时返回错误
|
||||||
|
func loadCertPool(certFile string) (*x509.CertPool, error) {
|
||||||
|
data, err := os.ReadFile(certFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
if !pool.AppendCertsFromPEM(data) {
|
||||||
|
return nil, fmt.Errorf("failed to parse certificates from %s", certFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
return pool, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseMinTLSVersion 解析最小 TLS 版本。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - protocols: 协议版本列表
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - uint16: TLS 版本常量
|
||||||
|
func parseMinTLSVersion(protocols []string) uint16 {
|
||||||
|
for _, p := range protocols {
|
||||||
|
switch p {
|
||||||
|
case "TLSv1.3":
|
||||||
|
return tls.VersionTLS13
|
||||||
|
case "TLSv1.2":
|
||||||
|
return tls.VersionTLS12
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tls.VersionTLS12
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseCipherSuites 解析加密套件列表。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - ciphers: 加密套件名称列表
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - []uint16: 加密套件 ID 列表
|
||||||
|
func parseCipherSuites(ciphers []string) []uint16 {
|
||||||
|
var suites []uint16
|
||||||
|
for _, c := range ciphers {
|
||||||
|
if id, ok := cipherNameToID[c]; ok {
|
||||||
|
suites = append(suites, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(suites) == 0 {
|
||||||
|
return nil // 使用默认值
|
||||||
|
}
|
||||||
|
return suites
|
||||||
|
}
|
||||||
|
|
||||||
|
// cipherNameToID 加密套件名称到 ID 的映射
|
||||||
|
var cipherNameToID = map[string]uint16{
|
||||||
|
"ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
"ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
"ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
"ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
"ECDHE-RSA-CHACHA20-POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||||
|
"ECDHE-ECDSA-CHACHA20-POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||||
|
"AES128-GCM-SHA256": tls.TLS_AES_128_GCM_SHA256,
|
||||||
|
"AES256-GCM-SHA384": tls.TLS_AES_256_GCM_SHA384,
|
||||||
|
"CHACHA20-POLY1305": tls.TLS_CHACHA20_POLY1305_SHA256,
|
||||||
|
"ECDHE-RSA-AES128-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||||
|
"ECDHE-RSA-AES256-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
|
||||||
|
"ECDHE-ECDSA-AES128-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||||
|
"ECDHE-ECDSA-AES256-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||||
|
"RSA-AES128-GCM-SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||||
|
"RSA-AES256-GCM-SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||||
|
"RSA-AES128-CBC-SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||||
|
"RSA-AES256-CBC-SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||||
|
"ECDHE-RSA-3DES-EDE-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||||
|
"RSA-3DES-EDE-CBC-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||||
|
}
|
||||||
464
internal/stream/ssl_test.go
Normal file
464
internal/stream/ssl_test.go
Normal file
@ -0,0 +1,464 @@
|
|||||||
|
// Package stream 提供 TCP/UDP Stream 代理功能。
|
||||||
|
//
|
||||||
|
// 该文件包含 Stream SSL/TLS 支持的单元测试。
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"math/big"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// generateTestCertificate 生成测试用的自签名证书
|
||||||
|
func generateTestCertificate(t *testing.T, certFile, keyFile string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// 创建证书模板
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(24 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
DNSNames: []string{"localhost"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成私钥和证书
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入证书文件
|
||||||
|
certOut, err := os.Create(certFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create cert file: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = certOut.Close() }()
|
||||||
|
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
||||||
|
t.Fatalf("Failed to encode certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入私钥文件
|
||||||
|
keyOut, err := os.Create(keyFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create key file: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = keyOut.Close() }()
|
||||||
|
if err := pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}); err != nil {
|
||||||
|
t.Fatalf("Failed to encode key: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStreamSSLManager_Disabled(t *testing.T) {
|
||||||
|
cfg := config.StreamSSLConfig{
|
||||||
|
Enabled: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, err := NewStreamSSLManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewStreamSSLManager failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mgr.IsEnabled() {
|
||||||
|
t.Error("Expected IsEnabled to be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := mgr.GetTLSConfig()
|
||||||
|
if tlsConfig != nil {
|
||||||
|
t.Error("Expected nil TLS config when disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStreamSSLManager_Enabled(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
certFile := filepath.Join(tempDir, "server.crt")
|
||||||
|
keyFile := filepath.Join(tempDir, "server.key")
|
||||||
|
|
||||||
|
generateTestCertificate(t, certFile, keyFile)
|
||||||
|
|
||||||
|
cfg := config.StreamSSLConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Cert: certFile,
|
||||||
|
Key: keyFile,
|
||||||
|
Protocols: []string{"TLSv1.2", "TLSv1.3"},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, err := NewStreamSSLManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewStreamSSLManager failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !mgr.IsEnabled() {
|
||||||
|
t.Error("Expected IsEnabled to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := mgr.GetTLSConfig()
|
||||||
|
if tlsConfig == nil {
|
||||||
|
t.Fatal("Expected non-nil TLS config")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tlsConfig.Certificates) != 1 {
|
||||||
|
t.Errorf("Expected 1 certificate, got %d", len(tlsConfig.Certificates))
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConfig.MinVersion < tls.VersionTLS12 {
|
||||||
|
t.Errorf("Expected MinVersion >= TLS 1.2, got %v", tlsConfig.MinVersion)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStreamSSLManager_InvalidCert(t *testing.T) {
|
||||||
|
cfg := config.StreamSSLConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Cert: "/nonexistent/cert.pem",
|
||||||
|
Key: "/nonexistent/key.pem",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewStreamSSLManager(cfg)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for invalid certificate path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStreamProxySSLManager_Disabled(t *testing.T) {
|
||||||
|
cfg := config.StreamProxySSLConfig{
|
||||||
|
Enabled: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, err := NewStreamProxySSLManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewStreamProxySSLManager failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mgr.IsEnabled() {
|
||||||
|
t.Error("Expected IsEnabled to be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := mgr.GetClientTLSConfig("example.com")
|
||||||
|
if tlsConfig != nil {
|
||||||
|
t.Error("Expected nil TLS config when disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStreamProxySSLManager_Enabled(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
certFile := filepath.Join(tempDir, "client.crt")
|
||||||
|
keyFile := filepath.Join(tempDir, "client.key")
|
||||||
|
|
||||||
|
generateTestCertificate(t, certFile, keyFile)
|
||||||
|
|
||||||
|
cfg := config.StreamProxySSLConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Cert: certFile,
|
||||||
|
Key: keyFile,
|
||||||
|
ServerName: "backend.example.com",
|
||||||
|
Verify: false,
|
||||||
|
Protocols: []string{"TLSv1.2", "TLSv1.3"},
|
||||||
|
SessionReuse: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, err := NewStreamProxySSLManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewStreamProxySSLManager failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !mgr.IsEnabled() {
|
||||||
|
t.Error("Expected IsEnabled to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := mgr.GetClientTLSConfig("fallback.example.com")
|
||||||
|
if tlsConfig == nil {
|
||||||
|
t.Fatal("Expected non-nil TLS config")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应该使用配置中的 ServerName
|
||||||
|
if tlsConfig.ServerName != "backend.example.com" {
|
||||||
|
t.Errorf("Expected ServerName 'backend.example.com', got '%s'", tlsConfig.ServerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应该有客户端证书
|
||||||
|
if len(tlsConfig.Certificates) != 1 {
|
||||||
|
t.Errorf("Expected 1 client certificate, got %d", len(tlsConfig.Certificates))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 跳过验证
|
||||||
|
if !tlsConfig.InsecureSkipVerify {
|
||||||
|
t.Error("Expected InsecureSkipVerify to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 会话复用
|
||||||
|
if tlsConfig.ClientSessionCache == nil {
|
||||||
|
t.Error("Expected ClientSessionCache to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewStreamProxySSLManager_WithVerify(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
caFile := filepath.Join(tempDir, "ca.crt")
|
||||||
|
|
||||||
|
// 创建 CA 证书
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(24 * time.Hour),
|
||||||
|
IsCA: true,
|
||||||
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create CA certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
caOut, err := os.Create(caFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create CA file: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = caOut.Close() }()
|
||||||
|
if err := pem.Encode(caOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
||||||
|
t.Fatalf("Failed to encode CA: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := config.StreamProxySSLConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Verify: true,
|
||||||
|
TrustedCA: caFile,
|
||||||
|
ServerName: "backend.example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, err := NewStreamProxySSLManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewStreamProxySSLManager failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := mgr.GetClientTLSConfig("")
|
||||||
|
if tlsConfig == nil {
|
||||||
|
t.Fatal("Expected non-nil TLS config")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应该验证证书
|
||||||
|
if tlsConfig.InsecureSkipVerify {
|
||||||
|
t.Error("Expected InsecureSkipVerify to be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应该有 RootCAs
|
||||||
|
if tlsConfig.RootCAs == nil {
|
||||||
|
t.Error("Expected RootCAs to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseMinTLSVersion(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
protocols []string
|
||||||
|
wantVersion uint16
|
||||||
|
}{
|
||||||
|
{[]string{"TLSv1.3"}, tls.VersionTLS13},
|
||||||
|
{[]string{"TLSv1.2"}, tls.VersionTLS12},
|
||||||
|
{[]string{"TLSv1.2", "TLSv1.3"}, tls.VersionTLS12},
|
||||||
|
{[]string{}, tls.VersionTLS12},
|
||||||
|
{[]string{"Unknown"}, tls.VersionTLS12},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := parseMinTLSVersion(tt.protocols)
|
||||||
|
if got != tt.wantVersion {
|
||||||
|
t.Errorf("parseMinTLSVersion(%v) = %v, want %v", tt.protocols, got, tt.wantVersion)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCipherSuites(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ciphers []string
|
||||||
|
wantLen int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid ciphers",
|
||||||
|
ciphers: []string{"ECDHE-RSA-AES128-GCM-SHA256", "ECDHE-RSA-AES256-GCM-SHA384"},
|
||||||
|
wantLen: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty ciphers",
|
||||||
|
ciphers: []string{},
|
||||||
|
wantLen: 0, // returns nil for empty
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown ciphers",
|
||||||
|
ciphers: []string{"UNKNOWN-CIPHER"},
|
||||||
|
wantLen: 0, // returns nil for no valid ciphers
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := parseCipherSuites(tt.ciphers)
|
||||||
|
if tt.wantLen == 0 && got != nil {
|
||||||
|
t.Errorf("Expected nil, got %v", got)
|
||||||
|
} else if tt.wantLen > 0 && len(got) != tt.wantLen {
|
||||||
|
t.Errorf("Expected %d ciphers, got %d", tt.wantLen, len(got))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadCertPool(t *testing.T) {
|
||||||
|
t.Run("valid cert", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
certFile := filepath.Join(tempDir, "ca.crt")
|
||||||
|
|
||||||
|
// 创建证书
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(24 * time.Hour),
|
||||||
|
IsCA: true,
|
||||||
|
KeyUsage: x509.KeyUsageCertSign,
|
||||||
|
}
|
||||||
|
|
||||||
|
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
certDER, _ := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||||
|
|
||||||
|
certOut, err := os.Create(certFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create cert file: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = certOut.Close() }()
|
||||||
|
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
||||||
|
t.Fatalf("Failed to encode certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pool, err := loadCertPool(certFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("loadCertPool failed: %v", err)
|
||||||
|
}
|
||||||
|
if pool == nil {
|
||||||
|
t.Error("Expected non-nil pool")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid path", func(t *testing.T) {
|
||||||
|
_, err := loadCertPool("/nonexistent/cert.pem")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for nonexistent file")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid content", func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
certFile := filepath.Join(tempDir, "invalid.crt")
|
||||||
|
if err := os.WriteFile(certFile, []byte("not a certificate"), 0644); err != nil {
|
||||||
|
t.Fatalf("写入无效证书文件失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := loadCertPool(certFile)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for invalid certificate content")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamSSLManager_GetTLSConfig_WithClientCA(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
certFile := filepath.Join(tempDir, "server.crt")
|
||||||
|
keyFile := filepath.Join(tempDir, "server.key")
|
||||||
|
caFile := filepath.Join(tempDir, "ca.crt")
|
||||||
|
|
||||||
|
generateTestCertificate(t, certFile, keyFile)
|
||||||
|
|
||||||
|
// 创建 CA 证书
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(24 * time.Hour),
|
||||||
|
IsCA: true,
|
||||||
|
KeyUsage: x509.KeyUsageCertSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
certDER, _ := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||||
|
|
||||||
|
caOut, err := os.Create(caFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create CA file: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = caOut.Close() }()
|
||||||
|
if err := pem.Encode(caOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
|
||||||
|
t.Fatalf("Failed to encode CA: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := config.StreamSSLConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Cert: certFile,
|
||||||
|
Key: keyFile,
|
||||||
|
ClientCA: caFile,
|
||||||
|
Protocols: []string{"TLSv1.2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, err := NewStreamSSLManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewStreamSSLManager failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := mgr.GetTLSConfig()
|
||||||
|
if tlsConfig == nil {
|
||||||
|
t.Fatal("Expected non-nil TLS config")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应该配置客户端 CA
|
||||||
|
if tlsConfig.ClientCAs == nil {
|
||||||
|
t.Error("Expected ClientCAs to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应该请求客户端证书
|
||||||
|
if tlsConfig.ClientAuth != tls.VerifyClientCertIfGiven {
|
||||||
|
t.Errorf("Expected ClientAuth VerifyClientCertIfGiven, got %v", tlsConfig.ClientAuth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamProxySSLManager_GetClientTLSConfig_WithServerNameOverride(t *testing.T) {
|
||||||
|
cfg := config.StreamProxySSLConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Verify: false,
|
||||||
|
ServerName: "configured.example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, err := NewStreamProxySSLManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewStreamProxySSLManager failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 即使传入不同的 serverName,也应该使用配置的
|
||||||
|
tlsConfig := mgr.GetClientTLSConfig("fallback.example.com")
|
||||||
|
if tlsConfig == nil {
|
||||||
|
t.Fatal("Expected non-nil TLS config")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tlsConfig.ServerName != "configured.example.com" {
|
||||||
|
t.Errorf("Expected ServerName 'configured.example.com', got '%s'", tlsConfig.ServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user