refactor(middleware): 提取限流 key 解析函数

- 新增 parseKeyFunc() 统一处理 key type 解析逻辑
- 消除重复代码,提高可维护性

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-10 11:20:43 +08:00
parent a965040eff
commit 4d608c4284

View File

@ -132,14 +132,11 @@ func newTokenBucketLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) {
}
// 根据配置设置键提取函数
switch cfg.Key {
case "ip", "":
rl.keyFunc = keyByIP
case rateLimitHeader:
rl.keyFunc = keyByHeader
default:
return nil, fmt.Errorf("unknown key type: %s", cfg.Key)
keyFunc, err := parseKeyFunc(cfg.Key)
if err != nil {
return nil, err
}
rl.keyFunc = keyFunc
// 启动后台清理 goroutine
rl.startCleanup(10 * time.Minute)
@ -155,14 +152,9 @@ type SlidingWindowLimiterWrapper struct {
// NewSlidingWindowLimiterWrapper 创建滑动窗口限流器包装。
func NewSlidingWindowLimiterWrapper(cfg *config.RateLimitConfig, window time.Duration, precise bool) (*SlidingWindowLimiterWrapper, error) {
var keyFunc KeyFunc
switch cfg.Key {
case "ip", "":
keyFunc = keyByIP
case rateLimitHeader:
keyFunc = keyByHeader
default:
return nil, fmt.Errorf("unknown key type: %s", cfg.Key)
keyFunc, err := parseKeyFunc(cfg.Key)
if err != nil {
return nil, err
}
return &SlidingWindowLimiterWrapper{
@ -356,6 +348,29 @@ func keyByHeader(ctx *fasthttp.RequestCtx) string {
return string(key)
}
// parseKeyFunc 根据配置字符串解析键提取函数。
//
// 支持的键类型:
// - "ip" 或 "": 使用客户端 IP 作为键
// - "header": 使用 X-RateLimit-Key 头部值作为键
//
// 参数:
// - keyType: 键类型字符串
//
// 返回值:
// - KeyFunc: 键提取函数
// - error: 未知的键类型时返回错误
func parseKeyFunc(keyType string) (KeyFunc, error) {
switch keyType {
case "ip", "":
return keyByIP, nil
case rateLimitHeader:
return keyByHeader, nil
default:
return nil, fmt.Errorf("unknown key type: %s", keyType)
}
}
// Reset 重置指定键的令牌桶。
//
// 删除该键的桶记录,下次请求时将重新创建满载的桶。
@ -499,14 +514,11 @@ func NewConnLimiter(maxConns int, perKey bool, keyType string) (*ConnLimiter, er
}
if perKey {
switch keyType {
case "ip", "":
cl.keyFunc = keyByIP
case rateLimitHeader:
cl.keyFunc = keyByHeader
default:
return nil, fmt.Errorf("unknown key type: %s", keyType)
keyFunc, err := parseKeyFunc(keyType)
if err != nil {
return nil, err
}
cl.keyFunc = keyFunc
}
return cl, nil