diff --git a/internal/server/vhost.go b/internal/server/vhost.go index e884abd..093c351 100644 --- a/internal/server/vhost.go +++ b/internal/server/vhost.go @@ -17,6 +17,10 @@ package server import ( + "fmt" + "regexp" + "strings" + "github.com/valyala/fasthttp" "rua.plus/lolly/internal/netutil" ) @@ -25,14 +29,30 @@ import ( // // 管理多个虚拟主机,根据请求的 Host 头分发到对应的处理器。 // 支持默认主机作为未匹配请求的 fallback。 +// 支持精确匹配、前缀通配(*.example.com)、后缀通配(example.*)和正则匹配。 type VHostManager struct { - // hosts 虚拟主机映射,按 server name 索引 + // 精确匹配 hosts map[string]*VirtualHost + // 前缀通配 - suffix map(O(1) 查找) + wildcardSuffixMap map[string]*VirtualHost // suffix -> vhost + + // 后缀通配 - TLD map + wildcardTLDMap map[string]*VirtualHost // TLD -> vhost + + // 正则匹配 + regexHosts []*RegexHostMatcher + // defaultHost 默认主机,处理未匹配的 Host 头请求 defaultHost *VirtualHost } +// RegexHostMatcher 正则主机匹配器。 +type RegexHostMatcher struct { + pattern *regexp.Regexp + vhost *VirtualHost +} + // VirtualHost 虚拟主机。 // // 代表一个虚拟主机配置,包含名称和对应的请求处理器。 @@ -50,19 +70,66 @@ type VirtualHost struct { // - *VHostManager: 新创建的管理器实例 func NewVHostManager() *VHostManager { return &VHostManager{ - hosts: make(map[string]*VirtualHost), + hosts: make(map[string]*VirtualHost), + wildcardSuffixMap: make(map[string]*VirtualHost), + wildcardTLDMap: make(map[string]*VirtualHost), + regexHosts: make([]*RegexHostMatcher, 0), } } // AddHost 添加虚拟主机。 // +// 支持以下 server_name 格式: +// - 精确匹配: "example.com" +// - 前缀通配: "*.example.com"(匹配任意子域名) +// - 后缀通配: "example.*"(匹配任意 TLD) +// - 正则匹配: "~regex"(以 ~ 开头,后面是正则表达式) +// // 参数: // - name: 虚拟主机名称(域名) // - handler: 请求处理器 -func (v *VHostManager) AddHost(name string, handler fasthttp.RequestHandler) { - v.hosts[name] = &VirtualHost{ - name: name, - handler: handler, +// +// 返回值: +// - error: 正则表达式无效时返回错误 +func (v *VHostManager) AddHost(name string, handler fasthttp.RequestHandler) error { + if strings.HasPrefix(name, "~") { + // 正则匹配 + pattern := name[1:] + re, err := regexp.Compile(pattern) + if err != nil { + return fmt.Errorf("invalid regex pattern: %w", err) + } + v.regexHosts = append(v.regexHosts, &RegexHostMatcher{ + pattern: re, + vhost: &VirtualHost{ + name: name, + handler: handler, + }, + }) + return nil + } else if strings.HasPrefix(name, "*.") { + // 前缀通配 *.example.com + suffix := name[2:] + v.wildcardSuffixMap[suffix] = &VirtualHost{ + name: name, + handler: handler, + } + return nil + } else if strings.HasSuffix(name, ".*") { + // 后缀通配 example.* + tld := name[:len(name)-2] + v.wildcardTLDMap[tld] = &VirtualHost{ + name: name, + handler: handler, + } + return nil + } else { + // 精确匹配 + v.hosts[name] = &VirtualHost{ + name: name, + handler: handler, + } + return nil } } @@ -77,6 +144,72 @@ func (v *VHostManager) SetDefault(handler fasthttp.RequestHandler) { } } +// findLongestWildcardPrefix 查找最长的通配符前缀匹配。 +// +// 按 nginx 规则,从最长子域名开始匹配,例如: +// "a.b.example.com" 优先匹配 "*.b.example.com",其次 "*.example.com"。 +// +// 参数: +// - host: 主机名 +// +// 返回值: +// - *VirtualHost: 匹配的虚拟主机,未匹配返回 nil +func (v *VHostManager) findLongestWildcardPrefix(host string) *VirtualHost { + parts := strings.Split(host, ".") + for i := 1; i < len(parts); i++ { + suffix := strings.Join(parts[i:], ".") + if vhost, ok := v.wildcardSuffixMap[suffix]; ok { + return vhost + } + } + return nil +} + +// FindHost 根据主机名查找虚拟主机。 +// +// 匹配优先级(nginx server_name 规则): +// 1. 精确匹配 +// 2. 最长前缀通配(*.example.com) +// 3. 后缀通配(example.*) +// 4. 正则匹配(按配置顺序) +// 5. 默认主机 +// +// 参数: +// - host: 主机名 +// +// 返回值: +// - *VirtualHost: 匹配的虚拟主机 +func (v *VHostManager) FindHost(host string) *VirtualHost { + // 1. 精确匹配 + if vhost, ok := v.hosts[host]; ok { + return vhost + } + + // 2. 最长前缀通配 *.example.com + if vhost := v.findLongestWildcardPrefix(host); vhost != nil { + return vhost + } + + // 3. 后缀通配 example.* + parts := strings.Split(host, ".") + if len(parts) >= 2 { + tld := parts[0] + if vhost, ok := v.wildcardTLDMap[tld]; ok { + return vhost + } + } + + // 4. 正则匹配(按配置顺序) + for _, m := range v.regexHosts { + if m.pattern.MatchString(host) { + return m.vhost + } + } + + // 5. 默认主机 + return v.defaultHost +} + // Handler 返回虚拟主机选择器。 // // 返回值: @@ -85,10 +218,8 @@ func (v *VHostManager) Handler() fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { host := netutil.StripPort(string(ctx.Host())) - if vhost, ok := v.hosts[host]; ok { + if vhost := v.FindHost(host); vhost != nil { vhost.handler(ctx) - } else if v.defaultHost != nil { - v.defaultHost.handler(ctx) } else { ctx.Error("Host not found", fasthttp.StatusNotFound) } diff --git a/internal/server/vhost_test.go b/internal/server/vhost_test.go index 97053de..348efbe 100644 --- a/internal/server/vhost_test.go +++ b/internal/server/vhost_test.go @@ -21,7 +21,7 @@ func TestVHostManager_Handler(t *testing.T) { t.Run("匹配已知主机", func(t *testing.T) { manager := NewVHostManager() hostCalled := false - manager.AddHost("example.com", mockHandler("example", &hostCalled)) + _ = manager.AddHost("example.com", mockHandler("example", &hostCalled)) handler := manager.Handler() ctx := &fasthttp.RequestCtx{} @@ -40,7 +40,7 @@ func TestVHostManager_Handler(t *testing.T) { t.Run("匹配带端口的主机", func(t *testing.T) { manager := NewVHostManager() hostCalled := false - manager.AddHost("example.com", mockHandler("example", &hostCalled)) + _ = manager.AddHost("example.com", mockHandler("example", &hostCalled)) handler := manager.Handler() ctx := &fasthttp.RequestCtx{} @@ -60,7 +60,7 @@ func TestVHostManager_Handler(t *testing.T) { manager := NewVHostManager() exampleCalled := false defaultCalled := false - manager.AddHost("example.com", mockHandler("example", &exampleCalled)) + _ = manager.AddHost("example.com", mockHandler("example", &exampleCalled)) manager.SetDefault(mockHandler("default", &defaultCalled)) handler := manager.Handler() @@ -83,7 +83,7 @@ func TestVHostManager_Handler(t *testing.T) { t.Run("无匹配无默认返回404", func(t *testing.T) { manager := NewVHostManager() exampleCalled := false - manager.AddHost("example.com", mockHandler("example", &exampleCalled)) + _ = manager.AddHost("example.com", mockHandler("example", &exampleCalled)) handler := manager.Handler() ctx := &fasthttp.RequestCtx{} @@ -102,7 +102,7 @@ func TestVHostManager_Handler(t *testing.T) { t.Run("IPv6地址Host", func(t *testing.T) { manager := NewVHostManager() ipv6Called := false - manager.AddHost("[::1]", mockHandler("ipv6", &ipv6Called)) + _ = manager.AddHost("[::1]", mockHandler("ipv6", &ipv6Called)) handler := manager.Handler() ctx := &fasthttp.RequestCtx{} @@ -157,7 +157,7 @@ func TestVHostManager_AddHost(t *testing.T) { t.Run("添加单个主机", func(t *testing.T) { manager := NewVHostManager() called := false - manager.AddHost("test.com", mockHandler("test", &called)) + _ = manager.AddHost("test.com", mockHandler("test", &called)) handler := manager.Handler() ctx := &fasthttp.RequestCtx{} @@ -174,8 +174,8 @@ func TestVHostManager_AddHost(t *testing.T) { manager := NewVHostManager() host1Called := false host2Called := false - manager.AddHost("host1.com", mockHandler("host1", &host1Called)) - manager.AddHost("host2.com", mockHandler("host2", &host2Called)) + _ = manager.AddHost("host1.com", mockHandler("host1", &host1Called)) + _ = manager.AddHost("host2.com", mockHandler("host2", &host2Called)) handler := manager.Handler() @@ -200,8 +200,8 @@ func TestVHostManager_AddHost(t *testing.T) { manager := NewVHostManager() firstCalled := false secondCalled := false - manager.AddHost("test.com", mockHandler("first", &firstCalled)) - manager.AddHost("test.com", mockHandler("second", &secondCalled)) + _ = manager.AddHost("test.com", mockHandler("first", &firstCalled)) + _ = manager.AddHost("test.com", mockHandler("second", &secondCalled)) handler := manager.Handler() ctx := &fasthttp.RequestCtx{} @@ -281,7 +281,7 @@ func TestVHostManager_PortStripping(t *testing.T) { t.Run(tt.name, func(t *testing.T) { manager := NewVHostManager() called := false - manager.AddHost(tt.expected, mockHandler("matched", &called)) + _ = manager.AddHost(tt.expected, mockHandler("matched", &called)) handler := manager.Handler() ctx := &fasthttp.RequestCtx{}