// Package proxy provides reverse proxy functionality for the Lolly HTTP server. // //go:generate go test -v ./... package proxy import ( "testing" "time" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/loadbalance" ) // TestNewProxy 测试 NewProxy 函数 func TestNewProxy(t *testing.T) { tests := []struct { name string cfg *config.ProxyConfig targets []*loadbalance.Target wantErr bool errContains string }{ { name: "正常创建", cfg: &config.ProxyConfig{ Path: "/api", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second}, }, targets: []*loadbalance.Target{ {URL: "http://localhost:8081", Healthy: true}, {URL: "http://localhost:8082", Healthy: true}, }, wantErr: false, }, { name: "nil配置", cfg: nil, targets: []*loadbalance.Target{{URL: "http://localhost:8081"}}, wantErr: true, errContains: "proxy config is nil", }, { name: "空目标列表", cfg: &config.ProxyConfig{Path: "/api"}, targets: []*loadbalance.Target{}, wantErr: true, errContains: "no proxy targets provided", }, { name: "nil目标列表", cfg: &config.ProxyConfig{Path: "/api"}, targets: nil, wantErr: true, errContains: "no proxy targets provided", }, { name: "默认负载均衡算法", cfg: &config.ProxyConfig{ Path: "/api", LoadBalance: "", }, targets: []*loadbalance.Target{ {URL: "http://localhost:8081", Healthy: true}, }, wantErr: false, }, { name: "加权轮询算法", cfg: &config.ProxyConfig{ Path: "/api", LoadBalance: "weighted_round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, }, targets: []*loadbalance.Target{ {URL: "http://localhost:8081", Weight: 1, Healthy: true}, {URL: "http://localhost:8082", Weight: 2, Healthy: true}, }, wantErr: false, }, { name: "最少连接算法", cfg: &config.ProxyConfig{ Path: "/api", LoadBalance: "least_conn", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, }, targets: []*loadbalance.Target{ {URL: "http://localhost:8081", Healthy: true}, }, wantErr: false, }, { name: "IP哈希算法", cfg: &config.ProxyConfig{ Path: "/api", LoadBalance: "ip_hash", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, }, targets: []*loadbalance.Target{ {URL: "http://localhost:8081", Healthy: true}, }, wantErr: false, }, { name: "无效负载均衡算法", cfg: &config.ProxyConfig{ Path: "/api", LoadBalance: "invalid_algorithm", }, targets: []*loadbalance.Target{ {URL: "http://localhost:8081", Healthy: true}, }, wantErr: true, errContains: "unsupported load balance algorithm", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p, err := NewProxy(tt.cfg, tt.targets) if tt.wantErr { if err == nil { t.Errorf("NewProxy() expected error containing %q, got nil", tt.errContains) return } if !contains(err.Error(), tt.errContains) { t.Errorf("NewProxy() error = %v, want containing %q", err, tt.errContains) } return } if err != nil { t.Errorf("NewProxy() unexpected error: %v", err) return } if p == nil { t.Error("NewProxy() returned nil proxy") return } if p.config != tt.cfg { t.Error("NewProxy() proxy config not set correctly") } if p.balancer == nil { t.Error("NewProxy() balancer not initialized") } }) } } // TestServeHTTP_NoHealthyTargets 测试没有健康目标时返回502 func TestServeHTTP_NoHealthyTargets(t *testing.T) { cfg := &config.ProxyConfig{ Path: "/api", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second}, } // 所有目标都不健康 targets := []*loadbalance.Target{ {URL: "http://localhost:8081", Healthy: false}, {URL: "http://localhost:8082", Healthy: false}, } p, err := NewProxy(cfg, targets) if err != nil { t.Fatalf("NewProxy() error: %v", err) } // 创建测试请求 ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fasthttp.MethodGet) ctx.Request.SetRequestURI("/api/test") // 执行请求 p.ServeHTTP(ctx) // 应该返回502 if ctx.Response.StatusCode() != fasthttp.StatusBadGateway { t.Errorf("ServeHTTP() status code = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusBadGateway) } } // TestServeHTTP_RequestForwarding 测试请求转发 func TestServeHTTP_RequestForwarding(t *testing.T) { // 创建本地测试服务器 ln := fasthttputil.NewInmemoryListener() defer ln.Close() // 启动后端服务器 go func() { s := &fasthttp.Server{ Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetStatusCode(fasthttp.StatusOK) ctx.SetBodyString("Hello from backend") ctx.Response.Header.Set("X-Backend-Header", "test-value") }, } s.Serve(ln) }() // 等待服务器启动 time.Sleep(10 * time.Millisecond) cfg := &config.ProxyConfig{ Path: "/api", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second}, } targets := []*loadbalance.Target{ {URL: "http://localhost:8080", Healthy: true}, } p, err := NewProxy(cfg, targets) if err != nil { t.Fatalf("NewProxy() error: %v", err) } // 创建测试请求 ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fasthttp.MethodGet) ctx.Request.SetRequestURI("/api/test") ctx.Request.Header.Set("X-Custom-Header", "client-value") // 执行请求 p.ServeHTTP(ctx) // 由于没有真实后端,应该返回502 // 但在单元测试中我们可以验证错误处理逻辑 if ctx.Response.StatusCode() != fasthttp.StatusBadGateway { t.Logf("ServeHTTP() status code = %d (expected 502 when no backend available)", ctx.Response.StatusCode()) } } // TestSelectTarget 测试目标选择 func TestSelectTarget(t *testing.T) { tests := []struct { name string loadBalance string targets []*loadbalance.Target clientIP string expectedTarget string }{ { name: "轮询选择", loadBalance: "round_robin", targets: []*loadbalance.Target{ {URL: "http://backend1:8080", Healthy: true}, {URL: "http://backend2:8080", Healthy: true}, }, expectedTarget: "http://backend1:8080", }, { name: "跳过不健康目标", loadBalance: "round_robin", targets: []*loadbalance.Target{ {URL: "http://backend1:8080", Healthy: false}, {URL: "http://backend2:8080", Healthy: true}, }, expectedTarget: "http://backend2:8080", }, { name: "IP哈希选择", loadBalance: "ip_hash", targets: []*loadbalance.Target{ {URL: "http://backend1:8080", Healthy: true}, {URL: "http://backend2:8080", Healthy: true}, }, clientIP: "192.168.1.100", expectedTarget: "any", // IP哈希应该返回一个目标,具体是哪个取决于哈希值 }, { name: "所有目标都不健康", loadBalance: "round_robin", targets: []*loadbalance.Target{ {URL: "http://backend1:8080", Healthy: false}, {URL: "http://backend2:8080", Healthy: false}, }, expectedTarget: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := &config.ProxyConfig{ Path: "/api", LoadBalance: tt.loadBalance, Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, } p, err := NewProxy(cfg, tt.targets) if err != nil { t.Fatalf("NewProxy() error: %v", err) } ctx := &fasthttp.RequestCtx{} if tt.clientIP != "" { // 设置远程地址模拟客户端IP ctx.Request.Header.Set("X-Forwarded-For", tt.clientIP) } ctx.Request.SetRequestURI("/api/test") target := p.selectTarget(ctx) if tt.expectedTarget == "" { if target != nil { t.Errorf("selectTarget() expected nil, got %v", target.URL) } return } if tt.loadBalance == "round_robin" && tt.expectedTarget != "" { // 轮询应该选择第一个健康目标 if target == nil { t.Error("selectTarget() returned nil for healthy targets") return } if target.URL != tt.expectedTarget { t.Errorf("selectTarget() = %v, want %v", target.URL, tt.expectedTarget) } } // IP哈希应该始终返回同一个目标给同一个IP if tt.loadBalance == "ip_hash" && tt.clientIP != "" { if target == nil { t.Error("selectTarget() returned nil for IP hash") return } // 再次选择,应该返回相同的目标 target2 := p.selectTarget(ctx) if target2 == nil || target2.URL != target.URL { t.Error("IP hash should consistently return the same target for the same IP") } } }) } } // TestModifyRequestHeaders 测试请求头修改 func TestModifyRequestHeaders(t *testing.T) { tests := []struct { name string clientIP string existingXFF string setRequest map[string]string removeHeaders []string checkHeaders map[string]string shouldNotExist []string }{ { name: "设置X-Real-IP", clientIP: "192.168.1.100", checkHeaders: map[string]string{ "X-Real-IP": "192.168.1.100", }, }, { name: "追加X-Forwarded-For", clientIP: "192.168.1.100", existingXFF: "10.0.0.1", checkHeaders: map[string]string{ "X-Forwarded-For": "10.0.0.1, 10.0.0.1", }, }, { name: "新建X-Forwarded-For", clientIP: "192.168.1.100", checkHeaders: map[string]string{ "X-Forwarded-For": "192.168.1.100", }, }, { name: "自定义请求头", clientIP: "192.168.1.100", setRequest: map[string]string{ "X-Custom-Header": "custom-value", "X-Another": "another-value", }, checkHeaders: map[string]string{ "X-Custom-Header": "custom-value", "X-Another": "another-value", }, }, { name: "移除请求头", clientIP: "192.168.1.100", removeHeaders: []string{"X-Remove-Me"}, shouldNotExist: []string{"X-Remove-Me"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := &config.ProxyConfig{ Path: "/api", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, Headers: config.ProxyHeaders{ SetRequest: tt.setRequest, Remove: tt.removeHeaders, }, } targets := []*loadbalance.Target{ {URL: "http://localhost:8080", Healthy: true}, } p, err := NewProxy(cfg, targets) if err != nil { t.Fatalf("NewProxy() error: %v", err) } ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/api/test") // 设置客户端IP if tt.clientIP != "" { ctx.Request.Header.Set("X-Real-IP", tt.clientIP) } // 设置已有的X-Forwarded-For if tt.existingXFF != "" { ctx.Request.Header.Set("X-Forwarded-For", tt.existingXFF) } // 设置需要被移除的头 if len(tt.removeHeaders) > 0 { for _, h := range tt.removeHeaders { ctx.Request.Header.Set(h, "should-be-removed") } } target := &loadbalance.Target{URL: "http://localhost:8080"} p.modifyRequestHeaders(ctx, target) // 检查期望存在的头 for key, expectedValue := range tt.checkHeaders { actualValue := string(ctx.Request.Header.Peek(key)) if actualValue != expectedValue { t.Errorf("Header %s = %q, want %q", key, actualValue, expectedValue) } } // 检查不应该存在的头 for _, key := range tt.shouldNotExist { if ctx.Request.Header.Peek(key) != nil { t.Errorf("Header %s should not exist", key) } } }) } } // TestModifyResponseHeaders 测试响应头修改 func TestModifyResponseHeaders(t *testing.T) { tests := []struct { name string setResponse map[string]string checkHeaders map[string]string }{ { name: "设置自定义响应头", setResponse: map[string]string{ "X-Custom-Response": "custom-value", "X-Powered-By": "Lolly", }, checkHeaders: map[string]string{ "X-Custom-Response": "custom-value", "X-Powered-By": "Lolly", }, }, { name: "空响应头配置", setResponse: nil, checkHeaders: map[string]string{}, }, { name: "覆盖已有响应头", setResponse: map[string]string{ "Content-Type": "application/json", }, checkHeaders: map[string]string{ "Content-Type": "application/json", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := &config.ProxyConfig{ Path: "/api", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, Headers: config.ProxyHeaders{ SetResponse: tt.setResponse, }, } targets := []*loadbalance.Target{ {URL: "http://localhost:8080", Healthy: true}, } p, err := NewProxy(cfg, targets) if err != nil { t.Fatalf("NewProxy() error: %v", err) } ctx := &fasthttp.RequestCtx{} ctx.Response.SetStatusCode(fasthttp.StatusOK) p.modifyResponseHeaders(ctx) // 检查期望存在的头 for key, expectedValue := range tt.checkHeaders { actualValue := string(ctx.Response.Header.Peek(key)) if actualValue != expectedValue { t.Errorf("Response Header %s = %q, want %q", key, actualValue, expectedValue) } } }) } } // TestGetClientIP 测试客户端IP提取 func TestGetClientIP(t *testing.T) { tests := []struct { name string xff string xri string expected string }{ { name: "从X-Forwarded-For提取", xff: "10.0.0.1, 10.0.0.2", expected: "10.0.0.1", }, { name: "从X-Real-IP提取", xri: "192.168.1.100", expected: "192.168.1.100", }, { name: "X-Forwarded-For优先", xff: "10.0.0.1", xri: "192.168.1.100", expected: "10.0.0.1", }, { name: "单IP", xff: "10.0.0.1", expected: "10.0.0.1", }, { name: "带空格", xff: " 10.0.0.1 ", expected: "10.0.0.1", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &fasthttp.RequestCtx{} if tt.xff != "" { ctx.Request.Header.Set("X-Forwarded-For", tt.xff) } if tt.xri != "" { ctx.Request.Header.Set("X-Real-IP", tt.xri) } ip := getClientIP(ctx) if ip != tt.expected { t.Errorf("getClientIP() = %q, want %q", ip, tt.expected) } }) } } // TestUpdateTargets 测试更新目标 func TestUpdateTargets(t *testing.T) { cfg := &config.ProxyConfig{ Path: "/api", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, } initialTargets := []*loadbalance.Target{ {URL: "http://old1:8080", Healthy: true}, {URL: "http://old2:8080", Healthy: true}, } p, err := NewProxy(cfg, initialTargets) if err != nil { t.Fatalf("NewProxy() error: %v", err) } // 更新目标 newTargets := []*loadbalance.Target{ {URL: "http://new1:8080", Healthy: true}, {URL: "http://new2:8080", Healthy: true}, {URL: "http://new3:8080", Healthy: true}, } err = p.UpdateTargets(newTargets) if err != nil { t.Errorf("UpdateTargets() error: %v", err) } // 验证目标已更新 targets := p.GetTargets() if len(targets) != len(newTargets) { t.Errorf("UpdateTargets() targets count = %d, want %d", len(targets), len(newTargets)) } // 验证空目标列表返回错误 err = p.UpdateTargets([]*loadbalance.Target{}) if err == nil { t.Error("UpdateTargets([]) should return error") } // 验证nil目标列表返回错误 err = p.UpdateTargets(nil) if err == nil { t.Error("UpdateTargets(nil) should return error") } } // TestGetTargets 测试获取目标列表 func TestGetTargets(t *testing.T) { cfg := &config.ProxyConfig{ Path: "/api", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, } targets := []*loadbalance.Target{ {URL: "http://backend1:8080", Healthy: true}, {URL: "http://backend2:8080", Healthy: true}, } p, err := NewProxy(cfg, targets) if err != nil { t.Fatalf("NewProxy() error: %v", err) } gotTargets := p.GetTargets() if len(gotTargets) != len(targets) { t.Errorf("GetTargets() returned %d targets, want %d", len(gotTargets), len(targets)) } for i, target := range gotTargets { if target.URL != targets[i].URL { t.Errorf("GetTargets()[%d].URL = %q, want %q", i, target.URL, targets[i].URL) } } } // TestGetConfig 测试获取配置 func TestGetConfig(t *testing.T) { cfg := &config.ProxyConfig{ Path: "/api", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, } targets := []*loadbalance.Target{ {URL: "http://localhost:8080", Healthy: true}, } p, err := NewProxy(cfg, targets) if err != nil { t.Fatalf("NewProxy() error: %v", err) } gotConfig := p.GetConfig() if gotConfig != cfg { t.Error("GetConfig() returned different config") } if gotConfig.Path != cfg.Path { t.Errorf("GetConfig().Path = %q, want %q", gotConfig.Path, cfg.Path) } } // TestIsWebSocketRequest 测试WebSocket请求检测 func TestIsWebSocketRequest(t *testing.T) { tests := []struct { name string upgrade string connection string expected bool }{ { name: "标准WebSocket请求", upgrade: "websocket", connection: "upgrade", expected: true, }, { name: "大小写不敏感", upgrade: "WebSocket", connection: "Upgrade", expected: true, }, { name: "非WebSocket升级", upgrade: "h2c", connection: "upgrade", expected: false, }, { name: "非upgrade连接", upgrade: "websocket", connection: "keep-alive", expected: false, }, { name: "keep-alive, Upgrade", upgrade: "websocket", connection: "keep-alive, Upgrade", expected: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := &fasthttp.RequestCtx{} if tt.upgrade != "" { ctx.Request.Header.Set("Upgrade", tt.upgrade) } if tt.connection != "" { ctx.Request.Header.Set("Connection", tt.connection) } result := isWebSocketRequest(ctx) if result != tt.expected { t.Errorf("isWebSocketRequest() = %v, want %v", result, tt.expected) } }) } } // TestCreateBalancer 测试负载均衡器创建 func TestCreateBalancer(t *testing.T) { tests := []struct { name string algorithm string wantErr bool errContains string }{ { name: "轮询", algorithm: "round_robin", wantErr: false, }, { name: "加权轮询", algorithm: "weighted_round_robin", wantErr: false, }, { name: "最少连接", algorithm: "least_conn", wantErr: false, }, { name: "IP哈希", algorithm: "ip_hash", wantErr: false, }, { name: "空算法(默认轮询)", algorithm: "", wantErr: false, }, { name: "无效算法", algorithm: "unknown_algorithm", wantErr: true, errContains: "unsupported load balance algorithm", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { balancer, err := createBalancer(tt.algorithm) if tt.wantErr { if err == nil { t.Errorf("createBalancer(%q) expected error", tt.algorithm) return } if !contains(err.Error(), tt.errContains) { t.Errorf("createBalancer(%q) error = %v, want containing %q", tt.algorithm, err, tt.errContains) } return } if err != nil { t.Errorf("createBalancer(%q) unexpected error: %v", tt.algorithm, err) return } if balancer == nil { t.Errorf("createBalancer(%q) returned nil balancer", tt.algorithm) } }) } } // TestCreateHostClient 测试HostClient创建 func TestCreateHostClient(t *testing.T) { tests := []struct { name string targetURL string timeout config.ProxyTimeout }{ { name: "HTTP地址", targetURL: "http://localhost:8080", timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second}, }, { name: "HTTPS地址", targetURL: "https://localhost:8443", timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second}, }, { name: "带路径的URL", targetURL: "http://localhost:8080/path", timeout: config.ProxyTimeout{Connect: 5 * time.Second}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client := createHostClient(tt.targetURL, tt.timeout) if client == nil { t.Error("createHostClient() returned nil") return } // 检查基本属性 if client.Addr == "" { t.Error("createHostClient() client.Addr is empty") } if tt.targetURL == "https://localhost:8443" && !client.IsTLS { t.Error("createHostClient() IsTLS should be true for HTTPS") } }) } } // TestHandleWebSocket 测试WebSocket处理 func TestHandleWebSocket(t *testing.T) { cfg := &config.ProxyConfig{ Path: "/api", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, } targets := []*loadbalance.Target{ {URL: "http://localhost:8080", Healthy: true}, } p, err := NewProxy(cfg, targets) if err != nil { t.Fatalf("NewProxy() error: %v", err) } ctx := &fasthttp.RequestCtx{} ctx.Request.Header.Set("Upgrade", "websocket") ctx.Request.Header.Set("Connection", "upgrade") target := &loadbalance.Target{URL: "http://localhost:8080"} client := p.getClient(target.URL) p.handleWebSocket(ctx, target, client) // WebSocket应该返回501 Not Implemented if ctx.Response.StatusCode() != fasthttp.StatusNotImplemented { t.Errorf("handleWebSocket() status code = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusNotImplemented) } } // 辅助函数 func contains(s, substr string) bool { return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsAt(s, substr, 0)) } func containsAt(s, substr string, start int) bool { for i := start; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false }