lolly/internal/ssl/ocsp.go

428 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package ssl 提供 SSL/TLS 支持。
//
// 该文件实现 OCSP Stapling 功能,用于在 TLS 握手时附加证书状态信息,
// 提高证书验证效率并减少客户端对 OCSP 服务器的直接查询。
//
// 主要功能:
// - OCSP 响应缓存:缓存证书状态响应,定期自动刷新
// - 优雅降级OCSP 查询失败时仍允许 TLS 连接
// - 自动重试:支持配置最大重试次数
// - 多服务器支持:尝试证书配置的多个 OCSP 服务器
//
// 使用示例:
//
// mgr := ssl.NewOCSPManager(ssl.DefaultOCSPConfig())
// mgr.Start()
// defer mgr.Stop()
//
// // 注册证书
// err := mgr.RegisterCertificate(cert, issuer)
//
// 作者xfy
package ssl
import (
"bytes"
"crypto"
"crypto/x509"
"errors"
"fmt"
"io"
"net/http"
"sync"
"time"
"golang.org/x/crypto/ocsp"
)
// OCSPManager OCSP Stapling 管理器。
//
// 管理 TLS 证书的 OCSP 响应缓存,支持定期自动刷新和优雅降级。
// 当 OCSP 查询失败时TLS 握手仍可继续进行。
type OCSPManager struct {
responses map[string]*ocspResponse
certs map[string]certPair
client *http.Client
stopChan chan struct{}
refreshInterval time.Duration
timeout time.Duration
maxRetries int
mu sync.RWMutex
running bool
}
type certPair struct {
cert *x509.Certificate
issuer *x509.Certificate
}
// ocspResponse OCSP 响应缓存条目。
//
// 保存 OCSP 响应数据及其元数据,用于证书状态验证。
type ocspResponse struct {
thisUpdate time.Time
nextUpdate time.Time
fetchedAt time.Time
response []byte
status OCSPStatus
errors int
}
// OCSPStatus OCSP 响应状态类型。
type OCSPStatus int
const (
statusValid OCSPStatus = iota // 响应有效且新鲜
statusStale // 响应过期但可用(优雅降级)
statusFailed // 无有效响应可用
)
// OCSPConfig OCSP 管理器配置。
type OCSPConfig struct {
Enabled bool // 是否启用 OCSP Stapling
RefreshInterval time.Duration // 刷新间隔默认1 小时)
Timeout time.Duration // HTTP 超时默认10 秒)
MaxRetries int // 失败时最大重试次数默认3
}
// DefaultOCSPConfig 返回默认的 OCSP 配置。
//
// 该函数提供一组适用于大多数生产环境的默认 OCSP Stapling 设置:
// - Enabled: true默认启用 OCSP Stapling
// - RefreshInterval: 1 小时(响应刷新间隔)
// - Timeout: 10 秒HTTP 请求超时)
// - MaxRetries: 3失败重试次数
//
// 返回值:
// - *OCSPConfig: 包含默认值的 OCSP 配置指针
func DefaultOCSPConfig() *OCSPConfig {
return &OCSPConfig{
Enabled: true,
RefreshInterval: 1 * time.Hour,
Timeout: 10 * time.Second,
MaxRetries: 3,
}
}
// NewOCSPManager 创建新的 OCSP 管理器。
//
// 如果配置为 nil则使用默认配置。
//
// 参数:
// - cfg: OCSP 配置
//
// 返回值:
// - *OCSPManager: 初始化的 OCSP 管理器
func NewOCSPManager(cfg *OCSPConfig) *OCSPManager {
if cfg == nil {
cfg = DefaultOCSPConfig()
}
// 应用默认值
refreshInterval := cfg.RefreshInterval
if refreshInterval == 0 {
refreshInterval = 1 * time.Hour
}
timeout := cfg.Timeout
if timeout == 0 {
timeout = 10 * time.Second
}
maxRetries := cfg.MaxRetries
if maxRetries == 0 {
maxRetries = 3
}
return &OCSPManager{
responses: make(map[string]*ocspResponse),
certs: make(map[string]certPair),
client: &http.Client{
Timeout: timeout,
},
refreshInterval: refreshInterval,
timeout: timeout,
maxRetries: maxRetries,
stopChan: make(chan struct{}),
}
}
// Start 启动 OCSP 定期刷新进程。
func (m *OCSPManager) Start() {
m.mu.Lock()
if m.running {
m.mu.Unlock()
return
}
m.running = true
m.mu.Unlock()
go m.refreshLoop()
}
// Stop 停止 OCSP 刷新进程。
func (m *OCSPManager) Stop() {
m.mu.Lock()
if !m.running {
m.mu.Unlock()
return
}
m.running = false
m.mu.Unlock()
close(m.stopChan)
}
// refreshLoop 定期刷新所有 OCSP 响应。
func (m *OCSPManager) refreshLoop() {
ticker := time.NewTicker(m.refreshInterval)
defer ticker.Stop()
for {
select {
case <-m.stopChan:
return
case <-ticker.C:
m.refreshAll()
}
}
}
// refreshAll 刷新所有缓存的 OCSP 响应。
func (m *OCSPManager) refreshAll() {
var toRefresh []string
m.mu.RLock()
for serial, resp := range m.responses {
if resp == nil || time.Now().After(resp.nextUpdate) || resp.status != statusValid {
toRefresh = append(toRefresh, serial)
}
}
m.mu.RUnlock()
for _, serial := range toRefresh {
m.mu.RLock()
pair, ok := m.certs[serial]
m.mu.RUnlock()
if !ok {
continue
}
response, err := m.fetchOCSP(pair.cert, pair.issuer)
m.mu.Lock()
if err != nil {
if existing, exists := m.responses[serial]; exists {
existing.errors++
if existing.errors >= m.maxRetries {
existing.status = statusFailed
}
}
} else {
m.responses[serial] = response
}
m.mu.Unlock()
}
}
// RegisterCertificate 注册证书以进行 OCSP Stapling。
//
// 从证书中提取 OCSP URL 并获取初始响应。如果获取失败,仍注册为失败状态
// 以支持优雅降级,允许 TLS 连接继续进行。
//
// 参数:
// - cert: 待注册的证书
// - issuer: 颁发者证书
//
// 返回值:
// - error: 证书为空或无 OCSP 服务器时返回错误
func (m *OCSPManager) RegisterCertificate(cert *x509.Certificate, issuer *x509.Certificate) error {
if cert == nil {
return errors.New("certificate is nil")
}
// 检查证书是否有 OCSP 服务器 URL
if len(cert.OCSPServer) == 0 {
return errors.New("certificate has no OCSP server URL")
}
serial := cert.SerialNumber.String()
m.mu.Lock()
m.certs[serial] = certPair{cert: cert, issuer: issuer}
m.mu.Unlock()
// 获取初始 OCSP 响应
response, err := m.fetchOCSP(cert, issuer)
if err != nil {
// 优雅降级:注册为失败状态但允许 TLS 继续
m.mu.Lock()
m.responses[serial] = &ocspResponse{
status: statusFailed,
fetchedAt: time.Now(),
errors: 1,
}
m.mu.Unlock()
return fmt.Errorf("failed to fetch OCSP response: %w", err)
}
m.mu.Lock()
m.responses[serial] = response
m.mu.Unlock()
return nil
}
// fetchOCSP 从证书的 OCSP 服务器获取 OCSP 响应。
//
// 参数:
// - cert: 待查询的证书
// - issuer: 颁发者证书
//
// 返回值:
// - *ocspResponse: OCSP 响应
// - error: 获取失败时返回错误
func (m *OCSPManager) fetchOCSP(cert, issuer *x509.Certificate) (*ocspResponse, error) {
if len(cert.OCSPServer) == 0 {
return nil, errors.New("no OCSP server in certificate")
}
// 创建 OCSP 请求
ocspReq, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{
Hash: crypto.SHA256,
})
if err != nil {
return nil, fmt.Errorf("failed to create OCSP request: %w", err)
}
// 尝试每个 OCSP 服务器 URL
var lastErr error
for _, url := range cert.OCSPServer {
resp, err := m.sendOCSPRequest(url, ocspReq)
if err != nil {
lastErr = err
continue
}
// 解析并验证响应
ocspResp, err := ocsp.ParseResponse(resp, issuer)
if err != nil {
lastErr = fmt.Errorf("failed to parse OCSP response: %w", err)
continue
}
// 检查响应状态
if ocspResp.Status != ocsp.Good {
return nil, fmt.Errorf("certificate status is not good: %d", ocspResp.Status)
}
// 检查响应是否匹配证书
if !bytes.Equal(ocspResp.SerialNumber.Bytes(), cert.SerialNumber.Bytes()) {
return nil, errors.New("OCSP response serial number mismatch")
}
return &ocspResponse{
response: resp,
thisUpdate: ocspResp.ThisUpdate,
nextUpdate: ocspResp.NextUpdate,
status: statusValid,
fetchedAt: time.Now(),
errors: 0,
}, nil
}
return nil, fmt.Errorf("all OCSP servers failed: %w", lastErr)
}
// sendOCSPRequest 向指定 URL 发送 OCSP 请求。
//
// 参数:
// - url: OCSP 服务器 URL
// - req: OCSP 请求数据
//
// 返回值:
// - []byte: OCSP 响应数据
// - error: 请求失败时返回错误
func (m *OCSPManager) sendOCSPRequest(url string, req []byte) ([]byte, error) {
var lastErr error
for i := 0; i < m.maxRetries; i++ {
body, err := m.singleOCSPAttempt(url, req)
if err != nil {
lastErr = err
if i < m.maxRetries-1 {
time.Sleep(time.Duration(i+1) * time.Second)
continue
}
return nil, lastErr
}
return body, nil
}
return nil, lastErr
}
// singleOCSPAttempt 发送单次 OCSP HTTP 请求。
//
// 将单次请求逻辑提取为独立函数,确保 resp.Body 在每次调用结束时
// 通过 defer 正确关闭,避免在重试循环中累积未关闭的 response body。
func (m *OCSPManager) singleOCSPAttempt(url string, req []byte) ([]byte, error) {
httpReq, err := http.NewRequest("POST", url, bytes.NewReader(req))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/ocsp-request")
resp, err := m.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("HTTP request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("OCSP server returned status %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return body, nil
}
// GetOCSPResponse 返回证书的 OCSP 响应。
//
// 如果无有效响应则返回 nil优雅降级TLS 可继续)。
//
// 参数:
// - serial: 证书序列号字符串
//
// 返回值:
// - []byte: OCSP 响应数据,无响应时返回 nil
func (m *OCSPManager) GetOCSPResponse(serial string) []byte {
m.mu.RLock()
resp, ok := m.responses[serial]
m.mu.RUnlock()
if !ok || resp == nil {
return nil
}
// 检查响应是否仍可用
switch resp.status {
case statusValid:
// 检查过期
if time.Now().After(resp.nextUpdate) {
// 标记为过期但仍返回(优雅降级)
m.mu.Lock()
resp.status = statusStale
m.mu.Unlock()
}
return resp.response
case statusStale:
// 返回过期响应用于优雅降级
// 这允许即使 OCSP 刷新失败也能继续 TLS 握手
return resp.response
case statusFailed:
// 无响应可用
return nil
default:
return nil
}
}