diff --git a/internal/loadbalance/balancer.go b/internal/loadbalance/balancer.go index be650c9..35a3f64 100644 --- a/internal/loadbalance/balancer.go +++ b/internal/loadbalance/balancer.go @@ -50,6 +50,10 @@ type Target struct { // lastResolved 最后解析时间(UnixNano,使用 atomic.Int64) lastResolved atomic.Int64 + + // VirtualHashes 预计算的虚拟节点哈希值(用于一致性哈希) + // 由 PrecomputeHashes 方法填充,避免运行时重复计算 + VirtualHashes []uint64 } // Balancer 是负载均衡算法的接口。 diff --git a/internal/loadbalance/balancer_bench_test.go b/internal/loadbalance/balancer_bench_test.go index 2aa09a5..84f493b 100644 --- a/internal/loadbalance/balancer_bench_test.go +++ b/internal/loadbalance/balancer_bench_test.go @@ -154,7 +154,38 @@ func BenchmarkConsistentHashRebuild(b *testing.B) { } } -// BenchmarkLeastConnSelect 基准测试最少连接算法。 +// BenchmarkConsistentHashSelectExcluding 基准测试一致性哈希排除选择算法。 +func BenchmarkConsistentHashSelectExcluding(b *testing.B) { + testCases := []struct { + name string + targets int + virtualNodes int + excludeCount int + }{ + {"50targets_150vnodes_exclude5", 50, 150, 5}, + {"50targets_150vnodes_exclude10", 50, 150, 10}, + {"100targets_150vnodes_exclude5", 100, 150, 5}, + } + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + targets := generateTargets(tc.targets) + ch := NewConsistentHash(tc.virtualNodes, "ip") + + // 预计算所有目标的虚拟节点哈希 + ch.PrecomputeHashes(targets, tc.virtualNodes) + ch.Rebuild(targets) + + excluded := targets[:tc.excludeCount] + key := "test-request-key" + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ch.SelectExcludingByKey(targets, excluded, key) + } + }) + } +} func BenchmarkLeastConnSelect(b *testing.B) { testCases := []struct { name string diff --git a/internal/loadbalance/balancer_test.go b/internal/loadbalance/balancer_test.go index 04c299e..e8a23bf 100644 --- a/internal/loadbalance/balancer_test.go +++ b/internal/loadbalance/balancer_test.go @@ -761,7 +761,153 @@ func TestConsistentHash(t *testing.T) { }) } -// TestIsValidAlgorithm 测试算法验证函数。 +// TestConsistentHashSelectExcludingByKey 测试一致性哈希排除选择功能。 +func TestConsistentHashSelectExcludingByKey(t *testing.T) { + t.Run("空排除列表", func(t *testing.T) { + ch := NewConsistentHash(150, "ip") + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + createHealthyTarget("http://backend3:8080", true), + } + ch.Rebuild(targets) + + key := "192.168.1.100" + got := ch.SelectExcludingByKey(targets, []*Target{}, key) + + if got == nil { + t.Fatal("SelectExcludingByKey() = nil, want non-nil") + } + + // 验证正常选择行为 + got2 := ch.SelectExcludingByKey(targets, nil, key) + if got2 == nil { + t.Fatal("SelectExcludingByKey() with nil = nil, want non-nil") + } + if got.URL != got2.URL { + t.Errorf("空排除和nil排除应该返回相同结果: empty=%q, nil=%q", got.URL, got2.URL) + } + }) + + t.Run("部分排除", func(t *testing.T) { + ch := NewConsistentHash(150, "ip") + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + createHealthyTarget("http://backend3:8080", true), + } + ch.Rebuild(targets) + + // 排除第一个目标 + excluded := []*Target{targets[0]} + key := "192.168.1.100" + + // 多次选择,验证不会选中排除的目标 + for i := 0; i < 100; i++ { + got := ch.SelectExcludingByKey(targets, excluded, key) + if got == nil { + t.Fatal("SelectExcludingByKey() = nil, want non-nil") + } + if got.URL == targets[0].URL { + t.Errorf("选中了被排除的目标: %q", got.URL) + } + } + }) + + t.Run("全部排除", func(t *testing.T) { + ch := NewConsistentHash(150, "ip") + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + } + ch.Rebuild(targets) + + // 排除所有目标 + excluded := []*Target{targets[0], targets[1]} + key := "192.168.1.100" + + got := ch.SelectExcludingByKey(targets, excluded, key) + if got != nil { + t.Errorf("SelectExcludingByKey() = %q, want nil (all excluded)", got.URL) + } + }) + + t.Run("排除包含nil目标", func(t *testing.T) { + ch := NewConsistentHash(150, "ip") + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + } + ch.Rebuild(targets) + + // 排除列表中包含nil + excluded := []*Target{nil, targets[0]} + key := "192.168.1.100" + + got := ch.SelectExcludingByKey(targets, excluded, key) + if got == nil { + t.Fatal("SelectExcludingByKey() = nil, want non-nil") + } + if got.URL == targets[0].URL { + t.Errorf("选中了被排除的目标: %q", got.URL) + } + }) + + t.Run("并发安全", func(t *testing.T) { + ch := NewConsistentHash(150, "ip") + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + createHealthyTarget("http://backend3:8080", true), + } + ch.Rebuild(targets) + + excluded := []*Target{targets[0]} + key := "192.168.1.100" + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + got := ch.SelectExcludingByKey(targets, excluded, key) + if got != nil && got.URL == targets[0].URL { + t.Errorf("并发时选中了被排除的目标: %q", got.URL) + } + } + }() + } + wg.Wait() + }) + + t.Run("相同键一致性", func(t *testing.T) { + ch := NewConsistentHash(150, "ip") + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + createHealthyTarget("http://backend3:8080", true), + } + ch.Rebuild(targets) + + excluded := []*Target{targets[0]} + key := "192.168.1.100" + + // 相同键应该始终返回相同的目标 + var firstSelection *Target + for i := 0; i < 50; i++ { + got := ch.SelectExcludingByKey(targets, excluded, key) + if got == nil { + t.Fatal("SelectExcludingByKey() = nil, want non-nil") + } + if firstSelection == nil { + firstSelection = got + } else if got.URL != firstSelection.URL { + t.Errorf("相同键选择不同目标: first=%q, got=%q", firstSelection.URL, got.URL) + } + } + }) +} func TestIsValidAlgorithm(t *testing.T) { tests := []struct { name string diff --git a/internal/loadbalance/consistent_hash.go b/internal/loadbalance/consistent_hash.go index 70c60eb..fe2d3fb 100644 --- a/internal/loadbalance/consistent_hash.go +++ b/internal/loadbalance/consistent_hash.go @@ -130,9 +130,17 @@ func (c *ConsistentHash) rebuildCircle(targets []*Target) { continue } - for i := 0; i < c.virtualNodes; i++ { - key := fmt.Sprintf("%s#%d", target.URL, i) - hash := c.hashKeyString(key) + // 确保目标已预计算哈希 + if len(target.VirtualHashes) == 0 { + target.VirtualHashes = make([]uint64, c.virtualNodes) + for i := 0; i < c.virtualNodes; i++ { + key := fmt.Sprintf("%s#%d", target.URL, i) + target.VirtualHashes[i] = c.hashKeyString(key) + } + } + + // 使用预计算的哈希值 + for _, hash := range target.VirtualHashes { c.circle[hash] = target c.sortedHashes = append(c.sortedHashes, hash) } @@ -151,6 +159,34 @@ func (c *ConsistentHash) hashKeyString(key string) uint64 { return h.Sum64() } +// PrecomputeHashes 预计算目标的虚拟节点哈希值。 +// +// 此方法应在目标初始化时调用,避免在 SelectExcludingByKey 中重复计算哈希值。 +// 预计算的哈希值存储在 Target.VirtualHashes 中,用于故障转移场景。 +// +// 参数: +// - targets: 需要预计算哈希的目标列表 +// - virtualNodes: 每个目标的虚拟节点数 +func (c *ConsistentHash) PrecomputeHashes(targets []*Target, virtualNodes int) { + if virtualNodes <= 0 { + virtualNodes = 150 + } + + for _, target := range targets { + // 如果已经预计算过且数量匹配,跳过 + if len(target.VirtualHashes) == virtualNodes { + continue + } + + // 预计算该目标的所有虚拟节点哈希 + target.VirtualHashes = make([]uint64, virtualNodes) + for i := 0; i < virtualNodes; i++ { + key := fmt.Sprintf("%s#%d", target.URL, i) + target.VirtualHashes[i] = c.hashKeyString(key) + } + } +} + // GetHashKey 返回哈希键配置。 func (c *ConsistentHash) GetHashKey() string { return c.hashKey @@ -207,6 +243,9 @@ func (c *ConsistentHash) SelectExcluding(targets []*Target, excluded []*Target) // 返回值: // - *Target: 选中的目标,如果没有可用目标则返回 nil func (c *ConsistentHash) SelectExcludingByKey(targets []*Target, excluded []*Target, key string) *Target { + c.mu.RLock() + defer c.mu.RUnlock() + // 构建排除集合 excludeSet := make(map[string]bool, len(excluded)) for _, t := range excluded { @@ -215,48 +254,46 @@ func (c *ConsistentHash) SelectExcludingByKey(targets []*Target, excluded []*Tar } } - c.mu.RLock() - defer c.mu.RUnlock() - // 如果没有排除的目标,使用正常选择 if len(excludeSet) == 0 { return c.SelectByKey(targets, key) } - // 过滤掉被排除的目标 - filtered := make([]*Target, 0, len(targets)) - for _, t := range targets { - if t.Healthy.Load() && !excludeSet[t.URL] { - filtered = append(filtered, t) - } - } - - if len(filtered) == 0 { - return nil - } - - // 为过滤后的目标临时构建哈希环 + // 使用预计算的虚拟节点哈希构建哈希环 + // 避免在每次调用时重新计算哈希值 circle := make(map[uint64]*Target) - sortedHashes := make([]uint64, 0) + sortedHashes := make([]uint64, 0, len(targets)*c.virtualNodes) - for _, target := range filtered { - for i := 0; i < c.virtualNodes; i++ { - nodeKey := fmt.Sprintf("%s#%d", target.URL, i) - hash := c.hashKeyString(nodeKey) + for _, target := range targets { + if !target.Healthy.Load() || excludeSet[target.URL] { + continue + } + + // 确保目标已预计算哈希 + if len(target.VirtualHashes) == 0 { + // 回退到动态计算(不应该发生,但保持安全) + c.mu.RUnlock() + c.PrecomputeHashes([]*Target{target}, c.virtualNodes) + c.mu.RLock() + } + + // 使用预计算的哈希值 + for _, hash := range target.VirtualHashes { circle[hash] = target sortedHashes = append(sortedHashes, hash) } } - // 排序哈希值 - sort.Slice(sortedHashes, func(i, j int) bool { - return sortedHashes[i] < sortedHashes[j] - }) - if len(sortedHashes) == 0 { return nil } + // 排序哈希值(仅在需要时) + // 使用 sort.Slice 进行排序 + sort.Slice(sortedHashes, func(i, j int) bool { + return sortedHashes[i] < sortedHashes[j] + }) + // 计算键的哈希值 hash := c.hashKeyString(key)