lolly/internal/proxy/proxy_ssl_test.go
xfy a644e551af feat(proxy): 添加上游 SSL 配置和缓存有效期分段配置
- ProxySSLConfig: 支持自定义 CA、客户端证书(mTLS)、SNI、TLS 版本控制
- ProxyCacheValidConfig: 按 HTTP 状态码分段配置缓存有效期
- proxy_ssl.go: 实现 CreateTLSConfig 和 TLS 版本解析
- proxy.go: 集成 SSL 配置到 HostClient,实现 getCacheDuration 分段缓存
- 测试文件适配新配置

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-15 18:27:50 +08:00

342 lines
8.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package proxy 反向代理包,为 Lolly HTTP 服务器提供反向代理功能。
package proxy
import (
"crypto/tls"
"testing"
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
)
func TestCreateTLSConfig_NilConfig(t *testing.T) {
tlsCfg, err := CreateTLSConfig(nil, "example.com")
if err != nil {
t.Errorf("CreateTLSConfig(nil) returned error: %v", err)
}
if tlsCfg != nil {
t.Error("CreateTLSConfig(nil) should return nil")
}
}
func TestCreateTLSConfig_Disabled(t *testing.T) {
cfg := &config.ProxySSLConfig{Enabled: false}
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
if err != nil {
t.Errorf("CreateTLSConfig(disabled) returned error: %v", err)
}
if tlsCfg != nil {
t.Error("CreateTLSConfig(disabled) should return nil")
}
}
func TestCreateTLSConfig_ServerName(t *testing.T) {
tests := []struct {
name string
cfg *config.ProxySSLConfig
defaultServerName string
wantServerName string
}{
{
name: "custom server name",
cfg: &config.ProxySSLConfig{Enabled: true, ServerName: "custom.example.com"},
defaultServerName: "default.example.com",
wantServerName: "custom.example.com",
},
{
name: "default server name",
cfg: &config.ProxySSLConfig{Enabled: true},
defaultServerName: "default.example.com",
wantServerName: "default.example.com",
},
{
name: "empty default",
cfg: &config.ProxySSLConfig{Enabled: true},
defaultServerName: "",
wantServerName: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tlsCfg, err := CreateTLSConfig(tt.cfg, tt.defaultServerName)
if err != nil {
t.Errorf("CreateTLSConfig returned error: %v", err)
return
}
if tlsCfg == nil {
t.Error("CreateTLSConfig returned nil")
return
}
if tlsCfg.ServerName != tt.wantServerName {
t.Errorf("ServerName = %q, want %q", tlsCfg.ServerName, tt.wantServerName)
}
})
}
}
func TestCreateTLSConfig_InsecureSkipVerify(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
InsecureSkipVerify: true,
}
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
if err != nil {
t.Errorf("CreateTLSConfig returned error: %v", err)
return
}
if tlsCfg == nil {
t.Error("CreateTLSConfig returned nil")
return
}
if !tlsCfg.InsecureSkipVerify {
t.Error("InsecureSkipVerify should be true")
}
}
func TestCreateTLSConfig_TLSVersions(t *testing.T) {
tests := []struct {
name string
minVersion string
maxVersion string
wantMin uint16
wantMax uint16
}{
{
name: "TLSV1.2 min",
minVersion: "TLSV1.2",
wantMin: tls.VersionTLS12,
},
{
name: "TLSV1.3 min",
minVersion: "TLSV1.3",
wantMin: tls.VersionTLS13,
},
{
name: "TLSV1.2 max",
maxVersion: "TLSV1.2",
wantMax: tls.VersionTLS12,
},
{
name: "both versions",
minVersion: "TLSV1.2",
maxVersion: "TLSV1.3",
wantMin: tls.VersionTLS12,
wantMax: tls.VersionTLS13,
},
{
name: "mixed case TLSv1.2",
minVersion: "TLSv1.2", // 测试大小写不敏感
wantMin: tls.VersionTLS12,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
MinVersion: tt.minVersion,
MaxVersion: tt.maxVersion,
}
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
if err != nil {
t.Errorf("CreateTLSConfig returned error: %v", err)
return
}
if tlsCfg == nil {
t.Error("CreateTLSConfig returned nil")
return
}
if tt.wantMin != 0 && tlsCfg.MinVersion != tt.wantMin {
t.Errorf("MinVersion = %d, want %d", tlsCfg.MinVersion, tt.wantMin)
}
if tt.wantMax != 0 && tlsCfg.MaxVersion != tt.wantMax {
t.Errorf("MaxVersion = %d, want %d", tlsCfg.MaxVersion, tt.wantMax)
}
})
}
}
func TestCreateTLSConfig_InvalidTLSVersion(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
MinVersion: "TLSv9.9",
}
_, err := CreateTLSConfig(cfg, "example.com")
if err == nil {
t.Error("CreateTLSConfig should return error for invalid TLS version")
}
}
func TestCreateTLSConfig_TrustedCA(t *testing.T) {
// 跳过这个测试,因为需要有效的 CA 证书文件
// 实际集成测试会使用真实证书
t.Skip("需要有效的 CA 证书文件,在集成测试中验证")
}
func TestCreateTLSConfig_TrustedCANotFound(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
TrustedCA: "/nonexistent/ca.crt",
}
_, err := CreateTLSConfig(cfg, "example.com")
if err == nil {
t.Error("CreateTLSConfig should return error for nonexistent CA file")
}
}
func TestCreateTLSConfig_ClientCert(t *testing.T) {
// 跳过这个测试,因为需要有效的证书文件
t.Skip("需要有效的客户端证书文件,在集成测试中验证")
}
func TestCreateTLSConfig_ClientCertNotFound(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
ClientCert: "/nonexistent/client.crt",
ClientKey: "/nonexistent/client.key",
}
_, err := CreateTLSConfig(cfg, "example.com")
if err == nil {
t.Error("CreateTLSConfig should return error for nonexistent cert files")
}
}
// 缓存分段有效期测试US-007
func TestGetCacheDuration_NoCacheValid(t *testing.T) {
cfg := &config.ProxyConfig{
Cache: config.ProxyCacheConfig{
MaxAge: 5 * time.Minute,
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 无 CacheValid 配置时,所有状态码应使用 MaxAge
tests := []struct {
statusCode int
want time.Duration
}{
{200, 5 * time.Minute},
{301, 5 * time.Minute},
{404, 5 * time.Minute},
{500, 5 * time.Minute},
}
for _, tt := range tests {
got := p.getCacheDuration(tt.statusCode)
if got != tt.want {
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
}
}
}
func TestGetCacheDuration_CacheValidOKInheritsMaxAge(t *testing.T) {
cfg := &config.ProxyConfig{
Cache: config.ProxyCacheConfig{
MaxAge: 10 * time.Minute,
},
CacheValid: &config.ProxyCacheValidConfig{
OK: 0, // 0 表示继承 MaxAge
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
got := p.getCacheDuration(200)
want := 10 * time.Minute
if got != want {
t.Errorf("getCacheDuration(200) with OK=0 = %v, want %v (MaxAge)", got, want)
}
}
func TestGetCacheDuration_StatusCodeMapping(t *testing.T) {
cfg := &config.ProxyConfig{
Cache: config.ProxyCacheConfig{
MaxAge: 1 * time.Minute,
},
CacheValid: &config.ProxyCacheValidConfig{
OK: 10 * time.Minute,
Redirect: 1 * time.Hour,
NotFound: 1 * time.Minute,
ClientError: 30 * time.Second,
ServerError: 0, // 不缓存
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
tests := []struct {
name string
statusCode int
want time.Duration
}{
{"200 OK", 200, 10 * time.Minute},
{"201 Created", 201, 10 * time.Minute},
{"299 OK boundary", 299, 10 * time.Minute},
{"301 Moved", 301, 1 * time.Hour},
{"302 Found", 302, 1 * time.Hour},
{"304 Not Modified", 304, 0}, // 不在 Redirect 范围内
{"404 Not Found", 404, 1 * time.Minute},
{"400 Bad Request", 400, 30 * time.Second},
{"403 Forbidden", 403, 30 * time.Second},
{"500 Internal Error", 500, 0},
{"503 Service Unavailable", 503, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := p.getCacheDuration(tt.statusCode)
if got != tt.want {
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
}
})
}
}
func TestGetCacheDuration_ZeroValuesNoCache(t *testing.T) {
cfg := &config.ProxyConfig{
Cache: config.ProxyCacheConfig{
MaxAge: 5 * time.Minute,
},
CacheValid: &config.ProxyCacheValidConfig{
OK: 10 * time.Minute, // OK 有值
Redirect: 0, // 不缓存
NotFound: 0, // 不缓存
ClientError: 0, // 不缓存
ServerError: 0, // 不缓存
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
tests := []struct {
statusCode int
want time.Duration
}{
{200, 10 * time.Minute}, // OK 有值
{301, 0}, // Redirect=0 不缓存
{404, 0}, // NotFound=0 不缓存
{500, 0}, // ServerError=0 不缓存
}
for _, tt := range tests {
got := p.getCacheDuration(tt.statusCode)
if got != tt.want {
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
}
}
}