From cb45e824d434c7bbe971178d2ae64eb52ada1fa0 Mon Sep 17 00:00:00 2001 From: xfy Date: Tue, 7 Apr 2026 17:50:50 +0800 Subject: [PATCH] =?UTF-8?q?feat(middleware,rewrite):=20=E5=A2=9E=E5=BC=BA?= =?UTF-8?q?=20FlagLast=20=E8=AF=AD=E4=B9=89=E4=B8=8E=E5=BE=AA=E7=8E=AF?= =?UTF-8?q?=E6=A3=80=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - FlagLast 重新从第一条规则开始匹配(nginx 兼容行为) - 新增全局迭代计数器检测循环,限制最多 10 次迭代 - FlagBreak 不触发循环检测,直接停止 - 新增测试覆盖跨规则循环、迭代限制、nginx 兼容语义 Co-Authored-By: Claude Opus 4.6 --- internal/middleware/rewrite/rewrite.go | 29 ++- internal/middleware/rewrite/rewrite_test.go | 184 ++++++++++++++++++++ 2 files changed, 209 insertions(+), 4 deletions(-) diff --git a/internal/middleware/rewrite/rewrite.go b/internal/middleware/rewrite/rewrite.go index c7e161c..d221691 100644 --- a/internal/middleware/rewrite/rewrite.go +++ b/internal/middleware/rewrite/rewrite.go @@ -10,11 +10,14 @@ import ( "rua.plus/lolly/internal/config" ) +// MaxRewriteIterations URL重写最大迭代次数,防止无限循环 +const MaxRewriteIterations = 10 + // Flag 重写标志类型。 type Flag int const ( - // FlagLast 继续匹配其他规则。 + // FlagLast 继续匹配其他规则(nginx行为:重新从第一条规则开始匹配)。 FlagLast Flag = iota // FlagRedirect 返回 302 临时重定向。 FlagRedirect @@ -113,7 +116,20 @@ func (m *RewriteMiddleware) Process(next fasthttp.RequestHandler) fasthttp.Reque path := string(ctx.Path()) originalPath := path - for _, rule := range m.rules { + // 全局迭代计数器,用于检测循环(每次重写都计入迭代) + iterationCount := 0 + // 规则索引,支持FlagLast后重新开始匹配 + ruleIndex := 0 + + for ruleIndex < len(m.rules) { + // 检查迭代次数是否超过限制(在任何重写操作之前检查) + if iterationCount >= MaxRewriteIterations { + ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError) + return + } + + rule := m.rules[ruleIndex] + if rule.pattern.MatchString(path) { // 执行正则替换 newPath := rule.pattern.ReplaceAllString(path, rule.replacement) @@ -126,16 +142,21 @@ func (m *RewriteMiddleware) Process(next fasthttp.RequestHandler) fasthttp.Reque ctx.Redirect(newPath, fasthttp.StatusMovedPermanently) return case FlagBreak: - // 修改路径后停止匹配 + // 修改路径后停止匹配,不增加迭代计数(不触发循环检测) ctx.Request.SetRequestURI(newPath) next(ctx) return case FlagLast: - // 修改路径,继续匹配其他规则 + // 修改路径,并重新从第一条规则开始匹配(nginx兼容行为) path = newPath ctx.Request.SetRequestURI(path) + iterationCount++ // 每次FlagLast重写都增加计数 + ruleIndex = 0 // 重新从第一条规则开始 + continue } } + + ruleIndex++ } // 如果路径被修改过,需要重新设置 diff --git a/internal/middleware/rewrite/rewrite_test.go b/internal/middleware/rewrite/rewrite_test.go index 3de16bf..ded87e7 100644 --- a/internal/middleware/rewrite/rewrite_test.go +++ b/internal/middleware/rewrite/rewrite_test.go @@ -342,3 +342,187 @@ func TestReDoSProtection(t *testing.T) { }) } } + +// TestCrossRuleCycle 测试跨规则循环检测 +// 规则 A → B → A 应该被检测为循环 +func TestCrossRuleCycle(t *testing.T) { + // 模拟规则 A: /a 重写为 /b + // 规则 B: /b 重写为 /a + // 这将形成 A → B → A 的循环 + m, err := New([]config.RewriteRule{ + {Pattern: "^/a$", Replacement: "/b", Flag: "last"}, + {Pattern: "^/b$", Replacement: "/a", Flag: "last"}, + }) + if err != nil { + t.Fatalf("New() error: %v", err) + } + + nextHandler := func(ctx *fasthttp.RequestCtx) { + t.Error("Next handler should not be called in a loop scenario") + } + + handler := m.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/a") + + handler(ctx) + + // 应该返回 500 内部服务器错误 + if ctx.Response.StatusCode() != fasthttp.StatusInternalServerError { + t.Errorf("Expected status %d for infinite loop, got %d", + fasthttp.StatusInternalServerError, ctx.Response.StatusCode()) + } + + // 检查错误消息 + body := string(ctx.Response.Body()) + if body != "Internal Server Error" { + t.Errorf("Expected body 'Internal Server Error', got %q", body) + } +} + +// TestFlagLastRescan 测试 FlagLast 的重新扫描语义(nginx 兼容行为) +// FlagLast 应该重新从第一条规则开始匹配 +func TestFlagLastRescan(t *testing.T) { + // 规则1: /old/* → /new/* + // 规则2: /new/* → /final/* + // 当请求 /old/resource 时: + // - 规则1匹配,重写为 /new/resource,FlagLast 重新从规则1开始 + // - 规则2匹配,重写为 /final/resource + m, err := New([]config.RewriteRule{ + {Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"}, + {Pattern: "^/new/(.*)$", Replacement: "/final/$1", Flag: "last"}, + }) + if err != nil { + t.Fatalf("New() error: %v", err) + } + + handlerCalled := false + var finalPath string + nextHandler := func(ctx *fasthttp.RequestCtx) { + handlerCalled = true + finalPath = string(ctx.Path()) + } + + handler := m.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/old/resource") + + handler(ctx) + + if !handlerCalled { + t.Error("Handler was not called") + } + + // 最终路径应该是 /final/resource,因为规则1重写后 FlagLast 会重新扫描 + // 然后规则2匹配 + if finalPath != "/final/resource" { + t.Errorf("Expected final path /final/resource, got %s", finalPath) + } +} + +// TestFlagBreakNoLoop 测试 FlagBreak 不触发循环检测 +func TestFlagBreakNoLoop(t *testing.T) { + // 规则1: /a → /b,使用 break + // 规则2: /b → /a,使用 break + // break 应该停止匹配,不应该形成循环 + m, err := New([]config.RewriteRule{ + {Pattern: "^/a$", Replacement: "/b", Flag: "break"}, + {Pattern: "^/b$", Replacement: "/a", Flag: "break"}, + }) + if err != nil { + t.Fatalf("New() error: %v", err) + } + + handlerCalled := false + var finalPath string + nextHandler := func(ctx *fasthttp.RequestCtx) { + handlerCalled = true + finalPath = string(ctx.Path()) + } + + handler := m.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/a") + + handler(ctx) + + if !handlerCalled { + t.Error("Handler was not called") + } + + // 规则1匹配后 break,所以路径应该是 /b + if finalPath != "/b" { + t.Errorf("Expected final path /b (stop at break), got %s", finalPath) + } +} + +// TestIterationLimitExact 测试精确的迭代限制 +func TestIterationLimitExact(t *testing.T) { + // 两条规则形成循环: + // 规则1: /a 重写为 /b + // 规则2: /b 重写为 /a + // 从 /a 开始,经过10次迭代应该触发 500 错误 + m, err := New([]config.RewriteRule{ + {Pattern: "^/a$", Replacement: "/b", Flag: "last"}, + {Pattern: "^/b$", Replacement: "/a", Flag: "last"}, + }) + if err != nil { + t.Fatalf("New() error: %v", err) + } + + nextHandler := func(ctx *fasthttp.RequestCtx) { + t.Error("Next handler should not be called when iteration limit exceeded") + } + + handler := m.Process(nextHandler) + + // 从 /a 开始,应该触发循环检测 + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/a") + + handler(ctx) + + // 应该返回 500,因为每次匹配都会触发迭代计数 + // 迭代过程: /a -> /b -> /a -> /b -> ... 直到超过10次 + if ctx.Response.StatusCode() != fasthttp.StatusInternalServerError { + t.Errorf("Expected status %d for exceeding iteration limit, got %d", + fasthttp.StatusInternalServerError, ctx.Response.StatusCode()) + } +} + +// TestNormalRewriteNotAffected 测试正常重写不受影响 +func TestNormalRewriteNotAffected(t *testing.T) { + m, err := New([]config.RewriteRule{ + {Pattern: "^/api/v1/(.*)$", Replacement: "/api/v2/$1", Flag: "last"}, + {Pattern: "^/static/(.*)$", Replacement: "/assets/$1", Flag: "last"}, + }) + if err != nil { + t.Fatalf("New() error: %v", err) + } + + handlerCalled := false + var finalPath string + nextHandler := func(ctx *fasthttp.RequestCtx) { + handlerCalled = true + finalPath = string(ctx.Path()) + } + + handler := m.Process(nextHandler) + + // 测试 /api/v1/users 重写为 /api/v2/users + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/api/v1/users") + + handler(ctx) + + if !handlerCalled { + t.Error("Handler was not called") + } + + if finalPath != "/api/v2/users" { + t.Errorf("Expected final path /api/v2/users, got %s", finalPath) + } +}