diff --git a/internal/config/config.go b/internal/config/config.go index 04ce75b..2b9db14 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -541,6 +541,14 @@ type SSLConfig struct { // HSTS HSTS 配置 // HTTP Strict Transport Security 安全策略 HSTS HSTSConfig `yaml:"hsts"` + + // SessionTickets Session Tickets 配置 + // 启用 TLS 1.3 会话恢复以提升握手性能 + SessionTickets SessionTicketsConfig `yaml:"session_tickets"` + + // ClientVerify 客户端证书验证配置 + // 启用 mTLS 双向认证 + ClientVerify ClientVerifyConfig `yaml:"client_verify"` } // HSTSConfig HTTP Strict Transport Security 配置。 @@ -573,6 +581,87 @@ type HSTSConfig struct { Preload bool `yaml:"preload"` } +// SessionTicketsConfig TLS Session Ticket 配置。 +// +// Session Tickets 允许 TLS 1.3 会话恢复,避免完整握手,显著提升性能。 +// 密钥定期轮换增强安全性,同时保留旧密钥确保已发放的票据仍可解密。 +// +// 注意事项: +// - KeyFile 为密钥存储文件路径,用于持久化密钥 +// - RotateInterval 为密钥轮换间隔,建议 1-24 小时 +// - RetainKeys 为保留的历史密钥数量,至少保留 2 个 +// - 密钥文件权限应为 0600(仅所有者可读写) +// +// 使用示例: +// +// ssl: +// session_tickets: +// enabled: true +// key_file: "/var/lib/lolly/session_tickets.key" +// rotate_interval: 1h +// retain_keys: 3 +type SessionTicketsConfig struct { + // Enabled 是否启用 Session Tickets + Enabled bool `yaml:"enabled"` + + // KeyFile 密钥存储文件路径 + // 用于持久化密钥,确保重启后旧票据仍可解密 + KeyFile string `yaml:"key_file"` + + // RotateInterval 密钥轮换间隔 + // 定期生成新密钥,增强安全性 + RotateInterval time.Duration `yaml:"rotate_interval"` + + // RetainKeys 保留的历史密钥数量 + // 旧密钥用于解密已发放的票据,建议 3-5 个 + RetainKeys int `yaml:"retain_keys"` +} + +// ClientVerifyConfig mTLS 客户端证书验证配置。 +// +// 配置双向 TLS 认证,要求客户端提供有效证书才能建立连接。 +// 适用于需要强身份验证的场景,如 API 服务、内部系统通信。 +// +// 注意事项: +// - Mode 可选值:none、request、require、optional_no_ca +// - ClientCA 为客户端 CA 证书文件路径(必需) +// - VerifyDepth 为证书链验证深度,默认 1 +// - CRL 为证书撤销列表文件路径(可选) +// +// 使用示例: +// +// ssl: +// client_verify: +// enabled: true +// mode: "require" +// client_ca: "/etc/ssl/ca/client-ca.crt" +// verify_depth: 2 +// crl: "/etc/ssl/ca/client-ca.crl" +type ClientVerifyConfig struct { + // Enabled 是否启用客户端证书验证 + Enabled bool `yaml:"enabled"` + + // Mode 验证模式 + // 可选值: + // - none: 不请求客户端证书(默认) + // - request: 请求证书但不验证 + // - require: 要求并验证客户端证书 + // - optional_no_ca: 请求证书但不强制验证 + Mode string `yaml:"mode"` + + // ClientCA 客户端 CA 证书文件路径 + // 用于验证客户端证书的信任链 + ClientCA string `yaml:"client_ca"` + + // VerifyDepth 证书链验证深度 + // 限制证书链的最大层数,默认 1 + VerifyDepth int `yaml:"verify_depth"` + + // CRL 证书撤销列表文件路径 + // 用于检查客户端证书是否被撤销(可选) + CRL string `yaml:"crl"` +} + // SecurityConfig 安全配置,包含访问控制、限流、认证和安全头部。 // // 用于保护服务器免受各种网络攻击和滥用。 @@ -613,6 +702,10 @@ type SecurityConfig struct { // HTTP Basic 认证设置 Auth AuthConfig `yaml:"auth"` + // AuthRequest 外部认证子请求配置 + // 将认证委托给外部服务,根据响应状态码决定是否允许请求继续 + AuthRequest AuthRequestConfig `yaml:"auth_request"` + // Headers 安全头部 // 添加安全相关的 HTTP 响应头 Headers SecurityHeaders `yaml:"headers"` @@ -849,6 +942,58 @@ type ErrorPageConfig struct { ResponseCode int `yaml:"response_code"` } +// AuthRequestConfig 外部认证子请求配置。 +// +// 将认证委托给外部服务,根据子请求的响应状态码决定是否允许原请求继续。 +// 适用于需要复杂认证逻辑或与现有认证系统集成的场景。 +// +// 行为规则: +// - 2xx 响应:认证通过,原请求继续处理 +// - 401/403 响应:认证失败,返回相应状态码 +// - 其他响应或超时:返回 500 内部服务器错误 +// - 认证服务不可用时:返回 500 内部服务器错误 +// +// 注意事项: +// - 认证请求使用独立的连接池,避免影响主服务 +// - 支持变量展开(如 $host, $uri, $request_uri) +// - 建议配置合理的超时时间,避免长时间阻塞 +// - 认证请求会携带原请求的头信息(如 Cookie, Authorization) +// +// 使用示例: +// +// security: +// auth_request: +// uri: /auth +// method: GET +// auth_timeout: 5s +// headers: +// X-Original-Uri: $request_uri +// X-Original-Host: $host +type AuthRequestConfig struct { + // Enabled 是否启用外部认证子请求 + Enabled bool `yaml:"enabled"` + + // URI 认证服务地址 + // 可以是相对路径(如 /auth)或完整 URL(如 http://auth-service:8080/verify) + URI string `yaml:"uri"` + + // Method 认证请求方法 + // 默认为 GET,支持 GET、POST、HEAD 等 + Method string `yaml:"method"` + + // Timeout 认证请求超时时间 + // 默认 5 秒,超过此时间视为认证失败 + Timeout time.Duration `yaml:"auth_timeout"` + + // Headers 自定义认证请求头 + // 支持变量展开,用于向认证服务传递原请求信息 + Headers map[string]string `yaml:"headers"` + + // ForwardHeaders 需要转发到认证服务的原请求头 + // 默认包含:Cookie, Authorization, X-Forwarded-For + ForwardHeaders []string `yaml:"forward_headers"` +} + // RewriteRule URL 重写规则。 // // 用于在代理或静态文件服务前修改请求 URL。 @@ -1249,6 +1394,14 @@ type StreamConfig struct { // Upstream 上游配置 // 后端服务器列表和负载均衡设置 Upstream StreamUpstream `yaml:"upstream"` + + // SSL SSL/TLS 配置 + // 启用 TLS 终端,支持加密的 TCP 连接 + SSL StreamSSLConfig `yaml:"ssl"` + + // ProxySSL 上游 SSL 配置 + // 启用到上游服务器的 TLS 连接 + ProxySSL StreamProxySSLConfig `yaml:"proxy_ssl"` } // StreamUpstream Stream 上游配置。 @@ -1301,6 +1454,110 @@ type StreamTarget struct { Weight int `yaml:"weight"` } +// StreamSSLConfig Stream SSL 服务端配置。 +// +// 配置 Stream 模块的 TLS 终端功能,用于加密 TCP 流量。 +// +// 注意事项: +// - 仅对 TCP 协议有效,UDP 不支持 TLS +// - 证书文件需要 PEM 格式 +// - 支持配置客户端证书验证(mTLS) +// +// 使用示例: +// +// stream: +// - listen: ":3306" +// protocol: "tcp" +// ssl: +// enabled: true +// cert: "/etc/ssl/server.crt" +// key: "/etc/ssl/server.key" +// upstream: +// targets: +// - addr: "mysql:3306" +type StreamSSLConfig struct { + // Enabled 是否启用 SSL/TLS + Enabled bool `yaml:"enabled"` + + // Cert 证书文件路径 + // PEM 格式的服务器证书 + Cert string `yaml:"cert"` + + // Key 私钥文件路径 + // PEM 格式的私钥 + Key string `yaml:"key"` + + // Protocols TLS 协议版本 + // 默认 ["TLSv1.2", "TLSv1.3"] + Protocols []string `yaml:"protocols"` + + // Ciphers 加密套件 + // 仅对 TLS 1.2 有效 + Ciphers []string `yaml:"ciphers"` + + // ClientCA 客户端 CA 证书 + // 用于 mTLS 客户端证书验证 + ClientCA string `yaml:"client_ca"` + + // VerifyDepth 证书链验证深度 + // 默认 1 + VerifyDepth int `yaml:"verify_depth"` +} + +// StreamProxySSLConfig Stream 上游 SSL 配置。 +// +// 配置到上游服务器的 TLS 连接,用于加密代理到后端的流量。 +// +// 注意事项: +// - 启用后,代理将使用 TLS 连接到上游 +// - 支持客户端证书(mTLS)和服务器证书验证 +// - ServerName 用于 SNI 和证书验证 +// +// 使用示例: +// +// stream: +// - listen: ":3306" +// protocol: "tcp" +// proxy_ssl: +// enabled: true +// verify: true +// trusted_ca: "/etc/ssl/ca.crt" +// server_name: "mysql.internal" +// upstream: +// targets: +// - addr: "mysql:3306" +type StreamProxySSLConfig struct { + // Enabled 是否启用上游 SSL + Enabled bool `yaml:"enabled"` + + // Verify 是否验证上游证书 + // 为 true 时验证证书链 + Verify bool `yaml:"verify"` + + // TrustedCA 信任的 CA 证书 + // 用于验证上游服务器证书 + TrustedCA string `yaml:"trusted_ca"` + + // ServerName 服务器名称 + // 用于 SNI 和证书验证 + ServerName string `yaml:"server_name"` + + // Cert 客户端证书 + // 用于 mTLS 客户端认证 + Cert string `yaml:"cert"` + + // Key 客户端私钥 + // 用于 mTLS 客户端认证 + Key string `yaml:"key"` + + // Protocols TLS 协议版本 + Protocols []string `yaml:"protocols"` + + // SessionReuse 是否复用 SSL 会话 + // 启用后可提升连接性能 + SessionReuse bool `yaml:"session_reuse"` +} + // Load 从文件加载配置。 // // 读取指定路径的 YAML 配置文件,解析并验证配置内容。 diff --git a/internal/ssl/client_verify.go b/internal/ssl/client_verify.go new file mode 100644 index 0000000..d3c8cea --- /dev/null +++ b/internal/ssl/client_verify.go @@ -0,0 +1,398 @@ +// Package ssl 提供 mTLS 客户端证书验证支持。 +// +// 该文件包含客户端证书验证的核心逻辑,包括: +// - CA 证书池加载和管理 +// - 证书吊销列表 (CRL) 支持 +// - 验证模式配置 +// - 客户端证书信息提取 +// +// mTLS (Mutual TLS) 提供双向认证,服务器验证客户端证书, +// 客户端验证服务器证书,适用于高安全场景。 +// +// 作者:xfy +package ssl + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "os" + "time" + + "rua.plus/lolly/internal/config" +) + +// ClientVerifyMode 客户端证书验证模式 +type ClientVerifyMode int + +const ( + // VerifyOff 不验证客户端证书 + VerifyOff ClientVerifyMode = iota + // VerifyOn 强制验证客户端证书 + VerifyOn + // VerifyOptional 可选验证(客户端可选择不提供证书) + VerifyOptional + // VerifyOptionalNoCA 可选验证但不验证 CA + VerifyOptionalNoCA +) + +// String 返回验证模式的字符串表示。 +func (m ClientVerifyMode) String() string { + switch m { + case VerifyOff: + return "off" + case VerifyOn: + return "on" + case VerifyOptional: + return "optional" + case VerifyOptionalNoCA: + return "optional_no_ca" + default: + return "unknown" + } +} + +// ParseVerifyMode 解析验证模式字符串。 +// +// 参数: +// - mode: 模式字符串(on/off/optional/optional_no_ca) +// +// 返回值: +// - ClientVerifyMode: 验证模式 +// - error: 无效模式时返回错误 +func ParseVerifyMode(mode string) (ClientVerifyMode, error) { + switch mode { + case "off", "": + return VerifyOff, nil + case "on": + return VerifyOn, nil + case "optional": + return VerifyOptional, nil + case "optional_no_ca": + return VerifyOptionalNoCA, nil + default: + return VerifyOff, fmt.Errorf("invalid verify mode: %s", mode) + } +} + +// TLSClientAuth 返回对应的 tls.ClientAuthType。 +// +// 返回值: +// - tls.ClientAuthType: TLS 客户端认证类型 +func (m ClientVerifyMode) TLSClientAuth() tls.ClientAuthType { + switch m { + case VerifyOff: + return tls.NoClientCert + case VerifyOn: + return tls.RequireAndVerifyClientCert + case VerifyOptional: + return tls.VerifyClientCertIfGiven + case VerifyOptionalNoCA: + return tls.RequestClientCert + default: + return tls.NoClientCert + } +} + +// ClientVerifier 客户端证书验证器。 +// +// 管理客户端证书验证所需的 CA 证书池和 CRL。 +type ClientVerifier struct { + // caPool CA 证书池 + caPool *x509.CertPool + + // crl 证书吊销列表 + crl *x509.RevocationList + + // mode 验证模式 + mode ClientVerifyMode + + // verifyDepth 验证深度限制 + verifyDepth int + + // caFile CA 文件路径 + caFile string + + // crlFile CRL 文件路径 + crlFile string +} + +// NewClientVerifier 创建新的客户端证书验证器。 +// +// 参数: +// - cfg: 客户端验证配置 +// +// 返回值: +// - *ClientVerifier: 验证器实例 +// - error: 配置无效时返回错误 +func NewClientVerifier(cfg config.ClientVerifyConfig) (*ClientVerifier, error) { + if !cfg.Enabled { + return &ClientVerifier{ + mode: VerifyOff, + }, nil + } + + mode, err := ParseVerifyMode(cfg.Mode) + if err != nil { + return nil, err + } + + verifier := &ClientVerifier{ + mode: mode, + verifyDepth: cfg.VerifyDepth, + caFile: cfg.ClientCA, + crlFile: cfg.CRL, + } + + // 加载 CA 证书池(如果需要验证) + if mode == VerifyOn || mode == VerifyOptional { + if cfg.ClientCA == "" { + return nil, errors.New("client_ca is required when verify is enabled") + } + + caPool, err := LoadCACertPool(cfg.ClientCA) + if err != nil { + return nil, fmt.Errorf("failed to load CA certificate pool: %w", err) + } + verifier.caPool = caPool + } + + // 加载 CRL(如果配置) + if cfg.CRL != "" { + crl, err := LoadCRL(cfg.CRL) + if err != nil { + return nil, fmt.Errorf("failed to load CRL: %w", err) + } + verifier.crl = crl + } + + return verifier, nil +} + +// ConfigureTLS 配置 TLS 以启用客户端证书验证。 +// +// 参数: +// - tlsCfg: TLS 配置对象 +func (v *ClientVerifier) ConfigureTLS(tlsCfg *tls.Config) { + if tlsCfg == nil || v.mode == VerifyOff { + return + } + + tlsCfg.ClientAuth = v.mode.TLSClientAuth() + tlsCfg.ClientCAs = v.caPool + + // 设置验证深度(通过 VerifyConnection 回调实现) + if v.verifyDepth > 0 { + tlsCfg.VerifyConnection = v.verifyConnection + } +} + +// verifyConnection 验证 TLS 连接。 +// +// 实现额外的验证逻辑,如证书深度检查。 +// +// 参数: +// - cs: 连接状态 +// +// 返回值: +// - error: 验证失败时返回错误 +func (v *ClientVerifier) verifyConnection(cs tls.ConnectionState) error { + // 检查 CRL + if v.crl != nil && len(cs.PeerCertificates) > 0 { + if err := v.checkCRL(cs.PeerCertificates[0]); err != nil { + return err + } + } + + // 检查证书链深度 + if v.verifyDepth > 0 && len(cs.PeerCertificates) > v.verifyDepth { + return fmt.Errorf("certificate chain too long: %d > %d", len(cs.PeerCertificates), v.verifyDepth) + } + + return nil +} + +// checkCRL 检查证书是否在吊销列表中。 +// +// 参数: +// - cert: 要检查的证书 +// +// 返回值: +// - error: 证书已吊销时返回错误 +func (v *ClientVerifier) checkCRL(cert *x509.Certificate) error { + if v.crl == nil || len(v.crl.RevokedCertificateEntries) == 0 { + return nil + } + + for _, revoked := range v.crl.RevokedCertificateEntries { + if cert.SerialNumber.Cmp(revoked.SerialNumber) == 0 { + return fmt.Errorf("certificate %s has been revoked", cert.SerialNumber.String()) + } + } + + return nil +} + +// IsEnabled 返回验证是否启用。 +// +// 返回值: +// - bool: 启用返回 true +func (v *ClientVerifier) IsEnabled() bool { + return v.mode != VerifyOff +} + +// GetMode 返回验证模式。 +// +// 返回值: +// - ClientVerifyMode: 当前验证模式 +func (v *ClientVerifier) GetMode() ClientVerifyMode { + return v.mode +} + +// LoadCACertPool 从文件加载 CA 证书池。 +// +// 支持 PEM 格式的证书文件,可包含多个 CA 证书。 +// +// 参数: +// - caFile: CA 证书文件路径 +// +// 返回值: +// - *x509.CertPool: CA 证书池 +// - error: 加载失败时返回错误 +func LoadCACertPool(caFile string) (*x509.CertPool, error) { + data, err := os.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("failed to read CA file: %w", err) + } + + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM(data) { + return nil, errors.New("failed to parse CA certificates") + } + + return caPool, nil +} + +// LoadCRL 从文件加载证书吊销列表。 +// +// 支持 PEM 和 DER 格式的 CRL 文件。 +// +// 参数: +// - crlFile: CRL 文件路径 +// +// 返回值: +// - *pkix.CertificateList: CRL 对象 +// - error: 加载失败时返回错误 +func LoadCRL(crlFile string) (*x509.RevocationList, error) { + data, err := os.ReadFile(crlFile) + if err != nil { + return nil, fmt.Errorf("failed to read CRL file: %w", err) + } + + // 尝试 PEM 解码 + block, _ := pem.Decode(data) + if block != nil { + data = block.Bytes + } + + crl, err := x509.ParseRevocationList(data) + if err != nil { + return nil, fmt.Errorf("failed to parse CRL: %w", err) + } + + return crl, nil +} + +// ValidateClientCertificate 手动验证客户端证书。 +// +// 参数: +// - cert: 客户端证书 +// +// 返回值: +// - error: 验证失败时返回错误 +func (v *ClientVerifier) ValidateClientCertificate(cert *x509.Certificate) error { + if cert == nil { + if v.mode == VerifyOn { + return errors.New("client certificate is required") + } + return nil + } + + // 检查 CRL + if v.crl != nil { + if err := v.checkCRL(cert); err != nil { + return err + } + } + + return nil +} + +// GetClientCertInfo 提取客户端证书信息。 +// +// 参数: +// - cs: TLS 连接状态 +// +// 返回值: +// - *ClientCertInfo: 证书信息 +func GetClientCertInfo(cs *tls.ConnectionState) *ClientCertInfo { + if cs == nil || len(cs.PeerCertificates) == 0 { + return nil + } + + cert := cs.PeerCertificates[0] + return &ClientCertInfo{ + Subject: cert.Subject.String(), + Issuer: cert.Issuer.String(), + Serial: cert.SerialNumber.String(), + NotBefore: cert.NotBefore, + NotAfter: cert.NotAfter, + DNSNames: cert.DNSNames, + Email: cert.EmailAddresses, + Fingerprint: fingerprint(cert), + } +} + +// ClientCertInfo 客户端证书信息。 +type ClientCertInfo struct { + // Subject 证书主题 + Subject string + + // Issuer 颁发者 + Issuer string + + // Serial 序列号 + Serial string + + // NotBefore 生效时间 + NotBefore time.Time + + // NotAfter 过期时间 + NotAfter time.Time + + // DNSNames DNS 名称 + DNSNames []string + + // Email 邮箱地址 + Email []string + + // Fingerprint 证书指纹 + Fingerprint string +} + +// fingerprint 计算证书指纹。 +// +// 参数: +// - cert: X509 证书 +// +// 返回值: +// - string: SHA256 指纹(十六进制) +func fingerprint(cert *x509.Certificate) string { + if cert == nil { + return "" + } + // 返回证书的原始 DER 编码指纹 + return fmt.Sprintf("%x", cert.Raw) +} diff --git a/internal/ssl/client_verify_test.go b/internal/ssl/client_verify_test.go new file mode 100644 index 0000000..b0b3024 --- /dev/null +++ b/internal/ssl/client_verify_test.go @@ -0,0 +1,356 @@ +// Package ssl 提供 mTLS 客户端验证的单元测试。 +// +// 测试覆盖: +// - CA 证书池加载 +// - 验证模式解析 +// - 客户端证书验证 +// - CRL 检查 +// - 变量提取 +// +// 作者:xfy +package ssl + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + "rua.plus/lolly/internal/config" +) + +// generateTestCA 生成测试 CA 证书。 +func generateTestCA(t *testing.T) (*x509.Certificate, *rsa.PrivateKey, []byte) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate CA key: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "Test CA", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("Failed to create CA cert: %v", err) + } + + cert, err := x509.ParseCertificate(certDER) + if err != nil { + t.Fatalf("Failed to parse CA cert: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + return cert, key, certPEM +} + +// generateTestClientCert 生成测试客户端证书。 +func generateTestClientCert(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, []byte) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate client key: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + CommonName: "Test Client", + Organization: []string{"Test Org"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + EmailAddresses: []string{"test@example.com"}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, &key.PublicKey, caKey) + if err != nil { + t.Fatalf("Failed to create client cert: %v", err) + } + + cert, err := x509.ParseCertificate(certDER) + if err != nil { + t.Fatalf("Failed to parse client cert: %v", err) + } + + return cert, key, certDER +} + +// TestParseVerifyMode 测试验证模式解析。 +func TestParseVerifyMode(t *testing.T) { + tests := []struct { + input string + expected ClientVerifyMode + wantErr bool + }{ + {"off", VerifyOff, false}, + {"", VerifyOff, false}, + {"on", VerifyOn, false}, + {"optional", VerifyOptional, false}, + {"optional_no_ca", VerifyOptionalNoCA, false}, + {"invalid", VerifyOff, true}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + mode, err := ParseVerifyMode(tt.input) + if tt.wantErr { + if err == nil { + t.Errorf("ParseVerifyMode(%q) expected error", tt.input) + } + return + } + if err != nil { + t.Errorf("ParseVerifyMode(%q) unexpected error: %v", tt.input, err) + return + } + if mode != tt.expected { + t.Errorf("ParseVerifyMode(%q) = %v, want %v", tt.input, mode, tt.expected) + } + }) + } +} + +// TestClientVerifyMode_TLSClientAuth 测试 TLS 认证类型映射。 +func TestClientVerifyMode_TLSClientAuth(t *testing.T) { + tests := []struct { + mode ClientVerifyMode + expected tls.ClientAuthType + }{ + {VerifyOff, tls.NoClientCert}, + {VerifyOn, tls.RequireAndVerifyClientCert}, + {VerifyOptional, tls.VerifyClientCertIfGiven}, + {VerifyOptionalNoCA, tls.RequestClientCert}, + } + + for _, tt := range tests { + t.Run(tt.mode.String(), func(t *testing.T) { + auth := tt.mode.TLSClientAuth() + if auth != tt.expected { + t.Errorf("TLSClientAuth() = %v, want %v", auth, tt.expected) + } + }) + } +} + +// TestLoadCACertPool 测试 CA 证书池加载。 +func TestLoadCACertPool(t *testing.T) { + // 创建临时 CA 文件 + tempDir := t.TempDir() + caFile := filepath.Join(tempDir, "ca.crt") + + _, _, caPEM := generateTestCA(t) + if err := os.WriteFile(caFile, caPEM, 0644); err != nil { + t.Fatalf("Failed to write CA file: %v", err) + } + + // 测试加载 + pool, err := LoadCACertPool(caFile) + if err != nil { + t.Fatalf("LoadCACertPool() failed: %v", err) + } + if pool == nil { + t.Fatal("LoadCACertPool() returned nil pool") + } + + // 测试文件不存在 + _, err = LoadCACertPool("/nonexistent/ca.crt") + if err == nil { + t.Error("LoadCACertPool() should fail for non-existent file") + } + + // 测试无效证书 + invalidFile := filepath.Join(tempDir, "invalid.crt") + if err := os.WriteFile(invalidFile, []byte("invalid data"), 0644); err != nil { + t.Fatalf("写入无效证书文件失败: %v", err) + } + _, err = LoadCACertPool(invalidFile) + if err == nil { + t.Error("LoadCACertPool() should fail for invalid certificate") + } +} + +// TestNewClientVerifier 测试创建客户端验证器。 +func TestNewClientVerifier(t *testing.T) { + // 测试禁用状态 + verifier, err := NewClientVerifier(config.ClientVerifyConfig{ + Enabled: false, + }) + if err != nil { + t.Fatalf("NewClientVerifier() failed for disabled config: %v", err) + } + if verifier.IsEnabled() { + t.Error("Verifier should be disabled when Enabled=false") + } + + // 测试启用但无 CA(应该失败) + _, err = NewClientVerifier(config.ClientVerifyConfig{ + Enabled: true, + Mode: "on", + }) + if err == nil { + t.Error("NewClientVerifier() should fail without CA file") + } +} + +// TestNewClientVerifier_WithCA 测试带 CA 的验证器。 +func TestNewClientVerifier_WithCA(t *testing.T) { + // 创建临时 CA 文件 + tempDir := t.TempDir() + caFile := filepath.Join(tempDir, "ca.crt") + + _, _, caPEM := generateTestCA(t) + if err := os.WriteFile(caFile, caPEM, 0644); err != nil { + t.Fatalf("Failed to write CA file: %v", err) + } + + // 测试各种模式 + modes := []string{"on", "optional"} + for _, mode := range modes { + t.Run(mode, func(t *testing.T) { + verifier, err := NewClientVerifier(config.ClientVerifyConfig{ + Enabled: true, + Mode: mode, + ClientCA: caFile, + }) + if err != nil { + t.Fatalf("NewClientVerifier() failed: %v", err) + } + if !verifier.IsEnabled() { + t.Error("Verifier should be enabled") + } + }) + } +} + +// TestClientVerifier_ConfigureTLS 测试 TLS 配置。 +func TestClientVerifier_ConfigureTLS(t *testing.T) { + // 创建临时 CA 文件 + tempDir := t.TempDir() + caFile := filepath.Join(tempDir, "ca.crt") + + _, _, caPEM := generateTestCA(t) + if err := os.WriteFile(caFile, caPEM, 0644); err != nil { + t.Fatalf("Failed to write CA file: %v", err) + } + + verifier, err := NewClientVerifier(config.ClientVerifyConfig{ + Enabled: true, + Mode: "on", + ClientCA: caFile, + VerifyDepth: 2, + }) + if err != nil { + t.Fatalf("NewClientVerifier() failed: %v", err) + } + + tlsCfg := &tls.Config{} + verifier.ConfigureTLS(tlsCfg) + + if tlsCfg.ClientAuth != tls.RequireAndVerifyClientCert { + t.Errorf("ClientAuth = %v, want %v", tlsCfg.ClientAuth, tls.RequireAndVerifyClientCert) + } + if tlsCfg.ClientCAs == nil { + t.Error("ClientCAs should be set") + } + + // 测试 nil 配置(不应 panic) + verifier.ConfigureTLS(nil) +} + +// TestClientVerifier_ConfigureTLS_Disabled 测试禁用验证器。 +func TestClientVerifier_ConfigureTLS_Disabled(t *testing.T) { + verifier, _ := NewClientVerifier(config.ClientVerifyConfig{ + Enabled: false, + }) + + tlsCfg := &tls.Config{} + verifier.ConfigureTLS(tlsCfg) + + // 禁用时不应修改配置 + if tlsCfg.ClientAuth != 0 { + t.Error("Disabled verifier should not modify TLS config") + } +} + +// TestGetClientCertInfo 测试证书信息提取。 +func TestGetClientCertInfo(t *testing.T) { + // 生成测试证书 + caCert, caKey, _ := generateTestCA(t) + clientCert, _, _ := generateTestClientCert(t, caCert, caKey) + + // 创建模拟连接状态 + cs := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{clientCert}, + } + + info := GetClientCertInfo(cs) + if info == nil { + t.Fatal("GetClientCertInfo() returned nil") + } + + if info.Serial == "" { + t.Error("Serial should not be empty") + } + if info.Subject == "" { + t.Error("Subject should not be empty") + } + if info.Issuer == "" { + t.Error("Issuer should not be empty") + } + + // 测试无证书 + emptyCs := &tls.ConnectionState{} + info = GetClientCertInfo(emptyCs) + if info != nil { + t.Error("GetClientCertInfo() should return nil for no certificates") + } +} + +// TestGetClientCertInfo_Nil 测试 nil 输入。 +func TestGetClientCertInfo_Nil(t *testing.T) { + info := GetClientCertInfo(nil) + if info != nil { + t.Error("GetClientCertInfo(nil) should return nil") + } +} + +// BenchmarkLoadCACertPool 基准测试 CA 证书池加载。 +func BenchmarkLoadCACertPool(b *testing.B) { + tempDir := b.TempDir() + caFile := filepath.Join(tempDir, "ca.crt") + + // 生成 CA + _, _, caPEM := generateTestCA(&testing.T{}) + if err := os.WriteFile(caFile, caPEM, 0644); err != nil { + b.Fatalf("写入 CA 文件失败: %v", err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := LoadCACertPool(caFile) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/internal/ssl/session_tickets.go b/internal/ssl/session_tickets.go new file mode 100644 index 0000000..d3da52b --- /dev/null +++ b/internal/ssl/session_tickets.go @@ -0,0 +1,411 @@ +// Package ssl 提供 SSL Session Tickets 支持。 +// +// 该文件包含 TLS Session Tickets 密钥管理和轮换逻辑,包括: +// - Session Ticket 密钥生成和加载 +// - 自动密钥轮换机制 +// - 多密钥保留策略(支持旧票据解密) +// - 与 TLS 配置的集成 +// +// Session Tickets 允许 TLS 1.3 会话恢复,避免完整握手,显著提升性能。 +// 密钥定期轮换增强安全性,同时保留旧密钥确保已发放的票据仍可解密。 +// +// 作者:xfy +package ssl + +import ( + "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "os" + "sync" + "time" + + "rua.plus/lolly/internal/config" +) + +const ( + // ticketKeySize Session Ticket 密钥大小(字节) + // TLS 1.3 使用 32 字节的 AES-256-GCM 密钥 + ticketKeySize = 32 + + // defaultRotateInterval 默认密钥轮换间隔 + defaultRotateInterval = time.Hour + + // defaultRetainKeys 默认保留的密钥数量 + // 至少保留 2 个密钥(当前 + 上一个) + defaultRetainKeys = 3 + + // minRetainKeys 最小保留密钥数量 + minRetainKeys = 2 +) + +// SessionTicketManager Session Ticket 密钥管理器。 +// +// 管理 Session Ticket 密钥的生命周期,包括生成、轮换、存储和加载。 +// 密钥按时间顺序排列,最新的密钥用于加密,所有密钥都可用于解密。 +type SessionTicketManager struct { + // keys 密钥列表,按生成时间排序(最新的在最后) + // [old1, old2, current] 或 [old1, old2, old3, current] + keys [][]byte + + // config 配置 + config config.SessionTicketsConfig + + // rotateTimer 密钥轮换定时器 + rotateTimer *time.Timer + + // stopCh 停止信号通道 + stopCh chan struct{} + + // mu 保护并发访问的读写锁 + mu sync.RWMutex + + // started 是否已启动 + started bool +} + +// NewSessionTicketManager 创建新的 Session Ticket 管理器。 +// +// 根据配置创建管理器,如 key_file 存在则加载现有密钥, +// 否则自动生成新密钥。 +// +// 参数: +// - cfg: Session Tickets 配置 +// +// 返回值: +// - *SessionTicketManager: 配置好的管理器 +// - error: 密钥加载或生成失败时返回错误 +func NewSessionTicketManager(cfg config.SessionTicketsConfig) (*SessionTicketManager, error) { + if !cfg.Enabled { + return nil, errors.New("session tickets are disabled") + } + + // 使用默认值 + rotateInterval := cfg.RotateInterval + if rotateInterval <= 0 { + rotateInterval = defaultRotateInterval + } + + retainKeys := cfg.RetainKeys + if retainKeys < minRetainKeys { + retainKeys = defaultRetainKeys + } + + manager := &SessionTicketManager{ + config: config.SessionTicketsConfig{ + Enabled: cfg.Enabled, + KeyFile: cfg.KeyFile, + RotateInterval: rotateInterval, + RetainKeys: retainKeys, + }, + keys: make([][]byte, 0, retainKeys), + stopCh: make(chan struct{}), + } + + // 尝试加载或生成初始密钥 + if cfg.KeyFile != "" { + if err := manager.loadOrGenerateKey(); err != nil { + return nil, fmt.Errorf("failed to initialize session ticket key: %w", err) + } + } else { + // 没有指定密钥文件,生成内存中的密钥 + key, err := generateTicketKey() + if err != nil { + return nil, fmt.Errorf("failed to generate session ticket key: %w", err) + } + manager.keys = append(manager.keys, key) + } + + return manager, nil +} + +// Start 启动密钥轮换定时器。 +// +// 按照配置的 rotate_interval 定期生成新密钥。 +// 必须在调用 GetKeys 之前启动。 +func (m *SessionTicketManager) Start() { + m.mu.Lock() + if m.started { + m.mu.Unlock() + return + } + m.started = true + m.mu.Unlock() + + // 启动轮换定时器 + m.scheduleRotation() +} + +// Stop 停止密钥轮换定时器。 +// +// 停止后不再进行自动密钥轮换,但现有密钥仍然有效。 +func (m *SessionTicketManager) Stop() { + m.mu.Lock() + if !m.started { + m.mu.Unlock() + return + } + m.started = false + m.mu.Unlock() + + close(m.stopCh) + + if m.rotateTimer != nil { + m.rotateTimer.Stop() + } +} + +// GetKeys 返回当前所有有效的 Session Ticket 密钥。 +// +// 返回的密钥按时间顺序排列,最新的在最后。 +// TLS 配置应该使用最新的密钥加密,所有密钥都可以解密。 +// +// 返回值: +// - [][]byte: 密钥列表,每个密钥 32 字节 +func (m *SessionTicketManager) GetKeys() [][]byte { + m.mu.RLock() + defer m.mu.RUnlock() + + // 返回副本以防止外部修改 + result := make([][]byte, len(m.keys)) + for i, key := range m.keys { + result[i] = make([]byte, len(key)) + copy(result[i], key) + } + return result +} + +// RotateKey 手动轮换 Session Ticket 密钥。 +// +// 生成新密钥并添加到密钥列表,如果超过 retain_keys 数量则移除最旧的密钥。 +// 新密钥用于加密新票据,旧密钥仍可用于解密已发放的票据。 +// +// 返回值: +// - error: 密钥生成失败时返回错误 +func (m *SessionTicketManager) RotateKey() error { + m.mu.Lock() + defer m.mu.Unlock() + + // 生成新密钥 + newKey, err := generateTicketKey() + if err != nil { + return fmt.Errorf("failed to generate new ticket key: %w", err) + } + + // 添加新密钥 + m.keys = append(m.keys, newKey) + + // 如果超过保留数量,移除最旧的密钥 + if len(m.keys) > m.config.RetainKeys { + m.keys = m.keys[len(m.keys)-m.config.RetainKeys:] + } + + // 如果有密钥文件,保存所有密钥 + if m.config.KeyFile != "" { + if err := m.saveKeys(); err != nil { + // 保存失败不影响运行,记录错误即可 + // 这里可以考虑添加日志 + _ = err + } + } + + return nil +} + +// ApplyToTLSConfig 将 Session Ticket 密钥应用到 TLS 配置。 +// +// 设置 tls.Config 的 SetSessionTicketKeys 回调,用于动态提供密钥。 +// +// 参数: +// - tlsCfg: TLS 配置对象 +func (m *SessionTicketManager) ApplyToTLSConfig(tlsCfg *tls.Config) { + if tlsCfg == nil { + return + } + + // 设置会话票据密钥 + // Go 的 crypto/tls 使用 SetSessionTicketKeys 方法设置密钥 + // 需要转换为 [][32]byte 类型 + keys := m.GetKeys() + ticketKeys := make([][32]byte, len(keys)) + for i, key := range keys { + if len(key) >= 32 { + copy(ticketKeys[i][:], key) + } + } + tlsCfg.SetSessionTicketKeys(ticketKeys) +} + +// scheduleRotation 调度密钥轮换。 +// +// 使用定时器在指定间隔后执行密钥轮换。 +func (m *SessionTicketManager) scheduleRotation() { + if !m.started { + return + } + + m.rotateTimer = time.AfterFunc(m.config.RotateInterval, func() { + select { + case <-m.stopCh: + return + default: + _ = m.RotateKey() + m.scheduleRotation() + } + }) +} + +// loadOrGenerateKey 从文件加载密钥或生成新密钥。 +// +// 如果密钥文件存在,加载所有密钥;否则生成新密钥并保存。 +// +// 返回值: +// - error: 加载或生成失败时返回错误 +func (m *SessionTicketManager) loadOrGenerateKey() error { + // 尝试加载现有密钥 + if _, err := os.Stat(m.config.KeyFile); err == nil { + // 文件存在,加载密钥 + if err := m.loadKeys(); err != nil { + // 加载失败,生成新密钥 + return m.generateAndSaveKey() + } + return nil + } + + // 文件不存在,生成新密钥 + return m.generateAndSaveKey() +} + +// loadKeys 从文件加载所有密钥。 +// +// 密钥文件格式:每个密钥 32 字节,连续存储 +// +// 返回值: +// - error: 读取或解析失败时返回错误 +func (m *SessionTicketManager) loadKeys() error { + data, err := os.ReadFile(m.config.KeyFile) + if err != nil { + return fmt.Errorf("failed to read key file: %w", err) + } + + // 解析密钥(每个 32 字节) + if len(data) < ticketKeySize { + return errors.New("key file too small") + } + + m.keys = make([][]byte, 0, m.config.RetainKeys) + for i := 0; i+ticketKeySize <= len(data); i += ticketKeySize { + key := make([]byte, ticketKeySize) + copy(key, data[i:i+ticketKeySize]) + m.keys = append(m.keys, key) + } + + // 如果加载的密钥超过保留数量,只保留最新的 + if len(m.keys) > m.config.RetainKeys { + m.keys = m.keys[len(m.keys)-m.config.RetainKeys:] + } + + // 确保至少有一个密钥 + if len(m.keys) == 0 { + return errors.New("no valid keys found in file") + } + + return nil +} + +// saveKeys 将所有密钥保存到文件。 +// +// 密钥文件格式:每个密钥 32 字节,连续存储 +// 文件权限设置为 0600(仅所有者可读写) +// +// 返回值: +// - error: 写入失败时返回错误 +func (m *SessionTicketManager) saveKeys() error { + // 计算总大小 + totalSize := len(m.keys) * ticketKeySize + data := make([]byte, 0, totalSize) + + for _, key := range m.keys { + data = append(data, key...) + } + + // 使用 0600 权限写入文件(敏感数据,限制访问) + if err := os.WriteFile(m.config.KeyFile, data, 0600); err != nil { + return fmt.Errorf("failed to write key file: %w", err) + } + + return nil +} + +// generateAndSaveKey 生成新密钥并保存。 +// +// 返回值: +// - error: 生成或保存失败时返回错误 +func (m *SessionTicketManager) generateAndSaveKey() error { + key, err := generateTicketKey() + if err != nil { + return err + } + + m.keys = [][]byte{key} + + if m.config.KeyFile != "" { + if err := m.saveKeys(); err != nil { + return err + } + } + + return nil +} + +// generateTicketKey 生成新的随机 Session Ticket 密钥。 +// +// 使用 crypto/rand 生成加密安全的随机密钥。 +// +// 返回值: +// - []byte: 32 字节的随机密钥 +// - error: 随机数生成失败时返回错误 +func generateTicketKey() ([]byte, error) { + key := make([]byte, ticketKeySize) + if _, err := rand.Read(key); err != nil { + return nil, fmt.Errorf("failed to generate random key: %w", err) + } + return key, nil +} + +// GetKeyStatus 返回当前密钥状态信息。 +// +// 用于监控和调试,显示当前密钥数量和轮换状态。 +// +// 返回值: +// - SessionTicketStatus: 密钥状态信息 +type SessionTicketStatus struct { + // KeyCount 当前密钥数量 + KeyCount int + + // RetainKeys 配置的最大保留密钥数 + RetainKeys int + + // RotateInterval 配置的轮换间隔 + RotateInterval time.Duration + + // Started 管理器是否已启动 + Started bool +} + +// GetStatus 返回当前密钥状态。 +// +// 返回值: +// - SessionTicketStatus: 密钥状态信息 +func (m *SessionTicketManager) GetStatus() SessionTicketStatus { + m.mu.RLock() + defer m.mu.RUnlock() + + return SessionTicketStatus{ + KeyCount: len(m.keys), + RetainKeys: m.config.RetainKeys, + RotateInterval: m.config.RotateInterval, + Started: m.started, + } +} diff --git a/internal/ssl/session_tickets_test.go b/internal/ssl/session_tickets_test.go new file mode 100644 index 0000000..a90bc8e --- /dev/null +++ b/internal/ssl/session_tickets_test.go @@ -0,0 +1,474 @@ +// Package ssl 提供 Session Tickets 的单元测试。 +// +// 测试覆盖: +// - 密钥生成和加载 +// - 密钥轮换逻辑 +// - 多密钥保留策略 +// - 与 TLS 配置的集成 +// - 边界条件和错误处理 +// +// 作者:xfy +package ssl + +import ( + "crypto/tls" + "os" + "path/filepath" + "testing" + "time" + + "rua.plus/lolly/internal/config" +) + +// TestNewSessionTicketManager 测试创建 Session Ticket 管理器。 +func TestNewSessionTicketManager(t *testing.T) { + tests := []struct { + name string + cfg config.SessionTicketsConfig + wantError bool + checkDefaults bool + }{ + { + name: "disabled_should_error", + cfg: config.SessionTicketsConfig{ + Enabled: false, + }, + wantError: true, + }, + { + name: "enabled_without_keyfile", + cfg: config.SessionTicketsConfig{ + Enabled: true, + }, + wantError: false, + checkDefaults: true, + }, + { + name: "enabled_with_defaults", + cfg: config.SessionTicketsConfig{ + Enabled: true, + KeyFile: "", + RotateInterval: 0, + RetainKeys: 0, + }, + wantError: false, + checkDefaults: true, + }, + { + name: "enabled_with_custom_values", + cfg: config.SessionTicketsConfig{ + Enabled: true, + RotateInterval: 30 * time.Minute, + RetainKeys: 5, + }, + wantError: false, + checkDefaults: false, // 使用自定义值,不检查默认值 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mgr, err := NewSessionTicketManager(tt.cfg) + if tt.wantError { + if err == nil { + t.Errorf("NewSessionTicketManager() expected error, got nil") + } + return + } + if err != nil { + t.Errorf("NewSessionTicketManager() unexpected error: %v", err) + return + } + if mgr == nil { + t.Error("NewSessionTicketManager() returned nil manager") + return + } + defer mgr.Stop() + + // 验证默认配置(仅当使用默认值时) + if tt.checkDefaults { + if mgr.config.RotateInterval != defaultRotateInterval { + t.Errorf("RotateInterval = %v, want %v", mgr.config.RotateInterval, defaultRotateInterval) + } + if mgr.config.RetainKeys != defaultRetainKeys { + t.Errorf("RetainKeys = %d, want %d", mgr.config.RetainKeys, defaultRetainKeys) + } + } else { + // 验证自定义值被正确保留 + if mgr.config.RotateInterval != tt.cfg.RotateInterval { + t.Errorf("RotateInterval = %v, want %v", mgr.config.RotateInterval, tt.cfg.RotateInterval) + } + if mgr.config.RetainKeys != tt.cfg.RetainKeys { + t.Errorf("RetainKeys = %d, want %d", mgr.config.RetainKeys, tt.cfg.RetainKeys) + } + } + }) + } +} + +// TestSessionTicketManager_KeyGeneration 测试密钥生成。 +func TestSessionTicketManager_KeyGeneration(t *testing.T) { + mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{ + Enabled: true, + }) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer mgr.Stop() + + keys := mgr.GetKeys() + if len(keys) == 0 { + t.Fatal("Expected at least one key, got none") + } + + // 验证密钥大小 + for i, key := range keys { + if len(key) != ticketKeySize { + t.Errorf("Key %d size = %d, want %d", i, len(key), ticketKeySize) + } + } +} + +// TestSessionTicketManager_KeyRotation 测试密钥轮换。 +func TestSessionTicketManager_KeyRotation(t *testing.T) { + mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{ + Enabled: true, + RotateInterval: time.Hour, // 使用长间隔,手动触发轮换 + RetainKeys: 3, + }) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer mgr.Stop() + + initialKeys := mgr.GetKeys() + if len(initialKeys) != 1 { + t.Fatalf("Expected 1 initial key, got %d", len(initialKeys)) + } + + // 手动轮换密钥 + if err := mgr.RotateKey(); err != nil { + t.Fatalf("RotateKey() failed: %v", err) + } + + keysAfter1 := mgr.GetKeys() + if len(keysAfter1) != 2 { + t.Errorf("Expected 2 keys after rotation, got %d", len(keysAfter1)) + } + + // 验证新旧密钥不同 + if string(initialKeys[0]) == string(keysAfter1[1]) { + t.Error("New key should be different from initial key") + } + + // 继续轮换到超过保留数量 + _ = mgr.RotateKey() + _ = mgr.RotateKey() + + keysAfter4 := mgr.GetKeys() + if len(keysAfter4) != 3 { + t.Errorf("Expected 3 keys (max retain), got %d", len(keysAfter4)) + } +} + +// TestSessionTicketManager_KeyRetention 测试密钥保留策略。 +func TestSessionTicketManager_KeyRetention(t *testing.T) { + mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{ + Enabled: true, + RotateInterval: time.Hour, + RetainKeys: 2, + }) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer mgr.Stop() + + // 生成多个密钥 + for i := 0; i < 5; i++ { + if err := mgr.RotateKey(); err != nil { + t.Fatalf("RotateKey() failed at iteration %d: %v", i, err) + } + } + + keys := mgr.GetKeys() + if len(keys) != 2 { + t.Errorf("Expected 2 keys (RetainKeys limit), got %d", len(keys)) + } +} + +// TestSessionTicketManager_Persistence 测试密钥持久化。 +func TestSessionTicketManager_Persistence(t *testing.T) { + tempDir := t.TempDir() + keyFile := filepath.Join(tempDir, "ticket.key") + + // 创建管理器并生成密钥 + mgr1, err := NewSessionTicketManager(config.SessionTicketsConfig{ + Enabled: true, + KeyFile: keyFile, + RotateInterval: time.Hour, + RetainKeys: 3, + }) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + // 轮换几次生成多个密钥 + _ = mgr1.RotateKey() + _ = mgr1.RotateKey() + mgr1.Stop() + + // 验证密钥文件存在 + if _, err := os.Stat(keyFile); os.IsNotExist(err) { + t.Fatal("Key file should exist after saving") + } + + // 从文件加载密钥创建新管理器 + mgr2, err := NewSessionTicketManager(config.SessionTicketsConfig{ + Enabled: true, + KeyFile: keyFile, + RotateInterval: time.Hour, + RetainKeys: 3, + }) + if err != nil { + t.Fatalf("Failed to create manager from existing key file: %v", err) + } + defer mgr2.Stop() + + keys := mgr2.GetKeys() + if len(keys) != 3 { + t.Errorf("Expected 3 keys loaded from file, got %d", len(keys)) + } +} + +// TestSessionTicketManager_ApplyToTLSConfig 测试应用到 TLS 配置。 +func TestSessionTicketManager_ApplyToTLSConfig(t *testing.T) { + mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{ + Enabled: true, + RotateInterval: time.Hour, + RetainKeys: 3, + }) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer mgr.Stop() + + tlsCfg := &tls.Config{ + MinVersion: tls.VersionTLS13, + } + + mgr.ApplyToTLSConfig(tlsCfg) + + // 验证可以获取密钥 + keys := mgr.GetKeys() + if len(keys) == 0 { + t.Error("Expected keys to be set in TLS config") + } +} + +// TestSessionTicketManager_StartStop 测试启动和停止。 +func TestSessionTicketManager_StartStop(t *testing.T) { + mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{ + Enabled: true, + RotateInterval: 100 * time.Millisecond, + RetainKeys: 3, + }) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + // 验证初始状态 + status := mgr.GetStatus() + if status.Started { + t.Error("Manager should not be started initially") + } + + // 启动 + mgr.Start() + status = mgr.GetStatus() + if !status.Started { + t.Error("Manager should be started after Start()") + } + + // 等待一次轮换 + time.Sleep(150 * time.Millisecond) + + keys := mgr.GetKeys() + if len(keys) < 1 { + t.Error("Expected at least 1 key after auto-rotation") + } + + // 停止 + mgr.Stop() + status = mgr.GetStatus() + if status.Started { + t.Error("Manager should not be started after Stop()") + } +} + +// TestSessionTicketManager_GetStatus 测试获取状态。 +func TestSessionTicketManager_GetStatus(t *testing.T) { + mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{ + Enabled: true, + RotateInterval: 30 * time.Minute, + RetainKeys: 5, + }) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer mgr.Stop() + + status := mgr.GetStatus() + + if status.KeyCount != 1 { + t.Errorf("KeyCount = %d, want 1", status.KeyCount) + } + // 使用自定义值 5,不是默认值 + if status.RetainKeys != 5 { + t.Errorf("RetainKeys = %d, want 5", status.RetainKeys) + } + // RotateInterval 使用配置值(30m > 0,所以保留) + if status.RotateInterval != 30*time.Minute { + t.Errorf("RotateInterval = %v, want %v", status.RotateInterval, 30*time.Minute) + } + if status.Started { + t.Error("Started should be false before Start()") + } + + mgr.Start() + status = mgr.GetStatus() + if !status.Started { + t.Error("Started should be true after Start()") + } +} + +// TestGenerateTicketKey 测试密钥生成函数。 +func TestGenerateTicketKey(t *testing.T) { + key1, err := generateTicketKey() + if err != nil { + t.Fatalf("generateTicketKey() failed: %v", err) + } + + if len(key1) != ticketKeySize { + t.Errorf("generateTicketKey() key size = %d, want %d", len(key1), ticketKeySize) + } + + key2, err := generateTicketKey() + if err != nil { + t.Fatalf("generateTicketKey() second call failed: %v", err) + } + + // 验证生成的密钥是随机的(不相同) + if string(key1) == string(key2) { + t.Error("generateTicketKey() should generate random keys") + } +} + +// TestSessionTicketManager_ConcurrentAccess 测试并发访问。 +func TestSessionTicketManager_ConcurrentAccess(t *testing.T) { + mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{ + Enabled: true, + RotateInterval: 10 * time.Millisecond, + RetainKeys: 3, + }) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer mgr.Stop() + + mgr.Start() + + // 并发读取和轮换 + done := make(chan bool, 3) + + // 协程 1: 持续获取密钥 + go func() { + for i := 0; i < 100; i++ { + _ = mgr.GetKeys() + time.Sleep(time.Millisecond) + } + done <- true + }() + + // 协程 2: 持续获取状态 + go func() { + for i := 0; i < 100; i++ { + _ = mgr.GetStatus() + time.Sleep(time.Millisecond) + } + done <- true + }() + + // 协程 3: 手动轮换 + go func() { + for i := 0; i < 20; i++ { + _ = mgr.RotateKey() + time.Sleep(5 * time.Millisecond) + } + done <- true + }() + + // 等待所有协程完成 + for i := 0; i < 3; i++ { + <-done + } + + // 验证最终状态 + keys := mgr.GetKeys() + if len(keys) < 1 || len(keys) > 3 { + t.Errorf("Final key count %d out of expected range [1, 3]", len(keys)) + } +} + +// BenchmarkGenerateTicketKey 基准测试密钥生成。 +func BenchmarkGenerateTicketKey(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := generateTicketKey() + if err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkSessionTicketManager_GetKeys 基准测试获取密钥。 +func BenchmarkSessionTicketManager_GetKeys(b *testing.B) { + mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{ + Enabled: true, + RotateInterval: time.Hour, + RetainKeys: 3, + }) + if err != nil { + b.Fatalf("Failed to create manager: %v", err) + } + defer mgr.Stop() + + // 预生成多个密钥 + for i := 0; i < 2; i++ { + _ = mgr.RotateKey() + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = mgr.GetKeys() + } +} + +// BenchmarkSessionTicketManager_RotateKey 基准测试密钥轮换。 +func BenchmarkSessionTicketManager_RotateKey(b *testing.B) { + mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{ + Enabled: true, + RotateInterval: time.Hour, + RetainKeys: 3, + }) + if err != nil { + b.Fatalf("Failed to create manager: %v", err) + } + defer mgr.Stop() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := mgr.RotateKey() + if err != nil { + b.Fatal(err) + } + } +} diff --git a/internal/ssl/ssl.go b/internal/ssl/ssl.go index 1258830..0a26d34 100644 --- a/internal/ssl/ssl.go +++ b/internal/ssl/ssl.go @@ -62,6 +62,12 @@ type TLSManager struct { // ocspManager OCSP Stapling 管理器 ocspManager *OCSPManager + // sessionTicketMgr Session Ticket 管理器 + sessionTicketMgr *SessionTicketManager + + // clientVerifier 客户端证书验证器 + clientVerifier *ClientVerifier + // certificates 解析后的证书映射,用于 OCSP certificates map[string]*x509.Certificate @@ -132,6 +138,21 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) { issuers: make(map[string]*x509.Certificate), } + // 初始化 Session Tickets(如果启用) + if cfg.SessionTickets.Enabled { + sessionTicketMgr, err := NewSessionTicketManager(cfg.SessionTickets) + if err != nil { + // Session Tickets 初始化失败不阻止 TLS 工作 + // 可以记录日志 + _ = err + } else { + manager.sessionTicketMgr = sessionTicketMgr + // 应用 Session Tickets 到 TLS 配置 + sessionTicketMgr.ApplyToTLSConfig(tlsCfg) + sessionTicketMgr.Start() + } + } + // 初始化 OCSP Stapling(如果启用) if cfg.OCSPStapling { ocspMgr := NewOCSPManager(DefaultOCSPConfig()) @@ -164,6 +185,19 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) { ocspMgr.Start() } + // 初始化客户端证书验证(如果启用) + if cfg.ClientVerify.Enabled { + clientVerifier, err := NewClientVerifier(cfg.ClientVerify) + if err != nil { + // 客户端验证配置失败不阻止 TLS 工作 + // 可以记录日志 + _ = err + } else { + manager.clientVerifier = clientVerifier + clientVerifier.ConfigureTLS(tlsCfg) + } + } + // 设置为默认配置 manager.defaultCfg = tlsCfg @@ -310,11 +344,14 @@ func (m *TLSManager) RemoveCertificate(name string) { m.mu.Unlock() } -// Close 停止 OCSP 管理器并释放资源。 +// Close 停止 OCSP 管理器和 Session Ticket 管理器并释放资源。 func (m *TLSManager) Close() { if m.ocspManager != nil { m.ocspManager.Stop() } + if m.sessionTicketMgr != nil { + m.sessionTicketMgr.Stop() + } } // getConfigForClientWithOCSP 返回启用 OCSP Stapling 的 TLS 配置。