278 lines
6.7 KiB
Go
278 lines
6.7 KiB
Go
// Package bodylimit 提供 HTTP 请求体大小限制的中间件。
|
||
//
|
||
// 该文件包含请求体大小限制相关的核心功能,包括:
|
||
// - BodyLimit 中间件:限制请求体大小
|
||
// - 解析大小字符串:支持 b, kb, mb, gb 等单位
|
||
// - 路径级别的覆盖配置
|
||
//
|
||
// 主要用途:
|
||
//
|
||
// 防止客户端通过发送超大请求体或 chunked 传输绕过限制导致服务器资源耗尽。
|
||
//
|
||
// 注意事项:
|
||
// - 使用 io.LimitReader 强制限制实际读取的字节数
|
||
// - 支持路径级别配置覆盖全局配置
|
||
// - 超限返回 413 Request Entity Too Large
|
||
//
|
||
// 作者:xfy
|
||
package bodylimit
|
||
|
||
import (
|
||
"fmt"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
|
||
"github.com/valyala/fasthttp"
|
||
)
|
||
|
||
// DefaultMaxBodySize 默认请求体大小限制为 1MB。
|
||
const DefaultMaxBodySize = 1 << 20 // 1MB
|
||
|
||
// BodyLimit 请求体大小限制中间件。
|
||
//
|
||
// 限制请求体的最大字节数,超过限制的请求将被拒绝并返回 413 错误。
|
||
// 支持全局配置和路径级别的覆盖配置。
|
||
type BodyLimit struct {
|
||
pathLimits map[string]int64
|
||
maxBodySize int64
|
||
pathLimitsMu sync.RWMutex
|
||
}
|
||
|
||
// New 创建请求体大小限制中间件。
|
||
//
|
||
// 参数:
|
||
// - maxBodySize: 最大请求体大小字符串,如 "1mb", "10kb" 等
|
||
//
|
||
// 返回值:
|
||
// - *BodyLimit: 创建的中间件实例
|
||
// - error: 解析大小字符串失败时的错误
|
||
func New(maxBodySize string) (*BodyLimit, error) {
|
||
size, err := ParseSize(maxBodySize)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("解析 client_max_body_size 失败: %w", err)
|
||
}
|
||
|
||
return &BodyLimit{
|
||
maxBodySize: size,
|
||
pathLimits: make(map[string]int64),
|
||
}, nil
|
||
}
|
||
|
||
// NewWithDefault 使用默认限制(1MB)创建中间件。
|
||
//
|
||
// 返回值:
|
||
// - *BodyLimit: 创建的中间件实例
|
||
func NewWithDefault() *BodyLimit {
|
||
return &BodyLimit{
|
||
maxBodySize: DefaultMaxBodySize,
|
||
pathLimits: make(map[string]int64),
|
||
}
|
||
}
|
||
|
||
// Name 返回中间件名称。
|
||
//
|
||
// 返回值:
|
||
// - string: 中间件名称
|
||
func (bl *BodyLimit) Name() string {
|
||
return "BodyLimit"
|
||
}
|
||
|
||
// AddPathLimit 添加路径级别的限制配置。
|
||
//
|
||
// 参数:
|
||
// - path: 路径前缀
|
||
// - sizeStr: 大小字符串,如 "1mb", "10kb" 等
|
||
//
|
||
// 返回值:
|
||
// - error: 解析大小字符串失败时的错误
|
||
func (bl *BodyLimit) AddPathLimit(path, sizeStr string) error {
|
||
size, err := ParseSize(sizeStr)
|
||
if err != nil {
|
||
return fmt.Errorf("解析路径 %s 的 client_max_body_size 失败: %w", path, err)
|
||
}
|
||
|
||
bl.pathLimitsMu.Lock()
|
||
bl.pathLimits[path] = size
|
||
bl.pathLimitsMu.Unlock()
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetLimit 获取指定路径的请求体限制。
|
||
//
|
||
// 优先使用路径级别配置,如无则使用全局配置。
|
||
//
|
||
// 参数:
|
||
// - path: 请求路径
|
||
//
|
||
// 返回值:
|
||
// - int64: 该路径的最大请求体大小(字节)
|
||
func (bl *BodyLimit) GetLimit(path string) int64 {
|
||
bl.pathLimitsMu.RLock()
|
||
defer bl.pathLimitsMu.RUnlock()
|
||
|
||
// 查找匹配的路径配置(最长匹配优先)
|
||
var matchedLimit int64
|
||
var matchedPath string
|
||
var matched bool
|
||
|
||
for prefix, limit := range bl.pathLimits {
|
||
if strings.HasPrefix(path, prefix) {
|
||
// 选择最长的匹配路径
|
||
if !matched || len(prefix) > len(matchedPath) {
|
||
matchedLimit = limit
|
||
matchedPath = prefix
|
||
matched = true
|
||
}
|
||
}
|
||
}
|
||
|
||
if matched {
|
||
return matchedLimit
|
||
}
|
||
|
||
return bl.maxBodySize
|
||
}
|
||
|
||
// Process 实现中间件接口。
|
||
//
|
||
// 检查请求体大小是否超过限制,超限返回 413 错误。
|
||
//
|
||
// 参数:
|
||
// - next: 下一个请求处理器
|
||
//
|
||
// 返回值:
|
||
// - fasthttp.RequestHandler: 包装后的请求处理器
|
||
func (bl *BodyLimit) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||
return func(ctx *fasthttp.RequestCtx) {
|
||
path := string(ctx.Path())
|
||
limit := bl.GetLimit(path)
|
||
|
||
// 检查 Content-Length 头
|
||
contentLength := ctx.Request.Header.ContentLength()
|
||
if contentLength > 0 && int64(contentLength) > limit {
|
||
ctx.Error("Request Entity Too Large", fasthttp.StatusRequestEntityTooLarge)
|
||
return
|
||
}
|
||
|
||
// 对于 chunked 传输或没有 Content-Length 的请求
|
||
// 只有当 BodyStream 存在时才设置限制
|
||
bodyStream := ctx.Request.BodyStream()
|
||
if bodyStream != nil {
|
||
ctx.Request.SetBodyStream(bodyStream, int(limit))
|
||
|
||
// 包装请求体读取以检测超限
|
||
limitedReader := &limitedBodyReader{
|
||
ctx: ctx,
|
||
limit: limit,
|
||
original: bodyStream,
|
||
}
|
||
ctx.Request.SetBodyStream(limitedReader, -1)
|
||
}
|
||
|
||
next(ctx)
|
||
}
|
||
}
|
||
|
||
// limitedBodyReader 包装请求体读取器以限制最大读取字节数。
|
||
//
|
||
// 当读取的字节数超过限制时,返回错误并在上下文中设置 413 状态码。
|
||
type limitedBodyReader struct {
|
||
// original 原始读取器
|
||
original interface {
|
||
Read(p []byte) (n int, err error)
|
||
}
|
||
// ctx 请求上下文,用于设置错误响应
|
||
ctx *fasthttp.RequestCtx
|
||
// limit 最大允许读取的字节数
|
||
limit int64
|
||
// read 已读取的字节数
|
||
read int64
|
||
// done 是否已达到限制
|
||
done bool
|
||
}
|
||
|
||
// Read 实现读取接口,在超过限制时返回错误。
|
||
func (l *limitedBodyReader) Read(p []byte) (n int, err error) {
|
||
if l.done {
|
||
return 0, fmt.Errorf("request body too large")
|
||
}
|
||
|
||
// 计算还能读取多少字节
|
||
remaining := l.limit - l.read
|
||
if remaining <= 0 {
|
||
l.done = true
|
||
// 返回 413 错误
|
||
l.ctx.Error("Request Entity Too Large", fasthttp.StatusRequestEntityTooLarge)
|
||
return 0, fmt.Errorf("request body exceeds limit of %d bytes", l.limit)
|
||
}
|
||
|
||
// 限制读取长度
|
||
if int64(len(p)) > remaining {
|
||
p = p[:remaining]
|
||
}
|
||
|
||
n, err = l.original.Read(p)
|
||
l.read += int64(n)
|
||
|
||
return n, err
|
||
}
|
||
|
||
// ParseSize 解析大小字符串为字节数。
|
||
//
|
||
// 支持的单位:b, kb, mb, gb(不区分大小写)
|
||
// 无单位时默认为字节。
|
||
//
|
||
// 参数:
|
||
// - sizeStr: 大小字符串,如 "1mb", "10kb", "1024" 等
|
||
//
|
||
// 返回值:
|
||
// - int64: 字节数
|
||
// - error: 解析失败时的错误
|
||
func ParseSize(sizeStr string) (int64, error) {
|
||
if sizeStr == "" {
|
||
return DefaultMaxBodySize, nil
|
||
}
|
||
|
||
sizeStr = strings.TrimSpace(strings.ToLower(sizeStr))
|
||
|
||
// 解析数值和单位
|
||
var numStr string
|
||
var unit string
|
||
|
||
for i, c := range sizeStr {
|
||
if c >= '0' && c <= '9' || c == '.' {
|
||
numStr = sizeStr[:i+1]
|
||
} else {
|
||
unit = sizeStr[i:]
|
||
break
|
||
}
|
||
}
|
||
|
||
if numStr == "" {
|
||
return 0, fmt.Errorf("无效的大小格式: %s", sizeStr)
|
||
}
|
||
|
||
value, err := strconv.ParseFloat(numStr, 64)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("解析数值失败: %w", err)
|
||
}
|
||
|
||
var multiplier float64
|
||
switch unit {
|
||
case "", "b":
|
||
multiplier = 1
|
||
case "kb":
|
||
multiplier = 1024
|
||
case "mb":
|
||
multiplier = 1024 * 1024
|
||
case "gb":
|
||
multiplier = 1024 * 1024 * 1024
|
||
default:
|
||
return 0, fmt.Errorf("不支持的大小单位: %s", unit)
|
||
}
|
||
|
||
return int64(value * multiplier), nil
|
||
}
|