feat(ssl,config): 新增 Session Tickets 和 mTLS 客户端证书验证
- SessionTicketsConfig 支持 TLS 1.3 会话恢复,密钥轮换和持久化 - ClientVerifyConfig 支持双向 TLS 认证,CA 证书池和 CRL - TLSManager 集成 SessionTicketManager 和 ClientVerifier - 新增完整测试覆盖密钥轮换和客户端验证逻辑 Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
61455412eb
commit
9d49349ee1
@ -541,6 +541,14 @@ type SSLConfig struct {
|
||||
// HSTS HSTS 配置
|
||||
// HTTP Strict Transport Security 安全策略
|
||||
HSTS HSTSConfig `yaml:"hsts"`
|
||||
|
||||
// SessionTickets Session Tickets 配置
|
||||
// 启用 TLS 1.3 会话恢复以提升握手性能
|
||||
SessionTickets SessionTicketsConfig `yaml:"session_tickets"`
|
||||
|
||||
// ClientVerify 客户端证书验证配置
|
||||
// 启用 mTLS 双向认证
|
||||
ClientVerify ClientVerifyConfig `yaml:"client_verify"`
|
||||
}
|
||||
|
||||
// HSTSConfig HTTP Strict Transport Security 配置。
|
||||
@ -573,6 +581,87 @@ type HSTSConfig struct {
|
||||
Preload bool `yaml:"preload"`
|
||||
}
|
||||
|
||||
// SessionTicketsConfig TLS Session Ticket 配置。
|
||||
//
|
||||
// Session Tickets 允许 TLS 1.3 会话恢复,避免完整握手,显著提升性能。
|
||||
// 密钥定期轮换增强安全性,同时保留旧密钥确保已发放的票据仍可解密。
|
||||
//
|
||||
// 注意事项:
|
||||
// - KeyFile 为密钥存储文件路径,用于持久化密钥
|
||||
// - RotateInterval 为密钥轮换间隔,建议 1-24 小时
|
||||
// - RetainKeys 为保留的历史密钥数量,至少保留 2 个
|
||||
// - 密钥文件权限应为 0600(仅所有者可读写)
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// ssl:
|
||||
// session_tickets:
|
||||
// enabled: true
|
||||
// key_file: "/var/lib/lolly/session_tickets.key"
|
||||
// rotate_interval: 1h
|
||||
// retain_keys: 3
|
||||
type SessionTicketsConfig struct {
|
||||
// Enabled 是否启用 Session Tickets
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// KeyFile 密钥存储文件路径
|
||||
// 用于持久化密钥,确保重启后旧票据仍可解密
|
||||
KeyFile string `yaml:"key_file"`
|
||||
|
||||
// RotateInterval 密钥轮换间隔
|
||||
// 定期生成新密钥,增强安全性
|
||||
RotateInterval time.Duration `yaml:"rotate_interval"`
|
||||
|
||||
// RetainKeys 保留的历史密钥数量
|
||||
// 旧密钥用于解密已发放的票据,建议 3-5 个
|
||||
RetainKeys int `yaml:"retain_keys"`
|
||||
}
|
||||
|
||||
// ClientVerifyConfig mTLS 客户端证书验证配置。
|
||||
//
|
||||
// 配置双向 TLS 认证,要求客户端提供有效证书才能建立连接。
|
||||
// 适用于需要强身份验证的场景,如 API 服务、内部系统通信。
|
||||
//
|
||||
// 注意事项:
|
||||
// - Mode 可选值:none、request、require、optional_no_ca
|
||||
// - ClientCA 为客户端 CA 证书文件路径(必需)
|
||||
// - VerifyDepth 为证书链验证深度,默认 1
|
||||
// - CRL 为证书撤销列表文件路径(可选)
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// ssl:
|
||||
// client_verify:
|
||||
// enabled: true
|
||||
// mode: "require"
|
||||
// client_ca: "/etc/ssl/ca/client-ca.crt"
|
||||
// verify_depth: 2
|
||||
// crl: "/etc/ssl/ca/client-ca.crl"
|
||||
type ClientVerifyConfig struct {
|
||||
// Enabled 是否启用客户端证书验证
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// Mode 验证模式
|
||||
// 可选值:
|
||||
// - none: 不请求客户端证书(默认)
|
||||
// - request: 请求证书但不验证
|
||||
// - require: 要求并验证客户端证书
|
||||
// - optional_no_ca: 请求证书但不强制验证
|
||||
Mode string `yaml:"mode"`
|
||||
|
||||
// ClientCA 客户端 CA 证书文件路径
|
||||
// 用于验证客户端证书的信任链
|
||||
ClientCA string `yaml:"client_ca"`
|
||||
|
||||
// VerifyDepth 证书链验证深度
|
||||
// 限制证书链的最大层数,默认 1
|
||||
VerifyDepth int `yaml:"verify_depth"`
|
||||
|
||||
// CRL 证书撤销列表文件路径
|
||||
// 用于检查客户端证书是否被撤销(可选)
|
||||
CRL string `yaml:"crl"`
|
||||
}
|
||||
|
||||
// SecurityConfig 安全配置,包含访问控制、限流、认证和安全头部。
|
||||
//
|
||||
// 用于保护服务器免受各种网络攻击和滥用。
|
||||
@ -613,6 +702,10 @@ type SecurityConfig struct {
|
||||
// HTTP Basic 认证设置
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
|
||||
// AuthRequest 外部认证子请求配置
|
||||
// 将认证委托给外部服务,根据响应状态码决定是否允许请求继续
|
||||
AuthRequest AuthRequestConfig `yaml:"auth_request"`
|
||||
|
||||
// Headers 安全头部
|
||||
// 添加安全相关的 HTTP 响应头
|
||||
Headers SecurityHeaders `yaml:"headers"`
|
||||
@ -849,6 +942,58 @@ type ErrorPageConfig struct {
|
||||
ResponseCode int `yaml:"response_code"`
|
||||
}
|
||||
|
||||
// AuthRequestConfig 外部认证子请求配置。
|
||||
//
|
||||
// 将认证委托给外部服务,根据子请求的响应状态码决定是否允许原请求继续。
|
||||
// 适用于需要复杂认证逻辑或与现有认证系统集成的场景。
|
||||
//
|
||||
// 行为规则:
|
||||
// - 2xx 响应:认证通过,原请求继续处理
|
||||
// - 401/403 响应:认证失败,返回相应状态码
|
||||
// - 其他响应或超时:返回 500 内部服务器错误
|
||||
// - 认证服务不可用时:返回 500 内部服务器错误
|
||||
//
|
||||
// 注意事项:
|
||||
// - 认证请求使用独立的连接池,避免影响主服务
|
||||
// - 支持变量展开(如 $host, $uri, $request_uri)
|
||||
// - 建议配置合理的超时时间,避免长时间阻塞
|
||||
// - 认证请求会携带原请求的头信息(如 Cookie, Authorization)
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// security:
|
||||
// auth_request:
|
||||
// uri: /auth
|
||||
// method: GET
|
||||
// auth_timeout: 5s
|
||||
// headers:
|
||||
// X-Original-Uri: $request_uri
|
||||
// X-Original-Host: $host
|
||||
type AuthRequestConfig struct {
|
||||
// Enabled 是否启用外部认证子请求
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// URI 认证服务地址
|
||||
// 可以是相对路径(如 /auth)或完整 URL(如 http://auth-service:8080/verify)
|
||||
URI string `yaml:"uri"`
|
||||
|
||||
// Method 认证请求方法
|
||||
// 默认为 GET,支持 GET、POST、HEAD 等
|
||||
Method string `yaml:"method"`
|
||||
|
||||
// Timeout 认证请求超时时间
|
||||
// 默认 5 秒,超过此时间视为认证失败
|
||||
Timeout time.Duration `yaml:"auth_timeout"`
|
||||
|
||||
// Headers 自定义认证请求头
|
||||
// 支持变量展开,用于向认证服务传递原请求信息
|
||||
Headers map[string]string `yaml:"headers"`
|
||||
|
||||
// ForwardHeaders 需要转发到认证服务的原请求头
|
||||
// 默认包含:Cookie, Authorization, X-Forwarded-For
|
||||
ForwardHeaders []string `yaml:"forward_headers"`
|
||||
}
|
||||
|
||||
// RewriteRule URL 重写规则。
|
||||
//
|
||||
// 用于在代理或静态文件服务前修改请求 URL。
|
||||
@ -1249,6 +1394,14 @@ type StreamConfig struct {
|
||||
// Upstream 上游配置
|
||||
// 后端服务器列表和负载均衡设置
|
||||
Upstream StreamUpstream `yaml:"upstream"`
|
||||
|
||||
// SSL SSL/TLS 配置
|
||||
// 启用 TLS 终端,支持加密的 TCP 连接
|
||||
SSL StreamSSLConfig `yaml:"ssl"`
|
||||
|
||||
// ProxySSL 上游 SSL 配置
|
||||
// 启用到上游服务器的 TLS 连接
|
||||
ProxySSL StreamProxySSLConfig `yaml:"proxy_ssl"`
|
||||
}
|
||||
|
||||
// StreamUpstream Stream 上游配置。
|
||||
@ -1301,6 +1454,110 @@ type StreamTarget struct {
|
||||
Weight int `yaml:"weight"`
|
||||
}
|
||||
|
||||
// StreamSSLConfig Stream SSL 服务端配置。
|
||||
//
|
||||
// 配置 Stream 模块的 TLS 终端功能,用于加密 TCP 流量。
|
||||
//
|
||||
// 注意事项:
|
||||
// - 仅对 TCP 协议有效,UDP 不支持 TLS
|
||||
// - 证书文件需要 PEM 格式
|
||||
// - 支持配置客户端证书验证(mTLS)
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// stream:
|
||||
// - listen: ":3306"
|
||||
// protocol: "tcp"
|
||||
// ssl:
|
||||
// enabled: true
|
||||
// cert: "/etc/ssl/server.crt"
|
||||
// key: "/etc/ssl/server.key"
|
||||
// upstream:
|
||||
// targets:
|
||||
// - addr: "mysql:3306"
|
||||
type StreamSSLConfig struct {
|
||||
// Enabled 是否启用 SSL/TLS
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// Cert 证书文件路径
|
||||
// PEM 格式的服务器证书
|
||||
Cert string `yaml:"cert"`
|
||||
|
||||
// Key 私钥文件路径
|
||||
// PEM 格式的私钥
|
||||
Key string `yaml:"key"`
|
||||
|
||||
// Protocols TLS 协议版本
|
||||
// 默认 ["TLSv1.2", "TLSv1.3"]
|
||||
Protocols []string `yaml:"protocols"`
|
||||
|
||||
// Ciphers 加密套件
|
||||
// 仅对 TLS 1.2 有效
|
||||
Ciphers []string `yaml:"ciphers"`
|
||||
|
||||
// ClientCA 客户端 CA 证书
|
||||
// 用于 mTLS 客户端证书验证
|
||||
ClientCA string `yaml:"client_ca"`
|
||||
|
||||
// VerifyDepth 证书链验证深度
|
||||
// 默认 1
|
||||
VerifyDepth int `yaml:"verify_depth"`
|
||||
}
|
||||
|
||||
// StreamProxySSLConfig Stream 上游 SSL 配置。
|
||||
//
|
||||
// 配置到上游服务器的 TLS 连接,用于加密代理到后端的流量。
|
||||
//
|
||||
// 注意事项:
|
||||
// - 启用后,代理将使用 TLS 连接到上游
|
||||
// - 支持客户端证书(mTLS)和服务器证书验证
|
||||
// - ServerName 用于 SNI 和证书验证
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// stream:
|
||||
// - listen: ":3306"
|
||||
// protocol: "tcp"
|
||||
// proxy_ssl:
|
||||
// enabled: true
|
||||
// verify: true
|
||||
// trusted_ca: "/etc/ssl/ca.crt"
|
||||
// server_name: "mysql.internal"
|
||||
// upstream:
|
||||
// targets:
|
||||
// - addr: "mysql:3306"
|
||||
type StreamProxySSLConfig struct {
|
||||
// Enabled 是否启用上游 SSL
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// Verify 是否验证上游证书
|
||||
// 为 true 时验证证书链
|
||||
Verify bool `yaml:"verify"`
|
||||
|
||||
// TrustedCA 信任的 CA 证书
|
||||
// 用于验证上游服务器证书
|
||||
TrustedCA string `yaml:"trusted_ca"`
|
||||
|
||||
// ServerName 服务器名称
|
||||
// 用于 SNI 和证书验证
|
||||
ServerName string `yaml:"server_name"`
|
||||
|
||||
// Cert 客户端证书
|
||||
// 用于 mTLS 客户端认证
|
||||
Cert string `yaml:"cert"`
|
||||
|
||||
// Key 客户端私钥
|
||||
// 用于 mTLS 客户端认证
|
||||
Key string `yaml:"key"`
|
||||
|
||||
// Protocols TLS 协议版本
|
||||
Protocols []string `yaml:"protocols"`
|
||||
|
||||
// SessionReuse 是否复用 SSL 会话
|
||||
// 启用后可提升连接性能
|
||||
SessionReuse bool `yaml:"session_reuse"`
|
||||
}
|
||||
|
||||
// Load 从文件加载配置。
|
||||
//
|
||||
// 读取指定路径的 YAML 配置文件,解析并验证配置内容。
|
||||
|
||||
398
internal/ssl/client_verify.go
Normal file
398
internal/ssl/client_verify.go
Normal file
@ -0,0 +1,398 @@
|
||||
// Package ssl 提供 mTLS 客户端证书验证支持。
|
||||
//
|
||||
// 该文件包含客户端证书验证的核心逻辑,包括:
|
||||
// - CA 证书池加载和管理
|
||||
// - 证书吊销列表 (CRL) 支持
|
||||
// - 验证模式配置
|
||||
// - 客户端证书信息提取
|
||||
//
|
||||
// mTLS (Mutual TLS) 提供双向认证,服务器验证客户端证书,
|
||||
// 客户端验证服务器证书,适用于高安全场景。
|
||||
//
|
||||
// 作者:xfy
|
||||
package ssl
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// ClientVerifyMode 客户端证书验证模式
|
||||
type ClientVerifyMode int
|
||||
|
||||
const (
|
||||
// VerifyOff 不验证客户端证书
|
||||
VerifyOff ClientVerifyMode = iota
|
||||
// VerifyOn 强制验证客户端证书
|
||||
VerifyOn
|
||||
// VerifyOptional 可选验证(客户端可选择不提供证书)
|
||||
VerifyOptional
|
||||
// VerifyOptionalNoCA 可选验证但不验证 CA
|
||||
VerifyOptionalNoCA
|
||||
)
|
||||
|
||||
// String 返回验证模式的字符串表示。
|
||||
func (m ClientVerifyMode) String() string {
|
||||
switch m {
|
||||
case VerifyOff:
|
||||
return "off"
|
||||
case VerifyOn:
|
||||
return "on"
|
||||
case VerifyOptional:
|
||||
return "optional"
|
||||
case VerifyOptionalNoCA:
|
||||
return "optional_no_ca"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// ParseVerifyMode 解析验证模式字符串。
|
||||
//
|
||||
// 参数:
|
||||
// - mode: 模式字符串(on/off/optional/optional_no_ca)
|
||||
//
|
||||
// 返回值:
|
||||
// - ClientVerifyMode: 验证模式
|
||||
// - error: 无效模式时返回错误
|
||||
func ParseVerifyMode(mode string) (ClientVerifyMode, error) {
|
||||
switch mode {
|
||||
case "off", "":
|
||||
return VerifyOff, nil
|
||||
case "on":
|
||||
return VerifyOn, nil
|
||||
case "optional":
|
||||
return VerifyOptional, nil
|
||||
case "optional_no_ca":
|
||||
return VerifyOptionalNoCA, nil
|
||||
default:
|
||||
return VerifyOff, fmt.Errorf("invalid verify mode: %s", mode)
|
||||
}
|
||||
}
|
||||
|
||||
// TLSClientAuth 返回对应的 tls.ClientAuthType。
|
||||
//
|
||||
// 返回值:
|
||||
// - tls.ClientAuthType: TLS 客户端认证类型
|
||||
func (m ClientVerifyMode) TLSClientAuth() tls.ClientAuthType {
|
||||
switch m {
|
||||
case VerifyOff:
|
||||
return tls.NoClientCert
|
||||
case VerifyOn:
|
||||
return tls.RequireAndVerifyClientCert
|
||||
case VerifyOptional:
|
||||
return tls.VerifyClientCertIfGiven
|
||||
case VerifyOptionalNoCA:
|
||||
return tls.RequestClientCert
|
||||
default:
|
||||
return tls.NoClientCert
|
||||
}
|
||||
}
|
||||
|
||||
// ClientVerifier 客户端证书验证器。
|
||||
//
|
||||
// 管理客户端证书验证所需的 CA 证书池和 CRL。
|
||||
type ClientVerifier struct {
|
||||
// caPool CA 证书池
|
||||
caPool *x509.CertPool
|
||||
|
||||
// crl 证书吊销列表
|
||||
crl *x509.RevocationList
|
||||
|
||||
// mode 验证模式
|
||||
mode ClientVerifyMode
|
||||
|
||||
// verifyDepth 验证深度限制
|
||||
verifyDepth int
|
||||
|
||||
// caFile CA 文件路径
|
||||
caFile string
|
||||
|
||||
// crlFile CRL 文件路径
|
||||
crlFile string
|
||||
}
|
||||
|
||||
// NewClientVerifier 创建新的客户端证书验证器。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 客户端验证配置
|
||||
//
|
||||
// 返回值:
|
||||
// - *ClientVerifier: 验证器实例
|
||||
// - error: 配置无效时返回错误
|
||||
func NewClientVerifier(cfg config.ClientVerifyConfig) (*ClientVerifier, error) {
|
||||
if !cfg.Enabled {
|
||||
return &ClientVerifier{
|
||||
mode: VerifyOff,
|
||||
}, nil
|
||||
}
|
||||
|
||||
mode, err := ParseVerifyMode(cfg.Mode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
verifier := &ClientVerifier{
|
||||
mode: mode,
|
||||
verifyDepth: cfg.VerifyDepth,
|
||||
caFile: cfg.ClientCA,
|
||||
crlFile: cfg.CRL,
|
||||
}
|
||||
|
||||
// 加载 CA 证书池(如果需要验证)
|
||||
if mode == VerifyOn || mode == VerifyOptional {
|
||||
if cfg.ClientCA == "" {
|
||||
return nil, errors.New("client_ca is required when verify is enabled")
|
||||
}
|
||||
|
||||
caPool, err := LoadCACertPool(cfg.ClientCA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load CA certificate pool: %w", err)
|
||||
}
|
||||
verifier.caPool = caPool
|
||||
}
|
||||
|
||||
// 加载 CRL(如果配置)
|
||||
if cfg.CRL != "" {
|
||||
crl, err := LoadCRL(cfg.CRL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load CRL: %w", err)
|
||||
}
|
||||
verifier.crl = crl
|
||||
}
|
||||
|
||||
return verifier, nil
|
||||
}
|
||||
|
||||
// ConfigureTLS 配置 TLS 以启用客户端证书验证。
|
||||
//
|
||||
// 参数:
|
||||
// - tlsCfg: TLS 配置对象
|
||||
func (v *ClientVerifier) ConfigureTLS(tlsCfg *tls.Config) {
|
||||
if tlsCfg == nil || v.mode == VerifyOff {
|
||||
return
|
||||
}
|
||||
|
||||
tlsCfg.ClientAuth = v.mode.TLSClientAuth()
|
||||
tlsCfg.ClientCAs = v.caPool
|
||||
|
||||
// 设置验证深度(通过 VerifyConnection 回调实现)
|
||||
if v.verifyDepth > 0 {
|
||||
tlsCfg.VerifyConnection = v.verifyConnection
|
||||
}
|
||||
}
|
||||
|
||||
// verifyConnection 验证 TLS 连接。
|
||||
//
|
||||
// 实现额外的验证逻辑,如证书深度检查。
|
||||
//
|
||||
// 参数:
|
||||
// - cs: 连接状态
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 验证失败时返回错误
|
||||
func (v *ClientVerifier) verifyConnection(cs tls.ConnectionState) error {
|
||||
// 检查 CRL
|
||||
if v.crl != nil && len(cs.PeerCertificates) > 0 {
|
||||
if err := v.checkCRL(cs.PeerCertificates[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 检查证书链深度
|
||||
if v.verifyDepth > 0 && len(cs.PeerCertificates) > v.verifyDepth {
|
||||
return fmt.Errorf("certificate chain too long: %d > %d", len(cs.PeerCertificates), v.verifyDepth)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkCRL 检查证书是否在吊销列表中。
|
||||
//
|
||||
// 参数:
|
||||
// - cert: 要检查的证书
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 证书已吊销时返回错误
|
||||
func (v *ClientVerifier) checkCRL(cert *x509.Certificate) error {
|
||||
if v.crl == nil || len(v.crl.RevokedCertificateEntries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, revoked := range v.crl.RevokedCertificateEntries {
|
||||
if cert.SerialNumber.Cmp(revoked.SerialNumber) == 0 {
|
||||
return fmt.Errorf("certificate %s has been revoked", cert.SerialNumber.String())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsEnabled 返回验证是否启用。
|
||||
//
|
||||
// 返回值:
|
||||
// - bool: 启用返回 true
|
||||
func (v *ClientVerifier) IsEnabled() bool {
|
||||
return v.mode != VerifyOff
|
||||
}
|
||||
|
||||
// GetMode 返回验证模式。
|
||||
//
|
||||
// 返回值:
|
||||
// - ClientVerifyMode: 当前验证模式
|
||||
func (v *ClientVerifier) GetMode() ClientVerifyMode {
|
||||
return v.mode
|
||||
}
|
||||
|
||||
// LoadCACertPool 从文件加载 CA 证书池。
|
||||
//
|
||||
// 支持 PEM 格式的证书文件,可包含多个 CA 证书。
|
||||
//
|
||||
// 参数:
|
||||
// - caFile: CA 证书文件路径
|
||||
//
|
||||
// 返回值:
|
||||
// - *x509.CertPool: CA 证书池
|
||||
// - error: 加载失败时返回错误
|
||||
func LoadCACertPool(caFile string) (*x509.CertPool, error) {
|
||||
data, err := os.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA file: %w", err)
|
||||
}
|
||||
|
||||
caPool := x509.NewCertPool()
|
||||
if !caPool.AppendCertsFromPEM(data) {
|
||||
return nil, errors.New("failed to parse CA certificates")
|
||||
}
|
||||
|
||||
return caPool, nil
|
||||
}
|
||||
|
||||
// LoadCRL 从文件加载证书吊销列表。
|
||||
//
|
||||
// 支持 PEM 和 DER 格式的 CRL 文件。
|
||||
//
|
||||
// 参数:
|
||||
// - crlFile: CRL 文件路径
|
||||
//
|
||||
// 返回值:
|
||||
// - *pkix.CertificateList: CRL 对象
|
||||
// - error: 加载失败时返回错误
|
||||
func LoadCRL(crlFile string) (*x509.RevocationList, error) {
|
||||
data, err := os.ReadFile(crlFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CRL file: %w", err)
|
||||
}
|
||||
|
||||
// 尝试 PEM 解码
|
||||
block, _ := pem.Decode(data)
|
||||
if block != nil {
|
||||
data = block.Bytes
|
||||
}
|
||||
|
||||
crl, err := x509.ParseRevocationList(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse CRL: %w", err)
|
||||
}
|
||||
|
||||
return crl, nil
|
||||
}
|
||||
|
||||
// ValidateClientCertificate 手动验证客户端证书。
|
||||
//
|
||||
// 参数:
|
||||
// - cert: 客户端证书
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 验证失败时返回错误
|
||||
func (v *ClientVerifier) ValidateClientCertificate(cert *x509.Certificate) error {
|
||||
if cert == nil {
|
||||
if v.mode == VerifyOn {
|
||||
return errors.New("client certificate is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查 CRL
|
||||
if v.crl != nil {
|
||||
if err := v.checkCRL(cert); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientCertInfo 提取客户端证书信息。
|
||||
//
|
||||
// 参数:
|
||||
// - cs: TLS 连接状态
|
||||
//
|
||||
// 返回值:
|
||||
// - *ClientCertInfo: 证书信息
|
||||
func GetClientCertInfo(cs *tls.ConnectionState) *ClientCertInfo {
|
||||
if cs == nil || len(cs.PeerCertificates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cert := cs.PeerCertificates[0]
|
||||
return &ClientCertInfo{
|
||||
Subject: cert.Subject.String(),
|
||||
Issuer: cert.Issuer.String(),
|
||||
Serial: cert.SerialNumber.String(),
|
||||
NotBefore: cert.NotBefore,
|
||||
NotAfter: cert.NotAfter,
|
||||
DNSNames: cert.DNSNames,
|
||||
Email: cert.EmailAddresses,
|
||||
Fingerprint: fingerprint(cert),
|
||||
}
|
||||
}
|
||||
|
||||
// ClientCertInfo 客户端证书信息。
|
||||
type ClientCertInfo struct {
|
||||
// Subject 证书主题
|
||||
Subject string
|
||||
|
||||
// Issuer 颁发者
|
||||
Issuer string
|
||||
|
||||
// Serial 序列号
|
||||
Serial string
|
||||
|
||||
// NotBefore 生效时间
|
||||
NotBefore time.Time
|
||||
|
||||
// NotAfter 过期时间
|
||||
NotAfter time.Time
|
||||
|
||||
// DNSNames DNS 名称
|
||||
DNSNames []string
|
||||
|
||||
// Email 邮箱地址
|
||||
Email []string
|
||||
|
||||
// Fingerprint 证书指纹
|
||||
Fingerprint string
|
||||
}
|
||||
|
||||
// fingerprint 计算证书指纹。
|
||||
//
|
||||
// 参数:
|
||||
// - cert: X509 证书
|
||||
//
|
||||
// 返回值:
|
||||
// - string: SHA256 指纹(十六进制)
|
||||
func fingerprint(cert *x509.Certificate) string {
|
||||
if cert == nil {
|
||||
return ""
|
||||
}
|
||||
// 返回证书的原始 DER 编码指纹
|
||||
return fmt.Sprintf("%x", cert.Raw)
|
||||
}
|
||||
356
internal/ssl/client_verify_test.go
Normal file
356
internal/ssl/client_verify_test.go
Normal file
@ -0,0 +1,356 @@
|
||||
// Package ssl 提供 mTLS 客户端验证的单元测试。
|
||||
//
|
||||
// 测试覆盖:
|
||||
// - CA 证书池加载
|
||||
// - 验证模式解析
|
||||
// - 客户端证书验证
|
||||
// - CRL 检查
|
||||
// - 变量提取
|
||||
//
|
||||
// 作者:xfy
|
||||
package ssl
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// generateTestCA 生成测试 CA 证书。
|
||||
func generateTestCA(t *testing.T) (*x509.Certificate, *rsa.PrivateKey, []byte) {
|
||||
t.Helper()
|
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate CA key: %v", err)
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "Test CA",
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create CA cert: %v", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse CA cert: %v", err)
|
||||
}
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
return cert, key, certPEM
|
||||
}
|
||||
|
||||
// generateTestClientCert 生成测试客户端证书。
|
||||
func generateTestClientCert(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, []byte) {
|
||||
t.Helper()
|
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate client key: %v", err)
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "Test Client",
|
||||
Organization: []string{"Test Org"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
EmailAddresses: []string{"test@example.com"},
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, &key.PublicKey, caKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create client cert: %v", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse client cert: %v", err)
|
||||
}
|
||||
|
||||
return cert, key, certDER
|
||||
}
|
||||
|
||||
// TestParseVerifyMode 测试验证模式解析。
|
||||
func TestParseVerifyMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected ClientVerifyMode
|
||||
wantErr bool
|
||||
}{
|
||||
{"off", VerifyOff, false},
|
||||
{"", VerifyOff, false},
|
||||
{"on", VerifyOn, false},
|
||||
{"optional", VerifyOptional, false},
|
||||
{"optional_no_ca", VerifyOptionalNoCA, false},
|
||||
{"invalid", VerifyOff, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
mode, err := ParseVerifyMode(tt.input)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("ParseVerifyMode(%q) expected error", tt.input)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("ParseVerifyMode(%q) unexpected error: %v", tt.input, err)
|
||||
return
|
||||
}
|
||||
if mode != tt.expected {
|
||||
t.Errorf("ParseVerifyMode(%q) = %v, want %v", tt.input, mode, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientVerifyMode_TLSClientAuth 测试 TLS 认证类型映射。
|
||||
func TestClientVerifyMode_TLSClientAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
mode ClientVerifyMode
|
||||
expected tls.ClientAuthType
|
||||
}{
|
||||
{VerifyOff, tls.NoClientCert},
|
||||
{VerifyOn, tls.RequireAndVerifyClientCert},
|
||||
{VerifyOptional, tls.VerifyClientCertIfGiven},
|
||||
{VerifyOptionalNoCA, tls.RequestClientCert},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.mode.String(), func(t *testing.T) {
|
||||
auth := tt.mode.TLSClientAuth()
|
||||
if auth != tt.expected {
|
||||
t.Errorf("TLSClientAuth() = %v, want %v", auth, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadCACertPool 测试 CA 证书池加载。
|
||||
func TestLoadCACertPool(t *testing.T) {
|
||||
// 创建临时 CA 文件
|
||||
tempDir := t.TempDir()
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
|
||||
_, _, caPEM := generateTestCA(t)
|
||||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||||
t.Fatalf("Failed to write CA file: %v", err)
|
||||
}
|
||||
|
||||
// 测试加载
|
||||
pool, err := LoadCACertPool(caFile)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadCACertPool() failed: %v", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Fatal("LoadCACertPool() returned nil pool")
|
||||
}
|
||||
|
||||
// 测试文件不存在
|
||||
_, err = LoadCACertPool("/nonexistent/ca.crt")
|
||||
if err == nil {
|
||||
t.Error("LoadCACertPool() should fail for non-existent file")
|
||||
}
|
||||
|
||||
// 测试无效证书
|
||||
invalidFile := filepath.Join(tempDir, "invalid.crt")
|
||||
if err := os.WriteFile(invalidFile, []byte("invalid data"), 0644); err != nil {
|
||||
t.Fatalf("写入无效证书文件失败: %v", err)
|
||||
}
|
||||
_, err = LoadCACertPool(invalidFile)
|
||||
if err == nil {
|
||||
t.Error("LoadCACertPool() should fail for invalid certificate")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewClientVerifier 测试创建客户端验证器。
|
||||
func TestNewClientVerifier(t *testing.T) {
|
||||
// 测试禁用状态
|
||||
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientVerifier() failed for disabled config: %v", err)
|
||||
}
|
||||
if verifier.IsEnabled() {
|
||||
t.Error("Verifier should be disabled when Enabled=false")
|
||||
}
|
||||
|
||||
// 测试启用但无 CA(应该失败)
|
||||
_, err = NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: "on",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("NewClientVerifier() should fail without CA file")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewClientVerifier_WithCA 测试带 CA 的验证器。
|
||||
func TestNewClientVerifier_WithCA(t *testing.T) {
|
||||
// 创建临时 CA 文件
|
||||
tempDir := t.TempDir()
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
|
||||
_, _, caPEM := generateTestCA(t)
|
||||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||||
t.Fatalf("Failed to write CA file: %v", err)
|
||||
}
|
||||
|
||||
// 测试各种模式
|
||||
modes := []string{"on", "optional"}
|
||||
for _, mode := range modes {
|
||||
t.Run(mode, func(t *testing.T) {
|
||||
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: mode,
|
||||
ClientCA: caFile,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientVerifier() failed: %v", err)
|
||||
}
|
||||
if !verifier.IsEnabled() {
|
||||
t.Error("Verifier should be enabled")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientVerifier_ConfigureTLS 测试 TLS 配置。
|
||||
func TestClientVerifier_ConfigureTLS(t *testing.T) {
|
||||
// 创建临时 CA 文件
|
||||
tempDir := t.TempDir()
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
|
||||
_, _, caPEM := generateTestCA(t)
|
||||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||||
t.Fatalf("Failed to write CA file: %v", err)
|
||||
}
|
||||
|
||||
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: "on",
|
||||
ClientCA: caFile,
|
||||
VerifyDepth: 2,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientVerifier() failed: %v", err)
|
||||
}
|
||||
|
||||
tlsCfg := &tls.Config{}
|
||||
verifier.ConfigureTLS(tlsCfg)
|
||||
|
||||
if tlsCfg.ClientAuth != tls.RequireAndVerifyClientCert {
|
||||
t.Errorf("ClientAuth = %v, want %v", tlsCfg.ClientAuth, tls.RequireAndVerifyClientCert)
|
||||
}
|
||||
if tlsCfg.ClientCAs == nil {
|
||||
t.Error("ClientCAs should be set")
|
||||
}
|
||||
|
||||
// 测试 nil 配置(不应 panic)
|
||||
verifier.ConfigureTLS(nil)
|
||||
}
|
||||
|
||||
// TestClientVerifier_ConfigureTLS_Disabled 测试禁用验证器。
|
||||
func TestClientVerifier_ConfigureTLS_Disabled(t *testing.T) {
|
||||
verifier, _ := NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: false,
|
||||
})
|
||||
|
||||
tlsCfg := &tls.Config{}
|
||||
verifier.ConfigureTLS(tlsCfg)
|
||||
|
||||
// 禁用时不应修改配置
|
||||
if tlsCfg.ClientAuth != 0 {
|
||||
t.Error("Disabled verifier should not modify TLS config")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetClientCertInfo 测试证书信息提取。
|
||||
func TestGetClientCertInfo(t *testing.T) {
|
||||
// 生成测试证书
|
||||
caCert, caKey, _ := generateTestCA(t)
|
||||
clientCert, _, _ := generateTestClientCert(t, caCert, caKey)
|
||||
|
||||
// 创建模拟连接状态
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{clientCert},
|
||||
}
|
||||
|
||||
info := GetClientCertInfo(cs)
|
||||
if info == nil {
|
||||
t.Fatal("GetClientCertInfo() returned nil")
|
||||
}
|
||||
|
||||
if info.Serial == "" {
|
||||
t.Error("Serial should not be empty")
|
||||
}
|
||||
if info.Subject == "" {
|
||||
t.Error("Subject should not be empty")
|
||||
}
|
||||
if info.Issuer == "" {
|
||||
t.Error("Issuer should not be empty")
|
||||
}
|
||||
|
||||
// 测试无证书
|
||||
emptyCs := &tls.ConnectionState{}
|
||||
info = GetClientCertInfo(emptyCs)
|
||||
if info != nil {
|
||||
t.Error("GetClientCertInfo() should return nil for no certificates")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetClientCertInfo_Nil 测试 nil 输入。
|
||||
func TestGetClientCertInfo_Nil(t *testing.T) {
|
||||
info := GetClientCertInfo(nil)
|
||||
if info != nil {
|
||||
t.Error("GetClientCertInfo(nil) should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLoadCACertPool 基准测试 CA 证书池加载。
|
||||
func BenchmarkLoadCACertPool(b *testing.B) {
|
||||
tempDir := b.TempDir()
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
|
||||
// 生成 CA
|
||||
_, _, caPEM := generateTestCA(&testing.T{})
|
||||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||||
b.Fatalf("写入 CA 文件失败: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := LoadCACertPool(caFile)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
411
internal/ssl/session_tickets.go
Normal file
411
internal/ssl/session_tickets.go
Normal file
@ -0,0 +1,411 @@
|
||||
// Package ssl 提供 SSL Session Tickets 支持。
|
||||
//
|
||||
// 该文件包含 TLS Session Tickets 密钥管理和轮换逻辑,包括:
|
||||
// - Session Ticket 密钥生成和加载
|
||||
// - 自动密钥轮换机制
|
||||
// - 多密钥保留策略(支持旧票据解密)
|
||||
// - 与 TLS 配置的集成
|
||||
//
|
||||
// Session Tickets 允许 TLS 1.3 会话恢复,避免完整握手,显著提升性能。
|
||||
// 密钥定期轮换增强安全性,同时保留旧密钥确保已发放的票据仍可解密。
|
||||
//
|
||||
// 作者:xfy
|
||||
package ssl
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
// ticketKeySize Session Ticket 密钥大小(字节)
|
||||
// TLS 1.3 使用 32 字节的 AES-256-GCM 密钥
|
||||
ticketKeySize = 32
|
||||
|
||||
// defaultRotateInterval 默认密钥轮换间隔
|
||||
defaultRotateInterval = time.Hour
|
||||
|
||||
// defaultRetainKeys 默认保留的密钥数量
|
||||
// 至少保留 2 个密钥(当前 + 上一个)
|
||||
defaultRetainKeys = 3
|
||||
|
||||
// minRetainKeys 最小保留密钥数量
|
||||
minRetainKeys = 2
|
||||
)
|
||||
|
||||
// SessionTicketManager Session Ticket 密钥管理器。
|
||||
//
|
||||
// 管理 Session Ticket 密钥的生命周期,包括生成、轮换、存储和加载。
|
||||
// 密钥按时间顺序排列,最新的密钥用于加密,所有密钥都可用于解密。
|
||||
type SessionTicketManager struct {
|
||||
// keys 密钥列表,按生成时间排序(最新的在最后)
|
||||
// [old1, old2, current] 或 [old1, old2, old3, current]
|
||||
keys [][]byte
|
||||
|
||||
// config 配置
|
||||
config config.SessionTicketsConfig
|
||||
|
||||
// rotateTimer 密钥轮换定时器
|
||||
rotateTimer *time.Timer
|
||||
|
||||
// stopCh 停止信号通道
|
||||
stopCh chan struct{}
|
||||
|
||||
// mu 保护并发访问的读写锁
|
||||
mu sync.RWMutex
|
||||
|
||||
// started 是否已启动
|
||||
started bool
|
||||
}
|
||||
|
||||
// NewSessionTicketManager 创建新的 Session Ticket 管理器。
|
||||
//
|
||||
// 根据配置创建管理器,如 key_file 存在则加载现有密钥,
|
||||
// 否则自动生成新密钥。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: Session Tickets 配置
|
||||
//
|
||||
// 返回值:
|
||||
// - *SessionTicketManager: 配置好的管理器
|
||||
// - error: 密钥加载或生成失败时返回错误
|
||||
func NewSessionTicketManager(cfg config.SessionTicketsConfig) (*SessionTicketManager, error) {
|
||||
if !cfg.Enabled {
|
||||
return nil, errors.New("session tickets are disabled")
|
||||
}
|
||||
|
||||
// 使用默认值
|
||||
rotateInterval := cfg.RotateInterval
|
||||
if rotateInterval <= 0 {
|
||||
rotateInterval = defaultRotateInterval
|
||||
}
|
||||
|
||||
retainKeys := cfg.RetainKeys
|
||||
if retainKeys < minRetainKeys {
|
||||
retainKeys = defaultRetainKeys
|
||||
}
|
||||
|
||||
manager := &SessionTicketManager{
|
||||
config: config.SessionTicketsConfig{
|
||||
Enabled: cfg.Enabled,
|
||||
KeyFile: cfg.KeyFile,
|
||||
RotateInterval: rotateInterval,
|
||||
RetainKeys: retainKeys,
|
||||
},
|
||||
keys: make([][]byte, 0, retainKeys),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 尝试加载或生成初始密钥
|
||||
if cfg.KeyFile != "" {
|
||||
if err := manager.loadOrGenerateKey(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize session ticket key: %w", err)
|
||||
}
|
||||
} else {
|
||||
// 没有指定密钥文件,生成内存中的密钥
|
||||
key, err := generateTicketKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session ticket key: %w", err)
|
||||
}
|
||||
manager.keys = append(manager.keys, key)
|
||||
}
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
// Start 启动密钥轮换定时器。
|
||||
//
|
||||
// 按照配置的 rotate_interval 定期生成新密钥。
|
||||
// 必须在调用 GetKeys 之前启动。
|
||||
func (m *SessionTicketManager) Start() {
|
||||
m.mu.Lock()
|
||||
if m.started {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
m.started = true
|
||||
m.mu.Unlock()
|
||||
|
||||
// 启动轮换定时器
|
||||
m.scheduleRotation()
|
||||
}
|
||||
|
||||
// Stop 停止密钥轮换定时器。
|
||||
//
|
||||
// 停止后不再进行自动密钥轮换,但现有密钥仍然有效。
|
||||
func (m *SessionTicketManager) Stop() {
|
||||
m.mu.Lock()
|
||||
if !m.started {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
m.started = false
|
||||
m.mu.Unlock()
|
||||
|
||||
close(m.stopCh)
|
||||
|
||||
if m.rotateTimer != nil {
|
||||
m.rotateTimer.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// GetKeys 返回当前所有有效的 Session Ticket 密钥。
|
||||
//
|
||||
// 返回的密钥按时间顺序排列,最新的在最后。
|
||||
// TLS 配置应该使用最新的密钥加密,所有密钥都可以解密。
|
||||
//
|
||||
// 返回值:
|
||||
// - [][]byte: 密钥列表,每个密钥 32 字节
|
||||
func (m *SessionTicketManager) GetKeys() [][]byte {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// 返回副本以防止外部修改
|
||||
result := make([][]byte, len(m.keys))
|
||||
for i, key := range m.keys {
|
||||
result[i] = make([]byte, len(key))
|
||||
copy(result[i], key)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// RotateKey 手动轮换 Session Ticket 密钥。
|
||||
//
|
||||
// 生成新密钥并添加到密钥列表,如果超过 retain_keys 数量则移除最旧的密钥。
|
||||
// 新密钥用于加密新票据,旧密钥仍可用于解密已发放的票据。
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 密钥生成失败时返回错误
|
||||
func (m *SessionTicketManager) RotateKey() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// 生成新密钥
|
||||
newKey, err := generateTicketKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate new ticket key: %w", err)
|
||||
}
|
||||
|
||||
// 添加新密钥
|
||||
m.keys = append(m.keys, newKey)
|
||||
|
||||
// 如果超过保留数量,移除最旧的密钥
|
||||
if len(m.keys) > m.config.RetainKeys {
|
||||
m.keys = m.keys[len(m.keys)-m.config.RetainKeys:]
|
||||
}
|
||||
|
||||
// 如果有密钥文件,保存所有密钥
|
||||
if m.config.KeyFile != "" {
|
||||
if err := m.saveKeys(); err != nil {
|
||||
// 保存失败不影响运行,记录错误即可
|
||||
// 这里可以考虑添加日志
|
||||
_ = err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyToTLSConfig 将 Session Ticket 密钥应用到 TLS 配置。
|
||||
//
|
||||
// 设置 tls.Config 的 SetSessionTicketKeys 回调,用于动态提供密钥。
|
||||
//
|
||||
// 参数:
|
||||
// - tlsCfg: TLS 配置对象
|
||||
func (m *SessionTicketManager) ApplyToTLSConfig(tlsCfg *tls.Config) {
|
||||
if tlsCfg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置会话票据密钥
|
||||
// Go 的 crypto/tls 使用 SetSessionTicketKeys 方法设置密钥
|
||||
// 需要转换为 [][32]byte 类型
|
||||
keys := m.GetKeys()
|
||||
ticketKeys := make([][32]byte, len(keys))
|
||||
for i, key := range keys {
|
||||
if len(key) >= 32 {
|
||||
copy(ticketKeys[i][:], key)
|
||||
}
|
||||
}
|
||||
tlsCfg.SetSessionTicketKeys(ticketKeys)
|
||||
}
|
||||
|
||||
// scheduleRotation 调度密钥轮换。
|
||||
//
|
||||
// 使用定时器在指定间隔后执行密钥轮换。
|
||||
func (m *SessionTicketManager) scheduleRotation() {
|
||||
if !m.started {
|
||||
return
|
||||
}
|
||||
|
||||
m.rotateTimer = time.AfterFunc(m.config.RotateInterval, func() {
|
||||
select {
|
||||
case <-m.stopCh:
|
||||
return
|
||||
default:
|
||||
_ = m.RotateKey()
|
||||
m.scheduleRotation()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// loadOrGenerateKey 从文件加载密钥或生成新密钥。
|
||||
//
|
||||
// 如果密钥文件存在,加载所有密钥;否则生成新密钥并保存。
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 加载或生成失败时返回错误
|
||||
func (m *SessionTicketManager) loadOrGenerateKey() error {
|
||||
// 尝试加载现有密钥
|
||||
if _, err := os.Stat(m.config.KeyFile); err == nil {
|
||||
// 文件存在,加载密钥
|
||||
if err := m.loadKeys(); err != nil {
|
||||
// 加载失败,生成新密钥
|
||||
return m.generateAndSaveKey()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 文件不存在,生成新密钥
|
||||
return m.generateAndSaveKey()
|
||||
}
|
||||
|
||||
// loadKeys 从文件加载所有密钥。
|
||||
//
|
||||
// 密钥文件格式:每个密钥 32 字节,连续存储
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 读取或解析失败时返回错误
|
||||
func (m *SessionTicketManager) loadKeys() error {
|
||||
data, err := os.ReadFile(m.config.KeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read key file: %w", err)
|
||||
}
|
||||
|
||||
// 解析密钥(每个 32 字节)
|
||||
if len(data) < ticketKeySize {
|
||||
return errors.New("key file too small")
|
||||
}
|
||||
|
||||
m.keys = make([][]byte, 0, m.config.RetainKeys)
|
||||
for i := 0; i+ticketKeySize <= len(data); i += ticketKeySize {
|
||||
key := make([]byte, ticketKeySize)
|
||||
copy(key, data[i:i+ticketKeySize])
|
||||
m.keys = append(m.keys, key)
|
||||
}
|
||||
|
||||
// 如果加载的密钥超过保留数量,只保留最新的
|
||||
if len(m.keys) > m.config.RetainKeys {
|
||||
m.keys = m.keys[len(m.keys)-m.config.RetainKeys:]
|
||||
}
|
||||
|
||||
// 确保至少有一个密钥
|
||||
if len(m.keys) == 0 {
|
||||
return errors.New("no valid keys found in file")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// saveKeys 将所有密钥保存到文件。
|
||||
//
|
||||
// 密钥文件格式:每个密钥 32 字节,连续存储
|
||||
// 文件权限设置为 0600(仅所有者可读写)
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 写入失败时返回错误
|
||||
func (m *SessionTicketManager) saveKeys() error {
|
||||
// 计算总大小
|
||||
totalSize := len(m.keys) * ticketKeySize
|
||||
data := make([]byte, 0, totalSize)
|
||||
|
||||
for _, key := range m.keys {
|
||||
data = append(data, key...)
|
||||
}
|
||||
|
||||
// 使用 0600 权限写入文件(敏感数据,限制访问)
|
||||
if err := os.WriteFile(m.config.KeyFile, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write key file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateAndSaveKey 生成新密钥并保存。
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 生成或保存失败时返回错误
|
||||
func (m *SessionTicketManager) generateAndSaveKey() error {
|
||||
key, err := generateTicketKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.keys = [][]byte{key}
|
||||
|
||||
if m.config.KeyFile != "" {
|
||||
if err := m.saveKeys(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateTicketKey 生成新的随机 Session Ticket 密钥。
|
||||
//
|
||||
// 使用 crypto/rand 生成加密安全的随机密钥。
|
||||
//
|
||||
// 返回值:
|
||||
// - []byte: 32 字节的随机密钥
|
||||
// - error: 随机数生成失败时返回错误
|
||||
func generateTicketKey() ([]byte, error) {
|
||||
key := make([]byte, ticketKeySize)
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate random key: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GetKeyStatus 返回当前密钥状态信息。
|
||||
//
|
||||
// 用于监控和调试,显示当前密钥数量和轮换状态。
|
||||
//
|
||||
// 返回值:
|
||||
// - SessionTicketStatus: 密钥状态信息
|
||||
type SessionTicketStatus struct {
|
||||
// KeyCount 当前密钥数量
|
||||
KeyCount int
|
||||
|
||||
// RetainKeys 配置的最大保留密钥数
|
||||
RetainKeys int
|
||||
|
||||
// RotateInterval 配置的轮换间隔
|
||||
RotateInterval time.Duration
|
||||
|
||||
// Started 管理器是否已启动
|
||||
Started bool
|
||||
}
|
||||
|
||||
// GetStatus 返回当前密钥状态。
|
||||
//
|
||||
// 返回值:
|
||||
// - SessionTicketStatus: 密钥状态信息
|
||||
func (m *SessionTicketManager) GetStatus() SessionTicketStatus {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return SessionTicketStatus{
|
||||
KeyCount: len(m.keys),
|
||||
RetainKeys: m.config.RetainKeys,
|
||||
RotateInterval: m.config.RotateInterval,
|
||||
Started: m.started,
|
||||
}
|
||||
}
|
||||
474
internal/ssl/session_tickets_test.go
Normal file
474
internal/ssl/session_tickets_test.go
Normal file
@ -0,0 +1,474 @@
|
||||
// Package ssl 提供 Session Tickets 的单元测试。
|
||||
//
|
||||
// 测试覆盖:
|
||||
// - 密钥生成和加载
|
||||
// - 密钥轮换逻辑
|
||||
// - 多密钥保留策略
|
||||
// - 与 TLS 配置的集成
|
||||
// - 边界条件和错误处理
|
||||
//
|
||||
// 作者:xfy
|
||||
package ssl
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// TestNewSessionTicketManager 测试创建 Session Ticket 管理器。
|
||||
func TestNewSessionTicketManager(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.SessionTicketsConfig
|
||||
wantError bool
|
||||
checkDefaults bool
|
||||
}{
|
||||
{
|
||||
name: "disabled_should_error",
|
||||
cfg: config.SessionTicketsConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "enabled_without_keyfile",
|
||||
cfg: config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
wantError: false,
|
||||
checkDefaults: true,
|
||||
},
|
||||
{
|
||||
name: "enabled_with_defaults",
|
||||
cfg: config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
KeyFile: "",
|
||||
RotateInterval: 0,
|
||||
RetainKeys: 0,
|
||||
},
|
||||
wantError: false,
|
||||
checkDefaults: true,
|
||||
},
|
||||
{
|
||||
name: "enabled_with_custom_values",
|
||||
cfg: config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
RotateInterval: 30 * time.Minute,
|
||||
RetainKeys: 5,
|
||||
},
|
||||
wantError: false,
|
||||
checkDefaults: false, // 使用自定义值,不检查默认值
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mgr, err := NewSessionTicketManager(tt.cfg)
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Errorf("NewSessionTicketManager() expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("NewSessionTicketManager() unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if mgr == nil {
|
||||
t.Error("NewSessionTicketManager() returned nil manager")
|
||||
return
|
||||
}
|
||||
defer mgr.Stop()
|
||||
|
||||
// 验证默认配置(仅当使用默认值时)
|
||||
if tt.checkDefaults {
|
||||
if mgr.config.RotateInterval != defaultRotateInterval {
|
||||
t.Errorf("RotateInterval = %v, want %v", mgr.config.RotateInterval, defaultRotateInterval)
|
||||
}
|
||||
if mgr.config.RetainKeys != defaultRetainKeys {
|
||||
t.Errorf("RetainKeys = %d, want %d", mgr.config.RetainKeys, defaultRetainKeys)
|
||||
}
|
||||
} else {
|
||||
// 验证自定义值被正确保留
|
||||
if mgr.config.RotateInterval != tt.cfg.RotateInterval {
|
||||
t.Errorf("RotateInterval = %v, want %v", mgr.config.RotateInterval, tt.cfg.RotateInterval)
|
||||
}
|
||||
if mgr.config.RetainKeys != tt.cfg.RetainKeys {
|
||||
t.Errorf("RetainKeys = %d, want %d", mgr.config.RetainKeys, tt.cfg.RetainKeys)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionTicketManager_KeyGeneration 测试密钥生成。
|
||||
func TestSessionTicketManager_KeyGeneration(t *testing.T) {
|
||||
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
defer mgr.Stop()
|
||||
|
||||
keys := mgr.GetKeys()
|
||||
if len(keys) == 0 {
|
||||
t.Fatal("Expected at least one key, got none")
|
||||
}
|
||||
|
||||
// 验证密钥大小
|
||||
for i, key := range keys {
|
||||
if len(key) != ticketKeySize {
|
||||
t.Errorf("Key %d size = %d, want %d", i, len(key), ticketKeySize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionTicketManager_KeyRotation 测试密钥轮换。
|
||||
func TestSessionTicketManager_KeyRotation(t *testing.T) {
|
||||
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
RotateInterval: time.Hour, // 使用长间隔,手动触发轮换
|
||||
RetainKeys: 3,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
defer mgr.Stop()
|
||||
|
||||
initialKeys := mgr.GetKeys()
|
||||
if len(initialKeys) != 1 {
|
||||
t.Fatalf("Expected 1 initial key, got %d", len(initialKeys))
|
||||
}
|
||||
|
||||
// 手动轮换密钥
|
||||
if err := mgr.RotateKey(); err != nil {
|
||||
t.Fatalf("RotateKey() failed: %v", err)
|
||||
}
|
||||
|
||||
keysAfter1 := mgr.GetKeys()
|
||||
if len(keysAfter1) != 2 {
|
||||
t.Errorf("Expected 2 keys after rotation, got %d", len(keysAfter1))
|
||||
}
|
||||
|
||||
// 验证新旧密钥不同
|
||||
if string(initialKeys[0]) == string(keysAfter1[1]) {
|
||||
t.Error("New key should be different from initial key")
|
||||
}
|
||||
|
||||
// 继续轮换到超过保留数量
|
||||
_ = mgr.RotateKey()
|
||||
_ = mgr.RotateKey()
|
||||
|
||||
keysAfter4 := mgr.GetKeys()
|
||||
if len(keysAfter4) != 3 {
|
||||
t.Errorf("Expected 3 keys (max retain), got %d", len(keysAfter4))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionTicketManager_KeyRetention 测试密钥保留策略。
|
||||
func TestSessionTicketManager_KeyRetention(t *testing.T) {
|
||||
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
RotateInterval: time.Hour,
|
||||
RetainKeys: 2,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
defer mgr.Stop()
|
||||
|
||||
// 生成多个密钥
|
||||
for i := 0; i < 5; i++ {
|
||||
if err := mgr.RotateKey(); err != nil {
|
||||
t.Fatalf("RotateKey() failed at iteration %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
keys := mgr.GetKeys()
|
||||
if len(keys) != 2 {
|
||||
t.Errorf("Expected 2 keys (RetainKeys limit), got %d", len(keys))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionTicketManager_Persistence 测试密钥持久化。
|
||||
func TestSessionTicketManager_Persistence(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
keyFile := filepath.Join(tempDir, "ticket.key")
|
||||
|
||||
// 创建管理器并生成密钥
|
||||
mgr1, err := NewSessionTicketManager(config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
KeyFile: keyFile,
|
||||
RotateInterval: time.Hour,
|
||||
RetainKeys: 3,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
// 轮换几次生成多个密钥
|
||||
_ = mgr1.RotateKey()
|
||||
_ = mgr1.RotateKey()
|
||||
mgr1.Stop()
|
||||
|
||||
// 验证密钥文件存在
|
||||
if _, err := os.Stat(keyFile); os.IsNotExist(err) {
|
||||
t.Fatal("Key file should exist after saving")
|
||||
}
|
||||
|
||||
// 从文件加载密钥创建新管理器
|
||||
mgr2, err := NewSessionTicketManager(config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
KeyFile: keyFile,
|
||||
RotateInterval: time.Hour,
|
||||
RetainKeys: 3,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager from existing key file: %v", err)
|
||||
}
|
||||
defer mgr2.Stop()
|
||||
|
||||
keys := mgr2.GetKeys()
|
||||
if len(keys) != 3 {
|
||||
t.Errorf("Expected 3 keys loaded from file, got %d", len(keys))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionTicketManager_ApplyToTLSConfig 测试应用到 TLS 配置。
|
||||
func TestSessionTicketManager_ApplyToTLSConfig(t *testing.T) {
|
||||
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
RotateInterval: time.Hour,
|
||||
RetainKeys: 3,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
defer mgr.Stop()
|
||||
|
||||
tlsCfg := &tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
mgr.ApplyToTLSConfig(tlsCfg)
|
||||
|
||||
// 验证可以获取密钥
|
||||
keys := mgr.GetKeys()
|
||||
if len(keys) == 0 {
|
||||
t.Error("Expected keys to be set in TLS config")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionTicketManager_StartStop 测试启动和停止。
|
||||
func TestSessionTicketManager_StartStop(t *testing.T) {
|
||||
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
RotateInterval: 100 * time.Millisecond,
|
||||
RetainKeys: 3,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
// 验证初始状态
|
||||
status := mgr.GetStatus()
|
||||
if status.Started {
|
||||
t.Error("Manager should not be started initially")
|
||||
}
|
||||
|
||||
// 启动
|
||||
mgr.Start()
|
||||
status = mgr.GetStatus()
|
||||
if !status.Started {
|
||||
t.Error("Manager should be started after Start()")
|
||||
}
|
||||
|
||||
// 等待一次轮换
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
keys := mgr.GetKeys()
|
||||
if len(keys) < 1 {
|
||||
t.Error("Expected at least 1 key after auto-rotation")
|
||||
}
|
||||
|
||||
// 停止
|
||||
mgr.Stop()
|
||||
status = mgr.GetStatus()
|
||||
if status.Started {
|
||||
t.Error("Manager should not be started after Stop()")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionTicketManager_GetStatus 测试获取状态。
|
||||
func TestSessionTicketManager_GetStatus(t *testing.T) {
|
||||
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
RotateInterval: 30 * time.Minute,
|
||||
RetainKeys: 5,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
defer mgr.Stop()
|
||||
|
||||
status := mgr.GetStatus()
|
||||
|
||||
if status.KeyCount != 1 {
|
||||
t.Errorf("KeyCount = %d, want 1", status.KeyCount)
|
||||
}
|
||||
// 使用自定义值 5,不是默认值
|
||||
if status.RetainKeys != 5 {
|
||||
t.Errorf("RetainKeys = %d, want 5", status.RetainKeys)
|
||||
}
|
||||
// RotateInterval 使用配置值(30m > 0,所以保留)
|
||||
if status.RotateInterval != 30*time.Minute {
|
||||
t.Errorf("RotateInterval = %v, want %v", status.RotateInterval, 30*time.Minute)
|
||||
}
|
||||
if status.Started {
|
||||
t.Error("Started should be false before Start()")
|
||||
}
|
||||
|
||||
mgr.Start()
|
||||
status = mgr.GetStatus()
|
||||
if !status.Started {
|
||||
t.Error("Started should be true after Start()")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateTicketKey 测试密钥生成函数。
|
||||
func TestGenerateTicketKey(t *testing.T) {
|
||||
key1, err := generateTicketKey()
|
||||
if err != nil {
|
||||
t.Fatalf("generateTicketKey() failed: %v", err)
|
||||
}
|
||||
|
||||
if len(key1) != ticketKeySize {
|
||||
t.Errorf("generateTicketKey() key size = %d, want %d", len(key1), ticketKeySize)
|
||||
}
|
||||
|
||||
key2, err := generateTicketKey()
|
||||
if err != nil {
|
||||
t.Fatalf("generateTicketKey() second call failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证生成的密钥是随机的(不相同)
|
||||
if string(key1) == string(key2) {
|
||||
t.Error("generateTicketKey() should generate random keys")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionTicketManager_ConcurrentAccess 测试并发访问。
|
||||
func TestSessionTicketManager_ConcurrentAccess(t *testing.T) {
|
||||
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
RotateInterval: 10 * time.Millisecond,
|
||||
RetainKeys: 3,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
defer mgr.Stop()
|
||||
|
||||
mgr.Start()
|
||||
|
||||
// 并发读取和轮换
|
||||
done := make(chan bool, 3)
|
||||
|
||||
// 协程 1: 持续获取密钥
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
_ = mgr.GetKeys()
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// 协程 2: 持续获取状态
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
_ = mgr.GetStatus()
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// 协程 3: 手动轮换
|
||||
go func() {
|
||||
for i := 0; i < 20; i++ {
|
||||
_ = mgr.RotateKey()
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// 等待所有协程完成
|
||||
for i := 0; i < 3; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// 验证最终状态
|
||||
keys := mgr.GetKeys()
|
||||
if len(keys) < 1 || len(keys) > 3 {
|
||||
t.Errorf("Final key count %d out of expected range [1, 3]", len(keys))
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkGenerateTicketKey 基准测试密钥生成。
|
||||
func BenchmarkGenerateTicketKey(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := generateTicketKey()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSessionTicketManager_GetKeys 基准测试获取密钥。
|
||||
func BenchmarkSessionTicketManager_GetKeys(b *testing.B) {
|
||||
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
RotateInterval: time.Hour,
|
||||
RetainKeys: 3,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
defer mgr.Stop()
|
||||
|
||||
// 预生成多个密钥
|
||||
for i := 0; i < 2; i++ {
|
||||
_ = mgr.RotateKey()
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = mgr.GetKeys()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSessionTicketManager_RotateKey 基准测试密钥轮换。
|
||||
func BenchmarkSessionTicketManager_RotateKey(b *testing.B) {
|
||||
mgr, err := NewSessionTicketManager(config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
RotateInterval: time.Hour,
|
||||
RetainKeys: 3,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
defer mgr.Stop()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
err := mgr.RotateKey()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -62,6 +62,12 @@ type TLSManager struct {
|
||||
// ocspManager OCSP Stapling 管理器
|
||||
ocspManager *OCSPManager
|
||||
|
||||
// sessionTicketMgr Session Ticket 管理器
|
||||
sessionTicketMgr *SessionTicketManager
|
||||
|
||||
// clientVerifier 客户端证书验证器
|
||||
clientVerifier *ClientVerifier
|
||||
|
||||
// certificates 解析后的证书映射,用于 OCSP
|
||||
certificates map[string]*x509.Certificate
|
||||
|
||||
@ -132,6 +138,21 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
|
||||
issuers: make(map[string]*x509.Certificate),
|
||||
}
|
||||
|
||||
// 初始化 Session Tickets(如果启用)
|
||||
if cfg.SessionTickets.Enabled {
|
||||
sessionTicketMgr, err := NewSessionTicketManager(cfg.SessionTickets)
|
||||
if err != nil {
|
||||
// Session Tickets 初始化失败不阻止 TLS 工作
|
||||
// 可以记录日志
|
||||
_ = err
|
||||
} else {
|
||||
manager.sessionTicketMgr = sessionTicketMgr
|
||||
// 应用 Session Tickets 到 TLS 配置
|
||||
sessionTicketMgr.ApplyToTLSConfig(tlsCfg)
|
||||
sessionTicketMgr.Start()
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化 OCSP Stapling(如果启用)
|
||||
if cfg.OCSPStapling {
|
||||
ocspMgr := NewOCSPManager(DefaultOCSPConfig())
|
||||
@ -164,6 +185,19 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
|
||||
ocspMgr.Start()
|
||||
}
|
||||
|
||||
// 初始化客户端证书验证(如果启用)
|
||||
if cfg.ClientVerify.Enabled {
|
||||
clientVerifier, err := NewClientVerifier(cfg.ClientVerify)
|
||||
if err != nil {
|
||||
// 客户端验证配置失败不阻止 TLS 工作
|
||||
// 可以记录日志
|
||||
_ = err
|
||||
} else {
|
||||
manager.clientVerifier = clientVerifier
|
||||
clientVerifier.ConfigureTLS(tlsCfg)
|
||||
}
|
||||
}
|
||||
|
||||
// 设置为默认配置
|
||||
manager.defaultCfg = tlsCfg
|
||||
|
||||
@ -310,11 +344,14 @@ func (m *TLSManager) RemoveCertificate(name string) {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// Close 停止 OCSP 管理器并释放资源。
|
||||
// Close 停止 OCSP 管理器和 Session Ticket 管理器并释放资源。
|
||||
func (m *TLSManager) Close() {
|
||||
if m.ocspManager != nil {
|
||||
m.ocspManager.Stop()
|
||||
}
|
||||
if m.sessionTicketMgr != nil {
|
||||
m.sessionTicketMgr.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// getConfigForClientWithOCSP 返回启用 OCSP Stapling 的 TLS 配置。
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user