feat(loadbalance): implement Least Time balancer
- Add atomic EWMA Stats field to Target - Implement LeastTime balancer with header_time and last_byte metrics - Support Select and SelectExcluding with zero-lock design - Add ResponseTimeRecorder interface for proxy integration
This commit is contained in:
parent
c6bb75cffe
commit
fa95b2a76e
@ -106,6 +106,9 @@ type Target struct {
|
||||
EffectiveWeight atomic.Int64
|
||||
// SlowStart 慢启动时间(配置)
|
||||
SlowStart time.Duration `yaml:"slow_start"`
|
||||
|
||||
// Stats 响应时间统计(用于 least_time 算法)
|
||||
Stats *EWMAStats
|
||||
}
|
||||
|
||||
// Balancer 是 HTTP 代理(L7 层)负载均衡算法的接口。
|
||||
@ -615,6 +618,7 @@ func NewTargetFromConfig(url string, weight int, maxConns int64, maxFails int64,
|
||||
Backup: backup,
|
||||
Down: down,
|
||||
ProxyURI: proxyURI,
|
||||
Stats: NewEWMAStats(),
|
||||
}
|
||||
t.initHostname()
|
||||
if !down {
|
||||
|
||||
@ -28,7 +28,7 @@ import (
|
||||
// 返回值:
|
||||
// - 初始化完成的 Target 指针
|
||||
func createHealthyTarget(url string, healthy bool) *Target {
|
||||
t := &Target{URL: url}
|
||||
t := &Target{URL: url, Stats: NewEWMAStats()}
|
||||
t.Healthy.Store(healthy)
|
||||
return t
|
||||
}
|
||||
|
||||
101
internal/loadbalance/least_time.go
Normal file
101
internal/loadbalance/least_time.go
Normal file
@ -0,0 +1,101 @@
|
||||
package loadbalance
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ResponseTimeRecorder 响应时间记录接口。
|
||||
type ResponseTimeRecorder interface {
|
||||
RecordResponseTime(target *Target, headerTime, lastByteTime time.Duration)
|
||||
}
|
||||
|
||||
// LeastTime 基于响应时间 EWMA 的负载均衡器。
|
||||
type LeastTime struct {
|
||||
metric string
|
||||
defaultTime time.Duration
|
||||
}
|
||||
|
||||
// NewLeastTime 创建一个新的基于响应时间的负载均衡器。
|
||||
//
|
||||
// 参数:
|
||||
// - metric: 选择指标,"header" 表示使用首字节时间,其他值使用完整响应时间
|
||||
// - defaultTime: 无统计信息时的默认响应时间,必须 > 0
|
||||
func NewLeastTime(metric string, defaultTime time.Duration) *LeastTime {
|
||||
if metric != "header" {
|
||||
metric = "last_byte"
|
||||
}
|
||||
if defaultTime <= 0 {
|
||||
defaultTime = time.Millisecond
|
||||
}
|
||||
return &LeastTime{
|
||||
metric: metric,
|
||||
defaultTime: defaultTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Select 根据响应时间 EWMA 选择一个目标。
|
||||
// 只考虑可用目标。如果没有可用目标则返回 nil。
|
||||
func (l *LeastTime) Select(targets []*Target) *Target {
|
||||
fc := acquireFilterContext()
|
||||
defer releaseFilterContext(fc)
|
||||
available := filterInto(fc, targets)
|
||||
return l.selectFrom(available)
|
||||
}
|
||||
|
||||
// SelectExcluding 根据响应时间 EWMA 选择一个目标,排除指定的目标列表。
|
||||
// 用于故障转移场景,避免选择已失败的目标。
|
||||
func (l *LeastTime) SelectExcluding(targets []*Target, excluded []*Target) *Target {
|
||||
fc := acquireFilterContext()
|
||||
defer releaseFilterContext(fc)
|
||||
available := filterIntoExcluding(fc, targets, excluded)
|
||||
return l.selectFrom(available)
|
||||
}
|
||||
|
||||
// selectFrom 从可用目标列表中选择响应时间最短的目标。
|
||||
func (l *LeastTime) selectFrom(available []*Target) *Target {
|
||||
if len(available) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var selected *Target
|
||||
var minTime int64 = -1
|
||||
defaultNano := l.defaultTime.Nanoseconds()
|
||||
|
||||
for _, t := range available {
|
||||
var currentTime int64
|
||||
if t.Stats != nil {
|
||||
if l.metric == "header" {
|
||||
currentTime = int64(t.Stats.HeaderTime())
|
||||
} else {
|
||||
currentTime = int64(t.Stats.LastByteTime())
|
||||
}
|
||||
}
|
||||
|
||||
if currentTime == 0 {
|
||||
currentTime = defaultNano
|
||||
}
|
||||
|
||||
if selected == nil || currentTime < minTime {
|
||||
selected = t
|
||||
minTime = currentTime
|
||||
}
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
// RecordResponseTime 记录目标服务器的响应时间。
|
||||
// 更新目标的 EWMA 统计信息。
|
||||
func (l *LeastTime) RecordResponseTime(target *Target, headerTime, lastByteTime time.Duration) {
|
||||
if target != nil && target.Stats != nil {
|
||||
target.Stats.Record(headerTime, lastByteTime)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetric 返回当前使用的响应时间指标。
|
||||
func (l *LeastTime) GetMetric() string {
|
||||
return l.metric
|
||||
}
|
||||
|
||||
var _ Balancer = (*LeastTime)(nil)
|
||||
var _ ResponseTimeRecorder = (*LeastTime)(nil)
|
||||
311
internal/loadbalance/least_time_test.go
Normal file
311
internal/loadbalance/least_time_test.go
Normal file
@ -0,0 +1,311 @@
|
||||
package loadbalance
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestLeastTime_BasicSelect 测试基本的响应时间选择。
|
||||
// 两个目标,不同响应时间,应选择更快的目标。
|
||||
func TestLeastTime_BasicSelect(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("选择响应时间最短的目标", func(_ *testing.T) {
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
target1 := createHealthyTarget("http://backend1:8080", true)
|
||||
target2 := createHealthyTarget("http://backend2:8080", true)
|
||||
targets := []*Target{target1, target2}
|
||||
|
||||
// 记录响应时间:backend1 慢,backend2 快
|
||||
lt.RecordResponseTime(target1, 10*time.Millisecond, 100*time.Millisecond)
|
||||
lt.RecordResponseTime(target2, 10*time.Millisecond, 10*time.Millisecond)
|
||||
|
||||
got := lt.Select(targets)
|
||||
if got == nil {
|
||||
t.Fatal("Select() = nil, want non-nil")
|
||||
}
|
||||
if got.URL != "http://backend2:8080" {
|
||||
t.Errorf("Select() = %q, want %q", got.URL, "http://backend2:8080")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("空目标", func(_ *testing.T) {
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
got := lt.Select([]*Target{})
|
||||
if got != nil {
|
||||
t.Errorf("Select() = %v, want nil", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("跳过不健康目标", func(_ *testing.T) {
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
target1 := createHealthyTarget("http://backend1:8080", false)
|
||||
target2 := createHealthyTarget("http://backend2:8080", true)
|
||||
targets := []*Target{target1, target2}
|
||||
|
||||
lt.RecordResponseTime(target2, 10*time.Millisecond, 100*time.Millisecond)
|
||||
|
||||
got := lt.Select(targets)
|
||||
if got == nil {
|
||||
t.Fatal("Select() = nil, want non-nil")
|
||||
}
|
||||
if got.URL != "http://backend2:8080" {
|
||||
t.Errorf("Select() = %q, want %q", got.URL, "http://backend2:8080")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLeastTime_NoStats 测试无统计信息的目标。
|
||||
// 目标没有记录过响应时间时,应使用默认值选择。
|
||||
func TestLeastTime_NoStats(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("无统计信息时使用默认值", func(_ *testing.T) {
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
target1 := createHealthyTarget("http://backend1:8080", true)
|
||||
target2 := createHealthyTarget("http://backend2:8080", true)
|
||||
targets := []*Target{target1, target2}
|
||||
|
||||
// 不记录任何响应时间
|
||||
got := lt.Select(targets)
|
||||
if got == nil {
|
||||
t.Fatal("Select() = nil, want non-nil")
|
||||
}
|
||||
// 使用默认值时,应返回第一个可用目标
|
||||
if got.URL != "http://backend1:8080" {
|
||||
t.Errorf("Select() = %q, want %q", got.URL, "http://backend1:8080")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("部分目标有统计", func(_ *testing.T) {
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
target1 := createHealthyTarget("http://backend1:8080", true)
|
||||
target2 := createHealthyTarget("http://backend2:8080", true)
|
||||
targets := []*Target{target1, target2}
|
||||
|
||||
// 只记录一个目标的响应时间(非常快)
|
||||
lt.RecordResponseTime(target1, 1*time.Nanosecond, 1*time.Nanosecond)
|
||||
// target2 无统计,使用默认值(1ms = 1,000,000ns)
|
||||
|
||||
got := lt.Select(targets)
|
||||
if got == nil {
|
||||
t.Fatal("Select() = nil, want non-nil")
|
||||
}
|
||||
// target1 的 EWMA 应该远小于默认值,所以选择 target1
|
||||
if got.URL != "http://backend1:8080" {
|
||||
t.Errorf("Select() = %q, want %q", got.URL, "http://backend1:8080")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLeastTime_HeaderMetric 测试 header 指标选择。
|
||||
// 使用 "header" 指标时,应基于 header_time 而非 last_byte_time。
|
||||
func TestLeastTime_HeaderMetric(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("header指标选择", func(_ *testing.T) {
|
||||
lt := NewLeastTime("header", time.Millisecond)
|
||||
target1 := createHealthyTarget("http://backend1:8080", true)
|
||||
target2 := createHealthyTarget("http://backend2:8080", true)
|
||||
targets := []*Target{target1, target2}
|
||||
|
||||
// target1: header快但last_byte慢
|
||||
lt.RecordResponseTime(target1, 10*time.Millisecond, 100*time.Millisecond)
|
||||
// target2: header慢但last_byte快
|
||||
lt.RecordResponseTime(target2, 100*time.Millisecond, 10*time.Millisecond)
|
||||
|
||||
// 使用 header 指标,应该选择 header_time 更小的 target1
|
||||
got := lt.Select(targets)
|
||||
if got == nil {
|
||||
t.Fatal("Select() = nil, want non-nil")
|
||||
}
|
||||
if got.URL != "http://backend1:8080" {
|
||||
t.Errorf("Select() = %q, want %q", got.URL, "http://backend1:8080")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("last_byte指标选择", func(_ *testing.T) {
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
target1 := createHealthyTarget("http://backend1:8080", true)
|
||||
target2 := createHealthyTarget("http://backend2:8080", true)
|
||||
targets := []*Target{target1, target2}
|
||||
|
||||
// target1: header快但last_byte慢
|
||||
lt.RecordResponseTime(target1, 10*time.Millisecond, 100*time.Millisecond)
|
||||
// target2: header慢但last_byte快
|
||||
lt.RecordResponseTime(target2, 100*time.Millisecond, 10*time.Millisecond)
|
||||
|
||||
// 使用 last_byte 指标,应该选择 last_byte_time 更小的 target2
|
||||
got := lt.Select(targets)
|
||||
if got == nil {
|
||||
t.Fatal("Select() = nil, want non-nil")
|
||||
}
|
||||
if got.URL != "http://backend2:8080" {
|
||||
t.Errorf("Select() = %q, want %q", got.URL, "http://backend2:8080")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLeastTime_SelectExcluding 测试排除选择。
|
||||
// 排除最快的目标后,应选择次快的目标。
|
||||
func TestLeastTime_SelectExcluding(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("排除最快目标", func(_ *testing.T) {
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
target1 := createHealthyTarget("http://backend1:8080", true)
|
||||
target2 := createHealthyTarget("http://backend2:8080", true)
|
||||
target3 := createHealthyTarget("http://backend3:8080", true)
|
||||
targets := []*Target{target1, target2, target3}
|
||||
|
||||
// target1 最快
|
||||
lt.RecordResponseTime(target1, 10*time.Millisecond, 10*time.Millisecond)
|
||||
// target2 中等
|
||||
lt.RecordResponseTime(target2, 10*time.Millisecond, 50*time.Millisecond)
|
||||
// target3 最慢
|
||||
lt.RecordResponseTime(target3, 10*time.Millisecond, 100*time.Millisecond)
|
||||
|
||||
// 排除最快的 target1
|
||||
excluded := []*Target{target1}
|
||||
got := lt.SelectExcluding(targets, excluded)
|
||||
if got == nil {
|
||||
t.Fatal("SelectExcluding() = nil, want non-nil")
|
||||
}
|
||||
// 应该选择次快的 target2
|
||||
if got.URL != "http://backend2:8080" {
|
||||
t.Errorf("SelectExcluding() = %q, want %q", got.URL, "http://backend2:8080")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("排除所有目标", func(_ *testing.T) {
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
target1 := createHealthyTarget("http://backend1:8080", true)
|
||||
target2 := createHealthyTarget("http://backend2:8080", true)
|
||||
targets := []*Target{target1, target2}
|
||||
excluded := []*Target{target1, target2}
|
||||
|
||||
got := lt.SelectExcluding(targets, excluded)
|
||||
if got != nil {
|
||||
t.Errorf("SelectExcluding() = %v, want nil", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("排除列表含nil", func(_ *testing.T) {
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
target1 := createHealthyTarget("http://backend1:8080", true)
|
||||
target2 := createHealthyTarget("http://backend2:8080", true)
|
||||
targets := []*Target{target1, target2}
|
||||
|
||||
lt.RecordResponseTime(target1, 10*time.Millisecond, 10*time.Millisecond)
|
||||
lt.RecordResponseTime(target2, 10*time.Millisecond, 50*time.Millisecond)
|
||||
|
||||
excluded := []*Target{nil, target1}
|
||||
got := lt.SelectExcluding(targets, excluded)
|
||||
if got == nil {
|
||||
t.Fatal("SelectExcluding() = nil, want non-nil")
|
||||
}
|
||||
if got.URL == target1.URL {
|
||||
t.Errorf("选中了被排除的目标: %q", got.URL)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLeastTime_Concurrent 测试并发安全。
|
||||
// 50 goroutines 记录响应时间,50 goroutines 选择目标。
|
||||
func TestLeastTime_Concurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
targets := []*Target{
|
||||
createHealthyTarget("http://backend1:8080", true),
|
||||
createHealthyTarget("http://backend2:8080", true),
|
||||
createHealthyTarget("http://backend3:8080", true),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// 50 goroutines 记录响应时间
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
target := targets[idx%len(targets)]
|
||||
headerTime := time.Duration(10+idx) * time.Millisecond
|
||||
lastByteTime := time.Duration(50+idx) * time.Millisecond
|
||||
lt.RecordResponseTime(target, headerTime, lastByteTime)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 50 goroutines 选择目标
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
got := lt.Select(targets)
|
||||
if got == nil {
|
||||
t.Error("并发Select() = nil, want non-nil")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestLeastTime_GetMetric 测试 GetMetric 方法。
|
||||
func TestLeastTime_GetMetric(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("header metric", func(_ *testing.T) {
|
||||
lt := NewLeastTime("header", time.Millisecond)
|
||||
if lt.GetMetric() != "header" {
|
||||
t.Errorf("GetMetric() = %q, want %q", lt.GetMetric(), "header")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("last_byte metric", func(_ *testing.T) {
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
if lt.GetMetric() != "last_byte" {
|
||||
t.Errorf("GetMetric() = %q, want %q", lt.GetMetric(), "last_byte")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("默认metric", func(_ *testing.T) {
|
||||
lt := NewLeastTime("", time.Millisecond)
|
||||
if lt.GetMetric() != "last_byte" {
|
||||
t.Errorf("GetMetric() = %q, want %q", lt.GetMetric(), "last_byte")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLeastTime_BalancerInterface 验证 LeastTime 实现了 Balancer 接口。
|
||||
func TestLeastTime_BalancerInterface(t *testing.T) {
|
||||
t.Parallel()
|
||||
var _ Balancer = (*LeastTime)(nil)
|
||||
var _ ResponseTimeRecorder = (*LeastTime)(nil)
|
||||
}
|
||||
|
||||
// TestLeastTime_RecordResponseTimeNil 测试 RecordResponseTime 的 nil 处理。
|
||||
func TestLeastTime_RecordResponseTimeNil(t *testing.T) {
|
||||
t.Parallel()
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
|
||||
// nil target 不应 panic
|
||||
lt.RecordResponseTime(nil, 10*time.Millisecond, 100*time.Millisecond)
|
||||
|
||||
// nil Stats 不应 panic
|
||||
target := &Target{URL: "http://backend1:8080"}
|
||||
lt.RecordResponseTime(target, 10*time.Millisecond, 100*time.Millisecond)
|
||||
}
|
||||
|
||||
// TestLeastTime_DefaultTimeValidation 测试默认值参数验证。
|
||||
func TestLeastTime_DefaultTimeValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("零值默认值", func(_ *testing.T) {
|
||||
lt := NewLeastTime("last_byte", 0)
|
||||
if lt.defaultTime != time.Millisecond {
|
||||
t.Errorf("defaultTime = %v, want %v", lt.defaultTime, time.Millisecond)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("负值默认值", func(_ *testing.T) {
|
||||
lt := NewLeastTime("last_byte", -1*time.Millisecond)
|
||||
if lt.defaultTime != time.Millisecond {
|
||||
t.Errorf("defaultTime = %v, want %v", lt.defaultTime, time.Millisecond)
|
||||
}
|
||||
})
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user