Compare commits
10 Commits
88a2c1fc1b
...
afbbc3a951
| Author | SHA1 | Date | |
|---|---|---|---|
| afbbc3a951 | |||
| 66ea93e3c1 | |||
| 7204432ca0 | |||
| f12ffd180f | |||
| 503daf65d3 | |||
| ef871f1d39 | |||
| e5885ce888 | |||
| 72f189bba8 | |||
| 3b6b70a491 | |||
| cb1f86298e |
33
CHANGELOG.md
33
CHANGELOG.md
@ -7,6 +7,39 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [0.4.1] - 2026-06-10
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **stream**: 修复上游选择逻辑,添加优雅关闭支持
|
||||||
|
- **stream**: 修复 Stop 后 stopCh 未重置导致的重启问题
|
||||||
|
|
||||||
|
## [0.4.0] - 2026-06-09
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
#### 负载均衡
|
||||||
|
|
||||||
|
- **Least Time 负载均衡器**:基于 EWMA(指数加权移动平均)统计的最少响应时间算法,自动选择响应最快上游
|
||||||
|
- **Session Sticky 负载均衡器**:基于 Cookie 的会话粘性,一致性哈希分片,支持 Cookie 过期、域名、路径、Secure/HttpOnly/SameSite 属性
|
||||||
|
- 对应 YAML 配置支持:`least_time`、`sticky` 策略及参数
|
||||||
|
|
||||||
|
#### 平台与构建
|
||||||
|
|
||||||
|
- FreeBSD 部署示例
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Least Time 响应时间记录修正
|
||||||
|
- Sticky Cookie 格式、分片键、过期检查修复
|
||||||
|
- Sticky 双重 Stop 防护和重启支持
|
||||||
|
- 配置验证:`least_time` 的 `default_time` 不允许负值
|
||||||
|
|
||||||
|
### Tests
|
||||||
|
|
||||||
|
- Least Time 和 Sticky 负载均衡器集成测试
|
||||||
|
- Least Time 和 Sticky 基准测试
|
||||||
|
|
||||||
## [0.3.0] - 2026-06-05
|
## [0.3.0] - 2026-06-05
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|||||||
2
Makefile
2
Makefile
@ -1,7 +1,7 @@
|
|||||||
# Makefile - Lolly Build Commands
|
# Makefile - Lolly Build Commands
|
||||||
|
|
||||||
APP_NAME := lolly
|
APP_NAME := lolly
|
||||||
FALLBACK_VERSION := 0.3.0
|
FALLBACK_VERSION := 0.4.1
|
||||||
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null | sed 's/^v//' || echo "$(FALLBACK_VERSION)")
|
VERSION := $(shell git describe --tags --always --dirty 2>/dev/null | sed 's/^v//' || echo "$(FALLBACK_VERSION)")
|
||||||
|
|
||||||
GIT_COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
GIT_COMMIT := $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||||
|
|||||||
@ -123,6 +123,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
|||||||
timeout = 30 * time.Second
|
timeout = 30 * time.Second
|
||||||
}
|
}
|
||||||
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("Graceful stop (waiting %v)", timeout))
|
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("Graceful stop (waiting %v)", timeout))
|
||||||
|
a.shutdownStream()
|
||||||
a.shutdownHTTP2()
|
a.shutdownHTTP2()
|
||||||
a.shutdownHTTP3()
|
a.shutdownHTTP3()
|
||||||
_ = a.srv.GracefulStop(timeout)
|
_ = a.srv.GracefulStop(timeout)
|
||||||
@ -139,6 +140,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
|||||||
} else {
|
} else {
|
||||||
a.logger.LogSignal(sigName(sigTyped), "Stopping server")
|
a.logger.LogSignal(sigName(sigTyped), "Stopping server")
|
||||||
}
|
}
|
||||||
|
a.shutdownStream()
|
||||||
a.shutdownHTTP2()
|
a.shutdownHTTP2()
|
||||||
a.shutdownHTTP3()
|
a.shutdownHTTP3()
|
||||||
_ = a.srv.StopWithTimeout(timeout)
|
_ = a.srv.StopWithTimeout(timeout)
|
||||||
@ -329,6 +331,7 @@ func (a *App) gracefulUpgrade() {
|
|||||||
if timeout <= 0 {
|
if timeout <= 0 {
|
||||||
timeout = 30 * time.Second
|
timeout = 30 * time.Second
|
||||||
}
|
}
|
||||||
|
a.shutdownStream()
|
||||||
a.shutdownHTTP2()
|
a.shutdownHTTP2()
|
||||||
a.shutdownHTTP3()
|
a.shutdownHTTP3()
|
||||||
_ = a.srv.GracefulStop(timeout)
|
_ = a.srv.GracefulStop(timeout)
|
||||||
|
|||||||
@ -252,6 +252,12 @@ func (a *App) shutdownHTTP2() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// reopenLogs reinitializes the logger from current config.
|
// reopenLogs reinitializes the logger from current config.
|
||||||
|
func (a *App) shutdownStream() {
|
||||||
|
if a.streamSrv != nil {
|
||||||
|
a.streamSrv.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a *App) reopenLogs() {
|
func (a *App) reopenLogs() {
|
||||||
if a.cfg != nil {
|
if a.cfg != nil {
|
||||||
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)
|
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)
|
||||||
|
|||||||
@ -19,7 +19,7 @@ import (
|
|||||||
// 注意事项:
|
// 注意事项:
|
||||||
// - Path 使用前缀匹配,较长路径优先匹配
|
// - Path 使用前缀匹配,较长路径优先匹配
|
||||||
// - 至少配置一个 Target 才能正常工作
|
// - 至少配置一个 Target 才能正常工作
|
||||||
// - 负载均衡算法支持:round_robin、weighted_round_robin、least_conn、ip_hash、consistent_hash、random、least_time、sticky
|
// - 负载均衡算法支持:round_robin、weighted_round_robin、least_conn、ip_hash、consistent_hash、random、least_time、sticky
|
||||||
// - 一致性哈希需要配置 HashKey
|
// - 一致性哈希需要配置 HashKey
|
||||||
//
|
//
|
||||||
// 使用示例:
|
// 使用示例:
|
||||||
|
|||||||
@ -509,6 +509,9 @@ func validateProxy(p *ProxyConfig) error {
|
|||||||
if p.LeastTime.Metric != "" && p.LeastTime.Metric != "header" && p.LeastTime.Metric != "last_byte" {
|
if p.LeastTime.Metric != "" && p.LeastTime.Metric != "header" && p.LeastTime.Metric != "last_byte" {
|
||||||
return fmt.Errorf("无效的 least_time metric: %s(有效值: header, last_byte)", p.LeastTime.Metric)
|
return fmt.Errorf("无效的 least_time metric: %s(有效值: header, last_byte)", p.LeastTime.Metric)
|
||||||
}
|
}
|
||||||
|
if p.LeastTime.DefaultTime < 0 {
|
||||||
|
return fmt.Errorf("least_time default_time 不能为负数")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate sticky config
|
// validate sticky config
|
||||||
@ -519,6 +522,12 @@ func validateProxy(p *ProxyConfig) error {
|
|||||||
if p.Sticky.FallbackAlgo != "" && !loadbalance.IsValidAlgorithm(p.Sticky.FallbackAlgo) {
|
if p.Sticky.FallbackAlgo != "" && !loadbalance.IsValidAlgorithm(p.Sticky.FallbackAlgo) {
|
||||||
return fmt.Errorf("无效的 sticky fallback_balance: %s", p.Sticky.FallbackAlgo)
|
return fmt.Errorf("无效的 sticky fallback_balance: %s", p.Sticky.FallbackAlgo)
|
||||||
}
|
}
|
||||||
|
if p.Sticky.SameSite != "" {
|
||||||
|
validSameSites := []string{"Lax", "Strict", "None"}
|
||||||
|
if !slices.Contains(validSameSites, p.Sticky.SameSite) {
|
||||||
|
return fmt.Errorf("无效的 sticky same_site: %s(有效值: Lax, Strict, None)", p.Sticky.SameSite)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证故障转移配置
|
// 验证故障转移配置
|
||||||
|
|||||||
@ -170,6 +170,80 @@ func TestValidateProxy(t *testing.T) {
|
|||||||
wantErr: true,
|
wantErr: true,
|
||||||
errMsg: "无效的负载均衡算法",
|
errMsg: "无效的负载均衡算法",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "有效 least_time 配置 metric=header",
|
||||||
|
config: ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||||
|
LoadBalance: "least_time",
|
||||||
|
LeastTime: LeastTimeConfig{Metric: "header"},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "有效 least_time 配置 metric=last_byte",
|
||||||
|
config: ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||||
|
LoadBalance: "least_time",
|
||||||
|
LeastTime: LeastTimeConfig{Metric: "last_byte"},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无效 least_time metric",
|
||||||
|
config: ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||||
|
LoadBalance: "least_time",
|
||||||
|
LeastTime: LeastTimeConfig{Metric: "invalid"},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "无效的 least_time metric",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "有效 sticky 配置",
|
||||||
|
config: ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||||
|
LoadBalance: "sticky",
|
||||||
|
Sticky: StickyConfig{Enabled: true, FallbackAlgo: "round_robin"},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无效 sticky enabled=false",
|
||||||
|
config: ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||||
|
LoadBalance: "sticky",
|
||||||
|
Sticky: StickyConfig{Enabled: false},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "sticky.enabled 必须为 true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无效 sticky fallback_balance",
|
||||||
|
config: ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||||
|
LoadBalance: "sticky",
|
||||||
|
Sticky: StickyConfig{Enabled: true, FallbackAlgo: "invalid"},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "无效的 sticky fallback_balance",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无效 sticky same_site",
|
||||||
|
config: ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
|
||||||
|
LoadBalance: "sticky",
|
||||||
|
Sticky: StickyConfig{Enabled: true, SameSite: "Invalid"},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "无效的 sticky same_site",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
@ -11,10 +11,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
gzipType = "gzip"
|
gzipType = "gzip"
|
||||||
offValue = "off"
|
offValue = "off"
|
||||||
redirectType = "redirect"
|
redirectType = "redirect"
|
||||||
staticType = "static"
|
staticType = "static"
|
||||||
|
returnDirective = "return"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Warning represents a conversion warning for unsupported or partially supported directives.
|
// Warning represents a conversion warning for unsupported or partially supported directives.
|
||||||
@ -70,7 +71,7 @@ var unsupportedDirectives = map[string]string{
|
|||||||
"split_clients": "the 'split_clients' directive is not supported",
|
"split_clients": "the 'split_clients' directive is not supported",
|
||||||
"geo": "the 'geo' directive is not supported; use access.geoip config instead",
|
"geo": "the 'geo' directive is not supported; use access.geoip config instead",
|
||||||
"range": "the 'range' directive is not supported",
|
"range": "the 'range' directive is not supported",
|
||||||
"return": "the 'return' directive is not supported for non-redirect status codes; only 301/302 are supported",
|
returnDirective: "the 'return' directive is not supported for non-redirect status codes; only 301/302 are supported",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert converts a parsed nginx configuration to a lolly configuration.
|
// Convert converts a parsed nginx configuration to a lolly configuration.
|
||||||
@ -273,7 +274,7 @@ func convertServerBlock(d *Directive, upstreams map[string]*upstreamInfo, result
|
|||||||
parseAccessLog(bd, result)
|
parseAccessLog(bd, result)
|
||||||
case "error_log":
|
case "error_log":
|
||||||
parseErrorLog(bd, result)
|
parseErrorLog(bd, result)
|
||||||
case "return":
|
case returnDirective:
|
||||||
parseServerReturn(bd, &baseServer, result)
|
parseServerReturn(bd, &baseServer, result)
|
||||||
case "rewrite":
|
case "rewrite":
|
||||||
parseRewrite(bd, &baseServer)
|
parseRewrite(bd, &baseServer)
|
||||||
@ -457,7 +458,7 @@ func parseServerReturn(d *Directive, server *config.ServerConfig, result *Conver
|
|||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
result.Warnings = append(result.Warnings, Warning{
|
result.Warnings = append(result.Warnings, Warning{
|
||||||
Directive: "return",
|
Directive: returnDirective,
|
||||||
Line: d.Line,
|
Line: d.Line,
|
||||||
File: d.File,
|
File: d.File,
|
||||||
Message: fmt.Sprintf("return %d is not a redirect; only 301/302 are supported at server level", code),
|
Message: fmt.Sprintf("return %d is not a redirect; only 301/302 are supported at server level", code),
|
||||||
@ -554,7 +555,7 @@ func classifyLocation(d *Directive, serverRoot string, result *ConvertResult) lo
|
|||||||
hasRootOrAlias = true
|
hasRootOrAlias = true
|
||||||
case "try_files":
|
case "try_files":
|
||||||
hasTryFiles = true
|
hasTryFiles = true
|
||||||
case "return":
|
case returnDirective:
|
||||||
if len(d.Block[i].Args) > 0 {
|
if len(d.Block[i].Args) > 0 {
|
||||||
code, err := strconv.Atoi(d.Block[i].Args[0])
|
code, err := strconv.Atoi(d.Block[i].Args[0])
|
||||||
if err == nil && (code == 301 || code == 302) {
|
if err == nil && (code == 301 || code == 302) {
|
||||||
@ -876,7 +877,7 @@ func convertRedirectDirectives(directives []Directive, locPath string, server *c
|
|||||||
for i := range directives {
|
for i := range directives {
|
||||||
d := &directives[i]
|
d := &directives[i]
|
||||||
|
|
||||||
if d.Name != "return" {
|
if d.Name != returnDirective {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -900,7 +901,7 @@ func convertRedirectDirectives(directives []Directive, locPath string, server *c
|
|||||||
Flag: "permanent",
|
Flag: "permanent",
|
||||||
})
|
})
|
||||||
result.Warnings = append(result.Warnings, Warning{
|
result.Warnings = append(result.Warnings, Warning{
|
||||||
Directive: "return",
|
Directive: returnDirective,
|
||||||
Line: d.Line,
|
Line: d.Line,
|
||||||
File: d.File,
|
File: d.File,
|
||||||
Message: "return 301 converted to rewrite rule with permanent flag",
|
Message: "return 301 converted to rewrite rule with permanent flag",
|
||||||
@ -912,14 +913,14 @@ func convertRedirectDirectives(directives []Directive, locPath string, server *c
|
|||||||
Flag: "redirect",
|
Flag: "redirect",
|
||||||
})
|
})
|
||||||
result.Warnings = append(result.Warnings, Warning{
|
result.Warnings = append(result.Warnings, Warning{
|
||||||
Directive: "return",
|
Directive: returnDirective,
|
||||||
Line: d.Line,
|
Line: d.Line,
|
||||||
File: d.File,
|
File: d.File,
|
||||||
Message: "return 302 converted to rewrite rule with redirect flag",
|
Message: "return 302 converted to rewrite rule with redirect flag",
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
result.Warnings = append(result.Warnings, Warning{
|
result.Warnings = append(result.Warnings, Warning{
|
||||||
Directive: "return",
|
Directive: returnDirective,
|
||||||
Line: d.Line,
|
Line: d.Line,
|
||||||
File: d.File,
|
File: d.File,
|
||||||
Message: fmt.Sprintf("return %d in location is not a redirect; only 301/302 are supported", code),
|
Message: fmt.Sprintf("return %d in location is not a redirect; only 301/302 are supported", code),
|
||||||
|
|||||||
@ -71,7 +71,6 @@ func TestProxyRequestHeaders(t *testing.T) {
|
|||||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
||||||
// 验证代理配置已设置
|
// 验证代理配置已设置
|
||||||
assert.NotNil(t, cfg.Headers.SetRequest)
|
assert.NotNil(t, cfg.Headers.SetRequest)
|
||||||
assert.Equal(t, "custom-value", cfg.Headers.SetRequest["X-Custom-Header"])
|
assert.Equal(t, "custom-value", cfg.Headers.SetRequest["X-Custom-Header"])
|
||||||
@ -101,7 +100,6 @@ func TestProxyResponseHeaders(t *testing.T) {
|
|||||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
||||||
// 验证响应头配置
|
// 验证响应头配置
|
||||||
assert.Equal(t, "lolly", cfg.Headers.SetResponse["X-Server"])
|
assert.Equal(t, "lolly", cfg.Headers.SetResponse["X-Server"])
|
||||||
assert.Contains(t, cfg.Headers.Remove, "X-Powered-By")
|
assert.Contains(t, cfg.Headers.Remove, "X-Powered-By")
|
||||||
@ -125,7 +123,6 @@ func TestProxyTimeout(t *testing.T) {
|
|||||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
||||||
// 验证超时配置
|
// 验证超时配置
|
||||||
assert.Equal(t, 1*time.Second, cfg.Timeout.Connect)
|
assert.Equal(t, 1*time.Second, cfg.Timeout.Connect)
|
||||||
assert.Equal(t, 50*time.Millisecond, cfg.Timeout.Read)
|
assert.Equal(t, 50*time.Millisecond, cfg.Timeout.Read)
|
||||||
@ -150,7 +147,6 @@ func TestProxyLoadBalanceRoundRobin(t *testing.T) {
|
|||||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
||||||
// 验证负载均衡器类型
|
// 验证负载均衡器类型
|
||||||
assert.Equal(t, "round_robin", cfg.LoadBalance)
|
assert.Equal(t, "round_robin", cfg.LoadBalance)
|
||||||
assert.Len(t, targets, 2)
|
assert.Len(t, targets, 2)
|
||||||
@ -174,7 +170,6 @@ func TestProxyWeightedRoundRobin(t *testing.T) {
|
|||||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
||||||
// 验证权重配置
|
// 验证权重配置
|
||||||
assert.Equal(t, 3, targets[0].Weight)
|
assert.Equal(t, 3, targets[0].Weight)
|
||||||
assert.Equal(t, 1, targets[1].Weight)
|
assert.Equal(t, 1, targets[1].Weight)
|
||||||
@ -198,7 +193,6 @@ func TestProxyLeastConn(t *testing.T) {
|
|||||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
||||||
assert.Equal(t, "least_conn", cfg.LoadBalance)
|
assert.Equal(t, "least_conn", cfg.LoadBalance)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -220,7 +214,6 @@ func TestProxyIPHash(t *testing.T) {
|
|||||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
||||||
assert.Equal(t, "ip_hash", cfg.LoadBalance)
|
assert.Equal(t, "ip_hash", cfg.LoadBalance)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -240,7 +233,6 @@ func TestProxyConsistentHash(t *testing.T) {
|
|||||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
||||||
assert.Equal(t, "consistent_hash", cfg.LoadBalance)
|
assert.Equal(t, "consistent_hash", cfg.LoadBalance)
|
||||||
assert.Equal(t, "uri", cfg.HashKey)
|
assert.Equal(t, "uri", cfg.HashKey)
|
||||||
assert.Equal(t, 150, cfg.VirtualNodes)
|
assert.Equal(t, 150, cfg.VirtualNodes)
|
||||||
@ -268,7 +260,6 @@ func TestProxyErrorHandling(t *testing.T) {
|
|||||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
||||||
// 验证 MaxFails 配置 (int64 类型)
|
// 验证 MaxFails 配置 (int64 类型)
|
||||||
assert.Equal(t, int64(3), targets[0].MaxFails)
|
assert.Equal(t, int64(3), targets[0].MaxFails)
|
||||||
assert.Equal(t, 10*time.Second, targets[0].FailTimeout)
|
assert.Equal(t, 10*time.Second, targets[0].FailTimeout)
|
||||||
@ -300,7 +291,6 @@ func TestProxyCacheConfig(t *testing.T) {
|
|||||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
||||||
// 验证缓存配置
|
// 验证缓存配置
|
||||||
assert.True(t, cfg.Cache.Enabled)
|
assert.True(t, cfg.Cache.Enabled)
|
||||||
assert.Equal(t, 60*time.Second, cfg.Cache.MaxAge)
|
assert.Equal(t, 60*time.Second, cfg.Cache.MaxAge)
|
||||||
@ -330,7 +320,6 @@ func TestProxyNextUpstream(t *testing.T) {
|
|||||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
||||||
// 验证故障转移配置
|
// 验证故障转移配置
|
||||||
assert.Equal(t, 3, cfg.NextUpstream.Tries)
|
assert.Equal(t, 3, cfg.NextUpstream.Tries)
|
||||||
assert.Contains(t, cfg.NextUpstream.HTTPCodes, 502)
|
assert.Contains(t, cfg.NextUpstream.HTTPCodes, 502)
|
||||||
@ -360,7 +349,6 @@ func TestProxyHealthCheck(t *testing.T) {
|
|||||||
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
_, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
||||||
// 验证健康检查配置
|
// 验证健康检查配置
|
||||||
assert.Equal(t, 10*time.Second, cfg.HealthCheck.Interval)
|
assert.Equal(t, 10*time.Second, cfg.HealthCheck.Interval)
|
||||||
assert.Equal(t, "/health", cfg.HealthCheck.Path)
|
assert.Equal(t, "/health", cfg.HealthCheck.Path)
|
||||||
|
|||||||
@ -12,9 +12,12 @@ package loadbalance
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// createHealthyTarget 创建并返回一个指定 URL 和健康状态的 Target 实例。
|
// createHealthyTarget 创建并返回一个指定 URL 和健康状态的 Target 实例。
|
||||||
@ -940,6 +943,8 @@ func TestIsValidAlgorithm(t *testing.T) {
|
|||||||
{"ip_hash", "ip_hash", true},
|
{"ip_hash", "ip_hash", true},
|
||||||
{"consistent_hash", "consistent_hash", true},
|
{"consistent_hash", "consistent_hash", true},
|
||||||
{"random", "random", true},
|
{"random", "random", true},
|
||||||
|
{"least_time", "least_time", true},
|
||||||
|
{"sticky", "sticky", true},
|
||||||
{"invalid", "invalid", false},
|
{"invalid", "invalid", false},
|
||||||
{"empty", "", true}, // 空字符串有效(使用默认值)
|
{"empty", "", true}, // 空字符串有效(使用默认值)
|
||||||
{"unknown", "unknown-algorithm", false},
|
{"unknown", "unknown-algorithm", false},
|
||||||
@ -2091,3 +2096,94 @@ func TestRandomBalancer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestBalancerIntegration_LeastTime 验证 Least Time 算法 consistently 选择更快目标。
|
||||||
|
func TestBalancerIntegration_LeastTime(t *testing.T) {
|
||||||
|
targets := []*Target{
|
||||||
|
NewTargetFromConfig("http://slow:8080", 1, 0, 0, 0, false, false, ""),
|
||||||
|
NewTargetFromConfig("http://fast:8080", 1, 0, 0, 0, false, false, ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||||
|
|
||||||
|
// 模拟:slow 目标有 100ms avg,fast 有 10ms avg
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
targets[0].Stats.Record(50*time.Millisecond, 100*time.Millisecond)
|
||||||
|
targets[1].Stats.Record(5*time.Millisecond, 10*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 选择 100 次,应该 mostly 选 fast
|
||||||
|
fastCount := 0
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
selected := lt.Select(targets)
|
||||||
|
if selected.URL == "http://fast:8080" {
|
||||||
|
fastCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if fastCount < 80 {
|
||||||
|
t.Errorf("fast target selected %d/100 times, expected >80", fastCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBalancerIntegration_StickyWithLeastTimeFallback 验证 Sticky 在目标不健康时 fallback。
|
||||||
|
func TestBalancerIntegration_StickyWithLeastTimeFallback(t *testing.T) {
|
||||||
|
fallback := NewLeastTime("last_byte", time.Millisecond)
|
||||||
|
config := StickyConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Name: "test_route",
|
||||||
|
Expires: time.Hour,
|
||||||
|
Path: "/",
|
||||||
|
HttpOnly: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
sticky := NewStickySession(config, fallback)
|
||||||
|
sticky.Start()
|
||||||
|
defer sticky.Stop()
|
||||||
|
|
||||||
|
targets := []*Target{
|
||||||
|
NewTargetFromConfig("http://backend1:8080", 1, 0, 0, 0, false, false, ""),
|
||||||
|
NewTargetFromConfig("http://backend2:8080", 1, 0, 0, 0, false, false, ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
|
||||||
|
// 首次请求
|
||||||
|
selected1 := sticky.Select(ctx, targets)
|
||||||
|
if selected1 == nil {
|
||||||
|
t.Fatal("expected a target")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 cookie 已设置
|
||||||
|
cookie := ctx.Response.Header.PeekCookie("test_route")
|
||||||
|
if len(cookie) == 0 {
|
||||||
|
t.Fatal("expected cookie")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使 selected1 不健康
|
||||||
|
selected1.Healthy.Store(false)
|
||||||
|
|
||||||
|
// 第二次请求带 cookie 应该 fallback
|
||||||
|
ctx2 := &fasthttp.RequestCtx{}
|
||||||
|
ctx2.Request.Header.SetCookie("test_route", string(extractCookieValue(cookie)))
|
||||||
|
|
||||||
|
selected2 := sticky.Select(ctx2, targets)
|
||||||
|
if selected2 == nil {
|
||||||
|
t.Fatal("expected fallback target")
|
||||||
|
}
|
||||||
|
if selected2.URL == selected1.URL {
|
||||||
|
t.Error("expected different target after fallback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractCookieValue 从 Set-Cookie header 中提取 cookie 值
|
||||||
|
func extractCookieValue(cookieHeader []byte) []byte {
|
||||||
|
s := string(cookieHeader)
|
||||||
|
// Format: "name=value; ..."
|
||||||
|
parts := strings.SplitN(s, "=", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
valueParts := strings.SplitN(parts[1], ";", 2)
|
||||||
|
return []byte(valueParts[0])
|
||||||
|
}
|
||||||
|
|||||||
@ -14,10 +14,12 @@ type EWMAStats struct {
|
|||||||
|
|
||||||
const defaultAlphaScale = 300 // alpha = 0.3
|
const defaultAlphaScale = 300 // alpha = 0.3
|
||||||
|
|
||||||
|
// NewEWMAStats 创建新的 EWMA 统计器。
|
||||||
func NewEWMAStats() *EWMAStats {
|
func NewEWMAStats() *EWMAStats {
|
||||||
return &EWMAStats{}
|
return &EWMAStats{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Record 记录一次响应时间样本。
|
||||||
func (e *EWMAStats) Record(headerTime, lastByteTime time.Duration) {
|
func (e *EWMAStats) Record(headerTime, lastByteTime time.Duration) {
|
||||||
e.recordAtomic(&e.headerTime, headerTime)
|
e.recordAtomic(&e.headerTime, headerTime)
|
||||||
e.recordAtomic(&e.lastByteTime, lastByteTime)
|
e.recordAtomic(&e.lastByteTime, lastByteTime)
|
||||||
@ -41,18 +43,22 @@ func (e *EWMAStats) recordAtomic(ptr *atomic.Int64, newValue time.Duration) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HeaderTime 返回首字节时间的 EWMA 值。
|
||||||
func (e *EWMAStats) HeaderTime() time.Duration {
|
func (e *EWMAStats) HeaderTime() time.Duration {
|
||||||
return time.Duration(e.headerTime.Load())
|
return time.Duration(e.headerTime.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LastByteTime 返回完整响应时间的 EWMA 值。
|
||||||
func (e *EWMAStats) LastByteTime() time.Duration {
|
func (e *EWMAStats) LastByteTime() time.Duration {
|
||||||
return time.Duration(e.lastByteTime.Load())
|
return time.Duration(e.lastByteTime.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SampleCount 返回已记录的样本数量。
|
||||||
func (e *EWMAStats) SampleCount() int64 {
|
func (e *EWMAStats) SampleCount() int64 {
|
||||||
return e.sampleCount.Load()
|
return e.sampleCount.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset 重置所有统计数据。
|
||||||
func (e *EWMAStats) Reset() {
|
func (e *EWMAStats) Reset() {
|
||||||
e.headerTime.Store(0)
|
e.headerTime.Store(0)
|
||||||
e.lastByteTime.Store(0)
|
e.lastByteTime.Store(0)
|
||||||
|
|||||||
@ -97,5 +97,7 @@ func (l *LeastTime) GetMetric() string {
|
|||||||
return l.metric
|
return l.metric
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Balancer = (*LeastTime)(nil)
|
var (
|
||||||
var _ ResponseTimeRecorder = (*LeastTime)(nil)
|
_ Balancer = (*LeastTime)(nil)
|
||||||
|
_ ResponseTimeRecorder = (*LeastTime)(nil)
|
||||||
|
)
|
||||||
|
|||||||
96
internal/loadbalance/least_time_bench_test.go
Normal file
96
internal/loadbalance/least_time_bench_test.go
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
package loadbalance
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkLeastTime_Select(b *testing.B) {
|
||||||
|
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||||
|
targets := []*Target{
|
||||||
|
NewTargetFromConfig("http://a:8080", 1, 0, 0, 0, false, false, ""),
|
||||||
|
NewTargetFromConfig("http://b:8080", 1, 0, 0, 0, false, false, ""),
|
||||||
|
NewTargetFromConfig("http://c:8080", 1, 0, 0, 0, false, false, ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-populate stats
|
||||||
|
for _, t := range targets {
|
||||||
|
t.Stats.Record(10*time.Millisecond, 20*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
lt.Select(targets)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLeastTime_Record(b *testing.B) {
|
||||||
|
stats := NewEWMAStats()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
stats.Record(10*time.Millisecond, 20*time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLeastTime_Concurrent(b *testing.B) {
|
||||||
|
lt := NewLeastTime("last_byte", time.Millisecond)
|
||||||
|
targets := []*Target{
|
||||||
|
NewTargetFromConfig("http://a:8080", 1, 0, 0, 0, false, false, ""),
|
||||||
|
NewTargetFromConfig("http://b:8080", 1, 0, 0, 0, false, false, ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
lt.Select(targets)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkStickySession_Select(b *testing.B) {
|
||||||
|
fallback := NewRoundRobin()
|
||||||
|
config := DefaultStickyConfig()
|
||||||
|
|
||||||
|
sticky := NewStickySession(config, fallback)
|
||||||
|
sticky.Start()
|
||||||
|
defer sticky.Stop()
|
||||||
|
|
||||||
|
targets := []*Target{
|
||||||
|
NewTargetFromConfig("http://backend1:8080", 1, 0, 0, 0, false, false, ""),
|
||||||
|
NewTargetFromConfig("http://backend2:8080", 1, 0, 0, 0, false, false, ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-populate a cookie
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
sticky.Select(ctx, targets)
|
||||||
|
cookie := ctx.Response.Header.PeekCookie(config.Name)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.Request.Header.SetCookie(config.Name, string(extractCookieValue(cookie)))
|
||||||
|
sticky.Select(ctx, targets)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkStickySession_SelectNew(b *testing.B) {
|
||||||
|
fallback := NewRoundRobin()
|
||||||
|
config := DefaultStickyConfig()
|
||||||
|
|
||||||
|
sticky := NewStickySession(config, fallback)
|
||||||
|
sticky.Start()
|
||||||
|
defer sticky.Stop()
|
||||||
|
|
||||||
|
targets := []*Target{
|
||||||
|
NewTargetFromConfig("http://backend1:8080", 1, 0, 0, 0, false, false, ""),
|
||||||
|
NewTargetFromConfig("http://backend2:8080", 1, 0, 0, 0, false, false, ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
sticky.Select(ctx, targets)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -2,6 +2,7 @@ package loadbalance
|
|||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
|
// StickyConfig 配置 Sticky 负载均衡的 Cookie 参数。
|
||||||
type StickyConfig struct {
|
type StickyConfig struct {
|
||||||
Enabled bool `yaml:"enabled"`
|
Enabled bool `yaml:"enabled"`
|
||||||
Name string `yaml:"name"`
|
Name string `yaml:"name"`
|
||||||
@ -13,6 +14,7 @@ type StickyConfig struct {
|
|||||||
SameSite string `yaml:"same_site"`
|
SameSite string `yaml:"same_site"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultStickyConfig 返回 Sticky 负载均衡的默认配置。
|
||||||
func DefaultStickyConfig() StickyConfig {
|
func DefaultStickyConfig() StickyConfig {
|
||||||
return StickyConfig{
|
return StickyConfig{
|
||||||
Name: "lolly_route",
|
Name: "lolly_route",
|
||||||
|
|||||||
@ -90,6 +90,8 @@ const (
|
|||||||
lbIPHash = "ip_hash" // IP 哈希
|
lbIPHash = "ip_hash" // IP 哈希
|
||||||
lbConsistentHash = "consistent_hash" // 一致性哈希
|
lbConsistentHash = "consistent_hash" // 一致性哈希
|
||||||
lbRandom = "random" // 随机(Power of Two Choices)
|
lbRandom = "random" // 随机(Power of Two Choices)
|
||||||
|
lbLeastTime = "least_time" // 最小响应时间
|
||||||
|
lbSticky = "sticky" // 会话粘性
|
||||||
)
|
)
|
||||||
|
|
||||||
// headersPool 复用缓存 headers map,减少分配。
|
// headersPool 复用缓存 headers map,减少分配。
|
||||||
@ -229,6 +231,22 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC
|
|||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stickyBalancer wraps StickySession to implement loadbalance.Balancer.
|
||||||
|
// It delegates Select/SelectExcluding to the fallback balancer while
|
||||||
|
// allowing the proxy to access the StickySession for cookie-based routing.
|
||||||
|
type stickyBalancer struct {
|
||||||
|
sticky *loadbalance.StickySession
|
||||||
|
fallback loadbalance.Balancer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *stickyBalancer) Select(targets []*loadbalance.Target) *loadbalance.Target {
|
||||||
|
return b.fallback.Select(targets)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *stickyBalancer) SelectExcluding(targets []*loadbalance.Target, excluded []*loadbalance.Target) *loadbalance.Target {
|
||||||
|
return b.fallback.SelectExcluding(targets, excluded)
|
||||||
|
}
|
||||||
|
|
||||||
// createBalancerByName 根据算法名称创建负载均衡器。
|
// createBalancerByName 根据算法名称创建负载均衡器。
|
||||||
//
|
//
|
||||||
// 支持的算法:
|
// 支持的算法:
|
||||||
@ -263,6 +281,49 @@ func createBalancerByName(name string, cfg *config.ProxyConfig) (loadbalance.Bal
|
|||||||
return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil
|
return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil
|
||||||
case lbRandom:
|
case lbRandom:
|
||||||
return loadbalance.NewRandom(), nil
|
return loadbalance.NewRandom(), nil
|
||||||
|
case lbLeastTime:
|
||||||
|
metric := cfg.LeastTime.Metric
|
||||||
|
if metric == "" {
|
||||||
|
metric = "last_byte"
|
||||||
|
}
|
||||||
|
defaultTime := cfg.LeastTime.DefaultTime
|
||||||
|
if defaultTime <= 0 {
|
||||||
|
defaultTime = time.Millisecond
|
||||||
|
}
|
||||||
|
return loadbalance.NewLeastTime(metric, defaultTime), nil
|
||||||
|
case lbSticky:
|
||||||
|
stickyCfg := loadbalance.StickyConfig{
|
||||||
|
Enabled: cfg.Sticky.Enabled,
|
||||||
|
Name: cfg.Sticky.Name,
|
||||||
|
Expires: cfg.Sticky.Expires,
|
||||||
|
Domain: cfg.Sticky.Domain,
|
||||||
|
Path: cfg.Sticky.Path,
|
||||||
|
Secure: cfg.Sticky.Secure,
|
||||||
|
HttpOnly: cfg.Sticky.HttpOnly,
|
||||||
|
SameSite: cfg.Sticky.SameSite,
|
||||||
|
}
|
||||||
|
if stickyCfg.Name == "" {
|
||||||
|
stickyCfg.Name = "lolly_route"
|
||||||
|
}
|
||||||
|
if stickyCfg.Expires <= 0 {
|
||||||
|
stickyCfg.Expires = time.Hour
|
||||||
|
}
|
||||||
|
if stickyCfg.Path == "" {
|
||||||
|
stickyCfg.Path = "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
fallbackAlgo := cfg.Sticky.FallbackAlgo
|
||||||
|
if fallbackAlgo == "" {
|
||||||
|
fallbackAlgo = lbRoundRobin
|
||||||
|
}
|
||||||
|
fallbackBalancer, err := createBalancerByName(fallbackAlgo, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sticky fallback balancer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sticky := loadbalance.NewStickySession(stickyCfg, fallbackBalancer)
|
||||||
|
sticky.Start()
|
||||||
|
return &stickyBalancer{sticky: sticky, fallback: fallbackBalancer}, nil
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("unsupported load balance algorithm: " + name)
|
return nil, errors.New("unsupported load balance algorithm: " + name)
|
||||||
}
|
}
|
||||||
@ -813,6 +874,12 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
|||||||
// 记录首字节时间
|
// 记录首字节时间
|
||||||
timing.MarkHeaderReceived()
|
timing.MarkHeaderReceived()
|
||||||
|
|
||||||
|
// 记录首字节响应时间(用于 least_time 负载均衡)
|
||||||
|
if recorder, ok := p.balancer.(loadbalance.ResponseTimeRecorder); ok {
|
||||||
|
headerTime := timing.headerReceived.Sub(timing.connectEnd)
|
||||||
|
recorder.RecordResponseTime(target, headerTime, 0)
|
||||||
|
}
|
||||||
|
|
||||||
// 请求成功,减少连接计数
|
// 请求成功,减少连接计数
|
||||||
loadbalance.DecrementConnections(target)
|
loadbalance.DecrementConnections(target)
|
||||||
|
|
||||||
@ -917,6 +984,15 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
|||||||
|
|
||||||
// 修改响应头
|
// 修改响应头
|
||||||
p.modifyResponseHeaders(ctx)
|
p.modifyResponseHeaders(ctx)
|
||||||
|
|
||||||
|
// 记录完整响应时间(用于 least_time 负载均衡)
|
||||||
|
timing.MarkResponseEnd()
|
||||||
|
if recorder, ok := p.balancer.(loadbalance.ResponseTimeRecorder); ok {
|
||||||
|
headerTime := timing.headerReceived.Sub(timing.connectEnd)
|
||||||
|
lastByteTime := timing.responseEnd.Sub(timing.connectEnd)
|
||||||
|
recorder.RecordResponseTime(target, headerTime, lastByteTime)
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -119,6 +119,11 @@ func (p *Proxy) selectByBalancer(ctx *fasthttp.RequestCtx, targets []*loadbalanc
|
|||||||
balancer := p.balancer
|
balancer := p.balancer
|
||||||
p.mu.RUnlock()
|
p.mu.RUnlock()
|
||||||
|
|
||||||
|
// 对于 StickySession 负载均衡器,需要请求上下文
|
||||||
|
if sb, ok := balancer.(*stickyBalancer); ok {
|
||||||
|
return sb.sticky.Select(ctx, targets)
|
||||||
|
}
|
||||||
|
|
||||||
// 对于 IPHash 负载均衡器,提取客户端 IP
|
// 对于 IPHash 负载均衡器,提取客户端 IP
|
||||||
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
|
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
|
||||||
clientIP := netutil.ExtractClientIP(ctx)
|
clientIP := netutil.ExtractClientIP(ctx)
|
||||||
|
|||||||
@ -12,6 +12,7 @@
|
|||||||
package stream
|
package stream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -63,7 +64,7 @@ func TestStart_NoListeners(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, s.running.Load())
|
assert.True(t, s.running.Load())
|
||||||
|
|
||||||
s.running.Store(false)
|
s.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStart_WithTCPListeners(t *testing.T) {
|
func TestStart_WithTCPListeners(t *testing.T) {
|
||||||
@ -81,12 +82,7 @@ func TestStart_WithTCPListeners(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, s.running.Load())
|
assert.True(t, s.running.Load())
|
||||||
|
|
||||||
s.running.Store(false)
|
s.Stop()
|
||||||
s.mu.RLock()
|
|
||||||
for _, ln := range s.listeners {
|
|
||||||
_ = ln.Close()
|
|
||||||
}
|
|
||||||
s.mu.RUnlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStart_AcceptConnections(t *testing.T) {
|
func TestStart_AcceptConnections(t *testing.T) {
|
||||||
@ -115,7 +111,7 @@ func TestStart_AcceptConnections(t *testing.T) {
|
|||||||
proxyAddr := ln.Addr().String()
|
proxyAddr := ln.Addr().String()
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.listeners[proxyAddr] = ln
|
s.listeners["test"] = ln
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
err = s.Start()
|
err = s.Start()
|
||||||
@ -136,12 +132,7 @@ func TestStart_AcceptConnections(t *testing.T) {
|
|||||||
|
|
||||||
_ = clientConn.Close()
|
_ = clientConn.Close()
|
||||||
|
|
||||||
s.running.Store(false)
|
s.Stop()
|
||||||
s.mu.RLock()
|
|
||||||
for _, l := range s.listeners {
|
|
||||||
_ = l.Close()
|
|
||||||
}
|
|
||||||
s.mu.RUnlock()
|
|
||||||
_ = backendLn.Close()
|
_ = backendLn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -637,3 +628,138 @@ func TestStartCleanupTicker_StopsOnSignal(t *testing.T) {
|
|||||||
t.Fatal("startCleanupTicker did not stop after signal")
|
t.Fatal("startCleanupTicker did not stop after signal")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleConnection_MultipleUpstreams(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
|
||||||
|
backend1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
backend2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer backend1.Close()
|
||||||
|
defer backend2.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := backend1.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _ := conn.Read(buf)
|
||||||
|
_, _ = conn.Write(append([]byte("backend1:"), buf[:n]...))
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := backend2.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _ := conn.Read(buf)
|
||||||
|
_, _ = conn.Write(append([]byte("backend2:"), buf[:n]...))
|
||||||
|
}()
|
||||||
|
|
||||||
|
_ = s.AddUpstream("upstream1", []TargetSpec{{Addr: backend1.Addr().String()}}, "round_robin", HealthCheckSpec{})
|
||||||
|
_ = s.AddUpstream("upstream2", []TargetSpec{{Addr: backend2.Addr().String()}}, "round_robin", HealthCheckSpec{})
|
||||||
|
s.upstreams["upstream1"].targets[0].healthy.Store(true)
|
||||||
|
s.upstreams["upstream2"].targets[0].healthy.Store(true)
|
||||||
|
|
||||||
|
ln1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
ln2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.listeners["upstream1"] = ln1
|
||||||
|
s.listeners["upstream2"] = ln2
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
err = s.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer s.Stop()
|
||||||
|
|
||||||
|
conn1, err := net.DialTimeout("tcp", ln1.Addr().String(), 2*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn1.Close()
|
||||||
|
|
||||||
|
_, err = conn1.Write([]byte("hello"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
_ = conn1.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, err := conn1.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend1:")), "should route to backend1, got: %s", string(buf[:n]))
|
||||||
|
|
||||||
|
conn2, err := net.DialTimeout("tcp", ln2.Addr().String(), 2*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn2.Close()
|
||||||
|
|
||||||
|
_, err = conn2.Write([]byte("world"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_ = conn2.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, err = conn2.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend2:")), "should route to backend2, got: %s", string(buf[:n]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStop(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
|
||||||
|
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer backendLn.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := backendLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = io.Copy(conn, conn)
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_ = s.AddUpstream("test", []TargetSpec{{Addr: backendLn.Addr().String()}}, "round_robin", HealthCheckSpec{})
|
||||||
|
s.upstreams["test"].targets[0].healthy.Store(true)
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
s.mu.Lock()
|
||||||
|
s.listeners["test"] = ln
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
err = s.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
clientConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_ = clientConn.Close()
|
||||||
|
|
||||||
|
s.Stop()
|
||||||
|
|
||||||
|
assert.False(t, s.running.Load())
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
assert.Empty(t, s.listeners)
|
||||||
|
s.mu.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStop_Idempotent(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
s.Stop()
|
||||||
|
s.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStart_DoubleStart(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
err := s.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = s.Start()
|
||||||
|
assert.Error(t, err)
|
||||||
|
s.Stop()
|
||||||
|
}
|
||||||
|
|||||||
@ -29,6 +29,7 @@
|
|||||||
package stream
|
package stream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -286,18 +287,14 @@ func (i *ipHash) SelectByIP(targets []*Target, clientIP string) *Target {
|
|||||||
|
|
||||||
// Server TCP/UDP Stream 代理服务器。
|
// Server TCP/UDP Stream 代理服务器。
|
||||||
type Server struct {
|
type Server struct {
|
||||||
// listeners TCP 监听器映射,按 upstream 名称索引
|
listeners map[string]net.Listener
|
||||||
listeners map[string]net.Listener
|
|
||||||
// udpServers UDP 服务器映射
|
|
||||||
udpServers map[string]*udpServer
|
udpServers map[string]*udpServer
|
||||||
// upstreams 上游配置映射
|
upstreams map[string]*Upstream
|
||||||
upstreams map[string]*Upstream
|
connCount atomic.Int64
|
||||||
// connCount 当前连接数
|
mu sync.RWMutex
|
||||||
connCount atomic.Int64
|
running atomic.Bool
|
||||||
// mu 读写锁,保护并发访问
|
wg sync.WaitGroup
|
||||||
mu sync.RWMutex
|
stopCh chan struct{}
|
||||||
// running 运行状态标志
|
|
||||||
running atomic.Bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upstream Stream 上游配置。
|
// Upstream Stream 上游配置。
|
||||||
@ -393,6 +390,7 @@ func NewServer() *Server {
|
|||||||
listeners: make(map[string]net.Listener),
|
listeners: make(map[string]net.Listener),
|
||||||
udpServers: make(map[string]*udpServer),
|
udpServers: make(map[string]*udpServer),
|
||||||
upstreams: make(map[string]*Upstream),
|
upstreams: make(map[string]*Upstream),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -491,25 +489,71 @@ func (s *Server) ListenUDP(addr string, upstreamName string, timeout time.Durati
|
|||||||
|
|
||||||
// Start 启动 Stream 服务器。
|
// Start 启动 Stream 服务器。
|
||||||
func (s *Server) Start() error {
|
func (s *Server) Start() error {
|
||||||
s.running.Store(true)
|
if !s.running.CompareAndSwap(false, true) {
|
||||||
|
return fmt.Errorf("stream server already running")
|
||||||
|
}
|
||||||
|
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
// 启动 TCP 监听器
|
|
||||||
for addr, listener := range s.listeners {
|
for addr, listener := range s.listeners {
|
||||||
go s.acceptLoop(addr, listener)
|
s.wg.Add(1)
|
||||||
|
go func(a string, ln net.Listener) {
|
||||||
|
defer s.wg.Done()
|
||||||
|
s.acceptLoop(a, ln)
|
||||||
|
}(addr, listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 启动 UDP 服务器
|
|
||||||
for _, udpSrv := range s.udpServers {
|
for _, udpSrv := range s.udpServers {
|
||||||
go udpSrv.serve()
|
s.wg.Add(1)
|
||||||
go udpSrv.startCleanupTicker()
|
go func(u *udpServer) {
|
||||||
|
defer s.wg.Done()
|
||||||
|
u.serve()
|
||||||
|
}(udpSrv)
|
||||||
|
s.wg.Add(1)
|
||||||
|
go func(u *udpServer) {
|
||||||
|
defer s.wg.Done()
|
||||||
|
u.startCleanupTicker()
|
||||||
|
}(udpSrv)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop stops the stream server, closing all listeners and waiting for goroutines to finish.
|
||||||
|
func (s *Server) Stop() {
|
||||||
|
if !s.running.CompareAndSwap(true, false) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
close(s.stopCh)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
for _, ln := range s.listeners {
|
||||||
|
_ = ln.Close()
|
||||||
|
}
|
||||||
|
for _, udpSrv := range s.udpServers {
|
||||||
|
close(udpSrv.stopCh)
|
||||||
|
if udpSrv.conn != nil {
|
||||||
|
_ = udpSrv.conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, upstream := range s.upstreams {
|
||||||
|
if upstream.healthChk != nil && upstream.healthChk.stopCh != nil {
|
||||||
|
close(upstream.healthChk.stopCh)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
s.wg.Wait()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.listeners = make(map[string]net.Listener)
|
||||||
|
s.udpServers = make(map[string]*udpServer)
|
||||||
|
s.stopCh = make(chan struct{})
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// acceptLoop 接受连接循环。
|
// acceptLoop 接受连接循环。
|
||||||
//
|
//
|
||||||
// 在单独的 goroutine 中运行,持续接受 TCP 连接。
|
// 在单独的 goroutine 中运行,持续接受 TCP 连接。
|
||||||
@ -523,7 +567,12 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) {
|
|||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !s.running.Load() {
|
if !s.running.Load() {
|
||||||
return // 正常关闭
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-s.stopCh:
|
||||||
|
return
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -544,19 +593,14 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) {
|
|||||||
// 参数:
|
// 参数:
|
||||||
// - clientConn: 客户端连接
|
// - clientConn: 客户端连接
|
||||||
// - addr: 监听地址
|
// - addr: 监听地址
|
||||||
func (s *Server) handleConnection(clientConn net.Conn, _ string) {
|
func (s *Server) handleConnection(clientConn net.Conn, addr string) {
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = clientConn.Close()
|
_ = clientConn.Close()
|
||||||
s.connCount.Add(-1)
|
s.connCount.Add(-1)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
// 根据监听地址找到对应 upstream(简化:用第一个)
|
upstream := s.upstreams[addr]
|
||||||
var upstream *Upstream
|
|
||||||
for _, up := range s.upstreams {
|
|
||||||
upstream = up
|
|
||||||
break
|
|
||||||
}
|
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
|
|
||||||
if upstream == nil {
|
if upstream == nil {
|
||||||
|
|||||||
@ -157,7 +157,7 @@ func TestHandleConnection_NoHealthyTarget(t *testing.T) {
|
|||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
s.handleConnection(clientConn, "127.0.0.1:0")
|
s.handleConnection(clientConn, "test2")
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -200,7 +200,7 @@ func TestHandleConnection_DialFail(t *testing.T) {
|
|||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
s.handleConnection(clientConn, "127.0.0.1:0")
|
s.handleConnection(clientConn, "test3")
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user