refactor(netutil): 提取通用主机名处理函数
- 新增 StripPort() 函数用于移除主机名中的端口 - 新增 HasPort() 函数用于检测主机名是否包含端口 - 替代 vhost 和 ssl 模块中的内联端口处理逻辑 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
f0f0f7b821
commit
616762e840
69
internal/netutil/host.go
Normal file
69
internal/netutil/host.go
Normal file
@ -0,0 +1,69 @@
|
||||
// Package netutil 提供网络相关的工具函数。
|
||||
//
|
||||
// 该文件包含主机名处理相关的工具函数。
|
||||
//
|
||||
// 作者:xfy
|
||||
package netutil
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// StripPort 从 Host 头中移除端口号。
|
||||
//
|
||||
// 支持 IPv4 和 IPv6 格式:
|
||||
// - example.com:8080 -> example.com
|
||||
// - [::1]:8080 -> [::1]
|
||||
// - [2001:db8::1]:443 -> [2001:db8::1]
|
||||
// - example.com -> example.com
|
||||
//
|
||||
// 参数:
|
||||
// - host: 主机名(可能包含端口)
|
||||
//
|
||||
// 返回值:
|
||||
// - string: 移除端口后的主机名
|
||||
func StripPort(host string) string {
|
||||
if len(host) == 0 {
|
||||
return host
|
||||
}
|
||||
|
||||
// IPv6 格式:以 '[' 开头,找 ']:' 作为分隔点
|
||||
if host[0] == '[' {
|
||||
for i := 0; i < len(host)-1; i++ {
|
||||
if host[i] == ']' && host[i+1] == ':' {
|
||||
return host[:i+1]
|
||||
}
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// IPv4 或域名格式:找第一个 ':' 作为分隔点
|
||||
for i := 0; i < len(host); i++ {
|
||||
if host[i] == ':' {
|
||||
return host[:i]
|
||||
}
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
// HasPort 检查主机名是否包含端口号。
|
||||
//
|
||||
// 参数:
|
||||
// - host: 主机名
|
||||
//
|
||||
// 返回值:
|
||||
// - bool: true 表示包含端口
|
||||
func HasPort(host string) bool {
|
||||
if len(host) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// IPv6 格式
|
||||
if host[0] == '[' {
|
||||
return strings.Contains(host, "]:")
|
||||
}
|
||||
|
||||
// IPv4 或域名格式
|
||||
return strings.Contains(host, ":")
|
||||
}
|
||||
61
internal/netutil/host_test.go
Normal file
61
internal/netutil/host_test.go
Normal file
@ -0,0 +1,61 @@
|
||||
package netutil
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestStripPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected string
|
||||
}{
|
||||
// IPv4 格式
|
||||
{"IPv4 with port", "example.com:8080", "example.com"},
|
||||
{"IPv4 with port 443", "example.com:443", "example.com"},
|
||||
{"IPv4 no port", "example.com", "example.com"},
|
||||
{"IPv4 with port and path", "example.com:8080", "example.com"},
|
||||
|
||||
// IPv6 格式
|
||||
{"IPv6 localhost with port", "[::1]:443", "[::1]"},
|
||||
{"IPv6 full with port", "[2001:db8::1]:8443", "[2001:db8::1]"},
|
||||
{"IPv6 no port", "[::1]", "[::1]"},
|
||||
{"IPv6 full no port", "[2001:db8::1]", "[2001:db8::1]"},
|
||||
|
||||
// 边界情况
|
||||
{"empty string", "", ""},
|
||||
{"just port", ":8080", ""},
|
||||
{"IPv6 with empty brackets", "[]", "[]"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := StripPort(tt.host)
|
||||
if result != tt.expected {
|
||||
t.Errorf("StripPort(%q) = %q, want %q", tt.host, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
expected bool
|
||||
}{
|
||||
{"IPv4 with port", "example.com:8080", true},
|
||||
{"IPv4 no port", "example.com", false},
|
||||
{"IPv6 with port", "[::1]:443", true},
|
||||
{"IPv6 no port", "[::1]", false},
|
||||
{"empty string", "", false},
|
||||
{"just port", ":8080", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := HasPort(tt.host)
|
||||
if result != tt.expected {
|
||||
t.Errorf("HasPort(%q) = %v, want %v", tt.host, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -18,6 +18,7 @@ package server
|
||||
|
||||
import (
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
)
|
||||
|
||||
// VHostManager 虚拟主机管理器。
|
||||
@ -82,7 +83,7 @@ func (v *VHostManager) SetDefault(handler fasthttp.RequestHandler) {
|
||||
// - fasthttp.RequestHandler: 根据 Host 头分发请求的处理器
|
||||
func (v *VHostManager) Handler() fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
host := stripPort(string(ctx.Host()))
|
||||
host := netutil.StripPort(string(ctx.Host()))
|
||||
|
||||
if vhost, ok := v.hosts[host]; ok {
|
||||
vhost.handler(ctx)
|
||||
@ -93,38 +94,3 @@ func (v *VHostManager) Handler() fasthttp.RequestHandler {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// stripPort 从 Host 头中移除端口号。
|
||||
//
|
||||
// 支持 IPv4 和 IPv6 格式:
|
||||
// - example.com:8080 -> example.com
|
||||
// - [::1]:8080 -> [::1]
|
||||
// - [2001:db8::1]:443 -> [2001:db8::1]
|
||||
// - example.com -> example.com
|
||||
func stripPort(host string) string {
|
||||
// 空字符串直接返回
|
||||
if len(host) == 0 {
|
||||
return host
|
||||
}
|
||||
|
||||
// IPv6 格式:以 '[' 开头,找 ']:' 作为分隔点
|
||||
if host[0] == '[' {
|
||||
// 查找 ']:' 分隔符
|
||||
for i := 0; i < len(host)-1; i++ {
|
||||
if host[i] == ']' && host[i+1] == ':' {
|
||||
return host[:i+1] // 返回包含 ']' 的部分,如 "[::1]"
|
||||
}
|
||||
}
|
||||
// 没有 ']:' 分隔符,可能是纯 IPv6 地址(如 "[::1]")
|
||||
return host
|
||||
}
|
||||
|
||||
// IPv4 或域名格式:找第一个 ':' 作为分隔点
|
||||
for i := 0; i < len(host); i++ {
|
||||
if host[i] == ':' {
|
||||
return host[:i]
|
||||
}
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
@ -46,6 +46,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
)
|
||||
|
||||
// TLSManager TLS 配置管理器。
|
||||
@ -272,12 +273,7 @@ func (m *TLSManager) GetTLSConfigForHost(host string) *tls.Config {
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// 从主机名中移除端口(如果存在)
|
||||
for i := 0; i < len(host); i++ {
|
||||
if host[i] == ':' {
|
||||
host = host[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
host = netutil.StripPort(host)
|
||||
|
||||
if cfg, ok := m.configs[host]; ok {
|
||||
return cfg
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user