feat(proxy): 添加上游 SSL 配置和缓存有效期分段配置

- ProxySSLConfig: 支持自定义 CA、客户端证书(mTLS)、SNI、TLS 版本控制
- ProxyCacheValidConfig: 按 HTTP 状态码分段配置缓存有效期
- proxy_ssl.go: 实现 CreateTLSConfig 和 TLS 版本解析
- proxy.go: 集成 SSL 配置到 HostClient,实现 getCacheDuration 分段缓存
- 测试文件适配新配置

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-15 18:27:50 +08:00
parent bf14282e40
commit a644e551af
6 changed files with 594 additions and 8 deletions

View File

@ -316,6 +316,8 @@ type ProxyConfig struct {
Timeout ProxyTimeout `yaml:"timeout"`
VirtualNodes int `yaml:"virtual_nodes"`
RedirectRewrite *RedirectRewriteConfig `yaml:"redirect_rewrite"`
ProxySSL *ProxySSLConfig `yaml:"proxy_ssl"`
CacheValid *ProxyCacheValidConfig `yaml:"cache_valid"`
}
// BalancerByLuaConfig Lua 负载均衡配置
@ -482,6 +484,97 @@ type ProxyCacheConfig struct {
CacheLock bool `yaml:"cache_lock"`
}
// ProxyCacheValidConfig 缓存有效期分段配置。
//
// 按 HTTP 状态码配置不同的缓存有效期,提供更精细的缓存控制。
// 未配置 CacheValid 时,使用 ProxyCacheConfig.MaxAge 作为统一缓存时间。
//
// 注意事项:
// - OK=0 时继承 MaxAge向后兼容
// - 其他字段为 0 表示不缓存该类响应
// - NotFound 缓存需谨慎,避免缓存错误页面
//
// 使用示例:
//
// cache_valid:
// ok: 10m # 200-299 缓存 10 分钟
// redirect: 1h # 301/302 缓存 1 小时
// not_found: 1m # 404 缓存 1 分钟
// client_error: 0 # 其他客户端错误不缓存
// server_error: 0 # 服务端错误不缓存
type ProxyCacheValidConfig struct {
// OK 200-299 状态码缓存时间
// 0 表示继承 MaxAge
OK time.Duration `yaml:"ok"`
// Redirect 301/302 重定向缓存时间
// 0 表示不缓存
Redirect time.Duration `yaml:"redirect"`
// NotFound 404 缓存时间
// 0 表示不缓存
NotFound time.Duration `yaml:"not_found"`
// ClientError 400-499除 404缓存时间
// 0 表示不缓存
ClientError time.Duration `yaml:"client_error"`
// ServerError 500-599 缓存时间
// 0 表示不缓存
ServerError time.Duration `yaml:"server_error"`
}
// ProxySSLConfig 上游 SSL/TLS 配置。
//
// 配置代理连接上游服务器时的 TLS 行为,支持自定义 CA、客户端证书mTLS
// SNI 和 TLS 版本控制。
//
// 注意事项:
// - Enabled 为 true 时启用自定义 TLS 配置
// - TrustedCA 用于验证上游服务器证书
// - ClientCert + ClientKey 用于 mTLS 客户端认证
// - InsecureSkipVerify 仅用于测试,生产环境禁用
//
// 使用示例:
//
// proxy_ssl:
// enabled: true
// server_name: "api.internal"
// trusted_ca: "/etc/ssl/ca/upstream-ca.crt"
// client_cert: "/etc/ssl/client.crt"
// client_key: "/etc/ssl/client.key"
// min_version: "TLSv1.2"
type ProxySSLConfig struct {
// Enabled 是否启用自定义 TLS 配置
Enabled bool `yaml:"enabled"`
// ServerName SNI 名称
// 用于 TLS handshake 中的服务器名称指示
// 未配置时使用目标 URL 的 host
ServerName string `yaml:"server_name"`
// InsecureSkipVerify 跳过证书验证
// 仅用于测试环境,生产环境必须禁用
InsecureSkipVerify bool `yaml:"insecure_skip_verify"`
// TrustedCA CA 证书文件路径
// 用于验证上游服务器证书
TrustedCA string `yaml:"trusted_ca"`
// ClientCert 客户端证书文件路径mTLS
ClientCert string `yaml:"client_cert"`
// ClientKey 客户端私钥文件路径mTLS
ClientKey string `yaml:"client_key"`
// MinVersion 最低 TLS 版本
// 可选值TLSv1.0, TLSv1.1, TLSv1.2, TLSv1.3
MinVersion string `yaml:"min_version"`
// MaxVersion 最高 TLS 版本
MaxVersion string `yaml:"max_version"`
}
// RedirectRewriteConfig Location/Refresh 头改写配置
//
// 用于配置代理响应中 Location 和 Refresh 头的改写行为。

View File

@ -145,7 +145,7 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC
continue
}
client := createHostClient(target.URL, cfg.Timeout, transportCfg)
client := createHostClient(target.URL, cfg.Timeout, transportCfg, cfg.ProxySSL)
p.clients[target.URL] = client
}
@ -205,7 +205,7 @@ func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) {
}
// createHostClient 为后台目标 URL 创建 fasthttp.HostClient。
func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig) *fasthttp.HostClient {
func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig, sslCfg *config.ProxySSLConfig) *fasthttp.HostClient {
// 从目标 URL 解析主机和协议
// addDefaultPort=true 确保 HostClient.Addr 包含端口host:port 格式)
addr, isTLS := netutil.ParseTargetURL(targetURL, true)
@ -237,6 +237,14 @@ func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCf
SecureErrorLogMessage: false,
}
// 上游 SSL 配置(使用原生 TLSConfig
if sslCfg != nil && sslCfg.Enabled && isTLS {
tlsCfg, err := CreateTLSConfig(sslCfg, extractHostFromURL(targetURL))
if err == nil {
client.TLSConfig = tlsCfg
}
}
return client
}
@ -582,7 +590,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
for key, value := range ctx.Response.Header.All() {
headers[string(key)] = string(value)
}
p.cache.Set(hashKey, origKey, ctx.Response.Body(), headers, statusCode, p.config.Cache.MaxAge)
p.cache.Set(hashKey, origKey, ctx.Response.Body(), headers, statusCode, p.getCacheDuration(statusCode))
}
p.cache.ReleaseLock(hashKey, nil)
}
@ -872,7 +880,7 @@ func (p *Proxy) UpdateTargets(targets []*loadbalance.Target) error {
continue
}
client := createHostClient(target.URL, p.config.Timeout, nil)
client := createHostClient(target.URL, p.config.Timeout, nil, p.config.ProxySSL)
p.clients[target.URL] = client
}
@ -953,7 +961,7 @@ func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance.
}
// 更新缓存
p.cache.Set(hashKey, origKey, resp.Body(), headers, resp.StatusCode(), p.config.Cache.MaxAge)
p.cache.Set(hashKey, origKey, resp.Body(), headers, resp.StatusCode(), p.getCacheDuration(resp.StatusCode()))
}
// GetCacheStats 返回代理缓存的统计信息。
@ -984,3 +992,45 @@ func extractHostFromURL(urlStr string) string {
return host
}
// getCacheDuration 根据状态码获取缓存时间。
// 优先级CacheValid 配置 > MaxAge
//
// 映射规则:
// - 200-299: CacheValid.OK0 时继承 MaxAge
// - 301/302: CacheValid.Redirect
// - 404: CacheValid.NotFound
// - 400-499除 404: CacheValid.ClientError
// - 500-599: CacheValid.ServerError
// - 其他: 不缓存(返回 0
func (p *Proxy) getCacheDuration(statusCode int) time.Duration {
// 无 CacheValid 配置,使用 MaxAge
if p.config.CacheValid == nil {
return p.config.Cache.MaxAge
}
cv := p.config.CacheValid
switch {
case statusCode >= 200 && statusCode < 300:
if cv.OK > 0 {
return cv.OK
}
return p.config.Cache.MaxAge // 0 表示继承 MaxAge
case statusCode == 301 || statusCode == 302:
return cv.Redirect // 0 表示不缓存
case statusCode == 404:
return cv.NotFound
case statusCode >= 400 && statusCode < 500:
return cv.ClientError
case statusCode >= 500:
return cv.ServerError
default:
return 0 // 不缓存
}
}

View File

@ -240,7 +240,7 @@ func BenchmarkProxyHostClient(b *testing.B) {
Write: 30 * time.Second,
}
client := createHostClient("http://"+addr, timeout, nil)
client := createHostClient("http://"+addr, timeout, nil, nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
@ -269,7 +269,7 @@ func BenchmarkProxyHostClientParallel(b *testing.B) {
Write: 30 * time.Second,
}
client := createHostClient("http://"+addr, timeout, nil)
client := createHostClient("http://"+addr, timeout, nil, nil)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {

101
internal/proxy/proxy_ssl.go Normal file
View File

@ -0,0 +1,101 @@
// Package proxy 反向代理包,为 Lolly HTTP 服务器提供反向代理功能。
//
// 该文件提供上游 SSL/TLS 配置支持,包括自定义 CA 证书、
// 客户端证书mTLS、SNI 和 TLS 版本控制。
package proxy
import (
"crypto/tls"
"crypto/x509"
"errors"
"os"
"strings"
"rua.plus/lolly/internal/config"
)
// TLS 版本字符串到 tls 常量的映射。
// 支持 TLSv1.0, TLSv1.1, TLSv1.2, TLSv1.3 格式(大小写不敏感)
var tlsVersionMap = map[string]uint16{
"TLSV1.0": tls.VersionTLS10,
"TLSV1.1": tls.VersionTLS11,
"TLSV1.2": tls.VersionTLS12,
"TLSV1.3": tls.VersionTLS13,
"": 0, // 空字符串表示使用默认
}
// CreateTLSConfig 从 ProxySSLConfig 创建 tls.Config。
//
// 参数:
// - cfg: 上游 SSL 配置
// - defaultServerName: 默认 SNI 名称(从目标 URL 提取)
//
// 返回值:
// - *tls.Config: TLS 配置对象
// - error: 配置错误(证书加载失败等)
//
// 注意事项:
// - cfg 为 nil 或 Enabled=false 时返回 nil
// - TrustedCA 加载失败时返回错误
// - ClientCert/ClientKey 加载失败时返回错误
func CreateTLSConfig(cfg *config.ProxySSLConfig, defaultServerName string) (*tls.Config, error) {
if cfg == nil || !cfg.Enabled {
return nil, nil
}
tlsCfg := &tls.Config{}
// SNI 配置
if cfg.ServerName != "" {
tlsCfg.ServerName = cfg.ServerName
} else if defaultServerName != "" {
tlsCfg.ServerName = defaultServerName
}
// 跳过证书验证(仅测试环境)
if cfg.InsecureSkipVerify {
tlsCfg.InsecureSkipVerify = true
}
// CA 证书验证
if cfg.TrustedCA != "" {
caData, err := os.ReadFile(cfg.TrustedCA)
if err != nil {
return nil, errors.New("failed to read CA certificate: " + err.Error())
}
caPool := x509.NewCertPool()
if !caPool.AppendCertsFromPEM(caData) {
return nil, errors.New("failed to parse CA certificate")
}
tlsCfg.RootCAs = caPool
}
// 客户端证书mTLS
if cfg.ClientCert != "" && cfg.ClientKey != "" {
cert, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, errors.New("failed to load client certificate: " + err.Error())
}
tlsCfg.Certificates = []tls.Certificate{cert}
}
// TLS 版本配置
if cfg.MinVersion != "" {
version, ok := tlsVersionMap[strings.ToUpper(cfg.MinVersion)]
if !ok {
return nil, errors.New("invalid TLS min version: " + cfg.MinVersion)
}
tlsCfg.MinVersion = version
}
if cfg.MaxVersion != "" {
version, ok := tlsVersionMap[strings.ToUpper(cfg.MaxVersion)]
if !ok {
return nil, errors.New("invalid TLS max version: " + cfg.MaxVersion)
}
tlsCfg.MaxVersion = version
}
return tlsCfg, nil
}

View File

@ -0,0 +1,342 @@
// Package proxy 反向代理包,为 Lolly HTTP 服务器提供反向代理功能。
package proxy
import (
"crypto/tls"
"testing"
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
)
func TestCreateTLSConfig_NilConfig(t *testing.T) {
tlsCfg, err := CreateTLSConfig(nil, "example.com")
if err != nil {
t.Errorf("CreateTLSConfig(nil) returned error: %v", err)
}
if tlsCfg != nil {
t.Error("CreateTLSConfig(nil) should return nil")
}
}
func TestCreateTLSConfig_Disabled(t *testing.T) {
cfg := &config.ProxySSLConfig{Enabled: false}
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
if err != nil {
t.Errorf("CreateTLSConfig(disabled) returned error: %v", err)
}
if tlsCfg != nil {
t.Error("CreateTLSConfig(disabled) should return nil")
}
}
func TestCreateTLSConfig_ServerName(t *testing.T) {
tests := []struct {
name string
cfg *config.ProxySSLConfig
defaultServerName string
wantServerName string
}{
{
name: "custom server name",
cfg: &config.ProxySSLConfig{Enabled: true, ServerName: "custom.example.com"},
defaultServerName: "default.example.com",
wantServerName: "custom.example.com",
},
{
name: "default server name",
cfg: &config.ProxySSLConfig{Enabled: true},
defaultServerName: "default.example.com",
wantServerName: "default.example.com",
},
{
name: "empty default",
cfg: &config.ProxySSLConfig{Enabled: true},
defaultServerName: "",
wantServerName: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tlsCfg, err := CreateTLSConfig(tt.cfg, tt.defaultServerName)
if err != nil {
t.Errorf("CreateTLSConfig returned error: %v", err)
return
}
if tlsCfg == nil {
t.Error("CreateTLSConfig returned nil")
return
}
if tlsCfg.ServerName != tt.wantServerName {
t.Errorf("ServerName = %q, want %q", tlsCfg.ServerName, tt.wantServerName)
}
})
}
}
func TestCreateTLSConfig_InsecureSkipVerify(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
InsecureSkipVerify: true,
}
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
if err != nil {
t.Errorf("CreateTLSConfig returned error: %v", err)
return
}
if tlsCfg == nil {
t.Error("CreateTLSConfig returned nil")
return
}
if !tlsCfg.InsecureSkipVerify {
t.Error("InsecureSkipVerify should be true")
}
}
func TestCreateTLSConfig_TLSVersions(t *testing.T) {
tests := []struct {
name string
minVersion string
maxVersion string
wantMin uint16
wantMax uint16
}{
{
name: "TLSV1.2 min",
minVersion: "TLSV1.2",
wantMin: tls.VersionTLS12,
},
{
name: "TLSV1.3 min",
minVersion: "TLSV1.3",
wantMin: tls.VersionTLS13,
},
{
name: "TLSV1.2 max",
maxVersion: "TLSV1.2",
wantMax: tls.VersionTLS12,
},
{
name: "both versions",
minVersion: "TLSV1.2",
maxVersion: "TLSV1.3",
wantMin: tls.VersionTLS12,
wantMax: tls.VersionTLS13,
},
{
name: "mixed case TLSv1.2",
minVersion: "TLSv1.2", // 测试大小写不敏感
wantMin: tls.VersionTLS12,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
MinVersion: tt.minVersion,
MaxVersion: tt.maxVersion,
}
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
if err != nil {
t.Errorf("CreateTLSConfig returned error: %v", err)
return
}
if tlsCfg == nil {
t.Error("CreateTLSConfig returned nil")
return
}
if tt.wantMin != 0 && tlsCfg.MinVersion != tt.wantMin {
t.Errorf("MinVersion = %d, want %d", tlsCfg.MinVersion, tt.wantMin)
}
if tt.wantMax != 0 && tlsCfg.MaxVersion != tt.wantMax {
t.Errorf("MaxVersion = %d, want %d", tlsCfg.MaxVersion, tt.wantMax)
}
})
}
}
func TestCreateTLSConfig_InvalidTLSVersion(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
MinVersion: "TLSv9.9",
}
_, err := CreateTLSConfig(cfg, "example.com")
if err == nil {
t.Error("CreateTLSConfig should return error for invalid TLS version")
}
}
func TestCreateTLSConfig_TrustedCA(t *testing.T) {
// 跳过这个测试,因为需要有效的 CA 证书文件
// 实际集成测试会使用真实证书
t.Skip("需要有效的 CA 证书文件,在集成测试中验证")
}
func TestCreateTLSConfig_TrustedCANotFound(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
TrustedCA: "/nonexistent/ca.crt",
}
_, err := CreateTLSConfig(cfg, "example.com")
if err == nil {
t.Error("CreateTLSConfig should return error for nonexistent CA file")
}
}
func TestCreateTLSConfig_ClientCert(t *testing.T) {
// 跳过这个测试,因为需要有效的证书文件
t.Skip("需要有效的客户端证书文件,在集成测试中验证")
}
func TestCreateTLSConfig_ClientCertNotFound(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
ClientCert: "/nonexistent/client.crt",
ClientKey: "/nonexistent/client.key",
}
_, err := CreateTLSConfig(cfg, "example.com")
if err == nil {
t.Error("CreateTLSConfig should return error for nonexistent cert files")
}
}
// 缓存分段有效期测试US-007
func TestGetCacheDuration_NoCacheValid(t *testing.T) {
cfg := &config.ProxyConfig{
Cache: config.ProxyCacheConfig{
MaxAge: 5 * time.Minute,
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 无 CacheValid 配置时,所有状态码应使用 MaxAge
tests := []struct {
statusCode int
want time.Duration
}{
{200, 5 * time.Minute},
{301, 5 * time.Minute},
{404, 5 * time.Minute},
{500, 5 * time.Minute},
}
for _, tt := range tests {
got := p.getCacheDuration(tt.statusCode)
if got != tt.want {
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
}
}
}
func TestGetCacheDuration_CacheValidOKInheritsMaxAge(t *testing.T) {
cfg := &config.ProxyConfig{
Cache: config.ProxyCacheConfig{
MaxAge: 10 * time.Minute,
},
CacheValid: &config.ProxyCacheValidConfig{
OK: 0, // 0 表示继承 MaxAge
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
got := p.getCacheDuration(200)
want := 10 * time.Minute
if got != want {
t.Errorf("getCacheDuration(200) with OK=0 = %v, want %v (MaxAge)", got, want)
}
}
func TestGetCacheDuration_StatusCodeMapping(t *testing.T) {
cfg := &config.ProxyConfig{
Cache: config.ProxyCacheConfig{
MaxAge: 1 * time.Minute,
},
CacheValid: &config.ProxyCacheValidConfig{
OK: 10 * time.Minute,
Redirect: 1 * time.Hour,
NotFound: 1 * time.Minute,
ClientError: 30 * time.Second,
ServerError: 0, // 不缓存
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
tests := []struct {
name string
statusCode int
want time.Duration
}{
{"200 OK", 200, 10 * time.Minute},
{"201 Created", 201, 10 * time.Minute},
{"299 OK boundary", 299, 10 * time.Minute},
{"301 Moved", 301, 1 * time.Hour},
{"302 Found", 302, 1 * time.Hour},
{"304 Not Modified", 304, 0}, // 不在 Redirect 范围内
{"404 Not Found", 404, 1 * time.Minute},
{"400 Bad Request", 400, 30 * time.Second},
{"403 Forbidden", 403, 30 * time.Second},
{"500 Internal Error", 500, 0},
{"503 Service Unavailable", 503, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := p.getCacheDuration(tt.statusCode)
if got != tt.want {
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
}
})
}
}
func TestGetCacheDuration_ZeroValuesNoCache(t *testing.T) {
cfg := &config.ProxyConfig{
Cache: config.ProxyCacheConfig{
MaxAge: 5 * time.Minute,
},
CacheValid: &config.ProxyCacheValidConfig{
OK: 10 * time.Minute, // OK 有值
Redirect: 0, // 不缓存
NotFound: 0, // 不缓存
ClientError: 0, // 不缓存
ServerError: 0, // 不缓存
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
tests := []struct {
statusCode int
want time.Duration
}{
{200, 10 * time.Minute}, // OK 有值
{301, 0}, // Redirect=0 不缓存
{404, 0}, // NotFound=0 不缓存
{500, 0}, // ServerError=0 不缓存
}
for _, tt := range tests {
got := p.getCacheDuration(tt.statusCode)
if got != tt.want {
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
}
}
}

View File

@ -849,7 +849,7 @@ func TestCreateHostClient(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := createHostClient(tt.targetURL, tt.timeout, nil)
client := createHostClient(tt.targetURL, tt.timeout, nil, nil)
if client == nil {
t.Error("createHostClient() returned nil")
return