diff --git a/internal/ssl/client_verify.go b/internal/ssl/client_verify.go index df9c8e8..43cb2ea 100644 --- a/internal/ssl/client_verify.go +++ b/internal/ssl/client_verify.go @@ -22,6 +22,7 @@ import ( "time" "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/sslutil" ) // ClientVerifyMode 客户端证书验证模式 @@ -147,7 +148,7 @@ func NewClientVerifier(cfg config.ClientVerifyConfig) (*ClientVerifier, error) { return nil, errors.New("client_ca is required when verify is enabled") } - caPool, err := LoadCACertPool(cfg.ClientCA) + caPool, err := sslutil.LoadCACertPool(cfg.ClientCA) if err != nil { return nil, fmt.Errorf("failed to load CA certificate pool: %w", err) } @@ -246,30 +247,6 @@ func (v *ClientVerifier) GetMode() ClientVerifyMode { return v.mode } -// LoadCACertPool 从文件加载 CA 证书池。 -// -// 支持 PEM 格式的证书文件,可包含多个 CA 证书。 -// -// 参数: -// - caFile: CA 证书文件路径 -// -// 返回值: -// - *x509.CertPool: CA 证书池 -// - error: 加载失败时返回错误 -func LoadCACertPool(caFile string) (*x509.CertPool, error) { - data, err := os.ReadFile(caFile) - if err != nil { - return nil, fmt.Errorf("failed to read CA file: %w", err) - } - - caPool := x509.NewCertPool() - if !caPool.AppendCertsFromPEM(data) { - return nil, errors.New("failed to parse CA certificates") - } - - return caPool, nil -} - // LoadCRL 从文件加载证书吊销列表。 // // 支持 PEM 和 DER 格式的 CRL 文件。 diff --git a/internal/ssl/client_verify_test.go b/internal/ssl/client_verify_test.go index 4636895..c4be7f2 100644 --- a/internal/ssl/client_verify_test.go +++ b/internal/ssl/client_verify_test.go @@ -24,6 +24,7 @@ import ( "time" "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/sslutil" ) // generateTestCA 生成测试 CA 证书。 @@ -165,7 +166,7 @@ func TestLoadCACertPool(t *testing.T) { } // 测试加载 - pool, err := LoadCACertPool(caFile) + pool, err := sslutil.LoadCACertPool(caFile) if err != nil { t.Fatalf("LoadCACertPool() failed: %v", err) } @@ -174,7 +175,7 @@ func TestLoadCACertPool(t *testing.T) { } // 测试文件不存在 - _, err = LoadCACertPool("/nonexistent/ca.crt") + _, err = sslutil.LoadCACertPool("/nonexistent/ca.crt") if err == nil { t.Error("LoadCACertPool() should fail for non-existent file") } @@ -184,7 +185,7 @@ func TestLoadCACertPool(t *testing.T) { if err := os.WriteFile(invalidFile, []byte("invalid data"), 0644); err != nil { t.Fatalf("写入无效证书文件失败: %v", err) } - _, err = LoadCACertPool(invalidFile) + _, err = sslutil.LoadCACertPool(invalidFile) if err == nil { t.Error("LoadCACertPool() should fail for invalid certificate") } @@ -692,7 +693,7 @@ func BenchmarkLoadCACertPool(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := LoadCACertPool(caFile) + _, err := sslutil.LoadCACertPool(caFile) if err != nil { b.Fatal(err) } diff --git a/internal/sslutil/certpool.go b/internal/sslutil/certpool.go new file mode 100644 index 0000000..6b09d0c --- /dev/null +++ b/internal/sslutil/certpool.go @@ -0,0 +1,56 @@ +// Package sslutil provides SSL/TLS utility functions. +package sslutil + +import ( + "crypto/x509" + "errors" + "fmt" + "os" +) + +// LoadCertPool loads a certificate pool from a file. +// Supports PEM format certificate files that may contain multiple certificates. +// +// Parameters: +// - certFile: Certificate file path +// - context: Context description for error messages +// +// Returns: +// - *x509.CertPool: Certificate pool +// - error: Returns error if loading fails +func LoadCertPool(certFile string, context string) (*x509.CertPool, error) { + data, err := os.ReadFile(certFile) + if err != nil { + return nil, fmt.Errorf("failed to read certificate file: %w", err) + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(data) { + return nil, fmt.Errorf("failed to parse certificates from %s", certFile) + } + + return pool, nil +} + +// LoadCACertPool loads a CA certificate pool from a file. +// This is a convenience function for loading CA certificates. +// +// Parameters: +// - caFile: CA certificate file path +// +// Returns: +// - *x509.CertPool: CA certificate pool +// - error: Returns error if loading fails +func LoadCACertPool(caFile string) (*x509.CertPool, error) { + data, err := os.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("failed to read CA file: %w", err) + } + + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM(data) { + return nil, errors.New("failed to parse CA certificates") + } + + return caPool, nil +} diff --git a/internal/stream/ssl.go b/internal/stream/ssl.go index e4a4346..9bb5810 100644 --- a/internal/stream/ssl.go +++ b/internal/stream/ssl.go @@ -13,10 +13,10 @@ import ( "crypto/tls" "crypto/x509" "fmt" - "os" "sync" "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/sslutil" ) // SSLManager 管理 Stream SSL/TLS 配置。 @@ -65,7 +65,7 @@ func NewSSLManager(cfg config.StreamSSLConfig) (*SSLManager, error) { // 加载客户端 CA 证书(mTLS) if cfg.ClientCA != "" { - pool, err := loadCertPool(cfg.ClientCA) + pool, err := sslutil.LoadCertPool(cfg.ClientCA, "client CA") if err != nil { return nil, fmt.Errorf("failed to load client CA: %w", err) } @@ -101,7 +101,7 @@ func NewProxySSLManager(cfg config.StreamProxySSLConfig) (*ProxySSLManager, erro // 加载信任的 CA 证书 if cfg.TrustedCA != "" { - pool, err := loadCertPool(cfg.TrustedCA) + pool, err := sslutil.LoadCertPool(cfg.TrustedCA, "trusted CA") if err != nil { return nil, fmt.Errorf("failed to load trusted CA: %w", err) } @@ -211,28 +211,6 @@ func (m *ProxySSLManager) IsEnabled() bool { return m.config.Enabled } -// loadCertPool 从文件加载证书池。 -// -// 参数: -// - certFile: 证书文件路径 -// -// 返回值: -// - *x509.CertPool: 证书池 -// - error: 加载失败时返回错误 -func loadCertPool(certFile string) (*x509.CertPool, error) { - data, err := os.ReadFile(certFile) - if err != nil { - return nil, err - } - - pool := x509.NewCertPool() - if !pool.AppendCertsFromPEM(data) { - return nil, fmt.Errorf("failed to parse certificates from %s", certFile) - } - - return pool, nil -} - // parseMinTLSVersion 解析最小 TLS 版本。 // // 参数: diff --git a/internal/stream/ssl_test.go b/internal/stream/ssl_test.go index b7a95e0..cbfb855 100644 --- a/internal/stream/ssl_test.go +++ b/internal/stream/ssl_test.go @@ -18,6 +18,7 @@ import ( "time" "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/sslutil" ) // generateTestCertificate 生成测试用的自签名证书 @@ -351,7 +352,7 @@ func TestLoadCertPool(t *testing.T) { t.Fatalf("Failed to encode certificate: %v", err) } - pool, err := loadCertPool(certFile) + pool, err := sslutil.LoadCertPool(certFile, "test") if err != nil { t.Fatalf("loadCertPool failed: %v", err) } @@ -361,7 +362,7 @@ func TestLoadCertPool(t *testing.T) { }) t.Run("invalid path", func(t *testing.T) { - _, err := loadCertPool("/nonexistent/cert.pem") + _, err := sslutil.LoadCertPool("/nonexistent/cert.pem", "test") if err == nil { t.Error("Expected error for nonexistent file") } @@ -374,7 +375,7 @@ func TestLoadCertPool(t *testing.T) { t.Fatalf("写入无效证书文件失败: %v", err) } - _, err := loadCertPool(certFile) + _, err := sslutil.LoadCertPool(certFile, "test") if err == nil { t.Error("Expected error for invalid certificate content") }