feat(proxy): add configurable X-Forwarded-Host and X-Forwarded-Proto headers
Add `set_forwarded_host` and `set_forwarded_proto` options to control whether the proxy automatically sets these headers. This fixes issues with upstream servers that validate X-Forwarded-Host against known hosts. Changes: - Add SetForwardedHost/SetForwardedProto fields to ProxyHeaders struct - Modify SetForwardedHeaders and WriteForwardedHeaders function signatures - Update modifyRequestHeaders to read config and pass control parameters - Update WebSocket call chain to support new config - Add unit tests for new functionality - Update default config generation (-g) to include new options Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
c02008cc6a
commit
144e101c09
@ -412,6 +412,8 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
|||||||
buf.WriteString(" # hide_response: [] # 隐藏的响应头列表\n")
|
buf.WriteString(" # hide_response: [] # 隐藏的响应头列表\n")
|
||||||
buf.WriteString(" # pass_response: [] # 白名单传递的响应头\n")
|
buf.WriteString(" # pass_response: [] # 白名单传递的响应头\n")
|
||||||
buf.WriteString(" # ignore_headers: [] # 完全忽略的头部(不传递给客户端也不记录)\n")
|
buf.WriteString(" # ignore_headers: [] # 完全忽略的头部(不传递给客户端也不记录)\n")
|
||||||
|
buf.WriteString(" # set_forwarded_host: true # 是否设置 X-Forwarded-Host(nil/true=设置,false=不设置)\n")
|
||||||
|
buf.WriteString(" # set_forwarded_proto: true # 是否设置 X-Forwarded-Proto(nil/true=设置,false=不设置)\n")
|
||||||
buf.WriteString(" # cookie_domain: \"\" # Cookie 域重写\n")
|
buf.WriteString(" # cookie_domain: \"\" # Cookie 域重写\n")
|
||||||
buf.WriteString(" # cookie_path: \"\" # Cookie 路径重写\n")
|
buf.WriteString(" # cookie_path: \"\" # Cookie 路径重写\n")
|
||||||
buf.WriteString(" # cache: # 代理缓存\n")
|
buf.WriteString(" # cache: # 代理缓存\n")
|
||||||
|
|||||||
@ -334,6 +334,18 @@ type ProxyHeaders struct {
|
|||||||
// CookiePath Cookie 路径重写
|
// CookiePath Cookie 路径重写
|
||||||
// 将响应中 Set-Cookie 的 path 替换为此值
|
// 将响应中 Set-Cookie 的 path 替换为此值
|
||||||
CookiePath string `yaml:"cookie_path"`
|
CookiePath string `yaml:"cookie_path"`
|
||||||
|
|
||||||
|
// SetForwardedHost 控制 X-Forwarded-Host 头的设置
|
||||||
|
// nil (默认): 设置 X-Forwarded-Host(向后兼容)
|
||||||
|
// true: 显式设置 X-Forwarded-Host
|
||||||
|
// false: 不设置 X-Forwarded-Host
|
||||||
|
SetForwardedHost *bool `yaml:"set_forwarded_host"`
|
||||||
|
|
||||||
|
// SetForwardedProto 控制 X-Forwarded-Proto 头的设置
|
||||||
|
// nil (默认): 设置 X-Forwarded-Proto(向后兼容)
|
||||||
|
// true: 显式设置 X-Forwarded-Proto
|
||||||
|
// false: 不设置 X-Forwarded-Proto
|
||||||
|
SetForwardedProto *bool `yaml:"set_forwarded_proto"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProxySSLConfig 上游 SSL/TLS 配置。
|
// ProxySSLConfig 上游 SSL/TLS 配置。
|
||||||
|
|||||||
@ -32,7 +32,18 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan
|
|||||||
|
|
||||||
// 提取并设置 X-Forwarded 系列头
|
// 提取并设置 X-Forwarded 系列头
|
||||||
fh := ExtractForwardedHeaders(ctx)
|
fh := ExtractForwardedHeaders(ctx)
|
||||||
SetForwardedHeaders(headers, fh, true)
|
|
||||||
|
// 根据配置决定是否设置 X-Forwarded-Host 和 X-Forwarded-Proto
|
||||||
|
setHost := true // 默认值(向后兼容)
|
||||||
|
if p.config.Headers.SetForwardedHost != nil {
|
||||||
|
setHost = *p.config.Headers.SetForwardedHost
|
||||||
|
}
|
||||||
|
setProto := true // 默认值(向后兼容)
|
||||||
|
if p.config.Headers.SetForwardedProto != nil {
|
||||||
|
setProto = *p.config.Headers.SetForwardedProto
|
||||||
|
}
|
||||||
|
|
||||||
|
SetForwardedHeaders(headers, fh, true, setHost, setProto)
|
||||||
|
|
||||||
// 从配置设置自定义请求头(支持变量展开)
|
// 从配置设置自定义请求头(支持变量展开)
|
||||||
if p.config.Headers.SetRequest != nil {
|
if p.config.Headers.SetRequest != nil {
|
||||||
|
|||||||
@ -69,7 +69,9 @@ func ExtractForwardedHeaders(ctx *fasthttp.RequestCtx) ForwardedHeaders {
|
|||||||
// - headers: 目标请求头
|
// - headers: 目标请求头
|
||||||
// - fh: ForwardedHeaders 结构体
|
// - fh: ForwardedHeaders 结构体
|
||||||
// - appendXFF: 是否追加到已有的 X-Forwarded-For 头
|
// - appendXFF: 是否追加到已有的 X-Forwarded-For 头
|
||||||
func SetForwardedHeaders(headers *fasthttp.RequestHeader, fh ForwardedHeaders, appendXFF bool) {
|
// - setHost: 是否设置 X-Forwarded-Host
|
||||||
|
// - setProto: 是否设置 X-Forwarded-Proto
|
||||||
|
func SetForwardedHeaders(headers *fasthttp.RequestHeader, fh ForwardedHeaders, appendXFF, setHost, setProto bool) {
|
||||||
// 设置 X-Real-IP
|
// 设置 X-Real-IP
|
||||||
if fh.ClientIP != "" {
|
if fh.ClientIP != "" {
|
||||||
headers.Set("X-Real-IP", fh.ClientIP)
|
headers.Set("X-Real-IP", fh.ClientIP)
|
||||||
@ -94,13 +96,13 @@ func SetForwardedHeaders(headers *fasthttp.RequestHeader, fh ForwardedHeaders, a
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置 X-Forwarded-Host
|
// 设置 X-Forwarded-Host(仅在 setHost 为 true 时)
|
||||||
if fh.Host != "" {
|
if setHost && fh.Host != "" {
|
||||||
headers.Set("X-Forwarded-Host", fh.Host)
|
headers.Set("X-Forwarded-Host", fh.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置 X-Forwarded-Proto
|
// 设置 X-Forwarded-Proto(仅在 setProto 为 true 时)
|
||||||
if fh.Proto != "" {
|
if setProto && fh.Proto != "" {
|
||||||
headers.Set("X-Forwarded-Proto", fh.Proto)
|
headers.Set("X-Forwarded-Proto", fh.Proto)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -111,7 +113,9 @@ func SetForwardedHeaders(headers *fasthttp.RequestHeader, fh ForwardedHeaders, a
|
|||||||
// 参数:
|
// 参数:
|
||||||
// - builder: strings.Builder 实例
|
// - builder: strings.Builder 实例
|
||||||
// - fh: ForwardedHeaders 结构体
|
// - fh: ForwardedHeaders 结构体
|
||||||
func WriteForwardedHeaders(builder *strings.Builder, fh ForwardedHeaders) {
|
// - setHost: 是否设置 X-Forwarded-Host
|
||||||
|
// - setProto: 是否设置 X-Forwarded-Proto
|
||||||
|
func WriteForwardedHeaders(builder *strings.Builder, fh ForwardedHeaders, setHost, setProto bool) {
|
||||||
if fh.ClientIP != "" {
|
if fh.ClientIP != "" {
|
||||||
builder.WriteString("X-Forwarded-For: ")
|
builder.WriteString("X-Forwarded-For: ")
|
||||||
builder.WriteString(fh.ClientIP)
|
builder.WriteString(fh.ClientIP)
|
||||||
@ -121,13 +125,13 @@ func WriteForwardedHeaders(builder *strings.Builder, fh ForwardedHeaders) {
|
|||||||
builder.WriteString("\r\n")
|
builder.WriteString("\r\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
if fh.Host != "" {
|
if setHost && fh.Host != "" {
|
||||||
builder.WriteString("X-Forwarded-Host: ")
|
builder.WriteString("X-Forwarded-Host: ")
|
||||||
builder.WriteString(fh.Host)
|
builder.WriteString(fh.Host)
|
||||||
builder.WriteString("\r\n")
|
builder.WriteString("\r\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
if fh.Proto != "" {
|
if setProto && fh.Proto != "" {
|
||||||
builder.WriteString("X-Forwarded-Proto: ")
|
builder.WriteString("X-Forwarded-Proto: ")
|
||||||
builder.WriteString(fh.Proto)
|
builder.WriteString(fh.Proto)
|
||||||
builder.WriteString("\r\n")
|
builder.WriteString("\r\n")
|
||||||
|
|||||||
264
internal/proxy/headers_test.go
Normal file
264
internal/proxy/headers_test.go
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSetForwardedHeaders_SetHost 测试 SetForwardedHost 配置对 X-Forwarded-Host 头的控制
|
||||||
|
func TestSetForwardedHeaders_SetHost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setHost bool
|
||||||
|
expectHost bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "setHost=true sets X-Forwarded-Host",
|
||||||
|
setHost: true,
|
||||||
|
expectHost: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "setHost=false does not set X-Forwarded-Host",
|
||||||
|
setHost: false,
|
||||||
|
expectHost: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
headers := &fasthttp.RequestHeader{}
|
||||||
|
fh := ForwardedHeaders{
|
||||||
|
ClientIP: "192.168.1.1",
|
||||||
|
Host: "example.com:8080",
|
||||||
|
Proto: "https",
|
||||||
|
}
|
||||||
|
|
||||||
|
// setProto=true 因为我们要测试 setHost 的效果
|
||||||
|
SetForwardedHeaders(headers, fh, true, tt.setHost, true)
|
||||||
|
|
||||||
|
hasHost := len(headers.Peek("X-Forwarded-Host")) > 0
|
||||||
|
if hasHost != tt.expectHost {
|
||||||
|
t.Errorf("X-Forwarded-Host presence = %v, want %v", hasHost, tt.expectHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
// X-Forwarded-For 和 X-Real-IP 应该始终设置
|
||||||
|
if len(headers.Peek("X-Forwarded-For")) == 0 {
|
||||||
|
t.Error("X-Forwarded-For should always be set")
|
||||||
|
}
|
||||||
|
if len(headers.Peek("X-Real-IP")) == 0 {
|
||||||
|
t.Error("X-Real-IP should always be set")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSetForwardedHeaders_SetProto 测试 SetForwardedProto 配置对 X-Forwarded-Proto 头的控制
|
||||||
|
func TestSetForwardedHeaders_SetProto(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setProto bool
|
||||||
|
expectProto bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "setProto=true sets X-Forwarded-Proto",
|
||||||
|
setProto: true,
|
||||||
|
expectProto: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "setProto=false does not set X-Forwarded-Proto",
|
||||||
|
setProto: false,
|
||||||
|
expectProto: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
headers := &fasthttp.RequestHeader{}
|
||||||
|
fh := ForwardedHeaders{
|
||||||
|
ClientIP: "192.168.1.1",
|
||||||
|
Host: "example.com:8080",
|
||||||
|
Proto: "https",
|
||||||
|
}
|
||||||
|
|
||||||
|
// setHost=true 因为我们要测试 setProto 的效果
|
||||||
|
SetForwardedHeaders(headers, fh, true, true, tt.setProto)
|
||||||
|
|
||||||
|
hasProto := len(headers.Peek("X-Forwarded-Proto")) > 0
|
||||||
|
if hasProto != tt.expectProto {
|
||||||
|
t.Errorf("X-Forwarded-Proto presence = %v, want %v", hasProto, tt.expectProto)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSetForwardedHeaders_DefaultBehavior 测试默认行为(所有参数为 true)
|
||||||
|
func TestSetForwardedHeaders_DefaultBehavior(t *testing.T) {
|
||||||
|
headers := &fasthttp.RequestHeader{}
|
||||||
|
fh := ForwardedHeaders{
|
||||||
|
ClientIP: "10.0.0.1",
|
||||||
|
Host: "localhost:8082",
|
||||||
|
Proto: "http",
|
||||||
|
}
|
||||||
|
|
||||||
|
SetForwardedHeaders(headers, fh, true, true, true)
|
||||||
|
|
||||||
|
// 验证所有头都被设置
|
||||||
|
if string(headers.Peek("X-Forwarded-For")) != "10.0.0.1" {
|
||||||
|
t.Errorf("X-Forwarded-For = %s, want 10.0.0.1", headers.Peek("X-Forwarded-For"))
|
||||||
|
}
|
||||||
|
if string(headers.Peek("X-Real-IP")) != "10.0.0.1" {
|
||||||
|
t.Errorf("X-Real-IP = %s, want 10.0.0.1", headers.Peek("X-Real-IP"))
|
||||||
|
}
|
||||||
|
if string(headers.Peek("X-Forwarded-Host")) != "localhost:8082" {
|
||||||
|
t.Errorf("X-Forwarded-Host = %s, want localhost:8082", headers.Peek("X-Forwarded-Host"))
|
||||||
|
}
|
||||||
|
if string(headers.Peek("X-Forwarded-Proto")) != "http" {
|
||||||
|
t.Errorf("X-Forwarded-Proto = %s, want http", headers.Peek("X-Forwarded-Proto"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSetForwardedHeaders_AllDisabled 测试所有控制参数为 false
|
||||||
|
func TestSetForwardedHeaders_AllDisabled(t *testing.T) {
|
||||||
|
headers := &fasthttp.RequestHeader{}
|
||||||
|
fh := ForwardedHeaders{
|
||||||
|
ClientIP: "10.0.0.1",
|
||||||
|
Host: "localhost:8082",
|
||||||
|
Proto: "http",
|
||||||
|
}
|
||||||
|
|
||||||
|
SetForwardedHeaders(headers, fh, true, false, false)
|
||||||
|
|
||||||
|
// X-Forwarded-For 和 X-Real-IP 应该始终设置
|
||||||
|
if len(headers.Peek("X-Forwarded-For")) == 0 {
|
||||||
|
t.Error("X-Forwarded-For should be set even when setHost/setProto are false")
|
||||||
|
}
|
||||||
|
if len(headers.Peek("X-Real-IP")) == 0 {
|
||||||
|
t.Error("X-Real-IP should be set even when setHost/setProto are false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// X-Forwarded-Host 和 X-Forwarded-Proto 不应该设置
|
||||||
|
if len(headers.Peek("X-Forwarded-Host")) > 0 {
|
||||||
|
t.Error("X-Forwarded-Host should not be set when setHost=false")
|
||||||
|
}
|
||||||
|
if len(headers.Peek("X-Forwarded-Proto")) > 0 {
|
||||||
|
t.Error("X-Forwarded-Proto should not be set when setProto=false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteForwardedHeaders_SetHost 测试 WriteForwardedHeaders 的 setHost 参数
|
||||||
|
func TestWriteForwardedHeaders_SetHost(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setHost bool
|
||||||
|
expectHost bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "setHost=true writes X-Forwarded-Host",
|
||||||
|
setHost: true,
|
||||||
|
expectHost: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "setHost=false does not write X-Forwarded-Host",
|
||||||
|
setHost: false,
|
||||||
|
expectHost: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
fh := ForwardedHeaders{
|
||||||
|
ClientIP: "192.168.1.1",
|
||||||
|
Host: "example.com:8080",
|
||||||
|
Proto: "https",
|
||||||
|
}
|
||||||
|
|
||||||
|
WriteForwardedHeaders(&builder, fh, tt.setHost, true)
|
||||||
|
|
||||||
|
result := builder.String()
|
||||||
|
hasHost := strings.Contains(result, "X-Forwarded-Host:")
|
||||||
|
if hasHost != tt.expectHost {
|
||||||
|
t.Errorf("X-Forwarded-Host presence = %v, want %v", hasHost, tt.expectHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
// X-Forwarded-For 和 X-Real-IP 应该始终存在
|
||||||
|
if !strings.Contains(result, "X-Forwarded-For:") {
|
||||||
|
t.Error("X-Forwarded-For should always be written")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "X-Real-IP:") {
|
||||||
|
t.Error("X-Real-IP should always be written")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteForwardedHeaders_SetProto 测试 WriteForwardedHeaders 的 setProto 参数
|
||||||
|
func TestWriteForwardedHeaders_SetProto(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setProto bool
|
||||||
|
expectProto bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "setProto=true writes X-Forwarded-Proto",
|
||||||
|
setProto: true,
|
||||||
|
expectProto: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "setProto=false does not write X-Forwarded-Proto",
|
||||||
|
setProto: false,
|
||||||
|
expectProto: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
fh := ForwardedHeaders{
|
||||||
|
ClientIP: "192.168.1.1",
|
||||||
|
Host: "example.com:8080",
|
||||||
|
Proto: "https",
|
||||||
|
}
|
||||||
|
|
||||||
|
WriteForwardedHeaders(&builder, fh, true, tt.setProto)
|
||||||
|
|
||||||
|
result := builder.String()
|
||||||
|
hasProto := strings.Contains(result, "X-Forwarded-Proto:")
|
||||||
|
if hasProto != tt.expectProto {
|
||||||
|
t.Errorf("X-Forwarded-Proto presence = %v, want %v", hasProto, tt.expectProto)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteForwardedHeaders_AllDisabled 测试 WriteForwardedHeaders 所有控制参数为 false
|
||||||
|
func TestWriteForwardedHeaders_AllDisabled(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
fh := ForwardedHeaders{
|
||||||
|
ClientIP: "10.0.0.1",
|
||||||
|
Host: "localhost:8082",
|
||||||
|
Proto: "http",
|
||||||
|
}
|
||||||
|
|
||||||
|
WriteForwardedHeaders(&builder, fh, false, false)
|
||||||
|
|
||||||
|
result := builder.String()
|
||||||
|
|
||||||
|
// X-Forwarded-For 和 X-Real-IP 应该始终存在
|
||||||
|
if !strings.Contains(result, "X-Forwarded-For:") {
|
||||||
|
t.Error("X-Forwarded-For should always be written")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "X-Real-IP:") {
|
||||||
|
t.Error("X-Real-IP should always be written")
|
||||||
|
}
|
||||||
|
|
||||||
|
// X-Forwarded-Host 和 X-Forwarded-Proto 不应该存在
|
||||||
|
if strings.Contains(result, "X-Forwarded-Host:") {
|
||||||
|
t.Error("X-Forwarded-Host should not be written when setHost=false")
|
||||||
|
}
|
||||||
|
if strings.Contains(result, "X-Forwarded-Proto:") {
|
||||||
|
t.Error("X-Forwarded-Proto should not be written when setProto=false")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -572,7 +572,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
|||||||
// WebSocket 使用 defer 确保连接计数释放
|
// WebSocket 使用 defer 确保连接计数释放
|
||||||
defer loadbalance.DecrementConnections(target)
|
defer loadbalance.DecrementConnections(target)
|
||||||
timing.MarkConnectStart()
|
timing.MarkConnectStart()
|
||||||
err := WebSocket(ctx, target, p.config.Timeout.Connect)
|
err := WebSocket(ctx, target, p.config.Timeout.Connect, &p.config.Headers)
|
||||||
timing.MarkConnectEnd()
|
timing.MarkConnectEnd()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
upstreamStatus = 502
|
upstreamStatus = 502
|
||||||
|
|||||||
@ -1055,7 +1055,7 @@ func TestWebSocket_ErrorCases(t *testing.T) {
|
|||||||
target.Healthy.Store(true)
|
target.Healthy.Store(true)
|
||||||
|
|
||||||
// 使用很短的超时
|
// 使用很短的超时
|
||||||
err := WebSocket(ctx, target, 10*time.Millisecond)
|
err := WebSocket(ctx, target, 10*time.Millisecond, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("WebSocket() should return error for invalid backend")
|
t.Error("WebSocket() should return error for invalid backend")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -32,6 +32,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
"rua.plus/lolly/internal/loadbalance"
|
"rua.plus/lolly/internal/loadbalance"
|
||||||
"rua.plus/lolly/internal/netutil"
|
"rua.plus/lolly/internal/netutil"
|
||||||
)
|
)
|
||||||
@ -260,10 +261,11 @@ func dialTarget(targetURL string, timeout time.Duration) (net.Conn, error) {
|
|||||||
// 参数:
|
// 参数:
|
||||||
// - ctx: FastHTTP 请求上下文
|
// - ctx: FastHTTP 请求上下文
|
||||||
// - targetHost: 目标主机地址
|
// - targetHost: 目标主机地址
|
||||||
|
// - headersConfig: 代理头配置,控制 X-Forwarded-Host/Proto 的设置
|
||||||
//
|
//
|
||||||
// 返回值:
|
// 返回值:
|
||||||
// - string: 完整的 HTTP 请求字符串
|
// - string: 完整的 HTTP 请求字符串
|
||||||
func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string) string {
|
func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string, headersConfig *config.ProxyHeaders) string {
|
||||||
// 构建请求行
|
// 构建请求行
|
||||||
path := string(ctx.Path())
|
path := string(ctx.Path())
|
||||||
if path == "" {
|
if path == "" {
|
||||||
@ -300,7 +302,18 @@ func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string) s
|
|||||||
|
|
||||||
// 添加 X-Forwarded 头
|
// 添加 X-Forwarded 头
|
||||||
fh := ExtractForwardedHeaders(ctx)
|
fh := ExtractForwardedHeaders(ctx)
|
||||||
WriteForwardedHeaders(&req, fh)
|
|
||||||
|
// 根据配置决定是否设置 X-Forwarded-Host 和 X-Forwarded-Proto
|
||||||
|
setHost := true // 默认值(向后兼容)
|
||||||
|
if headersConfig != nil && headersConfig.SetForwardedHost != nil {
|
||||||
|
setHost = *headersConfig.SetForwardedHost
|
||||||
|
}
|
||||||
|
setProto := true // 默认值(向后兼容)
|
||||||
|
if headersConfig != nil && headersConfig.SetForwardedProto != nil {
|
||||||
|
setProto = *headersConfig.SetForwardedProto
|
||||||
|
}
|
||||||
|
|
||||||
|
WriteForwardedHeaders(&req, fh, setHost, setProto)
|
||||||
|
|
||||||
// 结束请求头
|
// 结束请求头
|
||||||
req.WriteString("\r\n")
|
req.WriteString("\r\n")
|
||||||
@ -348,10 +361,11 @@ func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.R
|
|||||||
// - ctx: FastHTTP 请求上下文
|
// - ctx: FastHTTP 请求上下文
|
||||||
// - target: 负载均衡目标,包含后端 URL
|
// - target: 负载均衡目标,包含后端 URL
|
||||||
// - timeout: 连接和 I/O 超时时间
|
// - timeout: 连接和 I/O 超时时间
|
||||||
|
// - headersConfig: 代理头配置,控制 X-Forwarded-Host/Proto 的设置
|
||||||
//
|
//
|
||||||
// 返回值:
|
// 返回值:
|
||||||
// - error: 代理过程中的错误
|
// - error: 代理过程中的错误
|
||||||
func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout time.Duration) error {
|
func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout time.Duration, headersConfig *config.ProxyHeaders) error {
|
||||||
// 使用 Hijack 获取客户端 TCP 连接
|
// 使用 Hijack 获取客户端 TCP 连接
|
||||||
var clientConn net.Conn
|
var clientConn net.Conn
|
||||||
|
|
||||||
@ -380,7 +394,7 @@ func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout tim
|
|||||||
targetHost := extractHost(target.URL)
|
targetHost := extractHost(target.URL)
|
||||||
|
|
||||||
// 步骤3: 构建并发送 WebSocket 升级请求
|
// 步骤3: 构建并发送 WebSocket 升级请求
|
||||||
upgradeReq := buildWebSocketUpgradeRequest(ctx, targetHost)
|
upgradeReq := buildWebSocketUpgradeRequest(ctx, targetHost, headersConfig)
|
||||||
if _, writeErr := targetConn.Write([]byte(upgradeReq)); writeErr != nil {
|
if _, writeErr := targetConn.Write([]byte(upgradeReq)); writeErr != nil {
|
||||||
return fmt.Errorf("failed to send upgrade request: %w", writeErr)
|
return fmt.Errorf("failed to send upgrade request: %w", writeErr)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -46,7 +46,7 @@ func BenchmarkWebSocketHandshake(b *testing.B) {
|
|||||||
ctx.Request.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate")
|
ctx.Request.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate")
|
||||||
ctx.Request.Header.Set("Origin", "https://example.com")
|
ctx.Request.Header.Set("Origin", "https://example.com")
|
||||||
|
|
||||||
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com:8080")
|
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com:8080", nil)
|
||||||
|
|
||||||
// 验证握手请求包含关键头
|
// 验证握手请求包含关键头
|
||||||
if !strings.Contains(result, "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==") {
|
if !strings.Contains(result, "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==") {
|
||||||
|
|||||||
@ -373,7 +373,7 @@ func TestBuildWebSocketUpgradeRequest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
ctx.Request.Header.SetHost(tt.host)
|
ctx.Request.Header.SetHost(tt.host)
|
||||||
|
|
||||||
result := buildWebSocketUpgradeRequest(ctx, tt.targetHost)
|
result := buildWebSocketUpgradeRequest(ctx, tt.targetHost, nil)
|
||||||
|
|
||||||
for _, want := range tt.wantContains {
|
for _, want := range tt.wantContains {
|
||||||
if !strings.Contains(result, want) {
|
if !strings.Contains(result, want) {
|
||||||
@ -394,7 +394,7 @@ func TestBuildWebSocketUpgradeRequest_WithHeaders(t *testing.T) {
|
|||||||
ctx.Request.Header.Set("Sec-WebSocket-Version", "13")
|
ctx.Request.Header.Set("Sec-WebSocket-Version", "13")
|
||||||
ctx.Request.Header.Set("Sec-WebSocket-Protocol", "chat")
|
ctx.Request.Header.Set("Sec-WebSocket-Protocol", "chat")
|
||||||
|
|
||||||
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com")
|
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com", nil)
|
||||||
|
|
||||||
// 验证关键头被复制
|
// 验证关键头被复制
|
||||||
expectedHeaders := []string{
|
expectedHeaders := []string{
|
||||||
@ -434,7 +434,7 @@ func TestBuildWebSocketUpgradeRequest_TLSProto(t *testing.T) {
|
|||||||
// 注意:fasthttp.RequestCtx 默认 IsTLS() 返回 false
|
// 注意:fasthttp.RequestCtx 默认 IsTLS() 返回 false
|
||||||
// 无法在单元测试中直接模拟 TLS 连接
|
// 无法在单元测试中直接模拟 TLS 连接
|
||||||
|
|
||||||
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com")
|
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com", nil)
|
||||||
|
|
||||||
if !strings.Contains(result, tt.wantProto) {
|
if !strings.Contains(result, tt.wantProto) {
|
||||||
t.Errorf("Missing %q in:\n%s", tt.wantProto, result)
|
t.Errorf("Missing %q in:\n%s", tt.wantProto, result)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user