style(proxy,server): 代码风格优化

- headers.go: 添加协议常量 protoHTTP/protoHTTPS
- redirect_rewrite.go: 添加模式常量,修正缩进
- proxy_ssl_test.go: 表格测试字段对齐
- server.go: 添加 ServerModeAuto 分支防御性处理

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-16 09:54:09 +08:00
parent d874f97765
commit 470c82d940
7 changed files with 85 additions and 69 deletions

View File

@ -10,6 +10,12 @@ import (
"rua.plus/lolly/internal/netutil" "rua.plus/lolly/internal/netutil"
) )
// 协议常量
const (
protoHTTP = "http"
protoHTTPS = "https"
)
// ForwardedHeaders 包含 X-Forwarded 系列头信息。 // ForwardedHeaders 包含 X-Forwarded 系列头信息。
type ForwardedHeaders struct { type ForwardedHeaders struct {
ClientIP string // 客户端 IP ClientIP string // 客户端 IP
@ -28,9 +34,9 @@ func ExtractForwardedHeaders(ctx *fasthttp.RequestCtx) ForwardedHeaders {
clientIP := netutil.ExtractClientIP(ctx) clientIP := netutil.ExtractClientIP(ctx)
host := string(ctx.Host()) host := string(ctx.Host())
proto := "http" proto := protoHTTP
if ctx.IsTLS() { if ctx.IsTLS() {
proto = "https" proto = protoHTTPS
} }
return ForwardedHeaders{ return ForwardedHeaders{

View File

@ -84,8 +84,8 @@ type Proxy struct {
config *config.ProxyConfig config *config.ProxyConfig
cache *cache.ProxyCache cache *cache.ProxyCache
healthChecker *HealthChecker healthChecker *HealthChecker
luaEngine *lua.LuaEngine // Lua 引擎引用 luaEngine *lua.LuaEngine // Lua 引擎引用
redirectRewriter *RedirectRewriter // 重定向改写器 redirectRewriter *RedirectRewriter // 重定向改写器
stopCh chan struct{} stopCh chan struct{}
targets []*loadbalance.Target targets []*loadbalance.Target
mu sync.RWMutex mu sync.RWMutex

View File

@ -98,4 +98,4 @@ func CreateTLSConfig(cfg *config.ProxySSLConfig, defaultServerName string) (*tls
} }
return tlsCfg, nil return tlsCfg, nil
} }

View File

@ -33,28 +33,28 @@ func TestCreateTLSConfig_Disabled(t *testing.T) {
func TestCreateTLSConfig_ServerName(t *testing.T) { func TestCreateTLSConfig_ServerName(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cfg *config.ProxySSLConfig cfg *config.ProxySSLConfig
defaultServerName string defaultServerName string
wantServerName string wantServerName string
}{ }{
{ {
name: "custom server name", name: "custom server name",
cfg: &config.ProxySSLConfig{Enabled: true, ServerName: "custom.example.com"}, cfg: &config.ProxySSLConfig{Enabled: true, ServerName: "custom.example.com"},
defaultServerName: "default.example.com", defaultServerName: "default.example.com",
wantServerName: "custom.example.com", wantServerName: "custom.example.com",
}, },
{ {
name: "default server name", name: "default server name",
cfg: &config.ProxySSLConfig{Enabled: true}, cfg: &config.ProxySSLConfig{Enabled: true},
defaultServerName: "default.example.com", defaultServerName: "default.example.com",
wantServerName: "default.example.com", wantServerName: "default.example.com",
}, },
{ {
name: "empty default", name: "empty default",
cfg: &config.ProxySSLConfig{Enabled: true}, cfg: &config.ProxySSLConfig{Enabled: true},
defaultServerName: "", defaultServerName: "",
wantServerName: "", wantServerName: "",
}, },
} }
@ -97,11 +97,11 @@ func TestCreateTLSConfig_InsecureSkipVerify(t *testing.T) {
func TestCreateTLSConfig_TLSVersions(t *testing.T) { func TestCreateTLSConfig_TLSVersions(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
minVersion string minVersion string
maxVersion string maxVersion string
wantMin uint16 wantMin uint16
wantMax uint16 wantMax uint16
}{ }{
{ {
name: "TLSV1.2 min", name: "TLSV1.2 min",
@ -119,11 +119,11 @@ func TestCreateTLSConfig_TLSVersions(t *testing.T) {
wantMax: tls.VersionTLS12, wantMax: tls.VersionTLS12,
}, },
{ {
name: "both versions", name: "both versions",
minVersion: "TLSV1.2", minVersion: "TLSV1.2",
maxVersion: "TLSV1.3", maxVersion: "TLSV1.3",
wantMin: tls.VersionTLS12, wantMin: tls.VersionTLS12,
wantMax: tls.VersionTLS13, wantMax: tls.VersionTLS13,
}, },
{ {
name: "mixed case TLSv1.2", name: "mixed case TLSv1.2",
@ -135,9 +135,9 @@ func TestCreateTLSConfig_TLSVersions(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
cfg := &config.ProxySSLConfig{ cfg := &config.ProxySSLConfig{
Enabled: true, Enabled: true,
MinVersion: tt.minVersion, MinVersion: tt.minVersion,
MaxVersion: tt.maxVersion, MaxVersion: tt.maxVersion,
} }
tlsCfg, err := CreateTLSConfig(cfg, "example.com") tlsCfg, err := CreateTLSConfig(cfg, "example.com")
if err != nil { if err != nil {
@ -263,11 +263,11 @@ func TestGetCacheDuration_StatusCodeMapping(t *testing.T) {
MaxAge: 1 * time.Minute, MaxAge: 1 * time.Minute,
}, },
CacheValid: &config.ProxyCacheValidConfig{ CacheValid: &config.ProxyCacheValidConfig{
OK: 10 * time.Minute, OK: 10 * time.Minute,
Redirect: 1 * time.Hour, Redirect: 1 * time.Hour,
NotFound: 1 * time.Minute, NotFound: 1 * time.Minute,
ClientError: 30 * time.Second, ClientError: 30 * time.Second,
ServerError: 0, // 不缓存 ServerError: 0, // 不缓存
}, },
} }
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
@ -310,11 +310,11 @@ func TestGetCacheDuration_ZeroValuesNoCache(t *testing.T) {
MaxAge: 5 * time.Minute, MaxAge: 5 * time.Minute,
}, },
CacheValid: &config.ProxyCacheValidConfig{ CacheValid: &config.ProxyCacheValidConfig{
OK: 10 * time.Minute, // OK 有值 OK: 10 * time.Minute, // OK 有值
Redirect: 0, // 不缓存 Redirect: 0, // 不缓存
NotFound: 0, // 不缓存 NotFound: 0, // 不缓存
ClientError: 0, // 不缓存 ClientError: 0, // 不缓存
ServerError: 0, // 不缓存 ServerError: 0, // 不缓存
}, },
} }
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
@ -339,4 +339,4 @@ func TestGetCacheDuration_ZeroValuesNoCache(t *testing.T) {
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want) t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
} }
} }
} }

View File

@ -10,19 +10,26 @@ import (
"rua.plus/lolly/internal/variable" "rua.plus/lolly/internal/variable"
) )
// RedirectRewrite 模式常量
const (
redirectModeDefault = "default"
redirectModeOff = "off"
redirectModeCustom = "custom"
)
// compiledRule 预编译的改写规则 // compiledRule 预编译的改写规则
type compiledRule struct { type compiledRule struct {
pattern *regexp.Regexp // 正则模式nil 表示非正则匹配 pattern *regexp.Regexp // 正则模式nil 表示非正则匹配
exactMatch string // 精确匹配前缀(用于 prefix 匹配) replacement string // 替换模板(含变量
replacement string // 替换模板(含变量) exactMatch string // 精确匹配前缀(用于 prefix 匹配
caseInsensitive bool // 正则大小写不敏感(~* 前缀) caseInsensitive bool // 正则大小写不敏感(~* 前缀)
} }
// RedirectRewriter Location/Refresh 头改写器 // RedirectRewriter Location/Refresh 头改写器
type RedirectRewriter struct { type RedirectRewriter struct {
proxyPath string // 用于 default 模式(当前代理路径)
mode string // "default" | "off" | "custom"(空字符串视为 default mode string // "default" | "off" | "custom"(空字符串视为 default
rules []compiledRule // 仅 custom 模式预编译 rules []compiledRule // 仅 custom 模式预编译
proxyPath string // 用于 default 模式(当前代理路径)
} }
// NewRedirectRewriter 创建改写器 // NewRedirectRewriter 创建改写器
@ -32,7 +39,7 @@ func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (*
if cfg == nil { if cfg == nil {
// 未配置时默认启用 default 模式 // 未配置时默认启用 default 模式
return &RedirectRewriter{ return &RedirectRewriter{
mode: "default", mode: redirectModeDefault,
proxyPath: proxyPath, proxyPath: proxyPath,
}, nil }, nil
} }
@ -43,7 +50,7 @@ func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (*
} }
// custom 模式:预编译规则 // custom 模式:预编译规则
if cfg.Mode == "custom" { if cfg.Mode == redirectModeCustom {
rules := make([]compiledRule, 0, len(cfg.Rules)) rules := make([]compiledRule, 0, len(cfg.Rules))
for _, rule := range cfg.Rules { for _, rule := range cfg.Rules {
cr := compiledRule{ cr := compiledRule{
@ -52,7 +59,7 @@ func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (*
if strings.HasPrefix(rule.Pattern, "~") { if strings.HasPrefix(rule.Pattern, "~") {
// 正则模式 // 正则模式
patternStr := rule.Pattern var patternStr string
if strings.HasPrefix(rule.Pattern, "~*") { if strings.HasPrefix(rule.Pattern, "~*") {
cr.caseInsensitive = true cr.caseInsensitive = true
patternStr = rule.Pattern[2:] patternStr = rule.Pattern[2:]
@ -80,7 +87,7 @@ func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (*
// Mode 返回当前模式(处理空字符串默认值) // Mode 返回当前模式(处理空字符串默认值)
func (r *RedirectRewriter) Mode() string { func (r *RedirectRewriter) Mode() string {
if r.mode == "" { if r.mode == "" {
return "default" return redirectModeDefault
} }
return r.mode return r.mode
} }
@ -136,13 +143,13 @@ func (r *RedirectRewriter) rewriteURL(headerValue string, ctx *fasthttp.RequestC
} }
switch r.Mode() { switch r.Mode() {
case "off": case redirectModeOff:
return headerValue return headerValue
case "custom": case redirectModeCustom:
return r.rewriteCustom(headerValue, ctx) return r.rewriteCustom(headerValue, ctx)
case "default", "": case redirectModeDefault, "":
return r.rewriteDefault(headerValue, ctx, targetURL, originalClientHost) return r.rewriteDefault(headerValue, ctx, targetURL, originalClientHost)
default: default:
@ -154,7 +161,8 @@ func (r *RedirectRewriter) rewriteURL(headerValue string, ctx *fasthttp.RequestC
// 使用前缀匹配:如果 headerValue 以 targetURL 开头,替换为 replacement + 原路径后缀 // 使用前缀匹配:如果 headerValue 以 targetURL 开头,替换为 replacement + 原路径后缀
// replacement 使用 originalClientHost 构建:"$scheme://originalClientHost/" // replacement 使用 originalClientHost 构建:"$scheme://originalClientHost/"
// 例如targetURL="http://backend:8000", headerValue="http://backend:8000/api/v2/users" // 例如targetURL="http://backend:8000", headerValue="http://backend:8000/api/v2/users"
// → 替换为 "$scheme://originalClientHost/api/v2/users" //
// → 替换为 "$scheme://originalClientHost/api/v2/users"
func (r *RedirectRewriter) rewriteDefault(headerValue string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string { func (r *RedirectRewriter) rewriteDefault(headerValue string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string {
if targetURL == "" { if targetURL == "" {
return headerValue return headerValue
@ -167,9 +175,9 @@ func (r *RedirectRewriter) rewriteDefault(headerValue string, ctx *fasthttp.Requ
// 检查剩余部分是否以合法分隔符开头 // 检查剩余部分是否以合法分隔符开头
if len(remaining) == 0 || remaining[0] == '/' || remaining[0] == '?' || remaining[0] == '#' { if len(remaining) == 0 || remaining[0] == '/' || remaining[0] == '?' || remaining[0] == '#' {
// 使用客户端原始 host 构建 replacement // 使用客户端原始 host 构建 replacement
scheme := "http" scheme := protoHTTP
if ctx.IsTLS() { if ctx.IsTLS() {
scheme = "https" scheme = protoHTTPS
} }
replacement := scheme + "://" + originalClientHost replacement := scheme + "://" + originalClientHost
return replacement + remaining return replacement + remaining
@ -198,7 +206,7 @@ func (r *RedirectRewriter) rewriteCustom(headerValue string, ctx *fasthttp.Reque
return result return result
} }
} else { } else {
loc := rule.pattern.FindStringIndex(headerValue) loc := rule.pattern.FindStringIndex(headerValue)
if loc != nil { if loc != nil {
expanded := vc.Expand(rule.replacement) expanded := vc.Expand(rule.replacement)
result := headerValue[:loc[0]] + expanded + headerValue[loc[1]:] result := headerValue[:loc[0]] + expanded + headerValue[loc[1]:]

View File

@ -3,10 +3,9 @@ package proxy
import ( import (
"testing" "testing"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config" "rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/testutil" "rua.plus/lolly/internal/testutil"
"github.com/valyala/fasthttp"
) )
// TestRedirectRewrite_ExactMatch 测试精确匹配改写 // TestRedirectRewrite_ExactMatch 测试精确匹配改写
@ -273,4 +272,4 @@ func TestRedirectRewrite_EmptyMode(t *testing.T) {
if rw.Mode() != "default" { if rw.Mode() != "default" {
t.Errorf("Mode() = %q, want %q", rw.Mode(), "default") t.Errorf("Mode() = %q, want %q", rw.Mode(), "default")
} }
} }

View File

@ -438,6 +438,9 @@ func (s *Server) Start() error {
return s.startVHostMode() return s.startVHostMode()
case config.ServerModeMultiServer: case config.ServerModeMultiServer:
return s.startMultiServerMode() return s.startMultiServerMode()
case config.ServerModeAuto:
// auto 模式下 GetMode() 会自动推断,此处为防御性处理
return s.startSingleMode()
default: default:
// 默认使用单服务器模式 // 默认使用单服务器模式
return s.startSingleMode() return s.startSingleMode()
@ -513,10 +516,10 @@ func (s *Server) startSingleMode() error {
MaxRequestsPerConn: serverCfg.MaxRequestsPerConn, MaxRequestsPerConn: serverCfg.MaxRequestsPerConn,
CloseOnShutdown: true, CloseOnShutdown: true,
// 高并发优化配置 // 高并发优化配置
Concurrency: serverCfg.Concurrency, Concurrency: serverCfg.Concurrency,
ReadBufferSize: serverCfg.ReadBufferSize, ReadBufferSize: serverCfg.ReadBufferSize,
WriteBufferSize: serverCfg.WriteBufferSize, WriteBufferSize: serverCfg.WriteBufferSize,
ReduceMemoryUsage: serverCfg.ReduceMemoryUsage, ReduceMemoryUsage: serverCfg.ReduceMemoryUsage,
} }
s.running = true s.running = true
@ -638,10 +641,10 @@ func (s *Server) startVHostMode() error {
MaxRequestsPerConn: serverCfg.MaxRequestsPerConn, MaxRequestsPerConn: serverCfg.MaxRequestsPerConn,
CloseOnShutdown: true, CloseOnShutdown: true,
// 高并发优化配置 // 高并发优化配置
Concurrency: serverCfg.Concurrency, Concurrency: serverCfg.Concurrency,
ReadBufferSize: serverCfg.ReadBufferSize, ReadBufferSize: serverCfg.ReadBufferSize,
WriteBufferSize: serverCfg.WriteBufferSize, WriteBufferSize: serverCfg.WriteBufferSize,
ReduceMemoryUsage: serverCfg.ReduceMemoryUsage, ReduceMemoryUsage: serverCfg.ReduceMemoryUsage,
} }
s.running = true s.running = true
@ -861,7 +864,7 @@ func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.S
routePath := proxyCfg.Path routePath := proxyCfg.Path
// 确保通配符路由格式正确 // 确保通配符路由格式正确
if !strings.HasSuffix(routePath, "/") && routePath != "/" { if !strings.HasSuffix(routePath, "/") && routePath != "/" {
routePath += "/" routePath += "/"
} }
wildcardPath := routePath + "{path:*}" wildcardPath := routePath + "{path:*}"
router.GET(wildcardPath, p.ServeHTTP) router.GET(wildcardPath, p.ServeHTTP)