From fa95b2a76e49cdda1d5cc1046aba690bd720e925 Mon Sep 17 00:00:00 2001 From: xfy Date: Mon, 8 Jun 2026 17:21:20 +0800 Subject: [PATCH] feat(loadbalance): implement Least Time balancer - Add atomic EWMA Stats field to Target - Implement LeastTime balancer with header_time and last_byte metrics - Support Select and SelectExcluding with zero-lock design - Add ResponseTimeRecorder interface for proxy integration --- internal/loadbalance/balancer.go | 4 + internal/loadbalance/balancer_test.go | 2 +- internal/loadbalance/least_time.go | 101 ++++++++ internal/loadbalance/least_time_test.go | 311 ++++++++++++++++++++++++ 4 files changed, 417 insertions(+), 1 deletion(-) create mode 100644 internal/loadbalance/least_time.go create mode 100644 internal/loadbalance/least_time_test.go diff --git a/internal/loadbalance/balancer.go b/internal/loadbalance/balancer.go index e54f133..ac7a2d1 100644 --- a/internal/loadbalance/balancer.go +++ b/internal/loadbalance/balancer.go @@ -106,6 +106,9 @@ type Target struct { EffectiveWeight atomic.Int64 // SlowStart 慢启动时间(配置) SlowStart time.Duration `yaml:"slow_start"` + + // Stats 响应时间统计(用于 least_time 算法) + Stats *EWMAStats } // Balancer 是 HTTP 代理(L7 层)负载均衡算法的接口。 @@ -615,6 +618,7 @@ func NewTargetFromConfig(url string, weight int, maxConns int64, maxFails int64, Backup: backup, Down: down, ProxyURI: proxyURI, + Stats: NewEWMAStats(), } t.initHostname() if !down { diff --git a/internal/loadbalance/balancer_test.go b/internal/loadbalance/balancer_test.go index 1c18c2c..c6d1222 100644 --- a/internal/loadbalance/balancer_test.go +++ b/internal/loadbalance/balancer_test.go @@ -28,7 +28,7 @@ import ( // 返回值: // - 初始化完成的 Target 指针 func createHealthyTarget(url string, healthy bool) *Target { - t := &Target{URL: url} + t := &Target{URL: url, Stats: NewEWMAStats()} t.Healthy.Store(healthy) return t } diff --git a/internal/loadbalance/least_time.go b/internal/loadbalance/least_time.go new file mode 100644 index 0000000..8b43a7f --- /dev/null +++ b/internal/loadbalance/least_time.go @@ -0,0 +1,101 @@ +package loadbalance + +import ( + "time" +) + +// ResponseTimeRecorder 响应时间记录接口。 +type ResponseTimeRecorder interface { + RecordResponseTime(target *Target, headerTime, lastByteTime time.Duration) +} + +// LeastTime 基于响应时间 EWMA 的负载均衡器。 +type LeastTime struct { + metric string + defaultTime time.Duration +} + +// NewLeastTime 创建一个新的基于响应时间的负载均衡器。 +// +// 参数: +// - metric: 选择指标,"header" 表示使用首字节时间,其他值使用完整响应时间 +// - defaultTime: 无统计信息时的默认响应时间,必须 > 0 +func NewLeastTime(metric string, defaultTime time.Duration) *LeastTime { + if metric != "header" { + metric = "last_byte" + } + if defaultTime <= 0 { + defaultTime = time.Millisecond + } + return &LeastTime{ + metric: metric, + defaultTime: defaultTime, + } +} + +// Select 根据响应时间 EWMA 选择一个目标。 +// 只考虑可用目标。如果没有可用目标则返回 nil。 +func (l *LeastTime) Select(targets []*Target) *Target { + fc := acquireFilterContext() + defer releaseFilterContext(fc) + available := filterInto(fc, targets) + return l.selectFrom(available) +} + +// SelectExcluding 根据响应时间 EWMA 选择一个目标,排除指定的目标列表。 +// 用于故障转移场景,避免选择已失败的目标。 +func (l *LeastTime) SelectExcluding(targets []*Target, excluded []*Target) *Target { + fc := acquireFilterContext() + defer releaseFilterContext(fc) + available := filterIntoExcluding(fc, targets, excluded) + return l.selectFrom(available) +} + +// selectFrom 从可用目标列表中选择响应时间最短的目标。 +func (l *LeastTime) selectFrom(available []*Target) *Target { + if len(available) == 0 { + return nil + } + + var selected *Target + var minTime int64 = -1 + defaultNano := l.defaultTime.Nanoseconds() + + for _, t := range available { + var currentTime int64 + if t.Stats != nil { + if l.metric == "header" { + currentTime = int64(t.Stats.HeaderTime()) + } else { + currentTime = int64(t.Stats.LastByteTime()) + } + } + + if currentTime == 0 { + currentTime = defaultNano + } + + if selected == nil || currentTime < minTime { + selected = t + minTime = currentTime + } + } + + return selected +} + +// RecordResponseTime 记录目标服务器的响应时间。 +// 更新目标的 EWMA 统计信息。 +func (l *LeastTime) RecordResponseTime(target *Target, headerTime, lastByteTime time.Duration) { + if target != nil && target.Stats != nil { + target.Stats.Record(headerTime, lastByteTime) + } +} + +// GetMetric 返回当前使用的响应时间指标。 +func (l *LeastTime) GetMetric() string { + return l.metric +} + +var _ Balancer = (*LeastTime)(nil) +var _ ResponseTimeRecorder = (*LeastTime)(nil) diff --git a/internal/loadbalance/least_time_test.go b/internal/loadbalance/least_time_test.go new file mode 100644 index 0000000..84fbada --- /dev/null +++ b/internal/loadbalance/least_time_test.go @@ -0,0 +1,311 @@ +package loadbalance + +import ( + "sync" + "testing" + "time" +) + +// TestLeastTime_BasicSelect 测试基本的响应时间选择。 +// 两个目标,不同响应时间,应选择更快的目标。 +func TestLeastTime_BasicSelect(t *testing.T) { + t.Parallel() + t.Run("选择响应时间最短的目标", func(_ *testing.T) { + lt := NewLeastTime("last_byte", time.Millisecond) + target1 := createHealthyTarget("http://backend1:8080", true) + target2 := createHealthyTarget("http://backend2:8080", true) + targets := []*Target{target1, target2} + + // 记录响应时间:backend1 慢,backend2 快 + lt.RecordResponseTime(target1, 10*time.Millisecond, 100*time.Millisecond) + lt.RecordResponseTime(target2, 10*time.Millisecond, 10*time.Millisecond) + + got := lt.Select(targets) + if got == nil { + t.Fatal("Select() = nil, want non-nil") + } + if got.URL != "http://backend2:8080" { + t.Errorf("Select() = %q, want %q", got.URL, "http://backend2:8080") + } + }) + + t.Run("空目标", func(_ *testing.T) { + lt := NewLeastTime("last_byte", time.Millisecond) + got := lt.Select([]*Target{}) + if got != nil { + t.Errorf("Select() = %v, want nil", got) + } + }) + + t.Run("跳过不健康目标", func(_ *testing.T) { + lt := NewLeastTime("last_byte", time.Millisecond) + target1 := createHealthyTarget("http://backend1:8080", false) + target2 := createHealthyTarget("http://backend2:8080", true) + targets := []*Target{target1, target2} + + lt.RecordResponseTime(target2, 10*time.Millisecond, 100*time.Millisecond) + + got := lt.Select(targets) + if got == nil { + t.Fatal("Select() = nil, want non-nil") + } + if got.URL != "http://backend2:8080" { + t.Errorf("Select() = %q, want %q", got.URL, "http://backend2:8080") + } + }) +} + +// TestLeastTime_NoStats 测试无统计信息的目标。 +// 目标没有记录过响应时间时,应使用默认值选择。 +func TestLeastTime_NoStats(t *testing.T) { + t.Parallel() + t.Run("无统计信息时使用默认值", func(_ *testing.T) { + lt := NewLeastTime("last_byte", time.Millisecond) + target1 := createHealthyTarget("http://backend1:8080", true) + target2 := createHealthyTarget("http://backend2:8080", true) + targets := []*Target{target1, target2} + + // 不记录任何响应时间 + got := lt.Select(targets) + if got == nil { + t.Fatal("Select() = nil, want non-nil") + } + // 使用默认值时,应返回第一个可用目标 + if got.URL != "http://backend1:8080" { + t.Errorf("Select() = %q, want %q", got.URL, "http://backend1:8080") + } + }) + + t.Run("部分目标有统计", func(_ *testing.T) { + lt := NewLeastTime("last_byte", time.Millisecond) + target1 := createHealthyTarget("http://backend1:8080", true) + target2 := createHealthyTarget("http://backend2:8080", true) + targets := []*Target{target1, target2} + + // 只记录一个目标的响应时间(非常快) + lt.RecordResponseTime(target1, 1*time.Nanosecond, 1*time.Nanosecond) + // target2 无统计,使用默认值(1ms = 1,000,000ns) + + got := lt.Select(targets) + if got == nil { + t.Fatal("Select() = nil, want non-nil") + } + // target1 的 EWMA 应该远小于默认值,所以选择 target1 + if got.URL != "http://backend1:8080" { + t.Errorf("Select() = %q, want %q", got.URL, "http://backend1:8080") + } + }) +} + +// TestLeastTime_HeaderMetric 测试 header 指标选择。 +// 使用 "header" 指标时,应基于 header_time 而非 last_byte_time。 +func TestLeastTime_HeaderMetric(t *testing.T) { + t.Parallel() + t.Run("header指标选择", func(_ *testing.T) { + lt := NewLeastTime("header", time.Millisecond) + target1 := createHealthyTarget("http://backend1:8080", true) + target2 := createHealthyTarget("http://backend2:8080", true) + targets := []*Target{target1, target2} + + // target1: header快但last_byte慢 + lt.RecordResponseTime(target1, 10*time.Millisecond, 100*time.Millisecond) + // target2: header慢但last_byte快 + lt.RecordResponseTime(target2, 100*time.Millisecond, 10*time.Millisecond) + + // 使用 header 指标,应该选择 header_time 更小的 target1 + got := lt.Select(targets) + if got == nil { + t.Fatal("Select() = nil, want non-nil") + } + if got.URL != "http://backend1:8080" { + t.Errorf("Select() = %q, want %q", got.URL, "http://backend1:8080") + } + }) + + t.Run("last_byte指标选择", func(_ *testing.T) { + lt := NewLeastTime("last_byte", time.Millisecond) + target1 := createHealthyTarget("http://backend1:8080", true) + target2 := createHealthyTarget("http://backend2:8080", true) + targets := []*Target{target1, target2} + + // target1: header快但last_byte慢 + lt.RecordResponseTime(target1, 10*time.Millisecond, 100*time.Millisecond) + // target2: header慢但last_byte快 + lt.RecordResponseTime(target2, 100*time.Millisecond, 10*time.Millisecond) + + // 使用 last_byte 指标,应该选择 last_byte_time 更小的 target2 + got := lt.Select(targets) + if got == nil { + t.Fatal("Select() = nil, want non-nil") + } + if got.URL != "http://backend2:8080" { + t.Errorf("Select() = %q, want %q", got.URL, "http://backend2:8080") + } + }) +} + +// TestLeastTime_SelectExcluding 测试排除选择。 +// 排除最快的目标后,应选择次快的目标。 +func TestLeastTime_SelectExcluding(t *testing.T) { + t.Parallel() + t.Run("排除最快目标", func(_ *testing.T) { + lt := NewLeastTime("last_byte", time.Millisecond) + target1 := createHealthyTarget("http://backend1:8080", true) + target2 := createHealthyTarget("http://backend2:8080", true) + target3 := createHealthyTarget("http://backend3:8080", true) + targets := []*Target{target1, target2, target3} + + // target1 最快 + lt.RecordResponseTime(target1, 10*time.Millisecond, 10*time.Millisecond) + // target2 中等 + lt.RecordResponseTime(target2, 10*time.Millisecond, 50*time.Millisecond) + // target3 最慢 + lt.RecordResponseTime(target3, 10*time.Millisecond, 100*time.Millisecond) + + // 排除最快的 target1 + excluded := []*Target{target1} + got := lt.SelectExcluding(targets, excluded) + if got == nil { + t.Fatal("SelectExcluding() = nil, want non-nil") + } + // 应该选择次快的 target2 + if got.URL != "http://backend2:8080" { + t.Errorf("SelectExcluding() = %q, want %q", got.URL, "http://backend2:8080") + } + }) + + t.Run("排除所有目标", func(_ *testing.T) { + lt := NewLeastTime("last_byte", time.Millisecond) + target1 := createHealthyTarget("http://backend1:8080", true) + target2 := createHealthyTarget("http://backend2:8080", true) + targets := []*Target{target1, target2} + excluded := []*Target{target1, target2} + + got := lt.SelectExcluding(targets, excluded) + if got != nil { + t.Errorf("SelectExcluding() = %v, want nil", got) + } + }) + + t.Run("排除列表含nil", func(_ *testing.T) { + lt := NewLeastTime("last_byte", time.Millisecond) + target1 := createHealthyTarget("http://backend1:8080", true) + target2 := createHealthyTarget("http://backend2:8080", true) + targets := []*Target{target1, target2} + + lt.RecordResponseTime(target1, 10*time.Millisecond, 10*time.Millisecond) + lt.RecordResponseTime(target2, 10*time.Millisecond, 50*time.Millisecond) + + excluded := []*Target{nil, target1} + got := lt.SelectExcluding(targets, excluded) + if got == nil { + t.Fatal("SelectExcluding() = nil, want non-nil") + } + if got.URL == target1.URL { + t.Errorf("选中了被排除的目标: %q", got.URL) + } + }) +} + +// TestLeastTime_Concurrent 测试并发安全。 +// 50 goroutines 记录响应时间,50 goroutines 选择目标。 +func TestLeastTime_Concurrent(t *testing.T) { + t.Parallel() + lt := NewLeastTime("last_byte", time.Millisecond) + targets := []*Target{ + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + createHealthyTarget("http://backend3:8080", true), + } + + var wg sync.WaitGroup + + // 50 goroutines 记录响应时间 + for i := 0; i < 50; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + target := targets[idx%len(targets)] + headerTime := time.Duration(10+idx) * time.Millisecond + lastByteTime := time.Duration(50+idx) * time.Millisecond + lt.RecordResponseTime(target, headerTime, lastByteTime) + }(i) + } + + // 50 goroutines 选择目标 + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + got := lt.Select(targets) + if got == nil { + t.Error("并发Select() = nil, want non-nil") + } + }() + } + + wg.Wait() +} + +// TestLeastTime_GetMetric 测试 GetMetric 方法。 +func TestLeastTime_GetMetric(t *testing.T) { + t.Parallel() + t.Run("header metric", func(_ *testing.T) { + lt := NewLeastTime("header", time.Millisecond) + if lt.GetMetric() != "header" { + t.Errorf("GetMetric() = %q, want %q", lt.GetMetric(), "header") + } + }) + + t.Run("last_byte metric", func(_ *testing.T) { + lt := NewLeastTime("last_byte", time.Millisecond) + if lt.GetMetric() != "last_byte" { + t.Errorf("GetMetric() = %q, want %q", lt.GetMetric(), "last_byte") + } + }) + + t.Run("默认metric", func(_ *testing.T) { + lt := NewLeastTime("", time.Millisecond) + if lt.GetMetric() != "last_byte" { + t.Errorf("GetMetric() = %q, want %q", lt.GetMetric(), "last_byte") + } + }) +} + +// TestLeastTime_BalancerInterface 验证 LeastTime 实现了 Balancer 接口。 +func TestLeastTime_BalancerInterface(t *testing.T) { + t.Parallel() + var _ Balancer = (*LeastTime)(nil) + var _ ResponseTimeRecorder = (*LeastTime)(nil) +} + +// TestLeastTime_RecordResponseTimeNil 测试 RecordResponseTime 的 nil 处理。 +func TestLeastTime_RecordResponseTimeNil(t *testing.T) { + t.Parallel() + lt := NewLeastTime("last_byte", time.Millisecond) + + // nil target 不应 panic + lt.RecordResponseTime(nil, 10*time.Millisecond, 100*time.Millisecond) + + // nil Stats 不应 panic + target := &Target{URL: "http://backend1:8080"} + lt.RecordResponseTime(target, 10*time.Millisecond, 100*time.Millisecond) +} + +// TestLeastTime_DefaultTimeValidation 测试默认值参数验证。 +func TestLeastTime_DefaultTimeValidation(t *testing.T) { + t.Parallel() + t.Run("零值默认值", func(_ *testing.T) { + lt := NewLeastTime("last_byte", 0) + if lt.defaultTime != time.Millisecond { + t.Errorf("defaultTime = %v, want %v", lt.defaultTime, time.Millisecond) + } + }) + + t.Run("负值默认值", func(_ *testing.T) { + lt := NewLeastTime("last_byte", -1*time.Millisecond) + if lt.defaultTime != time.Millisecond { + t.Errorf("defaultTime = %v, want %v", lt.defaultTime, time.Millisecond) + } + }) +}