diff --git a/internal/middleware/accesslog/accesslog.go b/internal/middleware/accesslog/accesslog.go index f32a4c3..06f51fe 100644 --- a/internal/middleware/accesslog/accesslog.go +++ b/internal/middleware/accesslog/accesslog.go @@ -1,4 +1,16 @@ // Package accesslog 提供访问日志中间件,记录每个请求的详细信息。 +// +// 该文件包含访问日志相关的核心逻辑,包括: +// - 请求方法和路径记录 +// - 响应状态码和大小记录 +// - 请求处理耗时记录 +// +// 使用示例: +// +// accessLog := accesslog.New(cfg.Logging) +// chain := middleware.NewChain(accessLog) +// +// 作者:xfy package accesslog import ( diff --git a/internal/middleware/compression/compression.go b/internal/middleware/compression/compression.go index 0fcaed3..cf29f5b 100644 --- a/internal/middleware/compression/compression.go +++ b/internal/middleware/compression/compression.go @@ -1,4 +1,20 @@ // Package compression 提供 HTTP 响应压缩中间件,支持 gzip 和 brotli 算法。 +// +// 该文件包含压缩相关的核心逻辑,包括: +// - gzip 压缩(兼容性好,所有浏览器支持) +// - brotli 压缩(压缩率更高,适合现代浏览器) +// - MIME 类型过滤 +// - 最小压缩大小控制 +// +// 主要用途: +// +// 用于压缩 HTTP 响应内容,减少传输数据量,提升页面加载速度。 +// +// 注意事项: +// - 使用缓冲池复用压缩对象,减少内存分配 +// - 小于 MinSize 的响应不压缩 +// +// 作者:xfy package compression import ( @@ -24,13 +40,18 @@ const ( // CompressionMiddleware 响应压缩中间件。 type CompressionMiddleware struct { - types []string // 可压缩的 MIME 类型 - level int // 压缩级别 - minSize int // 最小压缩大小 - algorithm Algorithm // 压缩算法 + // types 可压缩的 MIME 类型列表 + types []string + // level 压缩级别(1-9) + level int + // minSize 最小压缩大小(字节) + minSize int + // algorithm 压缩算法 + algorithm Algorithm - // 缓冲池 - gzipPool sync.Pool + // gzipPool gzip.Writer 缓冲池 + gzipPool sync.Pool + // brotliPool brotli.Writer 缓冲池 brotliPool sync.Pool } diff --git a/internal/middleware/rewrite/rewrite.go b/internal/middleware/rewrite/rewrite.go index 3eec6fd..c7e161c 100644 --- a/internal/middleware/rewrite/rewrite.go +++ b/internal/middleware/rewrite/rewrite.go @@ -2,6 +2,7 @@ package rewrite import ( + "fmt" "regexp" "strings" @@ -39,13 +40,17 @@ func parseFlag(s string) Flag { // Rule 编译后的重写规则。 type Rule struct { - pattern *regexp.Regexp + // pattern 正则匹配模式 + pattern *regexp.Regexp + // replacement 替换字符串,支持 $1、$2 等捕获组 replacement string - flag Flag + // flag 执行标志,控制重写后行为 + flag Flag } // RewriteMiddleware URL 重写中间件。 type RewriteMiddleware struct { + // rules 编译后的规则列表,按配置顺序执行 rules []Rule } @@ -53,6 +58,11 @@ type RewriteMiddleware struct { func New(rules []config.RewriteRule) (*RewriteMiddleware, error) { compiled := make([]Rule, 0, len(rules)) for _, r := range rules { + // 验证正则表达式安全性,防止 ReDoS + if err := validateRegexSafety(r.Pattern); err != nil { + return nil, fmt.Errorf("unsafe regex pattern %q: %w", r.Pattern, err) + } + re, err := regexp.Compile(r.Pattern) if err != nil { return nil, err @@ -66,6 +76,32 @@ func New(rules []config.RewriteRule) (*RewriteMiddleware, error) { return &RewriteMiddleware{rules: compiled}, nil } +// validateRegexSafety 验证正则表达式的安全性,防止 ReDoS 攻击。 +// +// 检测可能导致灾难性回溯的危险模式,如嵌套量词。 +func validateRegexSafety(pattern string) error { + // 限制模式长度 + if len(pattern) > 1000 { + return fmt.Errorf("pattern too long (max 1000 chars)") + } + + // 检测危险模式:嵌套量词 + // 例如:(\w+)+, (\d+)+, (a+)+, (.+)+ + dangerousPatterns := []string{ + `(\w+)+`, `(\d+)+`, `(a+)+`, `(.+)+`, + `(\w*)*`, `(\d*)*`, `(a*)*`, `(.*)*`, + `(\w+)?+`, `(\d+)?+`, + } + + for _, dangerous := range dangerousPatterns { + if strings.Contains(pattern, dangerous) { + return fmt.Errorf("potential catastrophic backtracking pattern detected") + } + } + + return nil +} + // Name 返回中间件名称。 func (m *RewriteMiddleware) Name() string { return "rewrite" diff --git a/internal/middleware/rewrite/rewrite_test.go b/internal/middleware/rewrite/rewrite_test.go index bb0f296..f5c8af1 100644 --- a/internal/middleware/rewrite/rewrite_test.go +++ b/internal/middleware/rewrite/rewrite_test.go @@ -2,6 +2,7 @@ package rewrite import ( "bytes" + "strings" "testing" "github.com/valyala/fasthttp" @@ -285,3 +286,49 @@ func TestRewriteMiddlewareRules(t *testing.T) { t.Errorf("Expected 2 rules, got %d", len(compiled)) } } + +func TestReDoSProtection(t *testing.T) { + tests := []struct { + name string + pattern string + wantErr bool + }{ + { + name: "safe pattern", + pattern: "^/api/v1/(.*)$", + wantErr: false, + }, + { + name: "nested quantifier (\\w+)+", + pattern: `(\w+)+`, + wantErr: true, + }, + { + name: "nested quantifier (.+)+", + pattern: `(.+)+`, + wantErr: true, + }, + { + name: "nested quantifier (\\d+)+", + pattern: `(\d+)+`, + wantErr: true, + }, + { + name: "pattern too long", + pattern: strings.Repeat("a", 1001), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rules := []config.RewriteRule{ + {Pattern: tt.pattern, Replacement: "/new", Flag: "last"}, + } + _, err := New(rules) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}