diff --git a/internal/loadbalance/ewma.go b/internal/loadbalance/ewma.go new file mode 100644 index 0000000..383c7d1 --- /dev/null +++ b/internal/loadbalance/ewma.go @@ -0,0 +1,60 @@ +package loadbalance + +import ( + "sync/atomic" + "time" +) + +// EWMAStats 使用原子操作实现的 EWMA(指数加权移动平均)统计器。 +type EWMAStats struct { + headerTime atomic.Int64 // 首字节时间的 EWMA(纳秒) + lastByteTime atomic.Int64 // 完整响应时间的 EWMA(纳秒) + sampleCount atomic.Int64 // 样本计数 +} + +const defaultAlphaScale = 300 // alpha = 0.3 + +func NewEWMAStats() *EWMAStats { + return &EWMAStats{} +} + +func (e *EWMAStats) Record(headerTime, lastByteTime time.Duration) { + e.recordAtomic(&e.headerTime, headerTime) + e.recordAtomic(&e.lastByteTime, lastByteTime) + e.sampleCount.Add(1) +} + +func (e *EWMAStats) recordAtomic(ptr *atomic.Int64, newValue time.Duration) { + newNano := newValue.Nanoseconds() + for { + old := ptr.Load() + if old == 0 { + if ptr.CompareAndSwap(0, newNano) { + return + } + continue + } + updated := (defaultAlphaScale*newNano + (1000-defaultAlphaScale)*old) / 1000 + if ptr.CompareAndSwap(old, updated) { + return + } + } +} + +func (e *EWMAStats) HeaderTime() time.Duration { + return time.Duration(e.headerTime.Load()) +} + +func (e *EWMAStats) LastByteTime() time.Duration { + return time.Duration(e.lastByteTime.Load()) +} + +func (e *EWMAStats) SampleCount() int64 { + return e.sampleCount.Load() +} + +func (e *EWMAStats) Reset() { + e.headerTime.Store(0) + e.lastByteTime.Store(0) + e.sampleCount.Store(0) +} diff --git a/internal/loadbalance/ewma_test.go b/internal/loadbalance/ewma_test.go new file mode 100644 index 0000000..95b2eb3 --- /dev/null +++ b/internal/loadbalance/ewma_test.go @@ -0,0 +1,61 @@ +package loadbalance + +import ( + "sync" + "testing" + "time" +) + +func TestEWMAStats_BasicRecord(t *testing.T) { + stats := NewEWMAStats() + + stats.Record(50*time.Millisecond, 100*time.Millisecond) + + if stats.HeaderTime() != 50*time.Millisecond { + t.Errorf("expected header time %v, got %v", 50*time.Millisecond, stats.HeaderTime()) + } + if stats.LastByteTime() != 100*time.Millisecond { + t.Errorf("expected last byte time %v, got %v", 100*time.Millisecond, stats.LastByteTime()) + } + if stats.SampleCount() != 1 { + t.Errorf("expected sample count 1, got %d", stats.SampleCount()) + } +} + +func TestEWMAStats_Convergence(t *testing.T) { + stats := NewEWMAStats() + + value := 100 * time.Millisecond + for i := 0; i < 10; i++ { + stats.Record(value, value) + } + + // alpha=0.3, after 10 samples should be within 10ms of 100ms + diff := stats.LastByteTime() - value + if diff < 0 { + diff = -diff + } + if diff > 10*time.Millisecond { + t.Errorf("expected convergence within 10ms, got diff=%v, value=%v", diff, stats.LastByteTime()) + } +} + +func TestEWMAStats_Concurrent(t *testing.T) { + stats := NewEWMAStats() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + stats.Record(time.Millisecond, 2*time.Millisecond) + } + }() + } + wg.Wait() + + if stats.SampleCount() != 100*100 { + t.Errorf("expected sample count %d, got %d", 100*100, stats.SampleCount()) + } +}