refactor(middleware): 适配变量系统和 resolver 重命名

适配 variable.NewContext/ReleaseContext
适配 resolver.DNSCacheEntry

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-10 09:40:57 +08:00
parent ac8e89c492
commit 2af3176507
4 changed files with 40 additions and 38 deletions

View File

@ -36,10 +36,12 @@ const (
AlgorithmGzip Algorithm = iota AlgorithmGzip Algorithm = iota
// AlgorithmBrotli 使用 brotli 压缩。 // AlgorithmBrotli 使用 brotli 压缩。
AlgorithmBrotli AlgorithmBrotli
compressionGZIP = "gzip"
) )
// CompressionMiddleware 响应压缩中间件。 // Middleware 响应压缩中间件。
type CompressionMiddleware struct { type Middleware struct {
// types 可压缩的 MIME 类型列表 // types 可压缩的 MIME 类型列表
types []string types []string
// level 压缩级别1-9 // level 压缩级别1-9
@ -61,9 +63,9 @@ type CompressionMiddleware struct {
// - cfg: 压缩配置,包含算法类型、压缩级别、最小压缩大小等 // - cfg: 压缩配置,包含算法类型、压缩级别、最小压缩大小等
// //
// 返回值: // 返回值:
// - *CompressionMiddleware: 压缩中间件实例 // - *Middleware: 压缩中间件实例
// - error: 配置无效时返回错误 // - error: 配置无效时返回错误
func New(cfg *config.CompressionConfig) (*CompressionMiddleware, error) { func New(cfg *config.CompressionConfig) (*Middleware, error) {
if cfg == nil { if cfg == nil {
cfg = &config.CompressionConfig{ cfg = &config.CompressionConfig{
Type: "gzip", Type: "gzip",
@ -89,7 +91,7 @@ func New(cfg *config.CompressionConfig) (*CompressionMiddleware, error) {
switch strings.ToLower(cfg.Type) { switch strings.ToLower(cfg.Type) {
case "brotli": case "brotli":
algo = AlgorithmBrotli algo = AlgorithmBrotli
case "gzip": case compressionGZIP:
algo = AlgorithmGzip algo = AlgorithmGzip
case "both": case "both":
// both 模式优先使用 brotli如果客户端支持 // both 模式优先使用 brotli如果客户端支持
@ -98,7 +100,7 @@ func New(cfg *config.CompressionConfig) (*CompressionMiddleware, error) {
algo = AlgorithmGzip algo = AlgorithmGzip
} }
m := &CompressionMiddleware{ m := &Middleware{
types: cfg.Types, types: cfg.Types,
level: cfg.Level, level: cfg.Level,
minSize: cfg.MinSize, minSize: cfg.MinSize,
@ -141,7 +143,7 @@ func defaultCompressibleTypes() []string {
} }
// Name 返回中间件名称。 // Name 返回中间件名称。
func (m *CompressionMiddleware) Name() string { func (m *Middleware) Name() string {
return "compression" return "compression"
} }
@ -152,7 +154,7 @@ func (m *CompressionMiddleware) Name() string {
// //
// 返回值: // 返回值:
// - fasthttp.RequestHandler: 包装后的请求处理器 // - fasthttp.RequestHandler: 包装后的请求处理器
func (m *CompressionMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { func (m *Middleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) { return func(ctx *fasthttp.RequestCtx) {
// 检查客户端是否支持压缩 // 检查客户端是否支持压缩
acceptEncoding := string(ctx.Request.Header.Peek("Accept-Encoding")) acceptEncoding := string(ctx.Request.Header.Peek("Accept-Encoding"))
@ -204,7 +206,7 @@ func (m *CompressionMiddleware) Process(next fasthttp.RequestHandler) fasthttp.R
encoding = "br" encoding = "br"
} else if useGzip { } else if useGzip {
compressed = m.compressGzip(body) compressed = m.compressGzip(body)
encoding = "gzip" encoding = compressionGZIP
} }
if len(compressed) > 0 && len(compressed) < bodyLen { if len(compressed) > 0 && len(compressed) < bodyLen {
@ -222,7 +224,7 @@ func (m *CompressionMiddleware) Process(next fasthttp.RequestHandler) fasthttp.R
// //
// 返回值: // 返回值:
// - bool: 是否可压缩 // - bool: 是否可压缩
func (m *CompressionMiddleware) isCompressible(contentType string) bool { func (m *Middleware) isCompressible(contentType string) bool {
// 移除 charset 等参数 // 移除 charset 等参数
ct := contentType ct := contentType
if idx := strings.Index(ct, ";"); idx >= 0 { if idx := strings.Index(ct, ";"); idx >= 0 {
@ -252,7 +254,7 @@ func (m *CompressionMiddleware) isCompressible(contentType string) bool {
// //
// 返回值: // 返回值:
// - []byte: 压缩后的数据 // - []byte: 压缩后的数据
func (m *CompressionMiddleware) compressGzip(data []byte) []byte { func (m *Middleware) compressGzip(data []byte) []byte {
w := m.gzipPool.Get().(*gzip.Writer) w := m.gzipPool.Get().(*gzip.Writer)
defer m.gzipPool.Put(w) defer m.gzipPool.Put(w)
@ -271,7 +273,7 @@ func (m *CompressionMiddleware) compressGzip(data []byte) []byte {
// //
// 返回值: // 返回值:
// - []byte: 压缩后的数据 // - []byte: 压缩后的数据
func (m *CompressionMiddleware) compressBrotli(data []byte) []byte { func (m *Middleware) compressBrotli(data []byte) []byte {
w := m.brotliPool.Get().(*brotli.Writer) w := m.brotliPool.Get().(*brotli.Writer)
defer m.brotliPool.Put(w) defer m.brotliPool.Put(w)
@ -287,7 +289,7 @@ func (m *CompressionMiddleware) compressBrotli(data []byte) []byte {
// //
// 返回值: // 返回值:
// - []string: 可压缩的 MIME 类型列表 // - []string: 可压缩的 MIME 类型列表
func (m *CompressionMiddleware) Types() []string { func (m *Middleware) Types() []string {
return m.types return m.types
} }
@ -295,7 +297,7 @@ func (m *CompressionMiddleware) Types() []string {
// //
// 返回值: // 返回值:
// - int: 压缩级别1-9 // - int: 压缩级别1-9
func (m *CompressionMiddleware) Level() int { func (m *Middleware) Level() int {
return m.level return m.level
} }
@ -303,6 +305,6 @@ func (m *CompressionMiddleware) Level() int {
// //
// 返回值: // 返回值:
// - int: 最小压缩大小(字节) // - int: 最小压缩大小(字节)
func (m *CompressionMiddleware) MinSize() int { func (m *Middleware) MinSize() int {
return m.minSize return m.minSize
} }

View File

@ -39,7 +39,7 @@ func (m *testMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestH
func TestEmptyChain(t *testing.T) { func TestEmptyChain(t *testing.T) {
chain := NewChain() chain := NewChain()
executed := false executed := false
final := func(ctx *fasthttp.RequestCtx) { final := func(_ *fasthttp.RequestCtx) {
executed = true executed = true
} }
@ -64,7 +64,7 @@ func TestSingleMiddleware(t *testing.T) {
mw := &testMiddleware{name: "mw1", order: &order} mw := &testMiddleware{name: "mw1", order: &order}
chain := NewChain(mw) chain := NewChain(mw)
final := func(ctx *fasthttp.RequestCtx) { final := func(_ *fasthttp.RequestCtx) {
order = append(order, "final") order = append(order, "final")
} }
@ -92,7 +92,7 @@ func TestMultipleMiddlewareOrder(t *testing.T) {
// 添加顺序mw1, mw2, mw3 // 添加顺序mw1, mw2, mw3
chain := NewChain(mw1, mw2, mw3) chain := NewChain(mw1, mw2, mw3)
final := func(ctx *fasthttp.RequestCtx) { final := func(_ *fasthttp.RequestCtx) {
order = append(order, "final") order = append(order, "final")
} }

View File

@ -52,14 +52,14 @@ type Rule struct {
flag Flag flag Flag
} }
// RewriteMiddleware URL 重写中间件。 // Middleware URL 重写中间件。
type RewriteMiddleware struct { type Middleware struct {
// rules 编译后的规则列表,按配置顺序执行 // rules 编译后的规则列表,按配置顺序执行
rules []Rule rules []Rule
} }
// New 创建重写中间件。 // New 创建重写中间件。
func New(rules []config.RewriteRule) (*RewriteMiddleware, error) { func New(rules []config.RewriteRule) (*Middleware, error) {
compiled := make([]Rule, 0, len(rules)) compiled := make([]Rule, 0, len(rules))
for _, r := range rules { for _, r := range rules {
// 验证正则表达式安全性,防止 ReDoS // 验证正则表达式安全性,防止 ReDoS
@ -77,7 +77,7 @@ func New(rules []config.RewriteRule) (*RewriteMiddleware, error) {
flag: parseFlag(r.Flag), flag: parseFlag(r.Flag),
}) })
} }
return &RewriteMiddleware{rules: compiled}, nil return &Middleware{rules: compiled}, nil
} }
// validateRegexSafety 验证正则表达式的安全性,防止 ReDoS 攻击。 // validateRegexSafety 验证正则表达式的安全性,防止 ReDoS 攻击。
@ -107,12 +107,12 @@ func validateRegexSafety(pattern string) error {
} }
// Name 返回中间件名称。 // Name 返回中间件名称。
func (m *RewriteMiddleware) Name() string { func (m *Middleware) Name() string {
return "rewrite" return "rewrite"
} }
// Process 应用重写规则。 // Process 应用重写规则。
func (m *RewriteMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { func (m *Middleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) { return func(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path()) path := string(ctx.Path())
originalPath := path originalPath := path
@ -136,9 +136,9 @@ func (m *RewriteMiddleware) Process(next fasthttp.RequestHandler) fasthttp.Reque
newPath := rule.pattern.ReplaceAllString(path, rule.replacement) newPath := rule.pattern.ReplaceAllString(path, rule.replacement)
// 对替换结果进行变量展开 // 对替换结果进行变量展开
vc := variable.NewVariableContext(ctx) vc := variable.NewContext(ctx)
newPath = vc.Expand(newPath) newPath = vc.Expand(newPath)
variable.ReleaseVariableContext(vc) variable.ReleaseContext(vc)
switch rule.flag { switch rule.flag {
case FlagRedirect: case FlagRedirect:
@ -175,6 +175,6 @@ func (m *RewriteMiddleware) Process(next fasthttp.RequestHandler) fasthttp.Reque
} }
// Rules 返回编译后的规则列表(用于调试)。 // Rules 返回编译后的规则列表(用于调试)。
func (m *RewriteMiddleware) Rules() []Rule { func (m *Middleware) Rules() []Rule {
return m.rules return m.rules
} }

View File

@ -89,7 +89,7 @@ func TestNew(t *testing.T) {
} }
} }
func TestRewriteMiddlewareLast(t *testing.T) { func TestMiddlewareLast(t *testing.T) {
m, err := New([]config.RewriteRule{ m, err := New([]config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"}, {Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
}) })
@ -118,7 +118,7 @@ func TestRewriteMiddlewareLast(t *testing.T) {
} }
} }
func TestRewriteMiddlewareRedirect(t *testing.T) { func TestMiddlewareRedirect(t *testing.T) {
m, err := New([]config.RewriteRule{ m, err := New([]config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "redirect"}, {Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "redirect"},
}) })
@ -127,7 +127,7 @@ func TestRewriteMiddlewareRedirect(t *testing.T) {
} }
handlerCalled := false handlerCalled := false
nextHandler := func(ctx *fasthttp.RequestCtx) { nextHandler := func(_ *fasthttp.RequestCtx) {
handlerCalled = true handlerCalled = true
} }
@ -153,7 +153,7 @@ func TestRewriteMiddlewareRedirect(t *testing.T) {
} }
} }
func TestRewriteMiddlewarePermanent(t *testing.T) { func TestMiddlewarePermanent(t *testing.T) {
m, err := New([]config.RewriteRule{ m, err := New([]config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "permanent"}, {Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "permanent"},
}) })
@ -182,7 +182,7 @@ func TestRewriteMiddlewarePermanent(t *testing.T) {
} }
} }
func TestRewriteMiddlewareBreak(t *testing.T) { func TestMiddlewareBreak(t *testing.T) {
m, err := New([]config.RewriteRule{ m, err := New([]config.RewriteRule{
{Pattern: "^/api/(.*)$", Replacement: "/internal/$1", Flag: "break"}, {Pattern: "^/api/(.*)$", Replacement: "/internal/$1", Flag: "break"},
{Pattern: "^/internal/(.*)$", Replacement: "/final/$1", Flag: "last"}, {Pattern: "^/internal/(.*)$", Replacement: "/final/$1", Flag: "last"},
@ -212,7 +212,7 @@ func TestRewriteMiddlewareBreak(t *testing.T) {
} }
} }
func TestRewriteMiddlewareChain(t *testing.T) { func TestMiddlewareChain(t *testing.T) {
// 测试多个 last 规则链式应用 // 测试多个 last 规则链式应用
m, err := New([]config.RewriteRule{ m, err := New([]config.RewriteRule{
{Pattern: "^/v1/(.*)$", Replacement: "/v2/$1", Flag: "last"}, {Pattern: "^/v1/(.*)$", Replacement: "/v2/$1", Flag: "last"},
@ -242,7 +242,7 @@ func TestRewriteMiddlewareChain(t *testing.T) {
} }
} }
func TestRewriteMiddlewareNoMatch(t *testing.T) { func TestMiddlewareNoMatch(t *testing.T) {
m, err := New([]config.RewriteRule{ m, err := New([]config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"}, {Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
}) })
@ -270,7 +270,7 @@ func TestRewriteMiddlewareNoMatch(t *testing.T) {
} }
} }
func TestRewriteMiddlewareName(t *testing.T) { func TestMiddlewareName(t *testing.T) {
m, err := New(nil) m, err := New(nil)
if err != nil { if err != nil {
t.Fatalf("New() error: %v", err) t.Fatalf("New() error: %v", err)
@ -281,7 +281,7 @@ func TestRewriteMiddlewareName(t *testing.T) {
} }
} }
func TestRewriteMiddlewareRules(t *testing.T) { func TestMiddlewareRules(t *testing.T) {
rules := []config.RewriteRule{ rules := []config.RewriteRule{
{Pattern: "^/a/(.*)$", Replacement: "/b/$1", Flag: "last"}, {Pattern: "^/a/(.*)$", Replacement: "/b/$1", Flag: "last"},
{Pattern: "^/c$", Replacement: "/d", Flag: "redirect"}, {Pattern: "^/c$", Replacement: "/d", Flag: "redirect"},
@ -357,7 +357,7 @@ func TestCrossRuleCycle(t *testing.T) {
t.Fatalf("New() error: %v", err) t.Fatalf("New() error: %v", err)
} }
nextHandler := func(ctx *fasthttp.RequestCtx) { nextHandler := func(_ *fasthttp.RequestCtx) {
t.Error("Next handler should not be called in a loop scenario") t.Error("Next handler should not be called in a loop scenario")
} }
@ -473,7 +473,7 @@ func TestIterationLimitExact(t *testing.T) {
t.Fatalf("New() error: %v", err) t.Fatalf("New() error: %v", err)
} }
nextHandler := func(ctx *fasthttp.RequestCtx) { nextHandler := func(_ *fasthttp.RequestCtx) {
t.Error("Next handler should not be called when iteration limit exceeded") t.Error("Next handler should not be called when iteration limit exceeded")
} }