lolly/internal/ssl/ssl_test.go

828 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package ssl 提供 SSL/TLS 功能的测试。
//
// 该文件测试 SSL 模块的各项功能,包括:
// - TLS 管理器创建和配置
// - TLS 版本和加密套件解析
// - 证书验证
// - 多证书管理
// - TLS 配置获取
//
// 作者xfy
package ssl
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/sslutil"
)
func TestNewTLSManager(t *testing.T) {
tests := []struct {
cfg *config.SSLConfig
name string
errMsg string
wantErr bool
}{
{
name: "nil config",
cfg: nil,
wantErr: true,
errMsg: "ssl config is nil",
},
{
name: "empty cert path",
cfg: &config.SSLConfig{
Key: "key.pem",
},
wantErr: true,
errMsg: "certificate and key paths are required",
},
{
name: "empty key path",
cfg: &config.SSLConfig{
Cert: "cert.pem",
},
wantErr: true,
errMsg: "certificate and key paths are required",
},
{
name: "non-existent cert file",
cfg: &config.SSLConfig{
Cert: "/nonexistent/cert.pem",
Key: "/nonexistent/key.pem",
},
wantErr: true,
errMsg: "failed to load certificate",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewTLSManager(tt.cfg)
if (err != nil) != tt.wantErr {
t.Errorf("NewTLSManager() error = %v, wantErr %v", err, tt.wantErr)
}
if err != nil && tt.errMsg != "" {
if !containsString(err.Error(), tt.errMsg) {
t.Errorf("NewTLSManager() error = %v, want errMsg containing %v", err, tt.errMsg)
}
}
})
}
}
func TestNewTLSManagerWithCert(t *testing.T) {
// Create temporary test certificate
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
// Generate a self-signed certificate for testing
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0o600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
if manager == nil {
t.Fatal("Expected non-nil manager")
}
tlsCfg := manager.GetTLSConfig()
if tlsCfg == nil {
t.Fatal("Expected non-nil TLS config")
}
// Check TLS version defaults
if tlsCfg.MinVersion != tls.VersionTLS12 {
t.Errorf("Expected MinVersion TLS 1.2, got %v", tlsCfg.MinVersion)
}
// Check cipher suites are set
if len(tlsCfg.CipherSuites) == 0 {
t.Error("Expected cipher suites to be set")
}
}
func TestParseTLSVersions(t *testing.T) {
tests := []struct {
name string
protocols []string
wantMin uint16
wantMax uint16
wantErr bool
}{
{
name: "TLS 1.2 only",
protocols: []string{"TLSv1.2"},
wantMin: tls.VersionTLS12,
wantMax: tls.VersionTLS13,
},
{
name: "TLS 1.3 only",
protocols: []string{"TLSv1.3"},
wantMin: tls.VersionTLS13,
wantMax: tls.VersionTLS13,
},
{
name: "TLS 1.2 and 1.3",
protocols: []string{"TLSv1.2", "TLSv1.3"},
wantMin: tls.VersionTLS12,
wantMax: tls.VersionTLS13,
},
{
name: "insecure TLS 1.0",
protocols: []string{"TLSv1.0"},
wantErr: true,
},
{
name: "insecure TLS 1.1",
protocols: []string{"TLSv1.1"},
wantErr: true,
},
{
name: "unknown protocol",
protocols: []string{"TLSv0.9"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
minVer, maxVer, err := sslutil.ParseTLSVersions(tt.protocols)
if (err != nil) != tt.wantErr {
t.Errorf("parseTLSVersions() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if minVer != tt.wantMin {
t.Errorf("parseTLSVersions() minVer = %v, want %v", minVer, tt.wantMin)
}
if maxVer != tt.wantMax {
t.Errorf("parseTLSVersions() maxVer = %v, want %v", maxVer, tt.wantMax)
}
}
})
}
}
func TestParseCipherSuites(t *testing.T) {
tests := []struct {
name string
ciphers []string
wantLen int
wantErr bool
}{
{
name: "valid cipher",
ciphers: []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"},
wantLen: 1,
},
{
name: "multiple valid ciphers",
ciphers: []string{
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
},
wantLen: 2,
},
{
name: "unknown cipher",
ciphers: []string{"TLS_UNKNOWN_CIPHER"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := sslutil.ParseCipherSuites(tt.ciphers)
if (err != nil) != tt.wantErr {
t.Errorf("parseCipherSuites() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && len(result) != tt.wantLen {
t.Errorf("parseCipherSuites() returned %d ciphers, want %d", len(result), tt.wantLen)
}
})
}
}
func TestDefaultCipherSuites(t *testing.T) {
suites := sslutil.DefaultCipherSuites()
if len(suites) == 0 {
t.Error("Expected non-empty default cipher suites")
}
// Check that all default ciphers are secure
for _, suite := range suites {
if sslutil.IsInsecureCipher(suite) {
t.Errorf("Default cipher suite %v is insecure", suite)
}
}
}
func TestIsInsecureCipher(t *testing.T) {
// Test known insecure ciphers
insecureCiphers := []uint16{
tls.TLS_RSA_WITH_RC4_128_SHA,
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
}
for _, c := range insecureCiphers {
if !sslutil.IsInsecureCipher(c) {
t.Errorf("Expected cipher %v to be insecure", c)
}
}
// Test secure ciphers
secureCiphers := []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
}
for _, c := range secureCiphers {
if sslutil.IsInsecureCipher(c) {
t.Errorf("Expected cipher %v to be secure", c)
}
}
}
// generateTestCert generates a self-signed certificate for testing
func generateTestCert(t *testing.T) ([]byte, []byte) {
t.Helper()
// Generate ECDSA private key
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("Failed to generate private key: %v", err)
}
// Create certificate template
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
}
// Create certificate
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
t.Fatalf("Failed to create certificate: %v", err)
}
// Encode certificate to PEM
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
// Encode private key to PEM
keyDER, err := x509.MarshalECPrivateKey(priv)
if err != nil {
t.Fatalf("Failed to marshal private key: %v", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "EC PRIVATE KEY",
Bytes: keyDER,
})
return certPEM, keyPEM
}
func containsString(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// TestGetTLSConfig 测试 TLS 配置获取
func TestGetTLSConfig(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
// 生成自签名证书
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0o600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
defer manager.Close()
// 验证返回非 nil 配置
tlsCfg := manager.GetTLSConfig()
if tlsCfg == nil {
t.Fatal("Expected non-nil TLS config")
}
// 验证 TLS 版本设置
if tlsCfg.MinVersion != tls.VersionTLS12 {
t.Errorf("Expected MinVersion TLS 1.2, got %v", tlsCfg.MinVersion)
}
}
// TestGetTLSConfig_WithProtocols 测试带协议配置的 TLS 配置获取
func TestGetTLSConfig_WithProtocols(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0o600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
Protocols: []string{"TLSv1.3"},
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
defer manager.Close()
tlsCfg := manager.GetTLSConfig()
if tlsCfg == nil {
t.Fatal("Expected non-nil TLS config")
}
// 验证 TLS 1.3 设置
if tlsCfg.MinVersion != tls.VersionTLS13 {
t.Errorf("Expected MinVersion TLS 1.3, got %v", tlsCfg.MinVersion)
}
if tlsCfg.MaxVersion != tls.VersionTLS13 {
t.Errorf("Expected MaxVersion TLS 1.3, got %v", tlsCfg.MaxVersion)
}
}
// TestClose 测试 TLS 管理器关闭
func TestClose(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0o600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
// 第一次关闭
manager.Close()
// 验证多次调用 Close 是安全的
manager.Close()
manager.Close()
}
// TestNewTLSManager_Errors 测试错误场景
func TestNewTLSManager_Errors(t *testing.T) {
tests := []struct {
cfg *config.SSLConfig
name string
errMsg string
wantErr bool
}{
{
name: "nil config",
cfg: nil,
wantErr: true,
errMsg: "ssl config is nil",
},
{
name: "缺少证书路径",
cfg: &config.SSLConfig{
Key: "key.pem",
},
wantErr: true,
errMsg: "certificate and key paths are required",
},
{
name: "缺少密钥路径",
cfg: &config.SSLConfig{
Cert: "cert.pem",
},
wantErr: true,
errMsg: "certificate and key paths are required",
},
{
name: "无效证书文件",
cfg: &config.SSLConfig{
Cert: "/nonexistent/cert.pem",
Key: "/nonexistent/key.pem",
},
wantErr: true,
errMsg: "failed to load certificate",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewTLSManager(tt.cfg)
if (err != nil) != tt.wantErr {
t.Errorf("NewTLSManager() error = %v, wantErr %v", err, tt.wantErr)
}
if err != nil && tt.errMsg != "" {
if !containsString(err.Error(), tt.errMsg) {
t.Errorf("NewTLSManager() error = %v, want errMsg containing %v", err, tt.errMsg)
}
}
})
}
}
// TestNewTLSManager_InvalidCipher 测试无效加密套件
// TestNewTLSManager_InsecureCipher 测试不安全加密套件
func TestNewTLSManager_InsecureCipher(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0o600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
Ciphers: []string{"TLS_ECDHE_RSA_WITH_RC4_128_SHA"},
}
_, err := NewTLSManager(cfg)
if err == nil {
t.Error("Expected error for insecure cipher suite")
}
}
// TestNewMultiTLSManager 测试多证书 TLS 管理器
// TestNewMultiTLSManager_EmptyConfigs 测试空配置
// TestNewMultiTLSManager_NilConfig 测试 nil 配置项
// TestGetCertificate 测试证书获取回调
// TestAddCertificate 测试添加证书
// TestAddCertificate_Error 测试添加证书错误
// TestRemoveCertificate 测试移除证书
// TestGetOCSPStatus_NoManager 测试无 OCSP 管理器时的状态
// TestParsePEMChain 测试 PEM 证书链解析
func TestParsePEMChain(t *testing.T) {
// 测试有效的 PEM 数据
certPEM, _ := generateTestCert(t)
certs := parsePEMChain(certPEM)
if len(certs) == 0 {
t.Error("Expected at least one certificate from valid PEM")
}
// 测试空数据
emptyCerts := parsePEMChain([]byte{})
if len(emptyCerts) != 0 {
t.Error("Expected no certificates from empty data")
}
// 测试无效 PEM 数据
invalidCerts := parsePEMChain([]byte("not valid pem"))
if len(invalidCerts) != 0 {
t.Error("Expected no certificates from invalid PEM")
}
}
// TestExtractPEMBlock 测试 PEM 块提取
func TestExtractPEMBlock(t *testing.T) {
// 测试有效的证书块
certPEM, _ := generateTestCert(t)
block, rest := extractPEMBlock(certPEM)
if block == nil {
t.Error("Expected non-nil block from valid PEM")
}
if len(block) == 0 {
t.Error("Expected non-empty block")
}
_ = rest
// 测试空数据
block, _ = extractPEMBlock([]byte{})
if block != nil {
t.Error("Expected nil block from empty data")
}
// 测试无结束标记的数据
invalidData := []byte("-----BEGIN CERTIFICATE-----\nsome data without end")
block, _ = extractPEMBlock(invalidData)
if block != nil {
t.Error("Expected nil block from incomplete PEM")
}
// 测试无开始标记的数据
noStartData := []byte("some data\n-----END CERTIFICATE-----")
block, _ = extractPEMBlock(noStartData)
if block != nil {
t.Error("Expected nil block from data without start marker")
}
}
// TestFindMarker 测试标记查找
func TestFindMarker(t *testing.T) {
data := []byte("prefix-----BEGIN CERTIFICATE-----suffix")
marker := []byte("-----BEGIN CERTIFICATE-----")
idx := findMarker(data, marker)
if idx != 6 {
t.Errorf("Expected index 6, got %d", idx)
}
// 测试不存在的标记
idx = findMarker(data, []byte("NOTFOUND"))
if idx != -1 {
t.Errorf("Expected -1 for not found marker, got %d", idx)
}
// 测试空数据
idx = findMarker([]byte{}, marker)
if idx != -1 {
t.Errorf("Expected -1 for empty data, got %d", idx)
}
}
// TestMatchMarker 测试标记匹配
func TestMatchMarker(t *testing.T) {
data := []byte("-----BEGIN CERTIFICATE-----suffix")
marker := []byte("-----BEGIN CERTIFICATE-----")
if !matchMarker(data, marker) {
t.Error("Expected true for matching marker")
}
// 测试不匹配
if matchMarker(data, []byte("-----END CERTIFICATE-----")) {
t.Error("Expected false for non-matching marker")
}
// 测试数据长度小于标记
shortData := []byte("short")
if matchMarker(shortData, marker) {
t.Error("Expected false when data is shorter than marker")
}
}
// TestGetCertificate_NoCertificate 测试无证书时的错误情况
// TestGetConfigForClientWithOCSP 测试 OCSP 配置回调
// TestLoadCertificate_WithCertChain 测试带证书链的加载
// TestLoadCertificate_InvalidChain 测试无效证书链
// TestCreateTLSConfig_NilConfig 测试 nil 配置
// TestNewTLSManager_WithSessionTickets 测试启用 Session Tickets
func TestNewTLSManager_WithSessionTickets(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
ticketKeyPath := filepath.Join(tmpDir, "ticket.key")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0o600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
SessionTickets: config.SessionTicketsConfig{
Enabled: true,
KeyFile: ticketKeyPath,
RotateInterval: time.Hour,
RetainKeys: 3,
},
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
defer manager.Close()
// 验证 Session Ticket 管理器已初始化
manager.mu.RLock()
stm := manager.sessionTicketMgr
manager.mu.RUnlock()
if stm == nil {
t.Error("Expected session ticket manager to be initialized")
}
}
// TestNewTLSManager_WithClientVerify 测试启用客户端验证
func TestNewTLSManager_WithClientVerify(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
caPath := filepath.Join(tmpDir, "ca.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0o600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
// 创建 CA 证书
_, _, caPEM := generateTestCA(t)
if err := os.WriteFile(caPath, caPEM, 0o644); err != nil {
t.Fatalf("Failed to write CA: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
ClientVerify: config.ClientVerifyConfig{
Enabled: true,
Mode: "on",
ClientCA: caPath,
VerifyDepth: 3,
},
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
defer manager.Close()
// 验证客户端验证器已初始化
manager.mu.RLock()
cv := manager.clientVerifier
manager.mu.RUnlock()
if cv == nil {
t.Error("Expected client verifier to be initialized")
}
}
// TestNewTLSManager_WithInvalidClientCA 测试无效的客户端 CA
func TestNewTLSManager_WithInvalidClientCA(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
cert, key := generateTestCert(t)
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, key, 0o600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
ClientVerify: config.ClientVerifyConfig{
Enabled: true,
Mode: "on",
ClientCA: "/nonexistent/ca.pem",
},
}
// 客户端验证配置失败不阻止 TLS 工作
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() should not fail for invalid client CA: %v", err)
}
defer manager.Close()
// 客户端验证器应未初始化
manager.mu.RLock()
cv := manager.clientVerifier
manager.mu.RUnlock()
if cv != nil {
t.Error("Expected client verifier to be nil for invalid CA")
}
}
// TestNewTLSManager_WithOCSPAndIssuer 测试带颁发者证书的 OCSP
func TestNewTLSManager_WithOCSPAndIssuer(t *testing.T) {
tmpDir := t.TempDir()
certPath := filepath.Join(tmpDir, "cert.pem")
keyPath := filepath.Join(tmpDir, "key.pem")
chainPath := filepath.Join(tmpDir, "chain.pem")
// 生成带 OCSP 服务器的证书
certPEM, keyPEM := generateTestCertWithOCSP(t, []string{"http://ocsp.example.com"})
if err := os.WriteFile(certPath, certPEM, 0o644); err != nil {
t.Fatalf("Failed to write cert: %v", err)
}
if err := os.WriteFile(keyPath, keyPEM, 0o600); err != nil {
t.Fatalf("Failed to write key: %v", err)
}
// 生成证书链(颁发者证书)
chainCert, _ := generateTestCert(t)
if err := os.WriteFile(chainPath, chainCert, 0o644); err != nil {
t.Fatalf("Failed to write chain: %v", err)
}
cfg := &config.SSLConfig{
Cert: certPath,
Key: keyPath,
CertChain: chainPath,
OCSPStapling: true,
}
manager, err := NewTLSManager(cfg)
if err != nil {
t.Fatalf("NewTLSManager() failed: %v", err)
}
defer manager.Close()
// 验证 OCSP 管理器已初始化
manager.mu.RLock()
om := manager.ocspManager
manager.mu.RUnlock()
if om == nil {
t.Error("Expected OCSP manager to be initialized")
}
}