Compare commits
10 Commits
88a2c1fc1b
...
afbbc3a951
| Author | SHA1 | Date | |
|---|---|---|---|
| afbbc3a951 | |||
| 66ea93e3c1 | |||
| 7204432ca0 | |||
| f12ffd180f | |||
| 503daf65d3 | |||
| ef871f1d39 | |||
| e5885ce888 | |||
| 72f189bba8 | |||
| 3b6b70a491 | |||
| cb1f86298e |
33
CHANGELOG.md
33
CHANGELOG.md
@ -7,6 +7,39 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.4.1] - 2026-06-10
|
||||
|
||||
### Fixed
|
||||
|
||||
- **stream**: 修复上游选择逻辑,添加优雅关闭支持
|
||||
- **stream**: 修复 Stop 后 stopCh 未重置导致的重启问题
|
||||
|
||||
## [0.4.0] - 2026-06-09
|
||||
|
||||
### Added
|
||||
|
||||
#### 负载均衡
|
||||
|
||||
- **Least Time 负载均衡器**:基于 EWMA(指数加权移动平均)统计的最少响应时间算法,自动选择响应最快上游
|
||||
- **Session Sticky 负载均衡器**:基于 Cookie 的会话粘性,一致性哈希分片,支持 Cookie 过期、域名、路径、Secure/HttpOnly/SameSite 属性
|
||||
- 对应 YAML 配置支持:`least_time`、`sticky` 策略及参数
|
||||
|
||||
#### 平台与构建
|
||||
|
||||
- FreeBSD 部署示例
|
||||
|
||||
### Fixed
|
||||
|
||||
- Least Time 响应时间记录修正
|
||||
- Sticky Cookie 格式、分片键、过期检查修复
|
||||
- Sticky 双重 Stop 防护和重启支持
|
||||
- 配置验证:`least_time` 的 `default_time` 不允许负值
|
||||
|
||||
### Tests
|
||||
|
||||
- Least Time 和 Sticky 负载均衡器集成测试
|
||||
- Least Time 和 Sticky 基准测试
|
||||
|
||||
## [0.3.0] - 2026-06-05
|
||||
|
||||
### Added
|
||||
|
||||
2
Makefile
2
Makefile
@ -1,7 +1,7 @@
|
||||
# Makefile - Lolly Build Commands
|
||||
|
||||
APP_NAME := lolly
|
||||
FALLBACK_VERSION := 0.3.0
|
||||
FALLBACK_VERSION := 0.4.1
|
||||
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null | sed 's/^v//' || echo "$(FALLBACK_VERSION)")
|
||||
|
||||
GIT_COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
|
||||
@ -123,6 +123,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("Graceful stop (waiting %v)", timeout))
|
||||
a.shutdownStream()
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.GracefulStop(timeout)
|
||||
@ -139,6 +140,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
||||
} else {
|
||||
a.logger.LogSignal(sigName(sigTyped), "Stopping server")
|
||||
}
|
||||
a.shutdownStream()
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.StopWithTimeout(timeout)
|
||||
@ -329,6 +331,7 @@ func (a *App) gracefulUpgrade() {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
a.shutdownStream()
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.GracefulStop(timeout)
|
||||
|
||||
@ -252,6 +252,12 @@ func (a *App) shutdownHTTP2() {
|
||||
}
|
||||
|
||||
// reopenLogs reinitializes the logger from current config.
|
||||
func (a *App) shutdownStream() {
|
||||
if a.streamSrv != nil {
|
||||
a.streamSrv.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) reopenLogs() {
|
||||
if a.cfg != nil {
|
||||
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)
|
||||
|
||||
@ -19,7 +19,7 @@ import (
|
||||
// 注意事项:
|
||||
// - Path 使用前缀匹配,较长路径优先匹配
|
||||
// - 至少配置一个 Target 才能正常工作
|
||||
// - 负载均衡算法支持:round_robin、weighted_round_robin、least_conn、ip_hash、consistent_hash、random、least_time、sticky
|
||||
// - 负载均衡算法支持:round_robin、weighted_round_robin、least_conn、ip_hash、consistent_hash、random、least_time、sticky
|
||||
// - 一致性哈希需要配置 HashKey
|
||||
//
|
||||
// 使用示例:
|
||||
|
||||
@ -509,6 +509,9 @@ func validateProxy(p *ProxyConfig) error {
|
||||
if p.LeastTime.Metric != "" && p.LeastTime.Metric != "header" && p.LeastTime.Metric != "last_byte" {
|
||||
return fmt.Errorf("无效的 least_time metric: %s(有效值: header, last_byte)", p.LeastTime.Metric)
|
||||
}
|
||||
if p.LeastTime.DefaultTime < 0 {
|
||||
return fmt.Errorf("least_time default_time 不能为负数")
|
||||
}
|
||||
}
|
||||
|
||||
// validate sticky config
|
||||
@ -519,6 +522,12 @@ func validateProxy(p *ProxyConfig) error {
|
||||
if p.Sticky.FallbackAlgo != "" && !loadbalance.IsValidAlgorithm(p.Sticky.FallbackAlgo) {
|
||||
return fmt.Errorf("无效的 sticky fallback_balance: %s", p.Sticky.FallbackAlgo)
|
||||
}
|
||||
if p.Sticky.SameSite != "" {
|
||||
validSameSites := []string{"Lax", "Strict", "None"}
|
||||
if !slices.Contains(validSameSites, p.Sticky.SameSite) {
|
||||
return fmt.Errorf("无效的 sticky same_site: %s(有效值: Lax, Strict, None)", p.Sticky.SameSite)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证故障转移配置
|
||||
|
||||
@ -170,6 +170,80 @@ func TestValidateProxy(t *testing.T) {
|
||||
wantErr: true,
|
||||
errMsg: "无效的负载均衡算法",
|
||||
},
|
||||
{
|
||||
name: "有效 least_time 配置 metric=header",
|
||||
config: ProxyConfig{
|
||||
Path: "/api",
|
||||
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||
LoadBalance: "least_time",
|
||||
LeastTime: LeastTimeConfig{Metric: "header"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "有效 least_time 配置 metric=last_byte",
|
||||
config: ProxyConfig{
|
||||
Path: "/api",
|
||||
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||
LoadBalance: "least_time",
|
||||
LeastTime: LeastTimeConfig{Metric: "last_byte"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "无效 least_time metric",
|
||||
config: ProxyConfig{
|
||||
Path: "/api",
|
||||
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||
LoadBalance: "least_time",
|
||||
LeastTime: LeastTimeConfig{Metric: "invalid"},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "无效的 least_time metric",
|
||||
},
|
||||
{
|
||||
name: "有效 sticky 配置",
|
||||
config: ProxyConfig{
|
||||
Path: "/api",
|
||||
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||
LoadBalance: "sticky",
|
||||
Sticky: StickyConfig{Enabled: true, FallbackAlgo: "round_robin"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "无效 sticky enabled=false",
|
||||
config: ProxyConfig{
|
||||
Path: "/api",
|
||||
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||
LoadBalance: "sticky",
|
||||
Sticky: StickyConfig{Enabled: false},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "sticky.enabled 必须为 true",
|
||||
},
|
||||
{
|
||||
name: "无效 sticky fallback_balance",
|
||||
config: ProxyConfig{
|
||||
Path: "/api",
|
||||
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||
LoadBalance: "sticky",
|
||||
Sticky: StickyConfig{Enabled: true, FallbackAlgo: "invalid"},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "无效的 sticky fallback_balance",
|
||||
},
|
||||
{
|
||||
name: "无效 sticky same_site",
|
||||
config: ProxyConfig{
|
||||
Path: "/api",
|
||||
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||
LoadBalance: "sticky",
|
||||
Sticky: StickyConfig{Enabled: true, SameSite: "Invalid"},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "无效的 sticky same_site",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@ -11,10 +11,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
gzipType = "gzip"
|
||||
offValue = "off"
|
||||
redirectType = "redirect"
|
||||
staticType = "static"
|
||||
gzipType = "gzip"
|
||||
offValue = "off"
|
||||
redirectType = "redirect"
|
||||
staticType = "static"
|
||||
returnDirective = "return"
|
||||
)
|
||||
|
||||
// Warning represents a conversion warning for unsupported or partially supported directives.
|
||||
@ -70,7 +71,7 @@ var unsupportedDirectives = map[string]string{
|
||||
"split_clients": "the 'split_clients' directive is not supported",
|
||||
"geo": "the 'geo' directive is not supported; use access.geoip config instead",
|
||||
"range": "the 'range' directive is not supported",
|
||||
"return": "the 'return' directive is not supported for non-redirect status codes; only 301/302 are supported",
|
||||
returnDirective: "the 'return' directive is not supported for non-redirect status codes; only 301/302 are supported",
|
||||
}
|
||||
|
||||
// Convert converts a parsed nginx configuration to a lolly configuration.
|
||||
@ -273,7 +274,7 @@ func convertServerBlock(d *Directive, upstreams map[string]*upstreamInfo, result
|
||||
parseAccessLog(bd, result)
|
||||
case "error_log":
|
||||
parseErrorLog(bd, result)
|
||||
case "return":
|
||||
case returnDirective:
|
||||
parseServerReturn(bd, &baseServer, result)
|
||||
case "rewrite":
|
||||
parseRewrite(bd, &baseServer)
|
||||
@ -457,7 +458,7 @@ func parseServerReturn(d *Directive, server *config.ServerConfig, result *Conver
|
||||
})
|
||||
default:
|
||||
result.Warnings = append(result.Warnings, Warning{
|
||||
Directive: "return",
|
||||
Directive: returnDirective,
|
||||
Line: d.Line,
|
||||
File: d.File,
|
||||
Message: fmt.Sprintf("return %d is not a redirect; only 301/302 are supported at server level", code),
|
||||
@ -554,7 +555,7 @@ func classifyLocation(d *Directive, serverRoot string, result *ConvertResult) lo
|
||||
hasRootOrAlias = true
|
||||
case "try_files":
|
||||
hasTryFiles = true
|
||||
case "return":
|
||||
case returnDirective:
|
||||
if len(d.Block[i].Args) > 0 {
|
||||
code, err := strconv.Atoi(d.Block[i].Args[0])
|
||||
if err == nil && (code == 301 || code == 302) {
|
||||
@ -876,7 +877,7 @@ func convertRedirectDirectives(directives []Directive, locPath string, server *c
|
||||
for i := range directives {
|
||||
d := &directives[i]
|
||||
|
||||
if d.Name != "return" {
|
||||
if d.Name != returnDirective {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -900,7 +901,7 @@ func convertRedirectDirectives(directives []Directive, locPath string, server *c
|
||||
Flag: "permanent",
|
||||
})
|
||||
result.Warnings = append(result.Warnings, Warning{
|
||||
Directive: "return",
|
||||
Directive: returnDirective,
|
||||
Line: d.Line,
|
||||
File: d.File,
|
||||
Message: "return 301 converted to rewrite rule with permanent flag",
|
||||
@ -912,14 +913,14 @@ func convertRedirectDirectives(directives []Directive, locPath string, server *c
|
||||
Flag: "redirect",
|
||||
})
|
||||
result.Warnings = append(result.Warnings, Warning{
|
||||
Directive: "return",
|
||||
Directive: returnDirective,
|
||||
Line: d.Line,
|
||||
File: d.File,
|
||||
Message: "return 302 converted to rewrite rule with redirect flag",
|
||||
})
|
||||
default:
|
||||
result.Warnings = append(result.Warnings, Warning{
|
||||
Directive: "return",
|
||||
Directive: returnDirective,
|
||||
Line: d.Line,
|
||||
File: d.File,
|
||||
Message: fmt.Sprintf("return %d in location is not a redirect; only 301/302 are supported", code),
|
||||
|
||||
@ -71,7 +71,6 @@ func TestProxyRequestHeaders(t *testing.T) {
|
||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// 验证代理配置已设置
|
||||
assert.NotNil(t, cfg.Headers.SetRequest)
|
||||
assert.Equal(t, "custom-value", cfg.Headers.SetRequest["X-Custom-Header"])
|
||||
@ -101,7 +100,6 @@ func TestProxyResponseHeaders(t *testing.T) {
|
||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// 验证响应头配置
|
||||
assert.Equal(t, "lolly", cfg.Headers.SetResponse["X-Server"])
|
||||
assert.Contains(t, cfg.Headers.Remove, "X-Powered-By")
|
||||
@ -125,7 +123,6 @@ func TestProxyTimeout(t *testing.T) {
|
||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// 验证超时配置
|
||||
assert.Equal(t, 1*time.Second, cfg.Timeout.Connect)
|
||||
assert.Equal(t, 50*time.Millisecond, cfg.Timeout.Read)
|
||||
@ -150,7 +147,6 @@ func TestProxyLoadBalanceRoundRobin(t *testing.T) {
|
||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// 验证负载均衡器类型
|
||||
assert.Equal(t, "round_robin", cfg.LoadBalance)
|
||||
assert.Len(t, targets, 2)
|
||||
@ -174,7 +170,6 @@ func TestProxyWeightedRoundRobin(t *testing.T) {
|
||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// 验证权重配置
|
||||
assert.Equal(t, 3, targets[0].Weight)
|
||||
assert.Equal(t, 1, targets[1].Weight)
|
||||
@ -198,7 +193,6 @@ func TestProxyLeastConn(t *testing.T) {
|
||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, "least_conn", cfg.LoadBalance)
|
||||
}
|
||||
|
||||
@ -220,7 +214,6 @@ func TestProxyIPHash(t *testing.T) {
|
||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, "ip_hash", cfg.LoadBalance)
|
||||
}
|
||||
|
||||
@ -240,7 +233,6 @@ func TestProxyConsistentHash(t *testing.T) {
|
||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
assert.Equal(t, "consistent_hash", cfg.LoadBalance)
|
||||
assert.Equal(t, "uri", cfg.HashKey)
|
||||
assert.Equal(t, 150, cfg.VirtualNodes)
|
||||
@ -268,7 +260,6 @@ func TestProxyErrorHandling(t *testing.T) {
|
||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// 验证 MaxFails 配置 (int64 类型)
|
||||
assert.Equal(t, int64(3), targets[0].MaxFails)
|
||||
assert.Equal(t, 10*time.Second, targets[0].FailTimeout)
|
||||
@ -300,7 +291,6 @@ func TestProxyCacheConfig(t *testing.T) {
|
||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// 验证缓存配置
|
||||
assert.True(t, cfg.Cache.Enabled)
|
||||
assert.Equal(t, 60*time.Second, cfg.Cache.MaxAge)
|
||||
@ -330,7 +320,6 @@ func TestProxyNextUpstream(t *testing.T) {
|
||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// 验证故障转移配置
|
||||
assert.Equal(t, 3, cfg.NextUpstream.Tries)
|
||||
assert.Contains(t, cfg.NextUpstream.HTTPCodes, 502)
|
||||
@ -360,7 +349,6 @@ func TestProxyHealthCheck(t *testing.T) {
|
||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
// 验证健康检查配置
|
||||
assert.Equal(t, 10*time.Second, cfg.HealthCheck.Interval)
|
||||
assert.Equal(t, "/health", cfg.HealthCheck.Path)
|
||||
|
||||
@ -12,9 +12,12 @@ package loadbalance
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// createHealthyTarget 创建并返回一个指定 URL 和健康状态的 Target 实例。
|
||||
@ -940,6 +943,8 @@ func TestIsValidAlgorithm(t *testing.T) {
|
||||
{"ip_hash", "ip_hash", true},
|
||||
{"consistent_hash", "consistent_hash", true},
|
||||
{"random", "random", true},
|
||||
{"least_time", "least_time", true},
|
||||
{"sticky", "sticky", true},
|
||||
{"invalid", "invalid", false},
|
||||
{"empty", "", true}, // 空字符串有效(使用默认值)
|
||||
{"unknown", "unknown-algorithm", false},
|
||||
@ -2091,3 +2096,94 @@ func TestRandomBalancer(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestBalancerIntegration_LeastTime 验证 Least Time 算法 consistently 选择更快目标。
|
||||
func TestBalancerIntegration_LeastTime(t *testing.T) {
|
||||
targets := []*Target{
|
||||
NewTargetFromConfig("http://slow:8080", 1, 0, 0, 0, false, false, ""),
|
||||
NewTargetFromConfig("http://fast:8080", 1, 0, 0, 0, false, false, ""),
|
||||
}
|
||||
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
|
||||
// 模拟:slow 目标有 100ms avg,fast 有 10ms avg
|
||||
for i := 0; i < 10; i++ {
|
||||
targets[0].Stats.Record(50*time.Millisecond, 100*time.Millisecond)
|
||||
targets[1].Stats.Record(5*time.Millisecond, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
// 选择 100 次,应该 mostly 选 fast
|
||||
fastCount := 0
|
||||
for i := 0; i < 100; i++ {
|
||||
selected := lt.Select(targets)
|
||||
if selected.URL == "http://fast:8080" {
|
||||
fastCount++
|
||||
}
|
||||
}
|
||||
|
||||
if fastCount < 80 {
|
||||
t.Errorf("fast target selected %d/100 times, expected >80", fastCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBalancerIntegration_StickyWithLeastTimeFallback 验证 Sticky 在目标不健康时 fallback。
|
||||
func TestBalancerIntegration_StickyWithLeastTimeFallback(t *testing.T) {
|
||||
fallback := NewLeastTime("last_byte", time.Millisecond)
|
||||
config := StickyConfig{
|
||||
Enabled: true,
|
||||
Name: "test_route",
|
||||
Expires: time.Hour,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
}
|
||||
|
||||
sticky := NewStickySession(config, fallback)
|
||||
sticky.Start()
|
||||
defer sticky.Stop()
|
||||
|
||||
targets := []*Target{
|
||||
NewTargetFromConfig("http://backend1:8080", 1, 0, 0, 0, false, false, ""),
|
||||
NewTargetFromConfig("http://backend2:8080", 1, 0, 0, 0, false, false, ""),
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// 首次请求
|
||||
selected1 := sticky.Select(ctx, targets)
|
||||
if selected1 == nil {
|
||||
t.Fatal("expected a target")
|
||||
}
|
||||
|
||||
// 验证 cookie 已设置
|
||||
cookie := ctx.Response.Header.PeekCookie("test_route")
|
||||
if len(cookie) == 0 {
|
||||
t.Fatal("expected cookie")
|
||||
}
|
||||
|
||||
// 使 selected1 不健康
|
||||
selected1.Healthy.Store(false)
|
||||
|
||||
// 第二次请求带 cookie 应该 fallback
|
||||
ctx2 := &fasthttp.RequestCtx{}
|
||||
ctx2.Request.Header.SetCookie("test_route", string(extractCookieValue(cookie)))
|
||||
|
||||
selected2 := sticky.Select(ctx2, targets)
|
||||
if selected2 == nil {
|
||||
t.Fatal("expected fallback target")
|
||||
}
|
||||
if selected2.URL == selected1.URL {
|
||||
t.Error("expected different target after fallback")
|
||||
}
|
||||
}
|
||||
|
||||
// extractCookieValue 从 Set-Cookie header 中提取 cookie 值
|
||||
func extractCookieValue(cookieHeader []byte) []byte {
|
||||
s := string(cookieHeader)
|
||||
// Format: "name=value; ..."
|
||||
parts := strings.SplitN(s, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil
|
||||
}
|
||||
valueParts := strings.SplitN(parts[1], ";", 2)
|
||||
return []byte(valueParts[0])
|
||||
}
|
||||
|
||||
@ -14,10 +14,12 @@ type EWMAStats struct {
|
||||
|
||||
const defaultAlphaScale = 300 // alpha = 0.3
|
||||
|
||||
// NewEWMAStats 创建新的 EWMA 统计器。
|
||||
func NewEWMAStats() *EWMAStats {
|
||||
return &EWMAStats{}
|
||||
}
|
||||
|
||||
// Record 记录一次响应时间样本。
|
||||
func (e *EWMAStats) Record(headerTime, lastByteTime time.Duration) {
|
||||
e.recordAtomic(&e.headerTime, headerTime)
|
||||
e.recordAtomic(&e.lastByteTime, lastByteTime)
|
||||
@ -41,18 +43,22 @@ func (e *EWMAStats) recordAtomic(ptr *atomic.Int64, newValue time.Duration) {
|
||||
}
|
||||
}
|
||||
|
||||
// HeaderTime 返回首字节时间的 EWMA 值。
|
||||
func (e *EWMAStats) HeaderTime() time.Duration {
|
||||
return time.Duration(e.headerTime.Load())
|
||||
}
|
||||
|
||||
// LastByteTime 返回完整响应时间的 EWMA 值。
|
||||
func (e *EWMAStats) LastByteTime() time.Duration {
|
||||
return time.Duration(e.lastByteTime.Load())
|
||||
}
|
||||
|
||||
// SampleCount 返回已记录的样本数量。
|
||||
func (e *EWMAStats) SampleCount() int64 {
|
||||
return e.sampleCount.Load()
|
||||
}
|
||||
|
||||
// Reset 重置所有统计数据。
|
||||
func (e *EWMAStats) Reset() {
|
||||
e.headerTime.Store(0)
|
||||
e.lastByteTime.Store(0)
|
||||
|
||||
@ -97,5 +97,7 @@ func (l *LeastTime) GetMetric() string {
|
||||
return l.metric
|
||||
}
|
||||
|
||||
var _ Balancer = (*LeastTime)(nil)
|
||||
var _ ResponseTimeRecorder = (*LeastTime)(nil)
|
||||
var (
|
||||
_ Balancer = (*LeastTime)(nil)
|
||||
_ ResponseTimeRecorder = (*LeastTime)(nil)
|
||||
)
|
||||
|
||||
96
internal/loadbalance/least_time_bench_test.go
Normal file
96
internal/loadbalance/least_time_bench_test.go
Normal file
@ -0,0 +1,96 @@
|
||||
package loadbalance
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func BenchmarkLeastTime_Select(b *testing.B) {
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
targets := []*Target{
|
||||
NewTargetFromConfig("http://a:8080", 1, 0, 0, 0, false, false, ""),
|
||||
NewTargetFromConfig("http://b:8080", 1, 0, 0, 0, false, false, ""),
|
||||
NewTargetFromConfig("http://c:8080", 1, 0, 0, 0, false, false, ""),
|
||||
}
|
||||
|
||||
// Pre-populate stats
|
||||
for _, t := range targets {
|
||||
t.Stats.Record(10*time.Millisecond, 20*time.Millisecond)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lt.Select(targets)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLeastTime_Record(b *testing.B) {
|
||||
stats := NewEWMAStats()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
stats.Record(10*time.Millisecond, 20*time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLeastTime_Concurrent(b *testing.B) {
|
||||
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||
targets := []*Target{
|
||||
NewTargetFromConfig("http://a:8080", 1, 0, 0, 0, false, false, ""),
|
||||
NewTargetFromConfig("http://b:8080", 1, 0, 0, 0, false, false, ""),
|
||||
}
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
lt.Select(targets)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkStickySession_Select(b *testing.B) {
|
||||
fallback := NewRoundRobin()
|
||||
config := DefaultStickyConfig()
|
||||
|
||||
sticky := NewStickySession(config, fallback)
|
||||
sticky.Start()
|
||||
defer sticky.Stop()
|
||||
|
||||
targets := []*Target{
|
||||
NewTargetFromConfig("http://backend1:8080", 1, 0, 0, 0, false, false, ""),
|
||||
NewTargetFromConfig("http://backend2:8080", 1, 0, 0, 0, false, false, ""),
|
||||
}
|
||||
|
||||
// Pre-populate a cookie
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
sticky.Select(ctx, targets)
|
||||
cookie := ctx.Response.Header.PeekCookie(config.Name)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetCookie(config.Name, string(extractCookieValue(cookie)))
|
||||
sticky.Select(ctx, targets)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStickySession_SelectNew(b *testing.B) {
|
||||
fallback := NewRoundRobin()
|
||||
config := DefaultStickyConfig()
|
||||
|
||||
sticky := NewStickySession(config, fallback)
|
||||
sticky.Start()
|
||||
defer sticky.Stop()
|
||||
|
||||
targets := []*Target{
|
||||
NewTargetFromConfig("http://backend1:8080", 1, 0, 0, 0, false, false, ""),
|
||||
NewTargetFromConfig("http://backend2:8080", 1, 0, 0, 0, false, false, ""),
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
sticky.Select(ctx, targets)
|
||||
}
|
||||
}
|
||||
@ -2,6 +2,7 @@ package loadbalance
|
||||
|
||||
import "time"
|
||||
|
||||
// StickyConfig 配置 Sticky 负载均衡的 Cookie 参数。
|
||||
type StickyConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Name string `yaml:"name"`
|
||||
@ -13,6 +14,7 @@ type StickyConfig struct {
|
||||
SameSite string `yaml:"same_site"`
|
||||
}
|
||||
|
||||
// DefaultStickyConfig 返回 Sticky 负载均衡的默认配置。
|
||||
func DefaultStickyConfig() StickyConfig {
|
||||
return StickyConfig{
|
||||
Name: "lolly_route",
|
||||
|
||||
@ -90,6 +90,8 @@ const (
|
||||
lbIPHash = "ip_hash" // IP 哈希
|
||||
lbConsistentHash = "consistent_hash" // 一致性哈希
|
||||
lbRandom = "random" // 随机(Power of Two Choices)
|
||||
lbLeastTime = "least_time" // 最小响应时间
|
||||
lbSticky = "sticky" // 会话粘性
|
||||
)
|
||||
|
||||
// headersPool 复用缓存 headers map,减少分配。
|
||||
@ -229,6 +231,22 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// stickyBalancer wraps StickySession to implement loadbalance.Balancer.
|
||||
// It delegates Select/SelectExcluding to the fallback balancer while
|
||||
// allowing the proxy to access the StickySession for cookie-based routing.
|
||||
type stickyBalancer struct {
|
||||
sticky *loadbalance.StickySession
|
||||
fallback loadbalance.Balancer
|
||||
}
|
||||
|
||||
func (b *stickyBalancer) Select(targets []*loadbalance.Target) *loadbalance.Target {
|
||||
return b.fallback.Select(targets)
|
||||
}
|
||||
|
||||
func (b *stickyBalancer) SelectExcluding(targets []*loadbalance.Target, excluded []*loadbalance.Target) *loadbalance.Target {
|
||||
return b.fallback.SelectExcluding(targets, excluded)
|
||||
}
|
||||
|
||||
// createBalancerByName 根据算法名称创建负载均衡器。
|
||||
//
|
||||
// 支持的算法:
|
||||
@ -263,6 +281,49 @@ func createBalancerByName(name string, cfg *config.ProxyConfig) (loadbalance.Bal
|
||||
return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil
|
||||
case lbRandom:
|
||||
return loadbalance.NewRandom(), nil
|
||||
case lbLeastTime:
|
||||
metric := cfg.LeastTime.Metric
|
||||
if metric == "" {
|
||||
metric = "last_byte"
|
||||
}
|
||||
defaultTime := cfg.LeastTime.DefaultTime
|
||||
if defaultTime <= 0 {
|
||||
defaultTime = time.Millisecond
|
||||
}
|
||||
return loadbalance.NewLeastTime(metric, defaultTime), nil
|
||||
case lbSticky:
|
||||
stickyCfg := loadbalance.StickyConfig{
|
||||
Enabled: cfg.Sticky.Enabled,
|
||||
Name: cfg.Sticky.Name,
|
||||
Expires: cfg.Sticky.Expires,
|
||||
Domain: cfg.Sticky.Domain,
|
||||
Path: cfg.Sticky.Path,
|
||||
Secure: cfg.Sticky.Secure,
|
||||
HttpOnly: cfg.Sticky.HttpOnly,
|
||||
SameSite: cfg.Sticky.SameSite,
|
||||
}
|
||||
if stickyCfg.Name == "" {
|
||||
stickyCfg.Name = "lolly_route"
|
||||
}
|
||||
if stickyCfg.Expires <= 0 {
|
||||
stickyCfg.Expires = time.Hour
|
||||
}
|
||||
if stickyCfg.Path == "" {
|
||||
stickyCfg.Path = "/"
|
||||
}
|
||||
|
||||
fallbackAlgo := cfg.Sticky.FallbackAlgo
|
||||
if fallbackAlgo == "" {
|
||||
fallbackAlgo = lbRoundRobin
|
||||
}
|
||||
fallbackBalancer, err := createBalancerByName(fallbackAlgo, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sticky fallback balancer: %w", err)
|
||||
}
|
||||
|
||||
sticky := loadbalance.NewStickySession(stickyCfg, fallbackBalancer)
|
||||
sticky.Start()
|
||||
return &stickyBalancer{sticky: sticky, fallback: fallbackBalancer}, nil
|
||||
default:
|
||||
return nil, errors.New("unsupported load balance algorithm: " + name)
|
||||
}
|
||||
@ -813,6 +874,12 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
// 记录首字节时间
|
||||
timing.MarkHeaderReceived()
|
||||
|
||||
// 记录首字节响应时间(用于 least_time 负载均衡)
|
||||
if recorder, ok := p.balancer.(loadbalance.ResponseTimeRecorder); ok {
|
||||
headerTime := timing.headerReceived.Sub(timing.connectEnd)
|
||||
recorder.RecordResponseTime(target, headerTime, 0)
|
||||
}
|
||||
|
||||
// 请求成功,减少连接计数
|
||||
loadbalance.DecrementConnections(target)
|
||||
|
||||
@ -917,6 +984,15 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
|
||||
// 修改响应头
|
||||
p.modifyResponseHeaders(ctx)
|
||||
|
||||
// 记录完整响应时间(用于 least_time 负载均衡)
|
||||
timing.MarkResponseEnd()
|
||||
if recorder, ok := p.balancer.(loadbalance.ResponseTimeRecorder); ok {
|
||||
headerTime := timing.headerReceived.Sub(timing.connectEnd)
|
||||
lastByteTime := timing.responseEnd.Sub(timing.connectEnd)
|
||||
recorder.RecordResponseTime(target, headerTime, lastByteTime)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -119,6 +119,11 @@ func (p *Proxy) selectByBalancer(ctx *fasthttp.RequestCtx, targets []*loadbalanc
|
||||
balancer := p.balancer
|
||||
p.mu.RUnlock()
|
||||
|
||||
// 对于 StickySession 负载均衡器,需要请求上下文
|
||||
if sb, ok := balancer.(*stickyBalancer); ok {
|
||||
return sb.sticky.Select(ctx, targets)
|
||||
}
|
||||
|
||||
// 对于 IPHash 负载均衡器,提取客户端 IP
|
||||
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -63,7 +64,7 @@ func TestStart_NoListeners(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.True(t, s.running.Load())
|
||||
|
||||
s.running.Store(false)
|
||||
s.Stop()
|
||||
}
|
||||
|
||||
func TestStart_WithTCPListeners(t *testing.T) {
|
||||
@ -81,12 +82,7 @@ func TestStart_WithTCPListeners(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.True(t, s.running.Load())
|
||||
|
||||
s.running.Store(false)
|
||||
s.mu.RLock()
|
||||
for _, ln := range s.listeners {
|
||||
_ = ln.Close()
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
s.Stop()
|
||||
}
|
||||
|
||||
func TestStart_AcceptConnections(t *testing.T) {
|
||||
@ -115,7 +111,7 @@ func TestStart_AcceptConnections(t *testing.T) {
|
||||
proxyAddr := ln.Addr().String()
|
||||
|
||||
s.mu.Lock()
|
||||
s.listeners[proxyAddr] = ln
|
||||
s.listeners["test"] = ln
|
||||
s.mu.Unlock()
|
||||
|
||||
err = s.Start()
|
||||
@ -136,12 +132,7 @@ func TestStart_AcceptConnections(t *testing.T) {
|
||||
|
||||
_ = clientConn.Close()
|
||||
|
||||
s.running.Store(false)
|
||||
s.mu.RLock()
|
||||
for _, l := range s.listeners {
|
||||
_ = l.Close()
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
s.Stop()
|
||||
_ = backendLn.Close()
|
||||
}
|
||||
|
||||
@ -637,3 +628,138 @@ func TestStartCleanupTicker_StopsOnSignal(t *testing.T) {
|
||||
t.Fatal("startCleanupTicker did not stop after signal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleConnection_MultipleUpstreams(t *testing.T) {
|
||||
s := NewServer()
|
||||
|
||||
backend1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
backend2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend1.Close()
|
||||
defer backend2.Close()
|
||||
|
||||
go func() {
|
||||
conn, err := backend1.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := conn.Read(buf)
|
||||
_, _ = conn.Write(append([]byte("backend1:"), buf[:n]...))
|
||||
}()
|
||||
|
||||
go func() {
|
||||
conn, err := backend2.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := conn.Read(buf)
|
||||
_, _ = conn.Write(append([]byte("backend2:"), buf[:n]...))
|
||||
}()
|
||||
|
||||
_ = s.AddUpstream("upstream1", []TargetSpec{{Addr: backend1.Addr().String()}}, "round_robin", HealthCheckSpec{})
|
||||
_ = s.AddUpstream("upstream2", []TargetSpec{{Addr: backend2.Addr().String()}}, "round_robin", HealthCheckSpec{})
|
||||
s.upstreams["upstream1"].targets[0].healthy.Store(true)
|
||||
s.upstreams["upstream2"].targets[0].healthy.Store(true)
|
||||
|
||||
ln1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
ln2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s.mu.Lock()
|
||||
s.listeners["upstream1"] = ln1
|
||||
s.listeners["upstream2"] = ln2
|
||||
s.mu.Unlock()
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
defer s.Stop()
|
||||
|
||||
conn1, err := net.DialTimeout("tcp", ln1.Addr().String(), 2*time.Second)
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
|
||||
_, err = conn1.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
_ = conn1.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, err := conn1.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend1:")), "should route to backend1, got: %s", string(buf[:n]))
|
||||
|
||||
conn2, err := net.DialTimeout("tcp", ln2.Addr().String(), 2*time.Second)
|
||||
require.NoError(t, err)
|
||||
defer conn2.Close()
|
||||
|
||||
_, err = conn2.Write([]byte("world"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_ = conn2.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, err = conn2.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend2:")), "should route to backend2, got: %s", string(buf[:n]))
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
s := NewServer()
|
||||
|
||||
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backendLn.Close()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := backendLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = io.Copy(conn, conn)
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
_ = s.AddUpstream("test", []TargetSpec{{Addr: backendLn.Addr().String()}}, "round_robin", HealthCheckSpec{})
|
||||
s.upstreams["test"].targets[0].healthy.Store(true)
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
s.mu.Lock()
|
||||
s.listeners["test"] = ln
|
||||
s.mu.Unlock()
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
|
||||
require.NoError(t, err)
|
||||
_ = clientConn.Close()
|
||||
|
||||
s.Stop()
|
||||
|
||||
assert.False(t, s.running.Load())
|
||||
|
||||
s.mu.RLock()
|
||||
assert.Empty(t, s.listeners)
|
||||
s.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestStop_Idempotent(t *testing.T) {
|
||||
s := NewServer()
|
||||
s.Stop()
|
||||
s.Stop()
|
||||
}
|
||||
|
||||
func TestStart_DoubleStart(t *testing.T) {
|
||||
s := NewServer()
|
||||
err := s.Start()
|
||||
require.NoError(t, err)
|
||||
err = s.Start()
|
||||
assert.Error(t, err)
|
||||
s.Stop()
|
||||
}
|
||||
|
||||
@ -29,6 +29,7 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"net"
|
||||
@ -286,18 +287,14 @@ func (i *ipHash) SelectByIP(targets []*Target, clientIP string) *Target {
|
||||
|
||||
// Server TCP/UDP Stream 代理服务器。
|
||||
type Server struct {
|
||||
// listeners TCP 监听器映射,按 upstream 名称索引
|
||||
listeners map[string]net.Listener
|
||||
// udpServers UDP 服务器映射
|
||||
listeners map[string]net.Listener
|
||||
udpServers map[string]*udpServer
|
||||
// upstreams 上游配置映射
|
||||
upstreams map[string]*Upstream
|
||||
// connCount 当前连接数
|
||||
connCount atomic.Int64
|
||||
// mu 读写锁,保护并发访问
|
||||
mu sync.RWMutex
|
||||
// running 运行状态标志
|
||||
running atomic.Bool
|
||||
upstreams map[string]*Upstream
|
||||
connCount atomic.Int64
|
||||
mu sync.RWMutex
|
||||
running atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// Upstream Stream 上游配置。
|
||||
@ -393,6 +390,7 @@ func NewServer() *Server {
|
||||
listeners: make(map[string]net.Listener),
|
||||
udpServers: make(map[string]*udpServer),
|
||||
upstreams: make(map[string]*Upstream),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -491,25 +489,71 @@ func (s *Server) ListenUDP(addr string, upstreamName string, timeout time.Durati
|
||||
|
||||
// Start 启动 Stream 服务器。
|
||||
func (s *Server) Start() error {
|
||||
s.running.Store(true)
|
||||
if !s.running.CompareAndSwap(false, true) {
|
||||
return fmt.Errorf("stream server already running")
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// 启动 TCP 监听器
|
||||
for addr, listener := range s.listeners {
|
||||
go s.acceptLoop(addr, listener)
|
||||
s.wg.Add(1)
|
||||
go func(a string, ln net.Listener) {
|
||||
defer s.wg.Done()
|
||||
s.acceptLoop(a, ln)
|
||||
}(addr, listener)
|
||||
}
|
||||
|
||||
// 启动 UDP 服务器
|
||||
for _, udpSrv := range s.udpServers {
|
||||
go udpSrv.serve()
|
||||
go udpSrv.startCleanupTicker()
|
||||
s.wg.Add(1)
|
||||
go func(u *udpServer) {
|
||||
defer s.wg.Done()
|
||||
u.serve()
|
||||
}(udpSrv)
|
||||
s.wg.Add(1)
|
||||
go func(u *udpServer) {
|
||||
defer s.wg.Done()
|
||||
u.startCleanupTicker()
|
||||
}(udpSrv)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the stream server, closing all listeners and waiting for goroutines to finish.
|
||||
func (s *Server) Stop() {
|
||||
if !s.running.CompareAndSwap(true, false) {
|
||||
return
|
||||
}
|
||||
|
||||
close(s.stopCh)
|
||||
|
||||
s.mu.Lock()
|
||||
for _, ln := range s.listeners {
|
||||
_ = ln.Close()
|
||||
}
|
||||
for _, udpSrv := range s.udpServers {
|
||||
close(udpSrv.stopCh)
|
||||
if udpSrv.conn != nil {
|
||||
_ = udpSrv.conn.Close()
|
||||
}
|
||||
}
|
||||
for _, upstream := range s.upstreams {
|
||||
if upstream.healthChk != nil && upstream.healthChk.stopCh != nil {
|
||||
close(upstream.healthChk.stopCh)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
s.wg.Wait()
|
||||
|
||||
s.mu.Lock()
|
||||
s.listeners = make(map[string]net.Listener)
|
||||
s.udpServers = make(map[string]*udpServer)
|
||||
s.stopCh = make(chan struct{})
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// acceptLoop 接受连接循环。
|
||||
//
|
||||
// 在单独的 goroutine 中运行,持续接受 TCP 连接。
|
||||
@ -523,7 +567,12 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if !s.running.Load() {
|
||||
return // 正常关闭
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
continue
|
||||
}
|
||||
@ -544,19 +593,14 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) {
|
||||
// 参数:
|
||||
// - clientConn: 客户端连接
|
||||
// - addr: 监听地址
|
||||
func (s *Server) handleConnection(clientConn net.Conn, _ string) {
|
||||
func (s *Server) handleConnection(clientConn net.Conn, addr string) {
|
||||
defer func() {
|
||||
_ = clientConn.Close()
|
||||
s.connCount.Add(-1)
|
||||
}()
|
||||
|
||||
s.mu.RLock()
|
||||
// 根据监听地址找到对应 upstream(简化:用第一个)
|
||||
var upstream *Upstream
|
||||
for _, up := range s.upstreams {
|
||||
upstream = up
|
||||
break
|
||||
}
|
||||
upstream := s.upstreams[addr]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if upstream == nil {
|
||||
|
||||
@ -157,7 +157,7 @@ func TestHandleConnection_NoHealthyTarget(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.handleConnection(clientConn, "127.0.0.1:0")
|
||||
s.handleConnection(clientConn, "test2")
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -200,7 +200,7 @@ func TestHandleConnection_DialFail(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.handleConnection(clientConn, "127.0.0.1:0")
|
||||
s.handleConnection(clientConn, "test3")
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user