style(proxy,server): 代码风格优化
- headers.go: 添加协议常量 protoHTTP/protoHTTPS - redirect_rewrite.go: 添加模式常量,修正缩进 - proxy_ssl_test.go: 表格测试字段对齐 - server.go: 添加 ServerModeAuto 分支防御性处理 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d874f97765
commit
470c82d940
@ -10,6 +10,12 @@ import (
|
|||||||
"rua.plus/lolly/internal/netutil"
|
"rua.plus/lolly/internal/netutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// 协议常量
|
||||||
|
const (
|
||||||
|
protoHTTP = "http"
|
||||||
|
protoHTTPS = "https"
|
||||||
|
)
|
||||||
|
|
||||||
// ForwardedHeaders 包含 X-Forwarded 系列头信息。
|
// ForwardedHeaders 包含 X-Forwarded 系列头信息。
|
||||||
type ForwardedHeaders struct {
|
type ForwardedHeaders struct {
|
||||||
ClientIP string // 客户端 IP
|
ClientIP string // 客户端 IP
|
||||||
@ -28,9 +34,9 @@ func ExtractForwardedHeaders(ctx *fasthttp.RequestCtx) ForwardedHeaders {
|
|||||||
clientIP := netutil.ExtractClientIP(ctx)
|
clientIP := netutil.ExtractClientIP(ctx)
|
||||||
host := string(ctx.Host())
|
host := string(ctx.Host())
|
||||||
|
|
||||||
proto := "http"
|
proto := protoHTTP
|
||||||
if ctx.IsTLS() {
|
if ctx.IsTLS() {
|
||||||
proto = "https"
|
proto = protoHTTPS
|
||||||
}
|
}
|
||||||
|
|
||||||
return ForwardedHeaders{
|
return ForwardedHeaders{
|
||||||
|
|||||||
@ -84,8 +84,8 @@ type Proxy struct {
|
|||||||
config *config.ProxyConfig
|
config *config.ProxyConfig
|
||||||
cache *cache.ProxyCache
|
cache *cache.ProxyCache
|
||||||
healthChecker *HealthChecker
|
healthChecker *HealthChecker
|
||||||
luaEngine *lua.LuaEngine // Lua 引擎引用
|
luaEngine *lua.LuaEngine // Lua 引擎引用
|
||||||
redirectRewriter *RedirectRewriter // 重定向改写器
|
redirectRewriter *RedirectRewriter // 重定向改写器
|
||||||
stopCh chan struct{}
|
stopCh chan struct{}
|
||||||
targets []*loadbalance.Target
|
targets []*loadbalance.Target
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
|||||||
@ -98,4 +98,4 @@ func CreateTLSConfig(cfg *config.ProxySSLConfig, defaultServerName string) (*tls
|
|||||||
}
|
}
|
||||||
|
|
||||||
return tlsCfg, nil
|
return tlsCfg, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -33,28 +33,28 @@ func TestCreateTLSConfig_Disabled(t *testing.T) {
|
|||||||
|
|
||||||
func TestCreateTLSConfig_ServerName(t *testing.T) {
|
func TestCreateTLSConfig_ServerName(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
cfg *config.ProxySSLConfig
|
cfg *config.ProxySSLConfig
|
||||||
defaultServerName string
|
defaultServerName string
|
||||||
wantServerName string
|
wantServerName string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "custom server name",
|
name: "custom server name",
|
||||||
cfg: &config.ProxySSLConfig{Enabled: true, ServerName: "custom.example.com"},
|
cfg: &config.ProxySSLConfig{Enabled: true, ServerName: "custom.example.com"},
|
||||||
defaultServerName: "default.example.com",
|
defaultServerName: "default.example.com",
|
||||||
wantServerName: "custom.example.com",
|
wantServerName: "custom.example.com",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "default server name",
|
name: "default server name",
|
||||||
cfg: &config.ProxySSLConfig{Enabled: true},
|
cfg: &config.ProxySSLConfig{Enabled: true},
|
||||||
defaultServerName: "default.example.com",
|
defaultServerName: "default.example.com",
|
||||||
wantServerName: "default.example.com",
|
wantServerName: "default.example.com",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty default",
|
name: "empty default",
|
||||||
cfg: &config.ProxySSLConfig{Enabled: true},
|
cfg: &config.ProxySSLConfig{Enabled: true},
|
||||||
defaultServerName: "",
|
defaultServerName: "",
|
||||||
wantServerName: "",
|
wantServerName: "",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,11 +97,11 @@ func TestCreateTLSConfig_InsecureSkipVerify(t *testing.T) {
|
|||||||
|
|
||||||
func TestCreateTLSConfig_TLSVersions(t *testing.T) {
|
func TestCreateTLSConfig_TLSVersions(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
minVersion string
|
minVersion string
|
||||||
maxVersion string
|
maxVersion string
|
||||||
wantMin uint16
|
wantMin uint16
|
||||||
wantMax uint16
|
wantMax uint16
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "TLSV1.2 min",
|
name: "TLSV1.2 min",
|
||||||
@ -119,11 +119,11 @@ func TestCreateTLSConfig_TLSVersions(t *testing.T) {
|
|||||||
wantMax: tls.VersionTLS12,
|
wantMax: tls.VersionTLS12,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "both versions",
|
name: "both versions",
|
||||||
minVersion: "TLSV1.2",
|
minVersion: "TLSV1.2",
|
||||||
maxVersion: "TLSV1.3",
|
maxVersion: "TLSV1.3",
|
||||||
wantMin: tls.VersionTLS12,
|
wantMin: tls.VersionTLS12,
|
||||||
wantMax: tls.VersionTLS13,
|
wantMax: tls.VersionTLS13,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mixed case TLSv1.2",
|
name: "mixed case TLSv1.2",
|
||||||
@ -135,9 +135,9 @@ func TestCreateTLSConfig_TLSVersions(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
cfg := &config.ProxySSLConfig{
|
cfg := &config.ProxySSLConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
MinVersion: tt.minVersion,
|
MinVersion: tt.minVersion,
|
||||||
MaxVersion: tt.maxVersion,
|
MaxVersion: tt.maxVersion,
|
||||||
}
|
}
|
||||||
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
|
tlsCfg, err := CreateTLSConfig(cfg, "example.com")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -263,11 +263,11 @@ func TestGetCacheDuration_StatusCodeMapping(t *testing.T) {
|
|||||||
MaxAge: 1 * time.Minute,
|
MaxAge: 1 * time.Minute,
|
||||||
},
|
},
|
||||||
CacheValid: &config.ProxyCacheValidConfig{
|
CacheValid: &config.ProxyCacheValidConfig{
|
||||||
OK: 10 * time.Minute,
|
OK: 10 * time.Minute,
|
||||||
Redirect: 1 * time.Hour,
|
Redirect: 1 * time.Hour,
|
||||||
NotFound: 1 * time.Minute,
|
NotFound: 1 * time.Minute,
|
||||||
ClientError: 30 * time.Second,
|
ClientError: 30 * time.Second,
|
||||||
ServerError: 0, // 不缓存
|
ServerError: 0, // 不缓存
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||||
@ -310,11 +310,11 @@ func TestGetCacheDuration_ZeroValuesNoCache(t *testing.T) {
|
|||||||
MaxAge: 5 * time.Minute,
|
MaxAge: 5 * time.Minute,
|
||||||
},
|
},
|
||||||
CacheValid: &config.ProxyCacheValidConfig{
|
CacheValid: &config.ProxyCacheValidConfig{
|
||||||
OK: 10 * time.Minute, // OK 有值
|
OK: 10 * time.Minute, // OK 有值
|
||||||
Redirect: 0, // 不缓存
|
Redirect: 0, // 不缓存
|
||||||
NotFound: 0, // 不缓存
|
NotFound: 0, // 不缓存
|
||||||
ClientError: 0, // 不缓存
|
ClientError: 0, // 不缓存
|
||||||
ServerError: 0, // 不缓存
|
ServerError: 0, // 不缓存
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||||
@ -339,4 +339,4 @@ func TestGetCacheDuration_ZeroValuesNoCache(t *testing.T) {
|
|||||||
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
|
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, got, tt.want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,19 +10,26 @@ import (
|
|||||||
"rua.plus/lolly/internal/variable"
|
"rua.plus/lolly/internal/variable"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// RedirectRewrite 模式常量
|
||||||
|
const (
|
||||||
|
redirectModeDefault = "default"
|
||||||
|
redirectModeOff = "off"
|
||||||
|
redirectModeCustom = "custom"
|
||||||
|
)
|
||||||
|
|
||||||
// compiledRule 预编译的改写规则
|
// compiledRule 预编译的改写规则
|
||||||
type compiledRule struct {
|
type compiledRule struct {
|
||||||
pattern *regexp.Regexp // 正则模式,nil 表示非正则匹配
|
pattern *regexp.Regexp // 正则模式,nil 表示非正则匹配
|
||||||
exactMatch string // 精确匹配前缀(用于 prefix 匹配)
|
replacement string // 替换模板(含变量)
|
||||||
replacement string // 替换模板(含变量)
|
exactMatch string // 精确匹配前缀(用于 prefix 匹配)
|
||||||
caseInsensitive bool // 正则大小写不敏感(~* 前缀)
|
caseInsensitive bool // 正则大小写不敏感(~* 前缀)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RedirectRewriter Location/Refresh 头改写器
|
// RedirectRewriter Location/Refresh 头改写器
|
||||||
type RedirectRewriter struct {
|
type RedirectRewriter struct {
|
||||||
|
proxyPath string // 用于 default 模式(当前代理路径)
|
||||||
mode string // "default" | "off" | "custom"(空字符串视为 default)
|
mode string // "default" | "off" | "custom"(空字符串视为 default)
|
||||||
rules []compiledRule // 仅 custom 模式预编译
|
rules []compiledRule // 仅 custom 模式预编译
|
||||||
proxyPath string // 用于 default 模式(当前代理路径)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRedirectRewriter 创建改写器
|
// NewRedirectRewriter 创建改写器
|
||||||
@ -32,7 +39,7 @@ func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (*
|
|||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
// 未配置时默认启用 default 模式
|
// 未配置时默认启用 default 模式
|
||||||
return &RedirectRewriter{
|
return &RedirectRewriter{
|
||||||
mode: "default",
|
mode: redirectModeDefault,
|
||||||
proxyPath: proxyPath,
|
proxyPath: proxyPath,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@ -43,7 +50,7 @@ func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
// custom 模式:预编译规则
|
// custom 模式:预编译规则
|
||||||
if cfg.Mode == "custom" {
|
if cfg.Mode == redirectModeCustom {
|
||||||
rules := make([]compiledRule, 0, len(cfg.Rules))
|
rules := make([]compiledRule, 0, len(cfg.Rules))
|
||||||
for _, rule := range cfg.Rules {
|
for _, rule := range cfg.Rules {
|
||||||
cr := compiledRule{
|
cr := compiledRule{
|
||||||
@ -52,7 +59,7 @@ func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (*
|
|||||||
|
|
||||||
if strings.HasPrefix(rule.Pattern, "~") {
|
if strings.HasPrefix(rule.Pattern, "~") {
|
||||||
// 正则模式
|
// 正则模式
|
||||||
patternStr := rule.Pattern
|
var patternStr string
|
||||||
if strings.HasPrefix(rule.Pattern, "~*") {
|
if strings.HasPrefix(rule.Pattern, "~*") {
|
||||||
cr.caseInsensitive = true
|
cr.caseInsensitive = true
|
||||||
patternStr = rule.Pattern[2:]
|
patternStr = rule.Pattern[2:]
|
||||||
@ -80,7 +87,7 @@ func NewRedirectRewriter(cfg *config.RedirectRewriteConfig, proxyPath string) (*
|
|||||||
// Mode 返回当前模式(处理空字符串默认值)
|
// Mode 返回当前模式(处理空字符串默认值)
|
||||||
func (r *RedirectRewriter) Mode() string {
|
func (r *RedirectRewriter) Mode() string {
|
||||||
if r.mode == "" {
|
if r.mode == "" {
|
||||||
return "default"
|
return redirectModeDefault
|
||||||
}
|
}
|
||||||
return r.mode
|
return r.mode
|
||||||
}
|
}
|
||||||
@ -136,13 +143,13 @@ func (r *RedirectRewriter) rewriteURL(headerValue string, ctx *fasthttp.RequestC
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch r.Mode() {
|
switch r.Mode() {
|
||||||
case "off":
|
case redirectModeOff:
|
||||||
return headerValue
|
return headerValue
|
||||||
|
|
||||||
case "custom":
|
case redirectModeCustom:
|
||||||
return r.rewriteCustom(headerValue, ctx)
|
return r.rewriteCustom(headerValue, ctx)
|
||||||
|
|
||||||
case "default", "":
|
case redirectModeDefault, "":
|
||||||
return r.rewriteDefault(headerValue, ctx, targetURL, originalClientHost)
|
return r.rewriteDefault(headerValue, ctx, targetURL, originalClientHost)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@ -154,7 +161,8 @@ func (r *RedirectRewriter) rewriteURL(headerValue string, ctx *fasthttp.RequestC
|
|||||||
// 使用前缀匹配:如果 headerValue 以 targetURL 开头,替换为 replacement + 原路径后缀
|
// 使用前缀匹配:如果 headerValue 以 targetURL 开头,替换为 replacement + 原路径后缀
|
||||||
// replacement 使用 originalClientHost 构建:"$scheme://originalClientHost/"
|
// replacement 使用 originalClientHost 构建:"$scheme://originalClientHost/"
|
||||||
// 例如:targetURL="http://backend:8000", headerValue="http://backend:8000/api/v2/users"
|
// 例如:targetURL="http://backend:8000", headerValue="http://backend:8000/api/v2/users"
|
||||||
// → 替换为 "$scheme://originalClientHost/api/v2/users"
|
//
|
||||||
|
// → 替换为 "$scheme://originalClientHost/api/v2/users"
|
||||||
func (r *RedirectRewriter) rewriteDefault(headerValue string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string {
|
func (r *RedirectRewriter) rewriteDefault(headerValue string, ctx *fasthttp.RequestCtx, targetURL string, originalClientHost string) string {
|
||||||
if targetURL == "" {
|
if targetURL == "" {
|
||||||
return headerValue
|
return headerValue
|
||||||
@ -167,9 +175,9 @@ func (r *RedirectRewriter) rewriteDefault(headerValue string, ctx *fasthttp.Requ
|
|||||||
// 检查剩余部分是否以合法分隔符开头
|
// 检查剩余部分是否以合法分隔符开头
|
||||||
if len(remaining) == 0 || remaining[0] == '/' || remaining[0] == '?' || remaining[0] == '#' {
|
if len(remaining) == 0 || remaining[0] == '/' || remaining[0] == '?' || remaining[0] == '#' {
|
||||||
// 使用客户端原始 host 构建 replacement
|
// 使用客户端原始 host 构建 replacement
|
||||||
scheme := "http"
|
scheme := protoHTTP
|
||||||
if ctx.IsTLS() {
|
if ctx.IsTLS() {
|
||||||
scheme = "https"
|
scheme = protoHTTPS
|
||||||
}
|
}
|
||||||
replacement := scheme + "://" + originalClientHost
|
replacement := scheme + "://" + originalClientHost
|
||||||
return replacement + remaining
|
return replacement + remaining
|
||||||
@ -198,7 +206,7 @@ func (r *RedirectRewriter) rewriteCustom(headerValue string, ctx *fasthttp.Reque
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
loc := rule.pattern.FindStringIndex(headerValue)
|
loc := rule.pattern.FindStringIndex(headerValue)
|
||||||
if loc != nil {
|
if loc != nil {
|
||||||
expanded := vc.Expand(rule.replacement)
|
expanded := vc.Expand(rule.replacement)
|
||||||
result := headerValue[:loc[0]] + expanded + headerValue[loc[1]:]
|
result := headerValue[:loc[0]] + expanded + headerValue[loc[1]:]
|
||||||
|
|||||||
@ -3,10 +3,9 @@ package proxy
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
"rua.plus/lolly/internal/config"
|
"rua.plus/lolly/internal/config"
|
||||||
"rua.plus/lolly/internal/testutil"
|
"rua.plus/lolly/internal/testutil"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestRedirectRewrite_ExactMatch 测试精确匹配改写
|
// TestRedirectRewrite_ExactMatch 测试精确匹配改写
|
||||||
@ -273,4 +272,4 @@ func TestRedirectRewrite_EmptyMode(t *testing.T) {
|
|||||||
if rw.Mode() != "default" {
|
if rw.Mode() != "default" {
|
||||||
t.Errorf("Mode() = %q, want %q", rw.Mode(), "default")
|
t.Errorf("Mode() = %q, want %q", rw.Mode(), "default")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -438,6 +438,9 @@ func (s *Server) Start() error {
|
|||||||
return s.startVHostMode()
|
return s.startVHostMode()
|
||||||
case config.ServerModeMultiServer:
|
case config.ServerModeMultiServer:
|
||||||
return s.startMultiServerMode()
|
return s.startMultiServerMode()
|
||||||
|
case config.ServerModeAuto:
|
||||||
|
// auto 模式下 GetMode() 会自动推断,此处为防御性处理
|
||||||
|
return s.startSingleMode()
|
||||||
default:
|
default:
|
||||||
// 默认使用单服务器模式
|
// 默认使用单服务器模式
|
||||||
return s.startSingleMode()
|
return s.startSingleMode()
|
||||||
@ -513,10 +516,10 @@ func (s *Server) startSingleMode() error {
|
|||||||
MaxRequestsPerConn: serverCfg.MaxRequestsPerConn,
|
MaxRequestsPerConn: serverCfg.MaxRequestsPerConn,
|
||||||
CloseOnShutdown: true,
|
CloseOnShutdown: true,
|
||||||
// 高并发优化配置
|
// 高并发优化配置
|
||||||
Concurrency: serverCfg.Concurrency,
|
Concurrency: serverCfg.Concurrency,
|
||||||
ReadBufferSize: serverCfg.ReadBufferSize,
|
ReadBufferSize: serverCfg.ReadBufferSize,
|
||||||
WriteBufferSize: serverCfg.WriteBufferSize,
|
WriteBufferSize: serverCfg.WriteBufferSize,
|
||||||
ReduceMemoryUsage: serverCfg.ReduceMemoryUsage,
|
ReduceMemoryUsage: serverCfg.ReduceMemoryUsage,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.running = true
|
s.running = true
|
||||||
@ -638,10 +641,10 @@ func (s *Server) startVHostMode() error {
|
|||||||
MaxRequestsPerConn: serverCfg.MaxRequestsPerConn,
|
MaxRequestsPerConn: serverCfg.MaxRequestsPerConn,
|
||||||
CloseOnShutdown: true,
|
CloseOnShutdown: true,
|
||||||
// 高并发优化配置
|
// 高并发优化配置
|
||||||
Concurrency: serverCfg.Concurrency,
|
Concurrency: serverCfg.Concurrency,
|
||||||
ReadBufferSize: serverCfg.ReadBufferSize,
|
ReadBufferSize: serverCfg.ReadBufferSize,
|
||||||
WriteBufferSize: serverCfg.WriteBufferSize,
|
WriteBufferSize: serverCfg.WriteBufferSize,
|
||||||
ReduceMemoryUsage: serverCfg.ReduceMemoryUsage,
|
ReduceMemoryUsage: serverCfg.ReduceMemoryUsage,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.running = true
|
s.running = true
|
||||||
@ -861,7 +864,7 @@ func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.S
|
|||||||
routePath := proxyCfg.Path
|
routePath := proxyCfg.Path
|
||||||
// 确保通配符路由格式正确
|
// 确保通配符路由格式正确
|
||||||
if !strings.HasSuffix(routePath, "/") && routePath != "/" {
|
if !strings.HasSuffix(routePath, "/") && routePath != "/" {
|
||||||
routePath += "/"
|
routePath += "/"
|
||||||
}
|
}
|
||||||
wildcardPath := routePath + "{path:*}"
|
wildcardPath := routePath + "{path:*}"
|
||||||
router.GET(wildcardPath, p.ServeHTTP)
|
router.GET(wildcardPath, p.ServeHTTP)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user