278 lines
6.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package bodylimit 提供 HTTP 请求体大小限制的中间件。
//
// 该文件包含请求体大小限制相关的核心功能,包括:
// - BodyLimit 中间件:限制请求体大小
// - 解析大小字符串:支持 b, kb, mb, gb 等单位
// - 路径级别的覆盖配置
//
// 主要用途:
//
// 防止客户端通过发送超大请求体或 chunked 传输绕过限制导致服务器资源耗尽。
//
// 注意事项:
// - 使用 io.LimitReader 强制限制实际读取的字节数
// - 支持路径级别配置覆盖全局配置
// - 超限返回 413 Request Entity Too Large
//
// 作者xfy
package bodylimit
import (
"fmt"
"strconv"
"strings"
"sync"
"github.com/valyala/fasthttp"
)
// DefaultMaxBodySize 默认请求体大小限制为 1MB。
const DefaultMaxBodySize = 1 << 20 // 1MB
// BodyLimit 请求体大小限制中间件。
//
// 限制请求体的最大字节数,超过限制的请求将被拒绝并返回 413 错误。
// 支持全局配置和路径级别的覆盖配置。
type BodyLimit struct {
pathLimits map[string]int64
maxBodySize int64
pathLimitsMu sync.RWMutex
}
// New 创建请求体大小限制中间件。
//
// 参数:
// - maxBodySize: 最大请求体大小字符串,如 "1mb", "10kb" 等
//
// 返回值:
// - *BodyLimit: 创建的中间件实例
// - error: 解析大小字符串失败时的错误
func New(maxBodySize string) (*BodyLimit, error) {
size, err := ParseSize(maxBodySize)
if err != nil {
return nil, fmt.Errorf("解析 client_max_body_size 失败: %w", err)
}
return &BodyLimit{
maxBodySize: size,
pathLimits: make(map[string]int64),
}, nil
}
// NewWithDefault 使用默认限制1MB创建中间件。
//
// 返回值:
// - *BodyLimit: 创建的中间件实例
func NewWithDefault() *BodyLimit {
return &BodyLimit{
maxBodySize: DefaultMaxBodySize,
pathLimits: make(map[string]int64),
}
}
// Name 返回中间件名称。
//
// 返回值:
// - string: 中间件名称
func (bl *BodyLimit) Name() string {
return "BodyLimit"
}
// AddPathLimit 添加路径级别的限制配置。
//
// 参数:
// - path: 路径前缀
// - sizeStr: 大小字符串,如 "1mb", "10kb" 等
//
// 返回值:
// - error: 解析大小字符串失败时的错误
func (bl *BodyLimit) AddPathLimit(path, sizeStr string) error {
size, err := ParseSize(sizeStr)
if err != nil {
return fmt.Errorf("解析路径 %s 的 client_max_body_size 失败: %w", path, err)
}
bl.pathLimitsMu.Lock()
bl.pathLimits[path] = size
bl.pathLimitsMu.Unlock()
return nil
}
// GetLimit 获取指定路径的请求体限制。
//
// 优先使用路径级别配置,如无则使用全局配置。
//
// 参数:
// - path: 请求路径
//
// 返回值:
// - int64: 该路径的最大请求体大小(字节)
func (bl *BodyLimit) GetLimit(path string) int64 {
bl.pathLimitsMu.RLock()
defer bl.pathLimitsMu.RUnlock()
// 查找匹配的路径配置(最长匹配优先)
var matchedLimit int64
var matchedPath string
var matched bool
for prefix, limit := range bl.pathLimits {
if strings.HasPrefix(path, prefix) {
// 选择最长的匹配路径
if !matched || len(prefix) > len(matchedPath) {
matchedLimit = limit
matchedPath = prefix
matched = true
}
}
}
if matched {
return matchedLimit
}
return bl.maxBodySize
}
// Process 实现中间件接口。
//
// 检查请求体大小是否超过限制,超限返回 413 错误。
//
// 参数:
// - next: 下一个请求处理器
//
// 返回值:
// - fasthttp.RequestHandler: 包装后的请求处理器
func (bl *BodyLimit) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
path := string(ctx.Path())
limit := bl.GetLimit(path)
// 检查 Content-Length 头
contentLength := ctx.Request.Header.ContentLength()
if contentLength > 0 && int64(contentLength) > limit {
ctx.Error("Request Entity Too Large", fasthttp.StatusRequestEntityTooLarge)
return
}
// 对于 chunked 传输或没有 Content-Length 的请求
// 只有当 BodyStream 存在时才设置限制
bodyStream := ctx.Request.BodyStream()
if bodyStream != nil {
ctx.Request.SetBodyStream(bodyStream, int(limit))
// 包装请求体读取以检测超限
limitedReader := &limitedBodyReader{
ctx: ctx,
limit: limit,
original: bodyStream,
}
ctx.Request.SetBodyStream(limitedReader, -1)
}
next(ctx)
}
}
// limitedBodyReader 包装请求体读取器以限制最大读取字节数。
//
// 当读取的字节数超过限制时,返回错误并在上下文中设置 413 状态码。
type limitedBodyReader struct {
// original 原始读取器
original interface {
Read(p []byte) (n int, err error)
}
// ctx 请求上下文,用于设置错误响应
ctx *fasthttp.RequestCtx
// limit 最大允许读取的字节数
limit int64
// read 已读取的字节数
read int64
// done 是否已达到限制
done bool
}
// Read 实现读取接口,在超过限制时返回错误。
func (l *limitedBodyReader) Read(p []byte) (n int, err error) {
if l.done {
return 0, fmt.Errorf("request body too large")
}
// 计算还能读取多少字节
remaining := l.limit - l.read
if remaining <= 0 {
l.done = true
// 返回 413 错误
l.ctx.Error("Request Entity Too Large", fasthttp.StatusRequestEntityTooLarge)
return 0, fmt.Errorf("request body exceeds limit of %d bytes", l.limit)
}
// 限制读取长度
if int64(len(p)) > remaining {
p = p[:remaining]
}
n, err = l.original.Read(p)
l.read += int64(n)
return n, err
}
// ParseSize 解析大小字符串为字节数。
//
// 支持的单位b, kb, mb, gb不区分大小写
// 无单位时默认为字节。
//
// 参数:
// - sizeStr: 大小字符串,如 "1mb", "10kb", "1024" 等
//
// 返回值:
// - int64: 字节数
// - error: 解析失败时的错误
func ParseSize(sizeStr string) (int64, error) {
if sizeStr == "" {
return DefaultMaxBodySize, nil
}
sizeStr = strings.TrimSpace(strings.ToLower(sizeStr))
// 解析数值和单位
var numStr string
var unit string
for i, c := range sizeStr {
if c >= '0' && c <= '9' || c == '.' {
numStr = sizeStr[:i+1]
} else {
unit = sizeStr[i:]
break
}
}
if numStr == "" {
return 0, fmt.Errorf("无效的大小格式: %s", sizeStr)
}
value, err := strconv.ParseFloat(numStr, 64)
if err != nil {
return 0, fmt.Errorf("解析数值失败: %w", err)
}
var multiplier float64
switch unit {
case "", "b":
multiplier = 1
case "kb":
multiplier = 1024
case "mb":
multiplier = 1024 * 1024
case "gb":
multiplier = 1024 * 1024 * 1024
default:
return 0, fmt.Errorf("不支持的大小单位: %s", unit)
}
return int64(value * multiplier), nil
}