From f31e8afeff33a3a3419470edab40ce23f17d73d8 Mon Sep 17 00:00:00 2001 From: xfy Date: Mon, 13 Apr 2026 16:14:59 +0800 Subject: [PATCH] =?UTF-8?q?feat(lua):=20=E6=B7=BB=E5=8A=A0=20balancer=5Fby?= =?UTF-8?q?=5Flua=20=E5=8A=A8=E6=80=81=E8=B4=9F=E8=BD=BD=E5=9D=87=E8=A1=A1?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 BalancerByLuaConfig 配置,支持 Lua 脚本控制后端选择 - 实现 api_balancer.go Lua API,暴露 set_current_peer 等函数 - Proxy 集成 Lua 引擎,fallback 到标准算法确保可靠性 - 添加负载均衡算法常量提取,消除魔法字符串 - 支持超时控制和备用算法配置 Co-Authored-By: Claude Opus 4.6 --- go.mod | 3 + go.sum | 25 +---- internal/app/app.go | 18 +-- internal/loadbalance/balancer.go | 26 ++++- internal/lua/api_balancer.go | 140 +++++++++++++++++++++++ internal/proxy/proxy.go | 186 ++++++++++++++++++++++++++----- internal/proxy/proxy_test.go | 33 +++--- 7 files changed, 360 insertions(+), 71 deletions(-) create mode 100644 internal/lua/api_balancer.go diff --git a/go.mod b/go.mod index bf0d7ac..f06b88f 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,9 @@ require ( github.com/andybalholm/brotli v1.2.1 github.com/fasthttp/router v1.5.4 github.com/google/uuid v1.6.0 + github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/klauspost/compress v1.18.5 + github.com/oschwald/geoip2-golang v1.13.0 github.com/quic-go/quic-go v0.59.0 github.com/rs/zerolog v1.35.0 github.com/stretchr/testify v1.11.1 @@ -22,6 +24,7 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.21 // indirect + github.com/oschwald/maxminddb-golang v1.13.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect github.com/savsgio/gotils v0.0.0-20250924091648-bce9a52d7761 // indirect diff --git a/go.sum b/go.sum index 12917a4..b4f6b7c 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= -github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/andybalholm/brotli v1.2.1 h1:R+f5xP285VArJDRgowrfb9DqL18yVK0gKAW/F+eTWro= github.com/andybalholm/brotli v1.2.1/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -9,8 +7,8 @@ github.com/fasthttp/router v1.5.4 h1:oxdThbBwQgsDIYZ3wR1IavsNl6ZS9WdjKukeMikOnC8 github.com/fasthttp/router v1.5.4/go.mod h1:3/hysWq6cky7dTfzaaEPZGdptwjwx0qzTgFCKEWRjgc= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= -github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -19,10 +17,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs= github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4= +github.com/oschwald/geoip2-golang v1.13.0 h1:Q44/Ldc703pasJeP5V9+aFSZFmBN7DKHbNsSFzQATJI= +github.com/oschwald/geoip2-golang v1.13.0/go.mod h1:P9zG+54KPEFOliZ29i7SeYZ/GM6tfEL+rgSn03hYuUo= +github.com/oschwald/maxminddb-golang v1.13.0 h1:R8xBorY71s84yO06NgTmQvqvTvlS/bnYZrrWX1MElnU= +github.com/oschwald/maxminddb-golang v1.13.0/go.mod h1:BU0z8BfFVhi1LQaonTwwGQlsHUEu9pWNdMfmq4ztm0o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= @@ -33,16 +33,12 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rs/zerolog v1.35.0 h1:VD0ykx7HMiMJytqINBsKcbLS+BJ4WYjz+05us+LRTdI= github.com/rs/zerolog v1.35.0/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw= -github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc= -github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= github.com/savsgio/gotils v0.0.0-20250924091648-bce9a52d7761 h1:McifyVxygw1d67y6vxUqls2D46J8W9nrki9c8c0eVvE= github.com/savsgio/gotils v0.0.0-20250924091648-bce9a52d7761/go.mod h1:Vi9gvHvTw4yCUHIznFl5TPULS7aXwgaTByGeBY75Wko= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZyVI= -github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw= github.com/valyala/fasthttp v1.70.0 h1:LAhMGcWk13QZWm85+eg8ZBNbrq5mnkWFGbHMUJHIdXA= github.com/valyala/fasthttp v1.70.0/go.mod h1:oDZEHHkJ/Buyklg6uURmYs19442zFSnCIfX3j1FY3pE= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= @@ -51,21 +47,12 @@ github.com/yuin/gopher-lua v1.1.2 h1:yF/FjE3hD65tBbt0VXLE13HWS9h34fdzJmrWRXwobGA github.com/yuin/gopher-lua v1.1.2/go.mod h1:7aRmXIWl37SqRf0koeyylBEzJ+aPt8A+mmkQ4f1ntR8= go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= -golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= -golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= -golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= -golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= -golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= -golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/app/app.go b/internal/app/app.go index 5e714a0..10657fc 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -116,7 +116,7 @@ func generateConfig(outputPath string) int { if outputPath == "" { fmt.Print(string(yamlData)) } else { - if err := os.WriteFile(outputPath, yamlData, 0644); err != nil { + if err := os.WriteFile(outputPath, yamlData, 0o644); err != nil { fmt.Fprintf(os.Stderr, "写入文件失败: %v\n", err) return 1 } @@ -285,7 +285,7 @@ func (a *App) Run() int { a.upgradeMgr = server.NewUpgradeManager(a.srv) if a.pidFile != "" { a.upgradeMgr.SetPidFile(a.pidFile) - _ = a.upgradeMgr.WritePid() //nolint:errcheck + _ = a.upgradeMgr.WritePid() } // 启动信号处理 @@ -365,7 +365,7 @@ func (a *App) handleSignal(sig os.Signal) bool { a.logger.LogSignal("SIGQUIT", fmt.Sprintf("优雅停止(等待 %v)", timeout)) a.shutdownHTTP2() a.shutdownHTTP3() - _ = a.srv.GracefulStop(timeout) //nolint:errcheck + _ = a.srv.GracefulStop(timeout) return false case syscall.SIGTERM, syscall.SIGINT: @@ -374,11 +374,15 @@ func (a *App) handleSignal(sig os.Signal) bool { if timeout <= 0 { timeout = 5 * time.Second // 默认值 } - sigTyped := sig.(syscall.Signal) //nolint:errcheck // 类型断言 - a.logger.LogSignal(sigName(sigTyped), "停止服务器") + sigTyped, ok := sig.(syscall.Signal) + if !ok { + a.logger.LogSignal("unknown", "停止服务器") + } else { + a.logger.LogSignal(sigName(sigTyped), "停止服务器") + } a.shutdownHTTP2() a.shutdownHTTP3() - _ = a.srv.StopWithTimeout(timeout) //nolint:errcheck // 使用新方法 + _ = a.srv.StopWithTimeout(timeout) return false case syscall.SIGHUP: @@ -488,7 +492,7 @@ func (a *App) gracefulUpgrade() { } a.shutdownHTTP2() a.shutdownHTTP3() - _ = a.srv.GracefulStop(timeout) //nolint:errcheck + _ = a.srv.GracefulStop(timeout) } // sigName 返回信号名称(用于日志输出)。 diff --git a/internal/loadbalance/balancer.go b/internal/loadbalance/balancer.go index a2984f1..4cddfa0 100644 --- a/internal/loadbalance/balancer.go +++ b/internal/loadbalance/balancer.go @@ -24,7 +24,19 @@ import ( "time" ) -// Target 表示负载均衡的后端服务器目标。 +// Target 表示 HTTP 代理(L7 层)的负载均衡后端服务器目标。 +// +// HTTP Target 特性(区别于 Stream Target): +// - URL 解析:支持完整 URL(如 http://backend:8080),包含协议、路径、查询参数 +// - DNS 动态解析:resolvedIPs 和 lastResolved 字段支持 DNS TTL 缓存和动态重解析 +// - Failover 支持:配合 Balancer.SelectExcluding 实现失败节点排除重试 +// - 一致性哈希:VirtualHashes 支持一致性哈希算法的虚拟节点 +// +// 语义差异说明: +// - HTTP 代理工作在应用层(L7),需要处理 URL 和 DNS 解析 +// - Stream 代理工作在传输层(L4),只需简单 host:port,无需 DNS 缓存 +// - 因此 HTTP Target 和 Stream Target 必须保持独立定义,不可合并 +// // 所有字段都设计为使用原子操作进行并发访问(如适用)。 type Target struct { resolvedIPs atomic.Pointer[[]string] @@ -37,7 +49,17 @@ type Target struct { Healthy atomic.Bool } -// Balancer 是负载均衡算法的接口。 +// Balancer 是 HTTP 代理(L7 层)负载均衡算法的接口。 +// +// HTTP Balancer 特性(区别于 Stream Balancer): +// - Select(): 标准选择方法,按算法策略选择健康目标 +// - SelectExcluding(): 故障转移支持,排除失败节点后选择替代目标 +// +// 语义差异说明: +// - HTTP 代理需要 failover 重试能力(next_upstream 配置),因此需要 SelectExcluding +// - Stream 代理工作在传输层(L4),无重试机制,仅需要 Select 方法 +// - 因此 HTTP Balancer 和 Stream Balancer 接口签名不同,不可合并 +// // 实现必须是并发安全的。 type Balancer interface { // Select 根据算法策略从提供的列表中选择一个目标。 diff --git a/internal/lua/api_balancer.go b/internal/lua/api_balancer.go new file mode 100644 index 0000000..285ee4f --- /dev/null +++ b/internal/lua/api_balancer.go @@ -0,0 +1,140 @@ +// Package lua 提供 ngx.balancer API 实现 +// 本文件实现负载均衡相关的 Lua API,用于在 Lua 脚本中选择后端目标 +package lua + +import ( + "net/url" + "strings" + + glua "github.com/yuin/gopher-lua" + "rua.plus/lolly/internal/loadbalance" +) + +// BalancerContext Lua Balancer 上下文 +type BalancerContext struct { + LastError error + Selected *loadbalance.Target + ClientIP string + Targets []*loadbalance.Target + Retries int + selected bool +} + +// RegisterBalancerAPI 注册 ngx.balancer API +func RegisterBalancerAPI(L *glua.LState, bctx *BalancerContext, ngx *glua.LTable) { + balancer := L.NewTable() + + // set_current_peer(host, port) 或 set_current_peer(url) + L.SetField(balancer, "set_current_peer", L.NewFunction(func(L *glua.LState) int { + nargs := L.GetTop() + + var host, port string + if nargs >= 2 { + // set_current_peer(host, port) 形式 + host = L.CheckString(1) + port = L.CheckString(2) + if !strings.HasPrefix(port, ":") { + port = ":" + port + } + } else if nargs == 1 { + // set_current_peer(url) 形式 + targetURL := L.CheckString(1) + u, err := url.Parse(targetURL) + if err != nil { + L.Push(glua.LBool(false)) + L.Push(glua.LString("invalid url: " + err.Error())) + return 2 + } + host = u.Hostname() + port = ":" + u.Port() + if u.Port() == "" { + if u.Scheme == "https" { + port = ":443" + } else { + port = ":80" + } + } + } else { + L.RaiseError("set_current_peer requires 1 or 2 arguments") + return 0 + } + + // 在 Targets 中查找匹配的目标 + targetURL := "http://" + host + port + for _, t := range bctx.Targets { + if t.URL == targetURL || strings.HasPrefix(t.URL, targetURL) { + bctx.Selected = t + bctx.selected = true + L.Push(glua.LBool(true)) + return 1 + } + } + + L.Push(glua.LBool(false)) + L.Push(glua.LString("target not found: " + host + port)) + return 2 + })) + + // set_more_tries(count) + L.SetField(balancer, "set_more_tries", L.NewFunction(func(L *glua.LState) int { + count := L.CheckInt(1) + bctx.Retries = count + L.Push(glua.LBool(true)) + return 1 + })) + + // get_last_failure() + L.SetField(balancer, "get_last_failure", L.NewFunction(func(L *glua.LState) int { + if bctx.LastError == nil { + L.Push(glua.LNil) + return 1 + } + // 返回失败类型: "failed", "timeout", "next" + failType := classifyError(bctx.LastError) + L.Push(glua.LString(failType)) + return 1 + })) + + // get_targets() - 返回所有可用目标 + L.SetField(balancer, "get_targets", L.NewFunction(func(L *glua.LState) int { + targetsTable := L.NewTable() + for i, t := range bctx.Targets { + targetTable := L.NewTable() + L.SetField(targetTable, "url", glua.LString(t.URL)) + L.SetField(targetTable, "weight", glua.LNumber(t.Weight)) + L.SetField(targetTable, "healthy", glua.LBool(t.Healthy.Load())) + targetsTable.RawSetInt(i+1, targetTable) + } + L.Push(targetsTable) + return 1 + })) + + // get_client_ip() + L.SetField(balancer, "get_client_ip", L.NewFunction(func(L *glua.LState) int { + L.Push(glua.LString(bctx.ClientIP)) + return 1 + })) + + L.SetField(ngx, "balancer", balancer) +} + +// IsSelected 检查是否调用了 set_current_peer +func (bctx *BalancerContext) IsSelected() bool { + return bctx.selected +} + +// classifyError 分类错误类型 +func classifyError(err error) string { + if err == nil { + return "" + } + // 根据错误类型返回字符串 + errStr := err.Error() + if strings.Contains(errStr, "timeout") { + return "timeout" + } + if strings.Contains(errStr, "connection") { + return "failed" + } + return "failed" +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 5193bed..b96de67 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -33,6 +33,7 @@ package proxy import ( + "context" "errors" "fmt" "hash/fnv" @@ -43,10 +44,12 @@ import ( "time" "github.com/valyala/fasthttp" + glua "github.com/yuin/gopher-lua" "rua.plus/lolly/internal/cache" "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/logging" + "rua.plus/lolly/internal/lua" "rua.plus/lolly/internal/netutil" "rua.plus/lolly/internal/resolver" "rua.plus/lolly/internal/utils" @@ -57,6 +60,13 @@ const ( // upstreamCache 上游缓存标识 // 用于标记请求可直接使用缓存响应,无需转发到上游 upstreamCache = "CACHE" + + // 负载均衡算法名称 + lbRoundRobin = "round_robin" + lbWeightedRoundRobin = "weighted_round_robin" + lbLeastConn = "least_conn" + lbIPHash = "ip_hash" + lbConsistentHash = "consistent_hash" ) // Proxy 表示反向代理实例,负责将 HTTP 请求转发到后端目标。 @@ -67,16 +77,18 @@ const ( // - 所有公开方法均为并发安全 // - 使用前需确保 targets 中至少有一个健康目标 type Proxy struct { - balancer loadbalance.Balancer - resolver resolver.Resolver - clients map[string]*fasthttp.HostClient - config *config.ProxyConfig - cache *cache.ProxyCache - healthChecker *HealthChecker - stopCh chan struct{} - targets []*loadbalance.Target - mu sync.RWMutex - started atomic.Bool + balancer loadbalance.Balancer + fallbackBalancer loadbalance.Balancer // Lua 失败时的备用均衡器 + resolver resolver.Resolver + clients map[string]*fasthttp.HostClient + config *config.ProxyConfig + cache *cache.ProxyCache + healthChecker *HealthChecker + luaEngine *lua.LuaEngine // Lua 引擎引用 + stopCh chan struct{} + targets []*loadbalance.Target + mu sync.RWMutex + started atomic.Bool } // NewProxy 使用给定的配置和后台目标创建一个新的反向代理实例。 @@ -86,11 +98,12 @@ type Proxy struct { // - cfg: 代理配置,包括超时时间、请求头和负载均衡策略 // - targets: 要代理请求的后端目标列表 // - transportCfg: 可选的 Transport 连接池配置,nil 时使用默认值 +// - luaEngine: 可选的 Lua 引擎,用于 balancer_by_lua 功能 // // 返回值: // - *Proxy: 配置完成并可处理请求的代理实例 // - error: 初始化失败时非空(无效配置、没有健康目标等) -func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportCfg *config.TransportConfig) (*Proxy, error) { +func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportCfg *config.TransportConfig, luaEngine *lua.LuaEngine) (*Proxy, error) { if cfg == nil { return nil, errors.New("proxy config is nil") } @@ -105,12 +118,24 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC return nil, err } + // 创建 fallback 负载均衡器 + fallbackAlgo := cfg.BalancerByLua.Fallback + if fallbackAlgo == "" { + fallbackAlgo = lbRoundRobin + } + fallbackBalancer, err := createBalancerByName(fallbackAlgo, cfg) + if err != nil { + return nil, fmt.Errorf("create fallback balancer: %w", err) + } + p := &Proxy{ - targets: targets, - clients: make(map[string]*fasthttp.HostClient), - balancer: balancer, - config: cfg, - stopCh: make(chan struct{}), + targets: targets, + clients: make(map[string]*fasthttp.HostClient), + balancer: balancer, + fallbackBalancer: fallbackBalancer, + config: cfg, + luaEngine: luaEngine, + stopCh: make(chan struct{}), } // 为每个后端目标初始化 HostClient @@ -138,6 +163,28 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC return p, nil } +// createBalancerByName 根据算法名称创建负载均衡器 +func createBalancerByName(name string, cfg *config.ProxyConfig) (loadbalance.Balancer, error) { + switch name { + case lbRoundRobin, "": + return loadbalance.NewRoundRobin(), nil + case lbWeightedRoundRobin: + return loadbalance.NewWeightedRoundRobin(), nil + case lbLeastConn: + return loadbalance.NewLeastConnections(), nil + case lbIPHash: + return loadbalance.NewIPHash(), nil + case lbConsistentHash: + virtualNodes := cfg.VirtualNodes + if virtualNodes <= 0 { + virtualNodes = 150 + } + return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil + default: + return nil, errors.New("unsupported load balance algorithm: " + name) + } +} + // SetHealthChecker 设置健康检查器用于被动健康检查。 // 当代理请求失败时,将调用健康检查器的 MarkUnhealthy 方法。 func (p *Proxy) SetHealthChecker(hc *HealthChecker) { @@ -147,15 +194,15 @@ func (p *Proxy) SetHealthChecker(hc *HealthChecker) { // createBalancer 根据配置的算法创建负载均衡器。 func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) { switch cfg.LoadBalance { - case "round_robin", "": + case lbRoundRobin, "": return loadbalance.NewRoundRobin(), nil - case "weighted_round_robin": + case lbWeightedRoundRobin: return loadbalance.NewWeightedRoundRobin(), nil - case "least_conn": + case lbLeastConn: return loadbalance.NewLeastConnections(), nil - case "ip_hash": + case lbIPHash: return loadbalance.NewIPHash(), nil - case "consistent_hash": + case lbConsistentHash: virtualNodes := cfg.VirtualNodes if virtualNodes <= 0 { virtualNodes = 150 @@ -529,13 +576,10 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { } } -// selectTarget 使用配置的负载均衡器选择后端目标。 -// 对于 IP 哈希负载均衡,从请求中提取客户端 IP。 -// 对于一致性哈希,根据配置的 hash_key 选择目标。 -// 如果没有可用的健康目标则返回 nil。 +// selectTarget 使用配置的负载均衡器选择后端目标 +// 如果启用 Lua balancer,先尝试 Lua 脚本选择 func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target { p.mu.RLock() - balancer := p.balancer targets := p.targets p.mu.RUnlock() @@ -543,6 +587,96 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target { return nil } + // 检查是否启用 Lua balancer + if p.config.BalancerByLua.Enabled && p.config.BalancerByLua.Script != "" && p.luaEngine != nil { + target, err := p.selectByLua(ctx, targets) + if err != nil { + logging.Warn().Err(err).Msg("lua balancer failed, using fallback") + // Lua 失败,使用 fallback 算法 + return p.selectByFallback(ctx, targets) + } + if target != nil { + return target + } + // Lua 未调用 set_current_peer,使用 fallback + logging.Debug().Msg("lua balancer did not select target, using fallback") + return p.selectByFallback(ctx, targets) + } + + // 使用传统负载均衡算法 + return p.selectByBalancer(ctx, targets) +} + +// selectByLua 使用 Lua 脚本选择目标 +func (p *Proxy) selectByLua(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) (*loadbalance.Target, error) { + clientIP := netutil.ExtractClientIP(ctx) + + bctx := &lua.BalancerContext{ + Targets: targets, + ClientIP: clientIP, + Retries: p.config.NextUpstream.Tries, + } + + // 创建 Lua 协程 + coro, err := p.luaEngine.NewCoroutine(ctx) + if err != nil { + return nil, fmt.Errorf("create lua coroutine: %w", err) + } + defer coro.Close() + + // 注册 balancer API + L := coro.Co + ngx, ok := L.GetGlobal("ngx").(*glua.LTable) + if !ok { + return nil, fmt.Errorf("global 'ngx' is not an LTable") + } + lua.RegisterBalancerAPI(L, bctx, ngx) + + // 设置超时 + timeout := p.config.BalancerByLua.Timeout + if timeout <= 0 { + timeout = 100 * time.Millisecond + } + + // 执行脚本(带超时) + execCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + coro.ExecutionContext = execCtx + + err = coro.ExecuteFile(p.config.BalancerByLua.Script) + if err != nil { + return nil, fmt.Errorf("execute lua script: %w", err) + } + + // 检查是否调用了 set_current_peer + if !bctx.IsSelected() { + return nil, nil // 未选择,返回 nil 表示需使用 fallback + } + + return bctx.Selected, nil +} + +// selectByFallback 使用 fallback 算法选择目标 +func (p *Proxy) selectByFallback(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target { + p.mu.RLock() + balancer := p.fallbackBalancer + p.mu.RUnlock() + + if ipHash, ok := balancer.(*loadbalance.IPHash); ok { + clientIP := netutil.ExtractClientIP(ctx) + return ipHash.SelectByIP(targets, clientIP) + } + + return balancer.Select(targets) +} + +// selectByBalancer 使用主负载均衡器选择目标 +func (p *Proxy) selectByBalancer(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target { + p.mu.RLock() + balancer := p.balancer + p.mu.RUnlock() + // 对于 IPHash 负载均衡器,提取客户端 IP if ipHash, ok := balancer.(*loadbalance.IPHash); ok { clientIP := netutil.ExtractClientIP(ctx) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 9d7f514..68fcc44 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -25,7 +25,6 @@ import ( "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" - "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/netutil" @@ -140,7 +139,7 @@ func TestNewProxy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := NewProxy(tt.cfg, tt.targets, nil) + p, err := NewProxy(tt.cfg, tt.targets, nil, nil) if tt.wantErr { if err == nil { t.Errorf("NewProxy() expected error containing %q, got nil", tt.errContains) @@ -185,7 +184,7 @@ func TestServeHTTP_NoHealthyTargets(t *testing.T) { targets[0].Healthy.Store(false) targets[1].Healthy.Store(false) - p, err := NewProxy(cfg, targets, nil) + p, err := NewProxy(cfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -233,7 +232,7 @@ func TestServeHTTP_RequestForwarding(t *testing.T) { {URL: "http://localhost:8080"}, } - p, err := NewProxy(cfg, targets, nil) + p, err := NewProxy(cfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -323,7 +322,7 @@ func TestSelectTarget(t *testing.T) { Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, } - p, err := NewProxy(cfg, tt.targets, nil) + p, err := NewProxy(cfg, tt.targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -437,7 +436,7 @@ func TestModifyRequestHeaders(t *testing.T) { {URL: "http://localhost:8080"}, } - p, err := NewProxy(cfg, targets, nil) + p, err := NewProxy(cfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -528,7 +527,7 @@ func TestModifyResponseHeaders(t *testing.T) { {URL: "http://localhost:8080"}, } - p, err := NewProxy(cfg, targets, nil) + p, err := NewProxy(cfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -612,7 +611,7 @@ func TestUpdateTargets(t *testing.T) { {URL: "http://old2:8080"}, } - p, err := NewProxy(cfg, initialTargets, nil) + p, err := NewProxy(cfg, initialTargets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -661,7 +660,7 @@ func TestGetTargets(t *testing.T) { {URL: "http://backend2:8080"}, } - p, err := NewProxy(cfg, targets, nil) + p, err := NewProxy(cfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -690,7 +689,7 @@ func TestGetConfig(t *testing.T) { {URL: "http://localhost:8080"}, } - p, err := NewProxy(cfg, targets, nil) + p, err := NewProxy(cfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -882,7 +881,7 @@ func TestHandleWebSocket(t *testing.T) { {URL: "http://localhost:8080"}, } - p, err := NewProxy(cfg, targets, nil) + p, err := NewProxy(cfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -916,7 +915,7 @@ func TestSetHealthChecker(t *testing.T) { {URL: "http://localhost:8081"}, } - p, err := NewProxy(cfg, targets, nil) + p, err := NewProxy(cfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -954,7 +953,7 @@ func TestGetClient(t *testing.T) { {URL: "http://localhost:8082"}, } - p, err := NewProxy(cfg, targets, nil) + p, err := NewProxy(cfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -1018,7 +1017,7 @@ func TestProxyCache(t *testing.T) { } targets[0].Healthy.Store(true) - p, err := NewProxy(cfg, targets, nil) + p, err := NewProxy(cfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -1071,7 +1070,7 @@ func TestServeHTTP_WithPassiveHealthCheck(t *testing.T) { } targets[0].Healthy.Store(true) - p, err := NewProxy(cfg, targets, nil) + p, err := NewProxy(cfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } @@ -1156,7 +1155,7 @@ func TestUpstreamVariablesCapture(t *testing.T) { }, } - p, err := NewProxy(cfg, targets, nil) + p, err := NewProxy(cfg, targets, nil, nil) if err != nil { t.Fatalf("failed to create proxy: %v", err) } @@ -1239,7 +1238,7 @@ func TestUpstreamVariablesErrorPaths(t *testing.T) { }, } - p, err := NewProxy(cfg, targets, nil) + p, err := NewProxy(cfg, targets, nil, nil) if err != nil { t.Fatalf("failed to create proxy: %v", err) }