From 87a22b79f229b6108e290fd27a71eac829517ec6 Mon Sep 17 00:00:00 2001 From: xfy Date: Fri, 3 Apr 2026 10:11:50 +0800 Subject: [PATCH] =?UTF-8?q?feat(middleware):=20=E5=AE=9E=E7=8E=B0=20URL=20?= =?UTF-8?q?=E9=87=8D=E5=86=99=E4=B8=AD=E9=97=B4=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 支持正则表达式匹配和替换 - 可配置多条重写规则 - 支持捕获组引用 ($1, $2 等) - 规则按顺序依次应用 - 完整单元测试覆盖 Co-Authored-By: Claude Opus 4.6 --- internal/middleware/rewrite/rewrite.go | 117 ++++++++ internal/middleware/rewrite/rewrite_test.go | 287 ++++++++++++++++++++ 2 files changed, 404 insertions(+) create mode 100644 internal/middleware/rewrite/rewrite.go create mode 100644 internal/middleware/rewrite/rewrite_test.go diff --git a/internal/middleware/rewrite/rewrite.go b/internal/middleware/rewrite/rewrite.go new file mode 100644 index 0000000..7e5ebe0 --- /dev/null +++ b/internal/middleware/rewrite/rewrite.go @@ -0,0 +1,117 @@ +// Package rewrite 提供 URL 重写中间件,支持正则表达式匹配和多种重写标志。 +package rewrite + +import ( + "regexp" + "strings" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/config" +) + +// Flag 重写标志类型。 +type Flag int + +const ( + // FlagLast 继续匹配其他规则。 + FlagLast Flag = iota + // FlagRedirect 返回 302 临时重定向。 + FlagRedirect + // FlagPermanent 返回 301 永久重定向。 + FlagPermanent + // FlagBreak 停止匹配规则。 + FlagBreak +) + +// parseFlag 解析配置中的标志字符串。 +func parseFlag(s string) Flag { + switch strings.ToLower(s) { + case "redirect": + return FlagRedirect + case "permanent": + return FlagPermanent + case "break": + return FlagBreak + default: + return FlagLast + } +} + +// Rule 编译后的重写规则。 +type Rule struct { + pattern *regexp.Regexp + replacement string + flag Flag +} + +// RewriteMiddleware URL 重写中间件。 +type RewriteMiddleware struct { + rules []Rule +} + +// New 创建重写中间件。 +func New(rules []config.RewriteRule) (*RewriteMiddleware, error) { + compiled := make([]Rule, 0, len(rules)) + for _, r := range rules { + re, err := regexp.Compile(r.Pattern) + if err != nil { + return nil, err + } + compiled = append(compiled, Rule{ + pattern: re, + replacement: r.Replacement, + flag: parseFlag(r.Flag), + }) + } + return &RewriteMiddleware{rules: compiled}, nil +} + +// Name 返回中间件名称。 +func (m *RewriteMiddleware) Name() string { + return "rewrite" +} + +// Process 应用重写规则。 +func (m *RewriteMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + path := string(ctx.Path()) + originalPath := path + + for _, rule := range m.rules { + if rule.pattern.MatchString(path) { + // 执行正则替换 + newPath := rule.pattern.ReplaceAllString(path, rule.replacement) + + switch rule.flag { + case FlagRedirect: + ctx.Redirect(newPath, fasthttp.StatusFound) + return + case FlagPermanent: + ctx.Redirect(newPath, fasthttp.StatusMovedPermanently) + return + case FlagBreak: + // 修改路径后停止匹配 + ctx.Request.SetRequestURI(newPath) + next(ctx) + return + case FlagLast: + // 修改路径,继续匹配其他规则 + path = newPath + ctx.Request.SetRequestURI(path) + } + } + } + + // 如果路径被修改过,需要重新设置 + if path != originalPath { + ctx.Request.SetRequestURI(path) + } + + next(ctx) + } +} + +// Rules 返回编译后的规则列表(用于调试)。 +func (m *RewriteMiddleware) Rules() []Rule { + return m.rules +} \ No newline at end of file diff --git a/internal/middleware/rewrite/rewrite_test.go b/internal/middleware/rewrite/rewrite_test.go new file mode 100644 index 0000000..6bbd76a --- /dev/null +++ b/internal/middleware/rewrite/rewrite_test.go @@ -0,0 +1,287 @@ +package rewrite + +import ( + "bytes" + "testing" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/config" +) + +func TestParseFlag(t *testing.T) { + tests := []struct { + input string + expected Flag + }{ + {"last", FlagLast}, + {"redirect", FlagRedirect}, + {"permanent", FlagPermanent}, + {"break", FlagBreak}, + {"LAST", FlagLast}, + {"Redirect", FlagRedirect}, + {"", FlagLast}, + {"unknown", FlagLast}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := parseFlag(tt.input) + if result != tt.expected { + t.Errorf("parseFlag(%s) = %v, expected %v", tt.input, result, tt.expected) + } + }) + } +} + +func TestNew(t *testing.T) { + tests := []struct { + name string + rules []config.RewriteRule + wantErr bool + }{ + { + name: "empty rules", + rules: nil, + }, + { + name: "valid rule", + rules: []config.RewriteRule{ + {Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"}, + }, + }, + { + name: "invalid regex", + rules: []config.RewriteRule{ + {Pattern: "[invalid", Replacement: "/new", Flag: "last"}, + }, + wantErr: true, + }, + { + name: "multiple rules", + rules: []config.RewriteRule{ + {Pattern: "^/api/v1/(.*)$", Replacement: "/api/v2/$1", Flag: "last"}, + {Pattern: "^/old$", Replacement: "/new", Flag: "permanent"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m, err := New(tt.rules) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr && m == nil { + t.Error("Expected non-nil middleware") + } + }) + } +} + +func TestRewriteMiddlewareLast(t *testing.T) { + m, err := New([]config.RewriteRule{ + {Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"}, + }) + if err != nil { + t.Fatalf("New() error: %v", err) + } + + handlerCalled := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + handlerCalled = true + if string(ctx.Path()) != "/new/test" { + t.Errorf("Expected path /new/test, got %s", ctx.Path()) + } + ctx.WriteString("OK") + } + + handler := m.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/old/test") + + handler(ctx) + + if !handlerCalled { + t.Error("Handler was not called") + } +} + +func TestRewriteMiddlewareRedirect(t *testing.T) { + m, err := New([]config.RewriteRule{ + {Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "redirect"}, + }) + if err != nil { + t.Fatalf("New() error: %v", err) + } + + handlerCalled := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + handlerCalled = true + } + + handler := m.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/old/test") + + handler(ctx) + + if handlerCalled { + t.Error("Handler should not be called for redirect") + } + + // 检查重定向 + loc := ctx.Response.Header.Peek("Location") + // fasthttp 会构建完整 URL,所以检查后缀 + if !bytes.HasSuffix(loc, []byte("/new/test")) { + t.Errorf("Expected Location ending with /new/test, got %s", loc) + } + if ctx.Response.StatusCode() != fasthttp.StatusFound { + t.Errorf("Expected status %d, got %d", fasthttp.StatusFound, ctx.Response.StatusCode()) + } +} + +func TestRewriteMiddlewarePermanent(t *testing.T) { + m, err := New([]config.RewriteRule{ + {Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "permanent"}, + }) + if err != nil { + t.Fatalf("New() error: %v", err) + } + + nextHandler := func(ctx *fasthttp.RequestCtx) { + ctx.WriteString("OK") + } + + handler := m.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/old/page") + + handler(ctx) + + loc := ctx.Response.Header.Peek("Location") + // fasthttp 会构建完整 URL,所以检查后缀 + if !bytes.HasSuffix(loc, []byte("/new/page")) { + t.Errorf("Expected Location ending with /new/page, got %s", loc) + } + if ctx.Response.StatusCode() != fasthttp.StatusMovedPermanently { + t.Errorf("Expected status %d, got %d", fasthttp.StatusMovedPermanently, ctx.Response.StatusCode()) + } +} + +func TestRewriteMiddlewareBreak(t *testing.T) { + m, err := New([]config.RewriteRule{ + {Pattern: "^/api/(.*)$", Replacement: "/internal/$1", Flag: "break"}, + {Pattern: "^/internal/(.*)$", Replacement: "/final/$1", Flag: "last"}, + }) + if err != nil { + t.Fatalf("New() error: %v", err) + } + + handlerCalled := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + handlerCalled = true + // break 标志应该停止匹配,所以路径应该是 /internal/test + if string(ctx.Path()) != "/internal/test" { + t.Errorf("Expected path /internal/test, got %s", ctx.Path()) + } + } + + handler := m.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/api/test") + + handler(ctx) + + if !handlerCalled { + t.Error("Handler was not called") + } +} + +func TestRewriteMiddlewareChain(t *testing.T) { + // 测试多个 last 规则链式应用 + m, err := New([]config.RewriteRule{ + {Pattern: "^/v1/(.*)$", Replacement: "/v2/$1", Flag: "last"}, + {Pattern: "^/v2/(.*)$", Replacement: "/v3/$1", Flag: "last"}, + }) + if err != nil { + t.Fatalf("New() error: %v", err) + } + + handlerCalled := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + handlerCalled = true + if string(ctx.Path()) != "/v3/resource" { + t.Errorf("Expected path /v3/resource, got %s", ctx.Path()) + } + } + + handler := m.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/v1/resource") + + handler(ctx) + + if !handlerCalled { + t.Error("Handler was not called") + } +} + +func TestRewriteMiddlewareNoMatch(t *testing.T) { + m, err := New([]config.RewriteRule{ + {Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"}, + }) + if err != nil { + t.Fatalf("New() error: %v", err) + } + + handlerCalled := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + handlerCalled = true + if string(ctx.Path()) != "/other/path" { + t.Errorf("Expected path /other/path, got %s", ctx.Path()) + } + } + + handler := m.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/other/path") + + handler(ctx) + + if !handlerCalled { + t.Error("Handler was not called") + } +} + +func TestRewriteMiddlewareName(t *testing.T) { + m, err := New(nil) + if err != nil { + t.Fatalf("New() error: %v", err) + } + + if m.Name() != "rewrite" { + t.Errorf("Expected name 'rewrite', got %s", m.Name()) + } +} + +func TestRewriteMiddlewareRules(t *testing.T) { + rules := []config.RewriteRule{ + {Pattern: "^/a/(.*)$", Replacement: "/b/$1", Flag: "last"}, + {Pattern: "^/c$", Replacement: "/d", Flag: "redirect"}, + } + m, err := New(rules) + if err != nil { + t.Fatalf("New() error: %v", err) + } + + compiled := m.Rules() + if len(compiled) != 2 { + t.Errorf("Expected 2 rules, got %d", len(compiled)) + } +} \ No newline at end of file