Compare commits

..

10 Commits

Author SHA1 Message Date
xfy
afbbc3a951 chore: release v0.4.1 2026-06-10 13:54:26 +08:00
xfy
66ea93e3c1 fix(stream): reset stopCh after Stop for restartability 2026-06-10 13:48:07 +08:00
xfy
7204432ca0 fix(stream): correct upstream selection and add graceful shutdown
- Fix handleConnection to use addr parameter for direct upstream map
  lookup instead of always selecting the first upstream
- Add Server.Stop() for graceful shutdown with listener closing, UDP
  server cleanup, health checker termination, and goroutine joining
- Add shutdownStream() to App and call it in SIGTERM/SIGQUIT/SIGUSR2
  signal handlers to prevent goroutine and port leaks on shutdown
2026-06-10 13:45:35 +08:00
xfy
f12ffd180f chore: release v0.4.0
- Update CHANGELOG.md for v0.4.0
- Update Makefile FALLBACK_VERSION to 0.4.0
- Fix lint warnings (godoc comments, goconst)
- Clean up code formatting
2026-06-09 15:59:36 +08:00
xfy
503daf65d3 perf(loadbalance): add benchmarks for Least Time and Sticky
- Benchmark Select and Record operations
- Concurrent benchmark for realistic load testing
- Baseline performance:
  - LeastTime.Select: ~33ns/op, 0 allocs
  - LeastTime.Record: ~5.6ns/op, 0 allocs
  - StickySession.Select: ~205ns/op (with cookie lookup)
2026-06-08 18:21:03 +08:00
xfy
ef871f1d39 test(loadbalance): add integration tests for Least Time and Sticky
- Verify Least Time picks faster target consistently
- Verify Sticky fallback when target becomes unhealthy
- Test cookie encoding and session persistence
2026-06-08 18:19:20 +08:00
xfy
e5885ce888 fix(proxy): correct response time recording for Least Time
- Record headerTime when header is received
- Record lastByteTime when response is complete
- Use correct timing calculations (headerReceived/connectEnd/responseEnd)
2026-06-08 18:17:08 +08:00
xfy
72f189bba8 feat(proxy): integrate Least Time and Sticky balancers
- Add least_time and sticky to createBalancerByName
- Implement response time recording for Least Time
- Support StickySession in target selector with request context
- StickySession auto-starts when created
2026-06-08 18:11:47 +08:00
xfy
3b6b70a491 fix(config): validate least_time default_time is not negative 2026-06-08 18:03:52 +08:00
xfy
cb1f86298e fix: add missing test coverage for Task 4 config integration
- Add validation tests for least_time and sticky configs
- Add algorithm tests for least_time and sticky
- Add SameSite validation in validateProxy
2026-06-08 18:01:21 +08:00
19 changed files with 636 additions and 69 deletions

View File

@ -7,6 +7,39 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
## [Unreleased] ## [Unreleased]
## [0.4.1] - 2026-06-10
### Fixed
- **stream**: 修复上游选择逻辑,添加优雅关闭支持
- **stream**: 修复 Stop 后 stopCh 未重置导致的重启问题
## [0.4.0] - 2026-06-09
### Added
#### 负载均衡
- **Least Time 负载均衡器**:基于 EWMA指数加权移动平均统计的最少响应时间算法自动选择响应最快上游
- **Session Sticky 负载均衡器**:基于 Cookie 的会话粘性,一致性哈希分片,支持 Cookie 过期、域名、路径、Secure/HttpOnly/SameSite 属性
- 对应 YAML 配置支持:`least_time``sticky` 策略及参数
#### 平台与构建
- FreeBSD 部署示例
### Fixed
- Least Time 响应时间记录修正
- Sticky Cookie 格式、分片键、过期检查修复
- Sticky 双重 Stop 防护和重启支持
- 配置验证:`least_time``default_time` 不允许负值
### Tests
- Least Time 和 Sticky 负载均衡器集成测试
- Least Time 和 Sticky 基准测试
## [0.3.0] - 2026-06-05 ## [0.3.0] - 2026-06-05
### Added ### Added

View File

@ -1,7 +1,7 @@
# Makefile - Lolly Build Commands # Makefile - Lolly Build Commands
APP_NAME := lolly APP_NAME := lolly
FALLBACK_VERSION := 0.3.0 FALLBACK_VERSION := 0.4.1
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null | sed 's/^v//' || echo "$(FALLBACK_VERSION)") VERSION := $(shell git describe --tags --always --dirty 2>/dev/null | sed 's/^v//' || echo "$(FALLBACK_VERSION)")
GIT_COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown") GIT_COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")

View File

@ -123,6 +123,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
timeout = 30 * time.Second timeout = 30 * time.Second
} }
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("Graceful stop (waiting %v)", timeout)) a.logger.LogSignal("SIGQUIT", fmt.Sprintf("Graceful stop (waiting %v)", timeout))
a.shutdownStream()
a.shutdownHTTP2() a.shutdownHTTP2()
a.shutdownHTTP3() a.shutdownHTTP3()
_ = a.srv.GracefulStop(timeout) _ = a.srv.GracefulStop(timeout)
@ -139,6 +140,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
} else { } else {
a.logger.LogSignal(sigName(sigTyped), "Stopping server") a.logger.LogSignal(sigName(sigTyped), "Stopping server")
} }
a.shutdownStream()
a.shutdownHTTP2() a.shutdownHTTP2()
a.shutdownHTTP3() a.shutdownHTTP3()
_ = a.srv.StopWithTimeout(timeout) _ = a.srv.StopWithTimeout(timeout)
@ -329,6 +331,7 @@ func (a *App) gracefulUpgrade() {
if timeout <= 0 { if timeout <= 0 {
timeout = 30 * time.Second timeout = 30 * time.Second
} }
a.shutdownStream()
a.shutdownHTTP2() a.shutdownHTTP2()
a.shutdownHTTP3() a.shutdownHTTP3()
_ = a.srv.GracefulStop(timeout) _ = a.srv.GracefulStop(timeout)

View File

@ -252,6 +252,12 @@ func (a *App) shutdownHTTP2() {
} }
// reopenLogs reinitializes the logger from current config. // reopenLogs reinitializes the logger from current config.
func (a *App) shutdownStream() {
if a.streamSrv != nil {
a.streamSrv.Stop()
}
}
func (a *App) reopenLogs() { func (a *App) reopenLogs() {
if a.cfg != nil { if a.cfg != nil {
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format) logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)

View File

@ -19,7 +19,7 @@ import (
// 注意事项: // 注意事项:
// - Path 使用前缀匹配,较长路径优先匹配 // - Path 使用前缀匹配,较长路径优先匹配
// - 至少配置一个 Target 才能正常工作 // - 至少配置一个 Target 才能正常工作
// - 负载均衡算法支持round_robin、weighted_round_robin、least_conn、ip_hash、consistent_hash、random、least_time、sticky // - 负载均衡算法支持round_robin、weighted_round_robin、least_conn、ip_hash、consistent_hash、random、least_time、sticky
// - 一致性哈希需要配置 HashKey // - 一致性哈希需要配置 HashKey
// //
// 使用示例: // 使用示例:

View File

@ -509,6 +509,9 @@ func validateProxy(p *ProxyConfig) error {
if p.LeastTime.Metric != "" && p.LeastTime.Metric != "header" && p.LeastTime.Metric != "last_byte" { if p.LeastTime.Metric != "" && p.LeastTime.Metric != "header" && p.LeastTime.Metric != "last_byte" {
return fmt.Errorf("无效的 least_time metric: %s有效值: header, last_byte", p.LeastTime.Metric) return fmt.Errorf("无效的 least_time metric: %s有效值: header, last_byte", p.LeastTime.Metric)
} }
if p.LeastTime.DefaultTime < 0 {
return fmt.Errorf("least_time default_time 不能为负数")
}
} }
// validate sticky config // validate sticky config
@ -519,6 +522,12 @@ func validateProxy(p *ProxyConfig) error {
if p.Sticky.FallbackAlgo != "" && !loadbalance.IsValidAlgorithm(p.Sticky.FallbackAlgo) { if p.Sticky.FallbackAlgo != "" && !loadbalance.IsValidAlgorithm(p.Sticky.FallbackAlgo) {
return fmt.Errorf("无效的 sticky fallback_balance: %s", p.Sticky.FallbackAlgo) return fmt.Errorf("无效的 sticky fallback_balance: %s", p.Sticky.FallbackAlgo)
} }
if p.Sticky.SameSite != "" {
validSameSites := []string{"Lax", "Strict", "None"}
if !slices.Contains(validSameSites, p.Sticky.SameSite) {
return fmt.Errorf("无效的 sticky same_site: %s有效值: Lax, Strict, None", p.Sticky.SameSite)
}
}
} }
// 验证故障转移配置 // 验证故障转移配置

View File

@ -170,6 +170,80 @@ func TestValidateProxy(t *testing.T) {
wantErr: true, wantErr: true,
errMsg: "无效的负载均衡算法", errMsg: "无效的负载均衡算法",
}, },
{
name: "有效 least_time 配置 metric=header",
config: ProxyConfig{
Path: "/api",
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
LoadBalance: "least_time",
LeastTime: LeastTimeConfig{Metric: "header"},
},
wantErr: false,
},
{
name: "有效 least_time 配置 metric=last_byte",
config: ProxyConfig{
Path: "/api",
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
LoadBalance: "least_time",
LeastTime: LeastTimeConfig{Metric: "last_byte"},
},
wantErr: false,
},
{
name: "无效 least_time metric",
config: ProxyConfig{
Path: "/api",
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
LoadBalance: "least_time",
LeastTime: LeastTimeConfig{Metric: "invalid"},
},
wantErr: true,
errMsg: "无效的 least_time metric",
},
{
name: "有效 sticky 配置",
config: ProxyConfig{
Path: "/api",
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
LoadBalance: "sticky",
Sticky: StickyConfig{Enabled: true, FallbackAlgo: "round_robin"},
},
wantErr: false,
},
{
name: "无效 sticky enabled=false",
config: ProxyConfig{
Path: "/api",
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
LoadBalance: "sticky",
Sticky: StickyConfig{Enabled: false},
},
wantErr: true,
errMsg: "sticky.enabled 必须为 true",
},
{
name: "无效 sticky fallback_balance",
config: ProxyConfig{
Path: "/api",
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
LoadBalance: "sticky",
Sticky: StickyConfig{Enabled: true, FallbackAlgo: "invalid"},
},
wantErr: true,
errMsg: "无效的 sticky fallback_balance",
},
{
name: "无效 sticky same_site",
config: ProxyConfig{
Path: "/api",
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
LoadBalance: "sticky",
Sticky: StickyConfig{Enabled: true, SameSite: "Invalid"},
},
wantErr: true,
errMsg: "无效的 sticky same_site",
},
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -11,10 +11,11 @@ import (
) )
const ( const (
gzipType = "gzip" gzipType = "gzip"
offValue = "off" offValue = "off"
redirectType = "redirect" redirectType = "redirect"
staticType = "static" staticType = "static"
returnDirective = "return"
) )
// Warning represents a conversion warning for unsupported or partially supported directives. // Warning represents a conversion warning for unsupported or partially supported directives.
@ -70,7 +71,7 @@ var unsupportedDirectives = map[string]string{
"split_clients": "the 'split_clients' directive is not supported", "split_clients": "the 'split_clients' directive is not supported",
"geo": "the 'geo' directive is not supported; use access.geoip config instead", "geo": "the 'geo' directive is not supported; use access.geoip config instead",
"range": "the 'range' directive is not supported", "range": "the 'range' directive is not supported",
"return": "the 'return' directive is not supported for non-redirect status codes; only 301/302 are supported", returnDirective: "the 'return' directive is not supported for non-redirect status codes; only 301/302 are supported",
} }
// Convert converts a parsed nginx configuration to a lolly configuration. // Convert converts a parsed nginx configuration to a lolly configuration.
@ -273,7 +274,7 @@ func convertServerBlock(d *Directive, upstreams map[string]*upstreamInfo, result
parseAccessLog(bd, result) parseAccessLog(bd, result)
case "error_log": case "error_log":
parseErrorLog(bd, result) parseErrorLog(bd, result)
case "return": case returnDirective:
parseServerReturn(bd, &baseServer, result) parseServerReturn(bd, &baseServer, result)
case "rewrite": case "rewrite":
parseRewrite(bd, &baseServer) parseRewrite(bd, &baseServer)
@ -457,7 +458,7 @@ func parseServerReturn(d *Directive, server *config.ServerConfig, result *Conver
}) })
default: default:
result.Warnings = append(result.Warnings, Warning{ result.Warnings = append(result.Warnings, Warning{
Directive: "return", Directive: returnDirective,
Line: d.Line, Line: d.Line,
File: d.File, File: d.File,
Message: fmt.Sprintf("return %d is not a redirect; only 301/302 are supported at server level", code), Message: fmt.Sprintf("return %d is not a redirect; only 301/302 are supported at server level", code),
@ -554,7 +555,7 @@ func classifyLocation(d *Directive, serverRoot string, result *ConvertResult) lo
hasRootOrAlias = true hasRootOrAlias = true
case "try_files": case "try_files":
hasTryFiles = true hasTryFiles = true
case "return": case returnDirective:
if len(d.Block[i].Args) > 0 { if len(d.Block[i].Args) > 0 {
code, err := strconv.Atoi(d.Block[i].Args[0]) code, err := strconv.Atoi(d.Block[i].Args[0])
if err == nil && (code == 301 || code == 302) { if err == nil && (code == 301 || code == 302) {
@ -876,7 +877,7 @@ func convertRedirectDirectives(directives []Directive, locPath string, server *c
for i := range directives { for i := range directives {
d := &directives[i] d := &directives[i]
if d.Name != "return" { if d.Name != returnDirective {
continue continue
} }
@ -900,7 +901,7 @@ func convertRedirectDirectives(directives []Directive, locPath string, server *c
Flag: "permanent", Flag: "permanent",
}) })
result.Warnings = append(result.Warnings, Warning{ result.Warnings = append(result.Warnings, Warning{
Directive: "return", Directive: returnDirective,
Line: d.Line, Line: d.Line,
File: d.File, File: d.File,
Message: "return 301 converted to rewrite rule with permanent flag", Message: "return 301 converted to rewrite rule with permanent flag",
@ -912,14 +913,14 @@ func convertRedirectDirectives(directives []Directive, locPath string, server *c
Flag: "redirect", Flag: "redirect",
}) })
result.Warnings = append(result.Warnings, Warning{ result.Warnings = append(result.Warnings, Warning{
Directive: "return", Directive: returnDirective,
Line: d.Line, Line: d.Line,
File: d.File, File: d.File,
Message: "return 302 converted to rewrite rule with redirect flag", Message: "return 302 converted to rewrite rule with redirect flag",
}) })
default: default:
result.Warnings = append(result.Warnings, Warning{ result.Warnings = append(result.Warnings, Warning{
Directive: "return", Directive: returnDirective,
Line: d.Line, Line: d.Line,
File: d.File, File: d.File,
Message: fmt.Sprintf("return %d in location is not a redirect; only 301/302 are supported", code), Message: fmt.Sprintf("return %d in location is not a redirect; only 301/302 are supported", code),

View File

@ -71,7 +71,6 @@ func TestProxyRequestHeaders(t *testing.T) {
_, err := proxy.NewProxy(cfg, targets, nil, nil) _, err := proxy.NewProxy(cfg, targets, nil, nil)
require.NoError(t, err) require.NoError(t, err)
// 验证代理配置已设置 // 验证代理配置已设置
assert.NotNil(t, cfg.Headers.SetRequest) assert.NotNil(t, cfg.Headers.SetRequest)
assert.Equal(t, "custom-value", cfg.Headers.SetRequest["X-Custom-Header"]) assert.Equal(t, "custom-value", cfg.Headers.SetRequest["X-Custom-Header"])
@ -101,7 +100,6 @@ func TestProxyResponseHeaders(t *testing.T) {
_, err := proxy.NewProxy(cfg, targets, nil, nil) _, err := proxy.NewProxy(cfg, targets, nil, nil)
require.NoError(t, err) require.NoError(t, err)
// 验证响应头配置 // 验证响应头配置
assert.Equal(t, "lolly", cfg.Headers.SetResponse["X-Server"]) assert.Equal(t, "lolly", cfg.Headers.SetResponse["X-Server"])
assert.Contains(t, cfg.Headers.Remove, "X-Powered-By") assert.Contains(t, cfg.Headers.Remove, "X-Powered-By")
@ -125,7 +123,6 @@ func TestProxyTimeout(t *testing.T) {
_, err := proxy.NewProxy(cfg, targets, nil, nil) _, err := proxy.NewProxy(cfg, targets, nil, nil)
require.NoError(t, err) require.NoError(t, err)
// 验证超时配置 // 验证超时配置
assert.Equal(t, 1*time.Second, cfg.Timeout.Connect) assert.Equal(t, 1*time.Second, cfg.Timeout.Connect)
assert.Equal(t, 50*time.Millisecond, cfg.Timeout.Read) assert.Equal(t, 50*time.Millisecond, cfg.Timeout.Read)
@ -150,7 +147,6 @@ func TestProxyLoadBalanceRoundRobin(t *testing.T) {
_, err := proxy.NewProxy(cfg, targets, nil, nil) _, err := proxy.NewProxy(cfg, targets, nil, nil)
require.NoError(t, err) require.NoError(t, err)
// 验证负载均衡器类型 // 验证负载均衡器类型
assert.Equal(t, "round_robin", cfg.LoadBalance) assert.Equal(t, "round_robin", cfg.LoadBalance)
assert.Len(t, targets, 2) assert.Len(t, targets, 2)
@ -174,7 +170,6 @@ func TestProxyWeightedRoundRobin(t *testing.T) {
_, err := proxy.NewProxy(cfg, targets, nil, nil) _, err := proxy.NewProxy(cfg, targets, nil, nil)
require.NoError(t, err) require.NoError(t, err)
// 验证权重配置 // 验证权重配置
assert.Equal(t, 3, targets[0].Weight) assert.Equal(t, 3, targets[0].Weight)
assert.Equal(t, 1, targets[1].Weight) assert.Equal(t, 1, targets[1].Weight)
@ -198,7 +193,6 @@ func TestProxyLeastConn(t *testing.T) {
_, err := proxy.NewProxy(cfg, targets, nil, nil) _, err := proxy.NewProxy(cfg, targets, nil, nil)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "least_conn", cfg.LoadBalance) assert.Equal(t, "least_conn", cfg.LoadBalance)
} }
@ -220,7 +214,6 @@ func TestProxyIPHash(t *testing.T) {
_, err := proxy.NewProxy(cfg, targets, nil, nil) _, err := proxy.NewProxy(cfg, targets, nil, nil)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "ip_hash", cfg.LoadBalance) assert.Equal(t, "ip_hash", cfg.LoadBalance)
} }
@ -240,7 +233,6 @@ func TestProxyConsistentHash(t *testing.T) {
_, err := proxy.NewProxy(cfg, targets, nil, nil) _, err := proxy.NewProxy(cfg, targets, nil, nil)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "consistent_hash", cfg.LoadBalance) assert.Equal(t, "consistent_hash", cfg.LoadBalance)
assert.Equal(t, "uri", cfg.HashKey) assert.Equal(t, "uri", cfg.HashKey)
assert.Equal(t, 150, cfg.VirtualNodes) assert.Equal(t, 150, cfg.VirtualNodes)
@ -268,7 +260,6 @@ func TestProxyErrorHandling(t *testing.T) {
_, err := proxy.NewProxy(cfg, targets, nil, nil) _, err := proxy.NewProxy(cfg, targets, nil, nil)
require.NoError(t, err) require.NoError(t, err)
// 验证 MaxFails 配置 (int64 类型) // 验证 MaxFails 配置 (int64 类型)
assert.Equal(t, int64(3), targets[0].MaxFails) assert.Equal(t, int64(3), targets[0].MaxFails)
assert.Equal(t, 10*time.Second, targets[0].FailTimeout) assert.Equal(t, 10*time.Second, targets[0].FailTimeout)
@ -300,7 +291,6 @@ func TestProxyCacheConfig(t *testing.T) {
_, err := proxy.NewProxy(cfg, targets, nil, nil) _, err := proxy.NewProxy(cfg, targets, nil, nil)
require.NoError(t, err) require.NoError(t, err)
// 验证缓存配置 // 验证缓存配置
assert.True(t, cfg.Cache.Enabled) assert.True(t, cfg.Cache.Enabled)
assert.Equal(t, 60*time.Second, cfg.Cache.MaxAge) assert.Equal(t, 60*time.Second, cfg.Cache.MaxAge)
@ -330,7 +320,6 @@ func TestProxyNextUpstream(t *testing.T) {
_, err := proxy.NewProxy(cfg, targets, nil, nil) _, err := proxy.NewProxy(cfg, targets, nil, nil)
require.NoError(t, err) require.NoError(t, err)
// 验证故障转移配置 // 验证故障转移配置
assert.Equal(t, 3, cfg.NextUpstream.Tries) assert.Equal(t, 3, cfg.NextUpstream.Tries)
assert.Contains(t, cfg.NextUpstream.HTTPCodes, 502) assert.Contains(t, cfg.NextUpstream.HTTPCodes, 502)
@ -360,7 +349,6 @@ func TestProxyHealthCheck(t *testing.T) {
_, err := proxy.NewProxy(cfg, targets, nil, nil) _, err := proxy.NewProxy(cfg, targets, nil, nil)
require.NoError(t, err) require.NoError(t, err)
// 验证健康检查配置 // 验证健康检查配置
assert.Equal(t, 10*time.Second, cfg.HealthCheck.Interval) assert.Equal(t, 10*time.Second, cfg.HealthCheck.Interval)
assert.Equal(t, "/health", cfg.HealthCheck.Path) assert.Equal(t, "/health", cfg.HealthCheck.Path)

View File

@ -12,9 +12,12 @@ package loadbalance
import ( import (
"net" "net"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/valyala/fasthttp"
) )
// createHealthyTarget 创建并返回一个指定 URL 和健康状态的 Target 实例。 // createHealthyTarget 创建并返回一个指定 URL 和健康状态的 Target 实例。
@ -940,6 +943,8 @@ func TestIsValidAlgorithm(t *testing.T) {
{"ip_hash", "ip_hash", true}, {"ip_hash", "ip_hash", true},
{"consistent_hash", "consistent_hash", true}, {"consistent_hash", "consistent_hash", true},
{"random", "random", true}, {"random", "random", true},
{"least_time", "least_time", true},
{"sticky", "sticky", true},
{"invalid", "invalid", false}, {"invalid", "invalid", false},
{"empty", "", true}, // 空字符串有效(使用默认值) {"empty", "", true}, // 空字符串有效(使用默认值)
{"unknown", "unknown-algorithm", false}, {"unknown", "unknown-algorithm", false},
@ -2091,3 +2096,94 @@ func TestRandomBalancer(t *testing.T) {
} }
}) })
} }
// TestBalancerIntegration_LeastTime 验证 Least Time 算法 consistently 选择更快目标。
func TestBalancerIntegration_LeastTime(t *testing.T) {
targets := []*Target{
NewTargetFromConfig("http://slow:8080", 1, 0, 0, 0, false, false, ""),
NewTargetFromConfig("http://fast:8080", 1, 0, 0, 0, false, false, ""),
}
lt := NewLeastTime("last_byte", time.Millisecond)
// 模拟slow 目标有 100ms avgfast 有 10ms avg
for i := 0; i < 10; i++ {
targets[0].Stats.Record(50*time.Millisecond, 100*time.Millisecond)
targets[1].Stats.Record(5*time.Millisecond, 10*time.Millisecond)
}
// 选择 100 次,应该 mostly 选 fast
fastCount := 0
for i := 0; i < 100; i++ {
selected := lt.Select(targets)
if selected.URL == "http://fast:8080" {
fastCount++
}
}
if fastCount < 80 {
t.Errorf("fast target selected %d/100 times, expected >80", fastCount)
}
}
// TestBalancerIntegration_StickyWithLeastTimeFallback 验证 Sticky 在目标不健康时 fallback。
func TestBalancerIntegration_StickyWithLeastTimeFallback(t *testing.T) {
fallback := NewLeastTime("last_byte", time.Millisecond)
config := StickyConfig{
Enabled: true,
Name: "test_route",
Expires: time.Hour,
Path: "/",
HttpOnly: true,
}
sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop()
targets := []*Target{
NewTargetFromConfig("http://backend1:8080", 1, 0, 0, 0, false, false, ""),
NewTargetFromConfig("http://backend2:8080", 1, 0, 0, 0, false, false, ""),
}
ctx := &fasthttp.RequestCtx{}
// 首次请求
selected1 := sticky.Select(ctx, targets)
if selected1 == nil {
t.Fatal("expected a target")
}
// 验证 cookie 已设置
cookie := ctx.Response.Header.PeekCookie("test_route")
if len(cookie) == 0 {
t.Fatal("expected cookie")
}
// 使 selected1 不健康
selected1.Healthy.Store(false)
// 第二次请求带 cookie 应该 fallback
ctx2 := &fasthttp.RequestCtx{}
ctx2.Request.Header.SetCookie("test_route", string(extractCookieValue(cookie)))
selected2 := sticky.Select(ctx2, targets)
if selected2 == nil {
t.Fatal("expected fallback target")
}
if selected2.URL == selected1.URL {
t.Error("expected different target after fallback")
}
}
// extractCookieValue 从 Set-Cookie header 中提取 cookie 值
func extractCookieValue(cookieHeader []byte) []byte {
s := string(cookieHeader)
// Format: "name=value; ..."
parts := strings.SplitN(s, "=", 2)
if len(parts) != 2 {
return nil
}
valueParts := strings.SplitN(parts[1], ";", 2)
return []byte(valueParts[0])
}

View File

@ -14,10 +14,12 @@ type EWMAStats struct {
const defaultAlphaScale = 300 // alpha = 0.3 const defaultAlphaScale = 300 // alpha = 0.3
// NewEWMAStats 创建新的 EWMA 统计器。
func NewEWMAStats() *EWMAStats { func NewEWMAStats() *EWMAStats {
return &EWMAStats{} return &EWMAStats{}
} }
// Record 记录一次响应时间样本。
func (e *EWMAStats) Record(headerTime, lastByteTime time.Duration) { func (e *EWMAStats) Record(headerTime, lastByteTime time.Duration) {
e.recordAtomic(&e.headerTime, headerTime) e.recordAtomic(&e.headerTime, headerTime)
e.recordAtomic(&e.lastByteTime, lastByteTime) e.recordAtomic(&e.lastByteTime, lastByteTime)
@ -41,18 +43,22 @@ func (e *EWMAStats) recordAtomic(ptr *atomic.Int64, newValue time.Duration) {
} }
} }
// HeaderTime 返回首字节时间的 EWMA 值。
func (e *EWMAStats) HeaderTime() time.Duration { func (e *EWMAStats) HeaderTime() time.Duration {
return time.Duration(e.headerTime.Load()) return time.Duration(e.headerTime.Load())
} }
// LastByteTime 返回完整响应时间的 EWMA 值。
func (e *EWMAStats) LastByteTime() time.Duration { func (e *EWMAStats) LastByteTime() time.Duration {
return time.Duration(e.lastByteTime.Load()) return time.Duration(e.lastByteTime.Load())
} }
// SampleCount 返回已记录的样本数量。
func (e *EWMAStats) SampleCount() int64 { func (e *EWMAStats) SampleCount() int64 {
return e.sampleCount.Load() return e.sampleCount.Load()
} }
// Reset 重置所有统计数据。
func (e *EWMAStats) Reset() { func (e *EWMAStats) Reset() {
e.headerTime.Store(0) e.headerTime.Store(0)
e.lastByteTime.Store(0) e.lastByteTime.Store(0)

View File

@ -97,5 +97,7 @@ func (l *LeastTime) GetMetric() string {
return l.metric return l.metric
} }
var _ Balancer = (*LeastTime)(nil) var (
var _ ResponseTimeRecorder = (*LeastTime)(nil) _ Balancer = (*LeastTime)(nil)
_ ResponseTimeRecorder = (*LeastTime)(nil)
)

View File

@ -0,0 +1,96 @@
package loadbalance
import (
"testing"
"time"
"github.com/valyala/fasthttp"
)
func BenchmarkLeastTime_Select(b *testing.B) {
lt := NewLeastTime("last_byte", time.Millisecond)
targets := []*Target{
NewTargetFromConfig("http://a:8080", 1, 0, 0, 0, false, false, ""),
NewTargetFromConfig("http://b:8080", 1, 0, 0, 0, false, false, ""),
NewTargetFromConfig("http://c:8080", 1, 0, 0, 0, false, false, ""),
}
// Pre-populate stats
for _, t := range targets {
t.Stats.Record(10*time.Millisecond, 20*time.Millisecond)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
lt.Select(targets)
}
}
func BenchmarkLeastTime_Record(b *testing.B) {
stats := NewEWMAStats()
b.ResetTimer()
for i := 0; i < b.N; i++ {
stats.Record(10*time.Millisecond, 20*time.Millisecond)
}
}
func BenchmarkLeastTime_Concurrent(b *testing.B) {
lt := NewLeastTime("last_byte", time.Millisecond)
targets := []*Target{
NewTargetFromConfig("http://a:8080", 1, 0, 0, 0, false, false, ""),
NewTargetFromConfig("http://b:8080", 1, 0, 0, 0, false, false, ""),
}
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
lt.Select(targets)
}
})
}
func BenchmarkStickySession_Select(b *testing.B) {
fallback := NewRoundRobin()
config := DefaultStickyConfig()
sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop()
targets := []*Target{
NewTargetFromConfig("http://backend1:8080", 1, 0, 0, 0, false, false, ""),
NewTargetFromConfig("http://backend2:8080", 1, 0, 0, 0, false, false, ""),
}
// Pre-populate a cookie
ctx := &fasthttp.RequestCtx{}
sticky.Select(ctx, targets)
cookie := ctx.Response.Header.PeekCookie(config.Name)
b.ResetTimer()
for i := 0; i < b.N; i++ {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetCookie(config.Name, string(extractCookieValue(cookie)))
sticky.Select(ctx, targets)
}
}
func BenchmarkStickySession_SelectNew(b *testing.B) {
fallback := NewRoundRobin()
config := DefaultStickyConfig()
sticky := NewStickySession(config, fallback)
sticky.Start()
defer sticky.Stop()
targets := []*Target{
NewTargetFromConfig("http://backend1:8080", 1, 0, 0, 0, false, false, ""),
NewTargetFromConfig("http://backend2:8080", 1, 0, 0, 0, false, false, ""),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ctx := &fasthttp.RequestCtx{}
sticky.Select(ctx, targets)
}
}

View File

@ -2,6 +2,7 @@ package loadbalance
import "time" import "time"
// StickyConfig 配置 Sticky 负载均衡的 Cookie 参数。
type StickyConfig struct { type StickyConfig struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
Name string `yaml:"name"` Name string `yaml:"name"`
@ -13,6 +14,7 @@ type StickyConfig struct {
SameSite string `yaml:"same_site"` SameSite string `yaml:"same_site"`
} }
// DefaultStickyConfig 返回 Sticky 负载均衡的默认配置。
func DefaultStickyConfig() StickyConfig { func DefaultStickyConfig() StickyConfig {
return StickyConfig{ return StickyConfig{
Name: "lolly_route", Name: "lolly_route",

View File

@ -90,6 +90,8 @@ const (
lbIPHash = "ip_hash" // IP 哈希 lbIPHash = "ip_hash" // IP 哈希
lbConsistentHash = "consistent_hash" // 一致性哈希 lbConsistentHash = "consistent_hash" // 一致性哈希
lbRandom = "random" // 随机Power of Two Choices lbRandom = "random" // 随机Power of Two Choices
lbLeastTime = "least_time" // 最小响应时间
lbSticky = "sticky" // 会话粘性
) )
// headersPool 复用缓存 headers map减少分配。 // headersPool 复用缓存 headers map减少分配。
@ -229,6 +231,22 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC
return p, nil return p, nil
} }
// stickyBalancer wraps StickySession to implement loadbalance.Balancer.
// It delegates Select/SelectExcluding to the fallback balancer while
// allowing the proxy to access the StickySession for cookie-based routing.
type stickyBalancer struct {
sticky *loadbalance.StickySession
fallback loadbalance.Balancer
}
func (b *stickyBalancer) Select(targets []*loadbalance.Target) *loadbalance.Target {
return b.fallback.Select(targets)
}
func (b *stickyBalancer) SelectExcluding(targets []*loadbalance.Target, excluded []*loadbalance.Target) *loadbalance.Target {
return b.fallback.SelectExcluding(targets, excluded)
}
// createBalancerByName 根据算法名称创建负载均衡器。 // createBalancerByName 根据算法名称创建负载均衡器。
// //
// 支持的算法: // 支持的算法:
@ -263,6 +281,49 @@ func createBalancerByName(name string, cfg *config.ProxyConfig) (loadbalance.Bal
return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil
case lbRandom: case lbRandom:
return loadbalance.NewRandom(), nil return loadbalance.NewRandom(), nil
case lbLeastTime:
metric := cfg.LeastTime.Metric
if metric == "" {
metric = "last_byte"
}
defaultTime := cfg.LeastTime.DefaultTime
if defaultTime <= 0 {
defaultTime = time.Millisecond
}
return loadbalance.NewLeastTime(metric, defaultTime), nil
case lbSticky:
stickyCfg := loadbalance.StickyConfig{
Enabled: cfg.Sticky.Enabled,
Name: cfg.Sticky.Name,
Expires: cfg.Sticky.Expires,
Domain: cfg.Sticky.Domain,
Path: cfg.Sticky.Path,
Secure: cfg.Sticky.Secure,
HttpOnly: cfg.Sticky.HttpOnly,
SameSite: cfg.Sticky.SameSite,
}
if stickyCfg.Name == "" {
stickyCfg.Name = "lolly_route"
}
if stickyCfg.Expires <= 0 {
stickyCfg.Expires = time.Hour
}
if stickyCfg.Path == "" {
stickyCfg.Path = "/"
}
fallbackAlgo := cfg.Sticky.FallbackAlgo
if fallbackAlgo == "" {
fallbackAlgo = lbRoundRobin
}
fallbackBalancer, err := createBalancerByName(fallbackAlgo, cfg)
if err != nil {
return nil, fmt.Errorf("sticky fallback balancer: %w", err)
}
sticky := loadbalance.NewStickySession(stickyCfg, fallbackBalancer)
sticky.Start()
return &stickyBalancer{sticky: sticky, fallback: fallbackBalancer}, nil
default: default:
return nil, errors.New("unsupported load balance algorithm: " + name) return nil, errors.New("unsupported load balance algorithm: " + name)
} }
@ -813,6 +874,12 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
// 记录首字节时间 // 记录首字节时间
timing.MarkHeaderReceived() timing.MarkHeaderReceived()
// 记录首字节响应时间(用于 least_time 负载均衡)
if recorder, ok := p.balancer.(loadbalance.ResponseTimeRecorder); ok {
headerTime := timing.headerReceived.Sub(timing.connectEnd)
recorder.RecordResponseTime(target, headerTime, 0)
}
// 请求成功,减少连接计数 // 请求成功,减少连接计数
loadbalance.DecrementConnections(target) loadbalance.DecrementConnections(target)
@ -917,6 +984,15 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
// 修改响应头 // 修改响应头
p.modifyResponseHeaders(ctx) p.modifyResponseHeaders(ctx)
// 记录完整响应时间(用于 least_time 负载均衡)
timing.MarkResponseEnd()
if recorder, ok := p.balancer.(loadbalance.ResponseTimeRecorder); ok {
headerTime := timing.headerReceived.Sub(timing.connectEnd)
lastByteTime := timing.responseEnd.Sub(timing.connectEnd)
recorder.RecordResponseTime(target, headerTime, lastByteTime)
}
return return
} }

View File

@ -119,6 +119,11 @@ func (p *Proxy) selectByBalancer(ctx *fasthttp.RequestCtx, targets []*loadbalanc
balancer := p.balancer balancer := p.balancer
p.mu.RUnlock() p.mu.RUnlock()
// 对于 StickySession 负载均衡器,需要请求上下文
if sb, ok := balancer.(*stickyBalancer); ok {
return sb.sticky.Select(ctx, targets)
}
// 对于 IPHash 负载均衡器,提取客户端 IP // 对于 IPHash 负载均衡器,提取客户端 IP
if ipHash, ok := balancer.(*loadbalance.IPHash); ok { if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
clientIP := netutil.ExtractClientIP(ctx) clientIP := netutil.ExtractClientIP(ctx)

View File

@ -12,6 +12,7 @@
package stream package stream
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -63,7 +64,7 @@ func TestStart_NoListeners(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.True(t, s.running.Load()) assert.True(t, s.running.Load())
s.running.Store(false) s.Stop()
} }
func TestStart_WithTCPListeners(t *testing.T) { func TestStart_WithTCPListeners(t *testing.T) {
@ -81,12 +82,7 @@ func TestStart_WithTCPListeners(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.True(t, s.running.Load()) assert.True(t, s.running.Load())
s.running.Store(false) s.Stop()
s.mu.RLock()
for _, ln := range s.listeners {
_ = ln.Close()
}
s.mu.RUnlock()
} }
func TestStart_AcceptConnections(t *testing.T) { func TestStart_AcceptConnections(t *testing.T) {
@ -115,7 +111,7 @@ func TestStart_AcceptConnections(t *testing.T) {
proxyAddr := ln.Addr().String() proxyAddr := ln.Addr().String()
s.mu.Lock() s.mu.Lock()
s.listeners[proxyAddr] = ln s.listeners["test"] = ln
s.mu.Unlock() s.mu.Unlock()
err = s.Start() err = s.Start()
@ -136,12 +132,7 @@ func TestStart_AcceptConnections(t *testing.T) {
_ = clientConn.Close() _ = clientConn.Close()
s.running.Store(false) s.Stop()
s.mu.RLock()
for _, l := range s.listeners {
_ = l.Close()
}
s.mu.RUnlock()
_ = backendLn.Close() _ = backendLn.Close()
} }
@ -637,3 +628,138 @@ func TestStartCleanupTicker_StopsOnSignal(t *testing.T) {
t.Fatal("startCleanupTicker did not stop after signal") t.Fatal("startCleanupTicker did not stop after signal")
} }
} }
func TestHandleConnection_MultipleUpstreams(t *testing.T) {
s := NewServer()
backend1, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
backend2, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer backend1.Close()
defer backend2.Close()
go func() {
conn, err := backend1.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 1024)
n, _ := conn.Read(buf)
_, _ = conn.Write(append([]byte("backend1:"), buf[:n]...))
}()
go func() {
conn, err := backend2.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 1024)
n, _ := conn.Read(buf)
_, _ = conn.Write(append([]byte("backend2:"), buf[:n]...))
}()
_ = s.AddUpstream("upstream1", []TargetSpec{{Addr: backend1.Addr().String()}}, "round_robin", HealthCheckSpec{})
_ = s.AddUpstream("upstream2", []TargetSpec{{Addr: backend2.Addr().String()}}, "round_robin", HealthCheckSpec{})
s.upstreams["upstream1"].targets[0].healthy.Store(true)
s.upstreams["upstream2"].targets[0].healthy.Store(true)
ln1, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
ln2, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s.mu.Lock()
s.listeners["upstream1"] = ln1
s.listeners["upstream2"] = ln2
s.mu.Unlock()
err = s.Start()
require.NoError(t, err)
defer s.Stop()
conn1, err := net.DialTimeout("tcp", ln1.Addr().String(), 2*time.Second)
require.NoError(t, err)
defer conn1.Close()
_, err = conn1.Write([]byte("hello"))
require.NoError(t, err)
buf := make([]byte, 1024)
_ = conn1.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := conn1.Read(buf)
require.NoError(t, err)
assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend1:")), "should route to backend1, got: %s", string(buf[:n]))
conn2, err := net.DialTimeout("tcp", ln2.Addr().String(), 2*time.Second)
require.NoError(t, err)
defer conn2.Close()
_, err = conn2.Write([]byte("world"))
require.NoError(t, err)
_ = conn2.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err = conn2.Read(buf)
require.NoError(t, err)
assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend2:")), "should route to backend2, got: %s", string(buf[:n]))
}
func TestStop(t *testing.T) {
s := NewServer()
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer backendLn.Close()
go func() {
for {
conn, err := backendLn.Accept()
if err != nil {
return
}
_, _ = io.Copy(conn, conn)
_ = conn.Close()
}
}()
_ = s.AddUpstream("test", []TargetSpec{{Addr: backendLn.Addr().String()}}, "round_robin", HealthCheckSpec{})
s.upstreams["test"].targets[0].healthy.Store(true)
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s.mu.Lock()
s.listeners["test"] = ln
s.mu.Unlock()
err = s.Start()
require.NoError(t, err)
clientConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
require.NoError(t, err)
_ = clientConn.Close()
s.Stop()
assert.False(t, s.running.Load())
s.mu.RLock()
assert.Empty(t, s.listeners)
s.mu.RUnlock()
}
func TestStop_Idempotent(t *testing.T) {
s := NewServer()
s.Stop()
s.Stop()
}
func TestStart_DoubleStart(t *testing.T) {
s := NewServer()
err := s.Start()
require.NoError(t, err)
err = s.Start()
assert.Error(t, err)
s.Stop()
}

View File

@ -29,6 +29,7 @@
package stream package stream
import ( import (
"fmt"
"hash/fnv" "hash/fnv"
"io" "io"
"net" "net"
@ -286,18 +287,14 @@ func (i *ipHash) SelectByIP(targets []*Target, clientIP string) *Target {
// Server TCP/UDP Stream 代理服务器。 // Server TCP/UDP Stream 代理服务器。
type Server struct { type Server struct {
// listeners TCP 监听器映射,按 upstream 名称索引 listeners map[string]net.Listener
listeners map[string]net.Listener
// udpServers UDP 服务器映射
udpServers map[string]*udpServer udpServers map[string]*udpServer
// upstreams 上游配置映射 upstreams map[string]*Upstream
upstreams map[string]*Upstream connCount atomic.Int64
// connCount 当前连接数 mu sync.RWMutex
connCount atomic.Int64 running atomic.Bool
// mu 读写锁,保护并发访问 wg sync.WaitGroup
mu sync.RWMutex stopCh chan struct{}
// running 运行状态标志
running atomic.Bool
} }
// Upstream Stream 上游配置。 // Upstream Stream 上游配置。
@ -393,6 +390,7 @@ func NewServer() *Server {
listeners: make(map[string]net.Listener), listeners: make(map[string]net.Listener),
udpServers: make(map[string]*udpServer), udpServers: make(map[string]*udpServer),
upstreams: make(map[string]*Upstream), upstreams: make(map[string]*Upstream),
stopCh: make(chan struct{}),
} }
} }
@ -491,25 +489,71 @@ func (s *Server) ListenUDP(addr string, upstreamName string, timeout time.Durati
// Start 启动 Stream 服务器。 // Start 启动 Stream 服务器。
func (s *Server) Start() error { func (s *Server) Start() error {
s.running.Store(true) if !s.running.CompareAndSwap(false, true) {
return fmt.Errorf("stream server already running")
}
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
// 启动 TCP 监听器
for addr, listener := range s.listeners { for addr, listener := range s.listeners {
go s.acceptLoop(addr, listener) s.wg.Add(1)
go func(a string, ln net.Listener) {
defer s.wg.Done()
s.acceptLoop(a, ln)
}(addr, listener)
} }
// 启动 UDP 服务器
for _, udpSrv := range s.udpServers { for _, udpSrv := range s.udpServers {
go udpSrv.serve() s.wg.Add(1)
go udpSrv.startCleanupTicker() go func(u *udpServer) {
defer s.wg.Done()
u.serve()
}(udpSrv)
s.wg.Add(1)
go func(u *udpServer) {
defer s.wg.Done()
u.startCleanupTicker()
}(udpSrv)
} }
return nil return nil
} }
// Stop stops the stream server, closing all listeners and waiting for goroutines to finish.
func (s *Server) Stop() {
if !s.running.CompareAndSwap(true, false) {
return
}
close(s.stopCh)
s.mu.Lock()
for _, ln := range s.listeners {
_ = ln.Close()
}
for _, udpSrv := range s.udpServers {
close(udpSrv.stopCh)
if udpSrv.conn != nil {
_ = udpSrv.conn.Close()
}
}
for _, upstream := range s.upstreams {
if upstream.healthChk != nil && upstream.healthChk.stopCh != nil {
close(upstream.healthChk.stopCh)
}
}
s.mu.Unlock()
s.wg.Wait()
s.mu.Lock()
s.listeners = make(map[string]net.Listener)
s.udpServers = make(map[string]*udpServer)
s.stopCh = make(chan struct{})
s.mu.Unlock()
}
// acceptLoop 接受连接循环。 // acceptLoop 接受连接循环。
// //
// 在单独的 goroutine 中运行,持续接受 TCP 连接。 // 在单独的 goroutine 中运行,持续接受 TCP 连接。
@ -523,7 +567,12 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
if !s.running.Load() { if !s.running.Load() {
return // 正常关闭 return
}
select {
case <-s.stopCh:
return
default:
} }
continue continue
} }
@ -544,19 +593,14 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) {
// 参数: // 参数:
// - clientConn: 客户端连接 // - clientConn: 客户端连接
// - addr: 监听地址 // - addr: 监听地址
func (s *Server) handleConnection(clientConn net.Conn, _ string) { func (s *Server) handleConnection(clientConn net.Conn, addr string) {
defer func() { defer func() {
_ = clientConn.Close() _ = clientConn.Close()
s.connCount.Add(-1) s.connCount.Add(-1)
}() }()
s.mu.RLock() s.mu.RLock()
// 根据监听地址找到对应 upstream简化用第一个 upstream := s.upstreams[addr]
var upstream *Upstream
for _, up := range s.upstreams {
upstream = up
break
}
s.mu.RUnlock() s.mu.RUnlock()
if upstream == nil { if upstream == nil {

View File

@ -157,7 +157,7 @@ func TestHandleConnection_NoHealthyTarget(t *testing.T) {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
s.handleConnection(clientConn, "127.0.0.1:0") s.handleConnection(clientConn, "test2")
close(done) close(done)
}() }()
@ -200,7 +200,7 @@ func TestHandleConnection_DialFail(t *testing.T) {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
s.handleConnection(clientConn, "127.0.0.1:0") s.handleConnection(clientConn, "test3")
close(done) close(done)
}() }()