feat(proxy): 添加上游 SSL 配置和缓存有效期分段配置
- ProxySSLConfig: 支持自定义 CA、客户端证书(mTLS)、SNI、TLS 版本控制 - ProxyCacheValidConfig: 按 HTTP 状态码分段配置缓存有效期 - proxy_ssl.go: 实现 CreateTLSConfig 和 TLS 版本解析 - proxy.go: 集成 SSL 配置到 HostClient,实现 getCacheDuration 分段缓存 - 测试文件适配新配置 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
bf14282e40
commit
a644e551af
@ -316,6 +316,8 @@ type ProxyConfig struct {
|
||||
Timeout ProxyTimeout `yaml:"timeout"`
|
||||
VirtualNodes int `yaml:"virtual_nodes"`
|
||||
RedirectRewrite *RedirectRewriteConfig `yaml:"redirect_rewrite"`
|
||||
ProxySSL *ProxySSLConfig `yaml:"proxy_ssl"`
|
||||
CacheValid *ProxyCacheValidConfig `yaml:"cache_valid"`
|
||||
}
|
||||
|
||||
// BalancerByLuaConfig Lua 负载均衡配置
|
||||
@ -482,6 +484,97 @@ type ProxyCacheConfig struct {
|
||||
CacheLock bool `yaml:"cache_lock"`
|
||||
}
|
||||
|
||||
// ProxyCacheValidConfig 缓存有效期分段配置。
|
||||
//
|
||||
// 按 HTTP 状态码配置不同的缓存有效期,提供更精细的缓存控制。
|
||||
// 未配置 CacheValid 时,使用 ProxyCacheConfig.MaxAge 作为统一缓存时间。
|
||||
//
|
||||
// 注意事项:
|
||||
// - OK=0 时继承 MaxAge(向后兼容)
|
||||
// - 其他字段为 0 表示不缓存该类响应
|
||||
// - NotFound 缓存需谨慎,避免缓存错误页面
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// cache_valid:
|
||||
// ok: 10m # 200-299 缓存 10 分钟
|
||||
// redirect: 1h # 301/302 缓存 1 小时
|
||||
// not_found: 1m # 404 缓存 1 分钟
|
||||
// client_error: 0 # 其他客户端错误不缓存
|
||||
// server_error: 0 # 服务端错误不缓存
|
||||
type ProxyCacheValidConfig struct {
|
||||
// OK 200-299 状态码缓存时间
|
||||
// 0 表示继承 MaxAge
|
||||
OK time.Duration `yaml:"ok"`
|
||||
|
||||
// Redirect 301/302 重定向缓存时间
|
||||
// 0 表示不缓存
|
||||
Redirect time.Duration `yaml:"redirect"`
|
||||
|
||||
// NotFound 404 缓存时间
|
||||
// 0 表示不缓存
|
||||
NotFound time.Duration `yaml:"not_found"`
|
||||
|
||||
// ClientError 400-499(除 404)缓存时间
|
||||
// 0 表示不缓存
|
||||
ClientError time.Duration `yaml:"client_error"`
|
||||
|
||||
// ServerError 500-599 缓存时间
|
||||
// 0 表示不缓存
|
||||
ServerError time.Duration `yaml:"server_error"`
|
||||
}
|
||||
|
||||
// ProxySSLConfig 上游 SSL/TLS 配置。
|
||||
//
|
||||
// 配置代理连接上游服务器时的 TLS 行为,支持自定义 CA、客户端证书(mTLS)、
|
||||
// SNI 和 TLS 版本控制。
|
||||
//
|
||||
// 注意事项:
|
||||
// - Enabled 为 true 时启用自定义 TLS 配置
|
||||
// - TrustedCA 用于验证上游服务器证书
|
||||
// - ClientCert + ClientKey 用于 mTLS 客户端认证
|
||||
// - InsecureSkipVerify 仅用于测试,生产环境禁用
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// proxy_ssl:
|
||||
// enabled: true
|
||||
// server_name: "api.internal"
|
||||
// trusted_ca: "/etc/ssl/ca/upstream-ca.crt"
|
||||
// client_cert: "/etc/ssl/client.crt"
|
||||
// client_key: "/etc/ssl/client.key"
|
||||
// min_version: "TLSv1.2"
|
||||
type ProxySSLConfig struct {
|
||||
// Enabled 是否启用自定义 TLS 配置
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// ServerName SNI 名称
|
||||
// 用于 TLS handshake 中的服务器名称指示
|
||||
// 未配置时使用目标 URL 的 host
|
||||
ServerName string `yaml:"server_name"`
|
||||
|
||||
// InsecureSkipVerify 跳过证书验证
|
||||
// 仅用于测试环境,生产环境必须禁用
|
||||
InsecureSkipVerify bool `yaml:"insecure_skip_verify"`
|
||||
|
||||
// TrustedCA CA 证书文件路径
|
||||
// 用于验证上游服务器证书
|
||||
TrustedCA string `yaml:"trusted_ca"`
|
||||
|
||||
// ClientCert 客户端证书文件路径(mTLS)
|
||||
ClientCert string `yaml:"client_cert"`
|
||||
|
||||
// ClientKey 客户端私钥文件路径(mTLS)
|
||||
ClientKey string `yaml:"client_key"`
|
||||
|
||||
// MinVersion 最低 TLS 版本
|
||||
// 可选值:TLSv1.0, TLSv1.1, TLSv1.2, TLSv1.3
|
||||
MinVersion string `yaml:"min_version"`
|
||||
|
||||
// MaxVersion 最高 TLS 版本
|
||||
MaxVersion string `yaml:"max_version"`
|
||||
}
|
||||
|
||||
// RedirectRewriteConfig Location/Refresh 头改写配置
|
||||
//
|
||||
// 用于配置代理响应中 Location 和 Refresh 头的改写行为。
|
||||
|
||||
@ -145,7 +145,7 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC
|
||||
continue
|
||||
}
|
||||
|
||||
client := createHostClient(target.URL, cfg.Timeout, transportCfg)
|
||||
client := createHostClient(target.URL, cfg.Timeout, transportCfg, cfg.ProxySSL)
|
||||
p.clients[target.URL] = client
|
||||
}
|
||||
|
||||
@ -205,7 +205,7 @@ func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) {
|
||||
}
|
||||
|
||||
// createHostClient 为后台目标 URL 创建 fasthttp.HostClient。
|
||||
func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig) *fasthttp.HostClient {
|
||||
func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig, sslCfg *config.ProxySSLConfig) *fasthttp.HostClient {
|
||||
// 从目标 URL 解析主机和协议
|
||||
// addDefaultPort=true 确保 HostClient.Addr 包含端口(host:port 格式)
|
||||
addr, isTLS := netutil.ParseTargetURL(targetURL, true)
|
||||
@ -237,6 +237,14 @@ func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCf
|
||||
SecureErrorLogMessage: false,
|
||||
}
|
||||
|
||||
// 上游 SSL 配置(使用原生 TLSConfig)
|
||||
if sslCfg != nil && sslCfg.Enabled && isTLS {
|
||||
tlsCfg, err := CreateTLSConfig(sslCfg, extractHostFromURL(targetURL))
|
||||
if err == nil {
|
||||
client.TLSConfig = tlsCfg
|
||||
}
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
@ -582,7 +590,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
for key, value := range ctx.Response.Header.All() {
|
||||
headers[string(key)] = string(value)
|
||||
}
|
||||
p.cache.Set(hashKey, origKey, ctx.Response.Body(), headers, statusCode, p.config.Cache.MaxAge)
|
||||
p.cache.Set(hashKey, origKey, ctx.Response.Body(), headers, statusCode, p.getCacheDuration(statusCode))
|
||||
}
|
||||
p.cache.ReleaseLock(hashKey, nil)
|
||||
}
|
||||
@ -872,7 +880,7 @@ func (p *Proxy) UpdateTargets(targets []*loadbalance.Target) error {
|
||||
continue
|
||||
}
|
||||
|
||||
client := createHostClient(target.URL, p.config.Timeout, nil)
|
||||
client := createHostClient(target.URL, p.config.Timeout, nil, p.config.ProxySSL)
|
||||
p.clients[target.URL] = client
|
||||
}
|
||||
|
||||
@ -953,7 +961,7 @@ func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance.
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
p.cache.Set(hashKey, origKey, resp.Body(), headers, resp.StatusCode(), p.config.Cache.MaxAge)
|
||||
p.cache.Set(hashKey, origKey, resp.Body(), headers, resp.StatusCode(), p.getCacheDuration(resp.StatusCode()))
|
||||
}
|
||||
|
||||
// GetCacheStats 返回代理缓存的统计信息。
|
||||
@ -984,3 +992,45 @@ func extractHostFromURL(urlStr string) string {
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
// getCacheDuration 根据状态码获取缓存时间。
|
||||
// 优先级:CacheValid 配置 > MaxAge
|
||||
//
|
||||
// 映射规则:
|
||||
// - 200-299: CacheValid.OK(0 时继承 MaxAge)
|
||||
// - 301/302: CacheValid.Redirect
|
||||
// - 404: CacheValid.NotFound
|
||||
// - 400-499(除 404): CacheValid.ClientError
|
||||
// - 500-599: CacheValid.ServerError
|
||||
// - 其他: 不缓存(返回 0)
|
||||
func (p *Proxy) getCacheDuration(statusCode int) time.Duration {
|
||||
// 无 CacheValid 配置,使用 MaxAge
|
||||
if p.config.CacheValid == nil {
|
||||
return p.config.Cache.MaxAge
|
||||
}
|
||||
|
||||
cv := p.config.CacheValid
|
||||
|
||||
switch {
|
||||
case statusCode >= 200 && statusCode < 300:
|
||||
if cv.OK > 0 {
|
||||
return cv.OK
|
||||
}
|
||||
return p.config.Cache.MaxAge // 0 表示继承 MaxAge
|
||||
|
||||
case statusCode == 301 || statusCode == 302:
|
||||
return cv.Redirect // 0 表示不缓存
|
||||
|
||||
case statusCode == 404:
|
||||
return cv.NotFound
|
||||
|
||||
case statusCode >= 400 && statusCode < 500:
|
||||
return cv.ClientError
|
||||
|
||||
case statusCode >= 500:
|
||||
return cv.ServerError
|
||||
|
||||
default:
|
||||
return 0 // 不缓存
|
||||
}
|
||||
}
|
||||
|
||||
@ -240,7 +240,7 @@ func BenchmarkProxyHostClient(b *testing.B) {
|
||||
Write: 30 * time.Second,
|
||||
}
|
||||
|
||||
client := createHostClient("http://"+addr, timeout, nil)
|
||||
client := createHostClient("http://"+addr, timeout, nil, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
@ -269,7 +269,7 @@ func BenchmarkProxyHostClientParallel(b *testing.B) {
|
||||
Write: 30 * time.Second,
|
||||
}
|
||||
|
||||
client := createHostClient("http://"+addr, timeout, nil)
|
||||
client := createHostClient("http://"+addr, timeout, nil, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
|
||||
101
internal/proxy/proxy_ssl.go
Normal file
101
internal/proxy/proxy_ssl.go
Normal file
@ -0,0 +1,101 @@
|
||||
// Package proxy 反向代理包,为 Lolly HTTP 服务器提供反向代理功能。
|
||||
//
|
||||
// 该文件提供上游 SSL/TLS 配置支持,包括自定义 CA 证书、
|
||||
// 客户端证书(mTLS)、SNI 和 TLS 版本控制。
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// TLS 版本字符串到 tls 常量的映射。
|
||||
// 支持 TLSv1.0, TLSv1.1, TLSv1.2, TLSv1.3 格式(大小写不敏感)
|
||||
var tlsVersionMap = map[string]uint16{
|
||||
"TLSV1.0": tls.VersionTLS10,
|
||||
"TLSV1.1": tls.VersionTLS11,
|
||||
"TLSV1.2": tls.VersionTLS12,
|
||||
"TLSV1.3": tls.VersionTLS13,
|
||||
"": 0, // 空字符串表示使用默认
|
||||
}
|
||||
|
||||
// CreateTLSConfig 从 ProxySSLConfig 创建 tls.Config。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 上游 SSL 配置
|
||||
// - defaultServerName: 默认 SNI 名称(从目标 URL 提取)
|
||||
//
|
||||
// 返回值:
|
||||
// - *tls.Config: TLS 配置对象
|
||||
// - error: 配置错误(证书加载失败等)
|
||||
//
|
||||
// 注意事项:
|
||||
// - cfg 为 nil 或 Enabled=false 时返回 nil
|
||||
// - TrustedCA 加载失败时返回错误
|
||||
// - ClientCert/ClientKey 加载失败时返回错误
|
||||
func CreateTLSConfig(cfg *config.ProxySSLConfig, defaultServerName string) (*tls.Config, error) {
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tlsCfg := &tls.Config{}
|
||||
|
||||
// SNI 配置
|
||||
if cfg.ServerName != "" {
|
||||
tlsCfg.ServerName = cfg.ServerName
|
||||
} else if defaultServerName != "" {
|
||||
tlsCfg.ServerName = defaultServerName
|
||||
}
|
||||
|
||||
// 跳过证书验证(仅测试环境)
|
||||
if cfg.InsecureSkipVerify {
|
||||
tlsCfg.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
// CA 证书验证
|
||||
if cfg.TrustedCA != "" {
|
||||
caData, err := os.ReadFile(cfg.TrustedCA)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to read CA certificate: " + err.Error())
|
||||
}
|
||||
|
||||
caPool := x509.NewCertPool()
|
||||
if !caPool.AppendCertsFromPEM(caData) {
|
||||
return nil, errors.New("failed to parse CA certificate")
|
||||
}
|
||||
tlsCfg.RootCAs = caPool
|
||||
}
|
||||
|
||||
// 客户端证书(mTLS)
|
||||
if cfg.ClientCert != "" && cfg.ClientKey != "" {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to load client certificate: " + err.Error())
|
||||
}
|
||||
tlsCfg.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
// TLS 版本配置
|
||||
if cfg.MinVersion != "" {
|
||||
version, ok := tlsVersionMap[strings.ToUpper(cfg.MinVersion)]
|
||||
if !ok {
|
||||
return nil, errors.New("invalid TLS min version: " + cfg.MinVersion)
|
||||
}
|
||||
tlsCfg.MinVersion = version
|
||||
}
|
||||
|
||||
if cfg.MaxVersion != "" {
|
||||
version, ok := tlsVersionMap[strings.ToUpper(cfg.MaxVersion)]
|
||||
if !ok {
|
||||
return nil, errors.New("invalid TLS max version: " + cfg.MaxVersion)
|
||||
}
|
||||
tlsCfg.MaxVersion = version
|
||||
}
|
||||
|
||||
return tlsCfg, nil
|
||||
}
|
||||
342
internal/proxy/proxy_ssl_test.go
Normal file
342
internal/proxy/proxy_ssl_test.go
Normal file
@ -0,0 +1,342 @@
|
||||
// Package proxy 反向代理包,为 Lolly HTTP 服务器提供反向代理功能。
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
)
|
||||
|
||||
func TestCreateTLSConfig_NilConfig(t *testing.T) {
|
||||
tlsCfg, err := CreateTLSConfig(nil, "example.com")
|
||||
if err != nil {
|
||||
t.Errorf("CreateTLSConfig(nil) returned error: %v", err)
|
||||
}
|
||||
if tlsCfg != nil {
|
||||
t.Error("CreateTLSConfig(nil) should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTLSConfig_Disabled(t *testing.T) {
|
||||
cfg := &config.ProxySSLConfig{Enabled: false}
|
||||
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
|
||||
if err != nil {
|
||||
t.Errorf("CreateTLSConfig(disabled) returned error: %v", err)
|
||||
}
|
||||
if tlsCfg != nil {
|
||||
t.Error("CreateTLSConfig(disabled) should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTLSConfig_ServerName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *config.ProxySSLConfig
|
||||
defaultServerName string
|
||||
wantServerName string
|
||||
}{
|
||||
{
|
||||
name: "custom server name",
|
||||
cfg: &config.ProxySSLConfig{Enabled: true, ServerName: "custom.example.com"},
|
||||
defaultServerName: "default.example.com",
|
||||
wantServerName: "custom.example.com",
|
||||
},
|
||||
{
|
||||
name: "default server name",
|
||||
cfg: &config.ProxySSLConfig{Enabled: true},
|
||||
defaultServerName: "default.example.com",
|
||||
wantServerName: "default.example.com",
|
||||
},
|
||||
{
|
||||
name: "empty default",
|
||||
cfg: &config.ProxySSLConfig{Enabled: true},
|
||||
defaultServerName: "",
|
||||
wantServerName: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tlsCfg, err := CreateTLSConfig(tt.cfg, tt.defaultServerName)
|
||||
if err != nil {
|
||||
t.Errorf("CreateTLSConfig returned error: %v", err)
|
||||
return
|
||||
}
|
||||
if tlsCfg == nil {
|
||||
t.Error("CreateTLSConfig returned nil")
|
||||
return
|
||||
}
|
||||
if tlsCfg.ServerName != tt.wantServerName {
|
||||
t.Errorf("ServerName = %q, want %q", tlsCfg.ServerName, tt.wantServerName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTLSConfig_InsecureSkipVerify(t *testing.T) {
|
||||
cfg := &config.ProxySSLConfig{
|
||||
Enabled: true,
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
|
||||
if err != nil {
|
||||
t.Errorf("CreateTLSConfig returned error: %v", err)
|
||||
return
|
||||
}
|
||||
if tlsCfg == nil {
|
||||
t.Error("CreateTLSConfig returned nil")
|
||||
return
|
||||
}
|
||||
if !tlsCfg.InsecureSkipVerify {
|
||||
t.Error("InsecureSkipVerify should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTLSConfig_TLSVersions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
minVersion string
|
||||
maxVersion string
|
||||
wantMin uint16
|
||||
wantMax uint16
|
||||
}{
|
||||
{
|
||||
name: "TLSV1.2 min",
|
||||
minVersion: "TLSV1.2",
|
||||
wantMin: tls.VersionTLS12,
|
||||
},
|
||||
{
|
||||
name: "TLSV1.3 min",
|
||||
minVersion: "TLSV1.3",
|
||||
wantMin: tls.VersionTLS13,
|
||||
},
|
||||
{
|
||||
name: "TLSV1.2 max",
|
||||
maxVersion: "TLSV1.2",
|
||||
wantMax: tls.VersionTLS12,
|
||||
},
|
||||
{
|
||||
name: "both versions",
|
||||
minVersion: "TLSV1.2",
|
||||
maxVersion: "TLSV1.3",
|
||||
wantMin: tls.VersionTLS12,
|
||||
wantMax: tls.VersionTLS13,
|
||||
},
|
||||
{
|
||||
name: "mixed case TLSv1.2",
|
||||
minVersion: "TLSv1.2", // 测试大小写不敏感
|
||||
wantMin: tls.VersionTLS12,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.ProxySSLConfig{
|
||||
Enabled: true,
|
||||
MinVersion: tt.minVersion,
|
||||
MaxVersion: tt.maxVersion,
|
||||
}
|
||||
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
|
||||
if err != nil {
|
||||
t.Errorf("CreateTLSConfig returned error: %v", err)
|
||||
return
|
||||
}
|
||||
if tlsCfg == nil {
|
||||
t.Error("CreateTLSConfig returned nil")
|
||||
return
|
||||
}
|
||||
if tt.wantMin != 0 && tlsCfg.MinVersion != tt.wantMin {
|
||||
t.Errorf("MinVersion = %d, want %d", tlsCfg.MinVersion, tt.wantMin)
|
||||
}
|
||||
if tt.wantMax != 0 && tlsCfg.MaxVersion != tt.wantMax {
|
||||
t.Errorf("MaxVersion = %d, want %d", tlsCfg.MaxVersion, tt.wantMax)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTLSConfig_InvalidTLSVersion(t *testing.T) {
|
||||
cfg := &config.ProxySSLConfig{
|
||||
Enabled: true,
|
||||
MinVersion: "TLSv9.9",
|
||||
}
|
||||
_, err := CreateTLSConfig(cfg, "example.com")
|
||||
if err == nil {
|
||||
t.Error("CreateTLSConfig should return error for invalid TLS version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTLSConfig_TrustedCA(t *testing.T) {
|
||||
// 跳过这个测试,因为需要有效的 CA 证书文件
|
||||
// 实际集成测试会使用真实证书
|
||||
t.Skip("需要有效的 CA 证书文件,在集成测试中验证")
|
||||
}
|
||||
|
||||
func TestCreateTLSConfig_TrustedCANotFound(t *testing.T) {
|
||||
cfg := &config.ProxySSLConfig{
|
||||
Enabled: true,
|
||||
TrustedCA: "/nonexistent/ca.crt",
|
||||
}
|
||||
_, err := CreateTLSConfig(cfg, "example.com")
|
||||
if err == nil {
|
||||
t.Error("CreateTLSConfig should return error for nonexistent CA file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTLSConfig_ClientCert(t *testing.T) {
|
||||
// 跳过这个测试,因为需要有效的证书文件
|
||||
t.Skip("需要有效的客户端证书文件,在集成测试中验证")
|
||||
}
|
||||
|
||||
func TestCreateTLSConfig_ClientCertNotFound(t *testing.T) {
|
||||
cfg := &config.ProxySSLConfig{
|
||||
Enabled: true,
|
||||
ClientCert: "/nonexistent/client.crt",
|
||||
ClientKey: "/nonexistent/client.key",
|
||||
}
|
||||
_, err := CreateTLSConfig(cfg, "example.com")
|
||||
if err == nil {
|
||||
t.Error("CreateTLSConfig should return error for nonexistent cert files")
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存分段有效期测试(US-007)
|
||||
func TestGetCacheDuration_NoCacheValid(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Cache: config.ProxyCacheConfig{
|
||||
MaxAge: 5 * time.Minute,
|
||||
},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
// 无 CacheValid 配置时,所有状态码应使用 MaxAge
|
||||
tests := []struct {
|
||||
statusCode int
|
||||
want time.Duration
|
||||
}{
|
||||
{200, 5 * time.Minute},
|
||||
{301, 5 * time.Minute},
|
||||
{404, 5 * time.Minute},
|
||||
{500, 5 * time.Minute},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := p.getCacheDuration(tt.statusCode)
|
||||
if got != tt.want {
|
||||
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCacheDuration_CacheValidOKInheritsMaxAge(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Cache: config.ProxyCacheConfig{
|
||||
MaxAge: 10 * time.Minute,
|
||||
},
|
||||
CacheValid: &config.ProxyCacheValidConfig{
|
||||
OK: 0, // 0 表示继承 MaxAge
|
||||
},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
got := p.getCacheDuration(200)
|
||||
want := 10 * time.Minute
|
||||
if got != want {
|
||||
t.Errorf("getCacheDuration(200) with OK=0 = %v, want %v (MaxAge)", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCacheDuration_StatusCodeMapping(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Cache: config.ProxyCacheConfig{
|
||||
MaxAge: 1 * time.Minute,
|
||||
},
|
||||
CacheValid: &config.ProxyCacheValidConfig{
|
||||
OK: 10 * time.Minute,
|
||||
Redirect: 1 * time.Hour,
|
||||
NotFound: 1 * time.Minute,
|
||||
ClientError: 30 * time.Second,
|
||||
ServerError: 0, // 不缓存
|
||||
},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
want time.Duration
|
||||
}{
|
||||
{"200 OK", 200, 10 * time.Minute},
|
||||
{"201 Created", 201, 10 * time.Minute},
|
||||
{"299 OK boundary", 299, 10 * time.Minute},
|
||||
{"301 Moved", 301, 1 * time.Hour},
|
||||
{"302 Found", 302, 1 * time.Hour},
|
||||
{"304 Not Modified", 304, 0}, // 不在 Redirect 范围内
|
||||
{"404 Not Found", 404, 1 * time.Minute},
|
||||
{"400 Bad Request", 400, 30 * time.Second},
|
||||
{"403 Forbidden", 403, 30 * time.Second},
|
||||
{"500 Internal Error", 500, 0},
|
||||
{"503 Service Unavailable", 503, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := p.getCacheDuration(tt.statusCode)
|
||||
if got != tt.want {
|
||||
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCacheDuration_ZeroValuesNoCache(t *testing.T) {
|
||||
cfg := &config.ProxyConfig{
|
||||
Cache: config.ProxyCacheConfig{
|
||||
MaxAge: 5 * time.Minute,
|
||||
},
|
||||
CacheValid: &config.ProxyCacheValidConfig{
|
||||
OK: 10 * time.Minute, // OK 有值
|
||||
Redirect: 0, // 不缓存
|
||||
NotFound: 0, // 不缓存
|
||||
ClientError: 0, // 不缓存
|
||||
ServerError: 0, // 不缓存
|
||||
},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
statusCode int
|
||||
want time.Duration
|
||||
}{
|
||||
{200, 10 * time.Minute}, // OK 有值
|
||||
{301, 0}, // Redirect=0 不缓存
|
||||
{404, 0}, // NotFound=0 不缓存
|
||||
{500, 0}, // ServerError=0 不缓存
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := p.getCacheDuration(tt.statusCode)
|
||||
if got != tt.want {
|
||||
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -849,7 +849,7 @@ func TestCreateHostClient(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client := createHostClient(tt.targetURL, tt.timeout, nil)
|
||||
client := createHostClient(tt.targetURL, tt.timeout, nil, nil)
|
||||
if client == nil {
|
||||
t.Error("createHostClient() returned nil")
|
||||
return
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user