- headers.go: 添加协议常量 protoHTTP/protoHTTPS - redirect_rewrite.go: 添加模式常量,修正缩进 - proxy_ssl_test.go: 表格测试字段对齐 - server.go: 添加 ServerModeAuto 分支防御性处理 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
343 lines
8.9 KiB
Go
343 lines
8.9 KiB
Go
// Package proxy 反向代理包,为 Lolly HTTP 服务器提供反向代理功能。
|
||
package proxy
|
||
|
||
import (
|
||
"crypto/tls"
|
||
"testing"
|
||
"time"
|
||
|
||
"rua.plus/lolly/internal/config"
|
||
"rua.plus/lolly/internal/loadbalance"
|
||
)
|
||
|
||
func TestCreateTLSConfig_NilConfig(t *testing.T) {
|
||
tlsCfg, err := CreateTLSConfig(nil, "example.com")
|
||
if err != nil {
|
||
t.Errorf("CreateTLSConfig(nil) returned error: %v", err)
|
||
}
|
||
if tlsCfg != nil {
|
||
t.Error("CreateTLSConfig(nil) should return nil")
|
||
}
|
||
}
|
||
|
||
func TestCreateTLSConfig_Disabled(t *testing.T) {
|
||
cfg := &config.ProxySSLConfig{Enabled: false}
|
||
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
|
||
if err != nil {
|
||
t.Errorf("CreateTLSConfig(disabled) returned error: %v", err)
|
||
}
|
||
if tlsCfg != nil {
|
||
t.Error("CreateTLSConfig(disabled) should return nil")
|
||
}
|
||
}
|
||
|
||
func TestCreateTLSConfig_ServerName(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
cfg *config.ProxySSLConfig
|
||
defaultServerName string
|
||
wantServerName string
|
||
}{
|
||
{
|
||
name: "custom server name",
|
||
cfg: &config.ProxySSLConfig{Enabled: true, ServerName: "custom.example.com"},
|
||
defaultServerName: "default.example.com",
|
||
wantServerName: "custom.example.com",
|
||
},
|
||
{
|
||
name: "default server name",
|
||
cfg: &config.ProxySSLConfig{Enabled: true},
|
||
defaultServerName: "default.example.com",
|
||
wantServerName: "default.example.com",
|
||
},
|
||
{
|
||
name: "empty default",
|
||
cfg: &config.ProxySSLConfig{Enabled: true},
|
||
defaultServerName: "",
|
||
wantServerName: "",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
tlsCfg, err := CreateTLSConfig(tt.cfg, tt.defaultServerName)
|
||
if err != nil {
|
||
t.Errorf("CreateTLSConfig returned error: %v", err)
|
||
return
|
||
}
|
||
if tlsCfg == nil {
|
||
t.Error("CreateTLSConfig returned nil")
|
||
return
|
||
}
|
||
if tlsCfg.ServerName != tt.wantServerName {
|
||
t.Errorf("ServerName = %q, want %q", tlsCfg.ServerName, tt.wantServerName)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestCreateTLSConfig_InsecureSkipVerify(t *testing.T) {
|
||
cfg := &config.ProxySSLConfig{
|
||
Enabled: true,
|
||
InsecureSkipVerify: true,
|
||
}
|
||
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
|
||
if err != nil {
|
||
t.Errorf("CreateTLSConfig returned error: %v", err)
|
||
return
|
||
}
|
||
if tlsCfg == nil {
|
||
t.Error("CreateTLSConfig returned nil")
|
||
return
|
||
}
|
||
if !tlsCfg.InsecureSkipVerify {
|
||
t.Error("InsecureSkipVerify should be true")
|
||
}
|
||
}
|
||
|
||
func TestCreateTLSConfig_TLSVersions(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
minVersion string
|
||
maxVersion string
|
||
wantMin uint16
|
||
wantMax uint16
|
||
}{
|
||
{
|
||
name: "TLSV1.2 min",
|
||
minVersion: "TLSV1.2",
|
||
wantMin: tls.VersionTLS12,
|
||
},
|
||
{
|
||
name: "TLSV1.3 min",
|
||
minVersion: "TLSV1.3",
|
||
wantMin: tls.VersionTLS13,
|
||
},
|
||
{
|
||
name: "TLSV1.2 max",
|
||
maxVersion: "TLSV1.2",
|
||
wantMax: tls.VersionTLS12,
|
||
},
|
||
{
|
||
name: "both versions",
|
||
minVersion: "TLSV1.2",
|
||
maxVersion: "TLSV1.3",
|
||
wantMin: tls.VersionTLS12,
|
||
wantMax: tls.VersionTLS13,
|
||
},
|
||
{
|
||
name: "mixed case TLSv1.2",
|
||
minVersion: "TLSv1.2", // 测试大小写不敏感
|
||
wantMin: tls.VersionTLS12,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
cfg := &config.ProxySSLConfig{
|
||
Enabled: true,
|
||
MinVersion: tt.minVersion,
|
||
MaxVersion: tt.maxVersion,
|
||
}
|
||
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
|
||
if err != nil {
|
||
t.Errorf("CreateTLSConfig returned error: %v", err)
|
||
return
|
||
}
|
||
if tlsCfg == nil {
|
||
t.Error("CreateTLSConfig returned nil")
|
||
return
|
||
}
|
||
if tt.wantMin != 0 && tlsCfg.MinVersion != tt.wantMin {
|
||
t.Errorf("MinVersion = %d, want %d", tlsCfg.MinVersion, tt.wantMin)
|
||
}
|
||
if tt.wantMax != 0 && tlsCfg.MaxVersion != tt.wantMax {
|
||
t.Errorf("MaxVersion = %d, want %d", tlsCfg.MaxVersion, tt.wantMax)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestCreateTLSConfig_InvalidTLSVersion(t *testing.T) {
|
||
cfg := &config.ProxySSLConfig{
|
||
Enabled: true,
|
||
MinVersion: "TLSv9.9",
|
||
}
|
||
_, err := CreateTLSConfig(cfg, "example.com")
|
||
if err == nil {
|
||
t.Error("CreateTLSConfig should return error for invalid TLS version")
|
||
}
|
||
}
|
||
|
||
func TestCreateTLSConfig_TrustedCA(t *testing.T) {
|
||
// 跳过这个测试,因为需要有效的 CA 证书文件
|
||
// 实际集成测试会使用真实证书
|
||
t.Skip("需要有效的 CA 证书文件,在集成测试中验证")
|
||
}
|
||
|
||
func TestCreateTLSConfig_TrustedCANotFound(t *testing.T) {
|
||
cfg := &config.ProxySSLConfig{
|
||
Enabled: true,
|
||
TrustedCA: "/nonexistent/ca.crt",
|
||
}
|
||
_, err := CreateTLSConfig(cfg, "example.com")
|
||
if err == nil {
|
||
t.Error("CreateTLSConfig should return error for nonexistent CA file")
|
||
}
|
||
}
|
||
|
||
func TestCreateTLSConfig_ClientCert(t *testing.T) {
|
||
// 跳过这个测试,因为需要有效的证书文件
|
||
t.Skip("需要有效的客户端证书文件,在集成测试中验证")
|
||
}
|
||
|
||
func TestCreateTLSConfig_ClientCertNotFound(t *testing.T) {
|
||
cfg := &config.ProxySSLConfig{
|
||
Enabled: true,
|
||
ClientCert: "/nonexistent/client.crt",
|
||
ClientKey: "/nonexistent/client.key",
|
||
}
|
||
_, err := CreateTLSConfig(cfg, "example.com")
|
||
if err == nil {
|
||
t.Error("CreateTLSConfig should return error for nonexistent cert files")
|
||
}
|
||
}
|
||
|
||
// 缓存分段有效期测试(US-007)
|
||
func TestGetCacheDuration_NoCacheValid(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Cache: config.ProxyCacheConfig{
|
||
MaxAge: 5 * time.Minute,
|
||
},
|
||
}
|
||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewProxy() error: %v", err)
|
||
}
|
||
|
||
// 无 CacheValid 配置时,所有状态码应使用 MaxAge
|
||
tests := []struct {
|
||
statusCode int
|
||
want time.Duration
|
||
}{
|
||
{200, 5 * time.Minute},
|
||
{301, 5 * time.Minute},
|
||
{404, 5 * time.Minute},
|
||
{500, 5 * time.Minute},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
got := p.getCacheDuration(tt.statusCode)
|
||
if got != tt.want {
|
||
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestGetCacheDuration_CacheValidOKInheritsMaxAge(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Cache: config.ProxyCacheConfig{
|
||
MaxAge: 10 * time.Minute,
|
||
},
|
||
CacheValid: &config.ProxyCacheValidConfig{
|
||
OK: 0, // 0 表示继承 MaxAge
|
||
},
|
||
}
|
||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewProxy() error: %v", err)
|
||
}
|
||
|
||
got := p.getCacheDuration(200)
|
||
want := 10 * time.Minute
|
||
if got != want {
|
||
t.Errorf("getCacheDuration(200) with OK=0 = %v, want %v (MaxAge)", got, want)
|
||
}
|
||
}
|
||
|
||
func TestGetCacheDuration_StatusCodeMapping(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Cache: config.ProxyCacheConfig{
|
||
MaxAge: 1 * time.Minute,
|
||
},
|
||
CacheValid: &config.ProxyCacheValidConfig{
|
||
OK: 10 * time.Minute,
|
||
Redirect: 1 * time.Hour,
|
||
NotFound: 1 * time.Minute,
|
||
ClientError: 30 * time.Second,
|
||
ServerError: 0, // 不缓存
|
||
},
|
||
}
|
||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewProxy() error: %v", err)
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
statusCode int
|
||
want time.Duration
|
||
}{
|
||
{"200 OK", 200, 10 * time.Minute},
|
||
{"201 Created", 201, 10 * time.Minute},
|
||
{"299 OK boundary", 299, 10 * time.Minute},
|
||
{"301 Moved", 301, 1 * time.Hour},
|
||
{"302 Found", 302, 1 * time.Hour},
|
||
{"304 Not Modified", 304, 0}, // 不在 Redirect 范围内
|
||
{"404 Not Found", 404, 1 * time.Minute},
|
||
{"400 Bad Request", 400, 30 * time.Second},
|
||
{"403 Forbidden", 403, 30 * time.Second},
|
||
{"500 Internal Error", 500, 0},
|
||
{"503 Service Unavailable", 503, 0},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
got := p.getCacheDuration(tt.statusCode)
|
||
if got != tt.want {
|
||
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestGetCacheDuration_ZeroValuesNoCache(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Cache: config.ProxyCacheConfig{
|
||
MaxAge: 5 * time.Minute,
|
||
},
|
||
CacheValid: &config.ProxyCacheValidConfig{
|
||
OK: 10 * time.Minute, // OK 有值
|
||
Redirect: 0, // 不缓存
|
||
NotFound: 0, // 不缓存
|
||
ClientError: 0, // 不缓存
|
||
ServerError: 0, // 不缓存
|
||
},
|
||
}
|
||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewProxy() error: %v", err)
|
||
}
|
||
|
||
tests := []struct {
|
||
statusCode int
|
||
want time.Duration
|
||
}{
|
||
{200, 10 * time.Minute}, // OK 有值
|
||
{301, 0}, // Redirect=0 不缓存
|
||
{404, 0}, // NotFound=0 不缓存
|
||
{500, 0}, // ServerError=0 不缓存
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
got := p.getCacheDuration(tt.statusCode)
|
||
if got != tt.want {
|
||
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
|
||
}
|
||
}
|
||
}
|