From 616762e840d5f6fb8e5d25f445f87db83a19662a Mon Sep 17 00:00:00 2001 From: xfy Date: Fri, 10 Apr 2026 11:20:10 +0800 Subject: [PATCH] =?UTF-8?q?refactor(netutil):=20=E6=8F=90=E5=8F=96?= =?UTF-8?q?=E9=80=9A=E7=94=A8=E4=B8=BB=E6=9C=BA=E5=90=8D=E5=A4=84=E7=90=86?= =?UTF-8?q?=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 StripPort() 函数用于移除主机名中的端口 - 新增 HasPort() 函数用于检测主机名是否包含端口 - 替代 vhost 和 ssl 模块中的内联端口处理逻辑 Co-Authored-By: Claude Opus 4.6 --- internal/netutil/host.go | 69 +++++++++++++++++++++++++++++++++++ internal/netutil/host_test.go | 61 +++++++++++++++++++++++++++++++ internal/server/vhost.go | 38 +------------------ internal/ssl/ssl.go | 8 +--- 4 files changed, 134 insertions(+), 42 deletions(-) create mode 100644 internal/netutil/host.go create mode 100644 internal/netutil/host_test.go diff --git a/internal/netutil/host.go b/internal/netutil/host.go new file mode 100644 index 0000000..5673851 --- /dev/null +++ b/internal/netutil/host.go @@ -0,0 +1,69 @@ +// Package netutil 提供网络相关的工具函数。 +// +// 该文件包含主机名处理相关的工具函数。 +// +// 作者:xfy +package netutil + +import ( + "strings" +) + +// StripPort 从 Host 头中移除端口号。 +// +// 支持 IPv4 和 IPv6 格式: +// - example.com:8080 -> example.com +// - [::1]:8080 -> [::1] +// - [2001:db8::1]:443 -> [2001:db8::1] +// - example.com -> example.com +// +// 参数: +// - host: 主机名(可能包含端口) +// +// 返回值: +// - string: 移除端口后的主机名 +func StripPort(host string) string { + if len(host) == 0 { + return host + } + + // IPv6 格式:以 '[' 开头,找 ']:' 作为分隔点 + if host[0] == '[' { + for i := 0; i < len(host)-1; i++ { + if host[i] == ']' && host[i+1] == ':' { + return host[:i+1] + } + } + return host + } + + // IPv4 或域名格式:找第一个 ':' 作为分隔点 + for i := 0; i < len(host); i++ { + if host[i] == ':' { + return host[:i] + } + } + + return host +} + +// HasPort 检查主机名是否包含端口号。 +// +// 参数: +// - host: 主机名 +// +// 返回值: +// - bool: true 表示包含端口 +func HasPort(host string) bool { + if len(host) == 0 { + return false + } + + // IPv6 格式 + if host[0] == '[' { + return strings.Contains(host, "]:") + } + + // IPv4 或域名格式 + return strings.Contains(host, ":") +} diff --git a/internal/netutil/host_test.go b/internal/netutil/host_test.go new file mode 100644 index 0000000..bff38f4 --- /dev/null +++ b/internal/netutil/host_test.go @@ -0,0 +1,61 @@ +package netutil + +import "testing" + +func TestStripPort(t *testing.T) { + tests := []struct { + name string + host string + expected string + }{ + // IPv4 格式 + {"IPv4 with port", "example.com:8080", "example.com"}, + {"IPv4 with port 443", "example.com:443", "example.com"}, + {"IPv4 no port", "example.com", "example.com"}, + {"IPv4 with port and path", "example.com:8080", "example.com"}, + + // IPv6 格式 + {"IPv6 localhost with port", "[::1]:443", "[::1]"}, + {"IPv6 full with port", "[2001:db8::1]:8443", "[2001:db8::1]"}, + {"IPv6 no port", "[::1]", "[::1]"}, + {"IPv6 full no port", "[2001:db8::1]", "[2001:db8::1]"}, + + // 边界情况 + {"empty string", "", ""}, + {"just port", ":8080", ""}, + {"IPv6 with empty brackets", "[]", "[]"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := StripPort(tt.host) + if result != tt.expected { + t.Errorf("StripPort(%q) = %q, want %q", tt.host, result, tt.expected) + } + }) + } +} + +func TestHasPort(t *testing.T) { + tests := []struct { + name string + host string + expected bool + }{ + {"IPv4 with port", "example.com:8080", true}, + {"IPv4 no port", "example.com", false}, + {"IPv6 with port", "[::1]:443", true}, + {"IPv6 no port", "[::1]", false}, + {"empty string", "", false}, + {"just port", ":8080", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := HasPort(tt.host) + if result != tt.expected { + t.Errorf("HasPort(%q) = %v, want %v", tt.host, result, tt.expected) + } + }) + } +} diff --git a/internal/server/vhost.go b/internal/server/vhost.go index a320129..da2b778 100644 --- a/internal/server/vhost.go +++ b/internal/server/vhost.go @@ -18,6 +18,7 @@ package server import ( "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/netutil" ) // VHostManager 虚拟主机管理器。 @@ -82,7 +83,7 @@ func (v *VHostManager) SetDefault(handler fasthttp.RequestHandler) { // - fasthttp.RequestHandler: 根据 Host 头分发请求的处理器 func (v *VHostManager) Handler() fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { - host := stripPort(string(ctx.Host())) + host := netutil.StripPort(string(ctx.Host())) if vhost, ok := v.hosts[host]; ok { vhost.handler(ctx) @@ -93,38 +94,3 @@ func (v *VHostManager) Handler() fasthttp.RequestHandler { } } } - -// stripPort 从 Host 头中移除端口号。 -// -// 支持 IPv4 和 IPv6 格式: -// - example.com:8080 -> example.com -// - [::1]:8080 -> [::1] -// - [2001:db8::1]:443 -> [2001:db8::1] -// - example.com -> example.com -func stripPort(host string) string { - // 空字符串直接返回 - if len(host) == 0 { - return host - } - - // IPv6 格式:以 '[' 开头,找 ']:' 作为分隔点 - if host[0] == '[' { - // 查找 ']:' 分隔符 - for i := 0; i < len(host)-1; i++ { - if host[i] == ']' && host[i+1] == ':' { - return host[:i+1] // 返回包含 ']' 的部分,如 "[::1]" - } - } - // 没有 ']:' 分隔符,可能是纯 IPv6 地址(如 "[::1]") - return host - } - - // IPv4 或域名格式:找第一个 ':' 作为分隔点 - for i := 0; i < len(host); i++ { - if host[i] == ':' { - return host[:i] - } - } - - return host -} diff --git a/internal/ssl/ssl.go b/internal/ssl/ssl.go index 45b30e3..e455074 100644 --- a/internal/ssl/ssl.go +++ b/internal/ssl/ssl.go @@ -46,6 +46,7 @@ import ( "sync" "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/netutil" ) // TLSManager TLS 配置管理器。 @@ -272,12 +273,7 @@ func (m *TLSManager) GetTLSConfigForHost(host string) *tls.Config { defer m.mu.RUnlock() // 从主机名中移除端口(如果存在) - for i := 0; i < len(host); i++ { - if host[i] == ':' { - host = host[:i] - break - } - } + host = netutil.StripPort(host) if cfg, ok := m.configs[host]; ok { return cfg