lolly/internal/middleware/rewrite/rewrite_test.go
xfy 351f477822 feat(middleware): 增强 rewrite 安全性,添加 ReDoS 保护
- 新增 validateRegexSafety 检测危险正则模式
- 防止嵌套量词导致的灾难性回溯攻击
- 限制正则模式长度 (max 1000 chars)
- 补充 compression 和 accesslog 文档注释

Co-Authored-By: Claude <noreply@anthropic.com>
2026-04-03 16:57:40 +08:00

335 lines
7.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package rewrite
import (
"bytes"
"strings"
"testing"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
)
func TestParseFlag(t *testing.T) {
tests := []struct {
input string
expected Flag
}{
{"last", FlagLast},
{"redirect", FlagRedirect},
{"permanent", FlagPermanent},
{"break", FlagBreak},
{"LAST", FlagLast},
{"Redirect", FlagRedirect},
{"", FlagLast},
{"unknown", FlagLast},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := parseFlag(tt.input)
if result != tt.expected {
t.Errorf("parseFlag(%s) = %v, expected %v", tt.input, result, tt.expected)
}
})
}
}
func TestNew(t *testing.T) {
tests := []struct {
name string
rules []config.RewriteRule
wantErr bool
}{
{
name: "empty rules",
rules: nil,
},
{
name: "valid rule",
rules: []config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
},
},
{
name: "invalid regex",
rules: []config.RewriteRule{
{Pattern: "[invalid", Replacement: "/new", Flag: "last"},
},
wantErr: true,
},
{
name: "multiple rules",
rules: []config.RewriteRule{
{Pattern: "^/api/v1/(.*)$", Replacement: "/api/v2/$1", Flag: "last"},
{Pattern: "^/old$", Replacement: "/new", Flag: "permanent"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m, err := New(tt.rules)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && m == nil {
t.Error("Expected non-nil middleware")
}
})
}
}
func TestRewriteMiddlewareLast(t *testing.T) {
m, err := New([]config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
handlerCalled := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
if string(ctx.Path()) != "/new/test" {
t.Errorf("Expected path /new/test, got %s", ctx.Path())
}
ctx.WriteString("OK")
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/old/test")
handler(ctx)
if !handlerCalled {
t.Error("Handler was not called")
}
}
func TestRewriteMiddlewareRedirect(t *testing.T) {
m, err := New([]config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "redirect"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
handlerCalled := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/old/test")
handler(ctx)
if handlerCalled {
t.Error("Handler should not be called for redirect")
}
// 检查重定向
loc := ctx.Response.Header.Peek("Location")
// fasthttp 会构建完整 URL所以检查后缀
if !bytes.HasSuffix(loc, []byte("/new/test")) {
t.Errorf("Expected Location ending with /new/test, got %s", loc)
}
if ctx.Response.StatusCode() != fasthttp.StatusFound {
t.Errorf("Expected status %d, got %d", fasthttp.StatusFound, ctx.Response.StatusCode())
}
}
func TestRewriteMiddlewarePermanent(t *testing.T) {
m, err := New([]config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "permanent"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
nextHandler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("OK")
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/old/page")
handler(ctx)
loc := ctx.Response.Header.Peek("Location")
// fasthttp 会构建完整 URL所以检查后缀
if !bytes.HasSuffix(loc, []byte("/new/page")) {
t.Errorf("Expected Location ending with /new/page, got %s", loc)
}
if ctx.Response.StatusCode() != fasthttp.StatusMovedPermanently {
t.Errorf("Expected status %d, got %d", fasthttp.StatusMovedPermanently, ctx.Response.StatusCode())
}
}
func TestRewriteMiddlewareBreak(t *testing.T) {
m, err := New([]config.RewriteRule{
{Pattern: "^/api/(.*)$", Replacement: "/internal/$1", Flag: "break"},
{Pattern: "^/internal/(.*)$", Replacement: "/final/$1", Flag: "last"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
handlerCalled := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
// break 标志应该停止匹配,所以路径应该是 /internal/test
if string(ctx.Path()) != "/internal/test" {
t.Errorf("Expected path /internal/test, got %s", ctx.Path())
}
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/api/test")
handler(ctx)
if !handlerCalled {
t.Error("Handler was not called")
}
}
func TestRewriteMiddlewareChain(t *testing.T) {
// 测试多个 last 规则链式应用
m, err := New([]config.RewriteRule{
{Pattern: "^/v1/(.*)$", Replacement: "/v2/$1", Flag: "last"},
{Pattern: "^/v2/(.*)$", Replacement: "/v3/$1", Flag: "last"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
handlerCalled := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
if string(ctx.Path()) != "/v3/resource" {
t.Errorf("Expected path /v3/resource, got %s", ctx.Path())
}
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/v1/resource")
handler(ctx)
if !handlerCalled {
t.Error("Handler was not called")
}
}
func TestRewriteMiddlewareNoMatch(t *testing.T) {
m, err := New([]config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
handlerCalled := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
if string(ctx.Path()) != "/other/path" {
t.Errorf("Expected path /other/path, got %s", ctx.Path())
}
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/other/path")
handler(ctx)
if !handlerCalled {
t.Error("Handler was not called")
}
}
func TestRewriteMiddlewareName(t *testing.T) {
m, err := New(nil)
if err != nil {
t.Fatalf("New() error: %v", err)
}
if m.Name() != "rewrite" {
t.Errorf("Expected name 'rewrite', got %s", m.Name())
}
}
func TestRewriteMiddlewareRules(t *testing.T) {
rules := []config.RewriteRule{
{Pattern: "^/a/(.*)$", Replacement: "/b/$1", Flag: "last"},
{Pattern: "^/c$", Replacement: "/d", Flag: "redirect"},
}
m, err := New(rules)
if err != nil {
t.Fatalf("New() error: %v", err)
}
compiled := m.Rules()
if len(compiled) != 2 {
t.Errorf("Expected 2 rules, got %d", len(compiled))
}
}
func TestReDoSProtection(t *testing.T) {
tests := []struct {
name string
pattern string
wantErr bool
}{
{
name: "safe pattern",
pattern: "^/api/v1/(.*)$",
wantErr: false,
},
{
name: "nested quantifier (\\w+)+",
pattern: `(\w+)+`,
wantErr: true,
},
{
name: "nested quantifier (.+)+",
pattern: `(.+)+`,
wantErr: true,
},
{
name: "nested quantifier (\\d+)+",
pattern: `(\d+)+`,
wantErr: true,
},
{
name: "pattern too long",
pattern: strings.Repeat("a", 1001),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rules := []config.RewriteRule{
{Pattern: tt.pattern, Replacement: "/new", Flag: "last"},
}
_, err := New(rules)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}