lolly/internal/middleware/rewrite/rewrite_test.go
xfy 9d24263918 feat(stream,server,handler): 实现 Phase 6 性能优化和热升级
新增功能:
- stream 模块: 流式传输支持,优化大文件和实时数据传输
- Goroutine 池: 限制并发数量,减少调度开销
- 优雅升级: 零停机热升级,继承父进程监听器
- sendfile: 零拷贝文件传输,大文件直接从内核传输

重构改进:
- App 结构体封装,支持热升级和信号处理
- 配置结构字段对齐和代码清理
- 完善错误处理和日志记录

Co-Authored-By: Claude <noreply@anthropic.com>
2026-04-03 10:39:22 +08:00

288 lines
6.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package rewrite
import (
"bytes"
"testing"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
)
func TestParseFlag(t *testing.T) {
tests := []struct {
input string
expected Flag
}{
{"last", FlagLast},
{"redirect", FlagRedirect},
{"permanent", FlagPermanent},
{"break", FlagBreak},
{"LAST", FlagLast},
{"Redirect", FlagRedirect},
{"", FlagLast},
{"unknown", FlagLast},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := parseFlag(tt.input)
if result != tt.expected {
t.Errorf("parseFlag(%s) = %v, expected %v", tt.input, result, tt.expected)
}
})
}
}
func TestNew(t *testing.T) {
tests := []struct {
name string
rules []config.RewriteRule
wantErr bool
}{
{
name: "empty rules",
rules: nil,
},
{
name: "valid rule",
rules: []config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
},
},
{
name: "invalid regex",
rules: []config.RewriteRule{
{Pattern: "[invalid", Replacement: "/new", Flag: "last"},
},
wantErr: true,
},
{
name: "multiple rules",
rules: []config.RewriteRule{
{Pattern: "^/api/v1/(.*)$", Replacement: "/api/v2/$1", Flag: "last"},
{Pattern: "^/old$", Replacement: "/new", Flag: "permanent"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m, err := New(tt.rules)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && m == nil {
t.Error("Expected non-nil middleware")
}
})
}
}
func TestRewriteMiddlewareLast(t *testing.T) {
m, err := New([]config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
handlerCalled := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
if string(ctx.Path()) != "/new/test" {
t.Errorf("Expected path /new/test, got %s", ctx.Path())
}
ctx.WriteString("OK")
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/old/test")
handler(ctx)
if !handlerCalled {
t.Error("Handler was not called")
}
}
func TestRewriteMiddlewareRedirect(t *testing.T) {
m, err := New([]config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "redirect"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
handlerCalled := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/old/test")
handler(ctx)
if handlerCalled {
t.Error("Handler should not be called for redirect")
}
// 检查重定向
loc := ctx.Response.Header.Peek("Location")
// fasthttp 会构建完整 URL所以检查后缀
if !bytes.HasSuffix(loc, []byte("/new/test")) {
t.Errorf("Expected Location ending with /new/test, got %s", loc)
}
if ctx.Response.StatusCode() != fasthttp.StatusFound {
t.Errorf("Expected status %d, got %d", fasthttp.StatusFound, ctx.Response.StatusCode())
}
}
func TestRewriteMiddlewarePermanent(t *testing.T) {
m, err := New([]config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "permanent"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
nextHandler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("OK")
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/old/page")
handler(ctx)
loc := ctx.Response.Header.Peek("Location")
// fasthttp 会构建完整 URL所以检查后缀
if !bytes.HasSuffix(loc, []byte("/new/page")) {
t.Errorf("Expected Location ending with /new/page, got %s", loc)
}
if ctx.Response.StatusCode() != fasthttp.StatusMovedPermanently {
t.Errorf("Expected status %d, got %d", fasthttp.StatusMovedPermanently, ctx.Response.StatusCode())
}
}
func TestRewriteMiddlewareBreak(t *testing.T) {
m, err := New([]config.RewriteRule{
{Pattern: "^/api/(.*)$", Replacement: "/internal/$1", Flag: "break"},
{Pattern: "^/internal/(.*)$", Replacement: "/final/$1", Flag: "last"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
handlerCalled := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
// break 标志应该停止匹配,所以路径应该是 /internal/test
if string(ctx.Path()) != "/internal/test" {
t.Errorf("Expected path /internal/test, got %s", ctx.Path())
}
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/api/test")
handler(ctx)
if !handlerCalled {
t.Error("Handler was not called")
}
}
func TestRewriteMiddlewareChain(t *testing.T) {
// 测试多个 last 规则链式应用
m, err := New([]config.RewriteRule{
{Pattern: "^/v1/(.*)$", Replacement: "/v2/$1", Flag: "last"},
{Pattern: "^/v2/(.*)$", Replacement: "/v3/$1", Flag: "last"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
handlerCalled := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
if string(ctx.Path()) != "/v3/resource" {
t.Errorf("Expected path /v3/resource, got %s", ctx.Path())
}
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/v1/resource")
handler(ctx)
if !handlerCalled {
t.Error("Handler was not called")
}
}
func TestRewriteMiddlewareNoMatch(t *testing.T) {
m, err := New([]config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1", Flag: "last"},
})
if err != nil {
t.Fatalf("New() error: %v", err)
}
handlerCalled := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
handlerCalled = true
if string(ctx.Path()) != "/other/path" {
t.Errorf("Expected path /other/path, got %s", ctx.Path())
}
}
handler := m.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/other/path")
handler(ctx)
if !handlerCalled {
t.Error("Handler was not called")
}
}
func TestRewriteMiddlewareName(t *testing.T) {
m, err := New(nil)
if err != nil {
t.Fatalf("New() error: %v", err)
}
if m.Name() != "rewrite" {
t.Errorf("Expected name 'rewrite', got %s", m.Name())
}
}
func TestRewriteMiddlewareRules(t *testing.T) {
rules := []config.RewriteRule{
{Pattern: "^/a/(.*)$", Replacement: "/b/$1", Flag: "last"},
{Pattern: "^/c$", Replacement: "/d", Flag: "redirect"},
}
m, err := New(rules)
if err != nil {
t.Fatalf("New() error: %v", err)
}
compiled := m.Rules()
if len(compiled) != 2 {
t.Errorf("Expected 2 rules, got %d", len(compiled))
}
}