248 lines
6.8 KiB
Go
248 lines
6.8 KiB
Go
// Package bodylimit 提供请求体大小限制中间件的测试。
|
||
//
|
||
// 作者:xfy
|
||
package bodylimit
|
||
|
||
import (
|
||
"bytes"
|
||
"io"
|
||
"strings"
|
||
"testing"
|
||
|
||
"github.com/valyala/fasthttp"
|
||
)
|
||
|
||
// TestParseSize 测试大小字符串解析。
|
||
func TestParseSize(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
input string
|
||
expected int64
|
||
wantErr bool
|
||
}{
|
||
{"empty string", "", 1 << 20, false}, // 默认 1MB
|
||
{"plain bytes", "1024", 1024, false}, // 纯数字
|
||
{"with b", "2048b", 2048, false}, // 带 b 单位
|
||
{"kilobytes", "10kb", 10 * 1024, false}, // KB
|
||
{"megabytes", "1mb", 1024 * 1024, false}, // MB
|
||
{"gigabytes", "1gb", 1024 * 1024 * 1024, false}, // GB
|
||
{"uppercase", "1MB", 1024 * 1024, false}, // 大写
|
||
{"with spaces", " 1mb ", 1024 * 1024, false}, // 带空格
|
||
{"decimal", "1.5mb", int64(1.5 * 1024 * 1024), false}, // 小数
|
||
{"invalid unit", "1xx", 0, true}, // 无效单位
|
||
{"no number", "mb", 0, true}, // 无数字
|
||
{"negative", "-1mb", 0, true}, // 负数(ParseFloat 可能成功但结果为负)
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
got, err := ParseSize(tt.input)
|
||
if (err != nil) != tt.wantErr {
|
||
t.Errorf("ParseSize(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
|
||
return
|
||
}
|
||
if !tt.wantErr && got != tt.expected {
|
||
t.Errorf("ParseSize(%q) = %d, want %d", tt.input, got, tt.expected)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestNew 测试创建中间件。
|
||
func TestNew(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
maxSize string
|
||
wantErr bool
|
||
expected int64
|
||
}{
|
||
{"valid 1mb", "1mb", false, 1024 * 1024},
|
||
{"valid 10kb", "10kb", false, 10 * 1024},
|
||
{"invalid", "invalid", true, 0},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
bl, err := New(tt.maxSize)
|
||
if (err != nil) != tt.wantErr {
|
||
t.Errorf("New(%q) error = %v, wantErr %v", tt.maxSize, err, tt.wantErr)
|
||
return
|
||
}
|
||
if !tt.wantErr && bl.maxBodySize != tt.expected {
|
||
t.Errorf("New(%q).maxBodySize = %d, want %d", tt.maxSize, bl.maxBodySize, tt.expected)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestBodyLimit_Process 测试中间件处理。
|
||
func TestBodyLimit_Process(t *testing.T) {
|
||
bl, err := New("100b")
|
||
if err != nil {
|
||
t.Fatalf("创建中间件失败: %v", err)
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
body string
|
||
contentLength int
|
||
expectedStatus int
|
||
}{
|
||
{
|
||
name: "small body within limit",
|
||
body: "small body",
|
||
contentLength: 10,
|
||
expectedStatus: fasthttp.StatusOK,
|
||
},
|
||
{
|
||
name: "body exactly at limit",
|
||
body: strings.Repeat("a", 100),
|
||
contentLength: 100,
|
||
expectedStatus: fasthttp.StatusOK,
|
||
},
|
||
{
|
||
name: "body exceeds limit via content-length",
|
||
body: strings.Repeat("a", 200),
|
||
contentLength: 200,
|
||
expectedStatus: fasthttp.StatusRequestEntityTooLarge,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||
}
|
||
|
||
handler := bl.Process(nextHandler)
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.Header.SetMethod("POST")
|
||
ctx.Request.Header.SetContentLength(tt.contentLength)
|
||
ctx.Request.SetBodyStream(bytes.NewReader([]byte(tt.body)), tt.contentLength)
|
||
|
||
handler(ctx)
|
||
|
||
if ctx.Response.StatusCode() != tt.expectedStatus {
|
||
t.Errorf("status code = %d, want %d", ctx.Response.StatusCode(), tt.expectedStatus)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestBodyLimit_ProcessChunked 测试 chunked 传输编码。
|
||
func TestBodyLimit_ProcessChunked(t *testing.T) {
|
||
bl, err := New("100b")
|
||
if err != nil {
|
||
t.Fatalf("创建中间件失败: %v", err)
|
||
}
|
||
|
||
// 测试 chunked 传输无法绕过限制
|
||
body := strings.Repeat("a", 150) // 150 字节超过 100 字节限制
|
||
reader := &slowReader{data: []byte(body), chunkSize: 10}
|
||
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
// 尝试读取完整请求体
|
||
body, err := io.ReadAll(ctx.Request.BodyStream())
|
||
if err != nil {
|
||
ctx.SetStatusCode(fasthttp.StatusRequestEntityTooLarge)
|
||
return
|
||
}
|
||
if len(body) > 100 {
|
||
ctx.SetStatusCode(fasthttp.StatusRequestEntityTooLarge)
|
||
return
|
||
}
|
||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||
}
|
||
|
||
handler := bl.Process(nextHandler)
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.Header.SetMethod("POST")
|
||
// 不设置 Content-Length 模拟 chunked 传输
|
||
ctx.Request.SetBodyStream(reader, -1)
|
||
|
||
handler(ctx)
|
||
|
||
// 应该返回 413,因为 body 超过限制
|
||
if ctx.Response.StatusCode() != fasthttp.StatusRequestEntityTooLarge {
|
||
t.Errorf("chunked 传输绕过测试失败: status code = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusRequestEntityTooLarge)
|
||
}
|
||
}
|
||
|
||
// TestBodyLimit_PathLimits 测试路径级别配置。
|
||
func TestBodyLimit_PathLimits(t *testing.T) {
|
||
bl, err := New("1mb")
|
||
if err != nil {
|
||
t.Fatalf("创建中间件失败: %v", err)
|
||
}
|
||
|
||
// 添加路径级别配置
|
||
if err := bl.AddPathLimit("/api/upload", "10mb"); err != nil {
|
||
t.Fatalf("添加路径限制失败: %v", err)
|
||
}
|
||
if err := bl.AddPathLimit("/api", "2mb"); err != nil {
|
||
t.Fatalf("添加路径限制失败: %v", err)
|
||
}
|
||
|
||
tests := []struct {
|
||
path string
|
||
expected int64
|
||
}{
|
||
{"/api/upload/file", 10 * 1024 * 1024}, // 最长匹配 /api/upload
|
||
{"/api/users", 2 * 1024 * 1024}, // 匹配 /api
|
||
{"/other/path", 1 * 1024 * 1024}, // 默认限制
|
||
{"/api", 2 * 1024 * 1024}, // 匹配 /api
|
||
{"/apix", 2 * 1024 * 1024}, // 匹配 /api(前缀匹配)
|
||
{"/apiupload", 2 * 1024 * 1024}, // 匹配 /api(前缀匹配)
|
||
{"/notapi", 1 * 1024 * 1024}, // 不匹配 /api
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.path, func(t *testing.T) {
|
||
got := bl.GetLimit(tt.path)
|
||
if got != tt.expected {
|
||
t.Errorf("GetLimit(%q) = %d, want %d", tt.path, got, tt.expected)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestBodyLimit_Name 测试中间件名称。
|
||
func TestBodyLimit_Name(t *testing.T) {
|
||
bl := NewWithDefault()
|
||
if bl.Name() != "BodyLimit" {
|
||
t.Errorf("Name() = %s, want BodyLimit", bl.Name())
|
||
}
|
||
}
|
||
|
||
// TestBodyLimit_DefaultMaxBodySize 测试默认大小。
|
||
func TestBodyLimit_DefaultMaxBodySize(t *testing.T) {
|
||
if DefaultMaxBodySize != 1<<20 {
|
||
t.Errorf("DefaultMaxBodySize = %d, want %d", DefaultMaxBodySize, 1<<20)
|
||
}
|
||
}
|
||
|
||
// slowReader 模拟慢速读取,用于测试 chunked 传输。
|
||
type slowReader struct {
|
||
data []byte
|
||
pos int
|
||
chunkSize int
|
||
}
|
||
|
||
func (r *slowReader) Read(p []byte) (n int, err error) {
|
||
if r.pos >= len(r.data) {
|
||
return 0, io.EOF
|
||
}
|
||
|
||
// 每次只读取 chunkSize 字节
|
||
remaining := len(r.data) - r.pos
|
||
toRead := min(min(r.chunkSize, remaining), len(p))
|
||
|
||
n = toRead
|
||
copy(p, r.data[r.pos:r.pos+toRead])
|
||
r.pos += toRead
|
||
|
||
return n, nil
|
||
}
|