diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 694b28a..6a9a971 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -412,6 +412,8 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString(" # hide_response: [] # 隐藏的响应头列表\n") buf.WriteString(" # pass_response: [] # 白名单传递的响应头\n") buf.WriteString(" # ignore_headers: [] # 完全忽略的头部(不传递给客户端也不记录)\n") + buf.WriteString(" # set_forwarded_host: true # 是否设置 X-Forwarded-Host(nil/true=设置,false=不设置)\n") + buf.WriteString(" # set_forwarded_proto: true # 是否设置 X-Forwarded-Proto(nil/true=设置,false=不设置)\n") buf.WriteString(" # cookie_domain: \"\" # Cookie 域重写\n") buf.WriteString(" # cookie_path: \"\" # Cookie 路径重写\n") buf.WriteString(" # cache: # 代理缓存\n") diff --git a/internal/config/proxy_config.go b/internal/config/proxy_config.go index bb64dcc..6bfa6be 100644 --- a/internal/config/proxy_config.go +++ b/internal/config/proxy_config.go @@ -334,6 +334,18 @@ type ProxyHeaders struct { // CookiePath Cookie 路径重写 // 将响应中 Set-Cookie 的 path 替换为此值 CookiePath string `yaml:"cookie_path"` + + // SetForwardedHost 控制 X-Forwarded-Host 头的设置 + // nil (默认): 设置 X-Forwarded-Host(向后兼容) + // true: 显式设置 X-Forwarded-Host + // false: 不设置 X-Forwarded-Host + SetForwardedHost *bool `yaml:"set_forwarded_host"` + + // SetForwardedProto 控制 X-Forwarded-Proto 头的设置 + // nil (默认): 设置 X-Forwarded-Proto(向后兼容) + // true: 显式设置 X-Forwarded-Proto + // false: 不设置 X-Forwarded-Proto + SetForwardedProto *bool `yaml:"set_forwarded_proto"` } // ProxySSLConfig 上游 SSL/TLS 配置。 diff --git a/internal/proxy/header_modifier.go b/internal/proxy/header_modifier.go index 7148fa5..6e988ba 100644 --- a/internal/proxy/header_modifier.go +++ b/internal/proxy/header_modifier.go @@ -32,7 +32,18 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan // 提取并设置 X-Forwarded 系列头 fh := ExtractForwardedHeaders(ctx) - SetForwardedHeaders(headers, fh, true) + + // 根据配置决定是否设置 X-Forwarded-Host 和 X-Forwarded-Proto + setHost := true // 默认值(向后兼容) + if p.config.Headers.SetForwardedHost != nil { + setHost = *p.config.Headers.SetForwardedHost + } + setProto := true // 默认值(向后兼容) + if p.config.Headers.SetForwardedProto != nil { + setProto = *p.config.Headers.SetForwardedProto + } + + SetForwardedHeaders(headers, fh, true, setHost, setProto) // 从配置设置自定义请求头(支持变量展开) if p.config.Headers.SetRequest != nil { diff --git a/internal/proxy/headers.go b/internal/proxy/headers.go index 405de6d..11e677a 100644 --- a/internal/proxy/headers.go +++ b/internal/proxy/headers.go @@ -69,7 +69,9 @@ func ExtractForwardedHeaders(ctx *fasthttp.RequestCtx) ForwardedHeaders { // - headers: 目标请求头 // - fh: ForwardedHeaders 结构体 // - appendXFF: 是否追加到已有的 X-Forwarded-For 头 -func SetForwardedHeaders(headers *fasthttp.RequestHeader, fh ForwardedHeaders, appendXFF bool) { +// - setHost: 是否设置 X-Forwarded-Host +// - setProto: 是否设置 X-Forwarded-Proto +func SetForwardedHeaders(headers *fasthttp.RequestHeader, fh ForwardedHeaders, appendXFF, setHost, setProto bool) { // 设置 X-Real-IP if fh.ClientIP != "" { headers.Set("X-Real-IP", fh.ClientIP) @@ -94,13 +96,13 @@ func SetForwardedHeaders(headers *fasthttp.RequestHeader, fh ForwardedHeaders, a } } - // 设置 X-Forwarded-Host - if fh.Host != "" { + // 设置 X-Forwarded-Host(仅在 setHost 为 true 时) + if setHost && fh.Host != "" { headers.Set("X-Forwarded-Host", fh.Host) } - // 设置 X-Forwarded-Proto - if fh.Proto != "" { + // 设置 X-Forwarded-Proto(仅在 setProto 为 true 时) + if setProto && fh.Proto != "" { headers.Set("X-Forwarded-Proto", fh.Proto) } } @@ -111,7 +113,9 @@ func SetForwardedHeaders(headers *fasthttp.RequestHeader, fh ForwardedHeaders, a // 参数: // - builder: strings.Builder 实例 // - fh: ForwardedHeaders 结构体 -func WriteForwardedHeaders(builder *strings.Builder, fh ForwardedHeaders) { +// - setHost: 是否设置 X-Forwarded-Host +// - setProto: 是否设置 X-Forwarded-Proto +func WriteForwardedHeaders(builder *strings.Builder, fh ForwardedHeaders, setHost, setProto bool) { if fh.ClientIP != "" { builder.WriteString("X-Forwarded-For: ") builder.WriteString(fh.ClientIP) @@ -121,13 +125,13 @@ func WriteForwardedHeaders(builder *strings.Builder, fh ForwardedHeaders) { builder.WriteString("\r\n") } - if fh.Host != "" { + if setHost && fh.Host != "" { builder.WriteString("X-Forwarded-Host: ") builder.WriteString(fh.Host) builder.WriteString("\r\n") } - if fh.Proto != "" { + if setProto && fh.Proto != "" { builder.WriteString("X-Forwarded-Proto: ") builder.WriteString(fh.Proto) builder.WriteString("\r\n") diff --git a/internal/proxy/headers_test.go b/internal/proxy/headers_test.go new file mode 100644 index 0000000..0e7e2c4 --- /dev/null +++ b/internal/proxy/headers_test.go @@ -0,0 +1,264 @@ +package proxy + +import ( + "strings" + "testing" + + "github.com/valyala/fasthttp" +) + +// TestSetForwardedHeaders_SetHost 测试 SetForwardedHost 配置对 X-Forwarded-Host 头的控制 +func TestSetForwardedHeaders_SetHost(t *testing.T) { + tests := []struct { + name string + setHost bool + expectHost bool + }{ + { + name: "setHost=true sets X-Forwarded-Host", + setHost: true, + expectHost: true, + }, + { + name: "setHost=false does not set X-Forwarded-Host", + setHost: false, + expectHost: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + headers := &fasthttp.RequestHeader{} + fh := ForwardedHeaders{ + ClientIP: "192.168.1.1", + Host: "example.com:8080", + Proto: "https", + } + + // setProto=true 因为我们要测试 setHost 的效果 + SetForwardedHeaders(headers, fh, true, tt.setHost, true) + + hasHost := len(headers.Peek("X-Forwarded-Host")) > 0 + if hasHost != tt.expectHost { + t.Errorf("X-Forwarded-Host presence = %v, want %v", hasHost, tt.expectHost) + } + + // X-Forwarded-For 和 X-Real-IP 应该始终设置 + if len(headers.Peek("X-Forwarded-For")) == 0 { + t.Error("X-Forwarded-For should always be set") + } + if len(headers.Peek("X-Real-IP")) == 0 { + t.Error("X-Real-IP should always be set") + } + }) + } +} + +// TestSetForwardedHeaders_SetProto 测试 SetForwardedProto 配置对 X-Forwarded-Proto 头的控制 +func TestSetForwardedHeaders_SetProto(t *testing.T) { + tests := []struct { + name string + setProto bool + expectProto bool + }{ + { + name: "setProto=true sets X-Forwarded-Proto", + setProto: true, + expectProto: true, + }, + { + name: "setProto=false does not set X-Forwarded-Proto", + setProto: false, + expectProto: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + headers := &fasthttp.RequestHeader{} + fh := ForwardedHeaders{ + ClientIP: "192.168.1.1", + Host: "example.com:8080", + Proto: "https", + } + + // setHost=true 因为我们要测试 setProto 的效果 + SetForwardedHeaders(headers, fh, true, true, tt.setProto) + + hasProto := len(headers.Peek("X-Forwarded-Proto")) > 0 + if hasProto != tt.expectProto { + t.Errorf("X-Forwarded-Proto presence = %v, want %v", hasProto, tt.expectProto) + } + }) + } +} + +// TestSetForwardedHeaders_DefaultBehavior 测试默认行为(所有参数为 true) +func TestSetForwardedHeaders_DefaultBehavior(t *testing.T) { + headers := &fasthttp.RequestHeader{} + fh := ForwardedHeaders{ + ClientIP: "10.0.0.1", + Host: "localhost:8082", + Proto: "http", + } + + SetForwardedHeaders(headers, fh, true, true, true) + + // 验证所有头都被设置 + if string(headers.Peek("X-Forwarded-For")) != "10.0.0.1" { + t.Errorf("X-Forwarded-For = %s, want 10.0.0.1", headers.Peek("X-Forwarded-For")) + } + if string(headers.Peek("X-Real-IP")) != "10.0.0.1" { + t.Errorf("X-Real-IP = %s, want 10.0.0.1", headers.Peek("X-Real-IP")) + } + if string(headers.Peek("X-Forwarded-Host")) != "localhost:8082" { + t.Errorf("X-Forwarded-Host = %s, want localhost:8082", headers.Peek("X-Forwarded-Host")) + } + if string(headers.Peek("X-Forwarded-Proto")) != "http" { + t.Errorf("X-Forwarded-Proto = %s, want http", headers.Peek("X-Forwarded-Proto")) + } +} + +// TestSetForwardedHeaders_AllDisabled 测试所有控制参数为 false +func TestSetForwardedHeaders_AllDisabled(t *testing.T) { + headers := &fasthttp.RequestHeader{} + fh := ForwardedHeaders{ + ClientIP: "10.0.0.1", + Host: "localhost:8082", + Proto: "http", + } + + SetForwardedHeaders(headers, fh, true, false, false) + + // X-Forwarded-For 和 X-Real-IP 应该始终设置 + if len(headers.Peek("X-Forwarded-For")) == 0 { + t.Error("X-Forwarded-For should be set even when setHost/setProto are false") + } + if len(headers.Peek("X-Real-IP")) == 0 { + t.Error("X-Real-IP should be set even when setHost/setProto are false") + } + + // X-Forwarded-Host 和 X-Forwarded-Proto 不应该设置 + if len(headers.Peek("X-Forwarded-Host")) > 0 { + t.Error("X-Forwarded-Host should not be set when setHost=false") + } + if len(headers.Peek("X-Forwarded-Proto")) > 0 { + t.Error("X-Forwarded-Proto should not be set when setProto=false") + } +} + +// TestWriteForwardedHeaders_SetHost 测试 WriteForwardedHeaders 的 setHost 参数 +func TestWriteForwardedHeaders_SetHost(t *testing.T) { + tests := []struct { + name string + setHost bool + expectHost bool + }{ + { + name: "setHost=true writes X-Forwarded-Host", + setHost: true, + expectHost: true, + }, + { + name: "setHost=false does not write X-Forwarded-Host", + setHost: false, + expectHost: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var builder strings.Builder + fh := ForwardedHeaders{ + ClientIP: "192.168.1.1", + Host: "example.com:8080", + Proto: "https", + } + + WriteForwardedHeaders(&builder, fh, tt.setHost, true) + + result := builder.String() + hasHost := strings.Contains(result, "X-Forwarded-Host:") + if hasHost != tt.expectHost { + t.Errorf("X-Forwarded-Host presence = %v, want %v", hasHost, tt.expectHost) + } + + // X-Forwarded-For 和 X-Real-IP 应该始终存在 + if !strings.Contains(result, "X-Forwarded-For:") { + t.Error("X-Forwarded-For should always be written") + } + if !strings.Contains(result, "X-Real-IP:") { + t.Error("X-Real-IP should always be written") + } + }) + } +} + +// TestWriteForwardedHeaders_SetProto 测试 WriteForwardedHeaders 的 setProto 参数 +func TestWriteForwardedHeaders_SetProto(t *testing.T) { + tests := []struct { + name string + setProto bool + expectProto bool + }{ + { + name: "setProto=true writes X-Forwarded-Proto", + setProto: true, + expectProto: true, + }, + { + name: "setProto=false does not write X-Forwarded-Proto", + setProto: false, + expectProto: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var builder strings.Builder + fh := ForwardedHeaders{ + ClientIP: "192.168.1.1", + Host: "example.com:8080", + Proto: "https", + } + + WriteForwardedHeaders(&builder, fh, true, tt.setProto) + + result := builder.String() + hasProto := strings.Contains(result, "X-Forwarded-Proto:") + if hasProto != tt.expectProto { + t.Errorf("X-Forwarded-Proto presence = %v, want %v", hasProto, tt.expectProto) + } + }) + } +} + +// TestWriteForwardedHeaders_AllDisabled 测试 WriteForwardedHeaders 所有控制参数为 false +func TestWriteForwardedHeaders_AllDisabled(t *testing.T) { + var builder strings.Builder + fh := ForwardedHeaders{ + ClientIP: "10.0.0.1", + Host: "localhost:8082", + Proto: "http", + } + + WriteForwardedHeaders(&builder, fh, false, false) + + result := builder.String() + + // X-Forwarded-For 和 X-Real-IP 应该始终存在 + if !strings.Contains(result, "X-Forwarded-For:") { + t.Error("X-Forwarded-For should always be written") + } + if !strings.Contains(result, "X-Real-IP:") { + t.Error("X-Real-IP should always be written") + } + + // X-Forwarded-Host 和 X-Forwarded-Proto 不应该存在 + if strings.Contains(result, "X-Forwarded-Host:") { + t.Error("X-Forwarded-Host should not be written when setHost=false") + } + if strings.Contains(result, "X-Forwarded-Proto:") { + t.Error("X-Forwarded-Proto should not be written when setProto=false") + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 7ca2568..7b86a81 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -572,7 +572,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { // WebSocket 使用 defer 确保连接计数释放 defer loadbalance.DecrementConnections(target) timing.MarkConnectStart() - err := WebSocket(ctx, target, p.config.Timeout.Connect) + err := WebSocket(ctx, target, p.config.Timeout.Connect, &p.config.Headers) timing.MarkConnectEnd() if err != nil { upstreamStatus = 502 diff --git a/internal/proxy/proxy_coverage_extra_test.go b/internal/proxy/proxy_coverage_extra_test.go index decc9ad..fcdeb67 100644 --- a/internal/proxy/proxy_coverage_extra_test.go +++ b/internal/proxy/proxy_coverage_extra_test.go @@ -1055,7 +1055,7 @@ func TestWebSocket_ErrorCases(t *testing.T) { target.Healthy.Store(true) // 使用很短的超时 - err := WebSocket(ctx, target, 10*time.Millisecond) + err := WebSocket(ctx, target, 10*time.Millisecond, nil) if err == nil { t.Error("WebSocket() should return error for invalid backend") } diff --git a/internal/proxy/websocket.go b/internal/proxy/websocket.go index d53ef0e..59e7c17 100644 --- a/internal/proxy/websocket.go +++ b/internal/proxy/websocket.go @@ -32,6 +32,7 @@ import ( "time" "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/netutil" ) @@ -260,10 +261,11 @@ func dialTarget(targetURL string, timeout time.Duration) (net.Conn, error) { // 参数: // - ctx: FastHTTP 请求上下文 // - targetHost: 目标主机地址 +// - headersConfig: 代理头配置,控制 X-Forwarded-Host/Proto 的设置 // // 返回值: // - string: 完整的 HTTP 请求字符串 -func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string) string { +func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string, headersConfig *config.ProxyHeaders) string { // 构建请求行 path := string(ctx.Path()) if path == "" { @@ -300,7 +302,18 @@ func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string) s // 添加 X-Forwarded 头 fh := ExtractForwardedHeaders(ctx) - WriteForwardedHeaders(&req, fh) + + // 根据配置决定是否设置 X-Forwarded-Host 和 X-Forwarded-Proto + setHost := true // 默认值(向后兼容) + if headersConfig != nil && headersConfig.SetForwardedHost != nil { + setHost = *headersConfig.SetForwardedHost + } + setProto := true // 默认值(向后兼容) + if headersConfig != nil && headersConfig.SetForwardedProto != nil { + setProto = *headersConfig.SetForwardedProto + } + + WriteForwardedHeaders(&req, fh, setHost, setProto) // 结束请求头 req.WriteString("\r\n") @@ -348,10 +361,11 @@ func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.R // - ctx: FastHTTP 请求上下文 // - target: 负载均衡目标,包含后端 URL // - timeout: 连接和 I/O 超时时间 +// - headersConfig: 代理头配置,控制 X-Forwarded-Host/Proto 的设置 // // 返回值: // - error: 代理过程中的错误 -func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout time.Duration) error { +func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout time.Duration, headersConfig *config.ProxyHeaders) error { // 使用 Hijack 获取客户端 TCP 连接 var clientConn net.Conn @@ -380,7 +394,7 @@ func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout tim targetHost := extractHost(target.URL) // 步骤3: 构建并发送 WebSocket 升级请求 - upgradeReq := buildWebSocketUpgradeRequest(ctx, targetHost) + upgradeReq := buildWebSocketUpgradeRequest(ctx, targetHost, headersConfig) if _, writeErr := targetConn.Write([]byte(upgradeReq)); writeErr != nil { return fmt.Errorf("failed to send upgrade request: %w", writeErr) } diff --git a/internal/proxy/websocket_bench_test.go b/internal/proxy/websocket_bench_test.go index aeb3443..af91032 100644 --- a/internal/proxy/websocket_bench_test.go +++ b/internal/proxy/websocket_bench_test.go @@ -46,7 +46,7 @@ func BenchmarkWebSocketHandshake(b *testing.B) { ctx.Request.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate") ctx.Request.Header.Set("Origin", "https://example.com") - result := buildWebSocketUpgradeRequest(ctx, "backend.example.com:8080") + result := buildWebSocketUpgradeRequest(ctx, "backend.example.com:8080", nil) // 验证握手请求包含关键头 if !strings.Contains(result, "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==") { diff --git a/internal/proxy/websocket_test.go b/internal/proxy/websocket_test.go index cd7932c..78b2eaf 100644 --- a/internal/proxy/websocket_test.go +++ b/internal/proxy/websocket_test.go @@ -373,7 +373,7 @@ func TestBuildWebSocketUpgradeRequest(t *testing.T) { } ctx.Request.Header.SetHost(tt.host) - result := buildWebSocketUpgradeRequest(ctx, tt.targetHost) + result := buildWebSocketUpgradeRequest(ctx, tt.targetHost, nil) for _, want := range tt.wantContains { if !strings.Contains(result, want) { @@ -394,7 +394,7 @@ func TestBuildWebSocketUpgradeRequest_WithHeaders(t *testing.T) { ctx.Request.Header.Set("Sec-WebSocket-Version", "13") ctx.Request.Header.Set("Sec-WebSocket-Protocol", "chat") - result := buildWebSocketUpgradeRequest(ctx, "backend.example.com") + result := buildWebSocketUpgradeRequest(ctx, "backend.example.com", nil) // 验证关键头被复制 expectedHeaders := []string{ @@ -434,7 +434,7 @@ func TestBuildWebSocketUpgradeRequest_TLSProto(t *testing.T) { // 注意:fasthttp.RequestCtx 默认 IsTLS() 返回 false // 无法在单元测试中直接模拟 TLS 连接 - result := buildWebSocketUpgradeRequest(ctx, "backend.example.com") + result := buildWebSocketUpgradeRequest(ctx, "backend.example.com", nil) if !strings.Contains(result, tt.wantProto) { t.Errorf("Missing %q in:\n%s", tt.wantProto, result)