From 470c82d940827e862c850bb31583cd230cbb4684 Mon Sep 17 00:00:00 2001 From: xfy Date: Thu, 16 Apr 2026 09:54:09 +0800 Subject: [PATCH] =?UTF-8?q?style(proxy,server):=20=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E9=A3=8E=E6=A0=BC=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - headers.go: 添加协议常量 protoHTTP/protoHTTPS - redirect_rewrite.go: 添加模式常量,修正缩进 - proxy_ssl_test.go: 表格测试字段对齐 - server.go: 添加 ServerModeAuto 分支防御性处理 Co-Authored-By: Claude Opus 4.6 --- internal/proxy/headers.go | 10 +++- internal/proxy/proxy.go | 4 +- internal/proxy/proxy_ssl.go | 2 +- internal/proxy/proxy_ssl_test.go | 72 ++++++++++++------------- internal/proxy/redirect_rewrite.go | 40 ++++++++------ internal/proxy/redirect_rewrite_test.go | 5 +- internal/server/server.go | 21 ++++---- 7 files changed, 85 insertions(+), 69 deletions(-) diff --git a/internal/proxy/headers.go b/internal/proxy/headers.go index 4aff16f..ffd1357 100644 --- a/internal/proxy/headers.go +++ b/internal/proxy/headers.go @@ -10,6 +10,12 @@ import ( "rua.plus/lolly/internal/netutil" ) +// 协议常量 +const ( + protoHTTP = "http" + protoHTTPS = "https" +) + // ForwardedHeaders 包含 X-Forwarded 系列头信息。 type ForwardedHeaders struct { ClientIP string // 客户端 IP @@ -28,9 +34,9 @@ func ExtractForwardedHeaders(ctx *fasthttp.RequestCtx) ForwardedHeaders { clientIP := netutil.ExtractClientIP(ctx) host := string(ctx.Host()) - proto := "http" + proto := protoHTTP if ctx.IsTLS() { - proto = "https" + proto = protoHTTPS } return ForwardedHeaders{ diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index c1bad27..f5f5661 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -84,8 +84,8 @@ type Proxy struct { config *config.ProxyConfig cache *cache.ProxyCache healthChecker *HealthChecker - luaEngine *lua.LuaEngine // Lua 引擎引用 - redirectRewriter *RedirectRewriter // 重定向改写器 + luaEngine *lua.LuaEngine // Lua 引擎引用 + redirectRewriter *RedirectRewriter // 重定向改写器 stopCh chan struct{} targets []*loadbalance.Target mu sync.RWMutex diff --git a/internal/proxy/proxy_ssl.go b/internal/proxy/proxy_ssl.go index 7f23d00..19f0143 100644 --- a/internal/proxy/proxy_ssl.go +++ b/internal/proxy/proxy_ssl.go @@ -98,4 +98,4 @@ func CreateTLSConfig(cfg *config.ProxySSLConfig, defaultServerName string) (*tls } return tlsCfg, nil -} \ No newline at end of file +} diff --git a/internal/proxy/proxy_ssl_test.go b/internal/proxy/proxy_ssl_test.go index 6610f2f..9a7eef4 100644 --- a/internal/proxy/proxy_ssl_test.go +++ b/internal/proxy/proxy_ssl_test.go @@ -33,28 +33,28 @@ func TestCreateTLSConfig_Disabled(t *testing.T) { func TestCreateTLSConfig_ServerName(t *testing.T) { tests := []struct { - name string - cfg *config.ProxySSLConfig + name string + cfg *config.ProxySSLConfig defaultServerName string - wantServerName string + wantServerName string }{ { - name: "custom server name", - cfg: &config.ProxySSLConfig{Enabled: true, ServerName: "custom.example.com"}, + name: "custom server name", + cfg: &config.ProxySSLConfig{Enabled: true, ServerName: "custom.example.com"}, defaultServerName: "default.example.com", - wantServerName: "custom.example.com", + wantServerName: "custom.example.com", }, { - name: "default server name", - cfg: &config.ProxySSLConfig{Enabled: true}, + name: "default server name", + cfg: &config.ProxySSLConfig{Enabled: true}, defaultServerName: "default.example.com", - wantServerName: "default.example.com", + wantServerName: "default.example.com", }, { - name: "empty default", - cfg: &config.ProxySSLConfig{Enabled: true}, + name: "empty default", + cfg: &config.ProxySSLConfig{Enabled: true}, defaultServerName: "", - wantServerName: "", + wantServerName: "", }, } @@ -97,11 +97,11 @@ func TestCreateTLSConfig_InsecureSkipVerify(t *testing.T) { func TestCreateTLSConfig_TLSVersions(t *testing.T) { tests := []struct { - name string - minVersion string - maxVersion string - wantMin uint16 - wantMax uint16 + name string + minVersion string + maxVersion string + wantMin uint16 + wantMax uint16 }{ { name: "TLSV1.2 min", @@ -119,11 +119,11 @@ func TestCreateTLSConfig_TLSVersions(t *testing.T) { wantMax: tls.VersionTLS12, }, { - name: "both versions", - minVersion: "TLSV1.2", - maxVersion: "TLSV1.3", - wantMin: tls.VersionTLS12, - wantMax: tls.VersionTLS13, + name: "both versions", + minVersion: "TLSV1.2", + maxVersion: "TLSV1.3", + wantMin: tls.VersionTLS12, + wantMax: tls.VersionTLS13, }, { name: "mixed case TLSv1.2", @@ -135,9 +135,9 @@ func TestCreateTLSConfig_TLSVersions(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := &config.ProxySSLConfig{ - Enabled: true, - MinVersion: tt.minVersion, - MaxVersion: tt.maxVersion, + Enabled: true, + MinVersion: tt.minVersion, + MaxVersion: tt.maxVersion, } tlsCfg, err := CreateTLSConfig(cfg, "example.com") if err != nil { @@ -263,11 +263,11 @@ func TestGetCacheDuration_StatusCodeMapping(t *testing.T) { MaxAge: 1 * time.Minute, }, CacheValid: &config.ProxyCacheValidConfig{ - OK: 10 * time.Minute, - Redirect: 1 * time.Hour, - NotFound: 1 * time.Minute, - ClientError: 30 * time.Second, - ServerError: 0, // 不缓存 + OK: 10 * time.Minute, + Redirect: 1 * time.Hour, + NotFound: 1 * time.Minute, + ClientError: 30 * time.Second, + ServerError: 0, // 不缓存 }, } targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} @@ -310,11 +310,11 @@ func TestGetCacheDuration_ZeroValuesNoCache(t *testing.T) { MaxAge: 5 * time.Minute, }, CacheValid: &config.ProxyCacheValidConfig{ - OK: 10 * time.Minute, // OK 有值 - Redirect: 0, // 不缓存 - NotFound: 0, // 不缓存 - ClientError: 0, // 不缓存 - ServerError: 0, // 不缓存 + OK: 10 * time.Minute, // OK 有值 + Redirect: 0, // 不缓存 + NotFound: 0, // 不缓存 + ClientError: 0, // 不缓存 + ServerError: 0, // 不缓存 }, } targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} @@ -339,4 +339,4 @@ func TestGetCacheDuration_ZeroValuesNoCache(t *testing.T) { t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want) } } -} \ No newline at end of file +} diff --git a/internal/proxy/redirect_rewrite.go b/internal/proxy/redirect_rewrite.go index 764af71..09ec25b 100644 --- a/internal/proxy/redirect_rewrite.go +++ b/internal/proxy/redirect_rewrite.go @@ -10,19 +10,26 @@ import ( "rua.plus/lolly/internal/variable" ) +// RedirectRewrite 模式常量 +const ( + redirectModeDefault = "default" + redirectModeOff = "off" + redirectModeCustom = "custom" +) + // compiledRule 预编译的改写规则 type compiledRule struct { - pattern *regexp.Regexp // 正则模式,nil 表示非正则匹配 - exactMatch string // 精确匹配前缀(用于 prefix 匹配) - replacement string // 替换模板(含变量) - caseInsensitive bool // 正则大小写不敏感(~* 前缀) + pattern *regexp.Regexp // 正则模式,nil 表示非正则匹配 + replacement string // 替换模板(含变量) + exactMatch string // 精确匹配前缀(用于 prefix 匹配) + caseInsensitive bool // 正则大小写不敏感(~* 前缀) } // RedirectRewriter Location/Refresh 头改写器 type RedirectRewriter struct { + proxyPath string // 用于 default 模式(当前代理路径) mode string // "default" | "off" | "custom"(空字符串视为 default) rules []compiledRule // 仅 custom 模式预编译 - proxyPath string // 用于 default 模式(当前代理路径) } // NewRedirectRewriter 创建改写器 @@ -32,7 +39,7 @@ func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (* if cfg == nil { // 未配置时默认启用 default 模式 return &RedirectRewriter{ - mode: "default", + mode: redirectModeDefault, proxyPath: proxyPath, }, nil } @@ -43,7 +50,7 @@ func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (* } // custom 模式:预编译规则 - if cfg.Mode == "custom" { + if cfg.Mode == redirectModeCustom { rules := make([]compiledRule, 0, len(cfg.Rules)) for _, rule := range cfg.Rules { cr := compiledRule{ @@ -52,7 +59,7 @@ func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (* if strings.HasPrefix(rule.Pattern, "~") { // 正则模式 - patternStr := rule.Pattern + var patternStr string if strings.HasPrefix(rule.Pattern, "~*") { cr.caseInsensitive = true patternStr = rule.Pattern[2:] @@ -80,7 +87,7 @@ func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (* // Mode 返回当前模式(处理空字符串默认值) func (r *RedirectRewriter) Mode() string { if r.mode == "" { - return "default" + return redirectModeDefault } return r.mode } @@ -136,13 +143,13 @@ func (r *RedirectRewriter) rewriteURL(headerValue string, ctx *fasthttp.RequestC } switch r.Mode() { - case "off": + case redirectModeOff: return headerValue - case "custom": + case redirectModeCustom: return r.rewriteCustom(headerValue, ctx) - case "default", "": + case redirectModeDefault, "": return r.rewriteDefault(headerValue, ctx, targetURL, originalClientHost) default: @@ -154,7 +161,8 @@ func (r *RedirectRewriter) rewriteURL(headerValue string, ctx *fasthttp.RequestC // 使用前缀匹配:如果 headerValue 以 targetURL 开头,替换为 replacement + 原路径后缀 // replacement 使用 originalClientHost 构建:"$scheme://originalClientHost/" // 例如:targetURL="http://backend:8000", headerValue="http://backend:8000/api/v2/users" -// → 替换为 "$scheme://originalClientHost/api/v2/users" +// +// → 替换为 "$scheme://originalClientHost/api/v2/users" func (r *RedirectRewriter) rewriteDefault(headerValue string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string { if targetURL == "" { return headerValue @@ -167,9 +175,9 @@ func (r *RedirectRewriter) rewriteDefault(headerValue string, ctx *fasthttp.Requ // 检查剩余部分是否以合法分隔符开头 if len(remaining) == 0 || remaining[0] == '/' || remaining[0] == '?' || remaining[0] == '#' { // 使用客户端原始 host 构建 replacement - scheme := "http" + scheme := protoHTTP if ctx.IsTLS() { - scheme = "https" + scheme = protoHTTPS } replacement := scheme + "://" + originalClientHost return replacement + remaining @@ -198,7 +206,7 @@ func (r *RedirectRewriter) rewriteCustom(headerValue string, ctx *fasthttp.Reque return result } } else { - loc := rule.pattern.FindStringIndex(headerValue) + loc := rule.pattern.FindStringIndex(headerValue) if loc != nil { expanded := vc.Expand(rule.replacement) result := headerValue[:loc[0]] + expanded + headerValue[loc[1]:] diff --git a/internal/proxy/redirect_rewrite_test.go b/internal/proxy/redirect_rewrite_test.go index 8fa451f..e8ac78f 100644 --- a/internal/proxy/redirect_rewrite_test.go +++ b/internal/proxy/redirect_rewrite_test.go @@ -3,10 +3,9 @@ package proxy import ( "testing" + "github.com/valyala/fasthttp" "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/testutil" - - "github.com/valyala/fasthttp" ) // TestRedirectRewrite_ExactMatch 测试精确匹配改写 @@ -273,4 +272,4 @@ func TestRedirectRewrite_EmptyMode(t *testing.T) { if rw.Mode() != "default" { t.Errorf("Mode() = %q, want %q", rw.Mode(), "default") } -} \ No newline at end of file +} diff --git a/internal/server/server.go b/internal/server/server.go index ed1b2b1..e3f6f47 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -438,6 +438,9 @@ func (s *Server) Start() error { return s.startVHostMode() case config.ServerModeMultiServer: return s.startMultiServerMode() + case config.ServerModeAuto: + // auto 模式下 GetMode() 会自动推断,此处为防御性处理 + return s.startSingleMode() default: // 默认使用单服务器模式 return s.startSingleMode() @@ -513,10 +516,10 @@ func (s *Server) startSingleMode() error { MaxRequestsPerConn: serverCfg.MaxRequestsPerConn, CloseOnShutdown: true, // 高并发优化配置 - Concurrency: serverCfg.Concurrency, - ReadBufferSize: serverCfg.ReadBufferSize, - WriteBufferSize: serverCfg.WriteBufferSize, - ReduceMemoryUsage: serverCfg.ReduceMemoryUsage, + Concurrency: serverCfg.Concurrency, + ReadBufferSize: serverCfg.ReadBufferSize, + WriteBufferSize: serverCfg.WriteBufferSize, + ReduceMemoryUsage: serverCfg.ReduceMemoryUsage, } s.running = true @@ -638,10 +641,10 @@ func (s *Server) startVHostMode() error { MaxRequestsPerConn: serverCfg.MaxRequestsPerConn, CloseOnShutdown: true, // 高并发优化配置 - Concurrency: serverCfg.Concurrency, - ReadBufferSize: serverCfg.ReadBufferSize, - WriteBufferSize: serverCfg.WriteBufferSize, - ReduceMemoryUsage: serverCfg.ReduceMemoryUsage, + Concurrency: serverCfg.Concurrency, + ReadBufferSize: serverCfg.ReadBufferSize, + WriteBufferSize: serverCfg.WriteBufferSize, + ReduceMemoryUsage: serverCfg.ReduceMemoryUsage, } s.running = true @@ -861,7 +864,7 @@ func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.S routePath := proxyCfg.Path // 确保通配符路由格式正确 if !strings.HasSuffix(routePath, "/") && routePath != "/" { - routePath += "/" + routePath += "/" } wildcardPath := routePath + "{path:*}" router.GET(wildcardPath, p.ServeHTTP)