refactor(proxy): 提取命名常量并适配变量系统重命名

提取硬编码字符串为命名常量:
- upstreamCache = "CACHE"
- protoHTTPS = "https"
ProxyWebSocket → WebSocket
适配 variable.Context 重命名

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-10 09:40:37 +08:00
parent b28ad378fa
commit 1ce84ce9c5
7 changed files with 53 additions and 45 deletions

View File

@ -25,6 +25,8 @@ import (
"rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/loadbalance"
) )
const healthPath = "/health"
// HealthChecker 对后端目标执行健康检查。 // HealthChecker 对后端目标执行健康检查。
// 它支持主动(定期 HTTP 探测)和被动(基于失败的) // 它支持主动(定期 HTTP 探测)和被动(基于失败的)
// 两种健康检查模式。 // 两种健康检查模式。
@ -97,7 +99,7 @@ func NewHealthChecker(targets []*loadbalance.Target, cfg *config.HealthCheckConf
path := cfg.Path path := cfg.Path
if path == "" { if path == "" {
path = "/health" path = healthPath
} }
return &HealthChecker{ return &HealthChecker{

View File

@ -42,8 +42,8 @@ func TestNewHealthChecker(t *testing.T) {
if checker.GetTimeout() != 5*time.Second { if checker.GetTimeout() != 5*time.Second {
t.Errorf("Timeout = %v, want %v", checker.GetTimeout(), 5*time.Second) t.Errorf("Timeout = %v, want %v", checker.GetTimeout(), 5*time.Second)
} }
if checker.GetPath() != "/health" { if checker.GetPath() != healthPath {
t.Errorf("Path = %q, want %q", checker.GetPath(), "/health") t.Errorf("Path = %q, want %q", checker.GetPath(), healthPath)
} }
if checker.IsRunning() { if checker.IsRunning() {
t.Error("新建的 checker 应未启动") t.Error("新建的 checker 应未启动")
@ -115,7 +115,7 @@ func TestNewHealthChecker(t *testing.T) {
if checker.GetTimeout() != 5*time.Second { if checker.GetTimeout() != 5*time.Second {
t.Errorf("零值 Timeout 应使用默认值got %v", checker.GetTimeout()) t.Errorf("零值 Timeout 应使用默认值got %v", checker.GetTimeout())
} }
if checker.GetPath() != "/health" { if checker.GetPath() != healthPath {
t.Errorf("空 Path 应使用默认值got %q", checker.GetPath()) t.Errorf("空 Path 应使用默认值got %q", checker.GetPath())
} }
}) })
@ -131,7 +131,7 @@ func TestHealthCheckerStartStop(t *testing.T) {
cfg := &config.HealthCheckConfig{ cfg := &config.HealthCheckConfig{
Interval: 1 * time.Hour, Interval: 1 * time.Hour,
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
Path: "/health", Path: healthPath,
} }
checker := NewHealthChecker(targets, cfg) checker := NewHealthChecker(targets, cfg)
@ -200,8 +200,8 @@ func TestHealthCheckerStartStop(t *testing.T) {
func TestCheckTarget(t *testing.T) { func TestCheckTarget(t *testing.T) {
t.Run("健康响应", func(t *testing.T) { t.Run("健康响应", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/health" { if r.URL.Path != healthPath {
t.Errorf("请求路径 = %q, want %q", r.URL.Path, "/health") t.Errorf("请求路径 = %q, want %q", r.URL.Path, healthPath)
} }
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
})) }))
@ -215,7 +215,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour, Interval: 1 * time.Hour,
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
Path: "/health", Path: healthPath,
}) })
checker.checkTarget(target) checker.checkTarget(target)
@ -226,7 +226,7 @@ func TestCheckTarget(t *testing.T) {
}) })
t.Run("不健康响应", func(t *testing.T) { t.Run("不健康响应", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
})) }))
defer server.Close() defer server.Close()
@ -239,7 +239,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour, Interval: 1 * time.Hour,
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
Path: "/health", Path: healthPath,
}) })
checker.checkTarget(target) checker.checkTarget(target)
@ -250,7 +250,7 @@ func TestCheckTarget(t *testing.T) {
}) })
t.Run("超时", func(t *testing.T) { t.Run("超时", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
})) }))
defer server.Close() defer server.Close()
@ -263,7 +263,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour, Interval: 1 * time.Hour,
Timeout: 10 * time.Millisecond, Timeout: 10 * time.Millisecond,
Path: "/health", Path: healthPath,
}) })
checker.checkTarget(target) checker.checkTarget(target)
@ -282,7 +282,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour, Interval: 1 * time.Hour,
Timeout: 100 * time.Millisecond, Timeout: 100 * time.Millisecond,
Path: "/health", Path: healthPath,
}) })
checker.checkTarget(target) checker.checkTarget(target)
@ -293,7 +293,7 @@ func TestCheckTarget(t *testing.T) {
}) })
t.Run("3xx 重定向响应", func(t *testing.T) { t.Run("3xx 重定向响应", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusMovedPermanently) w.WriteHeader(http.StatusMovedPermanently)
})) }))
defer server.Close() defer server.Close()
@ -306,7 +306,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour, Interval: 1 * time.Hour,
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
Path: "/health", Path: healthPath,
}) })
checker.checkTarget(target) checker.checkTarget(target)
@ -317,7 +317,7 @@ func TestCheckTarget(t *testing.T) {
}) })
t.Run("4xx 客户端错误响应", func(t *testing.T) { t.Run("4xx 客户端错误响应", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
})) }))
defer server.Close() defer server.Close()
@ -330,7 +330,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour, Interval: 1 * time.Hour,
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
Path: "/health", Path: healthPath,
}) })
checker.checkTarget(target) checker.checkTarget(target)
@ -352,7 +352,7 @@ func TestCheckTarget(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(tt.statusCode) w.WriteHeader(tt.statusCode)
})) }))
defer server.Close() defer server.Close()
@ -365,7 +365,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour, Interval: 1 * time.Hour,
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
Path: "/health", Path: healthPath,
}) })
checker.checkTarget(target) checker.checkTarget(target)
@ -389,7 +389,7 @@ func TestMarkUnhealthy(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour, Interval: 1 * time.Hour,
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
Path: "/health", Path: healthPath,
}) })
checker.MarkUnhealthy(target) checker.MarkUnhealthy(target)
@ -408,7 +408,7 @@ func TestMarkUnhealthy(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour, Interval: 1 * time.Hour,
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
Path: "/health", Path: healthPath,
}) })
checker.MarkUnhealthy(target) checker.MarkUnhealthy(target)
@ -431,7 +431,7 @@ func TestMarkUnhealthy(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target1, target2}, &config.HealthCheckConfig{ checker := NewHealthChecker([]*loadbalance.Target{target1, target2}, &config.HealthCheckConfig{
Interval: 1 * time.Hour, Interval: 1 * time.Hour,
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
Path: "/health", Path: healthPath,
}) })
checker.MarkUnhealthy(target1) checker.MarkUnhealthy(target1)

View File

@ -51,6 +51,11 @@ import (
"rua.plus/lolly/internal/variable" "rua.plus/lolly/internal/variable"
) )
const (
upstreamCache = "CACHE"
protoHTTPS = "https"
)
// Proxy 表示反向代理实例,负责将 HTTP 请求转发到后端目标。 // Proxy 表示反向代理实例,负责将 HTTP 请求转发到后端目标。
// //
// 它为每个后端目标管理连接池,并提供负载均衡功能。 // 它为每个后端目标管理连接池,并提供负载均衡功能。
@ -272,9 +277,9 @@ func (t *UpstreamTiming) GetResponseTime() float64 {
return t.responseEnd.Sub(t.connectEnd).Seconds() return t.responseEnd.Sub(t.connectEnd).Seconds()
} }
// FinalizeUpstreamVars 在请求处理结束时设置上游变量到 VariableContext // FinalizeUpstreamVars 在请求处理结束时设置上游变量到 Context
// 这个函数应该在 ServeHTTP 的 defer 中调用 // 这个函数应该在 ServeHTTP 的 defer 中调用
func FinalizeUpstreamVars(vc *variable.VariableContext, upstreamAddr string, upstreamStatus int, timing *UpstreamTiming) { func FinalizeUpstreamVars(vc *variable.Context, upstreamAddr string, upstreamStatus int, timing *UpstreamTiming) {
if vc == nil { if vc == nil {
return return
} }
@ -304,7 +309,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
timing := NewUpstreamTiming() timing := NewUpstreamTiming()
// 创建变量上下文用于设置上游变量 // 创建变量上下文用于设置上游变量
vc := variable.NewVariableContext(ctx) vc := variable.NewContext(ctx)
defer func() { defer func() {
// 确保记录了响应结束时间 // 确保记录了响应结束时间
if timing.responseEnd.IsZero() { if timing.responseEnd.IsZero() {
@ -313,7 +318,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
// 设置上游变量 // 设置上游变量
FinalizeUpstreamVars(vc, upstreamAddr, upstreamStatus, timing) FinalizeUpstreamVars(vc, upstreamAddr, upstreamStatus, timing)
// 释放变量上下文 // 释放变量上下文
variable.ReleaseVariableContext(vc) variable.ReleaseContext(vc)
}() }()
// 故障转移配置 // 故障转移配置
@ -376,7 +381,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
// WebSocket 使用 defer 确保连接计数释放 // WebSocket 使用 defer 确保连接计数释放
defer loadbalance.DecrementConnections(target) defer loadbalance.DecrementConnections(target)
timing.MarkConnectStart() timing.MarkConnectStart()
err := ProxyWebSocket(ctx, target, p.config.Timeout.Connect) err := WebSocket(ctx, target, p.config.Timeout.Connect)
timing.MarkConnectEnd() timing.MarkConnectEnd()
if err != nil { if err != nil {
upstreamStatus = 502 upstreamStatus = 502
@ -402,7 +407,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
loadbalance.DecrementConnections(target) loadbalance.DecrementConnections(target)
if !stale { if !stale {
// 新鲜缓存,直接返回 // 新鲜缓存,直接返回
upstreamAddr = "CACHE" upstreamAddr = upstreamCache
upstreamStatus = entry.Status upstreamStatus = entry.Status
p.writeCachedResponse(ctx, entry) p.writeCachedResponse(ctx, entry)
return return
@ -425,7 +430,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
// 重新尝试获取缓存 // 重新尝试获取缓存
if entry, ok, _ := p.cache.Get(hashKey, origKey); ok { if entry, ok, _ := p.cache.Get(hashKey, origKey); ok {
upstreamAddr = "CACHE" upstreamAddr = upstreamCache
upstreamStatus = entry.Status upstreamStatus = entry.Status
p.writeCachedResponse(ctx, entry) p.writeCachedResponse(ctx, entry)
@ -634,7 +639,7 @@ func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient {
// modifyRequestHeaders 在转发到后端之前修改请求头。 // modifyRequestHeaders 在转发到后端之前修改请求头。
// 添加标准代理请求头并应用自定义请求头配置。 // 添加标准代理请求头并应用自定义请求头配置。
func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalance.Target) { func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, _ *loadbalance.Target) {
headers := &ctx.Request.Header headers := &ctx.Request.Header
// 添加 X-Real-IP 请求头 // 添加 X-Real-IP 请求头
@ -660,14 +665,14 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan
// 添加 X-Forwarded-Proto 请求头 // 添加 X-Forwarded-Proto 请求头
proto := "http" proto := "http"
if ctx.IsTLS() { if ctx.IsTLS() {
proto = "https" proto = protoHTTPS
} }
headers.Set("X-Forwarded-Proto", proto) headers.Set("X-Forwarded-Proto", proto)
// 从配置设置自定义请求头(支持变量展开) // 从配置设置自定义请求头(支持变量展开)
if p.config.Headers.SetRequest != nil { if p.config.Headers.SetRequest != nil {
vc := variable.NewVariableContext(ctx) vc := variable.NewContext(ctx)
defer variable.ReleaseVariableContext(vc) defer variable.ReleaseContext(vc)
for key, value := range p.config.Headers.SetRequest { for key, value := range p.config.Headers.SetRequest {
expanded := vc.Expand(value) expanded := vc.Expand(value)
headers.Set(key, expanded) headers.Set(key, expanded)
@ -686,8 +691,8 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan
func (p *Proxy) modifyResponseHeaders(ctx *fasthttp.RequestCtx) { func (p *Proxy) modifyResponseHeaders(ctx *fasthttp.RequestCtx) {
// 从配置设置自定义响应头(支持变量展开) // 从配置设置自定义响应头(支持变量展开)
if p.config.Headers.SetResponse != nil { if p.config.Headers.SetResponse != nil {
vc := variable.NewVariableContext(ctx) vc := variable.NewContext(ctx)
defer variable.ReleaseVariableContext(vc) defer variable.ReleaseContext(vc)
for key, value := range p.config.Headers.SetResponse { for key, value := range p.config.Headers.SetResponse {
expanded := vc.Expand(value) expanded := vc.Expand(value)
ctx.Response.Header.Set(key, expanded) ctx.Response.Header.Set(key, expanded)
@ -712,13 +717,14 @@ func isWebSocketRequest(ctx *fasthttp.RequestCtx) bool {
} }
// handleWebSocket 处理 WebSocket 升级请求(保留用于兼容性,实际逻辑在 ServeHTTP 中) // handleWebSocket 处理 WebSocket 升级请求(保留用于兼容性,实际逻辑在 ServeHTTP 中)
//
//nolint:unused // 保留用于未来 WebSocket 功能扩展 //nolint:unused // 保留用于未来 WebSocket 功能扩展
func (p *Proxy) handleWebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, _ *fasthttp.HostClient) { func (p *Proxy) handleWebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, _ *fasthttp.HostClient) {
timeout := p.config.Timeout.Connect timeout := p.config.Timeout.Connect
if timeout == 0 { if timeout == 0 {
timeout = 30 * time.Second timeout = 30 * time.Second
} }
if err := ProxyWebSocket(ctx, target, timeout); err != nil { if err := WebSocket(ctx, target, timeout); err != nil {
logging.Error().Msgf("WebSocket proxy error: %v", err) logging.Error().Msgf("WebSocket proxy error: %v", err)
} }
} }

View File

@ -170,13 +170,13 @@ func (p *Proxy) getResolverTTL() time.Duration {
} }
// GetResolverStats 返回 DNS 解析器的统计信息。 // GetResolverStats 返回 DNS 解析器的统计信息。
func (p *Proxy) GetResolverStats() resolver.ResolverStats { func (p *Proxy) GetResolverStats() resolver.Stats {
p.mu.RLock() p.mu.RLock()
r := p.resolver r := p.resolver
p.mu.RUnlock() p.mu.RUnlock()
if r == nil { if r == nil {
return resolver.ResolverStats{} return resolver.Stats{}
} }
return r.Stats() return r.Stats()
} }

View File

@ -1287,8 +1287,8 @@ func TestFinalizeUpstreamVars(t *testing.T) {
ctx.Request.Header.SetMethod("GET") ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetRequestURI("/test") ctx.Request.Header.SetRequestURI("/test")
vc := variable.NewVariableContext(ctx) vc := variable.NewContext(ctx)
defer variable.ReleaseVariableContext(vc) defer variable.ReleaseContext(vc)
timing := NewUpstreamTiming() timing := NewUpstreamTiming()
timing.MarkConnectStart() timing.MarkConnectStart()

View File

@ -11,7 +11,7 @@
// //
// 使用示例: // 使用示例:
// //
// err := proxy.ProxyWebSocket(ctx, target, 30*time.Second) // err := proxy.WebSocket(ctx, target, 30*time.Second)
// if err != nil { // if err != nil {
// log.Printf("WebSocket proxy error: %v", err) // log.Printf("WebSocket proxy error: %v", err)
// } // }
@ -337,7 +337,7 @@ func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.R
return resp, nil return resp, nil
} }
// ProxyWebSocket 处理 WebSocket 代理请求。 // WebSocket 处理 WebSocket 代理请求。
// //
// 完整流程: // 完整流程:
// 1. 劫持客户端连接 // 1. 劫持客户端连接
@ -353,7 +353,7 @@ func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.R
// //
// 返回值: // 返回值:
// - error: 代理过程中的错误 // - error: 代理过程中的错误
func ProxyWebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout time.Duration) error { func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout time.Duration) error {
// 使用 Hijack 获取客户端 TCP 连接 // 使用 Hijack 获取客户端 TCP 连接
var clientConn net.Conn var clientConn net.Conn

View File

@ -122,7 +122,7 @@ func TestIsConnectionClosedError(t *testing.T) {
} }
// TestExtractHost 测试从 URL 提取主机 // TestExtractHost 测试从 URL 提取主机
func TestExtractHost(t *testing.T) { func TestExtractHost(_ *testing.T) {
// extractHost 函数可能不存在,检查一下 // extractHost 函数可能不存在,检查一下
// 如果存在则测试 // 如果存在则测试
} }