lolly/internal/middleware/bodylimit/bodylimit_test.go

248 lines
6.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package bodylimit 提供请求体大小限制中间件的测试。
//
// 作者xfy
package bodylimit
import (
"bytes"
"io"
"strings"
"testing"
"github.com/valyala/fasthttp"
)
// TestParseSize 测试大小字符串解析。
func TestParseSize(t *testing.T) {
tests := []struct {
name string
input string
expected int64
wantErr bool
}{
{"empty string", "", 1 << 20, false}, // 默认 1MB
{"plain bytes", "1024", 1024, false}, // 纯数字
{"with b", "2048b", 2048, false}, // 带 b 单位
{"kilobytes", "10kb", 10 * 1024, false}, // KB
{"megabytes", "1mb", 1024 * 1024, false}, // MB
{"gigabytes", "1gb", 1024 * 1024 * 1024, false}, // GB
{"uppercase", "1MB", 1024 * 1024, false}, // 大写
{"with spaces", " 1mb ", 1024 * 1024, false}, // 带空格
{"decimal", "1.5mb", int64(1.5 * 1024 * 1024), false}, // 小数
{"invalid unit", "1xx", 0, true}, // 无效单位
{"no number", "mb", 0, true}, // 无数字
{"negative", "-1mb", 0, true}, // 负数ParseFloat 可能成功但结果为负)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseSize(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("ParseSize(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
return
}
if !tt.wantErr && got != tt.expected {
t.Errorf("ParseSize(%q) = %d, want %d", tt.input, got, tt.expected)
}
})
}
}
// TestNew 测试创建中间件。
func TestNew(t *testing.T) {
tests := []struct {
name string
maxSize string
wantErr bool
expected int64
}{
{"valid 1mb", "1mb", false, 1024 * 1024},
{"valid 10kb", "10kb", false, 10 * 1024},
{"invalid", "invalid", true, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bl, err := New(tt.maxSize)
if (err != nil) != tt.wantErr {
t.Errorf("New(%q) error = %v, wantErr %v", tt.maxSize, err, tt.wantErr)
return
}
if !tt.wantErr && bl.maxBodySize != tt.expected {
t.Errorf("New(%q).maxBodySize = %d, want %d", tt.maxSize, bl.maxBodySize, tt.expected)
}
})
}
}
// TestBodyLimit_Process 测试中间件处理。
func TestBodyLimit_Process(t *testing.T) {
bl, err := New("100b")
if err != nil {
t.Fatalf("创建中间件失败: %v", err)
}
tests := []struct {
name string
body string
contentLength int
expectedStatus int
}{
{
name: "small body within limit",
body: "small body",
contentLength: 10,
expectedStatus: fasthttp.StatusOK,
},
{
name: "body exactly at limit",
body: strings.Repeat("a", 100),
contentLength: 100,
expectedStatus: fasthttp.StatusOK,
},
{
name: "body exceeds limit via content-length",
body: strings.Repeat("a", 200),
contentLength: 200,
expectedStatus: fasthttp.StatusRequestEntityTooLarge,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nextHandler := func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(fasthttp.StatusOK)
}
handler := bl.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
ctx.Request.Header.SetContentLength(tt.contentLength)
ctx.Request.SetBodyStream(bytes.NewReader([]byte(tt.body)), tt.contentLength)
handler(ctx)
if ctx.Response.StatusCode() != tt.expectedStatus {
t.Errorf("status code = %d, want %d", ctx.Response.StatusCode(), tt.expectedStatus)
}
})
}
}
// TestBodyLimit_ProcessChunked 测试 chunked 传输编码。
func TestBodyLimit_ProcessChunked(t *testing.T) {
bl, err := New("100b")
if err != nil {
t.Fatalf("创建中间件失败: %v", err)
}
// 测试 chunked 传输无法绕过限制
body := strings.Repeat("a", 150) // 150 字节超过 100 字节限制
reader := &slowReader{data: []byte(body), chunkSize: 10}
nextHandler := func(ctx *fasthttp.RequestCtx) {
// 尝试读取完整请求体
body, err := io.ReadAll(ctx.Request.BodyStream())
if err != nil {
ctx.SetStatusCode(fasthttp.StatusRequestEntityTooLarge)
return
}
if len(body) > 100 {
ctx.SetStatusCode(fasthttp.StatusRequestEntityTooLarge)
return
}
ctx.SetStatusCode(fasthttp.StatusOK)
}
handler := bl.Process(nextHandler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("POST")
// 不设置 Content-Length 模拟 chunked 传输
ctx.Request.SetBodyStream(reader, -1)
handler(ctx)
// 应该返回 413因为 body 超过限制
if ctx.Response.StatusCode() != fasthttp.StatusRequestEntityTooLarge {
t.Errorf("chunked 传输绕过测试失败: status code = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusRequestEntityTooLarge)
}
}
// TestBodyLimit_PathLimits 测试路径级别配置。
func TestBodyLimit_PathLimits(t *testing.T) {
bl, err := New("1mb")
if err != nil {
t.Fatalf("创建中间件失败: %v", err)
}
// 添加路径级别配置
if err := bl.AddPathLimit("/api/upload", "10mb"); err != nil {
t.Fatalf("添加路径限制失败: %v", err)
}
if err := bl.AddPathLimit("/api", "2mb"); err != nil {
t.Fatalf("添加路径限制失败: %v", err)
}
tests := []struct {
path string
expected int64
}{
{"/api/upload/file", 10 * 1024 * 1024}, // 最长匹配 /api/upload
{"/api/users", 2 * 1024 * 1024}, // 匹配 /api
{"/other/path", 1 * 1024 * 1024}, // 默认限制
{"/api", 2 * 1024 * 1024}, // 匹配 /api
{"/apix", 2 * 1024 * 1024}, // 匹配 /api前缀匹配
{"/apiupload", 2 * 1024 * 1024}, // 匹配 /api前缀匹配
{"/notapi", 1 * 1024 * 1024}, // 不匹配 /api
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
got := bl.GetLimit(tt.path)
if got != tt.expected {
t.Errorf("GetLimit(%q) = %d, want %d", tt.path, got, tt.expected)
}
})
}
}
// TestBodyLimit_Name 测试中间件名称。
func TestBodyLimit_Name(t *testing.T) {
bl := NewWithDefault()
if bl.Name() != "BodyLimit" {
t.Errorf("Name() = %s, want BodyLimit", bl.Name())
}
}
// TestBodyLimit_DefaultMaxBodySize 测试默认大小。
func TestBodyLimit_DefaultMaxBodySize(t *testing.T) {
if DefaultMaxBodySize != 1<<20 {
t.Errorf("DefaultMaxBodySize = %d, want %d", DefaultMaxBodySize, 1<<20)
}
}
// slowReader 模拟慢速读取,用于测试 chunked 传输。
type slowReader struct {
data []byte
pos int
chunkSize int
}
func (r *slowReader) Read(p []byte) (n int, err error) {
if r.pos >= len(r.data) {
return 0, io.EOF
}
// 每次只读取 chunkSize 字节
remaining := len(r.data) - r.pos
toRead := min(min(r.chunkSize, remaining), len(p))
n = toRead
copy(p, r.data[r.pos:r.pos+toRead])
r.pos += toRead
return n, nil
}