refactor(ssl): 提取证书池加载函数到 sslutil 包
将 LoadCACertPool 和 LoadCertPool 函数提取到独立的 sslutil 包, 消除 ssl 和 stream 模块中的重复实现。 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
019bc80aa4
commit
96bd4b0ed5
@ -22,6 +22,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"rua.plus/lolly/internal/config"
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/sslutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientVerifyMode 客户端证书验证模式
|
// ClientVerifyMode 客户端证书验证模式
|
||||||
@ -147,7 +148,7 @@ func NewClientVerifier(cfg config.ClientVerifyConfig) (*ClientVerifier, error) {
|
|||||||
return nil, errors.New("client_ca is required when verify is enabled")
|
return nil, errors.New("client_ca is required when verify is enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
caPool, err := LoadCACertPool(cfg.ClientCA)
|
caPool, err := sslutil.LoadCACertPool(cfg.ClientCA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load CA certificate pool: %w", err)
|
return nil, fmt.Errorf("failed to load CA certificate pool: %w", err)
|
||||||
}
|
}
|
||||||
@ -246,30 +247,6 @@ func (v *ClientVerifier) GetMode() ClientVerifyMode {
|
|||||||
return v.mode
|
return v.mode
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadCACertPool 从文件加载 CA 证书池。
|
|
||||||
//
|
|
||||||
// 支持 PEM 格式的证书文件,可包含多个 CA 证书。
|
|
||||||
//
|
|
||||||
// 参数:
|
|
||||||
// - caFile: CA 证书文件路径
|
|
||||||
//
|
|
||||||
// 返回值:
|
|
||||||
// - *x509.CertPool: CA 证书池
|
|
||||||
// - error: 加载失败时返回错误
|
|
||||||
func LoadCACertPool(caFile string) (*x509.CertPool, error) {
|
|
||||||
data, err := os.ReadFile(caFile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read CA file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
caPool := x509.NewCertPool()
|
|
||||||
if !caPool.AppendCertsFromPEM(data) {
|
|
||||||
return nil, errors.New("failed to parse CA certificates")
|
|
||||||
}
|
|
||||||
|
|
||||||
return caPool, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadCRL 从文件加载证书吊销列表。
|
// LoadCRL 从文件加载证书吊销列表。
|
||||||
//
|
//
|
||||||
// 支持 PEM 和 DER 格式的 CRL 文件。
|
// 支持 PEM 和 DER 格式的 CRL 文件。
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"rua.plus/lolly/internal/config"
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/sslutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// generateTestCA 生成测试 CA 证书。
|
// generateTestCA 生成测试 CA 证书。
|
||||||
@ -165,7 +166,7 @@ func TestLoadCACertPool(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 测试加载
|
// 测试加载
|
||||||
pool, err := LoadCACertPool(caFile)
|
pool, err := sslutil.LoadCACertPool(caFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("LoadCACertPool() failed: %v", err)
|
t.Fatalf("LoadCACertPool() failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -174,7 +175,7 @@ func TestLoadCACertPool(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 测试文件不存在
|
// 测试文件不存在
|
||||||
_, err = LoadCACertPool("/nonexistent/ca.crt")
|
_, err = sslutil.LoadCACertPool("/nonexistent/ca.crt")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("LoadCACertPool() should fail for non-existent file")
|
t.Error("LoadCACertPool() should fail for non-existent file")
|
||||||
}
|
}
|
||||||
@ -184,7 +185,7 @@ func TestLoadCACertPool(t *testing.T) {
|
|||||||
if err := os.WriteFile(invalidFile, []byte("invalid data"), 0644); err != nil {
|
if err := os.WriteFile(invalidFile, []byte("invalid data"), 0644); err != nil {
|
||||||
t.Fatalf("写入无效证书文件失败: %v", err)
|
t.Fatalf("写入无效证书文件失败: %v", err)
|
||||||
}
|
}
|
||||||
_, err = LoadCACertPool(invalidFile)
|
_, err = sslutil.LoadCACertPool(invalidFile)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("LoadCACertPool() should fail for invalid certificate")
|
t.Error("LoadCACertPool() should fail for invalid certificate")
|
||||||
}
|
}
|
||||||
@ -692,7 +693,7 @@ func BenchmarkLoadCACertPool(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_, err := LoadCACertPool(caFile)
|
_, err := sslutil.LoadCACertPool(caFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
56
internal/sslutil/certpool.go
Normal file
56
internal/sslutil/certpool.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
// Package sslutil provides SSL/TLS utility functions.
|
||||||
|
package sslutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LoadCertPool loads a certificate pool from a file.
|
||||||
|
// Supports PEM format certificate files that may contain multiple certificates.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - certFile: Certificate file path
|
||||||
|
// - context: Context description for error messages
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *x509.CertPool: Certificate pool
|
||||||
|
// - error: Returns error if loading fails
|
||||||
|
func LoadCertPool(certFile string, context string) (*x509.CertPool, error) {
|
||||||
|
data, err := os.ReadFile(certFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read certificate file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
if !pool.AppendCertsFromPEM(data) {
|
||||||
|
return nil, fmt.Errorf("failed to parse certificates from %s", certFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
return pool, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadCACertPool loads a CA certificate pool from a file.
|
||||||
|
// This is a convenience function for loading CA certificates.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - caFile: CA certificate file path
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *x509.CertPool: CA certificate pool
|
||||||
|
// - error: Returns error if loading fails
|
||||||
|
func LoadCACertPool(caFile string) (*x509.CertPool, error) {
|
||||||
|
data, err := os.ReadFile(caFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read CA file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
caPool := x509.NewCertPool()
|
||||||
|
if !caPool.AppendCertsFromPEM(data) {
|
||||||
|
return nil, errors.New("failed to parse CA certificates")
|
||||||
|
}
|
||||||
|
|
||||||
|
return caPool, nil
|
||||||
|
}
|
||||||
@ -13,10 +13,10 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"rua.plus/lolly/internal/config"
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/sslutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SSLManager 管理 Stream SSL/TLS 配置。
|
// SSLManager 管理 Stream SSL/TLS 配置。
|
||||||
@ -65,7 +65,7 @@ func NewSSLManager(cfg config.StreamSSLConfig) (*SSLManager, error) {
|
|||||||
|
|
||||||
// 加载客户端 CA 证书(mTLS)
|
// 加载客户端 CA 证书(mTLS)
|
||||||
if cfg.ClientCA != "" {
|
if cfg.ClientCA != "" {
|
||||||
pool, err := loadCertPool(cfg.ClientCA)
|
pool, err := sslutil.LoadCertPool(cfg.ClientCA, "client CA")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load client CA: %w", err)
|
return nil, fmt.Errorf("failed to load client CA: %w", err)
|
||||||
}
|
}
|
||||||
@ -101,7 +101,7 @@ func NewProxySSLManager(cfg config.StreamProxySSLConfig) (*ProxySSLManager, erro
|
|||||||
|
|
||||||
// 加载信任的 CA 证书
|
// 加载信任的 CA 证书
|
||||||
if cfg.TrustedCA != "" {
|
if cfg.TrustedCA != "" {
|
||||||
pool, err := loadCertPool(cfg.TrustedCA)
|
pool, err := sslutil.LoadCertPool(cfg.TrustedCA, "trusted CA")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load trusted CA: %w", err)
|
return nil, fmt.Errorf("failed to load trusted CA: %w", err)
|
||||||
}
|
}
|
||||||
@ -211,28 +211,6 @@ func (m *ProxySSLManager) IsEnabled() bool {
|
|||||||
return m.config.Enabled
|
return m.config.Enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadCertPool 从文件加载证书池。
|
|
||||||
//
|
|
||||||
// 参数:
|
|
||||||
// - certFile: 证书文件路径
|
|
||||||
//
|
|
||||||
// 返回值:
|
|
||||||
// - *x509.CertPool: 证书池
|
|
||||||
// - error: 加载失败时返回错误
|
|
||||||
func loadCertPool(certFile string) (*x509.CertPool, error) {
|
|
||||||
data, err := os.ReadFile(certFile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pool := x509.NewCertPool()
|
|
||||||
if !pool.AppendCertsFromPEM(data) {
|
|
||||||
return nil, fmt.Errorf("failed to parse certificates from %s", certFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
return pool, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseMinTLSVersion 解析最小 TLS 版本。
|
// parseMinTLSVersion 解析最小 TLS 版本。
|
||||||
//
|
//
|
||||||
// 参数:
|
// 参数:
|
||||||
|
|||||||
@ -18,6 +18,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"rua.plus/lolly/internal/config"
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/sslutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// generateTestCertificate 生成测试用的自签名证书
|
// generateTestCertificate 生成测试用的自签名证书
|
||||||
@ -351,7 +352,7 @@ func TestLoadCertPool(t *testing.T) {
|
|||||||
t.Fatalf("Failed to encode certificate: %v", err)
|
t.Fatalf("Failed to encode certificate: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pool, err := loadCertPool(certFile)
|
pool, err := sslutil.LoadCertPool(certFile, "test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("loadCertPool failed: %v", err)
|
t.Fatalf("loadCertPool failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -361,7 +362,7 @@ func TestLoadCertPool(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("invalid path", func(t *testing.T) {
|
t.Run("invalid path", func(t *testing.T) {
|
||||||
_, err := loadCertPool("/nonexistent/cert.pem")
|
_, err := sslutil.LoadCertPool("/nonexistent/cert.pem", "test")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error for nonexistent file")
|
t.Error("Expected error for nonexistent file")
|
||||||
}
|
}
|
||||||
@ -374,7 +375,7 @@ func TestLoadCertPool(t *testing.T) {
|
|||||||
t.Fatalf("写入无效证书文件失败: %v", err)
|
t.Fatalf("写入无效证书文件失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := loadCertPool(certFile)
|
_, err := sslutil.LoadCertPool(certFile, "test")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error for invalid certificate content")
|
t.Error("Expected error for invalid certificate content")
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user