lolly/internal/proxy/proxy_test.go

1128 lines
29 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package proxy 提供反向代理功能的测试。
//
// 该文件测试代理模块的各项功能,包括:
// - 代理创建和配置
// - 目标选择
// - 请求转发
// - 请求头修改
// - 响应头修改
// - 客户端 IP 提取
// - 目标更新
// - WebSocket 请求检测
// - 负载均衡器创建
// - HostClient 创建
// - 健康检查器设置
// - 代理缓存功能
// - 被动健康检查
//
// 作者xfy
package proxy
import (
"net"
"testing"
"time"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/netutil"
"rua.plus/lolly/internal/testutil"
"rua.plus/lolly/internal/variable"
)
// TestNewProxy 测试 NewProxy 函数
func TestNewProxy(t *testing.T) {
tests := []struct {
cfg *config.ProxyConfig
name string
errContains string
targets []*loadbalance.Target
wantErr bool
}{
{
name: "正常创建",
cfg: testutil.NewTestProxyConfig("/api"),
targets: testutil.NewTestTargets("http://localhost:8081", "http://localhost:8082"),
wantErr: false,
},
{
name: "nil配置",
cfg: nil,
targets: testutil.NewTestTargets("http://localhost:8081"),
wantErr: true,
errContains: "proxy config is nil",
},
{
name: "空目标列表",
cfg: &config.ProxyConfig{Path: "/api"},
targets: []*loadbalance.Target{},
wantErr: true,
errContains: "no proxy targets provided",
},
{
name: "nil目标列表",
cfg: &config.ProxyConfig{Path: "/api"},
targets: nil,
wantErr: true,
errContains: "no proxy targets provided",
},
{
name: "默认负载均衡算法",
cfg: &config.ProxyConfig{
Path: "/api",
LoadBalance: "",
},
targets: testutil.NewTestTargets("http://localhost:8081"),
wantErr: false,
},
{
name: "加权轮询算法",
cfg: &config.ProxyConfig{
Path: "/api",
LoadBalance: "weighted_round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
},
targets: []*loadbalance.Target{
{URL: "http://localhost:8081", Weight: 1},
{URL: "http://localhost:8082", Weight: 2},
},
wantErr: false,
},
{
name: "最少连接算法",
cfg: &config.ProxyConfig{
Path: "/api",
LoadBalance: "least_conn",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
},
targets: []*loadbalance.Target{
{URL: "http://localhost:8081"},
},
wantErr: false,
},
{
name: "IP哈希算法",
cfg: &config.ProxyConfig{
Path: "/api",
LoadBalance: "ip_hash",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
},
targets: []*loadbalance.Target{
{URL: "http://localhost:8081"},
},
wantErr: false,
},
{
name: "无效负载均衡算法",
cfg: &config.ProxyConfig{
Path: "/api",
LoadBalance: "invalid_algorithm",
},
targets: testutil.NewTestTargets("http://localhost:8081"),
wantErr: true,
errContains: "unsupported load balance algorithm",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := NewProxy(tt.cfg, tt.targets, nil, nil)
if tt.wantErr {
if err == nil {
t.Errorf("NewProxy() expected error containing %q, got nil", tt.errContains)
return
}
if !contains(err.Error(), tt.errContains) {
t.Errorf("NewProxy() error = %v, want containing %q", err, tt.errContains)
}
return
}
if err != nil {
t.Errorf("NewProxy() unexpected error: %v", err)
return
}
if p == nil {
t.Error("NewProxy() returned nil proxy")
return
}
if p.config != tt.cfg {
t.Error("NewProxy() proxy config not set correctly")
}
if p.balancer == nil {
t.Error("NewProxy() balancer not initialized")
}
})
}
}
// TestServeHTTP_NoHealthyTargets 测试没有健康目标时返回502
func TestServeHTTP_NoHealthyTargets(t *testing.T) {
cfg := testutil.NewTestProxyConfig("/api")
// 所有目标都不健康
targets := testutil.NewTestTargets("http://localhost:8081", "http://localhost:8082")
targets[0].Healthy.Store(false)
targets[1].Healthy.Store(false)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 创建测试请求
ctx := testutil.NewRequestCtx("GET", "/api/test")
// 执行请求
p.ServeHTTP(ctx)
// 应该返回502
if ctx.Response.StatusCode() != fasthttp.StatusBadGateway {
t.Errorf("ServeHTTP() status code = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusBadGateway)
}
}
// TestServeHTTP_RequestForwarding 测试请求转发
func TestServeHTTP_RequestForwarding(t *testing.T) {
// 创建本地测试服务器
ln := fasthttputil.NewInmemoryListener()
defer func() { _ = ln.Close() }()
// 启动后端服务器
go func() {
s := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.SetBodyString("Hello from backend")
ctx.Response.Header.Set("X-Backend-Header", "test-value")
},
}
_ = s.Serve(ln)
}()
// 等待服务器启动
time.Sleep(10 * time.Millisecond)
cfg := testutil.NewTestProxyConfig("/api")
targets := testutil.NewTestTargets("http://localhost:8080")
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 创建测试请求
ctx := testutil.NewRequestCtx("GET", "/api/test")
ctx.Request.Header.Set("X-Custom-Header", "client-value")
// 执行请求
p.ServeHTTP(ctx)
// 由于没有真实后端应该返回502
// 但在单元测试中我们可以验证错误处理逻辑
if ctx.Response.StatusCode() != fasthttp.StatusBadGateway {
t.Logf("ServeHTTP() status code = %d (expected 502 when no backend available)", ctx.Response.StatusCode())
}
}
// TestSelectTarget 测试目标选择
func TestSelectTarget(t *testing.T) {
tests := []struct {
name string
loadBalance string
clientIP string
expectedTarget string
targets []*loadbalance.Target
}{
{
name: "轮询选择",
loadBalance: "round_robin",
targets: []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
},
expectedTarget: "http://backend1:8080",
},
{
name: "跳过不健康目标",
loadBalance: "round_robin",
targets: []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
},
expectedTarget: "http://backend2:8080",
},
{
name: "IP哈希选择",
loadBalance: "ip_hash",
targets: []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
},
clientIP: "192.168.1.100",
expectedTarget: "any", // IP哈希应该返回一个目标具体是哪个取决于哈希值
},
{
name: "所有目标都不健康",
loadBalance: "round_robin",
targets: []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
},
expectedTarget: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 根据测试用例设置健康状态
switch tt.name {
case "轮询选择", "IP哈希选择":
for _, target := range tt.targets {
target.Healthy.Store(true)
}
case "跳过不健康目标":
tt.targets[0].Healthy.Store(false)
tt.targets[1].Healthy.Store(true)
case "所有目标都不健康":
for _, target := range tt.targets {
target.Healthy.Store(false)
}
}
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: tt.loadBalance,
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
p, err := NewProxy(cfg, tt.targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{
"X-Forwarded-For": tt.clientIP,
})
target := p.selectTarget(ctx)
if tt.expectedTarget == "" {
if target != nil {
t.Errorf("selectTarget() expected nil, got %v", target.URL)
}
return
}
if tt.loadBalance == "round_robin" && tt.expectedTarget != "" {
// 轮询应该选择第一个健康目标
if target == nil {
t.Error("selectTarget() returned nil for healthy targets")
return
}
if target.URL != tt.expectedTarget {
t.Errorf("selectTarget() = %v, want %v", target.URL, tt.expectedTarget)
}
}
// IP哈希应该始终返回同一个目标给同一个IP
if tt.loadBalance == "ip_hash" && tt.clientIP != "" {
if target == nil {
t.Error("selectTarget() returned nil for IP hash")
return
}
// 再次选择,应该返回相同的目标
target2 := p.selectTarget(ctx)
if target2 == nil || target2.URL != target.URL {
t.Error("IP hash should consistently return the same target for the same IP")
}
}
})
}
}
// TestModifyRequestHeaders 测试请求头修改
func TestModifyRequestHeaders(t *testing.T) {
tests := []struct {
name string
clientIP string
existingXFF string
setRequest map[string]string
removeHeaders []string
checkHeaders map[string]string
shouldNotExist []string
}{
{
name: "设置X-Real-IP",
clientIP: "192.168.1.100",
checkHeaders: map[string]string{
"X-Real-IP": "192.168.1.100",
},
},
{
name: "追加X-Forwarded-For",
clientIP: "192.168.1.100",
existingXFF: "10.0.0.1",
checkHeaders: map[string]string{
"X-Forwarded-For": "10.0.0.1, 10.0.0.1",
},
},
{
name: "新建X-Forwarded-For",
clientIP: "192.168.1.100",
checkHeaders: map[string]string{
"X-Forwarded-For": "192.168.1.100",
},
},
{
name: "自定义请求头",
clientIP: "192.168.1.100",
setRequest: map[string]string{
"X-Custom-Header": "custom-value",
"X-Another": "another-value",
},
checkHeaders: map[string]string{
"X-Custom-Header": "custom-value",
"X-Another": "another-value",
},
},
{
name: "移除请求头",
clientIP: "192.168.1.100",
removeHeaders: []string{"X-Remove-Me"},
shouldNotExist: []string{"X-Remove-Me"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Headers: config.ProxyHeaders{
SetRequest: tt.setRequest,
Remove: tt.removeHeaders,
},
}
targets := []*loadbalance.Target{
{URL: "http://localhost:8080"},
}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 构建 headers map
headers := make(map[string]string)
if tt.clientIP != "" {
headers["X-Real-IP"] = tt.clientIP
}
if tt.existingXFF != "" {
headers["X-Forwarded-For"] = tt.existingXFF
}
if len(tt.removeHeaders) > 0 {
for _, h := range tt.removeHeaders {
headers[h] = "should-be-removed"
}
}
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", headers)
target := &loadbalance.Target{URL: "http://localhost:8080"}
p.modifyRequestHeaders(ctx, target)
// 检查期望存在的头
for key, expectedValue := range tt.checkHeaders {
actualValue := string(ctx.Request.Header.Peek(key))
if actualValue != expectedValue {
t.Errorf("Header %s = %q, want %q", key, actualValue, expectedValue)
}
}
// 检查不应该存在的头
for _, key := range tt.shouldNotExist {
if ctx.Request.Header.Peek(key) != nil {
t.Errorf("Header %s should not exist", key)
}
}
})
}
}
// TestModifyResponseHeaders 测试响应头修改
func TestModifyResponseHeaders(t *testing.T) {
tests := []struct {
setResponse map[string]string
checkHeaders map[string]string
name string
}{
{
name: "设置自定义响应头",
setResponse: map[string]string{
"X-Custom-Response": "custom-value",
"X-Powered-By": "Lolly",
},
checkHeaders: map[string]string{
"X-Custom-Response": "custom-value",
"X-Powered-By": "Lolly",
},
},
{
name: "空响应头配置",
setResponse: nil,
checkHeaders: map[string]string{},
},
{
name: "覆盖已有响应头",
setResponse: map[string]string{
"Content-Type": "application/json",
},
checkHeaders: map[string]string{
"Content-Type": "application/json",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Headers: config.ProxyHeaders{
SetResponse: tt.setResponse,
},
}
targets := testutil.NewTestTargets("http://localhost:8080")
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/")
p.modifyResponseHeaders(ctx)
// 检查期望存在的头
for key, expectedValue := range tt.checkHeaders {
actualValue := string(ctx.Response.Header.Peek(key))
if actualValue != expectedValue {
t.Errorf("Response Header %s = %q, want %q", key, actualValue, expectedValue)
}
}
})
}
}
// TestGetClientIP 测试客户端IP提取
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
xff string
xri string
expected string
}{
{
name: "从X-Forwarded-For提取",
xff: "10.0.0.1, 10.0.0.2",
expected: "10.0.0.1",
},
{
name: "从X-Real-IP提取",
xri: "192.168.1.100",
expected: "192.168.1.100",
},
{
name: "X-Forwarded-For优先",
xff: "10.0.0.1",
xri: "192.168.1.100",
expected: "10.0.0.1",
},
{
name: "单IP",
xff: "10.0.0.1",
expected: "10.0.0.1",
},
{
name: "带空格",
xff: " 10.0.0.1 ",
expected: "10.0.0.1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := testutil.NewRequestCtxWithHeader("GET", "/", map[string]string{
"X-Forwarded-For": tt.xff,
"X-Real-IP": tt.xri,
})
ip := netutil.ExtractClientIP(ctx)
if ip != tt.expected {
t.Errorf("ExtractClientIP() = %q, want %q", ip, tt.expected)
}
})
}
}
// TestUpdateTargets 测试更新目标
// TestGetTargets 测试获取目标列表
// TestGetConfig 测试获取配置
// TestIsWebSocketRequest 测试WebSocket请求检测
func TestIsWebSocketRequest(t *testing.T) {
tests := []struct {
name string
upgrade string
connection string
expected bool
}{
{
name: "标准WebSocket请求",
upgrade: "websocket",
connection: "upgrade",
expected: true,
},
{
name: "大小写不敏感",
upgrade: "WebSocket",
connection: "Upgrade",
expected: true,
},
{
name: "非WebSocket升级",
upgrade: "h2c",
connection: "upgrade",
expected: false,
},
{
name: "非upgrade连接",
upgrade: "websocket",
connection: "keep-alive",
expected: false,
},
{
name: "keep-alive, Upgrade",
upgrade: "websocket",
connection: "keep-alive, Upgrade",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := testutil.NewRequestCtxWithHeader("GET", "/", map[string]string{
"Upgrade": tt.upgrade,
"Connection": tt.connection,
})
result := isWebSocketRequest(ctx)
if result != tt.expected {
t.Errorf("isWebSocketRequest() = %v, want %v", result, tt.expected)
}
})
}
}
// TestCreateBalancer 测试负载均衡器创建
func TestCreateBalancer(t *testing.T) {
tests := []struct {
cfg *config.ProxyConfig
name string
errContains string
wantErr bool
}{
{
name: "轮询",
cfg: &config.ProxyConfig{LoadBalance: "round_robin"},
},
{
name: "加权轮询",
cfg: &config.ProxyConfig{LoadBalance: "weighted_round_robin"},
},
{
name: "最少连接",
cfg: &config.ProxyConfig{LoadBalance: "least_conn"},
},
{
name: "IP哈希",
cfg: &config.ProxyConfig{LoadBalance: "ip_hash"},
},
{
name: "一致性哈希",
cfg: &config.ProxyConfig{LoadBalance: "consistent_hash", HashKey: "ip", VirtualNodes: 150},
},
{
name: "空算法(默认轮询)",
cfg: &config.ProxyConfig{LoadBalance: ""},
},
{
name: "无效算法",
cfg: &config.ProxyConfig{LoadBalance: "unknown_algorithm"},
wantErr: true,
errContains: "unsupported load balance algorithm",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
balancer, err := createBalancer(tt.cfg)
if tt.wantErr {
if err == nil {
t.Errorf("createBalancer(%v) expected error", tt.cfg.LoadBalance)
return
}
if !contains(err.Error(), tt.errContains) {
t.Errorf("createBalancer(%v) error = %v, want containing %q", tt.cfg.LoadBalance, err, tt.errContains)
}
return
}
if err != nil {
t.Errorf("createBalancer(%v) unexpected error: %v", tt.cfg.LoadBalance, err)
return
}
if balancer == nil {
t.Errorf("createBalancer(%v) returned nil balancer", tt.cfg.LoadBalance)
}
})
}
}
// TestCreateHostClient 测试HostClient创建
func TestCreateHostClient(t *testing.T) {
tests := []struct {
name string
targetURL string
timeout config.ProxyTimeout
}{
{
name: "HTTP地址",
targetURL: "http://localhost:8080",
timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second},
},
{
name: "HTTPS地址",
targetURL: "https://localhost:8443",
timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second},
},
{
name: "带路径的URL",
targetURL: "http://localhost:8080/path",
timeout: config.ProxyTimeout{Connect: 5 * time.Second},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := createHostClient(tt.targetURL, tt.timeout, nil, nil, "", nil)
if client == nil {
t.Error("createHostClient() returned nil")
return
}
// 检查基本属性
if client.Addr == "" {
t.Error("createHostClient() client.Addr is empty")
}
if tt.targetURL == "https://localhost:8443" && !client.IsTLS {
t.Error("createHostClient() IsTLS should be true for HTTPS")
}
})
}
}
// TestHandleWebSocket 测试 WebSocket 处理
// 注意:由于 WebSocket 代理使用 Hijack 获取底层连接,
// 这个测试主要验证函数不会 panic实际桥接功能需要集成测试
func TestHandleWebSocket(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := testutil.NewTestTargets("http://localhost:8080")
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 由于 handleWebSocket 使用 Hijack在测试环境中无法正常工作
// (需要一个真实的 HTTP 连接),因此我们仅验证函数存在且可调用
// 实际功能通过集成测试验证
target := testutil.NewTestTarget("http://localhost:8080")
client := p.getClient(target.URL)
// 验证客户端和目标已正确配置
if client == nil {
t.Error("Expected non-nil client")
}
if target.URL != "http://localhost:8080" {
t.Errorf("Expected target URL http://localhost:8080, got %s", target.URL)
}
}
// TestSetHealthChecker 测试健康检查器设置
// 注意SetHealthChecker 是公开方法,但 healthChecker 是私有字段
// 此测试验证方法可以正常调用
func TestSetHealthChecker(t *testing.T) {
cfg := testutil.NewTestProxyConfig("/api")
targets := testutil.NewTestTargets("http://localhost:8081")
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 创建健康检查器
hcCfg := &config.HealthCheckConfig{
Interval: 10 * time.Second,
Path: "/health",
Timeout: 5 * time.Second,
}
hc := NewHealthChecker(targets, hcCfg)
// 设置健康检查器 - 验证方法存在且可调用
p.SetHealthChecker(hc)
// 测试被动健康检查:标记目标为不健康
targets[0].Healthy.Store(true)
hc.MarkUnhealthy(targets[0])
if targets[0].Healthy.Load() {
t.Error("MarkUnhealthy() target should be unhealthy after marking")
}
}
// TestGetClient 测试客户端获取
func TestGetClient(t *testing.T) {
cfg := testutil.NewTestProxyConfig("/api")
targets := testutil.NewTestTargets("http://localhost:8081", "http://localhost:8082")
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 测试获取存在的客户端
client1 := p.getClient("http://localhost:8081")
if client1 == nil {
t.Error("getClient() returned nil for existing client")
}
client2 := p.getClient("http://localhost:8082")
if client2 == nil {
t.Error("getClient() returned nil for existing client")
}
// 测试获取不存在的客户端
client3 := p.getClient("http://localhost:9999")
if client3 != nil {
t.Error("getClient() should return nil for non-existent client")
}
}
// TestProxyCache 测试代理缓存功能
// TestServeHTTP_WithPassiveHealthCheck 测试带有被动健康检查的请求转发
func TestServeHTTP_WithPassiveHealthCheck(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 100 * time.Millisecond, Read: 100 * time.Millisecond, Write: 100 * time.Millisecond},
}
targets := testutil.NewTestHealthyTargets("http://127.0.0.1:59999") // 不存在的后端
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 设置健康检查器
hcCfg := &config.HealthCheckConfig{
Interval: 10 * time.Second,
Path: "/health",
Timeout: 5 * time.Second,
}
hc := NewHealthChecker(targets, hcCfg)
p.SetHealthChecker(hc)
// 创建测试请求
ctx := testutil.NewRequestCtx("GET", "/api/test")
// 执行请求 - 应该会失败并触发被动健康检查
p.ServeHTTP(ctx)
// 验证返回502错误
if ctx.Response.StatusCode() != fasthttp.StatusBadGateway {
t.Errorf("ServeHTTP() status code = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusBadGateway)
}
// 验证目标已被标记为不健康
if targets[0].Healthy.Load() {
t.Error("Target should be marked unhealthy after failed request")
}
}
// 辅助函数
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsAt(s, substr, 0))
}
func containsAt(s, substr string, start int) bool {
for i := start; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// TestUpstreamVariablesCapture 测试上游变量捕获
func TestUpstreamVariablesCapture(t *testing.T) {
// 创建后端服务器
backend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetBodyString("OK")
},
}
// 在随机端口启动后端
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
defer func() { _ = backendLn.Close() }()
go func() { _ = backend.Serve(backendLn) }()
// 等待后端启动
time.Sleep(50 * time.Millisecond)
backendAddr := "http://" + backendLn.Addr().String()
// 创建代理
targets := []*loadbalance.Target{
{URL: backendAddr, Weight: 1},
}
targets[0].Healthy.Store(true)
cfg := testutil.NewTestProxyConfig("/")
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("failed to create proxy: %v", err)
}
// 创建请求
ctx := testutil.NewRequestCtxWithHeader("GET", "/test", map[string]string{
"Host": "example.com",
})
// 执行代理请求
p.ServeHTTP(ctx)
// 验证响应
if ctx.Response.StatusCode() != 200 {
t.Errorf("expected status 200, got %d", ctx.Response.StatusCode())
}
// 测试 UpstreamTiming
timing := NewUpstreamTiming()
if timing == nil {
t.Error("NewUpstreamTiming() returned nil")
}
// 测试时间标记
timing.MarkConnectStart()
timing.MarkConnectEnd()
timing.MarkHeaderReceived()
timing.MarkResponseEnd()
// 验证时间计算
if timing.GetConnectTime() < 0 {
t.Error("GetConnectTime() should be >= 0")
}
if timing.GetHeaderTime() < 0 {
t.Error("GetHeaderTime() should be >= 0")
}
if timing.GetResponseTime() < 0 {
t.Error("GetResponseTime() should be >= 0")
}
}
// TestUpstreamVariablesErrorPaths 测试上游变量错误路径
func TestUpstreamVariablesErrorPaths(t *testing.T) {
tests := []struct {
name string
backendAddr string
expectedAddr string
expectedCode int
}{
{
name: "no healthy backend",
backendAddr: "",
expectedAddr: "FAILED",
expectedCode: 502,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var targets []*loadbalance.Target
if tt.backendAddr != "" {
targets = []*loadbalance.Target{
{URL: tt.backendAddr, Weight: 1},
}
targets[0].Healthy.Store(true)
} else {
// 创建一个不健康目标
targets = []*loadbalance.Target{
{URL: "http://127.0.0.1:1", Weight: 1},
}
}
cfg := &config.ProxyConfig{
Path: "/",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{
Connect: 1 * time.Millisecond, // 超短超时
Read: 1 * time.Millisecond,
Write: 1 * time.Millisecond,
},
}
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("failed to create proxy: %v", err)
}
ctx := testutil.NewRequestCtxWithHeader("GET", "/test", map[string]string{
"Host": "example.com",
})
p.ServeHTTP(ctx)
// 验证错误状态码
if ctx.Response.StatusCode() != tt.expectedCode &&
ctx.Response.StatusCode() != 502 &&
ctx.Response.StatusCode() != 504 {
t.Errorf("expected status %d or 502/504, got %d", tt.expectedCode, ctx.Response.StatusCode())
}
})
}
}
// TestFinalizeUpstreamVars 测试 FinalizeUpstreamVars 函数
func TestFinalizeUpstreamVars(t *testing.T) {
ctx := testutil.NewRequestCtx("GET", "/test")
vc := variable.NewContext(ctx)
defer variable.ReleaseContext(vc)
timing := NewUpstreamTiming()
timing.MarkConnectStart()
time.Sleep(1 * time.Millisecond)
timing.MarkConnectEnd()
timing.MarkHeaderReceived()
time.Sleep(1 * time.Millisecond)
timing.MarkResponseEnd()
// 测试 FinalizeUpstreamVars
FinalizeUpstreamVars(vc, "http://backend:8080", 200, timing)
// 验证变量已设置
addr, ok := vc.Get("upstream_addr")
if !ok || addr != "http://backend:8080" {
t.Errorf("upstream_addr = %q, want 'http://backend:8080'", addr)
}
status, ok := vc.Get("upstream_status")
if !ok || status != "200" {
t.Errorf("upstream_status = %q, want '200'", status)
}
// 测试 nil vc
FinalizeUpstreamVars(nil, "http://backend:8080", 200, timing)
// 不应该 panic
}
// TestUpstreamTimingZero 测试 UpstreamTiming 零值处理
func TestUpstreamTimingZero(t *testing.T) {
timing := NewUpstreamTiming()
// 未标记时应该返回 0
if timing.GetConnectTime() != 0 {
t.Errorf("GetConnectTime() = %v, want 0", timing.GetConnectTime())
}
if timing.GetHeaderTime() != 0 {
t.Errorf("GetHeaderTime() = %v, want 0", timing.GetHeaderTime())
}
if timing.GetResponseTime() != 0 {
t.Errorf("GetResponseTime() = %v, want 0", timing.GetResponseTime())
}
// 只标记开始
timing.MarkConnectStart()
if timing.GetConnectTime() != 0 {
t.Errorf("GetConnectTime() after MarkConnectStart = %v, want 0", timing.GetConnectTime())
}
}
// TestUpstreamTiming_ZeroValues 测试 UpstreamTiming 完全零值情况
func TestUpstreamTiming_ZeroValues(t *testing.T) {
// 创建一个零值的时间记录器(模拟未初始化的状态)
timing := &UpstreamTiming{}
// 所有时间应该返回 0
if timing.GetConnectTime() != 0 {
t.Errorf("Zero timing GetConnectTime() = %v, want 0", timing.GetConnectTime())
}
if timing.GetHeaderTime() != 0 {
t.Errorf("Zero timing GetHeaderTime() = %v, want 0", timing.GetHeaderTime())
}
if timing.GetResponseTime() != 0 {
t.Errorf("Zero timing GetResponseTime() = %v, want 0", timing.GetResponseTime())
}
}
// TestUpstreamTiming_PartialMarks 测试部分标记的情况
func TestUpstreamTiming_PartialMarks(t *testing.T) {
timing := NewUpstreamTiming()
// 只标记 connectEnd不标记 connectStart
timing.MarkConnectEnd()
if timing.GetConnectTime() != 0 {
t.Errorf("GetConnectTime() with only end marked = %v, want 0", timing.GetConnectTime())
}
// 重置并测试只有 headerReceived 的情况
timing = NewUpstreamTiming()
timing.MarkHeaderReceived()
if timing.GetHeaderTime() != 0 {
t.Errorf("GetHeaderTime() without connectEnd = %v, want 0", timing.GetHeaderTime())
}
// 重置并测试只有 responseEnd 的情况
timing = NewUpstreamTiming()
timing.MarkResponseEnd()
if timing.GetResponseTime() != 0 {
t.Errorf("GetResponseTime() without connectEnd = %v, want 0", timing.GetResponseTime())
}
}