lolly/internal/loadbalance/least_time_test.go
xfy fa95b2a76e feat(loadbalance): implement Least Time balancer
- Add atomic EWMA Stats field to Target
- Implement LeastTime balancer with header_time and last_byte metrics
- Support Select and SelectExcluding with zero-lock design
- Add ResponseTimeRecorder interface for proxy integration
2026-06-08 17:21:20 +08:00

312 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package loadbalance
import (
"sync"
"testing"
"time"
)
// TestLeastTime_BasicSelect 测试基本的响应时间选择。
// 两个目标,不同响应时间,应选择更快的目标。
func TestLeastTime_BasicSelect(t *testing.T) {
t.Parallel()
t.Run("选择响应时间最短的目标", func(_ *testing.T) {
lt := NewLeastTime("last_byte", time.Millisecond)
target1 := createHealthyTarget("http://backend1:8080", true)
target2 := createHealthyTarget("http://backend2:8080", true)
targets := []*Target{target1, target2}
// 记录响应时间backend1 慢backend2 快
lt.RecordResponseTime(target1, 10*time.Millisecond, 100*time.Millisecond)
lt.RecordResponseTime(target2, 10*time.Millisecond, 10*time.Millisecond)
got := lt.Select(targets)
if got == nil {
t.Fatal("Select() = nil, want non-nil")
}
if got.URL != "http://backend2:8080" {
t.Errorf("Select() = %q, want %q", got.URL, "http://backend2:8080")
}
})
t.Run("空目标", func(_ *testing.T) {
lt := NewLeastTime("last_byte", time.Millisecond)
got := lt.Select([]*Target{})
if got != nil {
t.Errorf("Select() = %v, want nil", got)
}
})
t.Run("跳过不健康目标", func(_ *testing.T) {
lt := NewLeastTime("last_byte", time.Millisecond)
target1 := createHealthyTarget("http://backend1:8080", false)
target2 := createHealthyTarget("http://backend2:8080", true)
targets := []*Target{target1, target2}
lt.RecordResponseTime(target2, 10*time.Millisecond, 100*time.Millisecond)
got := lt.Select(targets)
if got == nil {
t.Fatal("Select() = nil, want non-nil")
}
if got.URL != "http://backend2:8080" {
t.Errorf("Select() = %q, want %q", got.URL, "http://backend2:8080")
}
})
}
// TestLeastTime_NoStats 测试无统计信息的目标。
// 目标没有记录过响应时间时,应使用默认值选择。
func TestLeastTime_NoStats(t *testing.T) {
t.Parallel()
t.Run("无统计信息时使用默认值", func(_ *testing.T) {
lt := NewLeastTime("last_byte", time.Millisecond)
target1 := createHealthyTarget("http://backend1:8080", true)
target2 := createHealthyTarget("http://backend2:8080", true)
targets := []*Target{target1, target2}
// 不记录任何响应时间
got := lt.Select(targets)
if got == nil {
t.Fatal("Select() = nil, want non-nil")
}
// 使用默认值时,应返回第一个可用目标
if got.URL != "http://backend1:8080" {
t.Errorf("Select() = %q, want %q", got.URL, "http://backend1:8080")
}
})
t.Run("部分目标有统计", func(_ *testing.T) {
lt := NewLeastTime("last_byte", time.Millisecond)
target1 := createHealthyTarget("http://backend1:8080", true)
target2 := createHealthyTarget("http://backend2:8080", true)
targets := []*Target{target1, target2}
// 只记录一个目标的响应时间(非常快)
lt.RecordResponseTime(target1, 1*time.Nanosecond, 1*time.Nanosecond)
// target2 无统计使用默认值1ms = 1,000,000ns
got := lt.Select(targets)
if got == nil {
t.Fatal("Select() = nil, want non-nil")
}
// target1 的 EWMA 应该远小于默认值,所以选择 target1
if got.URL != "http://backend1:8080" {
t.Errorf("Select() = %q, want %q", got.URL, "http://backend1:8080")
}
})
}
// TestLeastTime_HeaderMetric 测试 header 指标选择。
// 使用 "header" 指标时,应基于 header_time 而非 last_byte_time。
func TestLeastTime_HeaderMetric(t *testing.T) {
t.Parallel()
t.Run("header指标选择", func(_ *testing.T) {
lt := NewLeastTime("header", time.Millisecond)
target1 := createHealthyTarget("http://backend1:8080", true)
target2 := createHealthyTarget("http://backend2:8080", true)
targets := []*Target{target1, target2}
// target1: header快但last_byte慢
lt.RecordResponseTime(target1, 10*time.Millisecond, 100*time.Millisecond)
// target2: header慢但last_byte快
lt.RecordResponseTime(target2, 100*time.Millisecond, 10*time.Millisecond)
// 使用 header 指标,应该选择 header_time 更小的 target1
got := lt.Select(targets)
if got == nil {
t.Fatal("Select() = nil, want non-nil")
}
if got.URL != "http://backend1:8080" {
t.Errorf("Select() = %q, want %q", got.URL, "http://backend1:8080")
}
})
t.Run("last_byte指标选择", func(_ *testing.T) {
lt := NewLeastTime("last_byte", time.Millisecond)
target1 := createHealthyTarget("http://backend1:8080", true)
target2 := createHealthyTarget("http://backend2:8080", true)
targets := []*Target{target1, target2}
// target1: header快但last_byte慢
lt.RecordResponseTime(target1, 10*time.Millisecond, 100*time.Millisecond)
// target2: header慢但last_byte快
lt.RecordResponseTime(target2, 100*time.Millisecond, 10*time.Millisecond)
// 使用 last_byte 指标,应该选择 last_byte_time 更小的 target2
got := lt.Select(targets)
if got == nil {
t.Fatal("Select() = nil, want non-nil")
}
if got.URL != "http://backend2:8080" {
t.Errorf("Select() = %q, want %q", got.URL, "http://backend2:8080")
}
})
}
// TestLeastTime_SelectExcluding 测试排除选择。
// 排除最快的目标后,应选择次快的目标。
func TestLeastTime_SelectExcluding(t *testing.T) {
t.Parallel()
t.Run("排除最快目标", func(_ *testing.T) {
lt := NewLeastTime("last_byte", time.Millisecond)
target1 := createHealthyTarget("http://backend1:8080", true)
target2 := createHealthyTarget("http://backend2:8080", true)
target3 := createHealthyTarget("http://backend3:8080", true)
targets := []*Target{target1, target2, target3}
// target1 最快
lt.RecordResponseTime(target1, 10*time.Millisecond, 10*time.Millisecond)
// target2 中等
lt.RecordResponseTime(target2, 10*time.Millisecond, 50*time.Millisecond)
// target3 最慢
lt.RecordResponseTime(target3, 10*time.Millisecond, 100*time.Millisecond)
// 排除最快的 target1
excluded := []*Target{target1}
got := lt.SelectExcluding(targets, excluded)
if got == nil {
t.Fatal("SelectExcluding() = nil, want non-nil")
}
// 应该选择次快的 target2
if got.URL != "http://backend2:8080" {
t.Errorf("SelectExcluding() = %q, want %q", got.URL, "http://backend2:8080")
}
})
t.Run("排除所有目标", func(_ *testing.T) {
lt := NewLeastTime("last_byte", time.Millisecond)
target1 := createHealthyTarget("http://backend1:8080", true)
target2 := createHealthyTarget("http://backend2:8080", true)
targets := []*Target{target1, target2}
excluded := []*Target{target1, target2}
got := lt.SelectExcluding(targets, excluded)
if got != nil {
t.Errorf("SelectExcluding() = %v, want nil", got)
}
})
t.Run("排除列表含nil", func(_ *testing.T) {
lt := NewLeastTime("last_byte", time.Millisecond)
target1 := createHealthyTarget("http://backend1:8080", true)
target2 := createHealthyTarget("http://backend2:8080", true)
targets := []*Target{target1, target2}
lt.RecordResponseTime(target1, 10*time.Millisecond, 10*time.Millisecond)
lt.RecordResponseTime(target2, 10*time.Millisecond, 50*time.Millisecond)
excluded := []*Target{nil, target1}
got := lt.SelectExcluding(targets, excluded)
if got == nil {
t.Fatal("SelectExcluding() = nil, want non-nil")
}
if got.URL == target1.URL {
t.Errorf("选中了被排除的目标: %q", got.URL)
}
})
}
// TestLeastTime_Concurrent 测试并发安全。
// 50 goroutines 记录响应时间50 goroutines 选择目标。
func TestLeastTime_Concurrent(t *testing.T) {
t.Parallel()
lt := NewLeastTime("last_byte", time.Millisecond)
targets := []*Target{
createHealthyTarget("http://backend1:8080", true),
createHealthyTarget("http://backend2:8080", true),
createHealthyTarget("http://backend3:8080", true),
}
var wg sync.WaitGroup
// 50 goroutines 记录响应时间
for i := 0; i < 50; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
target := targets[idx%len(targets)]
headerTime := time.Duration(10+idx) * time.Millisecond
lastByteTime := time.Duration(50+idx) * time.Millisecond
lt.RecordResponseTime(target, headerTime, lastByteTime)
}(i)
}
// 50 goroutines 选择目标
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
got := lt.Select(targets)
if got == nil {
t.Error("并发Select() = nil, want non-nil")
}
}()
}
wg.Wait()
}
// TestLeastTime_GetMetric 测试 GetMetric 方法。
func TestLeastTime_GetMetric(t *testing.T) {
t.Parallel()
t.Run("header metric", func(_ *testing.T) {
lt := NewLeastTime("header", time.Millisecond)
if lt.GetMetric() != "header" {
t.Errorf("GetMetric() = %q, want %q", lt.GetMetric(), "header")
}
})
t.Run("last_byte metric", func(_ *testing.T) {
lt := NewLeastTime("last_byte", time.Millisecond)
if lt.GetMetric() != "last_byte" {
t.Errorf("GetMetric() = %q, want %q", lt.GetMetric(), "last_byte")
}
})
t.Run("默认metric", func(_ *testing.T) {
lt := NewLeastTime("", time.Millisecond)
if lt.GetMetric() != "last_byte" {
t.Errorf("GetMetric() = %q, want %q", lt.GetMetric(), "last_byte")
}
})
}
// TestLeastTime_BalancerInterface 验证 LeastTime 实现了 Balancer 接口。
func TestLeastTime_BalancerInterface(t *testing.T) {
t.Parallel()
var _ Balancer = (*LeastTime)(nil)
var _ ResponseTimeRecorder = (*LeastTime)(nil)
}
// TestLeastTime_RecordResponseTimeNil 测试 RecordResponseTime 的 nil 处理。
func TestLeastTime_RecordResponseTimeNil(t *testing.T) {
t.Parallel()
lt := NewLeastTime("last_byte", time.Millisecond)
// nil target 不应 panic
lt.RecordResponseTime(nil, 10*time.Millisecond, 100*time.Millisecond)
// nil Stats 不应 panic
target := &Target{URL: "http://backend1:8080"}
lt.RecordResponseTime(target, 10*time.Millisecond, 100*time.Millisecond)
}
// TestLeastTime_DefaultTimeValidation 测试默认值参数验证。
func TestLeastTime_DefaultTimeValidation(t *testing.T) {
t.Parallel()
t.Run("零值默认值", func(_ *testing.T) {
lt := NewLeastTime("last_byte", 0)
if lt.defaultTime != time.Millisecond {
t.Errorf("defaultTime = %v, want %v", lt.defaultTime, time.Millisecond)
}
})
t.Run("负值默认值", func(_ *testing.T) {
lt := NewLeastTime("last_byte", -1*time.Millisecond)
if lt.defaultTime != time.Millisecond {
t.Errorf("defaultTime = %v, want %v", lt.defaultTime, time.Millisecond)
}
})
}