lolly/internal/ssl/ocsp_test.go
xfy ac9153f09d fix(proxy,stream,server): Phase 8 问题修复与功能完善
- WebSocket 代理集成:handleWebSocket 现调用 ProxyWebSocket 实现
- 删除 UDP Stream 冗余代码:移除 udpListener 类型及相关测试
- 热升级监听器继承:改用 net.Listen + Serve 模式支持监听器传递
- 代码格式修复:注释格式调整、字段对齐、文件末尾换行符

Co-Authored-By: Claude <noreply@anthropic.com>
2026-04-03 14:28:00 +08:00

451 lines
11 KiB
Go

package ssl
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"rua.plus/lolly/internal/config"
)
func TestNewOCSPManager(t *testing.T) {
cfg := DefaultOCSPConfig()
mgr := NewOCSPManager(cfg)
if mgr == nil {
t.Fatal("Expected non-nil OCSP manager")
}
if mgr.refreshInterval != 1*time.Hour {
t.Errorf("Expected refresh interval 1h, got %v", mgr.refreshInterval)
}
if mgr.timeout != 10*time.Second {
t.Errorf("Expected timeout 10s, got %v", mgr.timeout)
}
if mgr.maxRetries != 3 {
t.Errorf("Expected max retries 3, got %d", mgr.maxRetries)
}
}
func TestNewOCSPManagerWithCustomConfig(t *testing.T) {
cfg := &OCSPConfig{
Enabled: true,
RefreshInterval: 30 * time.Minute,
Timeout: 5 * time.Second,
MaxRetries: 5,
}
mgr := NewOCSPManager(cfg)
if mgr.refreshInterval != 30*time.Minute {
t.Errorf("Expected refresh interval 30m, got %v", mgr.refreshInterval)
}
if mgr.timeout != 5*time.Second {
t.Errorf("Expected timeout 5s, got %v", mgr.timeout)
}
if mgr.maxRetries != 5 {
t.Errorf("Expected max retries 5, got %d", mgr.maxRetries)
}
}
func TestOCSPManagerStartStop(t *testing.T) {
mgr := NewOCSPManager(nil)
mgr.Start()
mgr.mu.RLock()
running := mgr.running
mgr.mu.RUnlock()
if !running {
t.Error("Expected OCSP manager to be running")
}
mgr.Stop()
mgr.mu.RLock()
running = mgr.running
mgr.mu.RUnlock()
if running {
t.Error("Expected OCSP manager to be stopped")
}
}
func TestOCSPGetOCSPResponse(t *testing.T) {
mgr := NewOCSPManager(nil)
// Test non-existent serial
resp := mgr.GetOCSPResponse("nonexistent")
if resp != nil {
t.Error("Expected nil response for non-existent serial")
}
// Test with valid response
testResp := []byte("test-ocsp-response")
serial := "12345"
mgr.mu.Lock()
mgr.responses[serial] = &ocspResponse{
response: testResp,
thisUpdate: time.Now(),
nextUpdate: time.Now().Add(1 * time.Hour),
status: statusValid,
fetchedAt: time.Now(),
}
mgr.mu.Unlock()
resp = mgr.GetOCSPResponse(serial)
if resp == nil {
t.Error("Expected non-nil response")
}
if string(resp) != "test-ocsp-response" {
t.Errorf("Expected 'test-ocsp-response', got '%s'", string(resp))
}
}
func TestOCSPGetStatus(t *testing.T) {
mgr := NewOCSPManager(nil)
// Test non-existent serial
status, hasResponse := mgr.GetStatus("nonexistent")
if status != statusFailed || hasResponse {
t.Error("Expected statusFailed and no response for non-existent serial")
}
// Test with valid response
serial := "12345"
mgr.mu.Lock()
mgr.responses[serial] = &ocspResponse{
response: []byte("test"),
status: statusValid,
nextUpdate: time.Now().Add(1 * time.Hour),
}
mgr.mu.Unlock()
status, hasResponse = mgr.GetStatus(serial)
if status != statusValid || !hasResponse {
t.Error("Expected statusValid and hasResponse true")
}
}
func TestOCSPStaleResponse(t *testing.T) {
mgr := NewOCSPManager(nil)
serial := "12345"
testResp := []byte("stale-response")
// Create expired response
mgr.mu.Lock()
mgr.responses[serial] = &ocspResponse{
response: testResp,
thisUpdate: time.Now().Add(-2 * time.Hour),
nextUpdate: time.Now().Add(-1 * time.Hour), // Expired
status: statusValid,
fetchedAt: time.Now().Add(-2 * time.Hour),
}
mgr.mu.Unlock()
// Should return stale response (graceful degradation)
resp := mgr.GetOCSPResponse(serial)
if resp == nil {
t.Error("Expected stale response for graceful degradation")
}
// Status should now be stale
mgr.mu.RLock()
storedResp := mgr.responses[serial]
mgr.mu.RUnlock()
if storedResp.status != statusStale {
t.Error("Expected status to be marked as stale")
}
}
func TestOCSPFailedResponse(t *testing.T) {
mgr := NewOCSPManager(nil)
serial := "12345"
mgr.mu.Lock()
mgr.responses[serial] = &ocspResponse{
status: statusFailed,
fetchedAt: time.Now(),
errors: 3,
}
mgr.mu.Unlock()
resp := mgr.GetOCSPResponse(serial)
if resp != nil {
t.Error("Expected nil response for failed status")
}
}
func TestTLSManagerWithOCSPDisabled(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
certPEM, keyPEM := generateTestCertWithOCSP(t, nil)
if err := os.WriteFile(certPath, certPEM, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
OCSPStapling: false,
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
// OCSP manager should not be initialized
if manager.ocspManager != nil {
t.Error("Expected OCSP manager to be nil when disabled")
}
manager.Close()
}
func TestTLSManagerGetOCSPStatus(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
// Generate cert without OCSP server
certPEM, keyPEM := generateTestCertWithOCSP(t, nil)
if err := os.WriteFile(certPath, certPEM, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
OCSPStapling: true,
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
defer manager.Close()
// Get status should return empty map (no certs with OCSP server)
status := manager.GetOCSPStatus()
// Could be empty if cert has no OCSP server URL
if len(status) > 0 {
// Verify status info structure
for serial, info := range status {
if info.Serial != serial {
t.Errorf("Serial mismatch: expected %s, got %s", serial, info.Serial)
}
}
}
}
func TestTLSManagerClose(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
certPEM, keyPEM := generateTestCertWithOCSP(t, nil)
if err := os.WriteFile(certPath, certPEM, 0644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
OCSPStapling: true,
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
// Close should work even if OCSP manager is nil or stopped
manager.Close()
// Should not panic on second close
manager.Close()
}
func TestExtractCertificates(t *testing.T) {
// Create valid PEM data
certPEM, _ := generateTestCertWithOCSP(t, nil)
certs, err := extractCertificates(certPEM)
if err != nil {
t.Fatalf("extractCertificates() failed: %v", err)
}
if len(certs) == 0 {
t.Error("Expected at least one certificate")
}
}
func TestExtractCertificatesInvalidPEM(t *testing.T) {
invalidPEM := []byte("not valid pem data")
certs, err := extractCertificates(invalidPEM)
if err == nil {
t.Error("Expected error for invalid PEM data")
}
if certs != nil {
t.Error("Expected nil certs for invalid PEM data")
}
}
func TestOCSPManagerRegisterCertificate(t *testing.T) {
mgr := NewOCSPManager(nil)
defer mgr.Stop()
// Generate test cert with OCSP server URL
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
// Create mock OCSP server
ocspServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return a simple OCSP response
w.WriteHeader(http.StatusOK)
w.Write([]byte("mock-ocsp-response"))
}))
defer ocspServer.Close()
template := x509.Certificate{
SerialNumber: big.NewInt(12345),
Subject: pkix.Name{CommonName: "test"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(1 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: []string{"localhost"},
OCSPServer: []string{ocspServer.URL},
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
t.Fatalf("Failed to create certificate: %v", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
t.Fatalf("Failed to parse certificate: %v", err)
}
// Register should fail gracefully because our mock server returns invalid OCSP
err = mgr.RegisterCertificate(cert, cert)
// This is expected to fail since mock server doesn't return valid OCSP response
if err == nil {
// If it succeeds, verify the response was stored
serial := cert.SerialNumber.String()
mgr.mu.RLock()
_, exists := mgr.responses[serial]
mgr.mu.RUnlock()
if !exists {
t.Error("Expected response to be stored")
}
}
// If it fails, that's also OK - graceful degradation
}
// generateTestCertWithOCSP generates a self-signed certificate for testing.
// If ocspServer is provided, it will be included in the certificate.
func generateTestCertWithOCSP(t *testing.T, ocspServer []string) ([]byte, []byte) {
t.Helper()
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("Failed to generate private key: %v", err)
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test"},
CommonName: "test.example.com",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost", "test.example.com"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
OCSPServer: ocspServer,
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
t.Fatalf("Failed to create certificate: %v", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
keyDER, err := x509.MarshalECPrivateKey(priv)
if err != nil {
t.Fatalf("Failed to marshal private key: %v", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "EC PRIVATE KEY",
Bytes: keyDER,
})
return certPEM, keyPEM
}
func TestOCSPConfigDefaults(t *testing.T) {
cfg := DefaultOCSPConfig()
if !cfg.Enabled {
t.Error("Expected OCSP to be enabled by default")
}
if cfg.RefreshInterval != 1*time.Hour {
t.Errorf("Expected default refresh interval 1h, got %v", cfg.RefreshInterval)
}
if cfg.Timeout != 10*time.Second {
t.Errorf("Expected default timeout 10s, got %v", cfg.Timeout)
}
if cfg.MaxRetries != 3 {
t.Errorf("Expected default max retries 3, got %d", cfg.MaxRetries)
}
}