refactor: remove extractHostFromURL, use netutil.ParseTargetURL
This commit is contained in:
parent
041bc97578
commit
ae3c167cd6
@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
"rua.plus/lolly/internal/loadbalance"
|
"rua.plus/lolly/internal/loadbalance"
|
||||||
"rua.plus/lolly/internal/logging"
|
"rua.plus/lolly/internal/logging"
|
||||||
|
"rua.plus/lolly/internal/netutil"
|
||||||
"rua.plus/lolly/internal/variable"
|
"rua.plus/lolly/internal/variable"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -30,7 +31,7 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan
|
|||||||
|
|
||||||
// 设置 Host header 为目标主机
|
// 设置 Host header 为目标主机
|
||||||
// 从 target.URL 提取 host:port(HostClient 连接需要此格式)
|
// 从 target.URL 提取 host:port(HostClient 连接需要此格式)
|
||||||
targetHost := extractHostFromURL(target.URL)
|
targetHost, _ := netutil.ParseTargetURL(target.URL, false)
|
||||||
if targetHost != "" {
|
if targetHost != "" {
|
||||||
headers.Set("Host", targetHost)
|
headers.Set("Host", targetHost)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -359,7 +359,8 @@ func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCf
|
|||||||
|
|
||||||
// 上游 SSL 配置(使用原生 TLSConfig)
|
// 上游 SSL 配置(使用原生 TLSConfig)
|
||||||
if sslCfg != nil && sslCfg.Enabled && isTLS {
|
if sslCfg != nil && sslCfg.Enabled && isTLS {
|
||||||
tlsCfg, err := CreateTLSConfig(sslCfg, extractHostFromURL(targetURL))
|
host, _ := netutil.ParseTargetURL(targetURL, false)
|
||||||
|
tlsCfg, err := CreateTLSConfig(sslCfg, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error().Err(err).Str("target", targetURL).Msg("Failed to create upstream TLS config")
|
logging.Error().Err(err).Str("target", targetURL).Msg("Failed to create upstream TLS config")
|
||||||
} else {
|
} else {
|
||||||
@ -990,29 +991,4 @@ func isWebSocketRequest(ctx *fasthttp.RequestCtx) bool {
|
|||||||
return strings.EqualFold(string(upgrade), "websocket")
|
return strings.EqualFold(string(upgrade), "websocket")
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractHostFromURL 从 URL 字符串中提取 host:port 部分。
|
|
||||||
//
|
|
||||||
// 移除 http:// 或 https:// 协议前缀,以及路径部分,
|
|
||||||
// 仅保留主机名和端口(如 "example.com:8080")。
|
|
||||||
//
|
|
||||||
// 参数:
|
|
||||||
// - urlStr: 完整 URL 字符串
|
|
||||||
//
|
|
||||||
// 返回值:
|
|
||||||
// - string: host:port 格式的主机地址
|
|
||||||
func extractHostFromURL(urlStr string) string {
|
|
||||||
// 移除协议前缀
|
|
||||||
host := urlStr
|
|
||||||
if strings.HasPrefix(host, "http://") {
|
|
||||||
host = host[7:]
|
|
||||||
} else if strings.HasPrefix(host, "https://") {
|
|
||||||
host = host[8:]
|
|
||||||
}
|
|
||||||
|
|
||||||
// 移除路径部分
|
|
||||||
if idx := strings.Index(host, "/"); idx != -1 {
|
|
||||||
host = host[:idx]
|
|
||||||
}
|
|
||||||
|
|
||||||
return host
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1423,55 +1423,6 @@ func TestSelectTarget_LuaEnabled(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestExtractHostFromURL 测试 extractHostFromURL 函数。
|
|
||||||
func TestExtractHostFromURL(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
url string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "HTTP URL with port",
|
|
||||||
url: "http://example.com:8080",
|
|
||||||
want: "example.com:8080",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "HTTPS URL with port",
|
|
||||||
url: "https://example.com:8443",
|
|
||||||
want: "example.com:8443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "HTTP URL without port",
|
|
||||||
url: "http://example.com",
|
|
||||||
want: "example.com",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "HTTPS URL without port",
|
|
||||||
url: "https://example.com",
|
|
||||||
want: "example.com",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "URL with path",
|
|
||||||
url: "http://example.com:8080/api/users",
|
|
||||||
want: "example.com:8080",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No protocol prefix",
|
|
||||||
url: "example.com:8080",
|
|
||||||
want: "example.com:8080",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got := extractHostFromURL(tt.url)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("extractHostFromURL() = %q, want %q", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBackgroundRefresh_304 测试后台刷新收到 304 响应。
|
// TestBackgroundRefresh_304 测试后台刷新收到 304 响应。
|
||||||
func TestBackgroundRefresh_304(t *testing.T) {
|
func TestBackgroundRefresh_304(t *testing.T) {
|
||||||
ln := fasthttputil.NewInmemoryListener()
|
ln := fasthttputil.NewInmemoryListener()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user