lolly/internal/ssl/client_verify_test.go
xfy 96bd4b0ed5 refactor(ssl): 提取证书池加载函数到 sslutil 包
将 LoadCACertPool 和 LoadCertPool 函数提取到独立的 sslutil 包,
消除 ssl 和 stream 模块中的重复实现。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-13 10:58:57 +08:00

702 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package ssl 提供 mTLS 客户端验证的单元测试。
//
// 测试覆盖:
// - CA 证书池加载
// - 验证模式解析
// - 客户端证书验证
// - CRL 检查
// - 变量提取
//
// 作者xfy
package ssl
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/sslutil"
)
// generateTestCA 生成测试 CA 证书。
func generateTestCA(t *testing.T) (*x509.Certificate, *rsa.PrivateKey, []byte) {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate CA key: %v", err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "Test CA",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
if err != nil {
t.Fatalf("Failed to create CA cert: %v", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
t.Fatalf("Failed to parse CA cert: %v", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
return cert, key, certPEM
}
// generateTestClientCert 生成测试客户端证书serial 参数指定序列号。
func generateTestClientCert(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey, serial int64) (*x509.Certificate, *rsa.PrivateKey, []byte) {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate client key: %v", err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(serial),
Subject: pkix.Name{
CommonName: "Test Client",
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
EmailAddresses: []string{"test@example.com"},
}
certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, &key.PublicKey, caKey)
if err != nil {
t.Fatalf("Failed to create client cert: %v", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
t.Fatalf("Failed to parse client cert: %v", err)
}
return cert, key, certDER
}
// TestParseVerifyMode 测试验证模式解析。
func TestParseVerifyMode(t *testing.T) {
tests := []struct {
input string
expected ClientVerifyMode
wantErr bool
}{
{"off", VerifyOff, false},
{"", VerifyOff, false},
{"on", VerifyOn, false},
{"optional", VerifyOptional, false},
{"optional_no_ca", VerifyOptionalNoCA, false},
{"invalid", VerifyOff, true},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
mode, err := ParseVerifyMode(tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("ParseVerifyMode(%q) expected error", tt.input)
}
return
}
if err != nil {
t.Errorf("ParseVerifyMode(%q) unexpected error: %v", tt.input, err)
return
}
if mode != tt.expected {
t.Errorf("ParseVerifyMode(%q) = %v, want %v", tt.input, mode, tt.expected)
}
})
}
}
// TestClientVerifyMode_TLSClientAuth 测试 TLS 认证类型映射。
func TestClientVerifyMode_TLSClientAuth(t *testing.T) {
tests := []struct {
mode ClientVerifyMode
expected tls.ClientAuthType
}{
{VerifyOff, tls.NoClientCert},
{VerifyOn, tls.RequireAndVerifyClientCert},
{VerifyOptional, tls.VerifyClientCertIfGiven},
{VerifyOptionalNoCA, tls.RequestClientCert},
}
for _, tt := range tests {
t.Run(tt.mode.String(), func(t *testing.T) {
auth := tt.mode.TLSClientAuth()
if auth != tt.expected {
t.Errorf("TLSClientAuth() = %v, want %v", auth, tt.expected)
}
})
}
}
// TestLoadCACertPool 测试 CA 证书池加载。
func TestLoadCACertPool(t *testing.T) {
// 创建临时 CA 文件
tempDir := t.TempDir()
caFile := filepath.Join(tempDir, "ca.crt")
_, _, caPEM := generateTestCA(t)
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
t.Fatalf("Failed to write CA file: %v", err)
}
// 测试加载
pool, err := sslutil.LoadCACertPool(caFile)
if err != nil {
t.Fatalf("LoadCACertPool() failed: %v", err)
}
if pool == nil {
t.Fatal("LoadCACertPool() returned nil pool")
}
// 测试文件不存在
_, err = sslutil.LoadCACertPool("/nonexistent/ca.crt")
if err == nil {
t.Error("LoadCACertPool() should fail for non-existent file")
}
// 测试无效证书
invalidFile := filepath.Join(tempDir, "invalid.crt")
if err := os.WriteFile(invalidFile, []byte("invalid data"), 0644); err != nil {
t.Fatalf("写入无效证书文件失败: %v", err)
}
_, err = sslutil.LoadCACertPool(invalidFile)
if err == nil {
t.Error("LoadCACertPool() should fail for invalid certificate")
}
}
// TestNewClientVerifier 测试创建客户端验证器。
func TestNewClientVerifier(t *testing.T) {
// 测试禁用状态
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
Enabled: false,
})
if err != nil {
t.Fatalf("NewClientVerifier() failed for disabled config: %v", err)
}
if verifier.IsEnabled() {
t.Error("Verifier should be disabled when Enabled=false")
}
// 测试启用但无 CA应该失败
_, err = NewClientVerifier(config.ClientVerifyConfig{
Enabled: true,
Mode: "on",
})
if err == nil {
t.Error("NewClientVerifier() should fail without CA file")
}
}
// TestNewClientVerifier_WithCA 测试带 CA 的验证器。
func TestNewClientVerifier_WithCA(t *testing.T) {
// 创建临时 CA 文件
tempDir := t.TempDir()
caFile := filepath.Join(tempDir, "ca.crt")
_, _, caPEM := generateTestCA(t)
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
t.Fatalf("Failed to write CA file: %v", err)
}
// 测试各种模式
modes := []string{"on", "optional"}
for _, mode := range modes {
t.Run(mode, func(t *testing.T) {
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
Enabled: true,
Mode: mode,
ClientCA: caFile,
})
if err != nil {
t.Fatalf("NewClientVerifier() failed: %v", err)
}
if !verifier.IsEnabled() {
t.Error("Verifier should be enabled")
}
})
}
}
// TestClientVerifier_ConfigureTLS 测试 TLS 配置。
func TestClientVerifier_ConfigureTLS(t *testing.T) {
// 创建临时 CA 文件
tempDir := t.TempDir()
caFile := filepath.Join(tempDir, "ca.crt")
_, _, caPEM := generateTestCA(t)
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
t.Fatalf("Failed to write CA file: %v", err)
}
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
Enabled: true,
Mode: "on",
ClientCA: caFile,
VerifyDepth: 2,
})
if err != nil {
t.Fatalf("NewClientVerifier() failed: %v", err)
}
tlsCfg := &tls.Config{}
verifier.ConfigureTLS(tlsCfg)
if tlsCfg.ClientAuth != tls.RequireAndVerifyClientCert {
t.Errorf("ClientAuth = %v, want %v", tlsCfg.ClientAuth, tls.RequireAndVerifyClientCert)
}
if tlsCfg.ClientCAs == nil {
t.Error("ClientCAs should be set")
}
// 测试 nil 配置(不应 panic
verifier.ConfigureTLS(nil)
}
// TestClientVerifier_ConfigureTLS_Disabled 测试禁用验证器。
func TestClientVerifier_ConfigureTLS_Disabled(t *testing.T) {
verifier, _ := NewClientVerifier(config.ClientVerifyConfig{
Enabled: false,
})
tlsCfg := &tls.Config{}
verifier.ConfigureTLS(tlsCfg)
// 禁用时不应修改配置
if tlsCfg.ClientAuth != 0 {
t.Error("Disabled verifier should not modify TLS config")
}
}
// TestGetClientCertInfo 测试证书信息提取。
func TestGetClientCertInfo(t *testing.T) {
// 生成测试证书
caCert, caKey, _ := generateTestCA(t)
clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 2)
// 创建模拟连接状态
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{clientCert},
}
info := GetClientCertInfo(cs)
if info == nil {
t.Fatal("GetClientCertInfo() returned nil")
}
if info.Serial == "" {
t.Error("Serial should not be empty")
}
if info.Subject == "" {
t.Error("Subject should not be empty")
}
if info.Issuer == "" {
t.Error("Issuer should not be empty")
}
// 测试无证书
emptyCs := &tls.ConnectionState{}
info = GetClientCertInfo(emptyCs)
if info != nil {
t.Error("GetClientCertInfo() should return nil for no certificates")
}
}
// TestGetClientCertInfo_Nil 测试 nil 输入。
func TestGetClientCertInfo_Nil(t *testing.T) {
info := GetClientCertInfo(nil)
if info != nil {
t.Error("GetClientCertInfo(nil) should return nil")
}
}
// TestGetMode 测试获取验证模式。
func TestGetMode(t *testing.T) {
tests := []struct {
name string
mode string
expected ClientVerifyMode
}{
{"off", "off", VerifyOff},
{"on", "on", VerifyOn},
{"optional", "optional", VerifyOptional},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir := t.TempDir()
caFile := filepath.Join(tempDir, "ca.crt")
_, _, caPEM := generateTestCA(t)
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
t.Fatalf("写入 CA 文件失败: %v", err)
}
var cfg config.ClientVerifyConfig
if tt.mode != "off" {
cfg = config.ClientVerifyConfig{
Enabled: true,
Mode: tt.mode,
ClientCA: caFile,
}
} else {
cfg = config.ClientVerifyConfig{Enabled: false}
}
verifier, err := NewClientVerifier(cfg)
if err != nil {
t.Fatalf("NewClientVerifier() failed: %v", err)
}
if verifier.GetMode() != tt.expected {
t.Errorf("GetMode() = %v, want %v", verifier.GetMode(), tt.expected)
}
})
}
}
// generateTestCRL 生成测试 CRL。
func generateTestCRL(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey, revokedSerials []*big.Int) []byte {
t.Helper()
template := &x509.RevocationList{
Number: big.NewInt(1),
ThisUpdate: time.Now(),
NextUpdate: time.Now().Add(24 * time.Hour),
RevokedCertificateEntries: func() []x509.RevocationListEntry {
entries := make([]x509.RevocationListEntry, len(revokedSerials))
for i, serial := range revokedSerials {
entries[i] = x509.RevocationListEntry{
SerialNumber: serial,
RevocationTime: time.Now(),
}
}
return entries
}(),
}
crlDER, err := x509.CreateRevocationList(rand.Reader, template, caCert, caKey)
if err != nil {
t.Fatalf("Failed to create CRL: %v", err)
}
return pem.EncodeToMemory(&pem.Block{Type: "X509 CRL", Bytes: crlDER})
}
// TestLoadCRL 测试 CRL 加载。
func TestLoadCRL(t *testing.T) {
// 生成测试 CA
caCert, caKey, _ := generateTestCA(t)
// 生成包含吊销证书的 CRL
revokedSerial := big.NewInt(999)
crlPEM := generateTestCRL(t, caCert, caKey, []*big.Int{revokedSerial})
// 写入临时文件
tempDir := t.TempDir()
crlFile := filepath.Join(tempDir, "crl.pem")
if err := os.WriteFile(crlFile, crlPEM, 0644); err != nil {
t.Fatalf("写入 CRL 文件失败: %v", err)
}
// 测试加载
crl, err := LoadCRL(crlFile)
if err != nil {
t.Fatalf("LoadCRL() failed: %v", err)
}
if crl == nil {
t.Fatal("LoadCRL() returned nil")
}
if len(crl.RevokedCertificateEntries) != 1 {
t.Errorf("CRL should have 1 revoked certificate, got %d", len(crl.RevokedCertificateEntries))
}
// 测试文件不存在
_, err = LoadCRL("/nonexistent/crl.pem")
if err == nil {
t.Error("LoadCRL() should fail for non-existent file")
}
// 测试无效 CRL
invalidFile := filepath.Join(tempDir, "invalid.crl")
if err := os.WriteFile(invalidFile, []byte("invalid data"), 0644); err != nil {
t.Fatalf("写入无效文件失败: %v", err)
}
_, err = LoadCRL(invalidFile)
if err == nil {
t.Error("LoadCRL() should fail for invalid CRL")
}
}
// TestCheckCRL 测试 CRL 检查。
func TestCheckCRL(t *testing.T) {
// 生成测试 CA
caCert, caKey, _ := generateTestCA(t)
// 生成将被吊销的客户端证书序列号100
revokedCert, _, _ := generateTestClientCert(t, caCert, caKey, 100)
// 生成有效客户端证书序列号200不会被吊销
validCert, _, _ := generateTestClientCert(t, caCert, caKey, 200)
// 生成包含吊销证书的 CRL
crlPEM := generateTestCRL(t, caCert, caKey, []*big.Int{revokedCert.SerialNumber})
// 写入临时文件
tempDir := t.TempDir()
crlFile := filepath.Join(tempDir, "crl.pem")
caFile := filepath.Join(tempDir, "ca.crt")
_, _, caPEM := generateTestCA(t)
if err := os.WriteFile(crlFile, crlPEM, 0644); err != nil {
t.Fatalf("写入 CRL 文件失败: %v", err)
}
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
t.Fatalf("写入 CA 文件失败: %v", err)
}
// 创建带 CRL 的验证器
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
Enabled: true,
Mode: "on",
ClientCA: caFile,
CRL: crlFile,
})
if err != nil {
t.Fatalf("NewClientVerifier() failed: %v", err)
}
// 测试检查有效证书
err = verifier.ValidateClientCertificate(validCert)
if err != nil {
t.Errorf("CheckCRL() should pass for valid cert: %v", err)
}
// 测试检查吊销证书
err = verifier.ValidateClientCertificate(revokedCert)
if err == nil {
t.Error("CheckCRL() should fail for revoked cert")
}
}
// TestCheckCRL_EmptyCRL 测试空 CRL。
func TestCheckCRL_EmptyCRL(t *testing.T) {
caCert, caKey, _ := generateTestCA(t)
clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 50)
// 生成空 CRL无吊销证书
crlPEM := generateTestCRL(t, caCert, caKey, nil)
tempDir := t.TempDir()
crlFile := filepath.Join(tempDir, "crl.pem")
caFile := filepath.Join(tempDir, "ca.crt")
_, _, caPEM := generateTestCA(t)
if err := os.WriteFile(crlFile, crlPEM, 0644); err != nil {
t.Fatalf("写入 CRL 文件失败: %v", err)
}
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
t.Fatalf("写入 CA 文件失败: %v", err)
}
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
Enabled: true,
Mode: "on",
ClientCA: caFile,
CRL: crlFile,
})
if err != nil {
t.Fatalf("NewClientVerifier() failed: %v", err)
}
// 空列表应通过所有证书
err = verifier.ValidateClientCertificate(clientCert)
if err != nil {
t.Errorf("CheckCRL() should pass with empty CRL: %v", err)
}
}
// TestValidateClientCertificate 测试手动验证客户端证书。
func TestValidateClientCertificate(t *testing.T) {
// 测试禁用验证器
verifier, _ := NewClientVerifier(config.ClientVerifyConfig{Enabled: false})
err := verifier.ValidateClientCertificate(nil)
if err != nil {
t.Errorf("Disabled verifier should accept nil cert: %v", err)
}
// 测试启用验证器on 模式)
tempDir := t.TempDir()
caFile := filepath.Join(tempDir, "ca.crt")
_, _, caPEM := generateTestCA(t)
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
t.Fatalf("写入 CA 文件失败: %v", err)
}
verifier, _ = NewClientVerifier(config.ClientVerifyConfig{
Enabled: true,
Mode: "on",
ClientCA: caFile,
})
// nil 证书在 on 模式应失败
err = verifier.ValidateClientCertificate(nil)
if err == nil {
t.Error("ValidateClientCertificate(nil) should fail in 'on' mode")
}
// 有效证书应通过
caCert, caKey, _ := generateTestCA(t)
clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 30)
err = verifier.ValidateClientCertificate(clientCert)
if err != nil {
t.Errorf("ValidateClientCertificate() should pass for valid cert: %v", err)
}
}
// TestVerifyConnection 测试连接验证。
func TestVerifyConnection(t *testing.T) {
// 生成测试 CA 和证书
caCert, caKey, _ := generateTestCA(t)
validCert, _, _ := generateTestClientCert(t, caCert, caKey, 300)
revokedCert, _, _ := generateTestClientCert(t, caCert, caKey, 400)
// 生成包含吊销证书的 CRL
crlPEM := generateTestCRL(t, caCert, caKey, []*big.Int{revokedCert.SerialNumber})
tempDir := t.TempDir()
crlFile := filepath.Join(tempDir, "crl.pem")
caFile := filepath.Join(tempDir, "ca.crt")
_, _, caPEM := generateTestCA(t)
if err := os.WriteFile(crlFile, crlPEM, 0644); err != nil {
t.Fatalf("写入 CRL 文件失败: %v", err)
}
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
t.Fatalf("写入 CA 文件失败: %v", err)
}
// 测试带 CRL 和深度限制的验证器
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
Enabled: true,
Mode: "on",
ClientCA: caFile,
CRL: crlFile,
VerifyDepth: 3,
})
if err != nil {
t.Fatalf("NewClientVerifier() failed: %v", err)
}
// 配置 TLS 以设置 VerifyConnection 回调
tlsCfg := &tls.Config{}
verifier.ConfigureTLS(tlsCfg)
if tlsCfg.VerifyConnection == nil {
t.Fatal("VerifyConnection should be set when VerifyDepth > 0")
}
// 测试有效证书连接
validCS := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{validCert},
}
err = tlsCfg.VerifyConnection(*validCS)
if err != nil {
t.Errorf("VerifyConnection() should pass for valid cert: %v", err)
}
// 测试吊销证书连接
revokedCS := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{revokedCert},
}
err = tlsCfg.VerifyConnection(*revokedCS)
if err == nil {
t.Error("VerifyConnection() should fail for revoked cert")
}
}
// TestVerifyConnection_DepthLimit 测试证书链深度限制。
func TestVerifyConnection_DepthLimit(t *testing.T) {
caCert, caKey, _ := generateTestCA(t)
clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 500)
tempDir := t.TempDir()
caFile := filepath.Join(tempDir, "ca.crt")
_, _, caPEM := generateTestCA(t)
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
t.Fatalf("写入 CA 文件失败: %v", err)
}
// 测试深度限制为 1
verifier, _ := NewClientVerifier(config.ClientVerifyConfig{
Enabled: true,
Mode: "on",
ClientCA: caFile,
VerifyDepth: 1,
})
tlsCfg := &tls.Config{}
verifier.ConfigureTLS(tlsCfg)
// 单个证书应通过
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{clientCert},
}
err := tlsCfg.VerifyConnection(*cs)
if err != nil {
t.Errorf("VerifyConnection() should pass for single cert with depth 1: %v", err)
}
// 多个证书应失败(链太长)
longChain := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{clientCert, caCert},
}
err = tlsCfg.VerifyConnection(*longChain)
if err == nil {
t.Error("VerifyConnection() should fail for chain exceeding depth limit")
}
}
// BenchmarkLoadCACertPool 基准测试 CA 证书池加载。
func BenchmarkLoadCACertPool(b *testing.B) {
tempDir := b.TempDir()
caFile := filepath.Join(tempDir, "ca.crt")
// 生成 CA
_, _, caPEM := generateTestCA(&testing.T{})
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
b.Fatalf("写入 CA 文件失败: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := sslutil.LoadCACertPool(caFile)
if err != nil {
b.Fatal(err)
}
}
}