428 lines
10 KiB
Go
428 lines
10 KiB
Go
// Package ssl 提供 SSL/TLS 支持。
|
||
//
|
||
// 该文件实现 OCSP Stapling 功能,用于在 TLS 握手时附加证书状态信息,
|
||
// 提高证书验证效率并减少客户端对 OCSP 服务器的直接查询。
|
||
//
|
||
// 主要功能:
|
||
// - OCSP 响应缓存:缓存证书状态响应,定期自动刷新
|
||
// - 优雅降级:OCSP 查询失败时仍允许 TLS 连接
|
||
// - 自动重试:支持配置最大重试次数
|
||
// - 多服务器支持:尝试证书配置的多个 OCSP 服务器
|
||
//
|
||
// 使用示例:
|
||
//
|
||
// mgr := ssl.NewOCSPManager(ssl.DefaultOCSPConfig())
|
||
// mgr.Start()
|
||
// defer mgr.Stop()
|
||
//
|
||
// // 注册证书
|
||
// err := mgr.RegisterCertificate(cert, issuer)
|
||
//
|
||
// 作者:xfy
|
||
package ssl
|
||
|
||
import (
|
||
"bytes"
|
||
"crypto"
|
||
"crypto/x509"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"sync"
|
||
"time"
|
||
|
||
"golang.org/x/crypto/ocsp"
|
||
)
|
||
|
||
// OCSPManager OCSP Stapling 管理器。
|
||
//
|
||
// 管理 TLS 证书的 OCSP 响应缓存,支持定期自动刷新和优雅降级。
|
||
// 当 OCSP 查询失败时,TLS 握手仍可继续进行。
|
||
type OCSPManager struct {
|
||
responses map[string]*ocspResponse
|
||
certs map[string]certPair
|
||
client *http.Client
|
||
stopChan chan struct{}
|
||
refreshInterval time.Duration
|
||
timeout time.Duration
|
||
maxRetries int
|
||
mu sync.RWMutex
|
||
running bool
|
||
}
|
||
|
||
type certPair struct {
|
||
cert *x509.Certificate
|
||
issuer *x509.Certificate
|
||
}
|
||
|
||
// ocspResponse OCSP 响应缓存条目。
|
||
//
|
||
// 保存 OCSP 响应数据及其元数据,用于证书状态验证。
|
||
type ocspResponse struct {
|
||
thisUpdate time.Time
|
||
nextUpdate time.Time
|
||
fetchedAt time.Time
|
||
response []byte
|
||
status OCSPStatus
|
||
errors int
|
||
}
|
||
|
||
// OCSPStatus OCSP 响应状态类型。
|
||
type OCSPStatus int
|
||
|
||
const (
|
||
statusValid OCSPStatus = iota // 响应有效且新鲜
|
||
statusStale // 响应过期但可用(优雅降级)
|
||
statusFailed // 无有效响应可用
|
||
)
|
||
|
||
// OCSPConfig OCSP 管理器配置。
|
||
type OCSPConfig struct {
|
||
Enabled bool // 是否启用 OCSP Stapling
|
||
RefreshInterval time.Duration // 刷新间隔(默认:1 小时)
|
||
Timeout time.Duration // HTTP 超时(默认:10 秒)
|
||
MaxRetries int // 失败时最大重试次数(默认:3)
|
||
}
|
||
|
||
// DefaultOCSPConfig 返回默认的 OCSP 配置。
|
||
//
|
||
// 该函数提供一组适用于大多数生产环境的默认 OCSP Stapling 设置:
|
||
// - Enabled: true(默认启用 OCSP Stapling)
|
||
// - RefreshInterval: 1 小时(响应刷新间隔)
|
||
// - Timeout: 10 秒(HTTP 请求超时)
|
||
// - MaxRetries: 3(失败重试次数)
|
||
//
|
||
// 返回值:
|
||
// - *OCSPConfig: 包含默认值的 OCSP 配置指针
|
||
func DefaultOCSPConfig() *OCSPConfig {
|
||
return &OCSPConfig{
|
||
Enabled: true,
|
||
RefreshInterval: 1 * time.Hour,
|
||
Timeout: 10 * time.Second,
|
||
MaxRetries: 3,
|
||
}
|
||
}
|
||
|
||
// NewOCSPManager 创建新的 OCSP 管理器。
|
||
//
|
||
// 如果配置为 nil,则使用默认配置。
|
||
//
|
||
// 参数:
|
||
// - cfg: OCSP 配置
|
||
//
|
||
// 返回值:
|
||
// - *OCSPManager: 初始化的 OCSP 管理器
|
||
func NewOCSPManager(cfg *OCSPConfig) *OCSPManager {
|
||
if cfg == nil {
|
||
cfg = DefaultOCSPConfig()
|
||
}
|
||
|
||
// 应用默认值
|
||
refreshInterval := cfg.RefreshInterval
|
||
if refreshInterval == 0 {
|
||
refreshInterval = 1 * time.Hour
|
||
}
|
||
timeout := cfg.Timeout
|
||
if timeout == 0 {
|
||
timeout = 10 * time.Second
|
||
}
|
||
maxRetries := cfg.MaxRetries
|
||
if maxRetries == 0 {
|
||
maxRetries = 3
|
||
}
|
||
|
||
return &OCSPManager{
|
||
responses: make(map[string]*ocspResponse),
|
||
certs: make(map[string]certPair),
|
||
client: &http.Client{
|
||
Timeout: timeout,
|
||
},
|
||
refreshInterval: refreshInterval,
|
||
timeout: timeout,
|
||
maxRetries: maxRetries,
|
||
stopChan: make(chan struct{}),
|
||
}
|
||
}
|
||
|
||
// Start 启动 OCSP 定期刷新进程。
|
||
func (m *OCSPManager) Start() {
|
||
m.mu.Lock()
|
||
if m.running {
|
||
m.mu.Unlock()
|
||
return
|
||
}
|
||
m.running = true
|
||
m.mu.Unlock()
|
||
|
||
go m.refreshLoop()
|
||
}
|
||
|
||
// Stop 停止 OCSP 刷新进程。
|
||
func (m *OCSPManager) Stop() {
|
||
m.mu.Lock()
|
||
if !m.running {
|
||
m.mu.Unlock()
|
||
return
|
||
}
|
||
m.running = false
|
||
m.mu.Unlock()
|
||
|
||
close(m.stopChan)
|
||
}
|
||
|
||
// refreshLoop 定期刷新所有 OCSP 响应。
|
||
func (m *OCSPManager) refreshLoop() {
|
||
ticker := time.NewTicker(m.refreshInterval)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-m.stopChan:
|
||
return
|
||
case <-ticker.C:
|
||
m.refreshAll()
|
||
}
|
||
}
|
||
}
|
||
|
||
// refreshAll 刷新所有缓存的 OCSP 响应。
|
||
func (m *OCSPManager) refreshAll() {
|
||
var toRefresh []string
|
||
m.mu.RLock()
|
||
for serial, resp := range m.responses {
|
||
if resp == nil || time.Now().After(resp.nextUpdate) || resp.status != statusValid {
|
||
toRefresh = append(toRefresh, serial)
|
||
}
|
||
}
|
||
m.mu.RUnlock()
|
||
|
||
for _, serial := range toRefresh {
|
||
m.mu.RLock()
|
||
pair, ok := m.certs[serial]
|
||
m.mu.RUnlock()
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
response, err := m.fetchOCSP(pair.cert, pair.issuer)
|
||
m.mu.Lock()
|
||
if err != nil {
|
||
if existing, exists := m.responses[serial]; exists {
|
||
existing.errors++
|
||
if existing.errors >= m.maxRetries {
|
||
existing.status = statusFailed
|
||
}
|
||
}
|
||
} else {
|
||
m.responses[serial] = response
|
||
}
|
||
m.mu.Unlock()
|
||
}
|
||
}
|
||
|
||
// RegisterCertificate 注册证书以进行 OCSP Stapling。
|
||
//
|
||
// 从证书中提取 OCSP URL 并获取初始响应。如果获取失败,仍注册为失败状态
|
||
// 以支持优雅降级,允许 TLS 连接继续进行。
|
||
//
|
||
// 参数:
|
||
// - cert: 待注册的证书
|
||
// - issuer: 颁发者证书
|
||
//
|
||
// 返回值:
|
||
// - error: 证书为空或无 OCSP 服务器时返回错误
|
||
func (m *OCSPManager) RegisterCertificate(cert *x509.Certificate, issuer *x509.Certificate) error {
|
||
if cert == nil {
|
||
return errors.New("certificate is nil")
|
||
}
|
||
|
||
// 检查证书是否有 OCSP 服务器 URL
|
||
if len(cert.OCSPServer) == 0 {
|
||
return errors.New("certificate has no OCSP server URL")
|
||
}
|
||
|
||
serial := cert.SerialNumber.String()
|
||
|
||
m.mu.Lock()
|
||
m.certs[serial] = certPair{cert: cert, issuer: issuer}
|
||
m.mu.Unlock()
|
||
|
||
// 获取初始 OCSP 响应
|
||
response, err := m.fetchOCSP(cert, issuer)
|
||
if err != nil {
|
||
// 优雅降级:注册为失败状态但允许 TLS 继续
|
||
m.mu.Lock()
|
||
m.responses[serial] = &ocspResponse{
|
||
status: statusFailed,
|
||
fetchedAt: time.Now(),
|
||
errors: 1,
|
||
}
|
||
m.mu.Unlock()
|
||
return fmt.Errorf("failed to fetch OCSP response: %w", err)
|
||
}
|
||
|
||
m.mu.Lock()
|
||
m.responses[serial] = response
|
||
m.mu.Unlock()
|
||
|
||
return nil
|
||
}
|
||
|
||
// fetchOCSP 从证书的 OCSP 服务器获取 OCSP 响应。
|
||
//
|
||
// 参数:
|
||
// - cert: 待查询的证书
|
||
// - issuer: 颁发者证书
|
||
//
|
||
// 返回值:
|
||
// - *ocspResponse: OCSP 响应
|
||
// - error: 获取失败时返回错误
|
||
func (m *OCSPManager) fetchOCSP(cert, issuer *x509.Certificate) (*ocspResponse, error) {
|
||
if len(cert.OCSPServer) == 0 {
|
||
return nil, errors.New("no OCSP server in certificate")
|
||
}
|
||
|
||
// 创建 OCSP 请求
|
||
ocspReq, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{
|
||
Hash: crypto.SHA256,
|
||
})
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create OCSP request: %w", err)
|
||
}
|
||
|
||
// 尝试每个 OCSP 服务器 URL
|
||
var lastErr error
|
||
for _, url := range cert.OCSPServer {
|
||
resp, err := m.sendOCSPRequest(url, ocspReq)
|
||
if err != nil {
|
||
lastErr = err
|
||
continue
|
||
}
|
||
|
||
// 解析并验证响应
|
||
ocspResp, err := ocsp.ParseResponse(resp, issuer)
|
||
if err != nil {
|
||
lastErr = fmt.Errorf("failed to parse OCSP response: %w", err)
|
||
continue
|
||
}
|
||
|
||
// 检查响应状态
|
||
if ocspResp.Status != ocsp.Good {
|
||
return nil, fmt.Errorf("certificate status is not good: %d", ocspResp.Status)
|
||
}
|
||
|
||
// 检查响应是否匹配证书
|
||
if !bytes.Equal(ocspResp.SerialNumber.Bytes(), cert.SerialNumber.Bytes()) {
|
||
return nil, errors.New("OCSP response serial number mismatch")
|
||
}
|
||
|
||
return &ocspResponse{
|
||
response: resp,
|
||
thisUpdate: ocspResp.ThisUpdate,
|
||
nextUpdate: ocspResp.NextUpdate,
|
||
status: statusValid,
|
||
fetchedAt: time.Now(),
|
||
errors: 0,
|
||
}, nil
|
||
}
|
||
|
||
return nil, fmt.Errorf("all OCSP servers failed: %w", lastErr)
|
||
}
|
||
|
||
// sendOCSPRequest 向指定 URL 发送 OCSP 请求。
|
||
//
|
||
// 参数:
|
||
// - url: OCSP 服务器 URL
|
||
// - req: OCSP 请求数据
|
||
//
|
||
// 返回值:
|
||
// - []byte: OCSP 响应数据
|
||
// - error: 请求失败时返回错误
|
||
func (m *OCSPManager) sendOCSPRequest(url string, req []byte) ([]byte, error) {
|
||
var lastErr error
|
||
for i := 0; i < m.maxRetries; i++ {
|
||
body, err := m.singleOCSPAttempt(url, req)
|
||
if err != nil {
|
||
lastErr = err
|
||
if i < m.maxRetries-1 {
|
||
time.Sleep(time.Duration(i+1) * time.Second)
|
||
continue
|
||
}
|
||
return nil, lastErr
|
||
}
|
||
return body, nil
|
||
}
|
||
return nil, lastErr
|
||
}
|
||
|
||
// singleOCSPAttempt 发送单次 OCSP HTTP 请求。
|
||
//
|
||
// 将单次请求逻辑提取为独立函数,确保 resp.Body 在每次调用结束时
|
||
// 通过 defer 正确关闭,避免在重试循环中累积未关闭的 response body。
|
||
func (m *OCSPManager) singleOCSPAttempt(url string, req []byte) ([]byte, error) {
|
||
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(req))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||
}
|
||
httpReq.Header.Set("Content-Type", "application/ocsp-request")
|
||
|
||
resp, err := m.client.Do(httpReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("HTTP request failed: %w", err)
|
||
}
|
||
defer func() { _ = resp.Body.Close() }()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return nil, fmt.Errorf("OCSP server returned status %d", resp.StatusCode)
|
||
}
|
||
|
||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||
}
|
||
|
||
return body, nil
|
||
}
|
||
|
||
// GetOCSPResponse 返回证书的 OCSP 响应。
|
||
//
|
||
// 如果无有效响应则返回 nil(优雅降级,TLS 可继续)。
|
||
//
|
||
// 参数:
|
||
// - serial: 证书序列号字符串
|
||
//
|
||
// 返回值:
|
||
// - []byte: OCSP 响应数据,无响应时返回 nil
|
||
func (m *OCSPManager) GetOCSPResponse(serial string) []byte {
|
||
m.mu.RLock()
|
||
resp, ok := m.responses[serial]
|
||
m.mu.RUnlock()
|
||
|
||
if !ok || resp == nil {
|
||
return nil
|
||
}
|
||
|
||
// 检查响应是否仍可用
|
||
switch resp.status {
|
||
case statusValid:
|
||
// 检查过期
|
||
if time.Now().After(resp.nextUpdate) {
|
||
// 标记为过期但仍返回(优雅降级)
|
||
m.mu.Lock()
|
||
resp.status = statusStale
|
||
m.mu.Unlock()
|
||
}
|
||
return resp.response
|
||
case statusStale:
|
||
// 返回过期响应用于优雅降级
|
||
// 这允许即使 OCSP 刷新失败也能继续 TLS 握手
|
||
return resp.response
|
||
case statusFailed:
|
||
// 无响应可用
|
||
return nil
|
||
default:
|
||
return nil
|
||
}
|
||
}
|