lolly/internal/loadbalance/consistent_hash.go
xfy 7fe1ca6bec perf(loadbalance): eliminate double-lock in ConsistentHash with atomic.Bool rebuild guard
SelectByKey and SelectExcludingByKey previously had a RLock→RUnlock→
rebuildCircle(Lock)→RLock pattern when the hash ring was empty. Under
cold-start concurrency, multiple goroutines could trigger simultaneous
rebuild attempts.

Add atomic.Bool 'rebuilt' flag with ensureRebuilt() check before any
RLock acquisition:
- Fast path: atomic load returns true → skip rebuild, proceed to RLock
- Cold start: first caller rebuilds and sets flag, subsequent callers
  see the flag and skip rebuild
- Rebuild() explicitly resets the flag for explicit ring invalidation

Eliminates the RLock→Unlock→Lock→RLock transition entirely. The ring
is guaranteed ready before RLock is acquired.
2026-06-04 00:20:43 +08:00

282 lines
7.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package loadbalance 提供一致性哈希负载均衡算法实现。
//
// 该文件实现基于虚拟节点的一致性哈希算法,适用于缓存代理场景。
//
// 主要用途:
//
// 用于将相同键的请求始终路由到同一后端服务器,提高缓存命中率。
//
// 算法特点:
// - 使用虚拟节点解决数据倾斜问题
// - 支持 FNV-64a 哈希算法
// - 支持多种哈希键来源IP、URI、Header
//
// 作者xfy
package loadbalance
import (
"fmt"
"slices"
"sort"
"sync"
"sync/atomic"
)
// ConsistentHash 一致性哈希负载均衡器。
//
// 使用虚拟节点将请求均匀分布到各个目标,同时保证相同键的请求
// 始终路由到同一目标。
type ConsistentHash struct {
circle map[uint64]*Target
hashKey string
sortedHashes []uint64
virtualNodes int
mu sync.RWMutex
rebuilt atomic.Bool
}
// NewConsistentHash 创建一致性哈希负载均衡器。
//
// 参数:
// - virtualNodes: 每个目标的虚拟节点数,默认 150
// - hashKey: 哈希键来源,支持 ip、uri、header:X-Name
func NewConsistentHash(virtualNodes int, hashKey string) *ConsistentHash {
if virtualNodes <= 0 {
virtualNodes = 150
}
return &ConsistentHash{
virtualNodes: virtualNodes,
circle: make(map[uint64]*Target),
hashKey: hashKey,
}
}
// Select 根据默认键选择目标。
//
// 由于一致性哈希需要具体键值,此方法返回 nil。
// 请使用 SelectByKey 方法。
func (c *ConsistentHash) Select(targets []*Target) *Target {
return c.SelectByKey(targets, "")
}
// SelectByKey 根据指定键选择目标。
//
// 参数:
// - targets: 可用目标列表
// - key: 哈希键值(如客户端 IP、URI 等)
//
// 返回值:
// - *Target: 选中的目标,如果没有健康目标则返回 nil
func (c *ConsistentHash) SelectByKey(targets []*Target, key string) *Target {
c.ensureRebuilt(targets)
c.mu.RLock()
defer c.mu.RUnlock()
if len(c.sortedHashes) == 0 {
return nil
}
hash := fnvHash64a(key)
idx := sort.Search(len(c.sortedHashes), func(i int) bool {
return c.sortedHashes[i] >= hash
})
if idx >= len(c.sortedHashes) {
idx = 0
}
return c.circle[c.sortedHashes[idx]]
}
// Rebuild 重建哈希环。
//
// 当目标列表发生变化时应调用此方法。
//
// 参数:
// - targets: 新的目标列表
func (c *ConsistentHash) Rebuild(targets []*Target) {
c.rebuilt.Store(false)
c.rebuildCircle(targets)
}
func (c *ConsistentHash) ensureRebuilt(targets []*Target) {
if c.rebuilt.Load() {
return
}
c.rebuildCircle(targets)
}
// rebuildCircle 重建哈希环(内部方法,需要持有锁)。
func (c *ConsistentHash) rebuildCircle(targets []*Target) {
c.mu.Lock()
defer c.mu.Unlock()
// 清空现有环
c.circle = make(map[uint64]*Target)
c.sortedHashes = make([]uint64, 0)
// 为每个目标添加虚拟节点
for _, target := range targets {
if !target.Healthy.Load() {
continue
}
// 确保目标已预计算哈希
if len(target.VirtualHashes) == 0 {
target.VirtualHashes = make([]uint64, c.virtualNodes)
for i := 0; i < c.virtualNodes; i++ {
key := fmt.Sprintf("%s#%d", target.URL, i)
target.VirtualHashes[i] = c.hashKeyString(key)
}
}
// 使用预计算的哈希值
for _, hash := range target.VirtualHashes {
c.circle[hash] = target
c.sortedHashes = append(c.sortedHashes, hash)
}
}
// 排序哈希值
slices.Sort(c.sortedHashes)
c.rebuilt.Store(true)
}
// hashKeyString 计算字符串的哈希值(使用 FNV-64a
func (c *ConsistentHash) hashKeyString(key string) uint64 {
return fnvHash64a(key)
}
// PrecomputeHashes 预计算目标的虚拟节点哈希值。
//
// 此方法应在目标初始化时调用,避免在 SelectExcludingByKey 中重复计算哈希值。
// 预计算的哈希值存储在 Target.VirtualHashes 中,用于故障转移场景。
//
// 参数:
// - targets: 需要预计算哈希的目标列表
// - virtualNodes: 每个目标的虚拟节点数
func (c *ConsistentHash) PrecomputeHashes(targets []*Target, virtualNodes int) {
if virtualNodes <= 0 {
virtualNodes = 150
}
for _, target := range targets {
// 如果已经预计算过且数量匹配,跳过
if len(target.VirtualHashes) == virtualNodes {
continue
}
// 预计算该目标的所有虚拟节点哈希
target.VirtualHashes = make([]uint64, virtualNodes)
for i := 0; i < virtualNodes; i++ {
key := fmt.Sprintf("%s#%d", target.URL, i)
target.VirtualHashes[i] = c.hashKeyString(key)
}
}
}
// GetHashKey 返回哈希键配置。
func (c *ConsistentHash) GetHashKey() string {
return c.hashKey
}
// GetVirtualNodes 返回虚拟节点数。
func (c *ConsistentHash) GetVirtualNodes() int {
return c.virtualNodes
}
// ConsistentHashStats 返回一致性哈希统计信息。
type ConsistentHashStats struct {
// VirtualNodes 每个目标的虚拟节点数量
VirtualNodes int
// CircleSize 哈希环中的节点总数
CircleSize int
// SortedHashes 排序后的哈希值数量
SortedHashes int
}
// GetStats 返回统计信息。
func (c *ConsistentHash) GetStats() ConsistentHashStats {
c.mu.RLock()
defer c.mu.RUnlock()
return ConsistentHashStats{
VirtualNodes: c.virtualNodes,
CircleSize: len(c.circle),
SortedHashes: len(c.sortedHashes),
}
}
// SelectExcluding 根据指定键选择目标,排除指定的目标列表。
//
// 参数:
// - targets: 可用目标列表
// - excluded: 需要排除的目标列表
//
// 返回值:
// - *Target: 选中的目标,如果没有可用目标则返回 nil
func (c *ConsistentHash) SelectExcluding(targets []*Target, excluded []*Target) *Target {
return c.SelectExcludingByKey(targets, excluded, "")
}
// SelectExcludingByKey 根据指定键选择目标,排除指定的目标列表。
//
// 参数:
// - targets: 可用目标列表
// - excluded: 需要排除的目标列表
// - key: 哈希键值(如客户端 IP、URI 等)
//
// 返回值:
// - *Target: 选中的目标,如果没有可用目标则返回 nil
//
// 语义说明:此方法假设传入的 targets 列表与主哈希环的目标列表一致。
// 若不一致如多上游组场景targetSet 校验将拒绝所有候选,返回 nil。
// 调用方应在 targets 列表变化时调用 Rebuild() 更新主哈希环。
func (c *ConsistentHash) SelectExcludingByKey(targets []*Target, excluded []*Target, key string) *Target {
c.ensureRebuilt(targets)
c.mu.RLock()
fc := acquireFilterContext()
defer releaseFilterContext(fc)
targetSet := fc.excludeSet
for _, t := range targets {
if t.Healthy.Load() {
targetSet[t.URL] = true
}
}
for _, t := range excluded {
if t != nil {
targetSet[t.URL] = false
}
}
if len(c.sortedHashes) == 0 {
c.mu.RUnlock()
return nil
}
hash := c.hashKeyString(key)
idx := sort.Search(len(c.sortedHashes), func(i int) bool {
return c.sortedHashes[i] >= hash
})
for i := 0; i < len(c.sortedHashes); i++ {
targetIdx := (idx + i) % len(c.sortedHashes)
target := c.circle[c.sortedHashes[targetIdx]]
if targetSet[target.URL] {
c.mu.RUnlock()
return target
}
}
c.mu.RUnlock()
return nil
}
// 验证接口实现
var _ Balancer = (*ConsistentHash)(nil)