feat(middleware,bodylimit): 新增请求体大小限制中间件

- 支持 client_max_body_size 配置,单位支持 b/kb/mb/gb
- 支持全局配置和路径级别覆盖
- 超限返回 413 Request Entity Too Large
- 使用 io.LimitReader 强制限制实际读取字节数

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-07 17:49:47 +08:00
parent f6245c19e0
commit deb9b3139a
3 changed files with 615 additions and 0 deletions

View File

@ -0,0 +1,40 @@
<!-- Parent: ../AGENTS.md -->
<!-- Generated: 2026-04-07 | Updated: 2026-04-07 -->
# bodylimit
## Purpose
HTTP 请求体大小限制中间件,防止客户端通过发送超大请求体导致服务器资源耗尽,支持全局配置和路径级别覆盖。
## Key Files
| File | Description |
|------|-------------|
| `bodylimit.go` | 请求体限制中间件实现,支持大小解析、路径级别配置 |
| `bodylimit_test.go` | 中间件单元测试 |
## For AI Agents
### Working In This Directory
- 使用 `io.LimitReader` 强制限制实际读取的字节数
- 支持路径级别配置覆盖全局配置(最长匹配优先)
- 大小字符串解析支持 b、kb、mb、gb 单位(不区分大小写)
- 超限返回 413 Request Entity Too Large
### Testing Requirements
- 运行测试:`go test ./internal/middleware/bodylimit/...`
- 测试覆盖:大小解析、路径匹配、超限处理
### Common Patterns
- 创建中间件:`bodylimit.New("10mb")`
- 添加路径配置:`bl.AddPathLimit("/upload", "100mb")`
- 获取路径限制:`bl.GetLimit(path)`
- 解析大小:`ParseSize("1kb")` → 1024
- 格式化大小:`FormatSize(1024)` → "1kb"
## Dependencies
### External
- `github.com/valyala/fasthttp` - HTTP 框架
<!-- MANUAL: -->

View File

@ -0,0 +1,299 @@
// Package bodylimit 提供 HTTP 请求体大小限制的中间件。
//
// 该文件包含请求体大小限制相关的核心功能,包括:
// - BodyLimit 中间件:限制请求体大小
// - 解析大小字符串:支持 b, kb, mb, gb 等单位
// - 路径级别的覆盖配置
//
// 主要用途:
//
// 防止客户端通过发送超大请求体或 chunked 传输绕过限制导致服务器资源耗尽。
//
// 注意事项:
// - 使用 io.LimitReader 强制限制实际读取的字节数
// - 支持路径级别配置覆盖全局配置
// - 超限返回 413 Request Entity Too Large
//
// 作者xfy
package bodylimit
import (
"fmt"
"strconv"
"strings"
"sync"
"github.com/valyala/fasthttp"
)
// DefaultMaxBodySize 默认请求体大小限制为 1MB。
const DefaultMaxBodySize = 1 << 20 // 1MB
// BodyLimit 请求体大小限制中间件。
//
// 限制请求体的最大字节数,超过限制的请求将被拒绝并返回 413 错误。
// 支持全局配置和路径级别的覆盖配置。
type BodyLimit struct {
// maxBodySize 全局请求体大小限制(字节)
maxBodySize int64
// pathLimits 路径级别的限制配置
// key 为路径前缀value 为该路径的限制大小
pathLimits map[string]int64
// pathLimitsMu 保护 pathLimits 的互斥锁
pathLimitsMu sync.RWMutex
}
// New 创建请求体大小限制中间件。
//
// 参数:
// - maxBodySize: 最大请求体大小字符串,如 "1mb", "10kb" 等
//
// 返回值:
// - *BodyLimit: 创建的中间件实例
// - error: 解析大小字符串失败时的错误
func New(maxBodySize string) (*BodyLimit, error) {
size, err := ParseSize(maxBodySize)
if err != nil {
return nil, fmt.Errorf("解析 client_max_body_size 失败: %w", err)
}
return &BodyLimit{
maxBodySize: size,
pathLimits: make(map[string]int64),
}, nil
}
// NewWithDefault 使用默认限制1MB创建中间件。
//
// 返回值:
// - *BodyLimit: 创建的中间件实例
func NewWithDefault() *BodyLimit {
return &BodyLimit{
maxBodySize: DefaultMaxBodySize,
pathLimits: make(map[string]int64),
}
}
// Name 返回中间件名称。
//
// 返回值:
// - string: 中间件名称
func (bl *BodyLimit) Name() string {
return "BodyLimit"
}
// AddPathLimit 添加路径级别的限制配置。
//
// 参数:
// - path: 路径前缀
// - sizeStr: 大小字符串,如 "1mb", "10kb" 等
//
// 返回值:
// - error: 解析大小字符串失败时的错误
func (bl *BodyLimit) AddPathLimit(path, sizeStr string) error {
size, err := ParseSize(sizeStr)
if err != nil {
return fmt.Errorf("解析路径 %s 的 client_max_body_size 失败: %w", path, err)
}
bl.pathLimitsMu.Lock()
bl.pathLimits[path] = size
bl.pathLimitsMu.Unlock()
return nil
}
// GetLimit 获取指定路径的请求体限制。
//
// 优先使用路径级别配置,如无则使用全局配置。
//
// 参数:
// - path: 请求路径
//
// 返回值:
// - int64: 该路径的最大请求体大小(字节)
func (bl *BodyLimit) GetLimit(path string) int64 {
bl.pathLimitsMu.RLock()
defer bl.pathLimitsMu.RUnlock()
// 查找匹配的路径配置(最长匹配优先)
var matchedLimit int64
var matchedPath string
var matched bool
for prefix, limit := range bl.pathLimits {
if strings.HasPrefix(path, prefix) {
// 选择最长的匹配路径
if !matched || len(prefix) > len(matchedPath) {
matchedLimit = limit
matchedPath = prefix
matched = true
}
}
}
if matched {
return matchedLimit
}
return bl.maxBodySize
}
// Process 实现中间件接口。
//
// 检查请求体大小是否超过限制,超限返回 413 错误。
//
// 参数:
// - next: 下一个请求处理器
//
// 返回值:
// - fasthttp.RequestHandler: 包装后的请求处理器
func (bl *BodyLimit) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
limit := bl.GetLimit(path)
// 检查 Content-Length 头
contentLength := ctx.Request.Header.ContentLength()
if contentLength > 0 && int64(contentLength) > limit {
ctx.Error("Request Entity Too Large", fasthttp.StatusRequestEntityTooLarge)
return
}
// 对于 chunked 传输或没有 Content-Length 的请求
// 设置最大读取限制
ctx.Request.SetBodyStream(ctx.Request.BodyStream(), int(limit))
// 包装请求体读取以检测超限
limitedReader := &limitedBodyReader{
ctx: ctx,
limit: limit,
original: ctx.Request.BodyStream(),
}
ctx.Request.SetBodyStream(limitedReader, -1)
next(ctx)
}
}
// limitedBodyReader 包装请求体读取器以限制最大读取字节数。
type limitedBodyReader struct {
ctx *fasthttp.RequestCtx
limit int64
original interface {
Read(p []byte) (n int, err error)
}
read int64
done bool
}
// Read 实现读取接口,在超过限制时返回错误。
func (l *limitedBodyReader) Read(p []byte) (n int, err error) {
if l.done {
return 0, fmt.Errorf("request body too large")
}
// 计算还能读取多少字节
remaining := l.limit - l.read
if remaining <= 0 {
l.done = true
// 返回 413 错误
l.ctx.Error("Request Entity Too Large", fasthttp.StatusRequestEntityTooLarge)
return 0, fmt.Errorf("request body exceeds limit of %d bytes", l.limit)
}
// 限制读取长度
if int64(len(p)) > remaining {
p = p[:remaining]
}
n, err = l.original.Read(p)
l.read += int64(n)
return n, err
}
// ParseSize 解析大小字符串为字节数。
//
// 支持的单位b, kb, mb, gb不区分大小写
// 无单位时默认为字节。
//
// 参数:
// - sizeStr: 大小字符串,如 "1mb", "10kb", "1024" 等
//
// 返回值:
// - int64: 字节数
// - error: 解析失败时的错误
func ParseSize(sizeStr string) (int64, error) {
if sizeStr == "" {
return DefaultMaxBodySize, nil
}
sizeStr = strings.TrimSpace(strings.ToLower(sizeStr))
// 解析数值和单位
var numStr string
var unit string
for i, c := range sizeStr {
if c >= '0' && c <= '9' || c == '.' {
numStr = sizeStr[:i+1]
} else {
unit = sizeStr[i:]
break
}
}
if numStr == "" {
return 0, fmt.Errorf("无效的大小格式: %s", sizeStr)
}
value, err := strconv.ParseFloat(numStr, 64)
if err != nil {
return 0, fmt.Errorf("解析数值失败: %w", err)
}
var multiplier float64
switch unit {
case "", "b":
multiplier = 1
case "kb":
multiplier = 1024
case "mb":
multiplier = 1024 * 1024
case "gb":
multiplier = 1024 * 1024 * 1024
default:
return 0, fmt.Errorf("不支持的大小单位: %s", unit)
}
return int64(value * multiplier), nil
}
// FormatSize 将字节数格式化为人类可读的字符串。
//
// 参数:
// - size: 字节数
//
// 返回值:
// - string: 格式化后的字符串,如 "1mb", "10kb" 等
func FormatSize(size int64) string {
const (
KB = 1024
MB = 1024 * KB
GB = 1024 * MB
)
switch {
case size >= GB:
return fmt.Sprintf("%.2fgb", float64(size)/GB)
case size >= MB:
return fmt.Sprintf("%.2fmb", float64(size)/MB)
case size >= KB:
return fmt.Sprintf("%.2fkb", float64(size)/KB)
default:
return fmt.Sprintf("%db", size)
}
}

View File

@ -0,0 +1,276 @@
// Package bodylimit 提供请求体大小限制中间件的测试。
//
// 作者xfy
package bodylimit
import (
"bytes"
"io"
"strings"
"testing"
"github.com/valyala/fasthttp"
)
// TestParseSize 测试大小字符串解析。
func TestParseSize(t *testing.T) {
tests := []struct {
name string
input string
expected int64
wantErr bool
}{
{"empty string", "", 1 << 20, false}, // 默认 1MB
{"plain bytes", "1024", 1024, false}, // 纯数字
{"with b", "2048b", 2048, false}, // 带 b 单位
{"kilobytes", "10kb", 10 * 1024, false}, // KB
{"megabytes", "1mb", 1024 * 1024, false}, // MB
{"gigabytes", "1gb", 1024 * 1024 * 1024, false}, // GB
{"uppercase", "1MB", 1024 * 1024, false}, // 大写
{"with spaces", " 1mb ", 1024 * 1024, false}, // 带空格
{"decimal", "1.5mb", int64(1.5 * 1024 * 1024), false}, // 小数
{"invalid unit", "1xx", 0, true}, // 无效单位
{"no number", "mb", 0, true}, // 无数字
{"negative", "-1mb", 0, true}, // 负数ParseFloat 可能成功但结果为负)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseSize(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ParseSize(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
return
}
if !tt.wantErr && got != tt.expected {
t.Errorf("ParseSize(%q) = %d, want %d", tt.input, got, tt.expected)
}
})
}
}
// TestFormatSize 测试字节数格式化。
func TestFormatSize(t *testing.T) {
tests := []struct {
input int64
expected string
}{
{512, "512b"},
{1024, "1.00kb"},
{1024 * 1024, "1.00mb"},
{1024 * 1024 * 1024, "1.00gb"},
{1536, "1.50kb"},
}
for _, tt := range tests {
t.Run(FormatSize(tt.input), func(t *testing.T) {
got := FormatSize(tt.input)
if got != tt.expected {
t.Errorf("FormatSize(%d) = %s, want %s", tt.input, got, tt.expected)
}
})
}
}
// TestNew 测试创建中间件。
func TestNew(t *testing.T) {
tests := []struct {
name string
maxSize string
wantErr bool
expected int64
}{
{"valid 1mb", "1mb", false, 1024 * 1024},
{"valid 10kb", "10kb", false, 10 * 1024},
{"invalid", "invalid", true, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bl, err := New(tt.maxSize)
if (err != nil) != tt.wantErr {
t.Errorf("New(%q) error = %v, wantErr %v", tt.maxSize, err, tt.wantErr)
return
}
if !tt.wantErr && bl.maxBodySize != tt.expected {
t.Errorf("New(%q).maxBodySize = %d, want %d", tt.maxSize, bl.maxBodySize, tt.expected)
}
})
}
}
// TestBodyLimit_Process 测试中间件处理。
func TestBodyLimit_Process(t *testing.T) {
bl, err := New("100b")
if err != nil {
t.Fatalf("创建中间件失败: %v", err)
}
tests := []struct {
name string
body string
contentLength int
expectedStatus int
}{
{
name: "small body within limit",
body: "small body",
contentLength: 10,
expectedStatus: fasthttp.StatusOK,
},
{
name: "body exactly at limit",
body: strings.Repeat("a", 100),
contentLength: 100,
expectedStatus: fasthttp.StatusOK,
},
{
name: "body exceeds limit via content-length",
body: strings.Repeat("a", 200),
contentLength: 200,
expectedStatus: fasthttp.StatusRequestEntityTooLarge,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nextHandler := func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(fasthttp.StatusOK)
}
handler := bl.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetContentLength(tt.contentLength)
ctx.Request.SetBodyStream(bytes.NewReader([]byte(tt.body)), tt.contentLength)
handler(ctx)
if ctx.Response.StatusCode() != tt.expectedStatus {
t.Errorf("status code = %d, want %d", ctx.Response.StatusCode(), tt.expectedStatus)
}
})
}
}
// TestBodyLimit_ProcessChunked 测试 chunked 传输编码。
func TestBodyLimit_ProcessChunked(t *testing.T) {
bl, err := New("100b")
if err != nil {
t.Fatalf("创建中间件失败: %v", err)
}
// 测试 chunked 传输无法绕过限制
body := strings.Repeat("a", 150) // 150 字节超过 100 字节限制
reader := &slowReader{data: []byte(body), chunkSize: 10}
nextHandler := func(ctx *fasthttp.RequestCtx) {
// 尝试读取完整请求体
body, err := io.ReadAll(ctx.Request.BodyStream())
if err != nil {
ctx.SetStatusCode(fasthttp.StatusRequestEntityTooLarge)
return
}
if len(body) > 100 {
ctx.SetStatusCode(fasthttp.StatusRequestEntityTooLarge)
return
}
ctx.SetStatusCode(fasthttp.StatusOK)
}
handler := bl.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
// 不设置 Content-Length 模拟 chunked 传输
ctx.Request.SetBodyStream(reader, -1)
handler(ctx)
// 应该返回 413因为 body 超过限制
if ctx.Response.StatusCode() != fasthttp.StatusRequestEntityTooLarge {
t.Errorf("chunked 传输绕过测试失败: status code = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusRequestEntityTooLarge)
}
}
// TestBodyLimit_PathLimits 测试路径级别配置。
func TestBodyLimit_PathLimits(t *testing.T) {
bl, err := New("1mb")
if err != nil {
t.Fatalf("创建中间件失败: %v", err)
}
// 添加路径级别配置
if err := bl.AddPathLimit("/api/upload", "10mb"); err != nil {
t.Fatalf("添加路径限制失败: %v", err)
}
if err := bl.AddPathLimit("/api", "2mb"); err != nil {
t.Fatalf("添加路径限制失败: %v", err)
}
tests := []struct {
path string
expected int64
}{
{"/api/upload/file", 10 * 1024 * 1024}, // 最长匹配 /api/upload
{"/api/users", 2 * 1024 * 1024}, // 匹配 /api
{"/other/path", 1 * 1024 * 1024}, // 默认限制
{"/api", 2 * 1024 * 1024}, // 匹配 /api
{"/apix", 2 * 1024 * 1024}, // 匹配 /api前缀匹配
{"/apiupload", 2 * 1024 * 1024}, // 匹配 /api前缀匹配
{"/notapi", 1 * 1024 * 1024}, // 不匹配 /api
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
got := bl.GetLimit(tt.path)
if got != tt.expected {
t.Errorf("GetLimit(%q) = %d, want %d", tt.path, got, tt.expected)
}
})
}
}
// TestBodyLimit_Name 测试中间件名称。
func TestBodyLimit_Name(t *testing.T) {
bl := NewWithDefault()
if bl.Name() != "BodyLimit" {
t.Errorf("Name() = %s, want BodyLimit", bl.Name())
}
}
// TestBodyLimit_DefaultMaxBodySize 测试默认大小。
func TestBodyLimit_DefaultMaxBodySize(t *testing.T) {
if DefaultMaxBodySize != 1<<20 {
t.Errorf("DefaultMaxBodySize = %d, want %d", DefaultMaxBodySize, 1<<20)
}
}
// slowReader 模拟慢速读取,用于测试 chunked 传输。
type slowReader struct {
data []byte
pos int
chunkSize int
}
func (r *slowReader) Read(p []byte) (n int, err error) {
if r.pos >= len(r.data) {
return 0, io.EOF
}
// 每次只读取 chunkSize 字节
remaining := len(r.data) - r.pos
toRead := r.chunkSize
if toRead > remaining {
toRead = remaining
}
if toRead > len(p) {
toRead = len(p)
}
n = toRead
copy(p, r.data[r.pos:r.pos+toRead])
r.pos += toRead
return n, nil
}