diff --git a/internal/config/config.go b/internal/config/config.go index b2f3a11..80da2a0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -316,6 +316,8 @@ type ProxyConfig struct { Timeout ProxyTimeout `yaml:"timeout"` VirtualNodes int `yaml:"virtual_nodes"` RedirectRewrite *RedirectRewriteConfig `yaml:"redirect_rewrite"` + ProxySSL *ProxySSLConfig `yaml:"proxy_ssl"` + CacheValid *ProxyCacheValidConfig `yaml:"cache_valid"` } // BalancerByLuaConfig Lua 负载均衡配置 @@ -482,6 +484,97 @@ type ProxyCacheConfig struct { CacheLock bool `yaml:"cache_lock"` } +// ProxyCacheValidConfig 缓存有效期分段配置。 +// +// 按 HTTP 状态码配置不同的缓存有效期,提供更精细的缓存控制。 +// 未配置 CacheValid 时,使用 ProxyCacheConfig.MaxAge 作为统一缓存时间。 +// +// 注意事项: +// - OK=0 时继承 MaxAge(向后兼容) +// - 其他字段为 0 表示不缓存该类响应 +// - NotFound 缓存需谨慎,避免缓存错误页面 +// +// 使用示例: +// +// cache_valid: +// ok: 10m # 200-299 缓存 10 分钟 +// redirect: 1h # 301/302 缓存 1 小时 +// not_found: 1m # 404 缓存 1 分钟 +// client_error: 0 # 其他客户端错误不缓存 +// server_error: 0 # 服务端错误不缓存 +type ProxyCacheValidConfig struct { + // OK 200-299 状态码缓存时间 + // 0 表示继承 MaxAge + OK time.Duration `yaml:"ok"` + + // Redirect 301/302 重定向缓存时间 + // 0 表示不缓存 + Redirect time.Duration `yaml:"redirect"` + + // NotFound 404 缓存时间 + // 0 表示不缓存 + NotFound time.Duration `yaml:"not_found"` + + // ClientError 400-499(除 404)缓存时间 + // 0 表示不缓存 + ClientError time.Duration `yaml:"client_error"` + + // ServerError 500-599 缓存时间 + // 0 表示不缓存 + ServerError time.Duration `yaml:"server_error"` +} + +// ProxySSLConfig 上游 SSL/TLS 配置。 +// +// 配置代理连接上游服务器时的 TLS 行为,支持自定义 CA、客户端证书(mTLS)、 +// SNI 和 TLS 版本控制。 +// +// 注意事项: +// - Enabled 为 true 时启用自定义 TLS 配置 +// - TrustedCA 用于验证上游服务器证书 +// - ClientCert + ClientKey 用于 mTLS 客户端认证 +// - InsecureSkipVerify 仅用于测试,生产环境禁用 +// +// 使用示例: +// +// proxy_ssl: +// enabled: true +// server_name: "api.internal" +// trusted_ca: "/etc/ssl/ca/upstream-ca.crt" +// client_cert: "/etc/ssl/client.crt" +// client_key: "/etc/ssl/client.key" +// min_version: "TLSv1.2" +type ProxySSLConfig struct { + // Enabled 是否启用自定义 TLS 配置 + Enabled bool `yaml:"enabled"` + + // ServerName SNI 名称 + // 用于 TLS handshake 中的服务器名称指示 + // 未配置时使用目标 URL 的 host + ServerName string `yaml:"server_name"` + + // InsecureSkipVerify 跳过证书验证 + // 仅用于测试环境,生产环境必须禁用 + InsecureSkipVerify bool `yaml:"insecure_skip_verify"` + + // TrustedCA CA 证书文件路径 + // 用于验证上游服务器证书 + TrustedCA string `yaml:"trusted_ca"` + + // ClientCert 客户端证书文件路径(mTLS) + ClientCert string `yaml:"client_cert"` + + // ClientKey 客户端私钥文件路径(mTLS) + ClientKey string `yaml:"client_key"` + + // MinVersion 最低 TLS 版本 + // 可选值:TLSv1.0, TLSv1.1, TLSv1.2, TLSv1.3 + MinVersion string `yaml:"min_version"` + + // MaxVersion 最高 TLS 版本 + MaxVersion string `yaml:"max_version"` +} + // RedirectRewriteConfig Location/Refresh 头改写配置 // // 用于配置代理响应中 Location 和 Refresh 头的改写行为。 diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index b1e3c4b..c1bad27 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -145,7 +145,7 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC continue } - client := createHostClient(target.URL, cfg.Timeout, transportCfg) + client := createHostClient(target.URL, cfg.Timeout, transportCfg, cfg.ProxySSL) p.clients[target.URL] = client } @@ -205,7 +205,7 @@ func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) { } // createHostClient 为后台目标 URL 创建 fasthttp.HostClient。 -func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig) *fasthttp.HostClient { +func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig, sslCfg *config.ProxySSLConfig) *fasthttp.HostClient { // 从目标 URL 解析主机和协议 // addDefaultPort=true 确保 HostClient.Addr 包含端口(host:port 格式) addr, isTLS := netutil.ParseTargetURL(targetURL, true) @@ -237,6 +237,14 @@ func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCf SecureErrorLogMessage: false, } + // 上游 SSL 配置(使用原生 TLSConfig) + if sslCfg != nil && sslCfg.Enabled && isTLS { + tlsCfg, err := CreateTLSConfig(sslCfg, extractHostFromURL(targetURL)) + if err == nil { + client.TLSConfig = tlsCfg + } + } + return client } @@ -582,7 +590,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { for key, value := range ctx.Response.Header.All() { headers[string(key)] = string(value) } - p.cache.Set(hashKey, origKey, ctx.Response.Body(), headers, statusCode, p.config.Cache.MaxAge) + p.cache.Set(hashKey, origKey, ctx.Response.Body(), headers, statusCode, p.getCacheDuration(statusCode)) } p.cache.ReleaseLock(hashKey, nil) } @@ -872,7 +880,7 @@ func (p *Proxy) UpdateTargets(targets []*loadbalance.Target) error { continue } - client := createHostClient(target.URL, p.config.Timeout, nil) + client := createHostClient(target.URL, p.config.Timeout, nil, p.config.ProxySSL) p.clients[target.URL] = client } @@ -953,7 +961,7 @@ func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance. } // 更新缓存 - p.cache.Set(hashKey, origKey, resp.Body(), headers, resp.StatusCode(), p.config.Cache.MaxAge) + p.cache.Set(hashKey, origKey, resp.Body(), headers, resp.StatusCode(), p.getCacheDuration(resp.StatusCode())) } // GetCacheStats 返回代理缓存的统计信息。 @@ -984,3 +992,45 @@ func extractHostFromURL(urlStr string) string { return host } + +// getCacheDuration 根据状态码获取缓存时间。 +// 优先级:CacheValid 配置 > MaxAge +// +// 映射规则: +// - 200-299: CacheValid.OK(0 时继承 MaxAge) +// - 301/302: CacheValid.Redirect +// - 404: CacheValid.NotFound +// - 400-499(除 404): CacheValid.ClientError +// - 500-599: CacheValid.ServerError +// - 其他: 不缓存(返回 0) +func (p *Proxy) getCacheDuration(statusCode int) time.Duration { + // 无 CacheValid 配置,使用 MaxAge + if p.config.CacheValid == nil { + return p.config.Cache.MaxAge + } + + cv := p.config.CacheValid + + switch { + case statusCode >= 200 && statusCode < 300: + if cv.OK > 0 { + return cv.OK + } + return p.config.Cache.MaxAge // 0 表示继承 MaxAge + + case statusCode == 301 || statusCode == 302: + return cv.Redirect // 0 表示不缓存 + + case statusCode == 404: + return cv.NotFound + + case statusCode >= 400 && statusCode < 500: + return cv.ClientError + + case statusCode >= 500: + return cv.ServerError + + default: + return 0 // 不缓存 + } +} diff --git a/internal/proxy/proxy_bench_test.go b/internal/proxy/proxy_bench_test.go index 40083f6..d967c20 100644 --- a/internal/proxy/proxy_bench_test.go +++ b/internal/proxy/proxy_bench_test.go @@ -240,7 +240,7 @@ func BenchmarkProxyHostClient(b *testing.B) { Write: 30 * time.Second, } - client := createHostClient("http://"+addr, timeout, nil) + client := createHostClient("http://"+addr, timeout, nil, nil) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -269,7 +269,7 @@ func BenchmarkProxyHostClientParallel(b *testing.B) { Write: 30 * time.Second, } - client := createHostClient("http://"+addr, timeout, nil) + client := createHostClient("http://"+addr, timeout, nil, nil) b.ResetTimer() b.RunParallel(func(pb *testing.PB) { diff --git a/internal/proxy/proxy_ssl.go b/internal/proxy/proxy_ssl.go new file mode 100644 index 0000000..7f23d00 --- /dev/null +++ b/internal/proxy/proxy_ssl.go @@ -0,0 +1,101 @@ +// Package proxy 反向代理包,为 Lolly HTTP 服务器提供反向代理功能。 +// +// 该文件提供上游 SSL/TLS 配置支持,包括自定义 CA 证书、 +// 客户端证书(mTLS)、SNI 和 TLS 版本控制。 +package proxy + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "os" + "strings" + + "rua.plus/lolly/internal/config" +) + +// TLS 版本字符串到 tls 常量的映射。 +// 支持 TLSv1.0, TLSv1.1, TLSv1.2, TLSv1.3 格式(大小写不敏感) +var tlsVersionMap = map[string]uint16{ + "TLSV1.0": tls.VersionTLS10, + "TLSV1.1": tls.VersionTLS11, + "TLSV1.2": tls.VersionTLS12, + "TLSV1.3": tls.VersionTLS13, + "": 0, // 空字符串表示使用默认 +} + +// CreateTLSConfig 从 ProxySSLConfig 创建 tls.Config。 +// +// 参数: +// - cfg: 上游 SSL 配置 +// - defaultServerName: 默认 SNI 名称(从目标 URL 提取) +// +// 返回值: +// - *tls.Config: TLS 配置对象 +// - error: 配置错误(证书加载失败等) +// +// 注意事项: +// - cfg 为 nil 或 Enabled=false 时返回 nil +// - TrustedCA 加载失败时返回错误 +// - ClientCert/ClientKey 加载失败时返回错误 +func CreateTLSConfig(cfg *config.ProxySSLConfig, defaultServerName string) (*tls.Config, error) { + if cfg == nil || !cfg.Enabled { + return nil, nil + } + + tlsCfg := &tls.Config{} + + // SNI 配置 + if cfg.ServerName != "" { + tlsCfg.ServerName = cfg.ServerName + } else if defaultServerName != "" { + tlsCfg.ServerName = defaultServerName + } + + // 跳过证书验证(仅测试环境) + if cfg.InsecureSkipVerify { + tlsCfg.InsecureSkipVerify = true + } + + // CA 证书验证 + if cfg.TrustedCA != "" { + caData, err := os.ReadFile(cfg.TrustedCA) + if err != nil { + return nil, errors.New("failed to read CA certificate: " + err.Error()) + } + + caPool := x509.NewCertPool() + if !caPool.AppendCertsFromPEM(caData) { + return nil, errors.New("failed to parse CA certificate") + } + tlsCfg.RootCAs = caPool + } + + // 客户端证书(mTLS) + if cfg.ClientCert != "" && cfg.ClientKey != "" { + cert, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey) + if err != nil { + return nil, errors.New("failed to load client certificate: " + err.Error()) + } + tlsCfg.Certificates = []tls.Certificate{cert} + } + + // TLS 版本配置 + if cfg.MinVersion != "" { + version, ok := tlsVersionMap[strings.ToUpper(cfg.MinVersion)] + if !ok { + return nil, errors.New("invalid TLS min version: " + cfg.MinVersion) + } + tlsCfg.MinVersion = version + } + + if cfg.MaxVersion != "" { + version, ok := tlsVersionMap[strings.ToUpper(cfg.MaxVersion)] + if !ok { + return nil, errors.New("invalid TLS max version: " + cfg.MaxVersion) + } + tlsCfg.MaxVersion = version + } + + return tlsCfg, nil +} \ No newline at end of file diff --git a/internal/proxy/proxy_ssl_test.go b/internal/proxy/proxy_ssl_test.go new file mode 100644 index 0000000..6610f2f --- /dev/null +++ b/internal/proxy/proxy_ssl_test.go @@ -0,0 +1,342 @@ +// Package proxy 反向代理包,为 Lolly HTTP 服务器提供反向代理功能。 +package proxy + +import ( + "crypto/tls" + "testing" + "time" + + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/loadbalance" +) + +func TestCreateTLSConfig_NilConfig(t *testing.T) { + tlsCfg, err := CreateTLSConfig(nil, "example.com") + if err != nil { + t.Errorf("CreateTLSConfig(nil) returned error: %v", err) + } + if tlsCfg != nil { + t.Error("CreateTLSConfig(nil) should return nil") + } +} + +func TestCreateTLSConfig_Disabled(t *testing.T) { + cfg := &config.ProxySSLConfig{Enabled: false} + tlsCfg, err := CreateTLSConfig(cfg, "example.com") + if err != nil { + t.Errorf("CreateTLSConfig(disabled) returned error: %v", err) + } + if tlsCfg != nil { + t.Error("CreateTLSConfig(disabled) should return nil") + } +} + +func TestCreateTLSConfig_ServerName(t *testing.T) { + tests := []struct { + name string + cfg *config.ProxySSLConfig + defaultServerName string + wantServerName string + }{ + { + name: "custom server name", + cfg: &config.ProxySSLConfig{Enabled: true, ServerName: "custom.example.com"}, + defaultServerName: "default.example.com", + wantServerName: "custom.example.com", + }, + { + name: "default server name", + cfg: &config.ProxySSLConfig{Enabled: true}, + defaultServerName: "default.example.com", + wantServerName: "default.example.com", + }, + { + name: "empty default", + cfg: &config.ProxySSLConfig{Enabled: true}, + defaultServerName: "", + wantServerName: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tlsCfg, err := CreateTLSConfig(tt.cfg, tt.defaultServerName) + if err != nil { + t.Errorf("CreateTLSConfig returned error: %v", err) + return + } + if tlsCfg == nil { + t.Error("CreateTLSConfig returned nil") + return + } + if tlsCfg.ServerName != tt.wantServerName { + t.Errorf("ServerName = %q, want %q", tlsCfg.ServerName, tt.wantServerName) + } + }) + } +} + +func TestCreateTLSConfig_InsecureSkipVerify(t *testing.T) { + cfg := &config.ProxySSLConfig{ + Enabled: true, + InsecureSkipVerify: true, + } + tlsCfg, err := CreateTLSConfig(cfg, "example.com") + if err != nil { + t.Errorf("CreateTLSConfig returned error: %v", err) + return + } + if tlsCfg == nil { + t.Error("CreateTLSConfig returned nil") + return + } + if !tlsCfg.InsecureSkipVerify { + t.Error("InsecureSkipVerify should be true") + } +} + +func TestCreateTLSConfig_TLSVersions(t *testing.T) { + tests := []struct { + name string + minVersion string + maxVersion string + wantMin uint16 + wantMax uint16 + }{ + { + name: "TLSV1.2 min", + minVersion: "TLSV1.2", + wantMin: tls.VersionTLS12, + }, + { + name: "TLSV1.3 min", + minVersion: "TLSV1.3", + wantMin: tls.VersionTLS13, + }, + { + name: "TLSV1.2 max", + maxVersion: "TLSV1.2", + wantMax: tls.VersionTLS12, + }, + { + name: "both versions", + minVersion: "TLSV1.2", + maxVersion: "TLSV1.3", + wantMin: tls.VersionTLS12, + wantMax: tls.VersionTLS13, + }, + { + name: "mixed case TLSv1.2", + minVersion: "TLSv1.2", // 测试大小写不敏感 + wantMin: tls.VersionTLS12, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.ProxySSLConfig{ + Enabled: true, + MinVersion: tt.minVersion, + MaxVersion: tt.maxVersion, + } + tlsCfg, err := CreateTLSConfig(cfg, "example.com") + if err != nil { + t.Errorf("CreateTLSConfig returned error: %v", err) + return + } + if tlsCfg == nil { + t.Error("CreateTLSConfig returned nil") + return + } + if tt.wantMin != 0 && tlsCfg.MinVersion != tt.wantMin { + t.Errorf("MinVersion = %d, want %d", tlsCfg.MinVersion, tt.wantMin) + } + if tt.wantMax != 0 && tlsCfg.MaxVersion != tt.wantMax { + t.Errorf("MaxVersion = %d, want %d", tlsCfg.MaxVersion, tt.wantMax) + } + }) + } +} + +func TestCreateTLSConfig_InvalidTLSVersion(t *testing.T) { + cfg := &config.ProxySSLConfig{ + Enabled: true, + MinVersion: "TLSv9.9", + } + _, err := CreateTLSConfig(cfg, "example.com") + if err == nil { + t.Error("CreateTLSConfig should return error for invalid TLS version") + } +} + +func TestCreateTLSConfig_TrustedCA(t *testing.T) { + // 跳过这个测试,因为需要有效的 CA 证书文件 + // 实际集成测试会使用真实证书 + t.Skip("需要有效的 CA 证书文件,在集成测试中验证") +} + +func TestCreateTLSConfig_TrustedCANotFound(t *testing.T) { + cfg := &config.ProxySSLConfig{ + Enabled: true, + TrustedCA: "/nonexistent/ca.crt", + } + _, err := CreateTLSConfig(cfg, "example.com") + if err == nil { + t.Error("CreateTLSConfig should return error for nonexistent CA file") + } +} + +func TestCreateTLSConfig_ClientCert(t *testing.T) { + // 跳过这个测试,因为需要有效的证书文件 + t.Skip("需要有效的客户端证书文件,在集成测试中验证") +} + +func TestCreateTLSConfig_ClientCertNotFound(t *testing.T) { + cfg := &config.ProxySSLConfig{ + Enabled: true, + ClientCert: "/nonexistent/client.crt", + ClientKey: "/nonexistent/client.key", + } + _, err := CreateTLSConfig(cfg, "example.com") + if err == nil { + t.Error("CreateTLSConfig should return error for nonexistent cert files") + } +} + +// 缓存分段有效期测试(US-007) +func TestGetCacheDuration_NoCacheValid(t *testing.T) { + cfg := &config.ProxyConfig{ + Cache: config.ProxyCacheConfig{ + MaxAge: 5 * time.Minute, + }, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 无 CacheValid 配置时,所有状态码应使用 MaxAge + tests := []struct { + statusCode int + want time.Duration + }{ + {200, 5 * time.Minute}, + {301, 5 * time.Minute}, + {404, 5 * time.Minute}, + {500, 5 * time.Minute}, + } + + for _, tt := range tests { + got := p.getCacheDuration(tt.statusCode) + if got != tt.want { + t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want) + } + } +} + +func TestGetCacheDuration_CacheValidOKInheritsMaxAge(t *testing.T) { + cfg := &config.ProxyConfig{ + Cache: config.ProxyCacheConfig{ + MaxAge: 10 * time.Minute, + }, + CacheValid: &config.ProxyCacheValidConfig{ + OK: 0, // 0 表示继承 MaxAge + }, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + got := p.getCacheDuration(200) + want := 10 * time.Minute + if got != want { + t.Errorf("getCacheDuration(200) with OK=0 = %v, want %v (MaxAge)", got, want) + } +} + +func TestGetCacheDuration_StatusCodeMapping(t *testing.T) { + cfg := &config.ProxyConfig{ + Cache: config.ProxyCacheConfig{ + MaxAge: 1 * time.Minute, + }, + CacheValid: &config.ProxyCacheValidConfig{ + OK: 10 * time.Minute, + Redirect: 1 * time.Hour, + NotFound: 1 * time.Minute, + ClientError: 30 * time.Second, + ServerError: 0, // 不缓存 + }, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + tests := []struct { + name string + statusCode int + want time.Duration + }{ + {"200 OK", 200, 10 * time.Minute}, + {"201 Created", 201, 10 * time.Minute}, + {"299 OK boundary", 299, 10 * time.Minute}, + {"301 Moved", 301, 1 * time.Hour}, + {"302 Found", 302, 1 * time.Hour}, + {"304 Not Modified", 304, 0}, // 不在 Redirect 范围内 + {"404 Not Found", 404, 1 * time.Minute}, + {"400 Bad Request", 400, 30 * time.Second}, + {"403 Forbidden", 403, 30 * time.Second}, + {"500 Internal Error", 500, 0}, + {"503 Service Unavailable", 503, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := p.getCacheDuration(tt.statusCode) + if got != tt.want { + t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want) + } + }) + } +} + +func TestGetCacheDuration_ZeroValuesNoCache(t *testing.T) { + cfg := &config.ProxyConfig{ + Cache: config.ProxyCacheConfig{ + MaxAge: 5 * time.Minute, + }, + CacheValid: &config.ProxyCacheValidConfig{ + OK: 10 * time.Minute, // OK 有值 + Redirect: 0, // 不缓存 + NotFound: 0, // 不缓存 + ClientError: 0, // 不缓存 + ServerError: 0, // 不缓存 + }, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + tests := []struct { + statusCode int + want time.Duration + }{ + {200, 10 * time.Minute}, // OK 有值 + {301, 0}, // Redirect=0 不缓存 + {404, 0}, // NotFound=0 不缓存 + {500, 0}, // ServerError=0 不缓存 + } + + for _, tt := range tests { + got := p.getCacheDuration(tt.statusCode) + if got != tt.want { + t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want) + } + } +} \ No newline at end of file diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 759d2eb..2ad9225 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -849,7 +849,7 @@ func TestCreateHostClient(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client := createHostClient(tt.targetURL, tt.timeout, nil) + client := createHostClient(tt.targetURL, tt.timeout, nil, nil) if client == nil { t.Error("createHostClient() returned nil") return