diff --git a/internal/middleware/security/ratelimit.go b/internal/middleware/security/ratelimit.go index 673e3e9..753060b 100644 --- a/internal/middleware/security/ratelimit.go +++ b/internal/middleware/security/ratelimit.go @@ -132,14 +132,11 @@ func newTokenBucketLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) { } // 根据配置设置键提取函数 - switch cfg.Key { - case "ip", "": - rl.keyFunc = keyByIP - case rateLimitHeader: - rl.keyFunc = keyByHeader - default: - return nil, fmt.Errorf("unknown key type: %s", cfg.Key) + keyFunc, err := parseKeyFunc(cfg.Key) + if err != nil { + return nil, err } + rl.keyFunc = keyFunc // 启动后台清理 goroutine rl.startCleanup(10 * time.Minute) @@ -155,14 +152,9 @@ type SlidingWindowLimiterWrapper struct { // NewSlidingWindowLimiterWrapper 创建滑动窗口限流器包装。 func NewSlidingWindowLimiterWrapper(cfg *config.RateLimitConfig, window time.Duration, precise bool) (*SlidingWindowLimiterWrapper, error) { - var keyFunc KeyFunc - switch cfg.Key { - case "ip", "": - keyFunc = keyByIP - case rateLimitHeader: - keyFunc = keyByHeader - default: - return nil, fmt.Errorf("unknown key type: %s", cfg.Key) + keyFunc, err := parseKeyFunc(cfg.Key) + if err != nil { + return nil, err } return &SlidingWindowLimiterWrapper{ @@ -356,6 +348,29 @@ func keyByHeader(ctx *fasthttp.RequestCtx) string { return string(key) } +// parseKeyFunc 根据配置字符串解析键提取函数。 +// +// 支持的键类型: +// - "ip" 或 "": 使用客户端 IP 作为键 +// - "header": 使用 X-RateLimit-Key 头部值作为键 +// +// 参数: +// - keyType: 键类型字符串 +// +// 返回值: +// - KeyFunc: 键提取函数 +// - error: 未知的键类型时返回错误 +func parseKeyFunc(keyType string) (KeyFunc, error) { + switch keyType { + case "ip", "": + return keyByIP, nil + case rateLimitHeader: + return keyByHeader, nil + default: + return nil, fmt.Errorf("unknown key type: %s", keyType) + } +} + // Reset 重置指定键的令牌桶。 // // 删除该键的桶记录,下次请求时将重新创建满载的桶。 @@ -499,14 +514,11 @@ func NewConnLimiter(maxConns int, perKey bool, keyType string) (*ConnLimiter, er } if perKey { - switch keyType { - case "ip", "": - cl.keyFunc = keyByIP - case rateLimitHeader: - cl.keyFunc = keyByHeader - default: - return nil, fmt.Errorf("unknown key type: %s", keyType) + keyFunc, err := parseKeyFunc(keyType) + if err != nil { + return nil, err } + cl.keyFunc = keyFunc } return cl, nil