diff --git a/internal/loadbalance/balancer.go b/internal/loadbalance/balancer.go index e7e2f44..f0e63e7 100644 --- a/internal/loadbalance/balancer.go +++ b/internal/loadbalance/balancer.go @@ -63,10 +63,10 @@ type Target struct { // ProxyURI 代理传递的 URI 路径 ProxyURI string - // failCount 失败计数(原子操作) - failCount atomic.Int64 - // failedUntil 失败冷却截止时间(UnixNano,原子操作) - failedUntil atomic.Int64 + // failMu 保护 failCount 和 failedUntil 的协调更新 + failMu sync.Mutex + failCount int64 + failedUntil int64 } // Balancer 是 HTTP 代理(L7 层)负载均衡算法的接口。 @@ -280,35 +280,39 @@ func (t *Target) IsAvailable() bool { return false } if t.MaxFails > 0 { - deadline := t.failedUntil.Load() - if deadline > 0 { - if time.Now().UnixNano() < deadline { - return false - } - // 冷却已过期,CAS 重置防止丢失并发的 RecordFailure - if t.failedUntil.CompareAndSwap(deadline, 0) { - t.failCount.Store(0) - } + t.failMu.Lock() + if t.failCount >= t.MaxFails && time.Now().UnixNano() < t.failedUntil { + t.failMu.Unlock() + return false } + // 冷却已过期,重置软状态 + if t.failCount >= t.MaxFails && t.failedUntil > 0 { + t.failCount = 0 + t.failedUntil = 0 + } + t.failMu.Unlock() } return true } // RecordFailure 记录一次失败。 -// 使用原子递增 failCount,当达到 MaxFails 时设置冷却截止时间。 +// 使用互斥锁保护 failCount/failedUntil 的协调更新。 // 返回当前失败计数。 func (t *Target) RecordFailure() int64 { if t.MaxFails <= 0 { return 0 } - count := t.failCount.Add(1) + t.failMu.Lock() + t.failCount++ + count := t.failCount if count >= t.MaxFails { timeout := t.FailTimeout if timeout <= 0 { timeout = 10 * time.Second } - t.failedUntil.Store(time.Now().Add(timeout).UnixNano()) + t.failedUntil = time.Now().Add(timeout).UnixNano() } + t.failMu.Unlock() return count } @@ -318,18 +322,10 @@ func (t *Target) RecordSuccess() { if t.MaxFails <= 0 { return } - // CAS 重置:仅在当前 goroutine 持有 deadline 时才清零, - // 防止丢失并发的 RecordFailure 设置的新 deadline。 - for { - deadline := t.failedUntil.Load() - if deadline == 0 { - break - } - if t.failedUntil.CompareAndSwap(deadline, 0) { - break - } - } - t.failCount.Store(0) + t.failMu.Lock() + t.failCount = 0 + t.failedUntil = 0 + t.failMu.Unlock() } // IsBackup 返回目标是否为备份服务器。 @@ -346,7 +342,7 @@ func (t *Target) IsBackup() bool { // 返回的切片容量与输入相同,避免多次内存分配。 func filterHealthy(targets []*Target) []*Target { available := make([]*Target, 0, len(targets)) - var backups []*Target + backups := make([]*Target, 0, len(targets)) for _, t := range targets { if !t.IsAvailable() { @@ -393,7 +389,7 @@ func filterHealthyAndExclude(targets []*Target, excluded []*Target) []*Target { } available := make([]*Target, 0, len(targets)) - var backups []*Target + backups := make([]*Target, 0, len(targets)) for _, t := range targets { if !t.IsAvailable() || excludeSet[t.URL] { diff --git a/internal/loadbalance/balancer_test.go b/internal/loadbalance/balancer_test.go index d82d8cc..807e0d7 100644 --- a/internal/loadbalance/balancer_test.go +++ b/internal/loadbalance/balancer_test.go @@ -931,6 +931,7 @@ func TestIsValidAlgorithm(t *testing.T) { {"least_conn", "least_conn", true}, {"ip_hash", "ip_hash", true}, {"consistent_hash", "consistent_hash", true}, + {"random", "random", true}, {"invalid", "invalid", false}, {"empty", "", true}, // 空字符串有效(使用默认值) {"unknown", "unknown-algorithm", false}, @@ -1820,9 +1821,11 @@ func TestTargetRecordSuccess(t *testing.T) { target.RecordFailure() target.RecordFailure() target.RecordSuccess() - if target.failCount.Load() != 0 { + target.failMu.Lock() + if target.failCount != 0 { t.Error("fail count should be reset after success") } + target.failMu.Unlock() if !target.IsAvailable() { t.Error("target should be available after success resets cooldown") }