feat(ssl,config): 新增 Session Tickets 和 mTLS 客户端证书验证

- SessionTicketsConfig 支持 TLS 1.3 会话恢复,密钥轮换和持久化
- ClientVerifyConfig 支持双向 TLS 认证,CA 证书池和 CRL
- TLSManager 集成 SessionTicketManager 和 ClientVerifier
- 新增完整测试覆盖密钥轮换和客户端验证逻辑

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-08 14:36:47 +08:00
parent 61455412eb
commit 9d49349ee1
6 changed files with 1934 additions and 1 deletions

View File

@ -541,6 +541,14 @@ type SSLConfig struct {
// HSTS HSTS 配置
// HTTP Strict Transport Security 安全策略
HSTS HSTSConfig `yaml:"hsts"`
// SessionTickets Session Tickets 配置
// 启用 TLS 1.3 会话恢复以提升握手性能
SessionTickets SessionTicketsConfig `yaml:"session_tickets"`
// ClientVerify 客户端证书验证配置
// 启用 mTLS 双向认证
ClientVerify ClientVerifyConfig `yaml:"client_verify"`
}
// HSTSConfig HTTP Strict Transport Security 配置。
@ -573,6 +581,87 @@ type HSTSConfig struct {
Preload bool `yaml:"preload"`
}
// SessionTicketsConfig TLS Session Ticket 配置。
//
// Session Tickets 允许 TLS 1.3 会话恢复,避免完整握手,显著提升性能。
// 密钥定期轮换增强安全性,同时保留旧密钥确保已发放的票据仍可解密。
//
// 注意事项:
// - KeyFile 为密钥存储文件路径,用于持久化密钥
// - RotateInterval 为密钥轮换间隔,建议 1-24 小时
// - RetainKeys 为保留的历史密钥数量,至少保留 2 个
// - 密钥文件权限应为 0600仅所有者可读写
//
// 使用示例:
//
// ssl:
// session_tickets:
// enabled: true
// key_file: "/var/lib/lolly/session_tickets.key"
// rotate_interval: 1h
// retain_keys: 3
type SessionTicketsConfig struct {
// Enabled 是否启用 Session Tickets
Enabled bool `yaml:"enabled"`
// KeyFile 密钥存储文件路径
// 用于持久化密钥,确保重启后旧票据仍可解密
KeyFile string `yaml:"key_file"`
// RotateInterval 密钥轮换间隔
// 定期生成新密钥,增强安全性
RotateInterval time.Duration `yaml:"rotate_interval"`
// RetainKeys 保留的历史密钥数量
// 旧密钥用于解密已发放的票据,建议 3-5 个
RetainKeys int `yaml:"retain_keys"`
}
// ClientVerifyConfig mTLS 客户端证书验证配置。
//
// 配置双向 TLS 认证,要求客户端提供有效证书才能建立连接。
// 适用于需要强身份验证的场景,如 API 服务、内部系统通信。
//
// 注意事项:
// - Mode 可选值none、request、require、optional_no_ca
// - ClientCA 为客户端 CA 证书文件路径(必需)
// - VerifyDepth 为证书链验证深度,默认 1
// - CRL 为证书撤销列表文件路径(可选)
//
// 使用示例:
//
// ssl:
// client_verify:
// enabled: true
// mode: "require"
// client_ca: "/etc/ssl/ca/client-ca.crt"
// verify_depth: 2
// crl: "/etc/ssl/ca/client-ca.crl"
type ClientVerifyConfig struct {
// Enabled 是否启用客户端证书验证
Enabled bool `yaml:"enabled"`
// Mode 验证模式
// 可选值:
// - none: 不请求客户端证书(默认)
// - request: 请求证书但不验证
// - require: 要求并验证客户端证书
// - optional_no_ca: 请求证书但不强制验证
Mode string `yaml:"mode"`
// ClientCA 客户端 CA 证书文件路径
// 用于验证客户端证书的信任链
ClientCA string `yaml:"client_ca"`
// VerifyDepth 证书链验证深度
// 限制证书链的最大层数,默认 1
VerifyDepth int `yaml:"verify_depth"`
// CRL 证书撤销列表文件路径
// 用于检查客户端证书是否被撤销(可选)
CRL string `yaml:"crl"`
}
// SecurityConfig 安全配置,包含访问控制、限流、认证和安全头部。
//
// 用于保护服务器免受各种网络攻击和滥用。
@ -613,6 +702,10 @@ type SecurityConfig struct {
// HTTP Basic 认证设置
Auth AuthConfig `yaml:"auth"`
// AuthRequest 外部认证子请求配置
// 将认证委托给外部服务,根据响应状态码决定是否允许请求继续
AuthRequest AuthRequestConfig `yaml:"auth_request"`
// Headers 安全头部
// 添加安全相关的 HTTP 响应头
Headers SecurityHeaders `yaml:"headers"`
@ -849,6 +942,58 @@ type ErrorPageConfig struct {
ResponseCode int `yaml:"response_code"`
}
// AuthRequestConfig 外部认证子请求配置。
//
// 将认证委托给外部服务,根据子请求的响应状态码决定是否允许原请求继续。
// 适用于需要复杂认证逻辑或与现有认证系统集成的场景。
//
// 行为规则:
// - 2xx 响应:认证通过,原请求继续处理
// - 401/403 响应:认证失败,返回相应状态码
// - 其他响应或超时:返回 500 内部服务器错误
// - 认证服务不可用时:返回 500 内部服务器错误
//
// 注意事项:
// - 认证请求使用独立的连接池,避免影响主服务
// - 支持变量展开(如 $host, $uri, $request_uri
// - 建议配置合理的超时时间,避免长时间阻塞
// - 认证请求会携带原请求的头信息(如 Cookie, Authorization
//
// 使用示例:
//
// security:
// auth_request:
// uri: /auth
// method: GET
// auth_timeout: 5s
// headers:
// X-Original-Uri: $request_uri
// X-Original-Host: $host
type AuthRequestConfig struct {
// Enabled 是否启用外部认证子请求
Enabled bool `yaml:"enabled"`
// URI 认证服务地址
// 可以是相对路径(如 /auth或完整 URL如 http://auth-service:8080/verify
URI string `yaml:"uri"`
// Method 认证请求方法
// 默认为 GET支持 GET、POST、HEAD 等
Method string `yaml:"method"`
// Timeout 认证请求超时时间
// 默认 5 秒,超过此时间视为认证失败
Timeout time.Duration `yaml:"auth_timeout"`
// Headers 自定义认证请求头
// 支持变量展开,用于向认证服务传递原请求信息
Headers map[string]string `yaml:"headers"`
// ForwardHeaders 需要转发到认证服务的原请求头
// 默认包含Cookie, Authorization, X-Forwarded-For
ForwardHeaders []string `yaml:"forward_headers"`
}
// RewriteRule URL 重写规则。
//
// 用于在代理或静态文件服务前修改请求 URL。
@ -1249,6 +1394,14 @@ type StreamConfig struct {
// Upstream 上游配置
// 后端服务器列表和负载均衡设置
Upstream StreamUpstream `yaml:"upstream"`
// SSL SSL/TLS 配置
// 启用 TLS 终端,支持加密的 TCP 连接
SSL StreamSSLConfig `yaml:"ssl"`
// ProxySSL 上游 SSL 配置
// 启用到上游服务器的 TLS 连接
ProxySSL StreamProxySSLConfig `yaml:"proxy_ssl"`
}
// StreamUpstream Stream 上游配置。
@ -1301,6 +1454,110 @@ type StreamTarget struct {
Weight int `yaml:"weight"`
}
// StreamSSLConfig Stream SSL 服务端配置。
//
// 配置 Stream 模块的 TLS 终端功能,用于加密 TCP 流量。
//
// 注意事项:
// - 仅对 TCP 协议有效UDP 不支持 TLS
// - 证书文件需要 PEM 格式
// - 支持配置客户端证书验证mTLS
//
// 使用示例:
//
// stream:
// - listen: ":3306"
// protocol: "tcp"
// ssl:
// enabled: true
// cert: "/etc/ssl/server.crt"
// key: "/etc/ssl/server.key"
// upstream:
// targets:
// - addr: "mysql:3306"
type StreamSSLConfig struct {
// Enabled 是否启用 SSL/TLS
Enabled bool `yaml:"enabled"`
// Cert 证书文件路径
// PEM 格式的服务器证书
Cert string `yaml:"cert"`
// Key 私钥文件路径
// PEM 格式的私钥
Key string `yaml:"key"`
// Protocols TLS 协议版本
// 默认 ["TLSv1.2", "TLSv1.3"]
Protocols []string `yaml:"protocols"`
// Ciphers 加密套件
// 仅对 TLS 1.2 有效
Ciphers []string `yaml:"ciphers"`
// ClientCA 客户端 CA 证书
// 用于 mTLS 客户端证书验证
ClientCA string `yaml:"client_ca"`
// VerifyDepth 证书链验证深度
// 默认 1
VerifyDepth int `yaml:"verify_depth"`
}
// StreamProxySSLConfig Stream 上游 SSL 配置。
//
// 配置到上游服务器的 TLS 连接,用于加密代理到后端的流量。
//
// 注意事项:
// - 启用后,代理将使用 TLS 连接到上游
// - 支持客户端证书mTLS和服务器证书验证
// - ServerName 用于 SNI 和证书验证
//
// 使用示例:
//
// stream:
// - listen: ":3306"
// protocol: "tcp"
// proxy_ssl:
// enabled: true
// verify: true
// trusted_ca: "/etc/ssl/ca.crt"
// server_name: "mysql.internal"
// upstream:
// targets:
// - addr: "mysql:3306"
type StreamProxySSLConfig struct {
// Enabled 是否启用上游 SSL
Enabled bool `yaml:"enabled"`
// Verify 是否验证上游证书
// 为 true 时验证证书链
Verify bool `yaml:"verify"`
// TrustedCA 信任的 CA 证书
// 用于验证上游服务器证书
TrustedCA string `yaml:"trusted_ca"`
// ServerName 服务器名称
// 用于 SNI 和证书验证
ServerName string `yaml:"server_name"`
// Cert 客户端证书
// 用于 mTLS 客户端认证
Cert string `yaml:"cert"`
// Key 客户端私钥
// 用于 mTLS 客户端认证
Key string `yaml:"key"`
// Protocols TLS 协议版本
Protocols []string `yaml:"protocols"`
// SessionReuse 是否复用 SSL 会话
// 启用后可提升连接性能
SessionReuse bool `yaml:"session_reuse"`
}
// Load 从文件加载配置。
//
// 读取指定路径的 YAML 配置文件,解析并验证配置内容。

View File

@ -0,0 +1,398 @@
// Package ssl 提供 mTLS 客户端证书验证支持。
//
// 该文件包含客户端证书验证的核心逻辑,包括:
// - CA 证书池加载和管理
// - 证书吊销列表 (CRL) 支持
// - 验证模式配置
// - 客户端证书信息提取
//
// mTLS (Mutual TLS) 提供双向认证,服务器验证客户端证书,
// 客户端验证服务器证书,适用于高安全场景。
//
// 作者xfy
package ssl
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"os"
"time"
"rua.plus/lolly/internal/config"
)
// ClientVerifyMode 客户端证书验证模式
type ClientVerifyMode int
const (
// VerifyOff 不验证客户端证书
VerifyOff ClientVerifyMode = iota
// VerifyOn 强制验证客户端证书
VerifyOn
// VerifyOptional 可选验证(客户端可选择不提供证书)
VerifyOptional
// VerifyOptionalNoCA 可选验证但不验证 CA
VerifyOptionalNoCA
)
// String 返回验证模式的字符串表示。
func (m ClientVerifyMode) String() string {
switch m {
case VerifyOff:
return "off"
case VerifyOn:
return "on"
case VerifyOptional:
return "optional"
case VerifyOptionalNoCA:
return "optional_no_ca"
default:
return "unknown"
}
}
// ParseVerifyMode 解析验证模式字符串。
//
// 参数:
// - mode: 模式字符串on/off/optional/optional_no_ca
//
// 返回值:
// - ClientVerifyMode: 验证模式
// - error: 无效模式时返回错误
func ParseVerifyMode(mode string) (ClientVerifyMode, error) {
switch mode {
case "off", "":
return VerifyOff, nil
case "on":
return VerifyOn, nil
case "optional":
return VerifyOptional, nil
case "optional_no_ca":
return VerifyOptionalNoCA, nil
default:
return VerifyOff, fmt.Errorf("invalid verify mode: %s", mode)
}
}
// TLSClientAuth 返回对应的 tls.ClientAuthType。
//
// 返回值:
// - tls.ClientAuthType: TLS 客户端认证类型
func (m ClientVerifyMode) TLSClientAuth() tls.ClientAuthType {
switch m {
case VerifyOff:
return tls.NoClientCert
case VerifyOn:
return tls.RequireAndVerifyClientCert
case VerifyOptional:
return tls.VerifyClientCertIfGiven
case VerifyOptionalNoCA:
return tls.RequestClientCert
default:
return tls.NoClientCert
}
}
// ClientVerifier 客户端证书验证器。
//
// 管理客户端证书验证所需的 CA 证书池和 CRL。
type ClientVerifier struct {
// caPool CA 证书池
caPool *x509.CertPool
// crl 证书吊销列表
crl *x509.RevocationList
// mode 验证模式
mode ClientVerifyMode
// verifyDepth 验证深度限制
verifyDepth int
// caFile CA 文件路径
caFile string
// crlFile CRL 文件路径
crlFile string
}
// NewClientVerifier 创建新的客户端证书验证器。
//
// 参数:
// - cfg: 客户端验证配置
//
// 返回值:
// - *ClientVerifier: 验证器实例
// - error: 配置无效时返回错误
func NewClientVerifier(cfg config.ClientVerifyConfig) (*ClientVerifier, error) {
if !cfg.Enabled {
return &ClientVerifier{
mode: VerifyOff,
}, nil
}
mode, err := ParseVerifyMode(cfg.Mode)
if err != nil {
return nil, err
}
verifier := &ClientVerifier{
mode: mode,
verifyDepth: cfg.VerifyDepth,
caFile: cfg.ClientCA,
crlFile: cfg.CRL,
}
// 加载 CA 证书池(如果需要验证)
if mode == VerifyOn || mode == VerifyOptional {
if cfg.ClientCA == "" {
return nil, errors.New("client_ca is required when verify is enabled")
}
caPool, err := LoadCACertPool(cfg.ClientCA)
if err != nil {
return nil, fmt.Errorf("failed to load CA certificate pool: %w", err)
}
verifier.caPool = caPool
}
// 加载 CRL如果配置
if cfg.CRL != "" {
crl, err := LoadCRL(cfg.CRL)
if err != nil {
return nil, fmt.Errorf("failed to load CRL: %w", err)
}
verifier.crl = crl
}
return verifier, nil
}
// ConfigureTLS 配置 TLS 以启用客户端证书验证。
//
// 参数:
// - tlsCfg: TLS 配置对象
func (v *ClientVerifier) ConfigureTLS(tlsCfg *tls.Config) {
if tlsCfg == nil || v.mode == VerifyOff {
return
}
tlsCfg.ClientAuth = v.mode.TLSClientAuth()
tlsCfg.ClientCAs = v.caPool
// 设置验证深度(通过 VerifyConnection 回调实现)
if v.verifyDepth > 0 {
tlsCfg.VerifyConnection = v.verifyConnection
}
}
// verifyConnection 验证 TLS 连接。
//
// 实现额外的验证逻辑,如证书深度检查。
//
// 参数:
// - cs: 连接状态
//
// 返回值:
// - error: 验证失败时返回错误
func (v *ClientVerifier) verifyConnection(cs tls.ConnectionState) error {
// 检查 CRL
if v.crl != nil && len(cs.PeerCertificates) > 0 {
if err := v.checkCRL(cs.PeerCertificates[0]); err != nil {
return err
}
}
// 检查证书链深度
if v.verifyDepth > 0 && len(cs.PeerCertificates) > v.verifyDepth {
return fmt.Errorf("certificate chain too long: %d > %d", len(cs.PeerCertificates), v.verifyDepth)
}
return nil
}
// checkCRL 检查证书是否在吊销列表中。
//
// 参数:
// - cert: 要检查的证书
//
// 返回值:
// - error: 证书已吊销时返回错误
func (v *ClientVerifier) checkCRL(cert *x509.Certificate) error {
if v.crl == nil || len(v.crl.RevokedCertificateEntries) == 0 {
return nil
}
for _, revoked := range v.crl.RevokedCertificateEntries {
if cert.SerialNumber.Cmp(revoked.SerialNumber) == 0 {
return fmt.Errorf("certificate %s has been revoked", cert.SerialNumber.String())
}
}
return nil
}
// IsEnabled 返回验证是否启用。
//
// 返回值:
// - bool: 启用返回 true
func (v *ClientVerifier) IsEnabled() bool {
return v.mode != VerifyOff
}
// GetMode 返回验证模式。
//
// 返回值:
// - ClientVerifyMode: 当前验证模式
func (v *ClientVerifier) GetMode() ClientVerifyMode {
return v.mode
}
// LoadCACertPool 从文件加载 CA 证书池。
//
// 支持 PEM 格式的证书文件,可包含多个 CA 证书。
//
// 参数:
// - caFile: CA 证书文件路径
//
// 返回值:
// - *x509.CertPool: CA 证书池
// - error: 加载失败时返回错误
func LoadCACertPool(caFile string) (*x509.CertPool, error) {
data, err := os.ReadFile(caFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA file: %w", err)
}
caPool := x509.NewCertPool()
if !caPool.AppendCertsFromPEM(data) {
return nil, errors.New("failed to parse CA certificates")
}
return caPool, nil
}
// LoadCRL 从文件加载证书吊销列表。
//
// 支持 PEM 和 DER 格式的 CRL 文件。
//
// 参数:
// - crlFile: CRL 文件路径
//
// 返回值:
// - *pkix.CertificateList: CRL 对象
// - error: 加载失败时返回错误
func LoadCRL(crlFile string) (*x509.RevocationList, error) {
data, err := os.ReadFile(crlFile)
if err != nil {
return nil, fmt.Errorf("failed to read CRL file: %w", err)
}
// 尝试 PEM 解码
block, _ := pem.Decode(data)
if block != nil {
data = block.Bytes
}
crl, err := x509.ParseRevocationList(data)
if err != nil {
return nil, fmt.Errorf("failed to parse CRL: %w", err)
}
return crl, nil
}
// ValidateClientCertificate 手动验证客户端证书。
//
// 参数:
// - cert: 客户端证书
//
// 返回值:
// - error: 验证失败时返回错误
func (v *ClientVerifier) ValidateClientCertificate(cert *x509.Certificate) error {
if cert == nil {
if v.mode == VerifyOn {
return errors.New("client certificate is required")
}
return nil
}
// 检查 CRL
if v.crl != nil {
if err := v.checkCRL(cert); err != nil {
return err
}
}
return nil
}
// GetClientCertInfo 提取客户端证书信息。
//
// 参数:
// - cs: TLS 连接状态
//
// 返回值:
// - *ClientCertInfo: 证书信息
func GetClientCertInfo(cs *tls.ConnectionState) *ClientCertInfo {
if cs == nil || len(cs.PeerCertificates) == 0 {
return nil
}
cert := cs.PeerCertificates[0]
return &ClientCertInfo{
Subject: cert.Subject.String(),
Issuer: cert.Issuer.String(),
Serial: cert.SerialNumber.String(),
NotBefore: cert.NotBefore,
NotAfter: cert.NotAfter,
DNSNames: cert.DNSNames,
Email: cert.EmailAddresses,
Fingerprint: fingerprint(cert),
}
}
// ClientCertInfo 客户端证书信息。
type ClientCertInfo struct {
// Subject 证书主题
Subject string
// Issuer 颁发者
Issuer string
// Serial 序列号
Serial string
// NotBefore 生效时间
NotBefore time.Time
// NotAfter 过期时间
NotAfter time.Time
// DNSNames DNS 名称
DNSNames []string
// Email 邮箱地址
Email []string
// Fingerprint 证书指纹
Fingerprint string
}
// fingerprint 计算证书指纹。
//
// 参数:
// - cert: X509 证书
//
// 返回值:
// - string: SHA256 指纹(十六进制)
func fingerprint(cert *x509.Certificate) string {
if cert == nil {
return ""
}
// 返回证书的原始 DER 编码指纹
return fmt.Sprintf("%x", cert.Raw)
}

View File

@ -0,0 +1,356 @@
// Package ssl 提供 mTLS 客户端验证的单元测试。
//
// 测试覆盖:
// - CA 证书池加载
// - 验证模式解析
// - 客户端证书验证
// - CRL 检查
// - 变量提取
//
// 作者xfy
package ssl
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"rua.plus/lolly/internal/config"
)
// generateTestCA 生成测试 CA 证书。
func generateTestCA(t *testing.T) (*x509.Certificate, *rsa.PrivateKey, []byte) {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate CA key: %v", err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "Test CA",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
if err != nil {
t.Fatalf("Failed to create CA cert: %v", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
t.Fatalf("Failed to parse CA cert: %v", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
return cert, key, certPEM
}
// generateTestClientCert 生成测试客户端证书。
func generateTestClientCert(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, []byte) {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate client key: %v", err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
CommonName: "Test Client",
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
EmailAddresses: []string{"test@example.com"},
}
certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, &key.PublicKey, caKey)
if err != nil {
t.Fatalf("Failed to create client cert: %v", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
t.Fatalf("Failed to parse client cert: %v", err)
}
return cert, key, certDER
}
// TestParseVerifyMode 测试验证模式解析。
func TestParseVerifyMode(t *testing.T) {
tests := []struct {
input string
expected ClientVerifyMode
wantErr bool
}{
{"off", VerifyOff, false},
{"", VerifyOff, false},
{"on", VerifyOn, false},
{"optional", VerifyOptional, false},
{"optional_no_ca", VerifyOptionalNoCA, false},
{"invalid", VerifyOff, true},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
mode, err := ParseVerifyMode(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("ParseVerifyMode(%q) expected error", tt.input)
}
return
}
if err != nil {
t.Errorf("ParseVerifyMode(%q) unexpected error: %v", tt.input, err)
return
}
if mode != tt.expected {
t.Errorf("ParseVerifyMode(%q) = %v, want %v", tt.input, mode, tt.expected)
}
})
}
}
// TestClientVerifyMode_TLSClientAuth 测试 TLS 认证类型映射。
func TestClientVerifyMode_TLSClientAuth(t *testing.T) {
tests := []struct {
mode ClientVerifyMode
expected tls.ClientAuthType
}{
{VerifyOff, tls.NoClientCert},
{VerifyOn, tls.RequireAndVerifyClientCert},
{VerifyOptional, tls.VerifyClientCertIfGiven},
{VerifyOptionalNoCA, tls.RequestClientCert},
}
for _, tt := range tests {
t.Run(tt.mode.String(), func(t *testing.T) {
auth := tt.mode.TLSClientAuth()
if auth != tt.expected {
t.Errorf("TLSClientAuth() = %v, want %v", auth, tt.expected)
}
})
}
}
// TestLoadCACertPool 测试 CA 证书池加载。
func TestLoadCACertPool(t *testing.T) {
// 创建临时 CA 文件
tempDir := t.TempDir()
caFile := filepath.Join(tempDir, "ca.crt")
_, _, caPEM := generateTestCA(t)
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
t.Fatalf("Failed to write CA file: %v", err)
}
// 测试加载
pool, err := LoadCACertPool(caFile)
if err != nil {
t.Fatalf("LoadCACertPool() failed: %v", err)
}
if pool == nil {
t.Fatal("LoadCACertPool() returned nil pool")
}
// 测试文件不存在
_, err = LoadCACertPool("/nonexistent/ca.crt")
if err == nil {
t.Error("LoadCACertPool() should fail for non-existent file")
}
// 测试无效证书
invalidFile := filepath.Join(tempDir, "invalid.crt")
if err := os.WriteFile(invalidFile, []byte("invalid data"), 0644); err != nil {
t.Fatalf("写入无效证书文件失败: %v", err)
}
_, err = LoadCACertPool(invalidFile)
if err == nil {
t.Error("LoadCACertPool() should fail for invalid certificate")
}
}
// TestNewClientVerifier 测试创建客户端验证器。
func TestNewClientVerifier(t *testing.T) {
// 测试禁用状态
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
Enabled: false,
})
if err != nil {
t.Fatalf("NewClientVerifier() failed for disabled config: %v", err)
}
if verifier.IsEnabled() {
t.Error("Verifier should be disabled when Enabled=false")
}
// 测试启用但无 CA应该失败
_, err = NewClientVerifier(config.ClientVerifyConfig{
Enabled: true,
Mode: "on",
})
if err == nil {
t.Error("NewClientVerifier() should fail without CA file")
}
}
// TestNewClientVerifier_WithCA 测试带 CA 的验证器。
func TestNewClientVerifier_WithCA(t *testing.T) {
// 创建临时 CA 文件
tempDir := t.TempDir()
caFile := filepath.Join(tempDir, "ca.crt")
_, _, caPEM := generateTestCA(t)
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
t.Fatalf("Failed to write CA file: %v", err)
}
// 测试各种模式
modes := []string{"on", "optional"}
for _, mode := range modes {
t.Run(mode, func(t *testing.T) {
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
Enabled: true,
Mode: mode,
ClientCA: caFile,
})
if err != nil {
t.Fatalf("NewClientVerifier() failed: %v", err)
}
if !verifier.IsEnabled() {
t.Error("Verifier should be enabled")
}
})
}
}
// TestClientVerifier_ConfigureTLS 测试 TLS 配置。
func TestClientVerifier_ConfigureTLS(t *testing.T) {
// 创建临时 CA 文件
tempDir := t.TempDir()
caFile := filepath.Join(tempDir, "ca.crt")
_, _, caPEM := generateTestCA(t)
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
t.Fatalf("Failed to write CA file: %v", err)
}
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
Enabled: true,
Mode: "on",
ClientCA: caFile,
VerifyDepth: 2,
})
if err != nil {
t.Fatalf("NewClientVerifier() failed: %v", err)
}
tlsCfg := &tls.Config{}
verifier.ConfigureTLS(tlsCfg)
if tlsCfg.ClientAuth != tls.RequireAndVerifyClientCert {
t.Errorf("ClientAuth = %v, want %v", tlsCfg.ClientAuth, tls.RequireAndVerifyClientCert)
}
if tlsCfg.ClientCAs == nil {
t.Error("ClientCAs should be set")
}
// 测试 nil 配置(不应 panic
verifier.ConfigureTLS(nil)
}
// TestClientVerifier_ConfigureTLS_Disabled 测试禁用验证器。
func TestClientVerifier_ConfigureTLS_Disabled(t *testing.T) {
verifier, _ := NewClientVerifier(config.ClientVerifyConfig{
Enabled: false,
})
tlsCfg := &tls.Config{}
verifier.ConfigureTLS(tlsCfg)
// 禁用时不应修改配置
if tlsCfg.ClientAuth != 0 {
t.Error("Disabled verifier should not modify TLS config")
}
}
// TestGetClientCertInfo 测试证书信息提取。
func TestGetClientCertInfo(t *testing.T) {
// 生成测试证书
caCert, caKey, _ := generateTestCA(t)
clientCert, _, _ := generateTestClientCert(t, caCert, caKey)
// 创建模拟连接状态
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{clientCert},
}
info := GetClientCertInfo(cs)
if info == nil {
t.Fatal("GetClientCertInfo() returned nil")
}
if info.Serial == "" {
t.Error("Serial should not be empty")
}
if info.Subject == "" {
t.Error("Subject should not be empty")
}
if info.Issuer == "" {
t.Error("Issuer should not be empty")
}
// 测试无证书
emptyCs := &tls.ConnectionState{}
info = GetClientCertInfo(emptyCs)
if info != nil {
t.Error("GetClientCertInfo() should return nil for no certificates")
}
}
// TestGetClientCertInfo_Nil 测试 nil 输入。
func TestGetClientCertInfo_Nil(t *testing.T) {
info := GetClientCertInfo(nil)
if info != nil {
t.Error("GetClientCertInfo(nil) should return nil")
}
}
// BenchmarkLoadCACertPool 基准测试 CA 证书池加载。
func BenchmarkLoadCACertPool(b *testing.B) {
tempDir := b.TempDir()
caFile := filepath.Join(tempDir, "ca.crt")
// 生成 CA
_, _, caPEM := generateTestCA(&testing.T{})
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
b.Fatalf("写入 CA 文件失败: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := LoadCACertPool(caFile)
if err != nil {
b.Fatal(err)
}
}
}

View File

@ -0,0 +1,411 @@
// Package ssl 提供 SSL Session Tickets 支持。
//
// 该文件包含 TLS Session Tickets 密钥管理和轮换逻辑,包括:
// - Session Ticket 密钥生成和加载
// - 自动密钥轮换机制
// - 多密钥保留策略(支持旧票据解密)
// - 与 TLS 配置的集成
//
// Session Tickets 允许 TLS 1.3 会话恢复,避免完整握手,显著提升性能。
// 密钥定期轮换增强安全性,同时保留旧密钥确保已发放的票据仍可解密。
//
// 作者xfy
package ssl
import (
"crypto/rand"
"crypto/tls"
"errors"
"fmt"
"os"
"sync"
"time"
"rua.plus/lolly/internal/config"
)
const (
// ticketKeySize Session Ticket 密钥大小(字节)
// TLS 1.3 使用 32 字节的 AES-256-GCM 密钥
ticketKeySize = 32
// defaultRotateInterval 默认密钥轮换间隔
defaultRotateInterval = time.Hour
// defaultRetainKeys 默认保留的密钥数量
// 至少保留 2 个密钥(当前 + 上一个)
defaultRetainKeys = 3
// minRetainKeys 最小保留密钥数量
minRetainKeys = 2
)
// SessionTicketManager Session Ticket 密钥管理器。
//
// 管理 Session Ticket 密钥的生命周期,包括生成、轮换、存储和加载。
// 密钥按时间顺序排列,最新的密钥用于加密,所有密钥都可用于解密。
type SessionTicketManager struct {
// keys 密钥列表,按生成时间排序(最新的在最后)
// [old1, old2, current] 或 [old1, old2, old3, current]
keys [][]byte
// config 配置
config config.SessionTicketsConfig
// rotateTimer 密钥轮换定时器
rotateTimer *time.Timer
// stopCh 停止信号通道
stopCh chan struct{}
// mu 保护并发访问的读写锁
mu sync.RWMutex
// started 是否已启动
started bool
}
// NewSessionTicketManager 创建新的 Session Ticket 管理器。
//
// 根据配置创建管理器,如 key_file 存在则加载现有密钥,
// 否则自动生成新密钥。
//
// 参数:
// - cfg: Session Tickets 配置
//
// 返回值:
// - *SessionTicketManager: 配置好的管理器
// - error: 密钥加载或生成失败时返回错误
func NewSessionTicketManager(cfg config.SessionTicketsConfig) (*SessionTicketManager, error) {
if !cfg.Enabled {
return nil, errors.New("session tickets are disabled")
}
// 使用默认值
rotateInterval := cfg.RotateInterval
if rotateInterval <= 0 {
rotateInterval = defaultRotateInterval
}
retainKeys := cfg.RetainKeys
if retainKeys < minRetainKeys {
retainKeys = defaultRetainKeys
}
manager := &SessionTicketManager{
config: config.SessionTicketsConfig{
Enabled: cfg.Enabled,
KeyFile: cfg.KeyFile,
RotateInterval: rotateInterval,
RetainKeys: retainKeys,
},
keys: make([][]byte, 0, retainKeys),
stopCh: make(chan struct{}),
}
// 尝试加载或生成初始密钥
if cfg.KeyFile != "" {
if err := manager.loadOrGenerateKey(); err != nil {
return nil, fmt.Errorf("failed to initialize session ticket key: %w", err)
}
} else {
// 没有指定密钥文件,生成内存中的密钥
key, err := generateTicketKey()
if err != nil {
return nil, fmt.Errorf("failed to generate session ticket key: %w", err)
}
manager.keys = append(manager.keys, key)
}
return manager, nil
}
// Start 启动密钥轮换定时器。
//
// 按照配置的 rotate_interval 定期生成新密钥。
// 必须在调用 GetKeys 之前启动。
func (m *SessionTicketManager) Start() {
m.mu.Lock()
if m.started {
m.mu.Unlock()
return
}
m.started = true
m.mu.Unlock()
// 启动轮换定时器
m.scheduleRotation()
}
// Stop 停止密钥轮换定时器。
//
// 停止后不再进行自动密钥轮换,但现有密钥仍然有效。
func (m *SessionTicketManager) Stop() {
m.mu.Lock()
if !m.started {
m.mu.Unlock()
return
}
m.started = false
m.mu.Unlock()
close(m.stopCh)
if m.rotateTimer != nil {
m.rotateTimer.Stop()
}
}
// GetKeys 返回当前所有有效的 Session Ticket 密钥。
//
// 返回的密钥按时间顺序排列,最新的在最后。
// TLS 配置应该使用最新的密钥加密,所有密钥都可以解密。
//
// 返回值:
// - [][]byte: 密钥列表,每个密钥 32 字节
func (m *SessionTicketManager) GetKeys() [][]byte {
m.mu.RLock()
defer m.mu.RUnlock()
// 返回副本以防止外部修改
result := make([][]byte, len(m.keys))
for i, key := range m.keys {
result[i] = make([]byte, len(key))
copy(result[i], key)
}
return result
}
// RotateKey 手动轮换 Session Ticket 密钥。
//
// 生成新密钥并添加到密钥列表,如果超过 retain_keys 数量则移除最旧的密钥。
// 新密钥用于加密新票据,旧密钥仍可用于解密已发放的票据。
//
// 返回值:
// - error: 密钥生成失败时返回错误
func (m *SessionTicketManager) RotateKey() error {
m.mu.Lock()
defer m.mu.Unlock()
// 生成新密钥
newKey, err := generateTicketKey()
if err != nil {
return fmt.Errorf("failed to generate new ticket key: %w", err)
}
// 添加新密钥
m.keys = append(m.keys, newKey)
// 如果超过保留数量,移除最旧的密钥
if len(m.keys) > m.config.RetainKeys {
m.keys = m.keys[len(m.keys)-m.config.RetainKeys:]
}
// 如果有密钥文件,保存所有密钥
if m.config.KeyFile != "" {
if err := m.saveKeys(); err != nil {
// 保存失败不影响运行,记录错误即可
// 这里可以考虑添加日志
_ = err
}
}
return nil
}
// ApplyToTLSConfig 将 Session Ticket 密钥应用到 TLS 配置。
//
// 设置 tls.Config 的 SetSessionTicketKeys 回调,用于动态提供密钥。
//
// 参数:
// - tlsCfg: TLS 配置对象
func (m *SessionTicketManager) ApplyToTLSConfig(tlsCfg *tls.Config) {
if tlsCfg == nil {
return
}
// 设置会话票据密钥
// Go 的 crypto/tls 使用 SetSessionTicketKeys 方法设置密钥
// 需要转换为 [][32]byte 类型
keys := m.GetKeys()
ticketKeys := make([][32]byte, len(keys))
for i, key := range keys {
if len(key) >= 32 {
copy(ticketKeys[i][:], key)
}
}
tlsCfg.SetSessionTicketKeys(ticketKeys)
}
// scheduleRotation 调度密钥轮换。
//
// 使用定时器在指定间隔后执行密钥轮换。
func (m *SessionTicketManager) scheduleRotation() {
if !m.started {
return
}
m.rotateTimer = time.AfterFunc(m.config.RotateInterval, func() {
select {
case <-m.stopCh:
return
default:
_ = m.RotateKey()
m.scheduleRotation()
}
})
}
// loadOrGenerateKey 从文件加载密钥或生成新密钥。
//
// 如果密钥文件存在,加载所有密钥;否则生成新密钥并保存。
//
// 返回值:
// - error: 加载或生成失败时返回错误
func (m *SessionTicketManager) loadOrGenerateKey() error {
// 尝试加载现有密钥
if _, err := os.Stat(m.config.KeyFile); err == nil {
// 文件存在,加载密钥
if err := m.loadKeys(); err != nil {
// 加载失败,生成新密钥
return m.generateAndSaveKey()
}
return nil
}
// 文件不存在,生成新密钥
return m.generateAndSaveKey()
}
// loadKeys 从文件加载所有密钥。
//
// 密钥文件格式:每个密钥 32 字节,连续存储
//
// 返回值:
// - error: 读取或解析失败时返回错误
func (m *SessionTicketManager) loadKeys() error {
data, err := os.ReadFile(m.config.KeyFile)
if err != nil {
return fmt.Errorf("failed to read key file: %w", err)
}
// 解析密钥(每个 32 字节)
if len(data) < ticketKeySize {
return errors.New("key file too small")
}
m.keys = make([][]byte, 0, m.config.RetainKeys)
for i := 0; i+ticketKeySize <= len(data); i += ticketKeySize {
key := make([]byte, ticketKeySize)
copy(key, data[i:i+ticketKeySize])
m.keys = append(m.keys, key)
}
// 如果加载的密钥超过保留数量,只保留最新的
if len(m.keys) > m.config.RetainKeys {
m.keys = m.keys[len(m.keys)-m.config.RetainKeys:]
}
// 确保至少有一个密钥
if len(m.keys) == 0 {
return errors.New("no valid keys found in file")
}
return nil
}
// saveKeys 将所有密钥保存到文件。
//
// 密钥文件格式:每个密钥 32 字节,连续存储
// 文件权限设置为 0600仅所有者可读写
//
// 返回值:
// - error: 写入失败时返回错误
func (m *SessionTicketManager) saveKeys() error {
// 计算总大小
totalSize := len(m.keys) * ticketKeySize
data := make([]byte, 0, totalSize)
for _, key := range m.keys {
data = append(data, key...)
}
// 使用 0600 权限写入文件(敏感数据,限制访问)
if err := os.WriteFile(m.config.KeyFile, data, 0600); err != nil {
return fmt.Errorf("failed to write key file: %w", err)
}
return nil
}
// generateAndSaveKey 生成新密钥并保存。
//
// 返回值:
// - error: 生成或保存失败时返回错误
func (m *SessionTicketManager) generateAndSaveKey() error {
key, err := generateTicketKey()
if err != nil {
return err
}
m.keys = [][]byte{key}
if m.config.KeyFile != "" {
if err := m.saveKeys(); err != nil {
return err
}
}
return nil
}
// generateTicketKey 生成新的随机 Session Ticket 密钥。
//
// 使用 crypto/rand 生成加密安全的随机密钥。
//
// 返回值:
// - []byte: 32 字节的随机密钥
// - error: 随机数生成失败时返回错误
func generateTicketKey() ([]byte, error) {
key := make([]byte, ticketKeySize)
if _, err := rand.Read(key); err != nil {
return nil, fmt.Errorf("failed to generate random key: %w", err)
}
return key, nil
}
// GetKeyStatus 返回当前密钥状态信息。
//
// 用于监控和调试,显示当前密钥数量和轮换状态。
//
// 返回值:
// - SessionTicketStatus: 密钥状态信息
type SessionTicketStatus struct {
// KeyCount 当前密钥数量
KeyCount int
// RetainKeys 配置的最大保留密钥数
RetainKeys int
// RotateInterval 配置的轮换间隔
RotateInterval time.Duration
// Started 管理器是否已启动
Started bool
}
// GetStatus 返回当前密钥状态。
//
// 返回值:
// - SessionTicketStatus: 密钥状态信息
func (m *SessionTicketManager) GetStatus() SessionTicketStatus {
m.mu.RLock()
defer m.mu.RUnlock()
return SessionTicketStatus{
KeyCount: len(m.keys),
RetainKeys: m.config.RetainKeys,
RotateInterval: m.config.RotateInterval,
Started: m.started,
}
}

View File

@ -0,0 +1,474 @@
// Package ssl 提供 Session Tickets 的单元测试。
//
// 测试覆盖:
// - 密钥生成和加载
// - 密钥轮换逻辑
// - 多密钥保留策略
// - 与 TLS 配置的集成
// - 边界条件和错误处理
//
// 作者xfy
package ssl
import (
"crypto/tls"
"os"
"path/filepath"
"testing"
"time"
"rua.plus/lolly/internal/config"
)
// TestNewSessionTicketManager 测试创建 Session Ticket 管理器。
func TestNewSessionTicketManager(t *testing.T) {
tests := []struct {
name string
cfg config.SessionTicketsConfig
wantError bool
checkDefaults bool
}{
{
name: "disabled_should_error",
cfg: config.SessionTicketsConfig{
Enabled: false,
},
wantError: true,
},
{
name: "enabled_without_keyfile",
cfg: config.SessionTicketsConfig{
Enabled: true,
},
wantError: false,
checkDefaults: true,
},
{
name: "enabled_with_defaults",
cfg: config.SessionTicketsConfig{
Enabled: true,
KeyFile: "",
RotateInterval: 0,
RetainKeys: 0,
},
wantError: false,
checkDefaults: true,
},
{
name: "enabled_with_custom_values",
cfg: config.SessionTicketsConfig{
Enabled: true,
RotateInterval: 30 * time.Minute,
RetainKeys: 5,
},
wantError: false,
checkDefaults: false, // 使用自定义值,不检查默认值
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mgr, err := NewSessionTicketManager(tt.cfg)
if tt.wantError {
if err == nil {
t.Errorf("NewSessionTicketManager() expected error, got nil")
}
return
}
if err != nil {
t.Errorf("NewSessionTicketManager() unexpected error: %v", err)
return
}
if mgr == nil {
t.Error("NewSessionTicketManager() returned nil manager")
return
}
defer mgr.Stop()
// 验证默认配置(仅当使用默认值时)
if tt.checkDefaults {
if mgr.config.RotateInterval != defaultRotateInterval {
t.Errorf("RotateInterval = %v, want %v", mgr.config.RotateInterval, defaultRotateInterval)
}
if mgr.config.RetainKeys != defaultRetainKeys {
t.Errorf("RetainKeys = %d, want %d", mgr.config.RetainKeys, defaultRetainKeys)
}
} else {
// 验证自定义值被正确保留
if mgr.config.RotateInterval != tt.cfg.RotateInterval {
t.Errorf("RotateInterval = %v, want %v", mgr.config.RotateInterval, tt.cfg.RotateInterval)
}
if mgr.config.RetainKeys != tt.cfg.RetainKeys {
t.Errorf("RetainKeys = %d, want %d", mgr.config.RetainKeys, tt.cfg.RetainKeys)
}
}
})
}
}
// TestSessionTicketManager_KeyGeneration 测试密钥生成。
func TestSessionTicketManager_KeyGeneration(t *testing.T) {
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
Enabled: true,
})
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
defer mgr.Stop()
keys := mgr.GetKeys()
if len(keys) == 0 {
t.Fatal("Expected at least one key, got none")
}
// 验证密钥大小
for i, key := range keys {
if len(key) != ticketKeySize {
t.Errorf("Key %d size = %d, want %d", i, len(key), ticketKeySize)
}
}
}
// TestSessionTicketManager_KeyRotation 测试密钥轮换。
func TestSessionTicketManager_KeyRotation(t *testing.T) {
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
Enabled: true,
RotateInterval: time.Hour, // 使用长间隔,手动触发轮换
RetainKeys: 3,
})
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
defer mgr.Stop()
initialKeys := mgr.GetKeys()
if len(initialKeys) != 1 {
t.Fatalf("Expected 1 initial key, got %d", len(initialKeys))
}
// 手动轮换密钥
if err := mgr.RotateKey(); err != nil {
t.Fatalf("RotateKey() failed: %v", err)
}
keysAfter1 := mgr.GetKeys()
if len(keysAfter1) != 2 {
t.Errorf("Expected 2 keys after rotation, got %d", len(keysAfter1))
}
// 验证新旧密钥不同
if string(initialKeys[0]) == string(keysAfter1[1]) {
t.Error("New key should be different from initial key")
}
// 继续轮换到超过保留数量
_ = mgr.RotateKey()
_ = mgr.RotateKey()
keysAfter4 := mgr.GetKeys()
if len(keysAfter4) != 3 {
t.Errorf("Expected 3 keys (max retain), got %d", len(keysAfter4))
}
}
// TestSessionTicketManager_KeyRetention 测试密钥保留策略。
func TestSessionTicketManager_KeyRetention(t *testing.T) {
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
Enabled: true,
RotateInterval: time.Hour,
RetainKeys: 2,
})
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
defer mgr.Stop()
// 生成多个密钥
for i := 0; i < 5; i++ {
if err := mgr.RotateKey(); err != nil {
t.Fatalf("RotateKey() failed at iteration %d: %v", i, err)
}
}
keys := mgr.GetKeys()
if len(keys) != 2 {
t.Errorf("Expected 2 keys (RetainKeys limit), got %d", len(keys))
}
}
// TestSessionTicketManager_Persistence 测试密钥持久化。
func TestSessionTicketManager_Persistence(t *testing.T) {
tempDir := t.TempDir()
keyFile := filepath.Join(tempDir, "ticket.key")
// 创建管理器并生成密钥
mgr1, err := NewSessionTicketManager(config.SessionTicketsConfig{
Enabled: true,
KeyFile: keyFile,
RotateInterval: time.Hour,
RetainKeys: 3,
})
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
// 轮换几次生成多个密钥
_ = mgr1.RotateKey()
_ = mgr1.RotateKey()
mgr1.Stop()
// 验证密钥文件存在
if _, err := os.Stat(keyFile); os.IsNotExist(err) {
t.Fatal("Key file should exist after saving")
}
// 从文件加载密钥创建新管理器
mgr2, err := NewSessionTicketManager(config.SessionTicketsConfig{
Enabled: true,
KeyFile: keyFile,
RotateInterval: time.Hour,
RetainKeys: 3,
})
if err != nil {
t.Fatalf("Failed to create manager from existing key file: %v", err)
}
defer mgr2.Stop()
keys := mgr2.GetKeys()
if len(keys) != 3 {
t.Errorf("Expected 3 keys loaded from file, got %d", len(keys))
}
}
// TestSessionTicketManager_ApplyToTLSConfig 测试应用到 TLS 配置。
func TestSessionTicketManager_ApplyToTLSConfig(t *testing.T) {
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
Enabled: true,
RotateInterval: time.Hour,
RetainKeys: 3,
})
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
defer mgr.Stop()
tlsCfg := &tls.Config{
MinVersion: tls.VersionTLS13,
}
mgr.ApplyToTLSConfig(tlsCfg)
// 验证可以获取密钥
keys := mgr.GetKeys()
if len(keys) == 0 {
t.Error("Expected keys to be set in TLS config")
}
}
// TestSessionTicketManager_StartStop 测试启动和停止。
func TestSessionTicketManager_StartStop(t *testing.T) {
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
Enabled: true,
RotateInterval: 100 * time.Millisecond,
RetainKeys: 3,
})
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
// 验证初始状态
status := mgr.GetStatus()
if status.Started {
t.Error("Manager should not be started initially")
}
// 启动
mgr.Start()
status = mgr.GetStatus()
if !status.Started {
t.Error("Manager should be started after Start()")
}
// 等待一次轮换
time.Sleep(150 * time.Millisecond)
keys := mgr.GetKeys()
if len(keys) < 1 {
t.Error("Expected at least 1 key after auto-rotation")
}
// 停止
mgr.Stop()
status = mgr.GetStatus()
if status.Started {
t.Error("Manager should not be started after Stop()")
}
}
// TestSessionTicketManager_GetStatus 测试获取状态。
func TestSessionTicketManager_GetStatus(t *testing.T) {
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
Enabled: true,
RotateInterval: 30 * time.Minute,
RetainKeys: 5,
})
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
defer mgr.Stop()
status := mgr.GetStatus()
if status.KeyCount != 1 {
t.Errorf("KeyCount = %d, want 1", status.KeyCount)
}
// 使用自定义值 5不是默认值
if status.RetainKeys != 5 {
t.Errorf("RetainKeys = %d, want 5", status.RetainKeys)
}
// RotateInterval 使用配置值30m > 0所以保留
if status.RotateInterval != 30*time.Minute {
t.Errorf("RotateInterval = %v, want %v", status.RotateInterval, 30*time.Minute)
}
if status.Started {
t.Error("Started should be false before Start()")
}
mgr.Start()
status = mgr.GetStatus()
if !status.Started {
t.Error("Started should be true after Start()")
}
}
// TestGenerateTicketKey 测试密钥生成函数。
func TestGenerateTicketKey(t *testing.T) {
key1, err := generateTicketKey()
if err != nil {
t.Fatalf("generateTicketKey() failed: %v", err)
}
if len(key1) != ticketKeySize {
t.Errorf("generateTicketKey() key size = %d, want %d", len(key1), ticketKeySize)
}
key2, err := generateTicketKey()
if err != nil {
t.Fatalf("generateTicketKey() second call failed: %v", err)
}
// 验证生成的密钥是随机的(不相同)
if string(key1) == string(key2) {
t.Error("generateTicketKey() should generate random keys")
}
}
// TestSessionTicketManager_ConcurrentAccess 测试并发访问。
func TestSessionTicketManager_ConcurrentAccess(t *testing.T) {
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
Enabled: true,
RotateInterval: 10 * time.Millisecond,
RetainKeys: 3,
})
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
defer mgr.Stop()
mgr.Start()
// 并发读取和轮换
done := make(chan bool, 3)
// 协程 1: 持续获取密钥
go func() {
for i := 0; i < 100; i++ {
_ = mgr.GetKeys()
time.Sleep(time.Millisecond)
}
done <- true
}()
// 协程 2: 持续获取状态
go func() {
for i := 0; i < 100; i++ {
_ = mgr.GetStatus()
time.Sleep(time.Millisecond)
}
done <- true
}()
// 协程 3: 手动轮换
go func() {
for i := 0; i < 20; i++ {
_ = mgr.RotateKey()
time.Sleep(5 * time.Millisecond)
}
done <- true
}()
// 等待所有协程完成
for i := 0; i < 3; i++ {
<-done
}
// 验证最终状态
keys := mgr.GetKeys()
if len(keys) < 1 || len(keys) > 3 {
t.Errorf("Final key count %d out of expected range [1, 3]", len(keys))
}
}
// BenchmarkGenerateTicketKey 基准测试密钥生成。
func BenchmarkGenerateTicketKey(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := generateTicketKey()
if err != nil {
b.Fatal(err)
}
}
}
// BenchmarkSessionTicketManager_GetKeys 基准测试获取密钥。
func BenchmarkSessionTicketManager_GetKeys(b *testing.B) {
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
Enabled: true,
RotateInterval: time.Hour,
RetainKeys: 3,
})
if err != nil {
b.Fatalf("Failed to create manager: %v", err)
}
defer mgr.Stop()
// 预生成多个密钥
for i := 0; i < 2; i++ {
_ = mgr.RotateKey()
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = mgr.GetKeys()
}
}
// BenchmarkSessionTicketManager_RotateKey 基准测试密钥轮换。
func BenchmarkSessionTicketManager_RotateKey(b *testing.B) {
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
Enabled: true,
RotateInterval: time.Hour,
RetainKeys: 3,
})
if err != nil {
b.Fatalf("Failed to create manager: %v", err)
}
defer mgr.Stop()
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := mgr.RotateKey()
if err != nil {
b.Fatal(err)
}
}
}

View File

@ -62,6 +62,12 @@ type TLSManager struct {
// ocspManager OCSP Stapling 管理器
ocspManager *OCSPManager
// sessionTicketMgr Session Ticket 管理器
sessionTicketMgr *SessionTicketManager
// clientVerifier 客户端证书验证器
clientVerifier *ClientVerifier
// certificates 解析后的证书映射,用于 OCSP
certificates map[string]*x509.Certificate
@ -132,6 +138,21 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
issuers: make(map[string]*x509.Certificate),
}
// 初始化 Session Tickets如果启用
if cfg.SessionTickets.Enabled {
sessionTicketMgr, err := NewSessionTicketManager(cfg.SessionTickets)
if err != nil {
// Session Tickets 初始化失败不阻止 TLS 工作
// 可以记录日志
_ = err
} else {
manager.sessionTicketMgr = sessionTicketMgr
// 应用 Session Tickets 到 TLS 配置
sessionTicketMgr.ApplyToTLSConfig(tlsCfg)
sessionTicketMgr.Start()
}
}
// 初始化 OCSP Stapling如果启用
if cfg.OCSPStapling {
ocspMgr := NewOCSPManager(DefaultOCSPConfig())
@ -164,6 +185,19 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
ocspMgr.Start()
}
// 初始化客户端证书验证(如果启用)
if cfg.ClientVerify.Enabled {
clientVerifier, err := NewClientVerifier(cfg.ClientVerify)
if err != nil {
// 客户端验证配置失败不阻止 TLS 工作
// 可以记录日志
_ = err
} else {
manager.clientVerifier = clientVerifier
clientVerifier.ConfigureTLS(tlsCfg)
}
}
// 设置为默认配置
manager.defaultCfg = tlsCfg
@ -310,11 +344,14 @@ func (m *TLSManager) RemoveCertificate(name string) {
m.mu.Unlock()
}
// Close 停止 OCSP 管理器并释放资源。
// Close 停止 OCSP 管理器和 Session Ticket 管理器并释放资源。
func (m *TLSManager) Close() {
if m.ocspManager != nil {
m.ocspManager.Stop()
}
if m.sessionTicketMgr != nil {
m.sessionTicketMgr.Stop()
}
}
// getConfigForClientWithOCSP 返回启用 OCSP Stapling 的 TLS 配置。