feat(middleware,rewrite): 增强 FlagLast 语义与循环检测

- FlagLast 重新从第一条规则开始匹配(nginx 兼容行为)
- 新增全局迭代计数器检测循环,限制最多 10 次迭代
- FlagBreak 不触发循环检测,直接停止
- 新增测试覆盖跨规则循环、迭代限制、nginx 兼容语义

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-07 17:50:50 +08:00
parent 23f5c08614
commit cb45e824d4
2 changed files with 209 additions and 4 deletions

View File

@ -10,11 +10,14 @@ import (
"rua.plus/lolly/internal/config" "rua.plus/lolly/internal/config"
) )
// MaxRewriteIterations URL重写最大迭代次数防止无限循环
const MaxRewriteIterations = 10
// Flag 重写标志类型。 // Flag 重写标志类型。
type Flag int type Flag int
const ( const (
// FlagLast 继续匹配其他规则 // FlagLast 继续匹配其他规则nginx行为重新从第一条规则开始匹配
FlagLast Flag = iota FlagLast Flag = iota
// FlagRedirect 返回 302 临时重定向。 // FlagRedirect 返回 302 临时重定向。
FlagRedirect FlagRedirect
@ -113,7 +116,20 @@ func (m *RewriteMiddleware) Process(next fasthttp.RequestHandler) fasthttp.Reque
path := string(ctx.Path()) path := string(ctx.Path())
originalPath := path originalPath := path
for _, rule := range m.rules { // 全局迭代计数器,用于检测循环(每次重写都计入迭代)
iterationCount := 0
// 规则索引支持FlagLast后重新开始匹配
ruleIndex := 0
for ruleIndex < len(m.rules) {
// 检查迭代次数是否超过限制(在任何重写操作之前检查)
if iterationCount >= MaxRewriteIterations {
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
return
}
rule := m.rules[ruleIndex]
if rule.pattern.MatchString(path) { if rule.pattern.MatchString(path) {
// 执行正则替换 // 执行正则替换
newPath := rule.pattern.ReplaceAllString(path, rule.replacement) newPath := rule.pattern.ReplaceAllString(path, rule.replacement)
@ -126,16 +142,21 @@ func (m *RewriteMiddleware) Process(next fasthttp.RequestHandler) fasthttp.Reque
ctx.Redirect(newPath, fasthttp.StatusMovedPermanently) ctx.Redirect(newPath, fasthttp.StatusMovedPermanently)
return return
case FlagBreak: case FlagBreak:
// 修改路径后停止匹配 // 修改路径后停止匹配,不增加迭代计数(不触发循环检测)
ctx.Request.SetRequestURI(newPath) ctx.Request.SetRequestURI(newPath)
next(ctx) next(ctx)
return return
case FlagLast: case FlagLast:
// 修改路径,继续匹配其他规则 // 修改路径,并重新从第一条规则开始匹配nginx兼容行为
path = newPath path = newPath
ctx.Request.SetRequestURI(path) ctx.Request.SetRequestURI(path)
iterationCount++ // 每次FlagLast重写都增加计数
ruleIndex = 0 // 重新从第一条规则开始
continue
} }
} }
ruleIndex++
} }
// 如果路径被修改过,需要重新设置 // 如果路径被修改过,需要重新设置

View File

@ -342,3 +342,187 @@ func TestReDoSProtection(t *testing.T) {
}) })
} }
} }
// TestCrossRuleCycle 测试跨规则循环检测
// 规则 A → B → A 应该被检测为循环
func TestCrossRuleCycle(t *testing.T) {
// 模拟规则 A: /a 重写为 /b
// 规则 B: /b 重写为 /a
// 这将形成 A → B → A 的循环
m, err := New([]config.RewriteRule{
{Pattern: "^/a$", Replacement: "/b", Flag: "last"},
{Pattern: "^/b$", Replacement: "/a", Flag: "last"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
nextHandler := func(ctx *fasthttp.RequestCtx) {
t.Error("Next handler should not be called in a loop scenario")
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/a")
handler(ctx)
// 应该返回 500 内部服务器错误
if ctx.Response.StatusCode() != fasthttp.StatusInternalServerError {
t.Errorf("Expected status %d for infinite loop, got %d",
fasthttp.StatusInternalServerError, ctx.Response.StatusCode())
}
// 检查错误消息
body := string(ctx.Response.Body())
if body != "Internal Server Error" {
t.Errorf("Expected body 'Internal Server Error', got %q", body)
}
}
// TestFlagLastRescan 测试 FlagLast 的重新扫描语义nginx 兼容行为)
// FlagLast 应该重新从第一条规则开始匹配
func TestFlagLastRescan(t *testing.T) {
// 规则1: /old/* → /new/*
// 规则2: /new/* → /final/*
// 当请求 /old/resource 时:
// - 规则1匹配重写为 /new/resourceFlagLast 重新从规则1开始
// - 规则2匹配重写为 /final/resource
m, err := New([]config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
{Pattern: "^/new/(.*)$", Replacement: "/final/$1", Flag: "last"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
handlerCalled := false
var finalPath string
nextHandler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
finalPath = string(ctx.Path())
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/old/resource")
handler(ctx)
if !handlerCalled {
t.Error("Handler was not called")
}
// 最终路径应该是 /final/resource因为规则1重写后 FlagLast 会重新扫描
// 然后规则2匹配
if finalPath != "/final/resource" {
t.Errorf("Expected final path /final/resource, got %s", finalPath)
}
}
// TestFlagBreakNoLoop 测试 FlagBreak 不触发循环检测
func TestFlagBreakNoLoop(t *testing.T) {
// 规则1: /a → /b使用 break
// 规则2: /b → /a使用 break
// break 应该停止匹配,不应该形成循环
m, err := New([]config.RewriteRule{
{Pattern: "^/a$", Replacement: "/b", Flag: "break"},
{Pattern: "^/b$", Replacement: "/a", Flag: "break"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
handlerCalled := false
var finalPath string
nextHandler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
finalPath = string(ctx.Path())
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/a")
handler(ctx)
if !handlerCalled {
t.Error("Handler was not called")
}
// 规则1匹配后 break所以路径应该是 /b
if finalPath != "/b" {
t.Errorf("Expected final path /b (stop at break), got %s", finalPath)
}
}
// TestIterationLimitExact 测试精确的迭代限制
func TestIterationLimitExact(t *testing.T) {
// 两条规则形成循环:
// 规则1: /a 重写为 /b
// 规则2: /b 重写为 /a
// 从 /a 开始经过10次迭代应该触发 500 错误
m, err := New([]config.RewriteRule{
{Pattern: "^/a$", Replacement: "/b", Flag: "last"},
{Pattern: "^/b$", Replacement: "/a", Flag: "last"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
nextHandler := func(ctx *fasthttp.RequestCtx) {
t.Error("Next handler should not be called when iteration limit exceeded")
}
handler := m.Process(nextHandler)
// 从 /a 开始,应该触发循环检测
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/a")
handler(ctx)
// 应该返回 500因为每次匹配都会触发迭代计数
// 迭代过程: /a -> /b -> /a -> /b -> ... 直到超过10次
if ctx.Response.StatusCode() != fasthttp.StatusInternalServerError {
t.Errorf("Expected status %d for exceeding iteration limit, got %d",
fasthttp.StatusInternalServerError, ctx.Response.StatusCode())
}
}
// TestNormalRewriteNotAffected 测试正常重写不受影响
func TestNormalRewriteNotAffected(t *testing.T) {
m, err := New([]config.RewriteRule{
{Pattern: "^/api/v1/(.*)$", Replacement: "/api/v2/$1", Flag: "last"},
{Pattern: "^/static/(.*)$", Replacement: "/assets/$1", Flag: "last"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
handlerCalled := false
var finalPath string
nextHandler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
finalPath = string(ctx.Path())
}
handler := m.Process(nextHandler)
// 测试 /api/v1/users 重写为 /api/v2/users
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/api/v1/users")
handler(ctx)
if !handlerCalled {
t.Error("Handler was not called")
}
if finalPath != "/api/v2/users" {
t.Errorf("Expected final path /api/v2/users, got %s", finalPath)
}
}