diff --git a/internal/middleware/security/ratelimit.go b/internal/middleware/security/ratelimit.go index e18260c..bcc4fd3 100644 --- a/internal/middleware/security/ratelimit.go +++ b/internal/middleware/security/ratelimit.go @@ -49,13 +49,16 @@ import ( // // 注意事项: // - 所有方法均为并发安全 -// - 应定期调用 Cleanup 清理过期的桶 +// - 启动后会自动后台清理过期的桶 type RateLimiter struct { - rate float64 // 每秒添加的令牌数 - burst float64 // 桶的最大容量 - keyFunc KeyFunc // 提取限流键的函数 - buckets map[string]*tokenBucket // 各键的令牌桶映射 - mu sync.RWMutex // 读写锁,保护并发访问 + rate float64 // 每秒添加的令牌数 + burst float64 // 桶的最大容量 + keyFunc KeyFunc // 提取限流键的函数 + buckets map[string]*tokenBucket // 各键的令牌桶映射 + mu sync.RWMutex // 读写锁,保护并发访问 + cleanupTicker *time.Ticker // 清理定时器 + stopCleanupCh chan struct{} // 停止清理的信号通道 + cleanupDone chan struct{} // 清理 goroutine 完成的信号 } // tokenBucket 表示单个限流键的令牌桶。 @@ -119,9 +122,11 @@ func newTokenBucketLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) { } rl := &RateLimiter{ - rate: float64(cfg.RequestRate), - burst: float64(cfg.Burst), - buckets: make(map[string]*tokenBucket), + rate: float64(cfg.RequestRate), + burst: float64(cfg.Burst), + buckets: make(map[string]*tokenBucket), + stopCleanupCh: make(chan struct{}), + cleanupDone: make(chan struct{}), } // 根据配置设置键提取函数 @@ -134,6 +139,9 @@ func newTokenBucketLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) { return nil, fmt.Errorf("unknown key type: %s", cfg.Key) } + // 启动后台清理 goroutine + rl.startCleanup(10 * time.Minute) + return rl, nil } @@ -388,6 +396,46 @@ func (rl *RateLimiter) Cleanup(maxAge time.Duration) { } } +// startCleanup 启动后台清理 goroutine。 +// +// 定期清理超过 24 小时未更新的令牌桶。 +// 该方法在创建限流器时自动调用,无需手动调用。 +// +// 参数: +// - interval: 清理间隔时间 +func (rl *RateLimiter) startCleanup(interval time.Duration) { + rl.cleanupTicker = time.NewTicker(interval) + maxAge := 24 * time.Hour // 24 小时未更新则清理 + + go func() { + defer close(rl.cleanupDone) + for { + select { + case <-rl.cleanupTicker.C: + rl.Cleanup(maxAge) + case <-rl.stopCleanupCh: + return + } + } + }() +} + +// StopCleanup 优雅关闭后台清理 goroutine。 +// +// 发送停止信号并等待 goroutine 完成,确保资源正确释放。 +// 该方法应在限流器不再使用时调用(如服务器关闭时)。 +func (rl *RateLimiter) StopCleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + if rl.cleanupTicker != nil { + rl.cleanupTicker.Stop() + close(rl.stopCleanupCh) + <-rl.cleanupDone + rl.cleanupTicker = nil // 防止重复关闭 + } +} + // RateLimitStats 速率限制器统计信息。 type RateLimitStats struct { BucketCount int // 当前活跃的桶数量 diff --git a/internal/middleware/security/ratelimit_test.go b/internal/middleware/security/ratelimit_test.go index 89e39b4..f97324f 100644 --- a/internal/middleware/security/ratelimit_test.go +++ b/internal/middleware/security/ratelimit_test.go @@ -289,6 +289,72 @@ func TestRateLimiterGetStats(t *testing.T) { if stats.Burst != 200 { t.Errorf("Expected Burst 200, got %f", stats.Burst) } + + // 测试优雅关闭 + rl.StopCleanup() +} + +func TestRateLimiterAutoCleanup(t *testing.T) { + // 使用自定义的创建方式,方便测试 + cfg := &config.RateLimitConfig{ + RequestRate: 100, + Burst: 200, + Key: "ip", + } + + mw, err := NewRateLimiter(cfg) + if err != nil { + t.Fatalf("NewRateLimiter() error: %v", err) + } + + rl, ok := mw.(*RateLimiter) + if !ok { + t.Fatalf("Expected *RateLimiter, got %T", mw) + } + + // 创建一些桶 + rl.Allow("key1") + rl.Allow("key2") + rl.Allow("key3") + + // 验证桶已创建 + stats := rl.GetStats() + if stats.BucketCount != 3 { + t.Errorf("Expected 3 buckets, got %d", stats.BucketCount) + } + + // 手动调用 Cleanup 模拟过期清理(使用很短的过期时间) + rl.Cleanup(1 * time.Nanosecond) + + // 验证所有桶已被清理 + stats = rl.GetStats() + if stats.BucketCount != 0 { + t.Errorf("Expected 0 buckets after cleanup, got %d", stats.BucketCount) + } + + // 测试优雅关闭 + rl.StopCleanup() +} + +func TestRateLimiterStopCleanup(t *testing.T) { + mw, err := NewRateLimiter(&config.RateLimitConfig{ + RequestRate: 100, + Burst: 200, + }) + if err != nil { + t.Fatalf("NewRateLimiter() error: %v", err) + } + + rl, ok := mw.(*RateLimiter) + if !ok { + t.Fatalf("Expected *RateLimiter, got %T", mw) + } + + // 验证可以正常关闭 + rl.StopCleanup() + + // 再次调用不应 panic + rl.StopCleanup() } func TestNewConnLimiter(t *testing.T) {