feat(lua): 添加 balancer_by_lua 动态负载均衡功能

- 新增 BalancerByLuaConfig 配置,支持 Lua 脚本控制后端选择
- 实现 api_balancer.go Lua API,暴露 set_current_peer 等函数
- Proxy 集成 Lua 引擎,fallback 到标准算法确保可靠性
- 添加负载均衡算法常量提取,消除魔法字符串
- 支持超时控制和备用算法配置

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-13 16:14:59 +08:00
parent 103e8ff0cf
commit f31e8afeff
7 changed files with 360 additions and 71 deletions

3
go.mod
View File

@ -6,7 +6,9 @@ require (
github.com/andybalholm/brotli v1.2.1
github.com/fasthttp/router v1.5.4
github.com/google/uuid v1.6.0
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/klauspost/compress v1.18.5
github.com/oschwald/geoip2-golang v1.13.0
github.com/quic-go/quic-go v0.59.0
github.com/rs/zerolog v1.35.0
github.com/stretchr/testify v1.11.1
@ -22,6 +24,7 @@ require (
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.21 // indirect
github.com/oschwald/maxminddb-golang v1.13.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/savsgio/gotils v0.0.0-20250924091648-bce9a52d7761 // indirect

25
go.sum
View File

@ -1,5 +1,3 @@
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/andybalholm/brotli v1.2.1 h1:R+f5xP285VArJDRgowrfb9DqL18yVK0gKAW/F+eTWro=
github.com/andybalholm/brotli v1.2.1/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
@ -9,8 +7,8 @@ github.com/fasthttp/router v1.5.4 h1:oxdThbBwQgsDIYZ3wR1IavsNl6ZS9WdjKukeMikOnC8
github.com/fasthttp/router v1.5.4/go.mod h1:3/hysWq6cky7dTfzaaEPZGdptwjwx0qzTgFCKEWRjgc=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@ -19,10 +17,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs=
github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
github.com/oschwald/geoip2-golang v1.13.0 h1:Q44/Ldc703pasJeP5V9+aFSZFmBN7DKHbNsSFzQATJI=
github.com/oschwald/geoip2-golang v1.13.0/go.mod h1:P9zG+54KPEFOliZ29i7SeYZ/GM6tfEL+rgSn03hYuUo=
github.com/oschwald/maxminddb-golang v1.13.0 h1:R8xBorY71s84yO06NgTmQvqvTvlS/bnYZrrWX1MElnU=
github.com/oschwald/maxminddb-golang v1.13.0/go.mod h1:BU0z8BfFVhi1LQaonTwwGQlsHUEu9pWNdMfmq4ztm0o=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
@ -33,16 +33,12 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rs/zerolog v1.35.0 h1:VD0ykx7HMiMJytqINBsKcbLS+BJ4WYjz+05us+LRTdI=
github.com/rs/zerolog v1.35.0/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw=
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55+q0lFuGse6U8lxlg7fYTctlT5Gc=
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg=
github.com/savsgio/gotils v0.0.0-20250924091648-bce9a52d7761 h1:McifyVxygw1d67y6vxUqls2D46J8W9nrki9c8c0eVvE=
github.com/savsgio/gotils v0.0.0-20250924091648-bce9a52d7761/go.mod h1:Vi9gvHvTw4yCUHIznFl5TPULS7aXwgaTByGeBY75Wko=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZyVI=
github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw=
github.com/valyala/fasthttp v1.70.0 h1:LAhMGcWk13QZWm85+eg8ZBNbrq5mnkWFGbHMUJHIdXA=
github.com/valyala/fasthttp v1.70.0/go.mod h1:oDZEHHkJ/Buyklg6uURmYs19442zFSnCIfX3j1FY3pE=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
@ -51,21 +47,12 @@ github.com/yuin/gopher-lua v1.1.2 h1:yF/FjE3hD65tBbt0VXLE13HWS9h34fdzJmrWRXwobGA
github.com/yuin/gopher-lua v1.1.2/go.mod h1:7aRmXIWl37SqRf0koeyylBEzJ+aPt8A+mmkQ4f1ntR8=
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@ -116,7 +116,7 @@ func generateConfig(outputPath string) int {
if outputPath == "" {
fmt.Print(string(yamlData))
} else {
if err := os.WriteFile(outputPath, yamlData, 0644); err != nil {
if err := os.WriteFile(outputPath, yamlData, 0o644); err != nil {
fmt.Fprintf(os.Stderr, "写入文件失败: %v\n", err)
return 1
}
@ -285,7 +285,7 @@ func (a *App) Run() int {
a.upgradeMgr = server.NewUpgradeManager(a.srv)
if a.pidFile != "" {
a.upgradeMgr.SetPidFile(a.pidFile)
_ = a.upgradeMgr.WritePid() //nolint:errcheck
_ = a.upgradeMgr.WritePid()
}
// 启动信号处理
@ -365,7 +365,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("优雅停止(等待 %v", timeout))
a.shutdownHTTP2()
a.shutdownHTTP3()
_ = a.srv.GracefulStop(timeout) //nolint:errcheck
_ = a.srv.GracefulStop(timeout)
return false
case syscall.SIGTERM, syscall.SIGINT:
@ -374,11 +374,15 @@ func (a *App) handleSignal(sig os.Signal) bool {
if timeout <= 0 {
timeout = 5 * time.Second // 默认值
}
sigTyped := sig.(syscall.Signal) //nolint:errcheck // 类型断言
a.logger.LogSignal(sigName(sigTyped), "停止服务器")
sigTyped, ok := sig.(syscall.Signal)
if !ok {
a.logger.LogSignal("unknown", "停止服务器")
} else {
a.logger.LogSignal(sigName(sigTyped), "停止服务器")
}
a.shutdownHTTP2()
a.shutdownHTTP3()
_ = a.srv.StopWithTimeout(timeout) //nolint:errcheck // 使用新方法
_ = a.srv.StopWithTimeout(timeout)
return false
case syscall.SIGHUP:
@ -488,7 +492,7 @@ func (a *App) gracefulUpgrade() {
}
a.shutdownHTTP2()
a.shutdownHTTP3()
_ = a.srv.GracefulStop(timeout) //nolint:errcheck
_ = a.srv.GracefulStop(timeout)
}
// sigName 返回信号名称(用于日志输出)。

View File

@ -24,7 +24,19 @@ import (
"time"
)
// Target 表示负载均衡的后端服务器目标。
// Target 表示 HTTP 代理L7 层)的负载均衡后端服务器目标。
//
// HTTP Target 特性(区别于 Stream Target
// - URL 解析:支持完整 URL如 http://backend:8080包含协议、路径、查询参数
// - DNS 动态解析resolvedIPs 和 lastResolved 字段支持 DNS TTL 缓存和动态重解析
// - Failover 支持:配合 Balancer.SelectExcluding 实现失败节点排除重试
// - 一致性哈希VirtualHashes 支持一致性哈希算法的虚拟节点
//
// 语义差异说明:
// - HTTP 代理工作在应用层L7需要处理 URL 和 DNS 解析
// - Stream 代理工作在传输层L4只需简单 host:port无需 DNS 缓存
// - 因此 HTTP Target 和 Stream Target 必须保持独立定义,不可合并
//
// 所有字段都设计为使用原子操作进行并发访问(如适用)。
type Target struct {
resolvedIPs atomic.Pointer[[]string]
@ -37,7 +49,17 @@ type Target struct {
Healthy atomic.Bool
}
// Balancer 是负载均衡算法的接口。
// Balancer 是 HTTP 代理L7 层)负载均衡算法的接口。
//
// HTTP Balancer 特性(区别于 Stream Balancer
// - Select(): 标准选择方法,按算法策略选择健康目标
// - SelectExcluding(): 故障转移支持,排除失败节点后选择替代目标
//
// 语义差异说明:
// - HTTP 代理需要 failover 重试能力next_upstream 配置),因此需要 SelectExcluding
// - Stream 代理工作在传输层L4无重试机制仅需要 Select 方法
// - 因此 HTTP Balancer 和 Stream Balancer 接口签名不同,不可合并
//
// 实现必须是并发安全的。
type Balancer interface {
// Select 根据算法策略从提供的列表中选择一个目标。

View File

@ -0,0 +1,140 @@
// Package lua 提供 ngx.balancer API 实现
// 本文件实现负载均衡相关的 Lua API用于在 Lua 脚本中选择后端目标
package lua
import (
"net/url"
"strings"
glua "github.com/yuin/gopher-lua"
"rua.plus/lolly/internal/loadbalance"
)
// BalancerContext Lua Balancer 上下文
type BalancerContext struct {
LastError error
Selected *loadbalance.Target
ClientIP string
Targets []*loadbalance.Target
Retries int
selected bool
}
// RegisterBalancerAPI 注册 ngx.balancer API
func RegisterBalancerAPI(L *glua.LState, bctx *BalancerContext, ngx *glua.LTable) {
balancer := L.NewTable()
// set_current_peer(host, port) 或 set_current_peer(url)
L.SetField(balancer, "set_current_peer", L.NewFunction(func(L *glua.LState) int {
nargs := L.GetTop()
var host, port string
if nargs >= 2 {
// set_current_peer(host, port) 形式
host = L.CheckString(1)
port = L.CheckString(2)
if !strings.HasPrefix(port, ":") {
port = ":" + port
}
} else if nargs == 1 {
// set_current_peer(url) 形式
targetURL := L.CheckString(1)
u, err := url.Parse(targetURL)
if err != nil {
L.Push(glua.LBool(false))
L.Push(glua.LString("invalid url: " + err.Error()))
return 2
}
host = u.Hostname()
port = ":" + u.Port()
if u.Port() == "" {
if u.Scheme == "https" {
port = ":443"
} else {
port = ":80"
}
}
} else {
L.RaiseError("set_current_peer requires 1 or 2 arguments")
return 0
}
// 在 Targets 中查找匹配的目标
targetURL := "http://" + host + port
for _, t := range bctx.Targets {
if t.URL == targetURL || strings.HasPrefix(t.URL, targetURL) {
bctx.Selected = t
bctx.selected = true
L.Push(glua.LBool(true))
return 1
}
}
L.Push(glua.LBool(false))
L.Push(glua.LString("target not found: " + host + port))
return 2
}))
// set_more_tries(count)
L.SetField(balancer, "set_more_tries", L.NewFunction(func(L *glua.LState) int {
count := L.CheckInt(1)
bctx.Retries = count
L.Push(glua.LBool(true))
return 1
}))
// get_last_failure()
L.SetField(balancer, "get_last_failure", L.NewFunction(func(L *glua.LState) int {
if bctx.LastError == nil {
L.Push(glua.LNil)
return 1
}
// 返回失败类型: "failed", "timeout", "next"
failType := classifyError(bctx.LastError)
L.Push(glua.LString(failType))
return 1
}))
// get_targets() - 返回所有可用目标
L.SetField(balancer, "get_targets", L.NewFunction(func(L *glua.LState) int {
targetsTable := L.NewTable()
for i, t := range bctx.Targets {
targetTable := L.NewTable()
L.SetField(targetTable, "url", glua.LString(t.URL))
L.SetField(targetTable, "weight", glua.LNumber(t.Weight))
L.SetField(targetTable, "healthy", glua.LBool(t.Healthy.Load()))
targetsTable.RawSetInt(i+1, targetTable)
}
L.Push(targetsTable)
return 1
}))
// get_client_ip()
L.SetField(balancer, "get_client_ip", L.NewFunction(func(L *glua.LState) int {
L.Push(glua.LString(bctx.ClientIP))
return 1
}))
L.SetField(ngx, "balancer", balancer)
}
// IsSelected 检查是否调用了 set_current_peer
func (bctx *BalancerContext) IsSelected() bool {
return bctx.selected
}
// classifyError 分类错误类型
func classifyError(err error) string {
if err == nil {
return ""
}
// 根据错误类型返回字符串
errStr := err.Error()
if strings.Contains(errStr, "timeout") {
return "timeout"
}
if strings.Contains(errStr, "connection") {
return "failed"
}
return "failed"
}

View File

@ -33,6 +33,7 @@
package proxy
import (
"context"
"errors"
"fmt"
"hash/fnv"
@ -43,10 +44,12 @@ import (
"time"
"github.com/valyala/fasthttp"
glua "github.com/yuin/gopher-lua"
"rua.plus/lolly/internal/cache"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/logging"
"rua.plus/lolly/internal/lua"
"rua.plus/lolly/internal/netutil"
"rua.plus/lolly/internal/resolver"
"rua.plus/lolly/internal/utils"
@ -57,6 +60,13 @@ const (
// upstreamCache 上游缓存标识
// 用于标记请求可直接使用缓存响应,无需转发到上游
upstreamCache = "CACHE"
// 负载均衡算法名称
lbRoundRobin = "round_robin"
lbWeightedRoundRobin = "weighted_round_robin"
lbLeastConn = "least_conn"
lbIPHash = "ip_hash"
lbConsistentHash = "consistent_hash"
)
// Proxy 表示反向代理实例,负责将 HTTP 请求转发到后端目标。
@ -67,16 +77,18 @@ const (
// - 所有公开方法均为并发安全
// - 使用前需确保 targets 中至少有一个健康目标
type Proxy struct {
balancer loadbalance.Balancer
resolver resolver.Resolver
clients map[string]*fasthttp.HostClient
config *config.ProxyConfig
cache *cache.ProxyCache
healthChecker *HealthChecker
stopCh chan struct{}
targets []*loadbalance.Target
mu sync.RWMutex
started atomic.Bool
balancer loadbalance.Balancer
fallbackBalancer loadbalance.Balancer // Lua 失败时的备用均衡器
resolver resolver.Resolver
clients map[string]*fasthttp.HostClient
config *config.ProxyConfig
cache *cache.ProxyCache
healthChecker *HealthChecker
luaEngine *lua.LuaEngine // Lua 引擎引用
stopCh chan struct{}
targets []*loadbalance.Target
mu sync.RWMutex
started atomic.Bool
}
// NewProxy 使用给定的配置和后台目标创建一个新的反向代理实例。
@ -86,11 +98,12 @@ type Proxy struct {
// - cfg: 代理配置,包括超时时间、请求头和负载均衡策略
// - targets: 要代理请求的后端目标列表
// - transportCfg: 可选的 Transport 连接池配置nil 时使用默认值
// - luaEngine: 可选的 Lua 引擎,用于 balancer_by_lua 功能
//
// 返回值:
// - *Proxy: 配置完成并可处理请求的代理实例
// - error: 初始化失败时非空(无效配置、没有健康目标等)
func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportCfg *config.TransportConfig) (*Proxy, error) {
func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportCfg *config.TransportConfig, luaEngine *lua.LuaEngine) (*Proxy, error) {
if cfg == nil {
return nil, errors.New("proxy config is nil")
}
@ -105,12 +118,24 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC
return nil, err
}
// 创建 fallback 负载均衡器
fallbackAlgo := cfg.BalancerByLua.Fallback
if fallbackAlgo == "" {
fallbackAlgo = lbRoundRobin
}
fallbackBalancer, err := createBalancerByName(fallbackAlgo, cfg)
if err != nil {
return nil, fmt.Errorf("create fallback balancer: %w", err)
}
p := &Proxy{
targets: targets,
clients: make(map[string]*fasthttp.HostClient),
balancer: balancer,
config: cfg,
stopCh: make(chan struct{}),
targets: targets,
clients: make(map[string]*fasthttp.HostClient),
balancer: balancer,
fallbackBalancer: fallbackBalancer,
config: cfg,
luaEngine: luaEngine,
stopCh: make(chan struct{}),
}
// 为每个后端目标初始化 HostClient
@ -138,6 +163,28 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC
return p, nil
}
// createBalancerByName 根据算法名称创建负载均衡器
func createBalancerByName(name string, cfg *config.ProxyConfig) (loadbalance.Balancer, error) {
switch name {
case lbRoundRobin, "":
return loadbalance.NewRoundRobin(), nil
case lbWeightedRoundRobin:
return loadbalance.NewWeightedRoundRobin(), nil
case lbLeastConn:
return loadbalance.NewLeastConnections(), nil
case lbIPHash:
return loadbalance.NewIPHash(), nil
case lbConsistentHash:
virtualNodes := cfg.VirtualNodes
if virtualNodes <= 0 {
virtualNodes = 150
}
return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil
default:
return nil, errors.New("unsupported load balance algorithm: " + name)
}
}
// SetHealthChecker 设置健康检查器用于被动健康检查。
// 当代理请求失败时,将调用健康检查器的 MarkUnhealthy 方法。
func (p *Proxy) SetHealthChecker(hc *HealthChecker) {
@ -147,15 +194,15 @@ func (p *Proxy) SetHealthChecker(hc *HealthChecker) {
// createBalancer 根据配置的算法创建负载均衡器。
func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) {
switch cfg.LoadBalance {
case "round_robin", "":
case lbRoundRobin, "":
return loadbalance.NewRoundRobin(), nil
case "weighted_round_robin":
case lbWeightedRoundRobin:
return loadbalance.NewWeightedRoundRobin(), nil
case "least_conn":
case lbLeastConn:
return loadbalance.NewLeastConnections(), nil
case "ip_hash":
case lbIPHash:
return loadbalance.NewIPHash(), nil
case "consistent_hash":
case lbConsistentHash:
virtualNodes := cfg.VirtualNodes
if virtualNodes <= 0 {
virtualNodes = 150
@ -529,13 +576,10 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
}
}
// selectTarget 使用配置的负载均衡器选择后端目标。
// 对于 IP 哈希负载均衡,从请求中提取客户端 IP。
// 对于一致性哈希,根据配置的 hash_key 选择目标。
// 如果没有可用的健康目标则返回 nil。
// selectTarget 使用配置的负载均衡器选择后端目标
// 如果启用 Lua balancer先尝试 Lua 脚本选择
func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
p.mu.RLock()
balancer := p.balancer
targets := p.targets
p.mu.RUnlock()
@ -543,6 +587,96 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
return nil
}
// 检查是否启用 Lua balancer
if p.config.BalancerByLua.Enabled && p.config.BalancerByLua.Script != "" && p.luaEngine != nil {
target, err := p.selectByLua(ctx, targets)
if err != nil {
logging.Warn().Err(err).Msg("lua balancer failed, using fallback")
// Lua 失败,使用 fallback 算法
return p.selectByFallback(ctx, targets)
}
if target != nil {
return target
}
// Lua 未调用 set_current_peer使用 fallback
logging.Debug().Msg("lua balancer did not select target, using fallback")
return p.selectByFallback(ctx, targets)
}
// 使用传统负载均衡算法
return p.selectByBalancer(ctx, targets)
}
// selectByLua 使用 Lua 脚本选择目标
func (p *Proxy) selectByLua(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) (*loadbalance.Target, error) {
clientIP := netutil.ExtractClientIP(ctx)
bctx := &lua.BalancerContext{
Targets: targets,
ClientIP: clientIP,
Retries: p.config.NextUpstream.Tries,
}
// 创建 Lua 协程
coro, err := p.luaEngine.NewCoroutine(ctx)
if err != nil {
return nil, fmt.Errorf("create lua coroutine: %w", err)
}
defer coro.Close()
// 注册 balancer API
L := coro.Co
ngx, ok := L.GetGlobal("ngx").(*glua.LTable)
if !ok {
return nil, fmt.Errorf("global 'ngx' is not an LTable")
}
lua.RegisterBalancerAPI(L, bctx, ngx)
// 设置超时
timeout := p.config.BalancerByLua.Timeout
if timeout <= 0 {
timeout = 100 * time.Millisecond
}
// 执行脚本(带超时)
execCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
coro.ExecutionContext = execCtx
err = coro.ExecuteFile(p.config.BalancerByLua.Script)
if err != nil {
return nil, fmt.Errorf("execute lua script: %w", err)
}
// 检查是否调用了 set_current_peer
if !bctx.IsSelected() {
return nil, nil // 未选择,返回 nil 表示需使用 fallback
}
return bctx.Selected, nil
}
// selectByFallback 使用 fallback 算法选择目标
func (p *Proxy) selectByFallback(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target {
p.mu.RLock()
balancer := p.fallbackBalancer
p.mu.RUnlock()
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
clientIP := netutil.ExtractClientIP(ctx)
return ipHash.SelectByIP(targets, clientIP)
}
return balancer.Select(targets)
}
// selectByBalancer 使用主负载均衡器选择目标
func (p *Proxy) selectByBalancer(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target {
p.mu.RLock()
balancer := p.balancer
p.mu.RUnlock()
// 对于 IPHash 负载均衡器,提取客户端 IP
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
clientIP := netutil.ExtractClientIP(ctx)

View File

@ -25,7 +25,6 @@ import (
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/fasthttputil"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/netutil"
@ -140,7 +139,7 @@ func TestNewProxy(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := NewProxy(tt.cfg, tt.targets, nil)
p, err := NewProxy(tt.cfg, tt.targets, nil, nil)
if tt.wantErr {
if err == nil {
t.Errorf("NewProxy() expected error containing %q, got nil", tt.errContains)
@ -185,7 +184,7 @@ func TestServeHTTP_NoHealthyTargets(t *testing.T) {
targets[0].Healthy.Store(false)
targets[1].Healthy.Store(false)
p, err := NewProxy(cfg, targets, nil)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
@ -233,7 +232,7 @@ func TestServeHTTP_RequestForwarding(t *testing.T) {
{URL: "http://localhost:8080"},
}
p, err := NewProxy(cfg, targets, nil)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
@ -323,7 +322,7 @@ func TestSelectTarget(t *testing.T) {
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
p, err := NewProxy(cfg, tt.targets, nil)
p, err := NewProxy(cfg, tt.targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
@ -437,7 +436,7 @@ func TestModifyRequestHeaders(t *testing.T) {
{URL: "http://localhost:8080"},
}
p, err := NewProxy(cfg, targets, nil)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
@ -528,7 +527,7 @@ func TestModifyResponseHeaders(t *testing.T) {
{URL: "http://localhost:8080"},
}
p, err := NewProxy(cfg, targets, nil)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
@ -612,7 +611,7 @@ func TestUpdateTargets(t *testing.T) {
{URL: "http://old2:8080"},
}
p, err := NewProxy(cfg, initialTargets, nil)
p, err := NewProxy(cfg, initialTargets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
@ -661,7 +660,7 @@ func TestGetTargets(t *testing.T) {
{URL: "http://backend2:8080"},
}
p, err := NewProxy(cfg, targets, nil)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
@ -690,7 +689,7 @@ func TestGetConfig(t *testing.T) {
{URL: "http://localhost:8080"},
}
p, err := NewProxy(cfg, targets, nil)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
@ -882,7 +881,7 @@ func TestHandleWebSocket(t *testing.T) {
{URL: "http://localhost:8080"},
}
p, err := NewProxy(cfg, targets, nil)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
@ -916,7 +915,7 @@ func TestSetHealthChecker(t *testing.T) {
{URL: "http://localhost:8081"},
}
p, err := NewProxy(cfg, targets, nil)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
@ -954,7 +953,7 @@ func TestGetClient(t *testing.T) {
{URL: "http://localhost:8082"},
}
p, err := NewProxy(cfg, targets, nil)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
@ -1018,7 +1017,7 @@ func TestProxyCache(t *testing.T) {
}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
@ -1071,7 +1070,7 @@ func TestServeHTTP_WithPassiveHealthCheck(t *testing.T) {
}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
@ -1156,7 +1155,7 @@ func TestUpstreamVariablesCapture(t *testing.T) {
},
}
p, err := NewProxy(cfg, targets, nil)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("failed to create proxy: %v", err)
}
@ -1239,7 +1238,7 @@ func TestUpstreamVariablesErrorPaths(t *testing.T) {
},
}
p, err := NewProxy(cfg, targets, nil)
p, err := NewProxy(cfg, targets, nil, nil)
if err != nil {
t.Fatalf("failed to create proxy: %v", err)
}