- FlagLast 重新从第一条规则开始匹配(nginx 兼容行为) - 新增全局迭代计数器检测循环,限制最多 10 次迭代 - FlagBreak 不触发循环检测,直接停止 - 新增测试覆盖跨规则循环、迭代限制、nginx 兼容语义 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
529 lines
13 KiB
Go
529 lines
13 KiB
Go
// Package rewrite 提供 URL 重写功能的测试。
|
||
//
|
||
// 该文件测试 URL 重写模块的各项功能,包括:
|
||
// - 重写规则解析
|
||
// - 正则表达式匹配
|
||
// - 重定向和重写
|
||
// - 规则链执行
|
||
// - ReDoS 防护
|
||
//
|
||
// 作者:xfy
|
||
package rewrite
|
||
|
||
import (
|
||
"bytes"
|
||
"strings"
|
||
"testing"
|
||
|
||
"github.com/valyala/fasthttp"
|
||
"rua.plus/lolly/internal/config"
|
||
)
|
||
|
||
func TestParseFlag(t *testing.T) {
|
||
tests := []struct {
|
||
input string
|
||
expected Flag
|
||
}{
|
||
{"last", FlagLast},
|
||
{"redirect", FlagRedirect},
|
||
{"permanent", FlagPermanent},
|
||
{"break", FlagBreak},
|
||
{"LAST", FlagLast},
|
||
{"Redirect", FlagRedirect},
|
||
{"", FlagLast},
|
||
{"unknown", FlagLast},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.input, func(t *testing.T) {
|
||
result := parseFlag(tt.input)
|
||
if result != tt.expected {
|
||
t.Errorf("parseFlag(%s) = %v, expected %v", tt.input, result, tt.expected)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestNew(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
rules []config.RewriteRule
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "empty rules",
|
||
rules: nil,
|
||
},
|
||
{
|
||
name: "valid rule",
|
||
rules: []config.RewriteRule{
|
||
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
|
||
},
|
||
},
|
||
{
|
||
name: "invalid regex",
|
||
rules: []config.RewriteRule{
|
||
{Pattern: "[invalid", Replacement: "/new", Flag: "last"},
|
||
},
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "multiple rules",
|
||
rules: []config.RewriteRule{
|
||
{Pattern: "^/api/v1/(.*)$", Replacement: "/api/v2/$1", Flag: "last"},
|
||
{Pattern: "^/old$", Replacement: "/new", Flag: "permanent"},
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
m, err := New(tt.rules)
|
||
if (err != nil) != tt.wantErr {
|
||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||
}
|
||
if !tt.wantErr && m == nil {
|
||
t.Error("Expected non-nil middleware")
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestRewriteMiddlewareLast(t *testing.T) {
|
||
m, err := New([]config.RewriteRule{
|
||
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("New() error: %v", err)
|
||
}
|
||
|
||
handlerCalled := false
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
handlerCalled = true
|
||
if string(ctx.Path()) != "/new/test" {
|
||
t.Errorf("Expected path /new/test, got %s", ctx.Path())
|
||
}
|
||
_, _ = ctx.WriteString("OK")
|
||
}
|
||
|
||
handler := m.Process(nextHandler)
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/old/test")
|
||
|
||
handler(ctx)
|
||
|
||
if !handlerCalled {
|
||
t.Error("Handler was not called")
|
||
}
|
||
}
|
||
|
||
func TestRewriteMiddlewareRedirect(t *testing.T) {
|
||
m, err := New([]config.RewriteRule{
|
||
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "redirect"},
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("New() error: %v", err)
|
||
}
|
||
|
||
handlerCalled := false
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
handlerCalled = true
|
||
}
|
||
|
||
handler := m.Process(nextHandler)
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/old/test")
|
||
|
||
handler(ctx)
|
||
|
||
if handlerCalled {
|
||
t.Error("Handler should not be called for redirect")
|
||
}
|
||
|
||
// 检查重定向
|
||
loc := ctx.Response.Header.Peek("Location")
|
||
// fasthttp 会构建完整 URL,所以检查后缀
|
||
if !bytes.HasSuffix(loc, []byte("/new/test")) {
|
||
t.Errorf("Expected Location ending with /new/test, got %s", loc)
|
||
}
|
||
if ctx.Response.StatusCode() != fasthttp.StatusFound {
|
||
t.Errorf("Expected status %d, got %d", fasthttp.StatusFound, ctx.Response.StatusCode())
|
||
}
|
||
}
|
||
|
||
func TestRewriteMiddlewarePermanent(t *testing.T) {
|
||
m, err := New([]config.RewriteRule{
|
||
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "permanent"},
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("New() error: %v", err)
|
||
}
|
||
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
_, _ = ctx.WriteString("OK")
|
||
}
|
||
|
||
handler := m.Process(nextHandler)
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/old/page")
|
||
|
||
handler(ctx)
|
||
|
||
loc := ctx.Response.Header.Peek("Location")
|
||
// fasthttp 会构建完整 URL,所以检查后缀
|
||
if !bytes.HasSuffix(loc, []byte("/new/page")) {
|
||
t.Errorf("Expected Location ending with /new/page, got %s", loc)
|
||
}
|
||
if ctx.Response.StatusCode() != fasthttp.StatusMovedPermanently {
|
||
t.Errorf("Expected status %d, got %d", fasthttp.StatusMovedPermanently, ctx.Response.StatusCode())
|
||
}
|
||
}
|
||
|
||
func TestRewriteMiddlewareBreak(t *testing.T) {
|
||
m, err := New([]config.RewriteRule{
|
||
{Pattern: "^/api/(.*)$", Replacement: "/internal/$1", Flag: "break"},
|
||
{Pattern: "^/internal/(.*)$", Replacement: "/final/$1", Flag: "last"},
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("New() error: %v", err)
|
||
}
|
||
|
||
handlerCalled := false
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
handlerCalled = true
|
||
// break 标志应该停止匹配,所以路径应该是 /internal/test
|
||
if string(ctx.Path()) != "/internal/test" {
|
||
t.Errorf("Expected path /internal/test, got %s", ctx.Path())
|
||
}
|
||
}
|
||
|
||
handler := m.Process(nextHandler)
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/api/test")
|
||
|
||
handler(ctx)
|
||
|
||
if !handlerCalled {
|
||
t.Error("Handler was not called")
|
||
}
|
||
}
|
||
|
||
func TestRewriteMiddlewareChain(t *testing.T) {
|
||
// 测试多个 last 规则链式应用
|
||
m, err := New([]config.RewriteRule{
|
||
{Pattern: "^/v1/(.*)$", Replacement: "/v2/$1", Flag: "last"},
|
||
{Pattern: "^/v2/(.*)$", Replacement: "/v3/$1", Flag: "last"},
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("New() error: %v", err)
|
||
}
|
||
|
||
handlerCalled := false
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
handlerCalled = true
|
||
if string(ctx.Path()) != "/v3/resource" {
|
||
t.Errorf("Expected path /v3/resource, got %s", ctx.Path())
|
||
}
|
||
}
|
||
|
||
handler := m.Process(nextHandler)
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/v1/resource")
|
||
|
||
handler(ctx)
|
||
|
||
if !handlerCalled {
|
||
t.Error("Handler was not called")
|
||
}
|
||
}
|
||
|
||
func TestRewriteMiddlewareNoMatch(t *testing.T) {
|
||
m, err := New([]config.RewriteRule{
|
||
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("New() error: %v", err)
|
||
}
|
||
|
||
handlerCalled := false
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
handlerCalled = true
|
||
if string(ctx.Path()) != "/other/path" {
|
||
t.Errorf("Expected path /other/path, got %s", ctx.Path())
|
||
}
|
||
}
|
||
|
||
handler := m.Process(nextHandler)
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/other/path")
|
||
|
||
handler(ctx)
|
||
|
||
if !handlerCalled {
|
||
t.Error("Handler was not called")
|
||
}
|
||
}
|
||
|
||
func TestRewriteMiddlewareName(t *testing.T) {
|
||
m, err := New(nil)
|
||
if err != nil {
|
||
t.Fatalf("New() error: %v", err)
|
||
}
|
||
|
||
if m.Name() != "rewrite" {
|
||
t.Errorf("Expected name 'rewrite', got %s", m.Name())
|
||
}
|
||
}
|
||
|
||
func TestRewriteMiddlewareRules(t *testing.T) {
|
||
rules := []config.RewriteRule{
|
||
{Pattern: "^/a/(.*)$", Replacement: "/b/$1", Flag: "last"},
|
||
{Pattern: "^/c$", Replacement: "/d", Flag: "redirect"},
|
||
}
|
||
m, err := New(rules)
|
||
if err != nil {
|
||
t.Fatalf("New() error: %v", err)
|
||
}
|
||
|
||
compiled := m.Rules()
|
||
if len(compiled) != 2 {
|
||
t.Errorf("Expected 2 rules, got %d", len(compiled))
|
||
}
|
||
}
|
||
|
||
func TestReDoSProtection(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
pattern string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "safe pattern",
|
||
pattern: "^/api/v1/(.*)$",
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "nested quantifier (\\w+)+",
|
||
pattern: `(\w+)+`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "nested quantifier (.+)+",
|
||
pattern: `(.+)+`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "nested quantifier (\\d+)+",
|
||
pattern: `(\d+)+`,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "pattern too long",
|
||
pattern: strings.Repeat("a", 1001),
|
||
wantErr: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
rules := []config.RewriteRule{
|
||
{Pattern: tt.pattern, Replacement: "/new", Flag: "last"},
|
||
}
|
||
_, err := New(rules)
|
||
if (err != nil) != tt.wantErr {
|
||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestCrossRuleCycle 测试跨规则循环检测
|
||
// 规则 A → B → A 应该被检测为循环
|
||
func TestCrossRuleCycle(t *testing.T) {
|
||
// 模拟规则 A: /a 重写为 /b
|
||
// 规则 B: /b 重写为 /a
|
||
// 这将形成 A → B → A 的循环
|
||
m, err := New([]config.RewriteRule{
|
||
{Pattern: "^/a$", Replacement: "/b", Flag: "last"},
|
||
{Pattern: "^/b$", Replacement: "/a", Flag: "last"},
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("New() error: %v", err)
|
||
}
|
||
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
t.Error("Next handler should not be called in a loop scenario")
|
||
}
|
||
|
||
handler := m.Process(nextHandler)
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/a")
|
||
|
||
handler(ctx)
|
||
|
||
// 应该返回 500 内部服务器错误
|
||
if ctx.Response.StatusCode() != fasthttp.StatusInternalServerError {
|
||
t.Errorf("Expected status %d for infinite loop, got %d",
|
||
fasthttp.StatusInternalServerError, ctx.Response.StatusCode())
|
||
}
|
||
|
||
// 检查错误消息
|
||
body := string(ctx.Response.Body())
|
||
if body != "Internal Server Error" {
|
||
t.Errorf("Expected body 'Internal Server Error', got %q", body)
|
||
}
|
||
}
|
||
|
||
// TestFlagLastRescan 测试 FlagLast 的重新扫描语义(nginx 兼容行为)
|
||
// FlagLast 应该重新从第一条规则开始匹配
|
||
func TestFlagLastRescan(t *testing.T) {
|
||
// 规则1: /old/* → /new/*
|
||
// 规则2: /new/* → /final/*
|
||
// 当请求 /old/resource 时:
|
||
// - 规则1匹配,重写为 /new/resource,FlagLast 重新从规则1开始
|
||
// - 规则2匹配,重写为 /final/resource
|
||
m, err := New([]config.RewriteRule{
|
||
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
|
||
{Pattern: "^/new/(.*)$", Replacement: "/final/$1", Flag: "last"},
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("New() error: %v", err)
|
||
}
|
||
|
||
handlerCalled := false
|
||
var finalPath string
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
handlerCalled = true
|
||
finalPath = string(ctx.Path())
|
||
}
|
||
|
||
handler := m.Process(nextHandler)
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/old/resource")
|
||
|
||
handler(ctx)
|
||
|
||
if !handlerCalled {
|
||
t.Error("Handler was not called")
|
||
}
|
||
|
||
// 最终路径应该是 /final/resource,因为规则1重写后 FlagLast 会重新扫描
|
||
// 然后规则2匹配
|
||
if finalPath != "/final/resource" {
|
||
t.Errorf("Expected final path /final/resource, got %s", finalPath)
|
||
}
|
||
}
|
||
|
||
// TestFlagBreakNoLoop 测试 FlagBreak 不触发循环检测
|
||
func TestFlagBreakNoLoop(t *testing.T) {
|
||
// 规则1: /a → /b,使用 break
|
||
// 规则2: /b → /a,使用 break
|
||
// break 应该停止匹配,不应该形成循环
|
||
m, err := New([]config.RewriteRule{
|
||
{Pattern: "^/a$", Replacement: "/b", Flag: "break"},
|
||
{Pattern: "^/b$", Replacement: "/a", Flag: "break"},
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("New() error: %v", err)
|
||
}
|
||
|
||
handlerCalled := false
|
||
var finalPath string
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
handlerCalled = true
|
||
finalPath = string(ctx.Path())
|
||
}
|
||
|
||
handler := m.Process(nextHandler)
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/a")
|
||
|
||
handler(ctx)
|
||
|
||
if !handlerCalled {
|
||
t.Error("Handler was not called")
|
||
}
|
||
|
||
// 规则1匹配后 break,所以路径应该是 /b
|
||
if finalPath != "/b" {
|
||
t.Errorf("Expected final path /b (stop at break), got %s", finalPath)
|
||
}
|
||
}
|
||
|
||
// TestIterationLimitExact 测试精确的迭代限制
|
||
func TestIterationLimitExact(t *testing.T) {
|
||
// 两条规则形成循环:
|
||
// 规则1: /a 重写为 /b
|
||
// 规则2: /b 重写为 /a
|
||
// 从 /a 开始,经过10次迭代应该触发 500 错误
|
||
m, err := New([]config.RewriteRule{
|
||
{Pattern: "^/a$", Replacement: "/b", Flag: "last"},
|
||
{Pattern: "^/b$", Replacement: "/a", Flag: "last"},
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("New() error: %v", err)
|
||
}
|
||
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
t.Error("Next handler should not be called when iteration limit exceeded")
|
||
}
|
||
|
||
handler := m.Process(nextHandler)
|
||
|
||
// 从 /a 开始,应该触发循环检测
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/a")
|
||
|
||
handler(ctx)
|
||
|
||
// 应该返回 500,因为每次匹配都会触发迭代计数
|
||
// 迭代过程: /a -> /b -> /a -> /b -> ... 直到超过10次
|
||
if ctx.Response.StatusCode() != fasthttp.StatusInternalServerError {
|
||
t.Errorf("Expected status %d for exceeding iteration limit, got %d",
|
||
fasthttp.StatusInternalServerError, ctx.Response.StatusCode())
|
||
}
|
||
}
|
||
|
||
// TestNormalRewriteNotAffected 测试正常重写不受影响
|
||
func TestNormalRewriteNotAffected(t *testing.T) {
|
||
m, err := New([]config.RewriteRule{
|
||
{Pattern: "^/api/v1/(.*)$", Replacement: "/api/v2/$1", Flag: "last"},
|
||
{Pattern: "^/static/(.*)$", Replacement: "/assets/$1", Flag: "last"},
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("New() error: %v", err)
|
||
}
|
||
|
||
handlerCalled := false
|
||
var finalPath string
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
handlerCalled = true
|
||
finalPath = string(ctx.Path())
|
||
}
|
||
|
||
handler := m.Process(nextHandler)
|
||
|
||
// 测试 /api/v1/users 重写为 /api/v2/users
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/api/v1/users")
|
||
|
||
handler(ctx)
|
||
|
||
if !handlerCalled {
|
||
t.Error("Handler was not called")
|
||
}
|
||
|
||
if finalPath != "/api/v2/users" {
|
||
t.Errorf("Expected final path /api/v2/users, got %s", finalPath)
|
||
}
|
||
}
|