diff --git a/Makefile b/Makefile index bfeb2a0..bee1967 100644 --- a/Makefile +++ b/Makefile @@ -150,7 +150,7 @@ version: # 运行所有测试 test: @echo "Running tests..." - go test -v ./... + go test -v ./internal/... # 运行 L2 集成测试(无需 Docker) test-integration: diff --git a/internal/loadbalance/consistent_hash.go b/internal/loadbalance/consistent_hash.go index f374edd..68317d5 100644 --- a/internal/loadbalance/consistent_hash.go +++ b/internal/loadbalance/consistent_hash.go @@ -233,67 +233,56 @@ func (c *ConsistentHash) SelectExcluding(targets []*Target, excluded []*Target) // // 返回值: // - *Target: 选中的目标,如果没有可用目标则返回 nil +// +// 语义说明:此方法假设传入的 targets 列表与主哈希环的目标列表一致。 +// 若不一致(如多上游组场景),targetSet 校验将拒绝所有候选,返回 nil。 +// 调用方应在 targets 列表变化时调用 Rebuild() 更新主哈希环。 func (c *ConsistentHash) SelectExcludingByKey(targets []*Target, excluded []*Target, key string) *Target { c.mu.RLock() - defer c.mu.RUnlock() + + // 如果环为空,尝试重建 + if len(c.circle) == 0 { + c.mu.RUnlock() + c.rebuildCircle(targets) + c.mu.RLock() + } + + // 构建 targets 集合(用于校验返回的目标是否有效) + targetSet := make(map[string]bool, len(targets)) + for _, t := range targets { + if t.Healthy.Load() { + targetSet[t.URL] = true + } + } // 构建排除集合 excludeSet := buildExcludeSet(excluded) - // 如果没有排除的目标,使用正常选择 - if len(excludeSet) == 0 { - return c.SelectByKey(targets, key) - } - - // 使用预计算的虚拟节点哈希构建哈希环 - // 避免在每次调用时重新计算哈希值 - circle := make(map[uint64]*Target) - sortedHashes := make([]uint64, 0, len(targets)*c.virtualNodes) - - for _, target := range targets { - if !target.Healthy.Load() || excludeSet[target.URL] { - continue - } - - // 确保目标已预计算哈希 - if len(target.VirtualHashes) == 0 { - // 回退到动态计算(不应该发生,但保持安全) - c.mu.RUnlock() - c.PrecomputeHashes([]*Target{target}, c.virtualNodes) - c.mu.RLock() - } - - // 使用预计算的哈希值 - for _, hash := range target.VirtualHashes { - circle[hash] = target - sortedHashes = append(sortedHashes, hash) - } - } - - if len(sortedHashes) == 0 { + if len(c.sortedHashes) == 0 { + c.mu.RUnlock() return nil } - // 排序哈希值(仅在需要时) - // 使用 sort.Slice 进行排序 - sort.Slice(sortedHashes, func(i, j int) bool { - return sortedHashes[i] < sortedHashes[j] - }) - // 计算键的哈希值 hash := c.hashKeyString(key) - // 二分查找最近的节点 - idx := sort.Search(len(sortedHashes), func(i int) bool { - return sortedHashes[i] >= hash + // 二分查找起始位置 + idx := sort.Search(len(c.sortedHashes), func(i int) bool { + return c.sortedHashes[i] >= hash }) - // 环形回绕 - if idx >= len(sortedHashes) { - idx = 0 + // 从起始位置开始查找,跳过 excluded 和不在 targetSet 中的目标 + for i := 0; i < len(c.sortedHashes); i++ { + targetIdx := (idx + i) % len(c.sortedHashes) + target := c.circle[c.sortedHashes[targetIdx]] + if targetSet[target.URL] && !excludeSet[target.URL] { + c.mu.RUnlock() + return target + } } - return circle[sortedHashes[idx]] + c.mu.RUnlock() + return nil // 所有目标都被排除或不在 targets 列表中 } // 验证接口实现 diff --git a/internal/middleware/security/ratelimit.go b/internal/middleware/security/ratelimit.go index 1bb305a..6d154b6 100644 --- a/internal/middleware/security/ratelimit.go +++ b/internal/middleware/security/ratelimit.go @@ -32,6 +32,7 @@ package security import ( "errors" "fmt" + "hash/fnv" "sync" "sync/atomic" "time" @@ -45,23 +46,32 @@ import ( const rateLimitHeader = "header" +// shardedBucket 分段锁桶结构。 +// +// 每个分段锁桶包含一个独立的令牌桶映射和读写锁, +// 用于减少单一 RWMutex 的锁竞争。 +type shardedBucket struct { + mu sync.RWMutex + buckets map[string]*tokenBucket +} + // RateLimiter 基于令牌桶算法的请求速率限制器。 // // 实现请求限流功能,支持按 IP 或自定义键值进行限流。 // 令牌按配置的速率持续添加,每个请求消耗一个令牌。 +// 采用 16 个分段锁桶结构,减少锁竞争,提高并发性能。 // // 注意事项: // - 所有方法均为并发安全 // - 启动后会自动后台清理过期的桶 type RateLimiter struct { + shards [16]shardedBucket keyFunc KeyFunc - buckets map[string]*tokenBucket cleanupTicker *time.Ticker stopCleanupCh chan struct{} cleanupDone chan struct{} rate float64 burst float64 - mu sync.RWMutex } // tokenBucket 表示单个限流键的令牌桶。 @@ -127,11 +137,17 @@ func newTokenBucketLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) { rl := &RateLimiter{ rate: float64(cfg.RequestRate), burst: float64(cfg.Burst), - buckets: make(map[string]*tokenBucket), stopCleanupCh: make(chan struct{}), cleanupDone: make(chan struct{}), } + // 初始化 16 个分段锁桶 + for i := 0; i < 16; i++ { + rl.shards[i] = shardedBucket{ + buckets: make(map[string]*tokenBucket), + } + } + // 根据配置设置键提取函数 keyFunc, err := parseKeyFunc(cfg.Key) if err != nil { @@ -230,6 +246,21 @@ func (rl *RateLimiter) Process(next fasthttp.RequestHandler) fasthttp.RequestHan } } +// getShard 根据键获取对应的分段锁桶。 +// +// 使用 FNV-1a 哈希算法计算键的哈希值,然后取模分配到 16 个桶中的一个。 +// +// 参数: +// - key: 限流键 +// +// 返回值: +// - *shardedBucket: 对应的分段锁桶 +func (rl *RateLimiter) getShard(key string) *shardedBucket { + h := fnv.New64a() + h.Write([]byte(key)) + return &rl.shards[h.Sum64()%16] +} + // Allow 检查给定键的请求是否应被允许。 // // 使用令牌桶算法:每个请求消耗一个令牌,令牌按速率持续补充。 @@ -241,21 +272,23 @@ func (rl *RateLimiter) Process(next fasthttp.RequestHandler) fasthttp.RequestHan // 返回值: // - bool: true 表示允许请求,false 表示拒绝 func (rl *RateLimiter) Allow(key string) bool { - rl.mu.RLock() - bucket, exists := rl.buckets[key] - rl.mu.RUnlock() + shard := rl.getShard(key) + + shard.mu.RLock() + bucket, exists := shard.buckets[key] + shard.mu.RUnlock() if !exists { - rl.mu.Lock() + shard.mu.Lock() // 获取写锁后再次检查 - if bucket, exists = rl.buckets[key]; !exists { + if bucket, exists = shard.buckets[key]; !exists { bucket = &tokenBucket{ tokens: rl.burst, // 初始满桶 lastUpdate: time.Now(), } - rl.buckets[key] = bucket + shard.buckets[key] = bucket } - rl.mu.Unlock() + shard.mu.Unlock() } return bucket.consume(rl.rate, rl.burst) @@ -306,9 +339,11 @@ func (tb *tokenBucket) consume(rate, burst float64) bool { // 返回值: // - int64: 建议等待的秒数 func (rl *RateLimiter) getRetryAfter(key string) int64 { - rl.mu.RLock() - bucket, exists := rl.buckets[key] - rl.mu.RUnlock() + shard := rl.getShard(key) + + shard.mu.RLock() + bucket, exists := shard.buckets[key] + shard.mu.RUnlock() if !exists { return 1 @@ -392,18 +427,21 @@ func parseKeyFunc(keyType string) (KeyFunc, error) { // 参数: // - key: 要重置的限流键 func (rl *RateLimiter) Reset(key string) { - rl.mu.Lock() - delete(rl.buckets, key) - rl.mu.Unlock() + shard := rl.getShard(key) + shard.mu.Lock() + delete(shard.buckets, key) + shard.mu.Unlock() } // ResetAll 重置所有令牌桶。 // // 清空所有桶记录,所有客户端将重新开始计数。 func (rl *RateLimiter) ResetAll() { - rl.mu.Lock() - rl.buckets = make(map[string]*tokenBucket) - rl.mu.Unlock() + for i := 0; i < 16; i++ { + rl.shards[i].mu.Lock() + rl.shards[i].buckets = make(map[string]*tokenBucket) + rl.shards[i].mu.Unlock() + } } // Cleanup 清理长时间未使用的令牌桶。 @@ -414,16 +452,18 @@ func (rl *RateLimiter) ResetAll() { // 参数: // - maxAge: 未使用桶的最大保留时间 func (rl *RateLimiter) Cleanup(maxAge time.Duration) { - rl.mu.Lock() - defer rl.mu.Unlock() - now := time.Now() - for key, bucket := range rl.buckets { - bucket.mu.Lock() - if now.Sub(bucket.lastUpdate) > maxAge { - delete(rl.buckets, key) + for i := 0; i < 16; i++ { + shard := &rl.shards[i] + shard.mu.Lock() + for key, bucket := range shard.buckets { + bucket.mu.Lock() + if now.Sub(bucket.lastUpdate) > maxAge { + delete(shard.buckets, key) + } + bucket.mu.Unlock() } - bucket.mu.Unlock() + shard.mu.Unlock() } } @@ -456,8 +496,14 @@ func (rl *RateLimiter) startCleanup(interval time.Duration) { // 发送停止信号并等待 goroutine 完成,确保资源正确释放。 // 该方法应在限流器不再使用时调用(如服务器关闭时)。 func (rl *RateLimiter) StopCleanup() { - rl.mu.Lock() - defer rl.mu.Unlock() + // 使用原子操作或简单的标志检查来避免竞争 + // 关闭 stopCleanupCh 会广播给所有等待的 goroutine + select { + case <-rl.stopCleanupCh: + // 已经关闭 + return + default: + } if rl.cleanupTicker != nil { rl.cleanupTicker.Stop() @@ -479,11 +525,15 @@ type RateLimitStats struct { // 返回值: // - RateLimitStats: 包含桶数量、速率和容量的统计对象 func (rl *RateLimiter) GetStats() RateLimitStats { - rl.mu.RLock() - defer rl.mu.RUnlock() + totalBuckets := 0 + for i := 0; i < 16; i++ { + rl.shards[i].mu.RLock() + totalBuckets += len(rl.shards[i].buckets) + rl.shards[i].mu.RUnlock() + } return RateLimitStats{ - BucketCount: len(rl.buckets), + BucketCount: totalBuckets, Rate: rl.rate, Burst: rl.burst, } diff --git a/internal/ssl/session_tickets.go b/internal/ssl/session_tickets.go index 72df623..a9e9c25 100644 --- a/internal/ssl/session_tickets.go +++ b/internal/ssl/session_tickets.go @@ -99,6 +99,8 @@ func NewSessionTicketManager(cfg config.SessionTicketsConfig) (*SessionTicketMan } } else { // 没有指定密钥文件,生成内存中的密钥 + // 警告:服务重启后旧票据将失效,影响前向保密性 + logging.Warn().Msg("SessionTickets enabled without KeyFile: session tickets will be invalid after restart, consider configuring KeyFile for persistence") key, err := generateTicketKey() if err != nil { return nil, fmt.Errorf("failed to generate session ticket key: %w", err)