refactor(cache): 统一路径匹配函数并增强通配符支持
- 删除 file_cache.go 中的 pathMatch() 函数 - 导出 purge.go 中的 MatchPattern() 函数 - 增强 MatchPattern() 支持中间通配符(如 /api/*/users) - 使用 netutil.ExtractClientIPNet() 替代内联 IP 提取逻辑 - 适配 status 模块使用新的工具函数 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
616762e840
commit
931144dd08
4
internal/cache/cache_test.go
vendored
4
internal/cache/cache_test.go
vendored
@ -339,9 +339,9 @@ func TestPathMatch(t *testing.T) {
|
||||
if pattern[0] != '/' && pattern != "*" {
|
||||
pattern = "/" + pattern
|
||||
}
|
||||
result := pathMatch(pattern, tt.path)
|
||||
result := MatchPattern(pattern, tt.path)
|
||||
if result != tt.want {
|
||||
t.Errorf("pathMatch(%s, %s) = %v, want %v", pattern, tt.path, result, tt.want)
|
||||
t.Errorf("MatchPattern(%s, %s) = %v, want %v", pattern, tt.path, result, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
34
internal/cache/file_cache.go
vendored
34
internal/cache/file_cache.go
vendored
@ -19,7 +19,6 @@ package cache
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@ -425,7 +424,7 @@ func (c *ProxyCache) ReleaseLock(hashKey uint64, err error) {
|
||||
func (c *ProxyCache) MatchRule(path, method string, status int) *ProxyCacheRule {
|
||||
for _, rule := range c.rules {
|
||||
// 检查路径匹配(简单前缀匹配)
|
||||
if rule.Path != "" && !pathMatch(rule.Path, path) {
|
||||
if rule.Path != "" && !MatchPattern(rule.Path, path) {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -444,37 +443,6 @@ func (c *ProxyCache) MatchRule(path, method string, status int) *ProxyCacheRule
|
||||
return nil
|
||||
}
|
||||
|
||||
// pathMatch 检查路径是否匹配指定模式。
|
||||
//
|
||||
// 支持以下匹配模式:
|
||||
// - "*":匹配所有路径
|
||||
// - 以 "*" 结尾:前缀匹配(如 "/api/*" 匹配 "/api/xxx")
|
||||
// - 以 "/" 结尾:目录前缀匹配
|
||||
// - 其他:精确匹配
|
||||
//
|
||||
// 参数:
|
||||
// - pattern: 匹配模式,支持通配符
|
||||
// - path: 待检查的路径
|
||||
//
|
||||
// 返回值:
|
||||
// - bool: true 表示匹配,false 表示不匹配
|
||||
func pathMatch(pattern, path string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
// 通配符匹配
|
||||
if pattern[len(pattern)-1] == '*' {
|
||||
prefix := pattern[:len(pattern)-1]
|
||||
return strings.HasPrefix(path, prefix)
|
||||
}
|
||||
// 前缀匹配(pattern 以 / 结尾)
|
||||
if pattern[len(pattern)-1] == '/' {
|
||||
return strings.HasPrefix(path, pattern)
|
||||
}
|
||||
// 精确匹配
|
||||
return path == pattern
|
||||
}
|
||||
|
||||
// contains 检查字符串切片是否包含某值。
|
||||
//
|
||||
// 参数:
|
||||
|
||||
59
internal/cache/purge.go
vendored
59
internal/cache/purge.go
vendored
@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
)
|
||||
|
||||
// PurgeAPI 缓存清理 API 处理器。
|
||||
@ -211,34 +212,7 @@ func (p *PurgeAPI) checkAuth(ctx *fasthttp.RequestCtx) bool {
|
||||
|
||||
// getClientIP 从请求上下文提取客户端 IP。
|
||||
func (p *PurgeAPI) getClientIP(ctx *fasthttp.RequestCtx) net.IP {
|
||||
// 检查 X-Forwarded-For 头部
|
||||
if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 {
|
||||
ips := strings.Split(string(xff), ",")
|
||||
if len(ips) > 0 {
|
||||
ipStr := strings.TrimSpace(ips[0])
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 X-Real-IP 头部
|
||||
if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 {
|
||||
ip := net.ParseIP(string(xri))
|
||||
if ip != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// 使用 RemoteAddr
|
||||
if addr := ctx.RemoteAddr(); addr != nil {
|
||||
if tcpAddr, ok := addr.(*net.TCPAddr); ok {
|
||||
return tcpAddr.IP
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return netutil.ExtractClientIPNet(ctx)
|
||||
}
|
||||
|
||||
// purgeByPath 按精确路径清理缓存。
|
||||
@ -293,14 +267,32 @@ func hashPath(path string) uint64 {
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// matchPattern 检查路径是否匹配通配符模式。
|
||||
// 仅支持 * 通配符,匹配任意字符。
|
||||
func matchPattern(pattern, path string) bool {
|
||||
// MatchPattern 检查路径是否匹配通配符模式。
|
||||
//
|
||||
// 支持以下匹配模式:
|
||||
// - "*":匹配所有路径
|
||||
// - 以 "*" 结尾:前缀匹配(如 "/api/*" 匹配 "/api/xxx")
|
||||
// - 以 "/" 结尾:目录前缀匹配
|
||||
// - 中间通配符:"/api/*/users" 匹配 "/api/v1/users"
|
||||
// - 其他:精确匹配
|
||||
//
|
||||
// 参数:
|
||||
// - pattern: 匹配模式,支持通配符
|
||||
// - path: 待检查的路径
|
||||
//
|
||||
// 返回值:
|
||||
// - bool: true 表示匹配,false 表示不匹配
|
||||
func MatchPattern(pattern, path string) bool {
|
||||
// 特殊情况:* 匹配所有
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
// 目录前缀匹配(pattern 以 / 结尾)
|
||||
if strings.HasSuffix(pattern, "/") {
|
||||
return strings.HasPrefix(path, pattern)
|
||||
}
|
||||
|
||||
// 检查是否有通配符
|
||||
if !strings.Contains(pattern, "*") {
|
||||
return path == pattern
|
||||
@ -321,6 +313,11 @@ func matchPattern(pattern, path string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// matchPattern 是 MatchPattern 的内部别名,保持向后兼容。
|
||||
func matchPattern(pattern, path string) bool {
|
||||
return MatchPattern(pattern, path)
|
||||
}
|
||||
|
||||
// sendError 发送错误响应。
|
||||
func (p *PurgeAPI) sendError(ctx *fasthttp.RequestCtx, status int, errMsg string) {
|
||||
ctx.SetContentType("application/json; charset=utf-8")
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
)
|
||||
|
||||
// StatusHandler 状态监控处理器。
|
||||
@ -333,7 +334,7 @@ func (h *StatusHandler) checkAccess(ctx *fasthttp.RequestCtx) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
clientIP := getClientIPForStatus(ctx)
|
||||
clientIP := netutil.ExtractClientIPNet(ctx)
|
||||
|
||||
// 检查是否在允许列表中
|
||||
for _, network := range h.allowed {
|
||||
@ -345,47 +346,6 @@ func (h *StatusHandler) checkAccess(ctx *fasthttp.RequestCtx) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// getClientIPForStatus 从请求上下文提取客户端 IP。
|
||||
//
|
||||
// 按优先级依次检查:X-Forwarded-For、X-Real-IP、RemoteAddr。
|
||||
// 用于状态端点的 IP 访问控制。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
//
|
||||
// 返回值:
|
||||
// - net.IP: 客户端 IP 地址,无法获取时返回 nil
|
||||
func getClientIPForStatus(ctx *fasthttp.RequestCtx) net.IP {
|
||||
// 检查 X-Forwarded-For 头部
|
||||
if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 {
|
||||
ips := strings.Split(string(xff), ",")
|
||||
if len(ips) > 0 {
|
||||
ipStr := strings.TrimSpace(ips[0])
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 X-Real-IP 头部
|
||||
if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 {
|
||||
ip := net.ParseIP(string(xri))
|
||||
if ip != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// 使用 RemoteAddr
|
||||
if addr := ctx.RemoteAddr(); addr != nil {
|
||||
if tcpAddr, ok := addr.(*net.TCPAddr); ok {
|
||||
return tcpAddr.IP
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectStatus 收集服务器状态数据。
|
||||
//
|
||||
// 从服务器实例读取各项统计指标,构建状态响应对象。
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
)
|
||||
|
||||
func TestNewStatusHandler_CIDR(t *testing.T) {
|
||||
@ -340,7 +341,7 @@ func TestGetClientIPForStatus_XForwardedFor(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("X-Forwarded-For", tt.xff)
|
||||
|
||||
gotIP := getClientIPForStatus(ctx)
|
||||
gotIP := netutil.ExtractClientIPNet(ctx)
|
||||
if gotIP == nil {
|
||||
t.Errorf("expected IP %s, got nil", tt.wantIP)
|
||||
} else if gotIP.String() != tt.wantIP {
|
||||
@ -377,7 +378,7 @@ func TestGetClientIPForStatus_XRealIP(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("X-Real-IP", tt.xri)
|
||||
|
||||
gotIP := getClientIPForStatus(ctx)
|
||||
gotIP := netutil.ExtractClientIPNet(ctx)
|
||||
if gotIP == nil {
|
||||
t.Errorf("expected IP %s, got nil", tt.wantIP)
|
||||
} else if gotIP.String() != tt.wantIP {
|
||||
@ -393,7 +394,7 @@ func TestGetClientIPForStatus_Priority(t *testing.T) {
|
||||
ctx.Request.Header.Set("X-Forwarded-For", "10.0.0.1")
|
||||
ctx.Request.Header.Set("X-Real-IP", "10.0.0.2")
|
||||
|
||||
gotIP := getClientIPForStatus(ctx)
|
||||
gotIP := netutil.ExtractClientIPNet(ctx)
|
||||
if gotIP == nil {
|
||||
t.Error("expected IP, got nil")
|
||||
} else if gotIP.String() != "10.0.0.1" {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user