lolly/internal/middleware/security/ratelimit_test.go
xfy e646cc5d05 refactor(test): 提取 testutil 包统一测试辅助函数
- 新增 NewRequestCtx 和 NewRequestCtxWithHeader 辅助函数
- 简化各测试文件中 RequestCtx 创建代码
- 减少测试代码重复,提高可维护性

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-13 13:15:20 +08:00

463 lines
8.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package security 提供速率限制功能的测试。
//
// 该文件测试速率限制模块的各项功能,包括:
// - 速率限制器创建
// - 令牌桶算法
// - 令牌补充机制
// - 计数器重置
// - 连接数限制
// - 统计信息获取
//
// 作者xfy
package security
import (
"testing"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/testutil"
)
func TestNewRateLimiter(t *testing.T) {
tests := []struct {
cfg *config.RateLimitConfig
name string
wantErr bool
}{
{
name: "nil config",
cfg: nil,
wantErr: true,
},
{
name: "valid config",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
},
},
{
name: "zero rate",
cfg: &config.RateLimitConfig{
RequestRate: 0,
Burst: 100,
},
wantErr: true,
},
{
name: "burst less than rate",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 50,
},
wantErr: true,
},
{
name: "key by IP",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
Key: "ip",
},
},
{
name: "key by header",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
Key: "header",
},
},
{
name: "unknown key type",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
Key: "unknown",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rl, err := NewRateLimiter(tt.cfg)
if (err != nil) != tt.wantErr {
t.Errorf("NewRateLimiter() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && rl == nil {
t.Error("Expected non-nil RateLimiter")
}
})
}
}
func TestRateLimiterAllow(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 10,
Burst: 10,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
// Test burst allowance
key := "test-key"
// Should allow burst requests
for i := 0; i < 10; i++ {
if !rl.Allow(key) {
t.Errorf("Expected request %d to be allowed", i+1)
}
}
// Next request should be denied (burst exhausted)
if rl.Allow(key) {
t.Error("Expected request to be denied after burst exhausted")
}
}
func TestRateLimiterTokenRefill(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100, // 100 tokens per second
Burst: 100,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
key := "refill-test"
// Exhaust the burst
for i := 0; i < 100; i++ {
rl.Allow(key)
}
// Should be denied
if rl.Allow(key) {
t.Error("Expected request to be denied")
}
// Wait for token refill (10ms should give us 1 token at 100/s)
time.Sleep(15 * time.Millisecond)
// Should be allowed now
if !rl.Allow(key) {
t.Error("Expected request to be allowed after refill")
}
}
func TestRateLimiterReset(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 1,
Burst: 1,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
key := "reset-test"
// Exhaust
rl.Allow(key)
if rl.Allow(key) {
t.Error("Expected denial")
}
// Reset
rl.Reset(key)
// Should be allowed again
if !rl.Allow(key) {
t.Error("Expected request to be allowed after reset")
}
}
func TestRateLimiterResetAll(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 1,
Burst: 1,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
// Create multiple buckets
rl.Allow("key1")
rl.Allow("key2")
// Reset all
rl.ResetAll()
stats := rl.GetStats()
if stats.BucketCount != 0 {
t.Errorf("Expected 0 buckets after reset, got %d", stats.BucketCount)
}
}
func TestRateLimiterCleanup(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 100,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
// Create some buckets
rl.Allow("key1")
rl.Allow("key2")
// Cleanup with very short max age
rl.Cleanup(1 * time.Nanosecond)
stats := rl.GetStats()
if stats.BucketCount != 0 {
t.Errorf("Expected 0 buckets after cleanup, got %d", stats.BucketCount)
}
}
func TestRateLimiterProcess(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 100,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
nextHandler := func(ctx *fasthttp.RequestCtx) {
_, _ = ctx.WriteString("OK")
}
handler := mw.Process(nextHandler)
if handler == nil {
t.Error("Process() returned nil handler")
}
}
func TestRateLimiterGetStats(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
rl.Allow("key1")
rl.Allow("key2")
stats := rl.GetStats()
if stats.BucketCount != 2 {
t.Errorf("Expected BucketCount 2, got %d", stats.BucketCount)
}
if stats.Rate != 100 {
t.Errorf("Expected Rate 100, got %f", stats.Rate)
}
if stats.Burst != 200 {
t.Errorf("Expected Burst 200, got %f", stats.Burst)
}
// 测试优雅关闭
rl.StopCleanup()
}
func TestRateLimiterAutoCleanup(t *testing.T) {
// 使用自定义的创建方式,方便测试
cfg := &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
Key: "ip",
}
mw, err := NewRateLimiter(cfg)
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
// 创建一些桶
rl.Allow("key1")
rl.Allow("key2")
rl.Allow("key3")
// 验证桶已创建
stats := rl.GetStats()
if stats.BucketCount != 3 {
t.Errorf("Expected 3 buckets, got %d", stats.BucketCount)
}
// 手动调用 Cleanup 模拟过期清理(使用很短的过期时间)
rl.Cleanup(1 * time.Nanosecond)
// 验证所有桶已被清理
stats = rl.GetStats()
if stats.BucketCount != 0 {
t.Errorf("Expected 0 buckets after cleanup, got %d", stats.BucketCount)
}
// 测试优雅关闭
rl.StopCleanup()
}
func TestRateLimiterStopCleanup(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
// 验证可以正常关闭
rl.StopCleanup()
// 再次调用不应 panic
rl.StopCleanup()
}
func TestNewConnLimiter(t *testing.T) {
tests := []struct {
name string
keyType string
max int
perKey bool
wantErr bool
}{
{
name: "global limit",
max: 100,
perKey: false,
},
{
name: "per-key by IP",
max: 10,
perKey: true,
keyType: "ip",
},
{
name: "per-key by header",
max: 10,
perKey: true,
keyType: "header",
},
{
name: "zero max",
max: 0,
wantErr: true,
},
{
name: "negative max",
max: -1,
wantErr: true,
},
{
name: "invalid key type",
max: 10,
perKey: true,
keyType: "invalid",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cl, err := NewConnLimiter(tt.max, tt.perKey, tt.keyType)
if (err != nil) != tt.wantErr {
t.Errorf("NewConnLimiter() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && cl == nil {
t.Error("Expected non-nil ConnLimiter")
}
})
}
}
func TestConnLimiterGlobal(t *testing.T) {
cl, err := NewConnLimiter(2, false, "")
if err != nil {
t.Fatalf("NewConnLimiter() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/")
// First two should succeed
if !cl.Acquire(ctx) {
t.Error("Expected first acquire to succeed")
}
if !cl.Acquire(ctx) {
t.Error("Expected second acquire to succeed")
}
// Third should fail
if cl.Acquire(ctx) {
t.Error("Expected third acquire to fail")
}
// Release one
cl.Release(ctx)
// Should succeed now
if !cl.Acquire(ctx) {
t.Error("Expected acquire after release to succeed")
}
}
func TestConnLimiterMiddleware(t *testing.T) {
cl, err := NewConnLimiter(1, false, "")
if err != nil {
t.Fatalf("NewConnLimiter() error: %v", err)
}
middleware := cl.Middleware()
if middleware == nil {
t.Error("Expected non-nil middleware")
}
if middleware.Name() != "conn_limiter" {
t.Errorf("Expected name 'conn_limiter', got %s", middleware.Name())
}
}