lolly/internal/proxy/proxy_ssl_test.go
xfy 470c82d940 style(proxy,server): 代码风格优化
- headers.go: 添加协议常量 protoHTTP/protoHTTPS
- redirect_rewrite.go: 添加模式常量,修正缩进
- proxy_ssl_test.go: 表格测试字段对齐
- server.go: 添加 ServerModeAuto 分支防御性处理

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-16 09:54:09 +08:00

343 lines
8.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package proxy 反向代理包,为 Lolly HTTP 服务器提供反向代理功能。
package proxy
import (
"crypto/tls"
"testing"
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
)
func TestCreateTLSConfig_NilConfig(t *testing.T) {
tlsCfg, err := CreateTLSConfig(nil, "example.com")
if err != nil {
t.Errorf("CreateTLSConfig(nil) returned error: %v", err)
}
if tlsCfg != nil {
t.Error("CreateTLSConfig(nil) should return nil")
}
}
func TestCreateTLSConfig_Disabled(t *testing.T) {
cfg := &config.ProxySSLConfig{Enabled: false}
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
if err != nil {
t.Errorf("CreateTLSConfig(disabled) returned error: %v", err)
}
if tlsCfg != nil {
t.Error("CreateTLSConfig(disabled) should return nil")
}
}
func TestCreateTLSConfig_ServerName(t *testing.T) {
tests := []struct {
name string
cfg *config.ProxySSLConfig
defaultServerName string
wantServerName string
}{
{
name: "custom server name",
cfg: &config.ProxySSLConfig{Enabled: true, ServerName: "custom.example.com"},
defaultServerName: "default.example.com",
wantServerName: "custom.example.com",
},
{
name: "default server name",
cfg: &config.ProxySSLConfig{Enabled: true},
defaultServerName: "default.example.com",
wantServerName: "default.example.com",
},
{
name: "empty default",
cfg: &config.ProxySSLConfig{Enabled: true},
defaultServerName: "",
wantServerName: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tlsCfg, err := CreateTLSConfig(tt.cfg, tt.defaultServerName)
if err != nil {
t.Errorf("CreateTLSConfig returned error: %v", err)
return
}
if tlsCfg == nil {
t.Error("CreateTLSConfig returned nil")
return
}
if tlsCfg.ServerName != tt.wantServerName {
t.Errorf("ServerName = %q, want %q", tlsCfg.ServerName, tt.wantServerName)
}
})
}
}
func TestCreateTLSConfig_InsecureSkipVerify(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
InsecureSkipVerify: true,
}
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
if err != nil {
t.Errorf("CreateTLSConfig returned error: %v", err)
return
}
if tlsCfg == nil {
t.Error("CreateTLSConfig returned nil")
return
}
if !tlsCfg.InsecureSkipVerify {
t.Error("InsecureSkipVerify should be true")
}
}
func TestCreateTLSConfig_TLSVersions(t *testing.T) {
tests := []struct {
name string
minVersion string
maxVersion string
wantMin uint16
wantMax uint16
}{
{
name: "TLSV1.2 min",
minVersion: "TLSV1.2",
wantMin: tls.VersionTLS12,
},
{
name: "TLSV1.3 min",
minVersion: "TLSV1.3",
wantMin: tls.VersionTLS13,
},
{
name: "TLSV1.2 max",
maxVersion: "TLSV1.2",
wantMax: tls.VersionTLS12,
},
{
name: "both versions",
minVersion: "TLSV1.2",
maxVersion: "TLSV1.3",
wantMin: tls.VersionTLS12,
wantMax: tls.VersionTLS13,
},
{
name: "mixed case TLSv1.2",
minVersion: "TLSv1.2", // 测试大小写不敏感
wantMin: tls.VersionTLS12,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
MinVersion: tt.minVersion,
MaxVersion: tt.maxVersion,
}
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
if err != nil {
t.Errorf("CreateTLSConfig returned error: %v", err)
return
}
if tlsCfg == nil {
t.Error("CreateTLSConfig returned nil")
return
}
if tt.wantMin != 0 && tlsCfg.MinVersion != tt.wantMin {
t.Errorf("MinVersion = %d, want %d", tlsCfg.MinVersion, tt.wantMin)
}
if tt.wantMax != 0 && tlsCfg.MaxVersion != tt.wantMax {
t.Errorf("MaxVersion = %d, want %d", tlsCfg.MaxVersion, tt.wantMax)
}
})
}
}
func TestCreateTLSConfig_InvalidTLSVersion(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
MinVersion: "TLSv9.9",
}
_, err := CreateTLSConfig(cfg, "example.com")
if err == nil {
t.Error("CreateTLSConfig should return error for invalid TLS version")
}
}
func TestCreateTLSConfig_TrustedCA(t *testing.T) {
// 跳过这个测试,因为需要有效的 CA 证书文件
// 实际集成测试会使用真实证书
t.Skip("需要有效的 CA 证书文件,在集成测试中验证")
}
func TestCreateTLSConfig_TrustedCANotFound(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
TrustedCA: "/nonexistent/ca.crt",
}
_, err := CreateTLSConfig(cfg, "example.com")
if err == nil {
t.Error("CreateTLSConfig should return error for nonexistent CA file")
}
}
func TestCreateTLSConfig_ClientCert(t *testing.T) {
// 跳过这个测试,因为需要有效的证书文件
t.Skip("需要有效的客户端证书文件,在集成测试中验证")
}
func TestCreateTLSConfig_ClientCertNotFound(t *testing.T) {
cfg := &config.ProxySSLConfig{
Enabled: true,
ClientCert: "/nonexistent/client.crt",
ClientKey: "/nonexistent/client.key",
}
_, err := CreateTLSConfig(cfg, "example.com")
if err == nil {
t.Error("CreateTLSConfig should return error for nonexistent cert files")
}
}
// 缓存分段有效期测试US-007
func TestGetCacheDuration_NoCacheValid(t *testing.T) {
cfg := &config.ProxyConfig{
Cache: config.ProxyCacheConfig{
MaxAge: 5 * time.Minute,
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 无 CacheValid 配置时,所有状态码应使用 MaxAge
tests := []struct {
statusCode int
want time.Duration
}{
{200, 5 * time.Minute},
{301, 5 * time.Minute},
{404, 5 * time.Minute},
{500, 5 * time.Minute},
}
for _, tt := range tests {
got := p.getCacheDuration(tt.statusCode)
if got != tt.want {
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
}
}
}
func TestGetCacheDuration_CacheValidOKInheritsMaxAge(t *testing.T) {
cfg := &config.ProxyConfig{
Cache: config.ProxyCacheConfig{
MaxAge: 10 * time.Minute,
},
CacheValid: &config.ProxyCacheValidConfig{
OK: 0, // 0 表示继承 MaxAge
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
got := p.getCacheDuration(200)
want := 10 * time.Minute
if got != want {
t.Errorf("getCacheDuration(200) with OK=0 = %v, want %v (MaxAge)", got, want)
}
}
func TestGetCacheDuration_StatusCodeMapping(t *testing.T) {
cfg := &config.ProxyConfig{
Cache: config.ProxyCacheConfig{
MaxAge: 1 * time.Minute,
},
CacheValid: &config.ProxyCacheValidConfig{
OK: 10 * time.Minute,
Redirect: 1 * time.Hour,
NotFound: 1 * time.Minute,
ClientError: 30 * time.Second,
ServerError: 0, // 不缓存
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
tests := []struct {
name string
statusCode int
want time.Duration
}{
{"200 OK", 200, 10 * time.Minute},
{"201 Created", 201, 10 * time.Minute},
{"299 OK boundary", 299, 10 * time.Minute},
{"301 Moved", 301, 1 * time.Hour},
{"302 Found", 302, 1 * time.Hour},
{"304 Not Modified", 304, 0}, // 不在 Redirect 范围内
{"404 Not Found", 404, 1 * time.Minute},
{"400 Bad Request", 400, 30 * time.Second},
{"403 Forbidden", 403, 30 * time.Second},
{"500 Internal Error", 500, 0},
{"503 Service Unavailable", 503, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := p.getCacheDuration(tt.statusCode)
if got != tt.want {
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
}
})
}
}
func TestGetCacheDuration_ZeroValuesNoCache(t *testing.T) {
cfg := &config.ProxyConfig{
Cache: config.ProxyCacheConfig{
MaxAge: 5 * time.Minute,
},
CacheValid: &config.ProxyCacheValidConfig{
OK: 10 * time.Minute, // OK 有值
Redirect: 0, // 不缓存
NotFound: 0, // 不缓存
ClientError: 0, // 不缓存
ServerError: 0, // 不缓存
},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
tests := []struct {
statusCode int
want time.Duration
}{
{200, 10 * time.Minute}, // OK 有值
{301, 0}, // Redirect=0 不缓存
{404, 0}, // NotFound=0 不缓存
{500, 0}, // ServerError=0 不缓存
}
for _, tt := range tests {
got := p.getCacheDuration(tt.statusCode)
if got != tt.want {
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
}
}
}