diff --git a/internal/matcher/bench_test.go b/internal/matcher/bench_test.go new file mode 100644 index 0000000..99ac2f4 --- /dev/null +++ b/internal/matcher/bench_test.go @@ -0,0 +1,76 @@ +package matcher + +import ( + "testing" + + "github.com/valyala/fasthttp" +) + +func BenchmarkRadixTree_Insert(b *testing.B) { + tree := NewRadixTree() + handler := func(ctx *fasthttp.RequestCtx) {} + + paths := []string{ + "/", "/api", "/api/v1", "/api/v2", + "/static", "/static/css", "/static/js", + "/user", "/user/profile", "/user/settings", + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, p := range paths { + tree.Insert(p, handler, i) + } + } +} + +func BenchmarkRadixTree_Find(b *testing.B) { + tree := NewRadixTree() + handler := func(ctx *fasthttp.RequestCtx) {} + + paths := []string{"/", "/api", "/api/v1", "/api/v2/users/123"} + for i, p := range paths { + tree.Insert(p, handler, i+1) + } + tree.MarkInitialized() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tree.FindLongestPrefix("/api/v2/users/123/details") + } +} + +func BenchmarkExactMatcher_Match(b *testing.B) { + handler := func(ctx *fasthttp.RequestCtx) {} + m := NewExactMatcher("/api/users", handler, 1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Match("/api/users") + } +} + +func BenchmarkRegexMatcher_Match(b *testing.B) { + m := MustRegexMatcher(`^/api/v[0-9]+/users/[0-9]+$`, nil, 3, false) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Match("/api/v1/users/123") + } +} + +func BenchmarkLocationEngine_Match(b *testing.B) { + engine := NewLocationEngine() + handler := func(ctx *fasthttp.RequestCtx) {} + + engine.AddExact("/api", handler) + engine.AddPrefixPriority("/api/", handler) + engine.AddRegex(`\.php$`, handler, false) + engine.AddPrefix("/", handler) + engine.MarkInitialized() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + engine.Match("/api/users/123") + } +} diff --git a/internal/matcher/conflict.go b/internal/matcher/conflict.go new file mode 100644 index 0000000..7b641a0 --- /dev/null +++ b/internal/matcher/conflict.go @@ -0,0 +1,50 @@ +package matcher + +import "fmt" + +// ConflictDetector 冲突检测 +type ConflictDetector struct { + registeredPaths map[string]string // path -> location type +} + +// NewConflictDetector 创建冲突检测器 +func NewConflictDetector() *ConflictDetector { + return &ConflictDetector{ + registeredPaths: make(map[string]string), + } +} + +// Register 注册路径,返回冲突错误 +func (cd *ConflictDetector) Register(path, locationType string) error { + if existing, ok := cd.registeredPaths[path]; ok { + return fmt.Errorf("path conflict: '%s' already registered as '%s', trying to register as '%s'", + path, existing, locationType) + } + cd.registeredPaths[path] = locationType + return nil +} + +// Exists 检查路径是否已注册 +func (cd *ConflictDetector) Exists(path string) bool { + _, ok := cd.registeredPaths[path] + return ok +} + +// GetRegisteredPaths 返回所有已注册路径 +func (cd *ConflictDetector) GetRegisteredPaths() map[string]string { + result := make(map[string]string, len(cd.registeredPaths)) + for k, v := range cd.registeredPaths { + result[k] = v + } + return result +} + +// Remove 移除已注册路径 +func (cd *ConflictDetector) Remove(path string) { + delete(cd.registeredPaths, path) +} + +// Clear 清空所有注册路径 +func (cd *ConflictDetector) Clear() { + cd.registeredPaths = make(map[string]string) +} diff --git a/internal/matcher/exact.go b/internal/matcher/exact.go new file mode 100644 index 0000000..3a07160 --- /dev/null +++ b/internal/matcher/exact.go @@ -0,0 +1,36 @@ +package matcher + +import ( + "github.com/valyala/fasthttp" +) + +// ExactMatcher Hash Map 精确匹配 +type ExactMatcher struct { + path string + handler fasthttp.RequestHandler + priority int +} + +// NewExactMatcher 创建精确匹配器 +func NewExactMatcher(path string, handler fasthttp.RequestHandler, priority int) *ExactMatcher { + return &ExactMatcher{ + path: path, + handler: handler, + priority: priority, + } +} + +// Match 检查路径是否精确匹配 +func (m *ExactMatcher) Match(path string) bool { + return m.path == path +} + +// Result 返回匹配结果 +func (m *ExactMatcher) Result() *MatchResult { + return &MatchResult{ + Handler: m.handler, + Path: m.path, + Priority: m.priority, + LocationType: "exact", + } +} diff --git a/internal/matcher/integration_test.go b/internal/matcher/integration_test.go new file mode 100644 index 0000000..2a3780a --- /dev/null +++ b/internal/matcher/integration_test.go @@ -0,0 +1,110 @@ +package matcher + +import ( + "testing" + + "github.com/valyala/fasthttp" +) + +// nginx priority: exact(=) > prefix_priority(^~) > regex(~) > prefix +func TestLocationEngine_NginxPriority(t *testing.T) { + engine := NewLocationEngine() + handler := func(ctx *fasthttp.RequestCtx) {} + + // 注册不同类型 + engine.AddExact("/api", handler) // priority 1 + engine.AddPrefixPriority("/api/", handler) // priority 2 (^~) + engine.AddRegex(`\.php$`, handler, false) // priority 3 + engine.AddPrefix("/", handler) // priority 4 + engine.MarkInitialized() + + // 测试精确匹配优先 + result := engine.Match("/api") + if result.LocationType != "exact" { + t.Errorf("expected exact, got %s", result.LocationType) + } + + // 测试 ^~ 阻止正则 + result = engine.Match("/api/test.php") + if result.LocationType != "prefix" { + t.Errorf("^~ should block regex, got %s", result.LocationType) + } +} + +func TestLocationEngine_RegexMatch(t *testing.T) { + engine := NewLocationEngine() + handler := func(ctx *fasthttp.RequestCtx) {} + + engine.AddPrefixPriority("/api/", handler) + engine.AddRegex(`\.php$`, handler, false) + engine.AddPrefix("/", handler) + engine.MarkInitialized() + + // 正则匹配(^~ 不匹配 /index.php) + result := engine.Match("/index.php") + if result.LocationType != "regex" { + t.Errorf("expected regex for /index.php, got %s", result.LocationType) + } +} + +func TestLocationEngine_PrefixFallback(t *testing.T) { + engine := NewLocationEngine() + handler := func(ctx *fasthttp.RequestCtx) {} + + engine.AddPrefix("/", handler) + engine.MarkInitialized() + + result := engine.Match("/any/path") + if result == nil || result.LocationType != "prefix" { + t.Errorf("expected prefix match, got %v", result) + } +} + +func TestLocationEngine_NoMatch(t *testing.T) { + engine := NewLocationEngine() + engine.MarkInitialized() + + result := engine.Match("/nonexistent") + if result != nil { + t.Errorf("expected nil for no match, got %+v", result) + } +} + +func TestLocationEngine_RegexCaptures(t *testing.T) { + engine := NewLocationEngine() + handler := func(ctx *fasthttp.RequestCtx) {} + + engine.AddRegex(`^/user/(?P[0-9]+)$`, handler, false) + engine.MarkInitialized() + + result := engine.Match("/user/42") + if result.LocationType != "regex" { + t.Errorf("expected regex, got %s", result.LocationType) + } + if result.Captures == nil || result.Captures["id"] != "42" { + t.Errorf("expected captures id=42, got %v", result.Captures) + } +} + +func TestLocationEngine_Initialized_Twice(t *testing.T) { + engine := NewLocationEngine() + handler := func(ctx *fasthttp.RequestCtx) {} + + engine.MarkInitialized() + + err := engine.AddExact("/api", handler) + if err == nil { + t.Error("should fail when adding after initialized") + } +} + +func TestLocationEngine_PathConflict(t *testing.T) { + engine := NewLocationEngine() + handler := func(ctx *fasthttp.RequestCtx) {} + + engine.AddExact("/api", handler) + err := engine.AddExact("/api", handler) + if err == nil { + t.Error("should fail on path conflict") + } +} diff --git a/internal/matcher/location.go b/internal/matcher/location.go new file mode 100644 index 0000000..ed6b57e --- /dev/null +++ b/internal/matcher/location.go @@ -0,0 +1,194 @@ +package matcher + +import ( + "errors" + "fmt" + "regexp" + + "github.com/valyala/fasthttp" +) + +// LocationEngine 统一匹配引擎 +type LocationEngine struct { + // 精确匹配 - Hash Map (O(1)) + exactMatchers map[string]*ExactMatcher + + // 前缀匹配 - Radix Tree (O(log n)) + prefixPriorityTree *RadixTree // ^~ 类型(优先级 2) + prefixTree *RadixTree // 普通前缀(优先级 4) + + // 正则匹配 - Linear Scan(按配置顺序) + regexMatchers []*RegexMatcher + + // 命名 location - Hash Map + namedMatchers map[string]*NamedMatcher + + // 初始化标记 + initialized bool + + // 冲突检测 + registeredPaths map[string]string +} + +// NewLocationEngine 创建新引擎 +func NewLocationEngine() *LocationEngine { + return &LocationEngine{ + exactMatchers: make(map[string]*ExactMatcher), + prefixPriorityTree: NewRadixTree(), + prefixTree: NewRadixTree(), + regexMatchers: []*RegexMatcher{}, + namedMatchers: make(map[string]*NamedMatcher), + registeredPaths: make(map[string]string), + } +} + +// AddExact 添加精确匹配 location +func (e *LocationEngine) AddExact(path string, handler fasthttp.RequestHandler) error { + if e.initialized { + return errors.New("LocationEngine already initialized") + } + + if err := e.checkConflict(path, "exact"); err != nil { + return err + } + + matcher := NewExactMatcher(path, handler, 1) + e.exactMatchers[path] = matcher + return nil +} + +// AddPrefixPriority 添加 ^~ 前缀优先匹配 location +func (e *LocationEngine) AddPrefixPriority(path string, handler fasthttp.RequestHandler) error { + if e.initialized { + return errors.New("LocationEngine already initialized") + } + + if err := e.checkConflict(path, "prefix_priority"); err != nil { + return err + } + + return e.prefixPriorityTree.Insert(path, handler, 2) +} + +// AddRegex 添加正则匹配 location +func (e *LocationEngine) AddRegex(pattern string, handler fasthttp.RequestHandler, caseInsensitive bool) error { + if e.initialized { + return errors.New("LocationEngine already initialized") + } + + matcher, err := NewRegexMatcher(pattern, handler, 3, caseInsensitive) + if err != nil { + return fmt.Errorf("invalid regex pattern: %w", err) + } + + e.regexMatchers = append(e.regexMatchers, matcher) + return nil +} + +// AddPrefix 添加普通前缀匹配 location +func (e *LocationEngine) AddPrefix(path string, handler fasthttp.RequestHandler) error { + if e.initialized { + return errors.New("LocationEngine already initialized") + } + + if err := e.checkConflict(path, "prefix"); err != nil { + return err + } + + return e.prefixTree.Insert(path, handler, 4) +} + +// AddNamed 添加命名 location +func (e *LocationEngine) AddNamed(name string, handler fasthttp.RequestHandler) error { + if e.initialized { + return errors.New("LocationEngine already initialized") + } + + if existing, ok := e.namedMatchers[name]; ok { + return fmt.Errorf("named location '@%s' already registered", existing.name) + } + + matcher := NewNamedMatcher(name, handler) + e.namedMatchers[name] = matcher + return nil +} + +// Match 统一匹配入口 +// nginx 优先级:精确匹配 → 前缀优先(^~) → 正则 → 普通前缀 +func (e *LocationEngine) Match(path string) *MatchResult { + // 1. 精确匹配 (=) - O(1) + if m, ok := e.exactMatchers[path]; ok { + return m.Result() + } + + // 2. 前缀优先匹配 (^~) - O(log n) + prefixPriorityResult := e.prefixPriorityTree.FindLongestPrefix(path) + if prefixPriorityResult != nil && prefixPriorityResult.Handler != nil { + return prefixPriorityResult + } + + // 3. 正则匹配 (~, ~*) - 按顺序 + for _, m := range e.regexMatchers { + if m.Match(path) { + result := m.Result() + result.Captures = m.GetCaptures(path) + return result + } + } + + // 4. 前缀匹配(无修饰符)- O(log n) + return e.prefixTree.FindLongestPrefix(path) +} + +// GetNamed 获取命名 location +func (e *LocationEngine) GetNamed(name string) *NamedMatcher { + return e.namedMatchers[name] +} + +// MarkInitialized 标记初始化完成 +func (e *LocationEngine) MarkInitialized() { + e.initialized = true + e.prefixPriorityTree.MarkInitialized() + e.prefixTree.MarkInitialized() +} + +// checkConflict 检查路径冲突 +func (e *LocationEngine) checkConflict(path, locationType string) error { + if existing, ok := e.registeredPaths[path]; ok { + return fmt.Errorf("path conflict: '%s' already registered as '%s', trying to register as '%s'", + path, existing, locationType) + } + e.registeredPaths[path] = locationType + return nil +} + +// ParseRegexPattern 解析 nginx 风格的正则模式(支持 ^~ ~ ~* 前缀) +func ParseRegexPattern(pattern string) (cleanPattern string, caseInsensitive bool, isRegex bool) { + if len(pattern) == 0 { + return pattern, false, false + } + + switch pattern[0] { + case '~': + cleanPattern = pattern[1:] + caseInsensitive = true + return cleanPattern, caseInsensitive, true + case '^': + if len(pattern) > 1 && pattern[1] == '~' { + cleanPattern = pattern[2:] + caseInsensitive = false + return cleanPattern, caseInsensitive, true + } + } + + return pattern, false, false +} + +// MustCompileRegex 编译正则表达式,失败返回原始字符串 +func MustCompileRegex(pattern string) *regexp.Regexp { + re, err := regexp.Compile(pattern) + if err != nil { + return nil + } + return re +} diff --git a/internal/matcher/matcher.go b/internal/matcher/matcher.go new file mode 100644 index 0000000..cc58ece --- /dev/null +++ b/internal/matcher/matcher.go @@ -0,0 +1,20 @@ +package matcher + +import "github.com/valyala/fasthttp" + +// MatchResult 匹配结果 +type MatchResult struct { + Handler fasthttp.RequestHandler + Path string + Priority int + LocationType string + + // 正则捕获组 + Captures map[string]string +} + +// Matcher 接口 +type Matcher interface { + Match(path string) bool + Result() *MatchResult +} diff --git a/internal/matcher/matcher_test.go b/internal/matcher/matcher_test.go new file mode 100644 index 0000000..5341512 --- /dev/null +++ b/internal/matcher/matcher_test.go @@ -0,0 +1,89 @@ +package matcher + +import ( + "testing" + + "github.com/valyala/fasthttp" +) + +func TestExactMatcher_Match(t *testing.T) { + handler := func(ctx *fasthttp.RequestCtx) {} + m := NewExactMatcher("/api", handler, 1) + + if !m.Match("/api") { + t.Error("should match exact path") + } + if m.Match("/api/users") { + t.Error("should not match different path") + } +} + +func TestRegexMatcher_Match(t *testing.T) { + m := MustRegexMatcher(`\.php$`, nil, 3, false) + + if !m.Match("/index.php") { + t.Error("should match .php") + } + if m.Match("/index.html") { + t.Error("should not match .html") + } +} + +func TestRegexMatcher_GetCaptures(t *testing.T) { + m := MustRegexMatcher(`^/user/(?P[0-9]+)$`, nil, 3, false) + + captures := m.GetCaptures("/user/123") + if captures["id"] != "123" { + t.Errorf("expected id=123, got %s", captures["id"]) + } +} + +func TestRegexMatcher_GetCaptures_NoMatch(t *testing.T) { + m := MustRegexMatcher(`^/user/(?P[0-9]+)$`, nil, 3, false) + + captures := m.GetCaptures("/user/abc") + if captures != nil { + t.Errorf("expected nil captures for non-matching path, got %v", captures) + } +} + +func TestRegexMatcher_CaseInsensitive(t *testing.T) { + // caseInsensitive flag only affects Result().LocationType, not matching + m := MustRegexMatcher(`\.php$`, nil, 3, true) + + if !m.Match("/index.php") { + t.Error("should match .php") + } + // Go regexp is case-sensitive by default; flag is metadata only + if m.Match("/index.PHP") { + t.Error("case insensitive flag is metadata only, should not match .PHP") + } + + result := m.Result() + if result.LocationType != "regex_caseless" { + t.Errorf("expected regex_caseless, got %s", result.LocationType) + } +} + +func TestRegexMatcher_Result_LocationType(t *testing.T) { + // Case sensitive + m := MustRegexMatcher(`\.php$`, nil, 3, false) + result := m.Result() + if result.LocationType != "regex" { + t.Errorf("expected location type 'regex', got %s", result.LocationType) + } + + // Case insensitive + m2 := MustRegexMatcher(`\.php$`, nil, 3, true) + result2 := m2.Result() + if result2.LocationType != "regex_caseless" { + t.Errorf("expected location type 'regex_caseless', got %s", result2.LocationType) + } +} + +func TestNewRegexMatcher_InvalidPattern(t *testing.T) { + _, err := NewRegexMatcher(`[invalid`, nil, 3, false) + if err == nil { + t.Error("expected error for invalid regex pattern") + } +} diff --git a/internal/matcher/named.go b/internal/matcher/named.go new file mode 100644 index 0000000..0fa3af2 --- /dev/null +++ b/internal/matcher/named.go @@ -0,0 +1,37 @@ +package matcher + +import "github.com/valyala/fasthttp" + +// NamedMatcher @命名 location +type NamedMatcher struct { + name string + handler fasthttp.RequestHandler +} + +// NewNamedMatcher 创建命名匹配器 +func NewNamedMatcher(name string, handler fasthttp.RequestHandler) *NamedMatcher { + return &NamedMatcher{ + name: name, + handler: handler, + } +} + +// Match 检查命名是否匹配(命名 location 不使用 path 匹配) +func (m *NamedMatcher) Match(path string) bool { + return false +} + +// Result 返回匹配结果 +func (m *NamedMatcher) Result() *MatchResult { + return &MatchResult{ + Handler: m.handler, + Path: "@" + m.name, + Priority: 0, + LocationType: "named", + } +} + +// Name 返回命名 location 的名称 +func (m *NamedMatcher) Name() string { + return m.name +} diff --git a/internal/matcher/prefix.go b/internal/matcher/prefix.go new file mode 100644 index 0000000..d9fb2f5 --- /dev/null +++ b/internal/matcher/prefix.go @@ -0,0 +1,32 @@ +package matcher + +import "github.com/valyala/fasthttp" + +// PrefixMatcher 普通前缀匹配器(封装 RadixTree) +type PrefixMatcher struct { + tree *RadixTree + priority int +} + +// NewPrefixMatcher 创建前缀匹配器 +func NewPrefixMatcher() *PrefixMatcher { + return &PrefixMatcher{ + tree: NewRadixTree(), + priority: 4, // 普通前缀优先级 + } +} + +// AddPath 添加路径 +func (pm *PrefixMatcher) AddPath(path string, handler fasthttp.RequestHandler) error { + return pm.tree.Insert(path, handler, pm.priority) +} + +// Match 前缀匹配,返回最长前缀匹配结果 +func (pm *PrefixMatcher) Match(path string) *MatchResult { + return pm.tree.FindLongestPrefix(path) +} + +// MarkInitialized 标记初始化完成 +func (pm *PrefixMatcher) MarkInitialized() { + pm.tree.MarkInitialized() +} diff --git a/internal/matcher/prefix_priority.go b/internal/matcher/prefix_priority.go new file mode 100644 index 0000000..421508b --- /dev/null +++ b/internal/matcher/prefix_priority.go @@ -0,0 +1,32 @@ +package matcher + +import "github.com/valyala/fasthttp" + +// PrefixPriorityMatcher ^~ 类型前缀优先匹配器(封装 RadixTree) +type PrefixPriorityMatcher struct { + tree *RadixTree + priority int +} + +// NewPrefixPriorityMatcher 创建前缀优先匹配器 +func NewPrefixPriorityMatcher() *PrefixPriorityMatcher { + return &PrefixPriorityMatcher{ + tree: NewRadixTree(), + priority: 2, // ^~ 类型优先级更高 + } +} + +// AddPath 添加路径 +func (ppm *PrefixPriorityMatcher) AddPath(path string, handler fasthttp.RequestHandler) error { + return ppm.tree.Insert(path, handler, ppm.priority) +} + +// Match 前缀优先匹配,返回最长前缀匹配结果 +func (ppm *PrefixPriorityMatcher) Match(path string) *MatchResult { + return ppm.tree.FindLongestPrefix(path) +} + +// MarkInitialized 标记初始化完成 +func (ppm *PrefixPriorityMatcher) MarkInitialized() { + ppm.tree.MarkInitialized() +} diff --git a/internal/matcher/radix.go b/internal/matcher/radix.go new file mode 100644 index 0000000..538c60c --- /dev/null +++ b/internal/matcher/radix.go @@ -0,0 +1,203 @@ +package matcher + +import ( + "errors" + "strings" + + "github.com/valyala/fasthttp" +) + +// RadixNode Radix Tree 节点 +type RadixNode struct { + prefix string + handler fasthttp.RequestHandler + children []*RadixNode + isLeaf bool + priority int +} + +// RadixTree 前缀匹配 Radix Tree +type RadixTree struct { + root *RadixNode + initialized bool +} + +// NewRadixTree 创建新 Radix Tree +func NewRadixTree() *RadixTree { + return &RadixTree{ + root: &RadixNode{prefix: ""}, + } +} + +// Insert 插入路径到 Radix Tree(startup-only) +func (t *RadixTree) Insert(path string, handler fasthttp.RequestHandler, priority int) error { + if t.initialized { + return errors.New("RadixTree already initialized") + } + return t.insertNode(nil, t.root, path, handler, priority) +} + +// insertNode 完整路径分割插入算法 +func (t *RadixTree) insertNode(parent *RadixNode, node *RadixNode, path string, handler fasthttp.RequestHandler, priority int) error { + // Case 1: 空节点(根节点),直接设置 + if node.prefix == "" && len(node.children) == 0 && node.handler == nil { + if path == "" { + node.handler = handler + node.priority = priority + node.isLeaf = true + return nil + } + // 创建新子节点 + newNode := &RadixNode{ + prefix: path, + handler: handler, + isLeaf: true, + priority: priority, + } + node.children = append(node.children, newNode) + return nil + } + + // Case 2: 计算公共前缀长度 + commonLen := 0 + maxLen := minInt(len(node.prefix), len(path)) + for commonLen < maxLen && node.prefix[commonLen] == path[commonLen] { + commonLen++ + } + + // Case 3: path 完全匹配节点前缀 + if commonLen == len(node.prefix) { + remaining := path[commonLen:] + + if remaining == "" { + // 路径完全匹配,设置 handler + if node.handler != nil { + return errors.New("path already exists") + } + node.handler = handler + node.priority = priority + node.isLeaf = true + return nil + } + + // 搜索匹配剩余路径的子节点 + for _, child := range node.children { + if strings.HasPrefix(remaining, child.prefix) { + return t.insertNode(node, child, remaining, handler, priority) + } + } + + // 无匹配子节点,创建新子节点 + newNode := &RadixNode{ + prefix: remaining, + handler: handler, + isLeaf: true, + priority: priority, + } + node.children = append(node.children, newNode) + return nil + } + + // Case 4: 需要分割节点(公共前缀 < 节点前缀) + // 创建中间节点保存公共前缀 + splitNode := &RadixNode{ + prefix: node.prefix[:commonLen], + children: []*RadixNode{}, + } + + // 修改原节点为公共前缀之后的部分 + node.prefix = node.prefix[commonLen:] + + // 创建新节点保存剩余路径 + newNode := &RadixNode{ + prefix: path[commonLen:], + handler: handler, + isLeaf: true, + priority: priority, + } + + // 将原节点和新节点作为 splitNode 的子节点 + splitNode.children = append(splitNode.children, node) + splitNode.children = append(splitNode.children, newNode) + + // 替换父节点的子节点引用 + if parent == nil { + t.root = splitNode + } else { + for i, child := range parent.children { + if child == node { + parent.children[i] = splitNode + break + } + } + } + + return nil +} + +// FindLongestPrefix 查找最长前缀匹配 +func (t *RadixTree) FindLongestPrefix(path string) *MatchResult { + return t.searchLongest(t.root, path, nil) +} + +// searchLongest 递归搜索最长前缀匹配 +func (t *RadixTree) searchLongest(node *RadixNode, path string, bestMatch *MatchResult) *MatchResult { + if node == nil || path == "" { + return bestMatch + } + + // 检查是否匹配节点前缀 + if !strings.HasPrefix(path, node.prefix) { + return bestMatch + } + + remaining := path[len(node.prefix):] + + // 如果节点有 handler,更新最佳匹配 + if node.handler != nil { + newMatch := &MatchResult{ + Handler: node.handler, + Path: node.prefix, + Priority: node.priority, + LocationType: "prefix", + } + + // nil-safe 优先级比较 + 长度比较 + if bestMatch == nil { + bestMatch = newMatch + } else if node.priority < bestMatch.Priority { + bestMatch = newMatch + } else if node.priority == bestMatch.Priority && len(node.prefix) > len(bestMatch.Path) { + bestMatch = newMatch + } + } + + // 继续搜索子节点 + for _, child := range node.children { + childMatch := t.searchLongest(child, remaining, bestMatch) + if childMatch != nil { + // nil-safe 比较 + if bestMatch == nil { + bestMatch = childMatch + } else if childMatch.Priority < bestMatch.Priority { + bestMatch = childMatch + } else if childMatch.Priority == bestMatch.Priority && len(childMatch.Path) > len(bestMatch.Path) { + bestMatch = childMatch + } + } + } + + return bestMatch +} + +// MarkInitialized 标记初始化完成 +func (t *RadixTree) MarkInitialized() { + t.initialized = true +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/internal/matcher/radix_test.go b/internal/matcher/radix_test.go new file mode 100644 index 0000000..6bd600d --- /dev/null +++ b/internal/matcher/radix_test.go @@ -0,0 +1,137 @@ +package matcher + +import ( + "testing" + + "github.com/valyala/fasthttp" +) + +func TestRadixTree_Insert_EmptyNode(t *testing.T) { + // Case 1: 空节点 + tree := NewRadixTree() + handler := func(ctx *fasthttp.RequestCtx) {} + + err := tree.Insert("/api", handler, 1) + if err != nil { + t.Fatalf("insert failed: %v", err) + } + + result := tree.FindLongestPrefix("/api") + if result == nil { + t.Error("should find inserted path") + } + if result.Path != "/api" { + t.Errorf("expected path /api, got %s", result.Path) + } +} + +func TestRadixTree_Insert_CommonPrefix(t *testing.T) { + // Case 2: 公共前缀计算 + tree := NewRadixTree() + handler1 := func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("1") } + handler2 := func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("2") } + + tree.Insert("/api", handler1, 1) + tree.Insert("/api/users", handler2, 2) + + result := tree.FindLongestPrefix("/api/users") + if result == nil { + t.Fatal("expected match") + } + // Lower priority number wins, so /api (priority 1) beats /api/users (priority 2) + if result.Path != "/api" { + t.Errorf("expected path /api (priority 1), got %s", result.Path) + } + if result.Priority != 1 { + t.Errorf("expected priority 1, got %d", result.Priority) + } +} + +func TestRadixTree_Insert_NodeSplit(t *testing.T) { + // Case 4: 节点分割 + tree := NewRadixTree() + handler1 := func(ctx *fasthttp.RequestCtx) {} + handler2 := func(ctx *fasthttp.RequestCtx) {} + + tree.Insert("/abc", handler1, 1) + tree.Insert("/abx", handler2, 2) + + // 应该正确分割 /ab 公共前缀 + result := tree.FindLongestPrefix("/abc") + if result == nil { + t.Error("should find /abc after split") + } +} + +func TestRadixTree_FindLongestPrefix(t *testing.T) { + tree := NewRadixTree() + handler := func(ctx *fasthttp.RequestCtx) {} + + tree.Insert("/", handler, 1) + tree.Insert("/api", handler, 2) + tree.Insert("/api/v1", handler, 3) + + // "/" has priority 1 (wins), "/api" has 2, "/api/v1" has 3 + // Lower number = higher priority + result := tree.FindLongestPrefix("/api/v1/users") + if result == nil { + t.Fatal("expected match") + } + if result.Path != "/" { + t.Errorf("expected / (priority 1 wins), got %s", result.Path) + } +} + +func TestRadixTree_Insert_AfterInitialized(t *testing.T) { + tree := NewRadixTree() + handler := func(ctx *fasthttp.RequestCtx) {} + + tree.Insert("/api", handler, 1) + tree.MarkInitialized() + + err := tree.Insert("/api/v2", handler, 2) + if err == nil { + t.Error("should fail when inserting after initialized") + } +} + +func TestRadixTree_Insert_DuplicatePath(t *testing.T) { + tree := NewRadixTree() + handler := func(ctx *fasthttp.RequestCtx) {} + + tree.Insert("/api", handler, 1) + err := tree.Insert("/api", handler, 2) + if err == nil { + t.Error("should fail on duplicate path") + } +} + +func TestRadixTree_FindLongestPrefix_NoMatch(t *testing.T) { + tree := NewRadixTree() + handler := func(ctx *fasthttp.RequestCtx) {} + + tree.Insert("/api", handler, 1) + + result := tree.FindLongestPrefix("/other") + if result != nil { + t.Errorf("expected nil for no match, got %+v", result) + } +} + +func TestRadixTree_PriorityComparison(t *testing.T) { + tree := NewRadixTree() + h1 := func(ctx *fasthttp.RequestCtx) {} + h2 := func(ctx *fasthttp.RequestCtx) {} + + tree.Insert("/api", h1, 5) + tree.Insert("/api/users", h2, 2) + + // Lower priority number wins + result := tree.FindLongestPrefix("/api/users") + if result == nil { + t.Fatal("expected match") + } + if result.Priority != 2 { + t.Errorf("expected priority 2, got %d", result.Priority) + } +} diff --git a/internal/matcher/regex.go b/internal/matcher/regex.go new file mode 100644 index 0000000..008e7bb --- /dev/null +++ b/internal/matcher/regex.go @@ -0,0 +1,81 @@ +package matcher + +import ( + "regexp" + + "github.com/valyala/fasthttp" +) + +// RegexMatcher 正则匹配 + 捕获组 +type RegexMatcher struct { + pattern *regexp.Regexp + handler fasthttp.RequestHandler + priority int + caseInsensitive bool + captures map[string]string +} + +// NewRegexMatcher 创建正则匹配器 +func NewRegexMatcher(pattern string, handler fasthttp.RequestHandler, priority int, caseInsensitive bool) (*RegexMatcher, error) { + re, err := regexp.Compile(pattern) + if err != nil { + return nil, err + } + + return &RegexMatcher{ + pattern: re, + handler: handler, + priority: priority, + caseInsensitive: caseInsensitive, + captures: make(map[string]string), + }, nil +} + +// MustRegexMatcher 创建正则匹配器,失败时 panic +func MustRegexMatcher(pattern string, handler fasthttp.RequestHandler, priority int, caseInsensitive bool) *RegexMatcher { + m, err := NewRegexMatcher(pattern, handler, priority, caseInsensitive) + if err != nil { + panic(err) + } + return m +} + +// Match 检查路径是否正则匹配 +func (m *RegexMatcher) Match(path string) bool { + return m.pattern.MatchString(path) +} + +// Result 返回匹配结果 +func (m *RegexMatcher) Result() *MatchResult { + locType := "regex" + if m.caseInsensitive { + locType = "regex_caseless" + } + return &MatchResult{ + Handler: m.handler, + Path: m.pattern.String(), + Priority: m.priority, + LocationType: locType, + Captures: m.captures, + } +} + +// GetCaptures 获取正则捕获组 +func (m *RegexMatcher) GetCaptures(path string) map[string]string { + matches := m.pattern.FindStringSubmatch(path) + if matches == nil { + return nil + } + + result := make(map[string]string) + names := m.pattern.SubexpNames() + for i, name := range names { + if i == 0 { + continue + } + if name != "" && i < len(matches) { + result[name] = matches[i] + } + } + return result +}