diff --git a/internal/middleware/bodylimit/AGENTS.md b/internal/middleware/bodylimit/AGENTS.md new file mode 100644 index 0000000..0c275f1 --- /dev/null +++ b/internal/middleware/bodylimit/AGENTS.md @@ -0,0 +1,40 @@ + + + +# bodylimit + +## Purpose +HTTP 请求体大小限制中间件,防止客户端通过发送超大请求体导致服务器资源耗尽,支持全局配置和路径级别覆盖。 + +## Key Files + +| File | Description | +|------|-------------| +| `bodylimit.go` | 请求体限制中间件实现,支持大小解析、路径级别配置 | +| `bodylimit_test.go` | 中间件单元测试 | + +## For AI Agents + +### Working In This Directory +- 使用 `io.LimitReader` 强制限制实际读取的字节数 +- 支持路径级别配置覆盖全局配置(最长匹配优先) +- 大小字符串解析支持 b、kb、mb、gb 单位(不区分大小写) +- 超限返回 413 Request Entity Too Large + +### Testing Requirements +- 运行测试:`go test ./internal/middleware/bodylimit/...` +- 测试覆盖:大小解析、路径匹配、超限处理 + +### Common Patterns +- 创建中间件:`bodylimit.New("10mb")` +- 添加路径配置:`bl.AddPathLimit("/upload", "100mb")` +- 获取路径限制:`bl.GetLimit(path)` +- 解析大小:`ParseSize("1kb")` → 1024 +- 格式化大小:`FormatSize(1024)` → "1kb" + +## Dependencies + +### External +- `github.com/valyala/fasthttp` - HTTP 框架 + + \ No newline at end of file diff --git a/internal/middleware/bodylimit/bodylimit.go b/internal/middleware/bodylimit/bodylimit.go new file mode 100644 index 0000000..97b433c --- /dev/null +++ b/internal/middleware/bodylimit/bodylimit.go @@ -0,0 +1,299 @@ +// Package bodylimit 提供 HTTP 请求体大小限制的中间件。 +// +// 该文件包含请求体大小限制相关的核心功能,包括: +// - BodyLimit 中间件:限制请求体大小 +// - 解析大小字符串:支持 b, kb, mb, gb 等单位 +// - 路径级别的覆盖配置 +// +// 主要用途: +// +// 防止客户端通过发送超大请求体或 chunked 传输绕过限制导致服务器资源耗尽。 +// +// 注意事项: +// - 使用 io.LimitReader 强制限制实际读取的字节数 +// - 支持路径级别配置覆盖全局配置 +// - 超限返回 413 Request Entity Too Large +// +// 作者:xfy +package bodylimit + +import ( + "fmt" + "strconv" + "strings" + "sync" + + "github.com/valyala/fasthttp" +) + +// DefaultMaxBodySize 默认请求体大小限制为 1MB。 +const DefaultMaxBodySize = 1 << 20 // 1MB + +// BodyLimit 请求体大小限制中间件。 +// +// 限制请求体的最大字节数,超过限制的请求将被拒绝并返回 413 错误。 +// 支持全局配置和路径级别的覆盖配置。 +type BodyLimit struct { + // maxBodySize 全局请求体大小限制(字节) + maxBodySize int64 + + // pathLimits 路径级别的限制配置 + // key 为路径前缀,value 为该路径的限制大小 + pathLimits map[string]int64 + + // pathLimitsMu 保护 pathLimits 的互斥锁 + pathLimitsMu sync.RWMutex +} + +// New 创建请求体大小限制中间件。 +// +// 参数: +// - maxBodySize: 最大请求体大小字符串,如 "1mb", "10kb" 等 +// +// 返回值: +// - *BodyLimit: 创建的中间件实例 +// - error: 解析大小字符串失败时的错误 +func New(maxBodySize string) (*BodyLimit, error) { + size, err := ParseSize(maxBodySize) + if err != nil { + return nil, fmt.Errorf("解析 client_max_body_size 失败: %w", err) + } + + return &BodyLimit{ + maxBodySize: size, + pathLimits: make(map[string]int64), + }, nil +} + +// NewWithDefault 使用默认限制(1MB)创建中间件。 +// +// 返回值: +// - *BodyLimit: 创建的中间件实例 +func NewWithDefault() *BodyLimit { + return &BodyLimit{ + maxBodySize: DefaultMaxBodySize, + pathLimits: make(map[string]int64), + } +} + +// Name 返回中间件名称。 +// +// 返回值: +// - string: 中间件名称 +func (bl *BodyLimit) Name() string { + return "BodyLimit" +} + +// AddPathLimit 添加路径级别的限制配置。 +// +// 参数: +// - path: 路径前缀 +// - sizeStr: 大小字符串,如 "1mb", "10kb" 等 +// +// 返回值: +// - error: 解析大小字符串失败时的错误 +func (bl *BodyLimit) AddPathLimit(path, sizeStr string) error { + size, err := ParseSize(sizeStr) + if err != nil { + return fmt.Errorf("解析路径 %s 的 client_max_body_size 失败: %w", path, err) + } + + bl.pathLimitsMu.Lock() + bl.pathLimits[path] = size + bl.pathLimitsMu.Unlock() + + return nil +} + +// GetLimit 获取指定路径的请求体限制。 +// +// 优先使用路径级别配置,如无则使用全局配置。 +// +// 参数: +// - path: 请求路径 +// +// 返回值: +// - int64: 该路径的最大请求体大小(字节) +func (bl *BodyLimit) GetLimit(path string) int64 { + bl.pathLimitsMu.RLock() + defer bl.pathLimitsMu.RUnlock() + + // 查找匹配的路径配置(最长匹配优先) + var matchedLimit int64 + var matchedPath string + var matched bool + + for prefix, limit := range bl.pathLimits { + if strings.HasPrefix(path, prefix) { + // 选择最长的匹配路径 + if !matched || len(prefix) > len(matchedPath) { + matchedLimit = limit + matchedPath = prefix + matched = true + } + } + } + + if matched { + return matchedLimit + } + + return bl.maxBodySize +} + +// Process 实现中间件接口。 +// +// 检查请求体大小是否超过限制,超限返回 413 错误。 +// +// 参数: +// - next: 下一个请求处理器 +// +// 返回值: +// - fasthttp.RequestHandler: 包装后的请求处理器 +func (bl *BodyLimit) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + path := string(ctx.Path()) + limit := bl.GetLimit(path) + + // 检查 Content-Length 头 + contentLength := ctx.Request.Header.ContentLength() + if contentLength > 0 && int64(contentLength) > limit { + ctx.Error("Request Entity Too Large", fasthttp.StatusRequestEntityTooLarge) + return + } + + // 对于 chunked 传输或没有 Content-Length 的请求 + // 设置最大读取限制 + ctx.Request.SetBodyStream(ctx.Request.BodyStream(), int(limit)) + + // 包装请求体读取以检测超限 + limitedReader := &limitedBodyReader{ + ctx: ctx, + limit: limit, + original: ctx.Request.BodyStream(), + } + ctx.Request.SetBodyStream(limitedReader, -1) + + next(ctx) + } +} + +// limitedBodyReader 包装请求体读取器以限制最大读取字节数。 +type limitedBodyReader struct { + ctx *fasthttp.RequestCtx + limit int64 + original interface { + Read(p []byte) (n int, err error) + } + read int64 + done bool +} + +// Read 实现读取接口,在超过限制时返回错误。 +func (l *limitedBodyReader) Read(p []byte) (n int, err error) { + if l.done { + return 0, fmt.Errorf("request body too large") + } + + // 计算还能读取多少字节 + remaining := l.limit - l.read + if remaining <= 0 { + l.done = true + // 返回 413 错误 + l.ctx.Error("Request Entity Too Large", fasthttp.StatusRequestEntityTooLarge) + return 0, fmt.Errorf("request body exceeds limit of %d bytes", l.limit) + } + + // 限制读取长度 + if int64(len(p)) > remaining { + p = p[:remaining] + } + + n, err = l.original.Read(p) + l.read += int64(n) + + return n, err +} + +// ParseSize 解析大小字符串为字节数。 +// +// 支持的单位:b, kb, mb, gb(不区分大小写) +// 无单位时默认为字节。 +// +// 参数: +// - sizeStr: 大小字符串,如 "1mb", "10kb", "1024" 等 +// +// 返回值: +// - int64: 字节数 +// - error: 解析失败时的错误 +func ParseSize(sizeStr string) (int64, error) { + if sizeStr == "" { + return DefaultMaxBodySize, nil + } + + sizeStr = strings.TrimSpace(strings.ToLower(sizeStr)) + + // 解析数值和单位 + var numStr string + var unit string + + for i, c := range sizeStr { + if c >= '0' && c <= '9' || c == '.' { + numStr = sizeStr[:i+1] + } else { + unit = sizeStr[i:] + break + } + } + + if numStr == "" { + return 0, fmt.Errorf("无效的大小格式: %s", sizeStr) + } + + value, err := strconv.ParseFloat(numStr, 64) + if err != nil { + return 0, fmt.Errorf("解析数值失败: %w", err) + } + + var multiplier float64 + switch unit { + case "", "b": + multiplier = 1 + case "kb": + multiplier = 1024 + case "mb": + multiplier = 1024 * 1024 + case "gb": + multiplier = 1024 * 1024 * 1024 + default: + return 0, fmt.Errorf("不支持的大小单位: %s", unit) + } + + return int64(value * multiplier), nil +} + +// FormatSize 将字节数格式化为人类可读的字符串。 +// +// 参数: +// - size: 字节数 +// +// 返回值: +// - string: 格式化后的字符串,如 "1mb", "10kb" 等 +func FormatSize(size int64) string { + const ( + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + ) + + switch { + case size >= GB: + return fmt.Sprintf("%.2fgb", float64(size)/GB) + case size >= MB: + return fmt.Sprintf("%.2fmb", float64(size)/MB) + case size >= KB: + return fmt.Sprintf("%.2fkb", float64(size)/KB) + default: + return fmt.Sprintf("%db", size) + } +} diff --git a/internal/middleware/bodylimit/bodylimit_test.go b/internal/middleware/bodylimit/bodylimit_test.go new file mode 100644 index 0000000..354365d --- /dev/null +++ b/internal/middleware/bodylimit/bodylimit_test.go @@ -0,0 +1,276 @@ +// Package bodylimit 提供请求体大小限制中间件的测试。 +// +// 作者:xfy +package bodylimit + +import ( + "bytes" + "io" + "strings" + "testing" + + "github.com/valyala/fasthttp" +) + +// TestParseSize 测试大小字符串解析。 +func TestParseSize(t *testing.T) { + tests := []struct { + name string + input string + expected int64 + wantErr bool + }{ + {"empty string", "", 1 << 20, false}, // 默认 1MB + {"plain bytes", "1024", 1024, false}, // 纯数字 + {"with b", "2048b", 2048, false}, // 带 b 单位 + {"kilobytes", "10kb", 10 * 1024, false}, // KB + {"megabytes", "1mb", 1024 * 1024, false}, // MB + {"gigabytes", "1gb", 1024 * 1024 * 1024, false}, // GB + {"uppercase", "1MB", 1024 * 1024, false}, // 大写 + {"with spaces", " 1mb ", 1024 * 1024, false}, // 带空格 + {"decimal", "1.5mb", int64(1.5 * 1024 * 1024), false}, // 小数 + {"invalid unit", "1xx", 0, true}, // 无效单位 + {"no number", "mb", 0, true}, // 无数字 + {"negative", "-1mb", 0, true}, // 负数(ParseFloat 可能成功但结果为负) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseSize(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ParseSize(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.expected { + t.Errorf("ParseSize(%q) = %d, want %d", tt.input, got, tt.expected) + } + }) + } +} + +// TestFormatSize 测试字节数格式化。 +func TestFormatSize(t *testing.T) { + tests := []struct { + input int64 + expected string + }{ + {512, "512b"}, + {1024, "1.00kb"}, + {1024 * 1024, "1.00mb"}, + {1024 * 1024 * 1024, "1.00gb"}, + {1536, "1.50kb"}, + } + + for _, tt := range tests { + t.Run(FormatSize(tt.input), func(t *testing.T) { + got := FormatSize(tt.input) + if got != tt.expected { + t.Errorf("FormatSize(%d) = %s, want %s", tt.input, got, tt.expected) + } + }) + } +} + +// TestNew 测试创建中间件。 +func TestNew(t *testing.T) { + tests := []struct { + name string + maxSize string + wantErr bool + expected int64 + }{ + {"valid 1mb", "1mb", false, 1024 * 1024}, + {"valid 10kb", "10kb", false, 10 * 1024}, + {"invalid", "invalid", true, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bl, err := New(tt.maxSize) + if (err != nil) != tt.wantErr { + t.Errorf("New(%q) error = %v, wantErr %v", tt.maxSize, err, tt.wantErr) + return + } + if !tt.wantErr && bl.maxBodySize != tt.expected { + t.Errorf("New(%q).maxBodySize = %d, want %d", tt.maxSize, bl.maxBodySize, tt.expected) + } + }) + } +} + +// TestBodyLimit_Process 测试中间件处理。 +func TestBodyLimit_Process(t *testing.T) { + bl, err := New("100b") + if err != nil { + t.Fatalf("创建中间件失败: %v", err) + } + + tests := []struct { + name string + body string + contentLength int + expectedStatus int + }{ + { + name: "small body within limit", + body: "small body", + contentLength: 10, + expectedStatus: fasthttp.StatusOK, + }, + { + name: "body exactly at limit", + body: strings.Repeat("a", 100), + contentLength: 100, + expectedStatus: fasthttp.StatusOK, + }, + { + name: "body exceeds limit via content-length", + body: strings.Repeat("a", 200), + contentLength: 200, + expectedStatus: fasthttp.StatusRequestEntityTooLarge, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nextHandler := func(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(fasthttp.StatusOK) + } + + handler := bl.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod("POST") + ctx.Request.Header.SetContentLength(tt.contentLength) + ctx.Request.SetBodyStream(bytes.NewReader([]byte(tt.body)), tt.contentLength) + + handler(ctx) + + if ctx.Response.StatusCode() != tt.expectedStatus { + t.Errorf("status code = %d, want %d", ctx.Response.StatusCode(), tt.expectedStatus) + } + }) + } +} + +// TestBodyLimit_ProcessChunked 测试 chunked 传输编码。 +func TestBodyLimit_ProcessChunked(t *testing.T) { + bl, err := New("100b") + if err != nil { + t.Fatalf("创建中间件失败: %v", err) + } + + // 测试 chunked 传输无法绕过限制 + body := strings.Repeat("a", 150) // 150 字节超过 100 字节限制 + reader := &slowReader{data: []byte(body), chunkSize: 10} + + nextHandler := func(ctx *fasthttp.RequestCtx) { + // 尝试读取完整请求体 + body, err := io.ReadAll(ctx.Request.BodyStream()) + if err != nil { + ctx.SetStatusCode(fasthttp.StatusRequestEntityTooLarge) + return + } + if len(body) > 100 { + ctx.SetStatusCode(fasthttp.StatusRequestEntityTooLarge) + return + } + ctx.SetStatusCode(fasthttp.StatusOK) + } + + handler := bl.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod("POST") + // 不设置 Content-Length 模拟 chunked 传输 + ctx.Request.SetBodyStream(reader, -1) + + handler(ctx) + + // 应该返回 413,因为 body 超过限制 + if ctx.Response.StatusCode() != fasthttp.StatusRequestEntityTooLarge { + t.Errorf("chunked 传输绕过测试失败: status code = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusRequestEntityTooLarge) + } +} + +// TestBodyLimit_PathLimits 测试路径级别配置。 +func TestBodyLimit_PathLimits(t *testing.T) { + bl, err := New("1mb") + if err != nil { + t.Fatalf("创建中间件失败: %v", err) + } + + // 添加路径级别配置 + if err := bl.AddPathLimit("/api/upload", "10mb"); err != nil { + t.Fatalf("添加路径限制失败: %v", err) + } + if err := bl.AddPathLimit("/api", "2mb"); err != nil { + t.Fatalf("添加路径限制失败: %v", err) + } + + tests := []struct { + path string + expected int64 + }{ + {"/api/upload/file", 10 * 1024 * 1024}, // 最长匹配 /api/upload + {"/api/users", 2 * 1024 * 1024}, // 匹配 /api + {"/other/path", 1 * 1024 * 1024}, // 默认限制 + {"/api", 2 * 1024 * 1024}, // 匹配 /api + {"/apix", 2 * 1024 * 1024}, // 匹配 /api(前缀匹配) + {"/apiupload", 2 * 1024 * 1024}, // 匹配 /api(前缀匹配) + {"/notapi", 1 * 1024 * 1024}, // 不匹配 /api + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + got := bl.GetLimit(tt.path) + if got != tt.expected { + t.Errorf("GetLimit(%q) = %d, want %d", tt.path, got, tt.expected) + } + }) + } +} + +// TestBodyLimit_Name 测试中间件名称。 +func TestBodyLimit_Name(t *testing.T) { + bl := NewWithDefault() + if bl.Name() != "BodyLimit" { + t.Errorf("Name() = %s, want BodyLimit", bl.Name()) + } +} + +// TestBodyLimit_DefaultMaxBodySize 测试默认大小。 +func TestBodyLimit_DefaultMaxBodySize(t *testing.T) { + if DefaultMaxBodySize != 1<<20 { + t.Errorf("DefaultMaxBodySize = %d, want %d", DefaultMaxBodySize, 1<<20) + } +} + +// slowReader 模拟慢速读取,用于测试 chunked 传输。 +type slowReader struct { + data []byte + pos int + chunkSize int +} + +func (r *slowReader) Read(p []byte) (n int, err error) { + if r.pos >= len(r.data) { + return 0, io.EOF + } + + // 每次只读取 chunkSize 字节 + remaining := len(r.data) - r.pos + toRead := r.chunkSize + if toRead > remaining { + toRead = remaining + } + if toRead > len(p) { + toRead = len(p) + } + + n = toRead + copy(p, r.data[r.pos:r.pos+toRead]) + r.pos += toRead + + return n, nil +}