// Package security 提供了 Lolly HTTP 服务器的安全相关中间件。 // // 该文件实现了基于令牌桶算法的请求速率限制中间件, // 支持按 IP 或按键值进行请求限流和连接数限制。 // // 主要功能: // - 请求速率限制:使用令牌桶算法控制请求频率 // - 突发流量处理:允许一定程度的请求突发 // - 多维度限流:支持按 IP、按头部键值等维度 // - 连接数限制:控制最大并发连接数 // // 使用示例: // // cfg := &config.RateLimitConfig{ // RequestRate: 100, // 每秒 100 个请求 // Burst: 200, // 允许突发到 200 个请求 // Key: "ip", // 按 IP 地址限流 // } // // limiter, err := security.NewRateLimiter(cfg) // if err != nil { // log.Fatal(err) // } // // // 作为中间件应用 // chain := middleware.NewChain(limiter) // handler := chain.Apply(finalHandler) // // 作者:xfy package security import ( "errors" "fmt" "sync" "sync/atomic" "time" "github.com/valyala/fasthttp" "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/hash" "rua.plus/lolly/internal/middleware" "rua.plus/lolly/internal/netutil" "rua.plus/lolly/internal/utils" ) const rateLimitHeader = "header" // shardedBucket 分段锁桶结构。 // // 每个分段锁桶包含一个独立的令牌桶映射和读写锁, // 用于减少单一 RWMutex 的锁竞争。 type shardedBucket struct { mu sync.RWMutex buckets map[string]*tokenBucket } // RateLimiter 基于令牌桶算法的请求速率限制器。 // // 实现请求限流功能,支持按 IP 或自定义键值进行限流。 // 令牌按配置的速率持续添加,每个请求消耗一个令牌。 // 采用 16 个分段锁桶结构,减少锁竞争,提高并发性能。 // // 注意事项: // - 所有方法均为并发安全 // - 启动后会自动后台清理过期的桶 type RateLimiter struct { shards [16]shardedBucket keyFunc KeyFunc cleanupTicker *time.Ticker stopCleanupCh chan struct{} cleanupDone chan struct{} stopOnce sync.Once rate float64 burst float64 } // tokenBucket 表示单个限流键的令牌桶。 // // 记录当前令牌数和最后更新时间,用于令牌计算。 type tokenBucket struct { lastUpdate time.Time tokens float64 mu sync.Mutex } // KeyFunc 从请求中提取限流键的函数类型。 // // 用于确定请求属于哪个限流桶,常见的实现包括按 IP、按头部值等。 type KeyFunc func(ctx *fasthttp.RequestCtx) string // NewRateLimiter 根据配置创建新的速率限制器。 // // 验证配置参数的有效性,并设置相应的限流键提取函数。 // // 参数: // - cfg: 限流配置,包含速率、突发量和键类型 // // 返回值: // - *RateLimiter: 配置好的限流器实例 // - error: 配置无效时返回错误(如速率小于 0) func NewRateLimiter(cfg *config.RateLimitConfig) (middleware.Middleware, error) { if cfg == nil { return nil, errors.New("rate limit config is nil") } if cfg.RequestRate <= 0 { return nil, errors.New("request rate must be positive") } // 根据算法选择限流器 algorithm := cfg.Algorithm if algorithm == "" { algorithm = "token_bucket" // 默认令牌桶 } switch algorithm { case "token_bucket", "": return newTokenBucketLimiter(cfg) case "sliding_window": window := time.Duration(cfg.SlidingWindow) * time.Second if window <= 0 { window = time.Second // 默认 1 秒窗口 } precise := cfg.SlidingWindowMode == "precise" return NewSlidingWindowLimiterWrapper(cfg, window, precise) default: return nil, fmt.Errorf("unknown algorithm: %s", algorithm) } } // newTokenBucketLimiter 创建令牌桶限流器。 func newTokenBucketLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) { if cfg.Burst < cfg.RequestRate { return nil, errors.New("burst must be at least equal to request rate") } rl := &RateLimiter{ rate: float64(cfg.RequestRate), burst: float64(cfg.Burst), stopCleanupCh: make(chan struct{}), cleanupDone: make(chan struct{}), } // 初始化 16 个分段锁桶 for i := range 16 { rl.shards[i] = shardedBucket{ buckets: make(map[string]*tokenBucket), } } // 根据配置设置键提取函数 keyFunc, err := parseKeyFunc(cfg.Key) if err != nil { return nil, err } rl.keyFunc = keyFunc // 启动后台清理 goroutine rl.startCleanup(10 * time.Minute) return rl, nil } // SlidingWindowLimiterWrapper 滑动窗口限流器包装,实现 middleware.Middleware 接口。 type SlidingWindowLimiterWrapper struct { limiter *SlidingWindowLimiter keyFunc KeyFunc } // NewSlidingWindowLimiterWrapper 创建滑动窗口限流器包装。 // // 该函数基于配置创建一个滑动窗口限流器,并将其包装为实现 // middleware.Middleware 接口的结构体。滑动窗口模式提供更精确 // 的限流控制,适用于需要平滑流量分布的场景。 // // 参数: // - cfg: 限流配置,包含速率和键类型 // - window: 滑动窗口持续时间 // - precise: 是否使用精确模式(true)或粗略模式(false) // // 返回值: // - *SlidingWindowLimiterWrapper: 限流器包装实例 // - error: 键类型无效时返回错误 func NewSlidingWindowLimiterWrapper(cfg *config.RateLimitConfig, window time.Duration, precise bool) (*SlidingWindowLimiterWrapper, error) { keyFunc, err := parseKeyFunc(cfg.Key) if err != nil { return nil, err } return &SlidingWindowLimiterWrapper{ limiter: NewSlidingWindowLimiter(window, cfg.RequestRate, precise), keyFunc: keyFunc, }, nil } // Name 返回中间件名称。 func (s *SlidingWindowLimiterWrapper) Name() string { return "sliding_window_limiter" } // Process 包装下一个处理器,添加限流逻辑。 func (s *SlidingWindowLimiterWrapper) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { key := s.keyFunc(ctx) if !s.limiter.Allow(key) { utils.SendError(ctx, utils.ErrTooManyRequests) return } next(ctx) } } // Name 返回中间件名称。 // // 返回值: // - string: 中间件标识名 "rate_limiter" func (rl *RateLimiter) Name() string { return "rate_limiter" } // Process 包装下一个处理器,添加限流逻辑。 // // 超过限流阈值的请求将收到 429 Too Many Requests 响应, // 并在响应头中设置 Retry-After 提示重试等待时间。 // // 参数: // - next: 下一个请求处理器 // // 返回值: // - fasthttp.RequestHandler: 包装后的处理器 func (rl *RateLimiter) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { key := rl.keyFunc(ctx) if !rl.Allow(key) { // 计算重试等待时间 retryAfter := rl.getRetryAfter(key) ctx.Response.Header.Set("Retry-After", fmt.Sprintf("%d", retryAfter)) utils.SendError(ctx, utils.ErrTooManyRequests) return } next(ctx) } } // getShard 根据键获取对应的分段锁桶。 // // 使用 FNV-1a 哈希算法计算键的哈希值,然后取模分配到 16 个桶中的一个。 // // 参数: // - key: 限流键 // // 返回值: // - *shardedBucket: 对应的分段锁桶 func (rl *RateLimiter) getShard(key string) *shardedBucket { return &rl.shards[hash.FNV64a(key)%16] } // Allow 检查给定键的请求是否应被允许。 // // 使用令牌桶算法:每个请求消耗一个令牌,令牌按速率持续补充。 // 如果桶中有足够令牌则允许请求,否则拒绝。 // // 参数: // - key: 限流键(如 IP 地址) // // 返回值: // - bool: true 表示允许请求,false 表示拒绝 func (rl *RateLimiter) Allow(key string) bool { shard := rl.getShard(key) shard.mu.RLock() bucket, exists := shard.buckets[key] shard.mu.RUnlock() if !exists { shard.mu.Lock() // 获取写锁后再次检查 if bucket, exists = shard.buckets[key]; !exists { bucket = &tokenBucket{ tokens: rl.burst, // 初始满桶 lastUpdate: time.Now(), } shard.buckets[key] = bucket } shard.mu.Unlock() } return bucket.consume(rl.rate, rl.burst) } // consume 尝试从桶中消耗一个令牌。 // // 根据时间流逝补充令牌,然后检查是否有足够令牌消耗。 // // 参数: // - rate: 令牌补充速率(每秒) // - burst: 桶的最大容量 // // 返回值: // - bool: true 表示成功消耗令牌,false 表示桶空 func (tb *tokenBucket) consume(rate, burst float64) bool { tb.mu.Lock() defer tb.mu.Unlock() now := time.Now() elapsed := now.Sub(tb.lastUpdate).Seconds() // 根据时间流逝补充令牌 tb.tokens += elapsed * rate if tb.tokens > burst { tb.tokens = burst // 不超过桶容量 } tb.lastUpdate = now // 检查是否有足够令牌 if tb.tokens >= 1.0 { tb.tokens -= 1.0 return true } return false } // getRetryAfter 计算重试前需等待的秒数。 // // 根据令牌桶当前状态计算需要等待的时间, // 包括生成一个令牌的时间和补偿欠缺令牌的时间。 // // 参数: // - key: 限流键 // // 返回值: // - int64: 建议等待的秒数 func (rl *RateLimiter) getRetryAfter(key string) int64 { shard := rl.getShard(key) shard.mu.RLock() bucket, exists := shard.buckets[key] shard.mu.RUnlock() if !exists { return 1 } bucket.mu.Lock() defer bucket.mu.Unlock() // 生成一个令牌的时间 waitTime := 1.0 / rl.rate // 如果桶欠缺令牌,需额外等待时间 if bucket.tokens < 0 { waitTime += -bucket.tokens / rl.rate } return int64(waitTime) + 1 } // keyByIP 提取客户端 IP 作为限流键。 // // 从请求上下文获取客户端的真实 IP 地址。 // // 参数: // - ctx: FastHTTP 请求上下文 // // 返回值: // - string: IP 地址字符串,无法获取时返回 "unknown" func keyByIP(ctx *fasthttp.RequestCtx) string { ip := netutil.ExtractClientIPNet(ctx) if ip == nil { return accessUnknown } return ip.String() } // keyByHeader 提取头部值作为限流键。 // // 默认使用 X-RateLimit-Key 头部,如果不存在则回退到 IP。 // // 参数: // - ctx: FastHTTP 请求上下文 // // 返回值: // - string: 头部值或 IP 地址字符串 func keyByHeader(ctx *fasthttp.RequestCtx) string { key := ctx.Request.Header.Peek("X-RateLimit-Key") if len(key) == 0 { // 头部不存在时回退到 IP return keyByIP(ctx) } return string(key) } // parseKeyFunc 根据配置字符串解析键提取函数。 // // 支持的键类型: // - "ip" 或 "": 使用客户端 IP 作为键 // - "header": 使用 X-RateLimit-Key 头部值作为键 // // 参数: // - keyType: 键类型字符串 // // 返回值: // - KeyFunc: 键提取函数 // - error: 未知的键类型时返回错误 func parseKeyFunc(keyType string) (KeyFunc, error) { switch keyType { case "ip", "": return keyByIP, nil case rateLimitHeader: return keyByHeader, nil default: return nil, fmt.Errorf("unknown key type: %s", keyType) } } // Reset 重置指定键的令牌桶。 // // 删除该键的桶记录,下次请求时将重新创建满载的桶。 // // 参数: // - key: 要重置的限流键 func (rl *RateLimiter) Reset(key string) { shard := rl.getShard(key) shard.mu.Lock() delete(shard.buckets, key) shard.mu.Unlock() } // ResetAll 重置所有令牌桶。 // // 清空所有桶记录,所有客户端将重新开始计数。 func (rl *RateLimiter) ResetAll() { for i := range 16 { rl.shards[i].mu.Lock() rl.shards[i].buckets = make(map[string]*tokenBucket) rl.shards[i].mu.Unlock() } } // Cleanup 清理长时间未使用的令牌桶。 // // 删除超过 maxAge 时间未更新的桶,防止内存无限增长。 // 建议定期调用此方法(如每分钟一次)。 // // 参数: // - maxAge: 未使用桶的最大保留时间 func (rl *RateLimiter) Cleanup(maxAge time.Duration) { now := time.Now() for i := range 16 { shard := &rl.shards[i] shard.mu.Lock() for key, bucket := range shard.buckets { bucket.mu.Lock() if now.Sub(bucket.lastUpdate) > maxAge { delete(shard.buckets, key) } bucket.mu.Unlock() } shard.mu.Unlock() } } // startCleanup 启动后台清理 goroutine。 // // 定期清理超过 24 小时未更新的令牌桶。 // 该方法在创建限流器时自动调用,无需手动调用。 // // 参数: // - interval: 清理间隔时间 func (rl *RateLimiter) startCleanup(interval time.Duration) { rl.cleanupTicker = time.NewTicker(interval) maxAge := 24 * time.Hour // 24 小时未更新则清理 go func() { defer close(rl.cleanupDone) for { select { case <-rl.cleanupTicker.C: rl.Cleanup(maxAge) case <-rl.stopCleanupCh: return } } }() } // StopCleanup 优雅关闭后台清理 goroutine。 // // 发送停止信号并等待 goroutine 完成,确保资源正确释放。 // 该方法应在限流器不再使用时调用(如服务器关闭时)。 func (rl *RateLimiter) StopCleanup() { if rl.cleanupTicker != nil { rl.cleanupTicker.Stop() rl.stopOnce.Do(func() { close(rl.stopCleanupCh) }) <-rl.cleanupDone rl.cleanupTicker = nil // 防止重复关闭 } } // RateLimitStats 速率限制器统计信息。 type RateLimitStats struct { BucketCount int // 当前活跃的桶数量 Rate float64 // 令牌补充速率(每秒) Burst float64 // 桶的最大容量 } // GetStats 返回当前速率限制器的统计信息。 // // 返回值: // - RateLimitStats: 包含桶数量、速率和容量的统计对象 func (rl *RateLimiter) GetStats() RateLimitStats { totalBuckets := 0 for i := range 16 { rl.shards[i].mu.RLock() totalBuckets += len(rl.shards[i].buckets) rl.shards[i].mu.RUnlock() } return RateLimitStats{ BucketCount: totalBuckets, Rate: rl.rate, Burst: rl.burst, } } // ConnLimiter 连接数限制器。 // // 控制最大并发连接数,支持全局限制或按键值限制。 // 与 RateLimiter 不同,此限制器控制并发而非速率。 // // 注意事项: // - 使用后必须调用 Release 释放连接槽 // - 所有方法均为并发安全 type ConnLimiter struct { keyFunc KeyFunc counts map[string]int64 max int current atomic.Int64 mu sync.RWMutex perKey bool } // NewConnLimiter 创建新的连接数限制器。 // // 参数: // - max: 最大并发连接数 // - perKey: true 为按键限制,false 为全局限制 // - keyType: 按键限制时的键类型("ip" 或 "header") // // 返回值: // - *ConnLimiter: 配置好的连接限制器 // - error: 配置无效时返回错误 func NewConnLimiter(maxConns int, perKey bool, keyType string) (*ConnLimiter, error) { if maxConns <= 0 { return nil, errors.New("max connections must be positive") } cl := &ConnLimiter{ max: maxConns, perKey: perKey, counts: make(map[string]int64), } if perKey { keyFunc, err := parseKeyFunc(keyType) if err != nil { return nil, err } cl.keyFunc = keyFunc } return cl, nil } // Acquire 尝试获取一个连接槽。 // // 如果当前连接数已达上限则返回 false。 // 成功获取后必须调用 Release 释放。 // // 参数: // - ctx: FastHTTP 请求上下文 // // 返回值: // - bool: true 表示成功获取,false 表示已达上限 func (cl *ConnLimiter) Acquire(ctx *fasthttp.RequestCtx) bool { if !cl.perKey { // 全局限制(原子递增后检查溢出,避免 TOCTOU 竞态) current := cl.current.Add(1) if current > int64(cl.max) { cl.current.Add(-1) return false } return true } // 按键限制 key := cl.keyFunc(ctx) cl.mu.Lock() defer cl.mu.Unlock() current := cl.counts[key] if current >= int64(cl.max) { return false } cl.counts[key] = current + 1 return true } // Release 释放一个连接槽。 // // 必须在连接结束时调用,否则连接数将持续增长。 // // 参数: // - ctx: FastHTTP 请求上下文 func (cl *ConnLimiter) Release(ctx *fasthttp.RequestCtx) { if !cl.perKey { cl.current.Add(-1) return } key := cl.keyFunc(ctx) cl.mu.Lock() if cl.counts[key] > 0 { cl.counts[key]-- } cl.mu.Unlock() } // Middleware 返回连接限制的中间件包装。 // // 返回值: // - middleware.Middleware: 可用于中间件链的限制器(返回自身) func (cl *ConnLimiter) Middleware() middleware.Middleware { return cl } // Name 返回中间件名称。 // // 返回值: // - string: 中间件标识名 "conn_limiter" func (cl *ConnLimiter) Name() string { return "conn_limiter" } // Process 包装处理器,添加连接限制逻辑。 // // 获取连接槽后执行处理器,完成后自动释放。 // 超过连接限制时返回 503 Service Unavailable。 // // 参数: // - next: 下一个请求处理器 // // 返回值: // - fasthttp.RequestHandler: 包装后的处理器 func (cl *ConnLimiter) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { if !cl.Acquire(ctx) { utils.SendErrorWithDetail(ctx, utils.ErrServiceUnavailable, "Connection limit exceeded") return } defer cl.Release(ctx) next(ctx) } } // 验证接口实现 var ( _ middleware.Middleware = (*RateLimiter)(nil) _ middleware.Middleware = (*ConnLimiter)(nil) )