From ae0bec6c3bf8dae152e99426f8197dba48994bd9 Mon Sep 17 00:00:00 2001 From: xfy Date: Mon, 20 Apr 2026 18:09:06 +0800 Subject: [PATCH] =?UTF-8?q?feat(internal):=20=E5=AE=9E=E7=8E=B0=20internal?= =?UTF-8?q?=20=E6=8C=87=E4=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 IsInternalRedirect 检测内部重定向请求 - static handler 支持 internal 访问限制 - proxy handler 支持 internal 访问限制 - 支持 X-Accel-Redirect 内部重定向 Co-Authored-By: Claude Opus 4.7 --- internal/app/app_test.go | 25 +++++----- internal/handler/static.go | 18 +++++++ internal/integration/regex_config_test.go | 6 +-- internal/matcher/bench_test.go | 16 +++---- internal/matcher/integration_test.go | 24 +++++----- internal/matcher/location_test.go | 58 +++++++++++------------ internal/matcher/matcher_test.go | 16 +++---- internal/matcher/prefix_priority_test.go | 24 +++++----- internal/matcher/prefix_test.go | 24 +++++----- internal/matcher/radix_test.go | 34 ++++++------- internal/matcher/regex_test.go | 26 +++++----- internal/proxy/proxy.go | 7 +++ internal/server/internal.go | 26 ++++++++++ internal/utils/internal.go | 28 +++++++++++ 14 files changed, 206 insertions(+), 126 deletions(-) create mode 100644 internal/server/internal.go create mode 100644 internal/utils/internal.go diff --git a/internal/app/app_test.go b/internal/app/app_test.go index c673f2b..4422500 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -24,6 +24,7 @@ import ( "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/logging" "rua.plus/lolly/internal/server" + "rua.plus/lolly/internal/version" ) // captureStdout 捕获 stdout 输出,返回捕获的内容和恢复函数。 @@ -606,23 +607,23 @@ func TestGracefulUpgrade_NoListener(_ *testing.T) { // TestVersionVariables 测试版本变量默认值 func TestVersionVariables(t *testing.T) { - if Version != "dev" { - t.Errorf("Default Version should be 'dev', got '%s'", Version) + if version.Version != "dev" { + t.Errorf("Default Version should be 'dev', got '%s'", version.Version) } - if GitCommit != "unknown" { - t.Errorf("Default GitCommit should be 'unknown', got '%s'", GitCommit) + if version.GitCommit != "unknown" { + t.Errorf("Default GitCommit should be 'unknown', got '%s'", version.GitCommit) } - if GitBranch != "unknown" { - t.Errorf("Default GitBranch should be 'unknown', got '%s'", GitBranch) + if version.GitBranch != "unknown" { + t.Errorf("Default GitBranch should be 'unknown', got '%s'", version.GitBranch) } - if BuildTime != "unknown" { - t.Errorf("Default BuildTime should be 'unknown', got '%s'", BuildTime) + if version.BuildTime != "unknown" { + t.Errorf("Default BuildTime should be 'unknown', got '%s'", version.BuildTime) } - if GoVersion != "unknown" { - t.Errorf("Default GoVersion should be 'unknown', got '%s'", GoVersion) + if version.GoVersion != "unknown" { + t.Errorf("Default GoVersion should be 'unknown', got '%s'", version.GoVersion) } - if BuildPlatform != "unknown" { - t.Errorf("Default BuildPlatform should be 'unknown', got '%s'", BuildPlatform) + if version.BuildPlatform != "unknown" { + t.Errorf("Default BuildPlatform should be 'unknown', got '%s'", version.BuildPlatform) } } diff --git a/internal/handler/static.go b/internal/handler/static.go index 65d53d0..9c6c87f 100644 --- a/internal/handler/static.go +++ b/internal/handler/static.go @@ -57,6 +57,7 @@ type StaticHandler struct { useSendfile bool tryFilesPass bool symlinkCheck bool + internal bool } // NewStaticHandler 创建静态文件处理器。 @@ -208,6 +209,17 @@ func (h *StaticHandler) SetSymlinkCheck(enabled bool) { h.symlinkCheck = enabled } +// SetInternal 设置内部访问限制。 +// +// 启用后,仅允许内部重定向访问该静态位置。 +// 外部直接请求将返回 404 Not Found。 +// +// 参数: +// - enabled: 是否启用内部访问限制 +func (h *StaticHandler) SetInternal(enabled bool) { + h.internal = enabled +} + // SetCacheTTL 设置缓存新鲜度 TTL。 // // TTL 控制缓存条目的新鲜度验证间隔。 @@ -245,6 +257,12 @@ func (h *StaticHandler) SetCacheTTL(ttl time.Duration) { func (h *StaticHandler) Handle(ctx *fasthttp.RequestCtx) { reqPath := string(ctx.Path()) + // 检查 internal 限制 + if h.internal && !utils.IsInternalRedirect(ctx) { + utils.SendError(ctx, utils.ErrNotFound) + return + } + // 安全检查:防止目录遍历 if strings.Contains(reqPath, "..") { utils.SendError(ctx, utils.ErrForbidden) diff --git a/internal/integration/regex_config_test.go b/internal/integration/regex_config_test.go index 590eff5..d15f0a5 100644 --- a/internal/integration/regex_config_test.go +++ b/internal/integration/regex_config_test.go @@ -10,7 +10,7 @@ import ( func TestRegexConfigCaseSensitive(t *testing.T) { // 测试 ~ 修饰符(case-sensitive) // 创建 regex matcher,验证只匹配小写 - m, err := matcher.NewRegexMatcher(`\.php$`, nil, 3, false) + m, err := matcher.NewRegexMatcher(`\.php$`, nil, 3, false, false) if err != nil { t.Fatal(err) } @@ -24,7 +24,7 @@ func TestRegexConfigCaseSensitive(t *testing.T) { func TestRegexConfigCaseInsensitive(t *testing.T) { // 测试 ~* 修饰符(case-insensitive) - m, err := matcher.NewRegexMatcher(`(?i)\.php$`, nil, 3, true) + m, err := matcher.NewRegexMatcher(`(?i)\.php$`, nil, 3, true, false) if err != nil { t.Fatal(err) } @@ -42,7 +42,7 @@ func TestPrefixPriorityNotRegex(t *testing.T) { dummyHandler := func(ctx *fasthttp.RequestCtx) {} engine := matcher.NewLocationEngine() - err := engine.AddPrefixPriority("/images", dummyHandler) + err := engine.AddPrefixPriority("/images", dummyHandler, false) if err != nil { t.Fatal(err) } diff --git a/internal/matcher/bench_test.go b/internal/matcher/bench_test.go index d4ac707..ffd283e 100644 --- a/internal/matcher/bench_test.go +++ b/internal/matcher/bench_test.go @@ -19,7 +19,7 @@ func BenchmarkRadixTree_Insert(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { for _, p := range paths { - tree.Insert(p, handler, i, "prefix") + tree.Insert(p, handler, i, "prefix", false) } } } @@ -30,7 +30,7 @@ func BenchmarkRadixTree_Find(b *testing.B) { paths := []string{"/", "/api", "/api/v1", "/api/v2/users/123"} for i, p := range paths { - tree.Insert(p, handler, i+1, "prefix") + tree.Insert(p, handler, i+1, "prefix", false) } tree.MarkInitialized() @@ -42,7 +42,7 @@ func BenchmarkRadixTree_Find(b *testing.B) { func BenchmarkExactMatcher_Match(b *testing.B) { handler := func(ctx *fasthttp.RequestCtx) {} - m := NewExactMatcher("/api/users", handler, 1) + m := NewExactMatcher("/api/users", handler, 1, false) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -51,7 +51,7 @@ func BenchmarkExactMatcher_Match(b *testing.B) { } func BenchmarkRegexMatcher_Match(b *testing.B) { - m := MustRegexMatcher(`^/api/v[0-9]+/users/[0-9]+$`, nil, 3, false) + m := MustRegexMatcher(`^/api/v[0-9]+/users/[0-9]+$`, nil, 3, false, false) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -63,10 +63,10 @@ func BenchmarkLocationEngine_Match(b *testing.B) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - engine.AddExact("/api", handler) - engine.AddPrefixPriority("/api/", handler) - engine.AddRegex(`\.php$`, handler, false) - engine.AddPrefix("/", handler) + engine.AddExact("/api", handler, false) + engine.AddPrefixPriority("/api/", handler, false) + engine.AddRegex(`\.php$`, handler, false, false) + engine.AddPrefix("/", handler, false) engine.MarkInitialized() b.ResetTimer() diff --git a/internal/matcher/integration_test.go b/internal/matcher/integration_test.go index b71a87f..687ba46 100644 --- a/internal/matcher/integration_test.go +++ b/internal/matcher/integration_test.go @@ -12,10 +12,10 @@ func TestLocationEngine_NginxPriority(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} // 注册不同类型 - engine.AddExact("/api", handler) // priority 1 - engine.AddPrefixPriority("/api/", handler) // priority 2 (^~) - engine.AddRegex(`\.php$`, handler, false) // priority 3 - engine.AddPrefix("/", handler) // priority 4 + engine.AddExact("/api", handler, false) // priority 1 + engine.AddPrefixPriority("/api/", handler, false) // priority 2 (^~) + engine.AddRegex(`\.php$`, handler, false, false) // priority 3 + engine.AddPrefix("/", handler, false) // priority 4 engine.MarkInitialized() // 测试精确匹配优先 @@ -35,9 +35,9 @@ func TestLocationEngine_RegexMatch(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - engine.AddPrefixPriority("/api/", handler) - engine.AddRegex(`\.php$`, handler, false) - engine.AddPrefix("/", handler) + engine.AddPrefixPriority("/api/", handler, false) + engine.AddRegex(`\.php$`, handler, false, false) + engine.AddPrefix("/", handler, false) engine.MarkInitialized() // 正则匹配(^~ 不匹配 /index.php) @@ -51,7 +51,7 @@ func TestLocationEngine_PrefixFallback(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - engine.AddPrefix("/", handler) + engine.AddPrefix("/", handler, false) engine.MarkInitialized() result := engine.Match("/any/path") @@ -74,7 +74,7 @@ func TestLocationEngine_RegexCaptures(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - engine.AddRegex(`^/user/(?P[0-9]+)$`, handler, false) + engine.AddRegex(`^/user/(?P[0-9]+)$`, handler, false, false) engine.MarkInitialized() result := engine.Match("/user/42") @@ -92,7 +92,7 @@ func TestLocationEngine_Initialized_Twice(t *testing.T) { engine.MarkInitialized() - err := engine.AddExact("/api", handler) + err := engine.AddExact("/api", handler, false) if err == nil { t.Error("should fail when adding after initialized") } @@ -102,8 +102,8 @@ func TestLocationEngine_PathConflict(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - engine.AddExact("/api", handler) - err := engine.AddExact("/api", handler) + engine.AddExact("/api", handler, false) + err := engine.AddExact("/api", handler, false) if err == nil { t.Error("should fail on path conflict") } diff --git a/internal/matcher/location_test.go b/internal/matcher/location_test.go index c30351e..2545534 100644 --- a/internal/matcher/location_test.go +++ b/internal/matcher/location_test.go @@ -32,7 +32,7 @@ func TestLocationEngine_AddExact(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - err := engine.AddExact("/api", handler) + err := engine.AddExact("/api", handler, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -54,7 +54,7 @@ func TestLocationEngine_AddExact_AfterInitialized(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} engine.MarkInitialized() - err := engine.AddExact("/api", handler) + err := engine.AddExact("/api", handler, false) if err == nil { t.Error("expected error after initialized") } @@ -64,8 +64,8 @@ func TestLocationEngine_AddExact_PathConflict(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - engine.AddExact("/api", handler) - err := engine.AddExact("/api", handler) + engine.AddExact("/api", handler, false) + err := engine.AddExact("/api", handler, false) if err == nil { t.Error("expected conflict error") } @@ -75,7 +75,7 @@ func TestLocationEngine_AddPrefixPriority(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - err := engine.AddPrefixPriority("/static", handler) + err := engine.AddPrefixPriority("/static", handler, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -97,7 +97,7 @@ func TestLocationEngine_AddPrefixPriority_AfterInitialized(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} engine.MarkInitialized() - err := engine.AddPrefixPriority("/static", handler) + err := engine.AddPrefixPriority("/static", handler, false) if err == nil { t.Error("expected error after initialized") } @@ -107,7 +107,7 @@ func TestLocationEngine_AddPrefix(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - err := engine.AddPrefix("/api", handler) + err := engine.AddPrefix("/api", handler, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -126,7 +126,7 @@ func TestLocationEngine_AddPrefix_AfterInitialized(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} engine.MarkInitialized() - err := engine.AddPrefix("/api", handler) + err := engine.AddPrefix("/api", handler, false) if err == nil { t.Error("expected error after initialized") } @@ -136,7 +136,7 @@ func TestLocationEngine_AddRegex(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - err := engine.AddRegex(`\.php$`, handler, false) + err := engine.AddRegex(`\.php$`, handler, false, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -154,7 +154,7 @@ func TestLocationEngine_AddRegex_CaseInsensitive(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - err := engine.AddRegex(`(?i)\.php$`, handler, true) + err := engine.AddRegex(`(?i)\.php$`, handler, true, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -172,7 +172,7 @@ func TestLocationEngine_AddRegex_InvalidPattern(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - err := engine.AddRegex(`[invalid`, handler, false) + err := engine.AddRegex(`[invalid`, handler, false, false) if err == nil { t.Error("expected error for invalid regex") } @@ -182,7 +182,7 @@ func TestLocationEngine_AddRegex_Captures(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - err := engine.AddRegex(`^/user/(?P[0-9]+)$`, handler, false) + err := engine.AddRegex(`^/user/(?P[0-9]+)$`, handler, false, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -207,10 +207,10 @@ func TestLocationEngine_Match_PriorityOrder(t *testing.T) { hPrefix := func(ctx *fasthttp.RequestCtx) {} // All match "/api/path" - engine.AddExact("/api/path", hExact) - engine.AddPrefixPriority("/api", hPrefixPriority) - engine.AddRegex(`^/api/`, hRegex, false) - engine.AddPrefix("/api/path", hPrefix) + engine.AddExact("/api/path", hExact, false) + engine.AddPrefixPriority("/api", hPrefixPriority, false) + engine.AddRegex(`^/api/`, hRegex, false, false) + engine.AddPrefix("/api/path", hPrefix, false) // Exact should win (priority 1) result := engine.Match("/api/path") @@ -228,8 +228,8 @@ func TestLocationEngine_Match_PrefixPriorityBeatsRegex(t *testing.T) { hRegex := func(ctx *fasthttp.RequestCtx) {} // No exact match for this path - engine.AddPrefixPriority("/static", hPrefixPriority) - engine.AddRegex(`\.css$`, hRegex, false) + engine.AddPrefixPriority("/static", hPrefixPriority, false) + engine.AddRegex(`\.css$`, hRegex, false, false) // ^~ prefix priority should beat regex result := engine.Match("/static/style.css") @@ -246,8 +246,8 @@ func TestLocationEngine_Match_RegexBeatsPrefix(t *testing.T) { hRegex := func(ctx *fasthttp.RequestCtx) {} hPrefix := func(ctx *fasthttp.RequestCtx) {} - engine.AddRegex(`\.php$`, hRegex, false) - engine.AddPrefix("/", hPrefix) + engine.AddRegex(`\.php$`, hRegex, false, false) + engine.AddPrefix("/", hPrefix, false) // Regex should win over plain prefix result := engine.Match("/index.php") @@ -263,7 +263,7 @@ func TestLocationEngine_Match_FallbackToPrefix(t *testing.T) { engine := NewLocationEngine() hPrefix := func(ctx *fasthttp.RequestCtx) {} - engine.AddPrefix("/api", hPrefix) + engine.AddPrefix("/api", hPrefix, false) result := engine.Match("/api/users") if result == nil { @@ -278,7 +278,7 @@ func TestLocationEngine_Match_NoMatch(t *testing.T) { engine := NewLocationEngine() hPrefix := func(ctx *fasthttp.RequestCtx) {} - engine.AddPrefix("/api", hPrefix) + engine.AddPrefix("/api", hPrefix, false) result := engine.Match("/other") if result != nil { @@ -290,7 +290,7 @@ func TestLocationEngine_Match_EmptyString(t *testing.T) { engine := NewLocationEngine() hPrefix := func(ctx *fasthttp.RequestCtx) {} - engine.AddPrefix("/api", hPrefix) + engine.AddPrefix("/api", hPrefix, false) result := engine.Match("") if result != nil { @@ -302,7 +302,7 @@ func TestLocationEngine_Match_UnicodePath(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - engine.AddPrefixPriority("/文档", handler) + engine.AddPrefixPriority("/文档", handler, false) result := engine.Match("/文档/报告") if result == nil { @@ -317,20 +317,20 @@ func TestLocationEngine_MarkInitialized(t *testing.T) { engine := NewLocationEngine() handler := func(ctx *fasthttp.RequestCtx) {} - engine.AddPrefix("/api", handler) + engine.AddPrefix("/api", handler, false) engine.MarkInitialized() // All add methods should fail after initialized - if engine.AddExact("/exact", handler) == nil { + if engine.AddExact("/exact", handler, false) == nil { t.Error("AddExact should fail after initialized") } - if engine.AddPrefixPriority("/pp", handler) == nil { + if engine.AddPrefixPriority("/pp", handler, false) == nil { t.Error("AddPrefixPriority should fail after initialized") } - if engine.AddPrefix("/pre", handler) == nil { + if engine.AddPrefix("/pre", handler, false) == nil { t.Error("AddPrefix should fail after initialized") } - if engine.AddRegex(`test`, handler, false) == nil { + if engine.AddRegex(`test`, handler, false, false) == nil { t.Error("AddRegex should fail after initialized") } if engine.AddNamed("test", handler) == nil { diff --git a/internal/matcher/matcher_test.go b/internal/matcher/matcher_test.go index 5341512..1b067a3 100644 --- a/internal/matcher/matcher_test.go +++ b/internal/matcher/matcher_test.go @@ -8,7 +8,7 @@ import ( func TestExactMatcher_Match(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} - m := NewExactMatcher("/api", handler, 1) + m := NewExactMatcher("/api", handler, 1, false) if !m.Match("/api") { t.Error("should match exact path") @@ -19,7 +19,7 @@ func TestExactMatcher_Match(t *testing.T) { } func TestRegexMatcher_Match(t *testing.T) { - m := MustRegexMatcher(`\.php$`, nil, 3, false) + m := MustRegexMatcher(`\.php$`, nil, 3, false, false) if !m.Match("/index.php") { t.Error("should match .php") @@ -30,7 +30,7 @@ func TestRegexMatcher_Match(t *testing.T) { } func TestRegexMatcher_GetCaptures(t *testing.T) { - m := MustRegexMatcher(`^/user/(?P[0-9]+)$`, nil, 3, false) + m := MustRegexMatcher(`^/user/(?P[0-9]+)$`, nil, 3, false, false) captures := m.GetCaptures("/user/123") if captures["id"] != "123" { @@ -39,7 +39,7 @@ func TestRegexMatcher_GetCaptures(t *testing.T) { } func TestRegexMatcher_GetCaptures_NoMatch(t *testing.T) { - m := MustRegexMatcher(`^/user/(?P[0-9]+)$`, nil, 3, false) + m := MustRegexMatcher(`^/user/(?P[0-9]+)$`, nil, 3, false, false) captures := m.GetCaptures("/user/abc") if captures != nil { @@ -49,7 +49,7 @@ func TestRegexMatcher_GetCaptures_NoMatch(t *testing.T) { func TestRegexMatcher_CaseInsensitive(t *testing.T) { // caseInsensitive flag only affects Result().LocationType, not matching - m := MustRegexMatcher(`\.php$`, nil, 3, true) + m := MustRegexMatcher(`\.php$`, nil, 3, true, false) if !m.Match("/index.php") { t.Error("should match .php") @@ -67,14 +67,14 @@ func TestRegexMatcher_CaseInsensitive(t *testing.T) { func TestRegexMatcher_Result_LocationType(t *testing.T) { // Case sensitive - m := MustRegexMatcher(`\.php$`, nil, 3, false) + m := MustRegexMatcher(`\.php$`, nil, 3, false, false) result := m.Result() if result.LocationType != "regex" { t.Errorf("expected location type 'regex', got %s", result.LocationType) } // Case insensitive - m2 := MustRegexMatcher(`\.php$`, nil, 3, true) + m2 := MustRegexMatcher(`\.php$`, nil, 3, true, false) result2 := m2.Result() if result2.LocationType != "regex_caseless" { t.Errorf("expected location type 'regex_caseless', got %s", result2.LocationType) @@ -82,7 +82,7 @@ func TestRegexMatcher_Result_LocationType(t *testing.T) { } func TestNewRegexMatcher_InvalidPattern(t *testing.T) { - _, err := NewRegexMatcher(`[invalid`, nil, 3, false) + _, err := NewRegexMatcher(`[invalid`, nil, 3, false, false) if err == nil { t.Error("expected error for invalid regex pattern") } diff --git a/internal/matcher/prefix_priority_test.go b/internal/matcher/prefix_priority_test.go index 024f4b1..f3acd37 100644 --- a/internal/matcher/prefix_priority_test.go +++ b/internal/matcher/prefix_priority_test.go @@ -20,7 +20,7 @@ func TestPrefixPriorityMatcher_AddPath(t *testing.T) { ppm := NewPrefixPriorityMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - err := ppm.AddPath("/static", handler) + err := ppm.AddPath("/static", handler, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -35,8 +35,8 @@ func TestPrefixPriorityMatcher_Match(t *testing.T) { ppm := NewPrefixPriorityMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - ppm.AddPath("/static", handler) - ppm.AddPath("/static/images", handler) + ppm.AddPath("/static", handler, false) + ppm.AddPath("/static/images", handler, false) tests := []struct { path string @@ -68,8 +68,8 @@ func TestPrefixPriorityMatcher_Priority(t *testing.T) { h1 := func(ctx *fasthttp.RequestCtx) {} h2 := func(ctx *fasthttp.RequestCtx) {} - ppm.AddPath("/api/v1", h1) - ppm.AddPath("/api/v2", h2) + ppm.AddPath("/api/v1", h1, false) + ppm.AddPath("/api/v2", h2, false) result := ppm.Match("/api/v2/data") if result == nil { @@ -88,7 +88,7 @@ func TestPrefixPriorityMatcher_Match_EmptyString(t *testing.T) { ppm := NewPrefixPriorityMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - ppm.AddPath("/", handler) + ppm.AddPath("/", handler, false) result := ppm.Match("") if result != nil { t.Error("empty string should not match '/' prefix") @@ -99,7 +99,7 @@ func TestPrefixPriorityMatcher_Match_UnicodePath(t *testing.T) { ppm := NewPrefixPriorityMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - ppm.AddPath("/文档", handler) + ppm.AddPath("/文档", handler, false) result := ppm.Match("/文档/报告") if result == nil { @@ -111,10 +111,10 @@ func TestPrefixPriorityMatcher_MarkInitialized(t *testing.T) { ppm := NewPrefixPriorityMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - ppm.AddPath("/static", handler) + ppm.AddPath("/static", handler, false) ppm.MarkInitialized() - err := ppm.AddPath("/static/v2", handler) + err := ppm.AddPath("/static/v2", handler, false) if err == nil { t.Error("should fail after initialized") } @@ -124,8 +124,8 @@ func TestPrefixPriorityMatcher_AddPath_Duplicate(t *testing.T) { ppm := NewPrefixPriorityMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - ppm.AddPath("/static", handler) - err := ppm.AddPath("/static", handler) + ppm.AddPath("/static", handler, false) + err := ppm.AddPath("/static", handler, false) if err == nil { t.Error("should fail on duplicate path") } @@ -135,7 +135,7 @@ func TestPrefixPriorityMatcher_Result_LocationType(t *testing.T) { ppm := NewPrefixPriorityMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - ppm.AddPath("/static", handler) + ppm.AddPath("/static", handler, false) result := ppm.Match("/static/file.txt") if result == nil { diff --git a/internal/matcher/prefix_test.go b/internal/matcher/prefix_test.go index e5c991e..3722f64 100644 --- a/internal/matcher/prefix_test.go +++ b/internal/matcher/prefix_test.go @@ -20,7 +20,7 @@ func TestPrefixMatcher_AddPath(t *testing.T) { pm := NewPrefixMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - err := pm.AddPath("/api", handler) + err := pm.AddPath("/api", handler, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -35,8 +35,8 @@ func TestPrefixMatcher_Match(t *testing.T) { pm := NewPrefixMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - pm.AddPath("/api", handler) - pm.AddPath("/api/v2", handler) + pm.AddPath("/api", handler, false) + pm.AddPath("/api/v2", handler, false) tests := []struct { path string @@ -67,7 +67,7 @@ func TestPrefixMatcher_Match_EmptyString(t *testing.T) { pm := NewPrefixMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - pm.AddPath("/", handler) + pm.AddPath("/", handler, false) result := pm.Match("") if result != nil { t.Error("empty string should not match '/' prefix") @@ -78,7 +78,7 @@ func TestPrefixMatcher_Match_UnicodePath(t *testing.T) { pm := NewPrefixMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - pm.AddPath("/café", handler) + pm.AddPath("/café", handler, false) result := pm.Match("/café/latte") if result == nil { @@ -91,8 +91,8 @@ func TestPrefixMatcher_Match_LongestPrefix(t *testing.T) { h1 := func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("1") } h2 := func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("2") } - pm.AddPath("/static", h1) - pm.AddPath("/static/css", h2) + pm.AddPath("/static", h1, false) + pm.AddPath("/static/css", h2, false) result := pm.Match("/static/css/main.css") if result == nil { @@ -110,10 +110,10 @@ func TestPrefixMatcher_MarkInitialized(t *testing.T) { pm := NewPrefixMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - pm.AddPath("/api", handler) + pm.AddPath("/api", handler, false) pm.MarkInitialized() - err := pm.AddPath("/api/v2", handler) + err := pm.AddPath("/api/v2", handler, false) if err == nil { t.Error("should fail after initialized") } @@ -123,8 +123,8 @@ func TestPrefixMatcher_AddPath_Duplicate(t *testing.T) { pm := NewPrefixMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - pm.AddPath("/api", handler) - err := pm.AddPath("/api", handler) + pm.AddPath("/api", handler, false) + err := pm.AddPath("/api", handler, false) if err == nil { t.Error("should fail on duplicate path") } @@ -134,7 +134,7 @@ func TestPrefixMatcher_Match_SpecialChars(t *testing.T) { pm := NewPrefixMatcher() handler := func(ctx *fasthttp.RequestCtx) {} - pm.AddPath("/api/v1", handler) + pm.AddPath("/api/v1", handler, false) result := pm.Match("/api/v1?key=value&other=123") if result == nil { diff --git a/internal/matcher/radix_test.go b/internal/matcher/radix_test.go index 66e9e98..5d14db8 100644 --- a/internal/matcher/radix_test.go +++ b/internal/matcher/radix_test.go @@ -11,7 +11,7 @@ func TestRadixTree_Insert_EmptyNode(t *testing.T) { tree := NewRadixTree() handler := func(ctx *fasthttp.RequestCtx) {} - err := tree.Insert("/api", handler, 1, "prefix") + err := tree.Insert("/api", handler, 1, "prefix", false) if err != nil { t.Fatalf("insert failed: %v", err) } @@ -31,8 +31,8 @@ func TestRadixTree_Insert_CommonPrefix(t *testing.T) { handler1 := func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("1") } handler2 := func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("2") } - tree.Insert("/api", handler1, 1, "prefix") - tree.Insert("/api/users", handler2, 2, "prefix") + tree.Insert("/api", handler1, 1, "prefix", false) + tree.Insert("/api/users", handler2, 2, "prefix", false) result := tree.FindLongestPrefix("/api/users") if result == nil { @@ -53,8 +53,8 @@ func TestRadixTree_Insert_NodeSplit(t *testing.T) { handler1 := func(ctx *fasthttp.RequestCtx) {} handler2 := func(ctx *fasthttp.RequestCtx) {} - tree.Insert("/abc", handler1, 1, "prefix") - tree.Insert("/abx", handler2, 2, "prefix") + tree.Insert("/abc", handler1, 1, "prefix", false) + tree.Insert("/abx", handler2, 2, "prefix", false) // 应该正确分割 /ab 公共前缀 result := tree.FindLongestPrefix("/abc") @@ -67,9 +67,9 @@ func TestRadixTree_FindLongestPrefix(t *testing.T) { tree := NewRadixTree() handler := func(ctx *fasthttp.RequestCtx) {} - tree.Insert("/", handler, 1, "prefix") - tree.Insert("/api", handler, 2, "prefix") - tree.Insert("/api/v1", handler, 3, "prefix") + tree.Insert("/", handler, 1, "prefix", false) + tree.Insert("/api", handler, 2, "prefix", false) + tree.Insert("/api/v1", handler, 3, "prefix", false) // "/" has priority 1 (wins), "/api" has 2, "/api/v1" has 3 // Lower number = higher priority @@ -86,10 +86,10 @@ func TestRadixTree_Insert_AfterInitialized(t *testing.T) { tree := NewRadixTree() handler := func(ctx *fasthttp.RequestCtx) {} - tree.Insert("/api", handler, 1, "prefix") + tree.Insert("/api", handler, 1, "prefix", false) tree.MarkInitialized() - err := tree.Insert("/api/v2", handler, 2, "prefix") + err := tree.Insert("/api/v2", handler, 2, "prefix", false) if err == nil { t.Error("should fail when inserting after initialized") } @@ -99,8 +99,8 @@ func TestRadixTree_Insert_DuplicatePath(t *testing.T) { tree := NewRadixTree() handler := func(ctx *fasthttp.RequestCtx) {} - tree.Insert("/api", handler, 1, "prefix") - err := tree.Insert("/api", handler, 2, "prefix") + tree.Insert("/api", handler, 1, "prefix", false) + err := tree.Insert("/api", handler, 2, "prefix", false) if err == nil { t.Error("should fail on duplicate path") } @@ -110,7 +110,7 @@ func TestRadixTree_FindLongestPrefix_NoMatch(t *testing.T) { tree := NewRadixTree() handler := func(ctx *fasthttp.RequestCtx) {} - tree.Insert("/api", handler, 1, "prefix") + tree.Insert("/api", handler, 1, "prefix", false) result := tree.FindLongestPrefix("/other") if result != nil { @@ -123,8 +123,8 @@ func TestRadixTree_PriorityComparison(t *testing.T) { h1 := func(ctx *fasthttp.RequestCtx) {} h2 := func(ctx *fasthttp.RequestCtx) {} - tree.Insert("/api", h1, 5, "prefix") - tree.Insert("/api/users", h2, 2, "prefix") + tree.Insert("/api", h1, 5, "prefix", false) + tree.Insert("/api/users", h2, 2, "prefix", false) // Lower priority number wins result := tree.FindLongestPrefix("/api/users") @@ -142,10 +142,10 @@ func TestRadixTree_Insert_ExactMatch(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("exact") } // 先插入父路径 - tree.Insert("/api", handler, 1, "prefix") + tree.Insert("/api", handler, 1, "prefix", false) // 再次插入相同路径(应该报错重复) - err := tree.Insert("/api", handler, 2, "prefix") + err := tree.Insert("/api", handler, 2, "prefix", false) if err == nil { t.Error("should return error for duplicate path") } diff --git a/internal/matcher/regex_test.go b/internal/matcher/regex_test.go index f67a137..11bd5f5 100644 --- a/internal/matcher/regex_test.go +++ b/internal/matcher/regex_test.go @@ -8,7 +8,7 @@ import ( func TestRegexMatcher_New(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} - m, err := NewRegexMatcher(`^/api/`, handler, 3, false) + m, err := NewRegexMatcher(`^/api/`, handler, 3, false, false) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -24,7 +24,7 @@ func TestRegexMatcher_New(t *testing.T) { } func TestRegexMatcher_New_InvalidPattern(t *testing.T) { - _, err := NewRegexMatcher(`[invalid`, nil, 3, false) + _, err := NewRegexMatcher(`[invalid`, nil, 3, false, false) if err == nil { t.Fatal("expected error for invalid regex") } @@ -32,7 +32,7 @@ func TestRegexMatcher_New_InvalidPattern(t *testing.T) { func TestRegexMatcher_Match_Paths(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} - m := MustRegexMatcher(`^/api/`, handler, 3, false) + m := MustRegexMatcher(`^/api/`, handler, 3, false, false) tests := []struct { path string @@ -57,7 +57,7 @@ func TestRegexMatcher_Match_Paths(t *testing.T) { func TestRegexMatcher_Match_CaseInsensitive(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} - m := MustRegexMatcher(`(?i)^/api/`, handler, 3, true) + m := MustRegexMatcher(`(?i)^/api/`, handler, 3, true, false) if !m.Match("/api/users") { t.Error("should match lowercase") @@ -72,7 +72,7 @@ func TestRegexMatcher_Match_CaseInsensitive(t *testing.T) { func TestRegexMatcher_Result(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} - m := MustRegexMatcher(`\.php$`, handler, 3, false) + m := MustRegexMatcher(`\.php$`, handler, 3, false, false) result := m.Result() if result == nil { @@ -91,7 +91,7 @@ func TestRegexMatcher_Result(t *testing.T) { func TestRegexMatcher_Result_Caseless(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} - m := MustRegexMatcher(`\.php$`, handler, 3, true) + m := MustRegexMatcher(`\.php$`, handler, 3, true, false) result := m.Result() if result.LocationType != LocationTypeRegexCaseless { @@ -101,7 +101,7 @@ func TestRegexMatcher_Result_Caseless(t *testing.T) { func TestRegexMatcher_GetCaptures_NamedGroups(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} - m := MustRegexMatcher(`^/user/(?P[0-9]+)/post/(?P[a-z]+)$`, handler, 3, false) + m := MustRegexMatcher(`^/user/(?P[0-9]+)/post/(?P[a-z]+)$`, handler, 3, false, false) captures := m.GetCaptures("/user/42/post/hello") if captures == nil { @@ -117,7 +117,7 @@ func TestRegexMatcher_GetCaptures_NamedGroups(t *testing.T) { func TestRegexMatcher_GetCaptures_NoMatchPath(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} - m := MustRegexMatcher(`^/user/(?P[0-9]+)$`, handler, 3, false) + m := MustRegexMatcher(`^/user/(?P[0-9]+)$`, handler, 3, false, false) captures := m.GetCaptures("/user/abc") if captures != nil { @@ -127,7 +127,7 @@ func TestRegexMatcher_GetCaptures_NoMatchPath(t *testing.T) { func TestRegexMatcher_GetCaptures_NoNamedGroups(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} - m := MustRegexMatcher(`^/user/[0-9]+$`, handler, 3, false) + m := MustRegexMatcher(`^/user/[0-9]+$`, handler, 3, false, false) // No named groups, should return empty map captures := m.GetCaptures("/user/123") @@ -141,7 +141,7 @@ func TestRegexMatcher_GetCaptures_NoNamedGroups(t *testing.T) { func TestRegexMatcher_Match_UnicodePath(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} - m := MustRegexMatcher(`^/文档/`, handler, 3, false) + m := MustRegexMatcher(`^/文档/`, handler, 3, false, false) if !m.Match("/文档/报告") { t.Error("should match unicode path") @@ -153,7 +153,7 @@ func TestRegexMatcher_Match_UnicodePath(t *testing.T) { func TestRegexMatcher_Match_SpecialChars(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} - m := MustRegexMatcher(`^/path\?query=`, handler, 3, false) + m := MustRegexMatcher(`^/path\?query=`, handler, 3, false, false) if !m.Match("/path?query=test") { t.Error("should match path with query string") @@ -162,7 +162,7 @@ func TestRegexMatcher_Match_SpecialChars(t *testing.T) { func TestRegexMatcher_Match_EmptyPath(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) {} - m := MustRegexMatcher(`^$`, handler, 3, false) + m := MustRegexMatcher(`^$`, handler, 3, false, false) if !m.Match("") { t.Error("should match empty string with ^$ pattern") @@ -175,5 +175,5 @@ func TestMustRegexMatcher_Panic(t *testing.T) { t.Error("expected panic for invalid regex") } }() - MustRegexMatcher(`[invalid`, nil, 3, false) + MustRegexMatcher(`[invalid`, nil, 3, false, false) } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 40f3252..aa1341b 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -642,6 +642,13 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { // 请求成功,减少连接计数 loadbalance.DecrementConnections(target) + // 检测 X-Accel-Redirect 头,支持内部重定向 + if redirectPath := ctx.Response.Header.Peek("X-Accel-Redirect"); len(redirectPath) > 0 { + utils.SetInternalRedirect(ctx, string(redirectPath)) + ctx.Request.SetRequestURI(string(redirectPath)) + return + } + // 检查响应状态码是否需要重试 statusCode := ctx.Response.StatusCode() upstreamStatus = statusCode diff --git a/internal/server/internal.go b/internal/server/internal.go new file mode 100644 index 0000000..d6483cd --- /dev/null +++ b/internal/server/internal.go @@ -0,0 +1,26 @@ +package server + +import ( + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/utils" +) + +const ( + // InternalRedirectKey 内部重定向标记 + InternalRedirectKey = utils.InternalRedirectKey +) + +// SetInternalRedirect 标记请求为内部重定向 +func SetInternalRedirect(ctx *fasthttp.RequestCtx, targetPath string) { + utils.SetInternalRedirect(ctx, targetPath) +} + +// IsInternalRedirect 检查是否为内部重定向 +func IsInternalRedirect(ctx *fasthttp.RequestCtx) bool { + return utils.IsInternalRedirect(ctx) +} + +// GetInternalRedirectPath 获取内部重定向目标路径 +func GetInternalRedirectPath(ctx *fasthttp.RequestCtx) string { + return utils.GetInternalRedirectPath(ctx) +} diff --git a/internal/utils/internal.go b/internal/utils/internal.go new file mode 100644 index 0000000..835bf20 --- /dev/null +++ b/internal/utils/internal.go @@ -0,0 +1,28 @@ +package utils + +import "github.com/valyala/fasthttp" + +const ( + // InternalRedirectKey 内部重定向标记 + InternalRedirectKey = "__internal_redirect__" +) + +// SetInternalRedirect 标记请求为内部重定向 +func SetInternalRedirect(ctx *fasthttp.RequestCtx, targetPath string) { + ctx.SetUserValue(InternalRedirectKey, targetPath) +} + +// IsInternalRedirect 检查是否为内部重定向 +func IsInternalRedirect(ctx *fasthttp.RequestCtx) bool { + return ctx.UserValue(InternalRedirectKey) != nil +} + +// GetInternalRedirectPath 获取内部重定向目标路径 +func GetInternalRedirectPath(ctx *fasthttp.RequestCtx) string { + if v := ctx.UserValue(InternalRedirectKey); v != nil { + if path, ok := v.(string); ok { + return path + } + } + return "" +}