refactor(proxy): 提取命名常量并适配变量系统重命名

提取硬编码字符串为命名常量:
- upstreamCache = "CACHE"
- protoHTTPS = "https"
ProxyWebSocket → WebSocket
适配 variable.Context 重命名

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-10 09:40:37 +08:00
parent b28ad378fa
commit 1ce84ce9c5
7 changed files with 53 additions and 45 deletions

View File

@ -25,6 +25,8 @@ import (
"rua.plus/lolly/internal/loadbalance"
)
const healthPath = "/health"
// HealthChecker 对后端目标执行健康检查。
// 它支持主动(定期 HTTP 探测)和被动(基于失败的)
// 两种健康检查模式。
@ -97,7 +99,7 @@ func NewHealthChecker(targets []*loadbalance.Target, cfg *config.HealthCheckConf
path := cfg.Path
if path == "" {
path = "/health"
path = healthPath
}
return &HealthChecker{

View File

@ -42,8 +42,8 @@ func TestNewHealthChecker(t *testing.T) {
if checker.GetTimeout() != 5*time.Second {
t.Errorf("Timeout = %v, want %v", checker.GetTimeout(), 5*time.Second)
}
if checker.GetPath() != "/health" {
t.Errorf("Path = %q, want %q", checker.GetPath(), "/health")
if checker.GetPath() != healthPath {
t.Errorf("Path = %q, want %q", checker.GetPath(), healthPath)
}
if checker.IsRunning() {
t.Error("新建的 checker 应未启动")
@ -115,7 +115,7 @@ func TestNewHealthChecker(t *testing.T) {
if checker.GetTimeout() != 5*time.Second {
t.Errorf("零值 Timeout 应使用默认值got %v", checker.GetTimeout())
}
if checker.GetPath() != "/health" {
if checker.GetPath() != healthPath {
t.Errorf("空 Path 应使用默认值got %q", checker.GetPath())
}
})
@ -131,7 +131,7 @@ func TestHealthCheckerStartStop(t *testing.T) {
cfg := &config.HealthCheckConfig{
Interval: 1 * time.Hour,
Timeout: 5 * time.Second,
Path: "/health",
Path: healthPath,
}
checker := NewHealthChecker(targets, cfg)
@ -200,8 +200,8 @@ func TestHealthCheckerStartStop(t *testing.T) {
func TestCheckTarget(t *testing.T) {
t.Run("健康响应", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/health" {
t.Errorf("请求路径 = %q, want %q", r.URL.Path, "/health")
if r.URL.Path != healthPath {
t.Errorf("请求路径 = %q, want %q", r.URL.Path, healthPath)
}
w.WriteHeader(http.StatusOK)
}))
@ -215,7 +215,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour,
Timeout: 5 * time.Second,
Path: "/health",
Path: healthPath,
})
checker.checkTarget(target)
@ -226,7 +226,7 @@ func TestCheckTarget(t *testing.T) {
})
t.Run("不健康响应", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
@ -239,7 +239,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour,
Timeout: 5 * time.Second,
Path: "/health",
Path: healthPath,
})
checker.checkTarget(target)
@ -250,7 +250,7 @@ func TestCheckTarget(t *testing.T) {
})
t.Run("超时", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
time.Sleep(100 * time.Millisecond)
}))
defer server.Close()
@ -263,7 +263,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour,
Timeout: 10 * time.Millisecond,
Path: "/health",
Path: healthPath,
})
checker.checkTarget(target)
@ -282,7 +282,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour,
Timeout: 100 * time.Millisecond,
Path: "/health",
Path: healthPath,
})
checker.checkTarget(target)
@ -293,7 +293,7 @@ func TestCheckTarget(t *testing.T) {
})
t.Run("3xx 重定向响应", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusMovedPermanently)
}))
defer server.Close()
@ -306,7 +306,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour,
Timeout: 5 * time.Second,
Path: "/health",
Path: healthPath,
})
checker.checkTarget(target)
@ -317,7 +317,7 @@ func TestCheckTarget(t *testing.T) {
})
t.Run("4xx 客户端错误响应", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer server.Close()
@ -330,7 +330,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour,
Timeout: 5 * time.Second,
Path: "/health",
Path: healthPath,
})
checker.checkTarget(target)
@ -352,7 +352,7 @@ func TestCheckTarget(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(tt.statusCode)
}))
defer server.Close()
@ -365,7 +365,7 @@ func TestCheckTarget(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour,
Timeout: 5 * time.Second,
Path: "/health",
Path: healthPath,
})
checker.checkTarget(target)
@ -389,7 +389,7 @@ func TestMarkUnhealthy(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour,
Timeout: 5 * time.Second,
Path: "/health",
Path: healthPath,
})
checker.MarkUnhealthy(target)
@ -408,7 +408,7 @@ func TestMarkUnhealthy(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{
Interval: 1 * time.Hour,
Timeout: 5 * time.Second,
Path: "/health",
Path: healthPath,
})
checker.MarkUnhealthy(target)
@ -431,7 +431,7 @@ func TestMarkUnhealthy(t *testing.T) {
checker := NewHealthChecker([]*loadbalance.Target{target1, target2}, &config.HealthCheckConfig{
Interval: 1 * time.Hour,
Timeout: 5 * time.Second,
Path: "/health",
Path: healthPath,
})
checker.MarkUnhealthy(target1)

View File

@ -51,6 +51,11 @@ import (
"rua.plus/lolly/internal/variable"
)
const (
upstreamCache = "CACHE"
protoHTTPS = "https"
)
// Proxy 表示反向代理实例,负责将 HTTP 请求转发到后端目标。
//
// 它为每个后端目标管理连接池,并提供负载均衡功能。
@ -272,9 +277,9 @@ func (t *UpstreamTiming) GetResponseTime() float64 {
return t.responseEnd.Sub(t.connectEnd).Seconds()
}
// FinalizeUpstreamVars 在请求处理结束时设置上游变量到 VariableContext
// FinalizeUpstreamVars 在请求处理结束时设置上游变量到 Context
// 这个函数应该在 ServeHTTP 的 defer 中调用
func FinalizeUpstreamVars(vc *variable.VariableContext, upstreamAddr string, upstreamStatus int, timing *UpstreamTiming) {
func FinalizeUpstreamVars(vc *variable.Context, upstreamAddr string, upstreamStatus int, timing *UpstreamTiming) {
if vc == nil {
return
}
@ -304,7 +309,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
timing := NewUpstreamTiming()
// 创建变量上下文用于设置上游变量
vc := variable.NewVariableContext(ctx)
vc := variable.NewContext(ctx)
defer func() {
// 确保记录了响应结束时间
if timing.responseEnd.IsZero() {
@ -313,7 +318,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
// 设置上游变量
FinalizeUpstreamVars(vc, upstreamAddr, upstreamStatus, timing)
// 释放变量上下文
variable.ReleaseVariableContext(vc)
variable.ReleaseContext(vc)
}()
// 故障转移配置
@ -376,7 +381,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
// WebSocket 使用 defer 确保连接计数释放
defer loadbalance.DecrementConnections(target)
timing.MarkConnectStart()
err := ProxyWebSocket(ctx, target, p.config.Timeout.Connect)
err := WebSocket(ctx, target, p.config.Timeout.Connect)
timing.MarkConnectEnd()
if err != nil {
upstreamStatus = 502
@ -402,7 +407,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
loadbalance.DecrementConnections(target)
if !stale {
// 新鲜缓存,直接返回
upstreamAddr = "CACHE"
upstreamAddr = upstreamCache
upstreamStatus = entry.Status
p.writeCachedResponse(ctx, entry)
return
@ -425,7 +430,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
// 重新尝试获取缓存
if entry, ok, _ := p.cache.Get(hashKey, origKey); ok {
upstreamAddr = "CACHE"
upstreamAddr = upstreamCache
upstreamStatus = entry.Status
p.writeCachedResponse(ctx, entry)
@ -634,7 +639,7 @@ func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient {
// modifyRequestHeaders 在转发到后端之前修改请求头。
// 添加标准代理请求头并应用自定义请求头配置。
func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalance.Target) {
func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, _ *loadbalance.Target) {
headers := &ctx.Request.Header
// 添加 X-Real-IP 请求头
@ -660,14 +665,14 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan
// 添加 X-Forwarded-Proto 请求头
proto := "http"
if ctx.IsTLS() {
proto = "https"
proto = protoHTTPS
}
headers.Set("X-Forwarded-Proto", proto)
// 从配置设置自定义请求头(支持变量展开)
if p.config.Headers.SetRequest != nil {
vc := variable.NewVariableContext(ctx)
defer variable.ReleaseVariableContext(vc)
vc := variable.NewContext(ctx)
defer variable.ReleaseContext(vc)
for key, value := range p.config.Headers.SetRequest {
expanded := vc.Expand(value)
headers.Set(key, expanded)
@ -686,8 +691,8 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan
func (p *Proxy) modifyResponseHeaders(ctx *fasthttp.RequestCtx) {
// 从配置设置自定义响应头(支持变量展开)
if p.config.Headers.SetResponse != nil {
vc := variable.NewVariableContext(ctx)
defer variable.ReleaseVariableContext(vc)
vc := variable.NewContext(ctx)
defer variable.ReleaseContext(vc)
for key, value := range p.config.Headers.SetResponse {
expanded := vc.Expand(value)
ctx.Response.Header.Set(key, expanded)
@ -712,13 +717,14 @@ func isWebSocketRequest(ctx *fasthttp.RequestCtx) bool {
}
// handleWebSocket 处理 WebSocket 升级请求(保留用于兼容性,实际逻辑在 ServeHTTP 中)
// nolint:unused // 保留用于未来 WebSocket 功能扩展
//
//nolint:unused // 保留用于未来 WebSocket 功能扩展
func (p *Proxy) handleWebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, _ *fasthttp.HostClient) {
timeout := p.config.Timeout.Connect
if timeout == 0 {
timeout = 30 * time.Second
}
if err := ProxyWebSocket(ctx, target, timeout); err != nil {
if err := WebSocket(ctx, target, timeout); err != nil {
logging.Error().Msgf("WebSocket proxy error: %v", err)
}
}

View File

@ -170,13 +170,13 @@ func (p *Proxy) getResolverTTL() time.Duration {
}
// GetResolverStats 返回 DNS 解析器的统计信息。
func (p *Proxy) GetResolverStats() resolver.ResolverStats {
func (p *Proxy) GetResolverStats() resolver.Stats {
p.mu.RLock()
r := p.resolver
p.mu.RUnlock()
if r == nil {
return resolver.ResolverStats{}
return resolver.Stats{}
}
return r.Stats()
}

View File

@ -1287,8 +1287,8 @@ func TestFinalizeUpstreamVars(t *testing.T) {
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetRequestURI("/test")
vc := variable.NewVariableContext(ctx)
defer variable.ReleaseVariableContext(vc)
vc := variable.NewContext(ctx)
defer variable.ReleaseContext(vc)
timing := NewUpstreamTiming()
timing.MarkConnectStart()

View File

@ -11,7 +11,7 @@
//
// 使用示例:
//
// err := proxy.ProxyWebSocket(ctx, target, 30*time.Second)
// err := proxy.WebSocket(ctx, target, 30*time.Second)
// if err != nil {
// log.Printf("WebSocket proxy error: %v", err)
// }
@ -337,7 +337,7 @@ func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.R
return resp, nil
}
// ProxyWebSocket 处理 WebSocket 代理请求。
// WebSocket 处理 WebSocket 代理请求。
//
// 完整流程:
// 1. 劫持客户端连接
@ -353,7 +353,7 @@ func readWebSocketUpgradeResponse(conn net.Conn, timeout time.Duration) (*http.R
//
// 返回值:
// - error: 代理过程中的错误
func ProxyWebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout time.Duration) error {
func WebSocket(ctx *fasthttp.RequestCtx, target *loadbalance.Target, timeout time.Duration) error {
// 使用 Hijack 获取客户端 TCP 连接
var clientConn net.Conn

View File

@ -122,7 +122,7 @@ func TestIsConnectionClosedError(t *testing.T) {
}
// TestExtractHost 测试从 URL 提取主机
func TestExtractHost(t *testing.T) {
func TestExtractHost(_ *testing.T) {
// extractHost 函数可能不存在,检查一下
// 如果存在则测试
}