feat(lua): 添加 balancer_by_lua 动态负载均衡功能
- 新增 BalancerByLuaConfig 配置,支持 Lua 脚本控制后端选择 - 实现 api_balancer.go Lua API,暴露 set_current_peer 等函数 - Proxy 集成 Lua 引擎,fallback 到标准算法确保可靠性 - 添加负载均衡算法常量提取,消除魔法字符串 - 支持超时控制和备用算法配置 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
103e8ff0cf
commit
f31e8afeff
3
go.mod
3
go.mod
@ -6,7 +6,9 @@ require (
|
||||
github.com/andybalholm/brotli v1.2.1
|
||||
github.com/fasthttp/router v1.5.4
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||
github.com/klauspost/compress v1.18.5
|
||||
github.com/oschwald/geoip2-golang v1.13.0
|
||||
github.com/quic-go/quic-go v0.59.0
|
||||
github.com/rs/zerolog v1.35.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
@ -22,6 +24,7 @@ require (
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.21 // indirect
|
||||
github.com/oschwald/maxminddb-golang v1.13.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/savsgio/gotils v0.0.0-20250924091648-bce9a52d7761 // indirect
|
||||
|
||||
25
go.sum
25
go.sum
@ -1,5 +1,3 @@
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/andybalholm/brotli v1.2.1 h1:R+f5xP285VArJDRgowrfb9DqL18yVK0gKAW/F+eTWro=
|
||||
github.com/andybalholm/brotli v1.2.1/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
@ -9,8 +7,8 @@ github.com/fasthttp/router v1.5.4 h1:oxdThbBwQgsDIYZ3wR1IavsNl6ZS9WdjKukeMikOnC8
|
||||
github.com/fasthttp/router v1.5.4/go.mod h1:3/hysWq6cky7dTfzaaEPZGdptwjwx0qzTgFCKEWRjgc=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
|
||||
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
|
||||
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
@ -19,10 +17,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs=
|
||||
github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/oschwald/geoip2-golang v1.13.0 h1:Q44/Ldc703pasJeP5V9+aFSZFmBN7DKHbNsSFzQATJI=
|
||||
github.com/oschwald/geoip2-golang v1.13.0/go.mod h1:P9zG+54KPEFOliZ29i7SeYZ/GM6tfEL+rgSn03hYuUo=
|
||||
github.com/oschwald/maxminddb-golang v1.13.0 h1:R8xBorY71s84yO06NgTmQvqvTvlS/bnYZrrWX1MElnU=
|
||||
github.com/oschwald/maxminddb-golang v1.13.0/go.mod h1:BU0z8BfFVhi1LQaonTwwGQlsHUEu9pWNdMfmq4ztm0o=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||
@ -33,16 +33,12 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/rs/zerolog v1.35.0 h1:VD0ykx7HMiMJytqINBsKcbLS+BJ4WYjz+05us+LRTdI=
|
||||
github.com/rs/zerolog v1.35.0/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw=
|
||||
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc=
|
||||
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg=
|
||||
github.com/savsgio/gotils v0.0.0-20250924091648-bce9a52d7761 h1:McifyVxygw1d67y6vxUqls2D46J8W9nrki9c8c0eVvE=
|
||||
github.com/savsgio/gotils v0.0.0-20250924091648-bce9a52d7761/go.mod h1:Vi9gvHvTw4yCUHIznFl5TPULS7aXwgaTByGeBY75Wko=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZyVI=
|
||||
github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw=
|
||||
github.com/valyala/fasthttp v1.70.0 h1:LAhMGcWk13QZWm85+eg8ZBNbrq5mnkWFGbHMUJHIdXA=
|
||||
github.com/valyala/fasthttp v1.70.0/go.mod h1:oDZEHHkJ/Buyklg6uURmYs19442zFSnCIfX3j1FY3pE=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
@ -51,21 +47,12 @@ github.com/yuin/gopher-lua v1.1.2 h1:yF/FjE3hD65tBbt0VXLE13HWS9h34fdzJmrWRXwobGA
|
||||
github.com/yuin/gopher-lua v1.1.2/go.mod h1:7aRmXIWl37SqRf0koeyylBEzJ+aPt8A+mmkQ4f1ntR8=
|
||||
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
|
||||
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
@ -116,7 +116,7 @@ func generateConfig(outputPath string) int {
|
||||
if outputPath == "" {
|
||||
fmt.Print(string(yamlData))
|
||||
} else {
|
||||
if err := os.WriteFile(outputPath, yamlData, 0644); err != nil {
|
||||
if err := os.WriteFile(outputPath, yamlData, 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "写入文件失败: %v\n", err)
|
||||
return 1
|
||||
}
|
||||
@ -285,7 +285,7 @@ func (a *App) Run() int {
|
||||
a.upgradeMgr = server.NewUpgradeManager(a.srv)
|
||||
if a.pidFile != "" {
|
||||
a.upgradeMgr.SetPidFile(a.pidFile)
|
||||
_ = a.upgradeMgr.WritePid() //nolint:errcheck
|
||||
_ = a.upgradeMgr.WritePid()
|
||||
}
|
||||
|
||||
// 启动信号处理
|
||||
@ -365,7 +365,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
||||
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("优雅停止(等待 %v)", timeout))
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.GracefulStop(timeout) //nolint:errcheck
|
||||
_ = a.srv.GracefulStop(timeout)
|
||||
return false
|
||||
|
||||
case syscall.SIGTERM, syscall.SIGINT:
|
||||
@ -374,11 +374,15 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second // 默认值
|
||||
}
|
||||
sigTyped := sig.(syscall.Signal) //nolint:errcheck // 类型断言
|
||||
a.logger.LogSignal(sigName(sigTyped), "停止服务器")
|
||||
sigTyped, ok := sig.(syscall.Signal)
|
||||
if !ok {
|
||||
a.logger.LogSignal("unknown", "停止服务器")
|
||||
} else {
|
||||
a.logger.LogSignal(sigName(sigTyped), "停止服务器")
|
||||
}
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.StopWithTimeout(timeout) //nolint:errcheck // 使用新方法
|
||||
_ = a.srv.StopWithTimeout(timeout)
|
||||
return false
|
||||
|
||||
case syscall.SIGHUP:
|
||||
@ -488,7 +492,7 @@ func (a *App) gracefulUpgrade() {
|
||||
}
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.GracefulStop(timeout) //nolint:errcheck
|
||||
_ = a.srv.GracefulStop(timeout)
|
||||
}
|
||||
|
||||
// sigName 返回信号名称(用于日志输出)。
|
||||
|
||||
@ -24,7 +24,19 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Target 表示负载均衡的后端服务器目标。
|
||||
// Target 表示 HTTP 代理(L7 层)的负载均衡后端服务器目标。
|
||||
//
|
||||
// HTTP Target 特性(区别于 Stream Target):
|
||||
// - URL 解析:支持完整 URL(如 http://backend:8080),包含协议、路径、查询参数
|
||||
// - DNS 动态解析:resolvedIPs 和 lastResolved 字段支持 DNS TTL 缓存和动态重解析
|
||||
// - Failover 支持:配合 Balancer.SelectExcluding 实现失败节点排除重试
|
||||
// - 一致性哈希:VirtualHashes 支持一致性哈希算法的虚拟节点
|
||||
//
|
||||
// 语义差异说明:
|
||||
// - HTTP 代理工作在应用层(L7),需要处理 URL 和 DNS 解析
|
||||
// - Stream 代理工作在传输层(L4),只需简单 host:port,无需 DNS 缓存
|
||||
// - 因此 HTTP Target 和 Stream Target 必须保持独立定义,不可合并
|
||||
//
|
||||
// 所有字段都设计为使用原子操作进行并发访问(如适用)。
|
||||
type Target struct {
|
||||
resolvedIPs atomic.Pointer[[]string]
|
||||
@ -37,7 +49,17 @@ type Target struct {
|
||||
Healthy atomic.Bool
|
||||
}
|
||||
|
||||
// Balancer 是负载均衡算法的接口。
|
||||
// Balancer 是 HTTP 代理(L7 层)负载均衡算法的接口。
|
||||
//
|
||||
// HTTP Balancer 特性(区别于 Stream Balancer):
|
||||
// - Select(): 标准选择方法,按算法策略选择健康目标
|
||||
// - SelectExcluding(): 故障转移支持,排除失败节点后选择替代目标
|
||||
//
|
||||
// 语义差异说明:
|
||||
// - HTTP 代理需要 failover 重试能力(next_upstream 配置),因此需要 SelectExcluding
|
||||
// - Stream 代理工作在传输层(L4),无重试机制,仅需要 Select 方法
|
||||
// - 因此 HTTP Balancer 和 Stream Balancer 接口签名不同,不可合并
|
||||
//
|
||||
// 实现必须是并发安全的。
|
||||
type Balancer interface {
|
||||
// Select 根据算法策略从提供的列表中选择一个目标。
|
||||
|
||||
140
internal/lua/api_balancer.go
Normal file
140
internal/lua/api_balancer.go
Normal file
@ -0,0 +1,140 @@
|
||||
// Package lua 提供 ngx.balancer API 实现
|
||||
// 本文件实现负载均衡相关的 Lua API,用于在 Lua 脚本中选择后端目标
|
||||
package lua
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
)
|
||||
|
||||
// BalancerContext Lua Balancer 上下文
|
||||
type BalancerContext struct {
|
||||
LastError error
|
||||
Selected *loadbalance.Target
|
||||
ClientIP string
|
||||
Targets []*loadbalance.Target
|
||||
Retries int
|
||||
selected bool
|
||||
}
|
||||
|
||||
// RegisterBalancerAPI 注册 ngx.balancer API
|
||||
func RegisterBalancerAPI(L *glua.LState, bctx *BalancerContext, ngx *glua.LTable) {
|
||||
balancer := L.NewTable()
|
||||
|
||||
// set_current_peer(host, port) 或 set_current_peer(url)
|
||||
L.SetField(balancer, "set_current_peer", L.NewFunction(func(L *glua.LState) int {
|
||||
nargs := L.GetTop()
|
||||
|
||||
var host, port string
|
||||
if nargs >= 2 {
|
||||
// set_current_peer(host, port) 形式
|
||||
host = L.CheckString(1)
|
||||
port = L.CheckString(2)
|
||||
if !strings.HasPrefix(port, ":") {
|
||||
port = ":" + port
|
||||
}
|
||||
} else if nargs == 1 {
|
||||
// set_current_peer(url) 形式
|
||||
targetURL := L.CheckString(1)
|
||||
u, err := url.Parse(targetURL)
|
||||
if err != nil {
|
||||
L.Push(glua.LBool(false))
|
||||
L.Push(glua.LString("invalid url: " + err.Error()))
|
||||
return 2
|
||||
}
|
||||
host = u.Hostname()
|
||||
port = ":" + u.Port()
|
||||
if u.Port() == "" {
|
||||
if u.Scheme == "https" {
|
||||
port = ":443"
|
||||
} else {
|
||||
port = ":80"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
L.RaiseError("set_current_peer requires 1 or 2 arguments")
|
||||
return 0
|
||||
}
|
||||
|
||||
// 在 Targets 中查找匹配的目标
|
||||
targetURL := "http://" + host + port
|
||||
for _, t := range bctx.Targets {
|
||||
if t.URL == targetURL || strings.HasPrefix(t.URL, targetURL) {
|
||||
bctx.Selected = t
|
||||
bctx.selected = true
|
||||
L.Push(glua.LBool(true))
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
L.Push(glua.LBool(false))
|
||||
L.Push(glua.LString("target not found: " + host + port))
|
||||
return 2
|
||||
}))
|
||||
|
||||
// set_more_tries(count)
|
||||
L.SetField(balancer, "set_more_tries", L.NewFunction(func(L *glua.LState) int {
|
||||
count := L.CheckInt(1)
|
||||
bctx.Retries = count
|
||||
L.Push(glua.LBool(true))
|
||||
return 1
|
||||
}))
|
||||
|
||||
// get_last_failure()
|
||||
L.SetField(balancer, "get_last_failure", L.NewFunction(func(L *glua.LState) int {
|
||||
if bctx.LastError == nil {
|
||||
L.Push(glua.LNil)
|
||||
return 1
|
||||
}
|
||||
// 返回失败类型: "failed", "timeout", "next"
|
||||
failType := classifyError(bctx.LastError)
|
||||
L.Push(glua.LString(failType))
|
||||
return 1
|
||||
}))
|
||||
|
||||
// get_targets() - 返回所有可用目标
|
||||
L.SetField(balancer, "get_targets", L.NewFunction(func(L *glua.LState) int {
|
||||
targetsTable := L.NewTable()
|
||||
for i, t := range bctx.Targets {
|
||||
targetTable := L.NewTable()
|
||||
L.SetField(targetTable, "url", glua.LString(t.URL))
|
||||
L.SetField(targetTable, "weight", glua.LNumber(t.Weight))
|
||||
L.SetField(targetTable, "healthy", glua.LBool(t.Healthy.Load()))
|
||||
targetsTable.RawSetInt(i+1, targetTable)
|
||||
}
|
||||
L.Push(targetsTable)
|
||||
return 1
|
||||
}))
|
||||
|
||||
// get_client_ip()
|
||||
L.SetField(balancer, "get_client_ip", L.NewFunction(func(L *glua.LState) int {
|
||||
L.Push(glua.LString(bctx.ClientIP))
|
||||
return 1
|
||||
}))
|
||||
|
||||
L.SetField(ngx, "balancer", balancer)
|
||||
}
|
||||
|
||||
// IsSelected 检查是否调用了 set_current_peer
|
||||
func (bctx *BalancerContext) IsSelected() bool {
|
||||
return bctx.selected
|
||||
}
|
||||
|
||||
// classifyError 分类错误类型
|
||||
func classifyError(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
// 根据错误类型返回字符串
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "timeout") {
|
||||
return "timeout"
|
||||
}
|
||||
if strings.Contains(errStr, "connection") {
|
||||
return "failed"
|
||||
}
|
||||
return "failed"
|
||||
}
|
||||
@ -33,6 +33,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
@ -43,10 +44,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
"rua.plus/lolly/internal/cache"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
"rua.plus/lolly/internal/lua"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
"rua.plus/lolly/internal/resolver"
|
||||
"rua.plus/lolly/internal/utils"
|
||||
@ -57,6 +60,13 @@ const (
|
||||
// upstreamCache 上游缓存标识
|
||||
// 用于标记请求可直接使用缓存响应,无需转发到上游
|
||||
upstreamCache = "CACHE"
|
||||
|
||||
// 负载均衡算法名称
|
||||
lbRoundRobin = "round_robin"
|
||||
lbWeightedRoundRobin = "weighted_round_robin"
|
||||
lbLeastConn = "least_conn"
|
||||
lbIPHash = "ip_hash"
|
||||
lbConsistentHash = "consistent_hash"
|
||||
)
|
||||
|
||||
// Proxy 表示反向代理实例,负责将 HTTP 请求转发到后端目标。
|
||||
@ -67,16 +77,18 @@ const (
|
||||
// - 所有公开方法均为并发安全
|
||||
// - 使用前需确保 targets 中至少有一个健康目标
|
||||
type Proxy struct {
|
||||
balancer loadbalance.Balancer
|
||||
resolver resolver.Resolver
|
||||
clients map[string]*fasthttp.HostClient
|
||||
config *config.ProxyConfig
|
||||
cache *cache.ProxyCache
|
||||
healthChecker *HealthChecker
|
||||
stopCh chan struct{}
|
||||
targets []*loadbalance.Target
|
||||
mu sync.RWMutex
|
||||
started atomic.Bool
|
||||
balancer loadbalance.Balancer
|
||||
fallbackBalancer loadbalance.Balancer // Lua 失败时的备用均衡器
|
||||
resolver resolver.Resolver
|
||||
clients map[string]*fasthttp.HostClient
|
||||
config *config.ProxyConfig
|
||||
cache *cache.ProxyCache
|
||||
healthChecker *HealthChecker
|
||||
luaEngine *lua.LuaEngine // Lua 引擎引用
|
||||
stopCh chan struct{}
|
||||
targets []*loadbalance.Target
|
||||
mu sync.RWMutex
|
||||
started atomic.Bool
|
||||
}
|
||||
|
||||
// NewProxy 使用给定的配置和后台目标创建一个新的反向代理实例。
|
||||
@ -86,11 +98,12 @@ type Proxy struct {
|
||||
// - cfg: 代理配置,包括超时时间、请求头和负载均衡策略
|
||||
// - targets: 要代理请求的后端目标列表
|
||||
// - transportCfg: 可选的 Transport 连接池配置,nil 时使用默认值
|
||||
// - luaEngine: 可选的 Lua 引擎,用于 balancer_by_lua 功能
|
||||
//
|
||||
// 返回值:
|
||||
// - *Proxy: 配置完成并可处理请求的代理实例
|
||||
// - error: 初始化失败时非空(无效配置、没有健康目标等)
|
||||
func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportCfg *config.TransportConfig) (*Proxy, error) {
|
||||
func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportCfg *config.TransportConfig, luaEngine *lua.LuaEngine) (*Proxy, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("proxy config is nil")
|
||||
}
|
||||
@ -105,12 +118,24 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建 fallback 负载均衡器
|
||||
fallbackAlgo := cfg.BalancerByLua.Fallback
|
||||
if fallbackAlgo == "" {
|
||||
fallbackAlgo = lbRoundRobin
|
||||
}
|
||||
fallbackBalancer, err := createBalancerByName(fallbackAlgo, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create fallback balancer: %w", err)
|
||||
}
|
||||
|
||||
p := &Proxy{
|
||||
targets: targets,
|
||||
clients: make(map[string]*fasthttp.HostClient),
|
||||
balancer: balancer,
|
||||
config: cfg,
|
||||
stopCh: make(chan struct{}),
|
||||
targets: targets,
|
||||
clients: make(map[string]*fasthttp.HostClient),
|
||||
balancer: balancer,
|
||||
fallbackBalancer: fallbackBalancer,
|
||||
config: cfg,
|
||||
luaEngine: luaEngine,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 为每个后端目标初始化 HostClient
|
||||
@ -138,6 +163,28 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// createBalancerByName 根据算法名称创建负载均衡器
|
||||
func createBalancerByName(name string, cfg *config.ProxyConfig) (loadbalance.Balancer, error) {
|
||||
switch name {
|
||||
case lbRoundRobin, "":
|
||||
return loadbalance.NewRoundRobin(), nil
|
||||
case lbWeightedRoundRobin:
|
||||
return loadbalance.NewWeightedRoundRobin(), nil
|
||||
case lbLeastConn:
|
||||
return loadbalance.NewLeastConnections(), nil
|
||||
case lbIPHash:
|
||||
return loadbalance.NewIPHash(), nil
|
||||
case lbConsistentHash:
|
||||
virtualNodes := cfg.VirtualNodes
|
||||
if virtualNodes <= 0 {
|
||||
virtualNodes = 150
|
||||
}
|
||||
return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil
|
||||
default:
|
||||
return nil, errors.New("unsupported load balance algorithm: " + name)
|
||||
}
|
||||
}
|
||||
|
||||
// SetHealthChecker 设置健康检查器用于被动健康检查。
|
||||
// 当代理请求失败时,将调用健康检查器的 MarkUnhealthy 方法。
|
||||
func (p *Proxy) SetHealthChecker(hc *HealthChecker) {
|
||||
@ -147,15 +194,15 @@ func (p *Proxy) SetHealthChecker(hc *HealthChecker) {
|
||||
// createBalancer 根据配置的算法创建负载均衡器。
|
||||
func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) {
|
||||
switch cfg.LoadBalance {
|
||||
case "round_robin", "":
|
||||
case lbRoundRobin, "":
|
||||
return loadbalance.NewRoundRobin(), nil
|
||||
case "weighted_round_robin":
|
||||
case lbWeightedRoundRobin:
|
||||
return loadbalance.NewWeightedRoundRobin(), nil
|
||||
case "least_conn":
|
||||
case lbLeastConn:
|
||||
return loadbalance.NewLeastConnections(), nil
|
||||
case "ip_hash":
|
||||
case lbIPHash:
|
||||
return loadbalance.NewIPHash(), nil
|
||||
case "consistent_hash":
|
||||
case lbConsistentHash:
|
||||
virtualNodes := cfg.VirtualNodes
|
||||
if virtualNodes <= 0 {
|
||||
virtualNodes = 150
|
||||
@ -529,13 +576,10 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
}
|
||||
}
|
||||
|
||||
// selectTarget 使用配置的负载均衡器选择后端目标。
|
||||
// 对于 IP 哈希负载均衡,从请求中提取客户端 IP。
|
||||
// 对于一致性哈希,根据配置的 hash_key 选择目标。
|
||||
// 如果没有可用的健康目标则返回 nil。
|
||||
// selectTarget 使用配置的负载均衡器选择后端目标
|
||||
// 如果启用 Lua balancer,先尝试 Lua 脚本选择
|
||||
func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
|
||||
p.mu.RLock()
|
||||
balancer := p.balancer
|
||||
targets := p.targets
|
||||
p.mu.RUnlock()
|
||||
|
||||
@ -543,6 +587,96 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否启用 Lua balancer
|
||||
if p.config.BalancerByLua.Enabled && p.config.BalancerByLua.Script != "" && p.luaEngine != nil {
|
||||
target, err := p.selectByLua(ctx, targets)
|
||||
if err != nil {
|
||||
logging.Warn().Err(err).Msg("lua balancer failed, using fallback")
|
||||
// Lua 失败,使用 fallback 算法
|
||||
return p.selectByFallback(ctx, targets)
|
||||
}
|
||||
if target != nil {
|
||||
return target
|
||||
}
|
||||
// Lua 未调用 set_current_peer,使用 fallback
|
||||
logging.Debug().Msg("lua balancer did not select target, using fallback")
|
||||
return p.selectByFallback(ctx, targets)
|
||||
}
|
||||
|
||||
// 使用传统负载均衡算法
|
||||
return p.selectByBalancer(ctx, targets)
|
||||
}
|
||||
|
||||
// selectByLua 使用 Lua 脚本选择目标
|
||||
func (p *Proxy) selectByLua(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) (*loadbalance.Target, error) {
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
|
||||
bctx := &lua.BalancerContext{
|
||||
Targets: targets,
|
||||
ClientIP: clientIP,
|
||||
Retries: p.config.NextUpstream.Tries,
|
||||
}
|
||||
|
||||
// 创建 Lua 协程
|
||||
coro, err := p.luaEngine.NewCoroutine(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create lua coroutine: %w", err)
|
||||
}
|
||||
defer coro.Close()
|
||||
|
||||
// 注册 balancer API
|
||||
L := coro.Co
|
||||
ngx, ok := L.GetGlobal("ngx").(*glua.LTable)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("global 'ngx' is not an LTable")
|
||||
}
|
||||
lua.RegisterBalancerAPI(L, bctx, ngx)
|
||||
|
||||
// 设置超时
|
||||
timeout := p.config.BalancerByLua.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = 100 * time.Millisecond
|
||||
}
|
||||
|
||||
// 执行脚本(带超时)
|
||||
execCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
coro.ExecutionContext = execCtx
|
||||
|
||||
err = coro.ExecuteFile(p.config.BalancerByLua.Script)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("execute lua script: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否调用了 set_current_peer
|
||||
if !bctx.IsSelected() {
|
||||
return nil, nil // 未选择,返回 nil 表示需使用 fallback
|
||||
}
|
||||
|
||||
return bctx.Selected, nil
|
||||
}
|
||||
|
||||
// selectByFallback 使用 fallback 算法选择目标
|
||||
func (p *Proxy) selectByFallback(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target {
|
||||
p.mu.RLock()
|
||||
balancer := p.fallbackBalancer
|
||||
p.mu.RUnlock()
|
||||
|
||||
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
return ipHash.SelectByIP(targets, clientIP)
|
||||
}
|
||||
|
||||
return balancer.Select(targets)
|
||||
}
|
||||
|
||||
// selectByBalancer 使用主负载均衡器选择目标
|
||||
func (p *Proxy) selectByBalancer(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target {
|
||||
p.mu.RLock()
|
||||
balancer := p.balancer
|
||||
p.mu.RUnlock()
|
||||
|
||||
// 对于 IPHash 负载均衡器,提取客户端 IP
|
||||
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
|
||||
@ -25,7 +25,6 @@ import (
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"github.com/valyala/fasthttp/fasthttputil"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
@ -140,7 +139,7 @@ func TestNewProxy(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, err := NewProxy(tt.cfg, tt.targets, nil)
|
||||
p, err := NewProxy(tt.cfg, tt.targets, nil, nil)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("NewProxy() expected error containing %q, got nil", tt.errContains)
|
||||
@ -185,7 +184,7 @@ func TestServeHTTP_NoHealthyTargets(t *testing.T) {
|
||||
targets[0].Healthy.Store(false)
|
||||
targets[1].Healthy.Store(false)
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -233,7 +232,7 @@ func TestServeHTTP_RequestForwarding(t *testing.T) {
|
||||
{URL: "http://localhost:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -323,7 +322,7 @@ func TestSelectTarget(t *testing.T) {
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, tt.targets, nil)
|
||||
p, err := NewProxy(cfg, tt.targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -437,7 +436,7 @@ func TestModifyRequestHeaders(t *testing.T) {
|
||||
{URL: "http://localhost:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -528,7 +527,7 @@ func TestModifyResponseHeaders(t *testing.T) {
|
||||
{URL: "http://localhost:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -612,7 +611,7 @@ func TestUpdateTargets(t *testing.T) {
|
||||
{URL: "http://old2:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, initialTargets, nil)
|
||||
p, err := NewProxy(cfg, initialTargets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -661,7 +660,7 @@ func TestGetTargets(t *testing.T) {
|
||||
{URL: "http://backend2:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -690,7 +689,7 @@ func TestGetConfig(t *testing.T) {
|
||||
{URL: "http://localhost:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -882,7 +881,7 @@ func TestHandleWebSocket(t *testing.T) {
|
||||
{URL: "http://localhost:8080"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -916,7 +915,7 @@ func TestSetHealthChecker(t *testing.T) {
|
||||
{URL: "http://localhost:8081"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -954,7 +953,7 @@ func TestGetClient(t *testing.T) {
|
||||
{URL: "http://localhost:8082"},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -1018,7 +1017,7 @@ func TestProxyCache(t *testing.T) {
|
||||
}
|
||||
targets[0].Healthy.Store(true)
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -1071,7 +1070,7 @@ func TestServeHTTP_WithPassiveHealthCheck(t *testing.T) {
|
||||
}
|
||||
targets[0].Healthy.Store(true)
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
@ -1156,7 +1155,7 @@ func TestUpstreamVariablesCapture(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create proxy: %v", err)
|
||||
}
|
||||
@ -1239,7 +1238,7 @@ func TestUpstreamVariablesErrorPaths(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p, err := NewProxy(cfg, targets, nil)
|
||||
p, err := NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create proxy: %v", err)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user