test(middleware): 添加 limitrate 中间件测试
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
4d66dd562f
commit
a832e48656
609
internal/middleware/limitrate/limitrate_test.go
Normal file
609
internal/middleware/limitrate/limitrate_test.go
Normal file
@ -0,0 +1,609 @@
|
|||||||
|
// Package limitrate 提供响应速率限制中间件的测试。
|
||||||
|
//
|
||||||
|
// 该文件测试速率限制模块的各项功能,包括:
|
||||||
|
// - 中间件创建和名称
|
||||||
|
// - 限速写入器创建
|
||||||
|
// - 令牌桶算法
|
||||||
|
// - 零值和负值边界情况
|
||||||
|
// - 并发安全性
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package limitrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestConstants 测试常量值。
|
||||||
|
func TestConstants(t *testing.T) {
|
||||||
|
if LargeFileStrategySkip != "skip" {
|
||||||
|
t.Errorf("LargeFileStrategySkip = %q, want %q", LargeFileStrategySkip, "skip")
|
||||||
|
}
|
||||||
|
if LargeFileStrategyCoarse != "coarse" {
|
||||||
|
t.Errorf("LargeFileStrategyCoarse = %q, want %q", LargeFileStrategyCoarse, "coarse")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewMiddleware 测试创建中间件。
|
||||||
|
func TestNewMiddleware(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg *config.LimitRateConfig
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil config",
|
||||||
|
cfg: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
cfg: &config.LimitRateConfig{
|
||||||
|
Rate: 1024,
|
||||||
|
Burst: 2048,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero rate",
|
||||||
|
cfg: &config.LimitRateConfig{
|
||||||
|
Rate: 0,
|
||||||
|
Burst: 1024,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative rate",
|
||||||
|
cfg: &config.LimitRateConfig{
|
||||||
|
Rate: -1,
|
||||||
|
Burst: 1024,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero burst",
|
||||||
|
cfg: &config.LimitRateConfig{
|
||||||
|
Rate: 1024,
|
||||||
|
Burst: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with large file config",
|
||||||
|
cfg: &config.LimitRateConfig{
|
||||||
|
Rate: 1024,
|
||||||
|
Burst: 2048,
|
||||||
|
LargeFileThreshold: 10 * 1024 * 1024,
|
||||||
|
LargeFileStrategy: LargeFileStrategySkip,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mw := NewMiddleware(tt.cfg)
|
||||||
|
if mw == nil {
|
||||||
|
t.Error("NewMiddleware() returned nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMiddleware_Name 测试中间件名称。
|
||||||
|
func TestMiddleware_Name(t *testing.T) {
|
||||||
|
mw := NewMiddleware(nil)
|
||||||
|
if mw.Name() != "limit_rate" {
|
||||||
|
t.Errorf("Name() = %q, want %q", mw.Name(), "limit_rate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMiddleware_Process_NilConfig 测试空配置时的处理。
|
||||||
|
func TestMiddleware_Process_NilConfig(t *testing.T) {
|
||||||
|
mw := NewMiddleware(nil)
|
||||||
|
|
||||||
|
called := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
called = true
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := mw.Process(nextHandler)
|
||||||
|
if handler == nil {
|
||||||
|
t.Fatal("Process() returned nil handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/test")
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Error("next handler was not called")
|
||||||
|
}
|
||||||
|
if ctx.Response.StatusCode() != fasthttp.StatusOK {
|
||||||
|
t.Errorf("status code = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMiddleware_Process_ZeroRate 测试零速率时的处理。
|
||||||
|
func TestMiddleware_Process_ZeroRate(t *testing.T) {
|
||||||
|
mw := NewMiddleware(&config.LimitRateConfig{
|
||||||
|
Rate: 0,
|
||||||
|
Burst: 1024,
|
||||||
|
})
|
||||||
|
|
||||||
|
called := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
called = true
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := mw.Process(nextHandler)
|
||||||
|
if handler == nil {
|
||||||
|
t.Fatal("Process() returned nil handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/test")
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Error("next handler was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMiddleware_Process_NegativeRate 测试负速率时的处理。
|
||||||
|
func TestMiddleware_Process_NegativeRate(t *testing.T) {
|
||||||
|
mw := NewMiddleware(&config.LimitRateConfig{
|
||||||
|
Rate: -100,
|
||||||
|
Burst: 1024,
|
||||||
|
})
|
||||||
|
|
||||||
|
called := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
called = true
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := mw.Process(nextHandler)
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/test")
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Error("next handler was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMiddleware_Process_ValidConfig 测试有效配置时的处理。
|
||||||
|
func TestMiddleware_Process_ValidConfig(t *testing.T) {
|
||||||
|
mw := NewMiddleware(&config.LimitRateConfig{
|
||||||
|
Rate: 1024,
|
||||||
|
Burst: 2048,
|
||||||
|
})
|
||||||
|
|
||||||
|
called := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
called = true
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := mw.Process(nextHandler)
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/test")
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Error("next handler was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewRateLimitedWriter 测试创建限速写入器。
|
||||||
|
func TestNewRateLimitedWriter(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rate int64
|
||||||
|
burst int64
|
||||||
|
}{
|
||||||
|
{"positive rate and burst", 1024, 2048},
|
||||||
|
{"zero rate", 0, 1024},
|
||||||
|
{"negative rate", -1, 1024},
|
||||||
|
{"zero burst", 1024, 0},
|
||||||
|
{"negative burst", 1024, -1},
|
||||||
|
{"rate equals burst", 1024, 1024},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
w := NewRateLimitedWriter(buf, tt.rate, tt.burst)
|
||||||
|
if w == nil {
|
||||||
|
t.Error("NewRateLimitedWriter() returned nil")
|
||||||
|
}
|
||||||
|
if w.rate != tt.rate {
|
||||||
|
t.Errorf("rate = %d, want %d", w.rate, tt.rate)
|
||||||
|
}
|
||||||
|
if w.maxBucket != tt.burst {
|
||||||
|
t.Errorf("maxBucket = %d, want %d", w.maxBucket, tt.burst)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_Write_ZeroRate 测试零速率时直接写入。
|
||||||
|
func TestRateLimitedWriter_Write_ZeroRate(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w := NewRateLimitedWriter(buf, 0, 1024)
|
||||||
|
|
||||||
|
data := []byte("hello world")
|
||||||
|
n, err := w.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Write() error = %v", err)
|
||||||
|
}
|
||||||
|
if n != len(data) {
|
||||||
|
t.Errorf("Write() n = %d, want %d", n, len(data))
|
||||||
|
}
|
||||||
|
if !bytes.Equal(buf.Bytes(), data) {
|
||||||
|
t.Errorf("buffer = %q, want %q", buf.Bytes(), data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_Write_NegativeRate 测试负速率时直接写入。
|
||||||
|
func TestRateLimitedWriter_Write_NegativeRate(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w := NewRateLimitedWriter(buf, -1, 1024)
|
||||||
|
|
||||||
|
data := []byte("hello world")
|
||||||
|
n, err := w.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Write() error = %v", err)
|
||||||
|
}
|
||||||
|
if n != len(data) {
|
||||||
|
t.Errorf("Write() n = %d, want %d", n, len(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_Write_WithinBucket 测试令牌充足时的写入。
|
||||||
|
func TestRateLimitedWriter_Write_WithinBucket(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
rate := int64(1024)
|
||||||
|
burst := int64(2048)
|
||||||
|
w := NewRateLimitedWriter(buf, rate, burst)
|
||||||
|
|
||||||
|
data := []byte("hello")
|
||||||
|
n, err := w.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Write() error = %v", err)
|
||||||
|
}
|
||||||
|
if n != len(data) {
|
||||||
|
t.Errorf("Write() n = %d, want %d", n, len(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证令牌被消耗
|
||||||
|
expectedBucket := burst - int64(len(data))
|
||||||
|
if w.bucket != expectedBucket {
|
||||||
|
t.Errorf("bucket = %d, want %d", w.bucket, expectedBucket)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_Write_ExceedsBucket 测试令牌不足时的分批写入。
|
||||||
|
func TestRateLimitedWriter_Write_ExceedsBucket(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
rate := int64(100) // 100 bytes/sec
|
||||||
|
burst := int64(10) // only 10 tokens initially
|
||||||
|
w := NewRateLimitedWriter(buf, rate, burst)
|
||||||
|
|
||||||
|
// 写入超过 burst 的数据
|
||||||
|
data := make([]byte, 50)
|
||||||
|
for i := range data {
|
||||||
|
data[i] = 'a'
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := w.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Write() error = %v", err)
|
||||||
|
}
|
||||||
|
if n != len(data) {
|
||||||
|
t.Errorf("Write() n = %d, want %d", n, len(data))
|
||||||
|
}
|
||||||
|
if !bytes.Equal(buf.Bytes(), data) {
|
||||||
|
t.Error("buffer content mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_Write_Error 测试底层写入错误。
|
||||||
|
func TestRateLimitedWriter_Write_Error(t *testing.T) {
|
||||||
|
errWriter := &errorWriter{err: errors.New("write error")}
|
||||||
|
|
||||||
|
// 测试零速率时的错误传播(零速率时直接调用底层 writer,错误会传播)
|
||||||
|
w := NewRateLimitedWriter(errWriter, 0, 1024)
|
||||||
|
|
||||||
|
data := []byte("hello")
|
||||||
|
_, err := w.Write(data)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Write() with zero rate should propagate error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试非零速率时的错误传播
|
||||||
|
w2 := NewRateLimitedWriter(errWriter, 1024, 10)
|
||||||
|
_, err = w2.Write(data)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Write() expected error, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorWriter 是一个总是返回错误的写入器。
|
||||||
|
type errorWriter struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *errorWriter) Write(p []byte) (n int, err error) {
|
||||||
|
return 0, w.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_Write_MultipleWrites 测试多次写入。
|
||||||
|
func TestRateLimitedWriter_Write_MultipleWrites(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
rate := int64(1000)
|
||||||
|
burst := int64(100)
|
||||||
|
w := NewRateLimitedWriter(buf, rate, burst)
|
||||||
|
|
||||||
|
// 第一次写入消耗令牌
|
||||||
|
data1 := []byte("first")
|
||||||
|
n1, err := w.Write(data1)
|
||||||
|
if err != nil || n1 != len(data1) {
|
||||||
|
t.Errorf("first Write() failed: n=%d, err=%v", n1, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第二次写入
|
||||||
|
data2 := []byte("second")
|
||||||
|
n2, err := w.Write(data2)
|
||||||
|
if err != nil || n2 != len(data2) {
|
||||||
|
t.Errorf("second Write() failed: n=%d, err=%v", n2, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := append(data1, data2...)
|
||||||
|
if !bytes.Equal(buf.Bytes(), expected) {
|
||||||
|
t.Errorf("buffer = %q, want %q", buf.Bytes(), expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_BucketRefill 测试令牌桶补充。
|
||||||
|
func TestRateLimitedWriter_BucketRefill(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
rate := int64(1000) // 1000 tokens/sec
|
||||||
|
burst := int64(100)
|
||||||
|
w := NewRateLimitedWriter(buf, rate, burst)
|
||||||
|
|
||||||
|
// 消耗所有令牌
|
||||||
|
w.bucket = 0
|
||||||
|
|
||||||
|
// 等待一段时间让令牌补充
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
|
||||||
|
// 写入数据,应该能获得新令牌
|
||||||
|
data := []byte("test")
|
||||||
|
n, err := w.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Write() error = %v", err)
|
||||||
|
}
|
||||||
|
if n != len(data) {
|
||||||
|
t.Errorf("Write() n = %d, want %d", n, len(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_BucketMax 测试令牌桶上限。
|
||||||
|
func TestRateLimitedWriter_BucketMax(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
rate := int64(1000)
|
||||||
|
burst := int64(100)
|
||||||
|
w := NewRateLimitedWriter(buf, rate, burst)
|
||||||
|
|
||||||
|
// 设置 bucket 超过最大值
|
||||||
|
w.bucket = burst + 500
|
||||||
|
|
||||||
|
// 等待让时间流逝,触发令牌补充逻辑
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// 写入数据
|
||||||
|
data := []byte("test")
|
||||||
|
_, err := w.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Write() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// bucket 不应超过 maxBucket
|
||||||
|
if w.bucket > w.maxBucket {
|
||||||
|
t.Errorf("bucket = %d, should not exceed maxBucket = %d", w.bucket, w.maxBucket)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_Concurrent 测试并发写入安全性。
|
||||||
|
func TestRateLimitedWriter_Concurrent(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
rate := int64(10000) // 高速率以减少测试时间
|
||||||
|
burst := int64(1000)
|
||||||
|
w := NewRateLimitedWriter(buf, rate, burst)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
goroutines := 5
|
||||||
|
writesPerGoroutine := 10
|
||||||
|
data := []byte("test")
|
||||||
|
|
||||||
|
for i := 0; i < goroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < writesPerGoroutine; j++ {
|
||||||
|
_, _ = w.Write(data)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
expectedLen := goroutines * writesPerGoroutine * len(data)
|
||||||
|
if buf.Len() != expectedLen {
|
||||||
|
t.Errorf("buffer length = %d, want %d", buf.Len(), expectedLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_EmptyWrite 测试空写入。
|
||||||
|
func TestRateLimitedWriter_EmptyWrite(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w := NewRateLimitedWriter(buf, 1024, 100)
|
||||||
|
|
||||||
|
n, err := w.Write([]byte{})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Write() error = %v", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("Write() n = %d, want 0", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_LargeData 测试大数据写入。
|
||||||
|
func TestRateLimitedWriter_LargeData(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
rate := int64(10000) // 10KB/s
|
||||||
|
burst := int64(1000) // 1KB burst
|
||||||
|
w := NewRateLimitedWriter(buf, rate, burst)
|
||||||
|
|
||||||
|
// 写入 5KB 数据
|
||||||
|
data := make([]byte, 5*1024)
|
||||||
|
for i := range data {
|
||||||
|
data[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
n, err := w.Write(data)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Write() error = %v", err)
|
||||||
|
}
|
||||||
|
if n != len(data) {
|
||||||
|
t.Errorf("Write() n = %d, want %d", n, len(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证数据被正确写入
|
||||||
|
if !bytes.Equal(buf.Bytes(), data) {
|
||||||
|
t.Error("buffer content mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 由于限速,写入时间应该大于零(令牌不足时需要等待)
|
||||||
|
t.Logf("Large data write took %v", elapsed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_PartialWriteError 测试部分写入错误。
|
||||||
|
func TestRateLimitedWriter_PartialWriteError(t *testing.T) {
|
||||||
|
// 创建一个写入部分数据后返回错误的 writer
|
||||||
|
partialWriter := &partialErrorWriter{maxWrite: 5}
|
||||||
|
w := NewRateLimitedWriter(partialWriter, 0, 100)
|
||||||
|
|
||||||
|
data := []byte("hello world")
|
||||||
|
n, err := w.Write(data)
|
||||||
|
|
||||||
|
// 零速率时,直接调用底层 writer
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Write() expected error, got nil")
|
||||||
|
}
|
||||||
|
if n > len(data) {
|
||||||
|
t.Errorf("Write() n = %d, should not exceed %d", n, len(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// partialErrorWriter 写入 maxWrite 字节后返回错误。
|
||||||
|
type partialErrorWriter struct {
|
||||||
|
maxWrite int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *partialErrorWriter) Write(p []byte) (n int, err error) {
|
||||||
|
if len(p) <= w.maxWrite {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
return w.maxWrite, errors.New("partial write error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_TimeAdvancement 测试时间推进。
|
||||||
|
func TestRateLimitedWriter_TimeAdvancement(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
rate := int64(1000)
|
||||||
|
burst := int64(100)
|
||||||
|
w := NewRateLimitedWriter(buf, rate, burst)
|
||||||
|
|
||||||
|
// 记录初始时间
|
||||||
|
initialTime := w.lastTime
|
||||||
|
|
||||||
|
// 写入一些数据
|
||||||
|
_, _ = w.Write([]byte("test"))
|
||||||
|
|
||||||
|
// lastTime 应该被更新
|
||||||
|
if !w.lastTime.After(initialTime) && w.lastTime != initialTime {
|
||||||
|
t.Error("lastTime should be updated after write")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRateLimitedWriter_WriteAll 测试完整写入 io.Writer 接口兼容性。
|
||||||
|
func TestRateLimitedWriter_WriteAll(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w := NewRateLimitedWriter(buf, 1000, 100)
|
||||||
|
|
||||||
|
data := []byte("hello world, this is a test of the rate limited writer")
|
||||||
|
|
||||||
|
// 使用 io.Writer 接口
|
||||||
|
var writer io.Writer = w
|
||||||
|
n, err := writer.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Write() error = %v", err)
|
||||||
|
}
|
||||||
|
if n != len(data) {
|
||||||
|
t.Errorf("Write() n = %d, want %d", n, len(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkMiddleware_Process 基准测试中间件处理。
|
||||||
|
func BenchmarkMiddleware_Process(b *testing.B) {
|
||||||
|
mw := NewMiddleware(&config.LimitRateConfig{
|
||||||
|
Rate: 1024 * 1024, // 1MB/s
|
||||||
|
Burst: 2 * 1024 * 1024,
|
||||||
|
})
|
||||||
|
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := mw.Process(nextHandler)
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/test")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
handler(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkRateLimitedWriter_Write 基准测试限速写入。
|
||||||
|
func BenchmarkRateLimitedWriter_Write(b *testing.B) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w := NewRateLimitedWriter(buf, 1024*1024, 1024*1024) // 高速率减少等待
|
||||||
|
data := make([]byte, 1024)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
buf.Reset()
|
||||||
|
w.bucket = w.maxBucket // 重置令牌桶
|
||||||
|
_, _ = w.Write(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkRateLimitedWriter_Concurrent 基准测试并发写入。
|
||||||
|
func BenchmarkRateLimitedWriter_Concurrent(b *testing.B) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
w := NewRateLimitedWriter(buf, 10*1024*1024, 1024*1024)
|
||||||
|
data := []byte("test data for benchmarking")
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
_, _ = w.Write(data)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user