- SessionTicketsConfig 支持 TLS 1.3 会话恢复,密钥轮换和持久化 - ClientVerifyConfig 支持双向 TLS 认证,CA 证书池和 CRL - TLSManager 集成 SessionTicketManager 和 ClientVerifier - 新增完整测试覆盖密钥轮换和客户端验证逻辑 Co-Authored-By: Claude <noreply@anthropic.com>
357 lines
8.9 KiB
Go
357 lines
8.9 KiB
Go
// Package ssl 提供 mTLS 客户端验证的单元测试。
|
||
//
|
||
// 测试覆盖:
|
||
// - CA 证书池加载
|
||
// - 验证模式解析
|
||
// - 客户端证书验证
|
||
// - CRL 检查
|
||
// - 变量提取
|
||
//
|
||
// 作者:xfy
|
||
package ssl
|
||
|
||
import (
|
||
"crypto/rand"
|
||
"crypto/rsa"
|
||
"crypto/tls"
|
||
"crypto/x509"
|
||
"crypto/x509/pkix"
|
||
"encoding/pem"
|
||
"math/big"
|
||
"os"
|
||
"path/filepath"
|
||
"testing"
|
||
"time"
|
||
|
||
"rua.plus/lolly/internal/config"
|
||
)
|
||
|
||
// generateTestCA 生成测试 CA 证书。
|
||
func generateTestCA(t *testing.T) (*x509.Certificate, *rsa.PrivateKey, []byte) {
|
||
t.Helper()
|
||
|
||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||
if err != nil {
|
||
t.Fatalf("Failed to generate CA key: %v", err)
|
||
}
|
||
|
||
template := &x509.Certificate{
|
||
SerialNumber: big.NewInt(1),
|
||
Subject: pkix.Name{
|
||
CommonName: "Test CA",
|
||
},
|
||
NotBefore: time.Now(),
|
||
NotAfter: time.Now().Add(24 * time.Hour),
|
||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||
BasicConstraintsValid: true,
|
||
IsCA: true,
|
||
}
|
||
|
||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||
if err != nil {
|
||
t.Fatalf("Failed to create CA cert: %v", err)
|
||
}
|
||
|
||
cert, err := x509.ParseCertificate(certDER)
|
||
if err != nil {
|
||
t.Fatalf("Failed to parse CA cert: %v", err)
|
||
}
|
||
|
||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||
return cert, key, certPEM
|
||
}
|
||
|
||
// generateTestClientCert 生成测试客户端证书。
|
||
func generateTestClientCert(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, []byte) {
|
||
t.Helper()
|
||
|
||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||
if err != nil {
|
||
t.Fatalf("Failed to generate client key: %v", err)
|
||
}
|
||
|
||
template := &x509.Certificate{
|
||
SerialNumber: big.NewInt(2),
|
||
Subject: pkix.Name{
|
||
CommonName: "Test Client",
|
||
Organization: []string{"Test Org"},
|
||
},
|
||
NotBefore: time.Now(),
|
||
NotAfter: time.Now().Add(24 * time.Hour),
|
||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||
EmailAddresses: []string{"test@example.com"},
|
||
}
|
||
|
||
certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, &key.PublicKey, caKey)
|
||
if err != nil {
|
||
t.Fatalf("Failed to create client cert: %v", err)
|
||
}
|
||
|
||
cert, err := x509.ParseCertificate(certDER)
|
||
if err != nil {
|
||
t.Fatalf("Failed to parse client cert: %v", err)
|
||
}
|
||
|
||
return cert, key, certDER
|
||
}
|
||
|
||
// TestParseVerifyMode 测试验证模式解析。
|
||
func TestParseVerifyMode(t *testing.T) {
|
||
tests := []struct {
|
||
input string
|
||
expected ClientVerifyMode
|
||
wantErr bool
|
||
}{
|
||
{"off", VerifyOff, false},
|
||
{"", VerifyOff, false},
|
||
{"on", VerifyOn, false},
|
||
{"optional", VerifyOptional, false},
|
||
{"optional_no_ca", VerifyOptionalNoCA, false},
|
||
{"invalid", VerifyOff, true},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.input, func(t *testing.T) {
|
||
mode, err := ParseVerifyMode(tt.input)
|
||
if tt.wantErr {
|
||
if err == nil {
|
||
t.Errorf("ParseVerifyMode(%q) expected error", tt.input)
|
||
}
|
||
return
|
||
}
|
||
if err != nil {
|
||
t.Errorf("ParseVerifyMode(%q) unexpected error: %v", tt.input, err)
|
||
return
|
||
}
|
||
if mode != tt.expected {
|
||
t.Errorf("ParseVerifyMode(%q) = %v, want %v", tt.input, mode, tt.expected)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestClientVerifyMode_TLSClientAuth 测试 TLS 认证类型映射。
|
||
func TestClientVerifyMode_TLSClientAuth(t *testing.T) {
|
||
tests := []struct {
|
||
mode ClientVerifyMode
|
||
expected tls.ClientAuthType
|
||
}{
|
||
{VerifyOff, tls.NoClientCert},
|
||
{VerifyOn, tls.RequireAndVerifyClientCert},
|
||
{VerifyOptional, tls.VerifyClientCertIfGiven},
|
||
{VerifyOptionalNoCA, tls.RequestClientCert},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.mode.String(), func(t *testing.T) {
|
||
auth := tt.mode.TLSClientAuth()
|
||
if auth != tt.expected {
|
||
t.Errorf("TLSClientAuth() = %v, want %v", auth, tt.expected)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestLoadCACertPool 测试 CA 证书池加载。
|
||
func TestLoadCACertPool(t *testing.T) {
|
||
// 创建临时 CA 文件
|
||
tempDir := t.TempDir()
|
||
caFile := filepath.Join(tempDir, "ca.crt")
|
||
|
||
_, _, caPEM := generateTestCA(t)
|
||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||
t.Fatalf("Failed to write CA file: %v", err)
|
||
}
|
||
|
||
// 测试加载
|
||
pool, err := LoadCACertPool(caFile)
|
||
if err != nil {
|
||
t.Fatalf("LoadCACertPool() failed: %v", err)
|
||
}
|
||
if pool == nil {
|
||
t.Fatal("LoadCACertPool() returned nil pool")
|
||
}
|
||
|
||
// 测试文件不存在
|
||
_, err = LoadCACertPool("/nonexistent/ca.crt")
|
||
if err == nil {
|
||
t.Error("LoadCACertPool() should fail for non-existent file")
|
||
}
|
||
|
||
// 测试无效证书
|
||
invalidFile := filepath.Join(tempDir, "invalid.crt")
|
||
if err := os.WriteFile(invalidFile, []byte("invalid data"), 0644); err != nil {
|
||
t.Fatalf("写入无效证书文件失败: %v", err)
|
||
}
|
||
_, err = LoadCACertPool(invalidFile)
|
||
if err == nil {
|
||
t.Error("LoadCACertPool() should fail for invalid certificate")
|
||
}
|
||
}
|
||
|
||
// TestNewClientVerifier 测试创建客户端验证器。
|
||
func TestNewClientVerifier(t *testing.T) {
|
||
// 测试禁用状态
|
||
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
|
||
Enabled: false,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("NewClientVerifier() failed for disabled config: %v", err)
|
||
}
|
||
if verifier.IsEnabled() {
|
||
t.Error("Verifier should be disabled when Enabled=false")
|
||
}
|
||
|
||
// 测试启用但无 CA(应该失败)
|
||
_, err = NewClientVerifier(config.ClientVerifyConfig{
|
||
Enabled: true,
|
||
Mode: "on",
|
||
})
|
||
if err == nil {
|
||
t.Error("NewClientVerifier() should fail without CA file")
|
||
}
|
||
}
|
||
|
||
// TestNewClientVerifier_WithCA 测试带 CA 的验证器。
|
||
func TestNewClientVerifier_WithCA(t *testing.T) {
|
||
// 创建临时 CA 文件
|
||
tempDir := t.TempDir()
|
||
caFile := filepath.Join(tempDir, "ca.crt")
|
||
|
||
_, _, caPEM := generateTestCA(t)
|
||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||
t.Fatalf("Failed to write CA file: %v", err)
|
||
}
|
||
|
||
// 测试各种模式
|
||
modes := []string{"on", "optional"}
|
||
for _, mode := range modes {
|
||
t.Run(mode, func(t *testing.T) {
|
||
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
|
||
Enabled: true,
|
||
Mode: mode,
|
||
ClientCA: caFile,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("NewClientVerifier() failed: %v", err)
|
||
}
|
||
if !verifier.IsEnabled() {
|
||
t.Error("Verifier should be enabled")
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestClientVerifier_ConfigureTLS 测试 TLS 配置。
|
||
func TestClientVerifier_ConfigureTLS(t *testing.T) {
|
||
// 创建临时 CA 文件
|
||
tempDir := t.TempDir()
|
||
caFile := filepath.Join(tempDir, "ca.crt")
|
||
|
||
_, _, caPEM := generateTestCA(t)
|
||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||
t.Fatalf("Failed to write CA file: %v", err)
|
||
}
|
||
|
||
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
|
||
Enabled: true,
|
||
Mode: "on",
|
||
ClientCA: caFile,
|
||
VerifyDepth: 2,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("NewClientVerifier() failed: %v", err)
|
||
}
|
||
|
||
tlsCfg := &tls.Config{}
|
||
verifier.ConfigureTLS(tlsCfg)
|
||
|
||
if tlsCfg.ClientAuth != tls.RequireAndVerifyClientCert {
|
||
t.Errorf("ClientAuth = %v, want %v", tlsCfg.ClientAuth, tls.RequireAndVerifyClientCert)
|
||
}
|
||
if tlsCfg.ClientCAs == nil {
|
||
t.Error("ClientCAs should be set")
|
||
}
|
||
|
||
// 测试 nil 配置(不应 panic)
|
||
verifier.ConfigureTLS(nil)
|
||
}
|
||
|
||
// TestClientVerifier_ConfigureTLS_Disabled 测试禁用验证器。
|
||
func TestClientVerifier_ConfigureTLS_Disabled(t *testing.T) {
|
||
verifier, _ := NewClientVerifier(config.ClientVerifyConfig{
|
||
Enabled: false,
|
||
})
|
||
|
||
tlsCfg := &tls.Config{}
|
||
verifier.ConfigureTLS(tlsCfg)
|
||
|
||
// 禁用时不应修改配置
|
||
if tlsCfg.ClientAuth != 0 {
|
||
t.Error("Disabled verifier should not modify TLS config")
|
||
}
|
||
}
|
||
|
||
// TestGetClientCertInfo 测试证书信息提取。
|
||
func TestGetClientCertInfo(t *testing.T) {
|
||
// 生成测试证书
|
||
caCert, caKey, _ := generateTestCA(t)
|
||
clientCert, _, _ := generateTestClientCert(t, caCert, caKey)
|
||
|
||
// 创建模拟连接状态
|
||
cs := &tls.ConnectionState{
|
||
PeerCertificates: []*x509.Certificate{clientCert},
|
||
}
|
||
|
||
info := GetClientCertInfo(cs)
|
||
if info == nil {
|
||
t.Fatal("GetClientCertInfo() returned nil")
|
||
}
|
||
|
||
if info.Serial == "" {
|
||
t.Error("Serial should not be empty")
|
||
}
|
||
if info.Subject == "" {
|
||
t.Error("Subject should not be empty")
|
||
}
|
||
if info.Issuer == "" {
|
||
t.Error("Issuer should not be empty")
|
||
}
|
||
|
||
// 测试无证书
|
||
emptyCs := &tls.ConnectionState{}
|
||
info = GetClientCertInfo(emptyCs)
|
||
if info != nil {
|
||
t.Error("GetClientCertInfo() should return nil for no certificates")
|
||
}
|
||
}
|
||
|
||
// TestGetClientCertInfo_Nil 测试 nil 输入。
|
||
func TestGetClientCertInfo_Nil(t *testing.T) {
|
||
info := GetClientCertInfo(nil)
|
||
if info != nil {
|
||
t.Error("GetClientCertInfo(nil) should return nil")
|
||
}
|
||
}
|
||
|
||
// BenchmarkLoadCACertPool 基准测试 CA 证书池加载。
|
||
func BenchmarkLoadCACertPool(b *testing.B) {
|
||
tempDir := b.TempDir()
|
||
caFile := filepath.Join(tempDir, "ca.crt")
|
||
|
||
// 生成 CA
|
||
_, _, caPEM := generateTestCA(&testing.T{})
|
||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||
b.Fatalf("写入 CA 文件失败: %v", err)
|
||
}
|
||
|
||
b.ResetTimer()
|
||
for i := 0; i < b.N; i++ {
|
||
_, err := LoadCACertPool(caFile)
|
||
if err != nil {
|
||
b.Fatal(err)
|
||
}
|
||
}
|
||
}
|