refactor: remove extractHostFromURL, use netutil.ParseTargetURL

This commit is contained in:
xfy 2026-06-03 17:50:06 +08:00
parent 041bc97578
commit ae3c167cd6
3 changed files with 4 additions and 76 deletions

View File

@ -11,6 +11,7 @@ import (
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/logging"
"rua.plus/lolly/internal/netutil"
"rua.plus/lolly/internal/variable"
)
@ -30,7 +31,7 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan
// 设置 Host header 为目标主机
// 从 target.URL 提取 host:portHostClient 连接需要此格式)
targetHost := extractHostFromURL(target.URL)
targetHost, _ := netutil.ParseTargetURL(target.URL, false)
if targetHost != "" {
headers.Set("Host", targetHost)
}

View File

@ -359,7 +359,8 @@ func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCf
// 上游 SSL 配置(使用原生 TLSConfig
if sslCfg != nil && sslCfg.Enabled && isTLS {
tlsCfg, err := CreateTLSConfig(sslCfg, extractHostFromURL(targetURL))
host, _ := netutil.ParseTargetURL(targetURL, false)
tlsCfg, err := CreateTLSConfig(sslCfg, host)
if err != nil {
logging.Error().Err(err).Str("target", targetURL).Msg("Failed to create upstream TLS config")
} else {
@ -990,29 +991,4 @@ func isWebSocketRequest(ctx *fasthttp.RequestCtx) bool {
return strings.EqualFold(string(upgrade), "websocket")
}
// extractHostFromURL 从 URL 字符串中提取 host:port 部分。
//
// 移除 http:// 或 https:// 协议前缀,以及路径部分,
// 仅保留主机名和端口(如 "example.com:8080")。
//
// 参数:
// - urlStr: 完整 URL 字符串
//
// 返回值:
// - string: host:port 格式的主机地址
func extractHostFromURL(urlStr string) string {
// 移除协议前缀
host := urlStr
if strings.HasPrefix(host, "http://") {
host = host[7:]
} else if strings.HasPrefix(host, "https://") {
host = host[8:]
}
// 移除路径部分
if idx := strings.Index(host, "/"); idx != -1 {
host = host[:idx]
}
return host
}

View File

@ -1423,55 +1423,6 @@ func TestSelectTarget_LuaEnabled(t *testing.T) {
})
}
// TestExtractHostFromURL 测试 extractHostFromURL 函数。
func TestExtractHostFromURL(t *testing.T) {
tests := []struct {
name string
url string
want string
}{
{
name: "HTTP URL with port",
url: "http://example.com:8080",
want: "example.com:8080",
},
{
name: "HTTPS URL with port",
url: "https://example.com:8443",
want: "example.com:8443",
},
{
name: "HTTP URL without port",
url: "http://example.com",
want: "example.com",
},
{
name: "HTTPS URL without port",
url: "https://example.com",
want: "example.com",
},
{
name: "URL with path",
url: "http://example.com:8080/api/users",
want: "example.com:8080",
},
{
name: "No protocol prefix",
url: "example.com:8080",
want: "example.com:8080",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractHostFromURL(tt.url)
if got != tt.want {
t.Errorf("extractHostFromURL() = %q, want %q", got, tt.want)
}
})
}
}
// TestBackgroundRefresh_304 测试后台刷新收到 304 响应。
func TestBackgroundRefresh_304(t *testing.T) {
ln := fasthttputil.NewInmemoryListener()