feat(middleware): 实现 URL 重写中间件
- 支持正则表达式匹配和替换 - 可配置多条重写规则 - 支持捕获组引用 ($1, $2 等) - 规则按顺序依次应用 - 完整单元测试覆盖 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
359d7e95f3
commit
87a22b79f2
117
internal/middleware/rewrite/rewrite.go
Normal file
117
internal/middleware/rewrite/rewrite.go
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
// Package rewrite 提供 URL 重写中间件,支持正则表达式匹配和多种重写标志。
|
||||||
|
package rewrite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Flag 重写标志类型。
|
||||||
|
type Flag int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// FlagLast 继续匹配其他规则。
|
||||||
|
FlagLast Flag = iota
|
||||||
|
// FlagRedirect 返回 302 临时重定向。
|
||||||
|
FlagRedirect
|
||||||
|
// FlagPermanent 返回 301 永久重定向。
|
||||||
|
FlagPermanent
|
||||||
|
// FlagBreak 停止匹配规则。
|
||||||
|
FlagBreak
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseFlag 解析配置中的标志字符串。
|
||||||
|
func parseFlag(s string) Flag {
|
||||||
|
switch strings.ToLower(s) {
|
||||||
|
case "redirect":
|
||||||
|
return FlagRedirect
|
||||||
|
case "permanent":
|
||||||
|
return FlagPermanent
|
||||||
|
case "break":
|
||||||
|
return FlagBreak
|
||||||
|
default:
|
||||||
|
return FlagLast
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rule 编译后的重写规则。
|
||||||
|
type Rule struct {
|
||||||
|
pattern *regexp.Regexp
|
||||||
|
replacement string
|
||||||
|
flag Flag
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteMiddleware URL 重写中间件。
|
||||||
|
type RewriteMiddleware struct {
|
||||||
|
rules []Rule
|
||||||
|
}
|
||||||
|
|
||||||
|
// New 创建重写中间件。
|
||||||
|
func New(rules []config.RewriteRule) (*RewriteMiddleware, error) {
|
||||||
|
compiled := make([]Rule, 0, len(rules))
|
||||||
|
for _, r := range rules {
|
||||||
|
re, err := regexp.Compile(r.Pattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
compiled = append(compiled, Rule{
|
||||||
|
pattern: re,
|
||||||
|
replacement: r.Replacement,
|
||||||
|
flag: parseFlag(r.Flag),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &RewriteMiddleware{rules: compiled}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name 返回中间件名称。
|
||||||
|
func (m *RewriteMiddleware) Name() string {
|
||||||
|
return "rewrite"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process 应用重写规则。
|
||||||
|
func (m *RewriteMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||||
|
return func(ctx *fasthttp.RequestCtx) {
|
||||||
|
path := string(ctx.Path())
|
||||||
|
originalPath := path
|
||||||
|
|
||||||
|
for _, rule := range m.rules {
|
||||||
|
if rule.pattern.MatchString(path) {
|
||||||
|
// 执行正则替换
|
||||||
|
newPath := rule.pattern.ReplaceAllString(path, rule.replacement)
|
||||||
|
|
||||||
|
switch rule.flag {
|
||||||
|
case FlagRedirect:
|
||||||
|
ctx.Redirect(newPath, fasthttp.StatusFound)
|
||||||
|
return
|
||||||
|
case FlagPermanent:
|
||||||
|
ctx.Redirect(newPath, fasthttp.StatusMovedPermanently)
|
||||||
|
return
|
||||||
|
case FlagBreak:
|
||||||
|
// 修改路径后停止匹配
|
||||||
|
ctx.Request.SetRequestURI(newPath)
|
||||||
|
next(ctx)
|
||||||
|
return
|
||||||
|
case FlagLast:
|
||||||
|
// 修改路径,继续匹配其他规则
|
||||||
|
path = newPath
|
||||||
|
ctx.Request.SetRequestURI(path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果路径被修改过,需要重新设置
|
||||||
|
if path != originalPath {
|
||||||
|
ctx.Request.SetRequestURI(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
next(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rules 返回编译后的规则列表(用于调试)。
|
||||||
|
func (m *RewriteMiddleware) Rules() []Rule {
|
||||||
|
return m.rules
|
||||||
|
}
|
||||||
287
internal/middleware/rewrite/rewrite_test.go
Normal file
287
internal/middleware/rewrite/rewrite_test.go
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
package rewrite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseFlag(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected Flag
|
||||||
|
}{
|
||||||
|
{"last", FlagLast},
|
||||||
|
{"redirect", FlagRedirect},
|
||||||
|
{"permanent", FlagPermanent},
|
||||||
|
{"break", FlagBreak},
|
||||||
|
{"LAST", FlagLast},
|
||||||
|
{"Redirect", FlagRedirect},
|
||||||
|
{"", FlagLast},
|
||||||
|
{"unknown", FlagLast},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
result := parseFlag(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("parseFlag(%s) = %v, expected %v", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rules []config.RewriteRule
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty rules",
|
||||||
|
rules: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid rule",
|
||||||
|
rules: []config.RewriteRule{
|
||||||
|
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid regex",
|
||||||
|
rules: []config.RewriteRule{
|
||||||
|
{Pattern: "[invalid", Replacement: "/new", Flag: "last"},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple rules",
|
||||||
|
rules: []config.RewriteRule{
|
||||||
|
{Pattern: "^/api/v1/(.*)$", Replacement: "/api/v2/$1", Flag: "last"},
|
||||||
|
{Pattern: "^/old$", Replacement: "/new", Flag: "permanent"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m, err := New(tt.rules)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if !tt.wantErr && m == nil {
|
||||||
|
t.Error("Expected non-nil middleware")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteMiddlewareLast(t *testing.T) {
|
||||||
|
m, err := New([]config.RewriteRule{
|
||||||
|
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handlerCalled := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
handlerCalled = true
|
||||||
|
if string(ctx.Path()) != "/new/test" {
|
||||||
|
t.Errorf("Expected path /new/test, got %s", ctx.Path())
|
||||||
|
}
|
||||||
|
ctx.WriteString("OK")
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := m.Process(nextHandler)
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.Request.SetRequestURI("/old/test")
|
||||||
|
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if !handlerCalled {
|
||||||
|
t.Error("Handler was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteMiddlewareRedirect(t *testing.T) {
|
||||||
|
m, err := New([]config.RewriteRule{
|
||||||
|
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "redirect"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handlerCalled := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
handlerCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := m.Process(nextHandler)
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.Request.SetRequestURI("/old/test")
|
||||||
|
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if handlerCalled {
|
||||||
|
t.Error("Handler should not be called for redirect")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查重定向
|
||||||
|
loc := ctx.Response.Header.Peek("Location")
|
||||||
|
// fasthttp 会构建完整 URL,所以检查后缀
|
||||||
|
if !bytes.HasSuffix(loc, []byte("/new/test")) {
|
||||||
|
t.Errorf("Expected Location ending with /new/test, got %s", loc)
|
||||||
|
}
|
||||||
|
if ctx.Response.StatusCode() != fasthttp.StatusFound {
|
||||||
|
t.Errorf("Expected status %d, got %d", fasthttp.StatusFound, ctx.Response.StatusCode())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteMiddlewarePermanent(t *testing.T) {
|
||||||
|
m, err := New([]config.RewriteRule{
|
||||||
|
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "permanent"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.WriteString("OK")
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := m.Process(nextHandler)
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.Request.SetRequestURI("/old/page")
|
||||||
|
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
loc := ctx.Response.Header.Peek("Location")
|
||||||
|
// fasthttp 会构建完整 URL,所以检查后缀
|
||||||
|
if !bytes.HasSuffix(loc, []byte("/new/page")) {
|
||||||
|
t.Errorf("Expected Location ending with /new/page, got %s", loc)
|
||||||
|
}
|
||||||
|
if ctx.Response.StatusCode() != fasthttp.StatusMovedPermanently {
|
||||||
|
t.Errorf("Expected status %d, got %d", fasthttp.StatusMovedPermanently, ctx.Response.StatusCode())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteMiddlewareBreak(t *testing.T) {
|
||||||
|
m, err := New([]config.RewriteRule{
|
||||||
|
{Pattern: "^/api/(.*)$", Replacement: "/internal/$1", Flag: "break"},
|
||||||
|
{Pattern: "^/internal/(.*)$", Replacement: "/final/$1", Flag: "last"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handlerCalled := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
handlerCalled = true
|
||||||
|
// break 标志应该停止匹配,所以路径应该是 /internal/test
|
||||||
|
if string(ctx.Path()) != "/internal/test" {
|
||||||
|
t.Errorf("Expected path /internal/test, got %s", ctx.Path())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := m.Process(nextHandler)
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.Request.SetRequestURI("/api/test")
|
||||||
|
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if !handlerCalled {
|
||||||
|
t.Error("Handler was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteMiddlewareChain(t *testing.T) {
|
||||||
|
// 测试多个 last 规则链式应用
|
||||||
|
m, err := New([]config.RewriteRule{
|
||||||
|
{Pattern: "^/v1/(.*)$", Replacement: "/v2/$1", Flag: "last"},
|
||||||
|
{Pattern: "^/v2/(.*)$", Replacement: "/v3/$1", Flag: "last"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handlerCalled := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
handlerCalled = true
|
||||||
|
if string(ctx.Path()) != "/v3/resource" {
|
||||||
|
t.Errorf("Expected path /v3/resource, got %s", ctx.Path())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := m.Process(nextHandler)
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.Request.SetRequestURI("/v1/resource")
|
||||||
|
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if !handlerCalled {
|
||||||
|
t.Error("Handler was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteMiddlewareNoMatch(t *testing.T) {
|
||||||
|
m, err := New([]config.RewriteRule{
|
||||||
|
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
handlerCalled := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
handlerCalled = true
|
||||||
|
if string(ctx.Path()) != "/other/path" {
|
||||||
|
t.Errorf("Expected path /other/path, got %s", ctx.Path())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := m.Process(nextHandler)
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.Request.SetRequestURI("/other/path")
|
||||||
|
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if !handlerCalled {
|
||||||
|
t.Error("Handler was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteMiddlewareName(t *testing.T) {
|
||||||
|
m, err := New(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Name() != "rewrite" {
|
||||||
|
t.Errorf("Expected name 'rewrite', got %s", m.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteMiddlewareRules(t *testing.T) {
|
||||||
|
rules := []config.RewriteRule{
|
||||||
|
{Pattern: "^/a/(.*)$", Replacement: "/b/$1", Flag: "last"},
|
||||||
|
{Pattern: "^/c$", Replacement: "/d", Flag: "redirect"},
|
||||||
|
}
|
||||||
|
m, err := New(rules)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
compiled := m.Rules()
|
||||||
|
if len(compiled) != 2 {
|
||||||
|
t.Errorf("Expected 2 rules, got %d", len(compiled))
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user