From 4d608c4284e23be9ea2883b9e69beb93fac87bf0 Mon Sep 17 00:00:00 2001 From: xfy Date: Fri, 10 Apr 2026 11:20:43 +0800 Subject: [PATCH] =?UTF-8?q?refactor(middleware):=20=E6=8F=90=E5=8F=96?= =?UTF-8?q?=E9=99=90=E6=B5=81=20key=20=E8=A7=A3=E6=9E=90=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 parseKeyFunc() 统一处理 key type 解析逻辑 - 消除重复代码,提高可维护性 Co-Authored-By: Claude Opus 4.6 --- internal/middleware/security/ratelimit.go | 56 ++++++++++++++--------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/internal/middleware/security/ratelimit.go b/internal/middleware/security/ratelimit.go index 673e3e9..753060b 100644 --- a/internal/middleware/security/ratelimit.go +++ b/internal/middleware/security/ratelimit.go @@ -132,14 +132,11 @@ func newTokenBucketLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) { } // 根据配置设置键提取函数 - switch cfg.Key { - case "ip", "": - rl.keyFunc = keyByIP - case rateLimitHeader: - rl.keyFunc = keyByHeader - default: - return nil, fmt.Errorf("unknown key type: %s", cfg.Key) + keyFunc, err := parseKeyFunc(cfg.Key) + if err != nil { + return nil, err } + rl.keyFunc = keyFunc // 启动后台清理 goroutine rl.startCleanup(10 * time.Minute) @@ -155,14 +152,9 @@ type SlidingWindowLimiterWrapper struct { // NewSlidingWindowLimiterWrapper 创建滑动窗口限流器包装。 func NewSlidingWindowLimiterWrapper(cfg *config.RateLimitConfig, window time.Duration, precise bool) (*SlidingWindowLimiterWrapper, error) { - var keyFunc KeyFunc - switch cfg.Key { - case "ip", "": - keyFunc = keyByIP - case rateLimitHeader: - keyFunc = keyByHeader - default: - return nil, fmt.Errorf("unknown key type: %s", cfg.Key) + keyFunc, err := parseKeyFunc(cfg.Key) + if err != nil { + return nil, err } return &SlidingWindowLimiterWrapper{ @@ -356,6 +348,29 @@ func keyByHeader(ctx *fasthttp.RequestCtx) string { return string(key) } +// parseKeyFunc 根据配置字符串解析键提取函数。 +// +// 支持的键类型: +// - "ip" 或 "": 使用客户端 IP 作为键 +// - "header": 使用 X-RateLimit-Key 头部值作为键 +// +// 参数: +// - keyType: 键类型字符串 +// +// 返回值: +// - KeyFunc: 键提取函数 +// - error: 未知的键类型时返回错误 +func parseKeyFunc(keyType string) (KeyFunc, error) { + switch keyType { + case "ip", "": + return keyByIP, nil + case rateLimitHeader: + return keyByHeader, nil + default: + return nil, fmt.Errorf("unknown key type: %s", keyType) + } +} + // Reset 重置指定键的令牌桶。 // // 删除该键的桶记录,下次请求时将重新创建满载的桶。 @@ -499,14 +514,11 @@ func NewConnLimiter(maxConns int, perKey bool, keyType string) (*ConnLimiter, er } if perKey { - switch keyType { - case "ip", "": - cl.keyFunc = keyByIP - case rateLimitHeader: - cl.keyFunc = keyByHeader - default: - return nil, fmt.Errorf("unknown key type: %s", keyType) + keyFunc, err := parseKeyFunc(keyType) + if err != nil { + return nil, err } + cl.keyFunc = keyFunc } return cl, nil