feat(proxy): add configurable X-Forwarded-Host and X-Forwarded-Proto headers

Add `set_forwarded_host` and `set_forwarded_proto` options to control
whether the proxy automatically sets these headers. This fixes issues
with upstream servers that validate X-Forwarded-Host against known hosts.

Changes:
- Add SetForwardedHost/SetForwardedProto fields to ProxyHeaders struct
- Modify SetForwardedHeaders and WriteForwardedHeaders function signatures
- Update modifyRequestHeaders to read config and pass control parameters
- Update WebSocket call chain to support new config
- Add unit tests for new functionality
- Update default config generation (-g) to include new options

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-05-07 13:28:28 +08:00
parent c02008cc6a
commit 144e101c09
10 changed files with 326 additions and 19 deletions

View File

@ -412,6 +412,8 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
buf.WriteString(" # hide_response: [] # 隐藏的响应头列表\n")
buf.WriteString(" # pass_response: [] # 白名单传递的响应头\n")
buf.WriteString(" # ignore_headers: [] # 完全忽略的头部(不传递给客户端也不记录)\n")
buf.WriteString(" # set_forwarded_host: true # 是否设置 X-Forwarded-Hostnil/true=设置false=不设置)\n")
buf.WriteString(" # set_forwarded_proto: true # 是否设置 X-Forwarded-Protonil/true=设置false=不设置)\n")
buf.WriteString(" # cookie_domain: \"\" # Cookie 域重写\n")
buf.WriteString(" # cookie_path: \"\" # Cookie 路径重写\n")
buf.WriteString(" # cache: # 代理缓存\n")

View File

@ -334,6 +334,18 @@ type ProxyHeaders struct {
// CookiePath Cookie 路径重写
// 将响应中 Set-Cookie 的 path 替换为此值
CookiePath string `yaml:"cookie_path"`
// SetForwardedHost 控制 X-Forwarded-Host 头的设置
// nil (默认): 设置 X-Forwarded-Host向后兼容
// true: 显式设置 X-Forwarded-Host
// false: 不设置 X-Forwarded-Host
SetForwardedHost *bool `yaml:"set_forwarded_host"`
// SetForwardedProto 控制 X-Forwarded-Proto 头的设置
// nil (默认): 设置 X-Forwarded-Proto向后兼容
// true: 显式设置 X-Forwarded-Proto
// false: 不设置 X-Forwarded-Proto
SetForwardedProto *bool `yaml:"set_forwarded_proto"`
}
// ProxySSLConfig 上游 SSL/TLS 配置。

View File

@ -32,7 +32,18 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan
// 提取并设置 X-Forwarded 系列头
fh := ExtractForwardedHeaders(ctx)
SetForwardedHeaders(headers, fh, true)
// 根据配置决定是否设置 X-Forwarded-Host 和 X-Forwarded-Proto
setHost := true // 默认值(向后兼容)
if p.config.Headers.SetForwardedHost != nil {
setHost = *p.config.Headers.SetForwardedHost
}
setProto := true // 默认值(向后兼容)
if p.config.Headers.SetForwardedProto != nil {
setProto = *p.config.Headers.SetForwardedProto
}
SetForwardedHeaders(headers, fh, true, setHost, setProto)
// 从配置设置自定义请求头(支持变量展开)
if p.config.Headers.SetRequest != nil {

View File

@ -69,7 +69,9 @@ func ExtractForwardedHeaders(ctx *fasthttp.RequestCtx) ForwardedHeaders {
// - headers: 目标请求头
// - fh: ForwardedHeaders 结构体
// - appendXFF: 是否追加到已有的 X-Forwarded-For 头
func SetForwardedHeaders(headers *fasthttp.RequestHeader, fh ForwardedHeaders, appendXFF bool) {
// - setHost: 是否设置 X-Forwarded-Host
// - setProto: 是否设置 X-Forwarded-Proto
func SetForwardedHeaders(headers *fasthttp.RequestHeader, fh ForwardedHeaders, appendXFF, setHost, setProto bool) {
// 设置 X-Real-IP
if fh.ClientIP != "" {
headers.Set("X-Real-IP", fh.ClientIP)
@ -94,13 +96,13 @@ func SetForwardedHeaders(headers *fasthttp.RequestHeader, fh ForwardedHeaders, a
}
}
// 设置 X-Forwarded-Host
if fh.Host != "" {
// 设置 X-Forwarded-Host(仅在 setHost 为 true 时)
if setHost && fh.Host != "" {
headers.Set("X-Forwarded-Host", fh.Host)
}
// 设置 X-Forwarded-Proto
if fh.Proto != "" {
// 设置 X-Forwarded-Proto(仅在 setProto 为 true 时)
if setProto && fh.Proto != "" {
headers.Set("X-Forwarded-Proto", fh.Proto)
}
}
@ -111,7 +113,9 @@ func SetForwardedHeaders(headers *fasthttp.RequestHeader, fh ForwardedHeaders, a
// 参数:
// - builder: strings.Builder 实例
// - fh: ForwardedHeaders 结构体
func WriteForwardedHeaders(builder *strings.Builder, fh ForwardedHeaders) {
// - setHost: 是否设置 X-Forwarded-Host
// - setProto: 是否设置 X-Forwarded-Proto
func WriteForwardedHeaders(builder *strings.Builder, fh ForwardedHeaders, setHost, setProto bool) {
if fh.ClientIP != "" {
builder.WriteString("X-Forwarded-For: ")
builder.WriteString(fh.ClientIP)
@ -121,13 +125,13 @@ func WriteForwardedHeaders(builder *strings.Builder, fh ForwardedHeaders) {
builder.WriteString("\r\n")
}
if fh.Host != "" {
if setHost && fh.Host != "" {
builder.WriteString("X-Forwarded-Host: ")
builder.WriteString(fh.Host)
builder.WriteString("\r\n")
}
if fh.Proto != "" {
if setProto && fh.Proto != "" {
builder.WriteString("X-Forwarded-Proto: ")
builder.WriteString(fh.Proto)
builder.WriteString("\r\n")

View File

@ -0,0 +1,264 @@
package proxy
import (
"strings"
"testing"
"github.com/valyala/fasthttp"
)
// TestSetForwardedHeaders_SetHost 测试 SetForwardedHost 配置对 X-Forwarded-Host 头的控制
func TestSetForwardedHeaders_SetHost(t *testing.T) {
tests := []struct {
name string
setHost bool
expectHost bool
}{
{
name: "setHost=true sets X-Forwarded-Host",
setHost: true,
expectHost: true,
},
{
name: "setHost=false does not set X-Forwarded-Host",
setHost: false,
expectHost: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
headers := &fasthttp.RequestHeader{}
fh := ForwardedHeaders{
ClientIP: "192.168.1.1",
Host: "example.com:8080",
Proto: "https",
}
// setProto=true 因为我们要测试 setHost 的效果
SetForwardedHeaders(headers, fh, true, tt.setHost, true)
hasHost := len(headers.Peek("X-Forwarded-Host")) > 0
if hasHost != tt.expectHost {
t.Errorf("X-Forwarded-Host presence = %v, want %v", hasHost, tt.expectHost)
}
// X-Forwarded-For 和 X-Real-IP 应该始终设置
if len(headers.Peek("X-Forwarded-For")) == 0 {
t.Error("X-Forwarded-For should always be set")
}
if len(headers.Peek("X-Real-IP")) == 0 {
t.Error("X-Real-IP should always be set")
}
})
}
}
// TestSetForwardedHeaders_SetProto 测试 SetForwardedProto 配置对 X-Forwarded-Proto 头的控制
func TestSetForwardedHeaders_SetProto(t *testing.T) {
tests := []struct {
name string
setProto bool
expectProto bool
}{
{
name: "setProto=true sets X-Forwarded-Proto",
setProto: true,
expectProto: true,
},
{
name: "setProto=false does not set X-Forwarded-Proto",
setProto: false,
expectProto: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
headers := &fasthttp.RequestHeader{}
fh := ForwardedHeaders{
ClientIP: "192.168.1.1",
Host: "example.com:8080",
Proto: "https",
}
// setHost=true 因为我们要测试 setProto 的效果
SetForwardedHeaders(headers, fh, true, true, tt.setProto)
hasProto := len(headers.Peek("X-Forwarded-Proto")) > 0
if hasProto != tt.expectProto {
t.Errorf("X-Forwarded-Proto presence = %v, want %v", hasProto, tt.expectProto)
}
})
}
}
// TestSetForwardedHeaders_DefaultBehavior 测试默认行为(所有参数为 true
func TestSetForwardedHeaders_DefaultBehavior(t *testing.T) {
headers := &fasthttp.RequestHeader{}
fh := ForwardedHeaders{
ClientIP: "10.0.0.1",
Host: "localhost:8082",
Proto: "http",
}
SetForwardedHeaders(headers, fh, true, true, true)
// 验证所有头都被设置
if string(headers.Peek("X-Forwarded-For")) != "10.0.0.1" {
t.Errorf("X-Forwarded-For = %s, want 10.0.0.1", headers.Peek("X-Forwarded-For"))
}
if string(headers.Peek("X-Real-IP")) != "10.0.0.1" {
t.Errorf("X-Real-IP = %s, want 10.0.0.1", headers.Peek("X-Real-IP"))
}
if string(headers.Peek("X-Forwarded-Host")) != "localhost:8082" {
t.Errorf("X-Forwarded-Host = %s, want localhost:8082", headers.Peek("X-Forwarded-Host"))
}
if string(headers.Peek("X-Forwarded-Proto")) != "http" {
t.Errorf("X-Forwarded-Proto = %s, want http", headers.Peek("X-Forwarded-Proto"))
}
}
// TestSetForwardedHeaders_AllDisabled 测试所有控制参数为 false
func TestSetForwardedHeaders_AllDisabled(t *testing.T) {
headers := &fasthttp.RequestHeader{}
fh := ForwardedHeaders{
ClientIP: "10.0.0.1",
Host: "localhost:8082",
Proto: "http",
}
SetForwardedHeaders(headers, fh, true, false, false)
// X-Forwarded-For 和 X-Real-IP 应该始终设置
if len(headers.Peek("X-Forwarded-For")) == 0 {
t.Error("X-Forwarded-For should be set even when setHost/setProto are false")
}
if len(headers.Peek("X-Real-IP")) == 0 {
t.Error("X-Real-IP should be set even when setHost/setProto are false")
}
// X-Forwarded-Host 和 X-Forwarded-Proto 不应该设置
if len(headers.Peek("X-Forwarded-Host")) > 0 {
t.Error("X-Forwarded-Host should not be set when setHost=false")
}
if len(headers.Peek("X-Forwarded-Proto")) > 0 {
t.Error("X-Forwarded-Proto should not be set when setProto=false")
}
}
// TestWriteForwardedHeaders_SetHost 测试 WriteForwardedHeaders 的 setHost 参数
func TestWriteForwardedHeaders_SetHost(t *testing.T) {
tests := []struct {
name string
setHost bool
expectHost bool
}{
{
name: "setHost=true writes X-Forwarded-Host",
setHost: true,
expectHost: true,
},
{
name: "setHost=false does not write X-Forwarded-Host",
setHost: false,
expectHost: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var builder strings.Builder
fh := ForwardedHeaders{
ClientIP: "192.168.1.1",
Host: "example.com:8080",
Proto: "https",
}
WriteForwardedHeaders(&builder, fh, tt.setHost, true)
result := builder.String()
hasHost := strings.Contains(result, "X-Forwarded-Host:")
if hasHost != tt.expectHost {
t.Errorf("X-Forwarded-Host presence = %v, want %v", hasHost, tt.expectHost)
}
// X-Forwarded-For 和 X-Real-IP 应该始终存在
if !strings.Contains(result, "X-Forwarded-For:") {
t.Error("X-Forwarded-For should always be written")
}
if !strings.Contains(result, "X-Real-IP:") {
t.Error("X-Real-IP should always be written")
}
})
}
}
// TestWriteForwardedHeaders_SetProto 测试 WriteForwardedHeaders 的 setProto 参数
func TestWriteForwardedHeaders_SetProto(t *testing.T) {
tests := []struct {
name string
setProto bool
expectProto bool
}{
{
name: "setProto=true writes X-Forwarded-Proto",
setProto: true,
expectProto: true,
},
{
name: "setProto=false does not write X-Forwarded-Proto",
setProto: false,
expectProto: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var builder strings.Builder
fh := ForwardedHeaders{
ClientIP: "192.168.1.1",
Host: "example.com:8080",
Proto: "https",
}
WriteForwardedHeaders(&builder, fh, true, tt.setProto)
result := builder.String()
hasProto := strings.Contains(result, "X-Forwarded-Proto:")
if hasProto != tt.expectProto {
t.Errorf("X-Forwarded-Proto presence = %v, want %v", hasProto, tt.expectProto)
}
})
}
}
// TestWriteForwardedHeaders_AllDisabled 测试 WriteForwardedHeaders 所有控制参数为 false
func TestWriteForwardedHeaders_AllDisabled(t *testing.T) {
var builder strings.Builder
fh := ForwardedHeaders{
ClientIP: "10.0.0.1",
Host: "localhost:8082",
Proto: "http",
}
WriteForwardedHeaders(&builder, fh, false, false)
result := builder.String()
// X-Forwarded-For 和 X-Real-IP 应该始终存在
if !strings.Contains(result, "X-Forwarded-For:") {
t.Error("X-Forwarded-For should always be written")
}
if !strings.Contains(result, "X-Real-IP:") {
t.Error("X-Real-IP should always be written")
}
// X-Forwarded-Host 和 X-Forwarded-Proto 不应该存在
if strings.Contains(result, "X-Forwarded-Host:") {
t.Error("X-Forwarded-Host should not be written when setHost=false")
}
if strings.Contains(result, "X-Forwarded-Proto:") {
t.Error("X-Forwarded-Proto should not be written when setProto=false")
}
}

View File

@ -572,7 +572,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
// WebSocket 使用 defer 确保连接计数释放
defer loadbalance.DecrementConnections(target)
timing.MarkConnectStart()
err := WebSocket(ctx, target, p.config.Timeout.Connect)
err := WebSocket(ctx, target, p.config.Timeout.Connect, &p.config.Headers)
timing.MarkConnectEnd()
if err != nil {
upstreamStatus = 502

View File

@ -1055,7 +1055,7 @@ func TestWebSocket_ErrorCases(t *testing.T) {
target.Healthy.Store(true)
// 使用很短的超时
err := WebSocket(ctx, target, 10*time.Millisecond)
err := WebSocket(ctx, target, 10*time.Millisecond, nil)
if err == nil {
t.Error("WebSocket() should return error for invalid backend")
}

View File

@ -32,6 +32,7 @@ import (
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/netutil"
)
@ -260,10 +261,11 @@ func dialTarget(targetURL string, timeout time.Duration) (net.Conn, error) {
// 参数:
// - ctx: FastHTTP 请求上下文
// - targetHost: 目标主机地址
// - headersConfig: 代理头配置,控制 X-Forwarded-Host/Proto 的设置
//
// 返回值:
// - string: 完整的 HTTP 请求字符串
func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string) string {
func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string, headersConfig *config.ProxyHeaders) string {
// 构建请求行
path := string(ctx.Path())
if path == "" {
@ -300,7 +302,18 @@ func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string) s
// 添加 X-Forwarded 头
fh := ExtractForwardedHeaders(ctx)
WriteForwardedHeaders(&req, fh)
// 根据配置决定是否设置 X-Forwarded-Host 和 X-Forwarded-Proto
setHost := true // 默认值(向后兼容)
if headersConfig != nil && headersConfig.SetForwardedHost != nil {
setHost = *headersConfig.SetForwardedHost
}
setProto := true // 默认值(向后兼容)
if headersConfig != nil && headersConfig.SetForwardedProto != nil {
setProto = *headersConfig.SetForwardedProto
}
WriteForwardedHeaders(&req, fh, setHost, setProto)
// 结束请求头
req.WriteString("\r\n")
@ -348,10 +361,11 @@ func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.R
// - ctx: FastHTTP 请求上下文
// - target: 负载均衡目标,包含后端 URL
// - timeout: 连接和 I/O 超时时间
// - headersConfig: 代理头配置,控制 X-Forwarded-Host/Proto 的设置
//
// 返回值:
// - error: 代理过程中的错误
func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout time.Duration) error {
func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout time.Duration, headersConfig *config.ProxyHeaders) error {
// 使用 Hijack 获取客户端 TCP 连接
var clientConn net.Conn
@ -380,7 +394,7 @@ func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout tim
targetHost := extractHost(target.URL)
// 步骤3: 构建并发送 WebSocket 升级请求
upgradeReq := buildWebSocketUpgradeRequest(ctx, targetHost)
upgradeReq := buildWebSocketUpgradeRequest(ctx, targetHost, headersConfig)
if _, writeErr := targetConn.Write([]byte(upgradeReq)); writeErr != nil {
return fmt.Errorf("failed to send upgrade request: %w", writeErr)
}

View File

@ -46,7 +46,7 @@ func BenchmarkWebSocketHandshake(b *testing.B) {
ctx.Request.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate")
ctx.Request.Header.Set("Origin", "https://example.com")
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com:8080")
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com:8080", nil)
// 验证握手请求包含关键头
if !strings.Contains(result, "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==") {

View File

@ -373,7 +373,7 @@ func TestBuildWebSocketUpgradeRequest(t *testing.T) {
}
ctx.Request.Header.SetHost(tt.host)
result := buildWebSocketUpgradeRequest(ctx, tt.targetHost)
result := buildWebSocketUpgradeRequest(ctx, tt.targetHost, nil)
for _, want := range tt.wantContains {
if !strings.Contains(result, want) {
@ -394,7 +394,7 @@ func TestBuildWebSocketUpgradeRequest_WithHeaders(t *testing.T) {
ctx.Request.Header.Set("Sec-WebSocket-Version", "13")
ctx.Request.Header.Set("Sec-WebSocket-Protocol", "chat")
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com")
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com", nil)
// 验证关键头被复制
expectedHeaders := []string{
@ -434,7 +434,7 @@ func TestBuildWebSocketUpgradeRequest_TLSProto(t *testing.T) {
// 注意fasthttp.RequestCtx 默认 IsTLS() 返回 false
// 无法在单元测试中直接模拟 TLS 连接
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com")
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com", nil)
if !strings.Contains(result, tt.wantProto) {
t.Errorf("Missing %q in:\n%s", tt.wantProto, result)