lolly/internal/middleware/security/ratelimit_test.go
xfy f145a8770e refactor: modernize code with Go 1.22+ features
Apply modern Go patterns across the codebase:
- Replace `interface{}` with `any` (Go 1.18+)
- Use `for range n` instead of `for i := 0; i < n; i++` (Go 1.22+)
- Replace `sort.Slice` with `slices.Sort` from slices package
- Simplify sync.WaitGroup patterns with errgroup where appropriate
- Add Makefile targets for modernize analyzer

Total: 84 files updated, net reduction of 79 lines

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-30 10:37:45 +08:00

870 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package security 提供速率限制功能的测试。
//
// 该文件测试速率限制模块的各项功能,包括:
// - 速率限制器创建
// - 令牌桶算法
// - 令牌补充机制
// - 计数器重置
// - 连接数限制
// - 统计信息获取
//
// 作者xfy
package security
import (
"testing"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/testutil"
)
func TestNewRateLimiter(t *testing.T) {
tests := []struct {
cfg *config.RateLimitConfig
name string
wantErr bool
}{
{
name: "nil config",
cfg: nil,
wantErr: true,
},
{
name: "valid config",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
},
},
{
name: "zero rate",
cfg: &config.RateLimitConfig{
RequestRate: 0,
Burst: 100,
},
wantErr: true,
},
{
name: "burst less than rate",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 50,
},
wantErr: true,
},
{
name: "key by IP",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
Key: "ip",
},
},
{
name: "key by header",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
Key: "header",
},
},
{
name: "unknown key type",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
Key: "unknown",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rl, err := NewRateLimiter(tt.cfg)
if (err != nil) != tt.wantErr {
t.Errorf("NewRateLimiter() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && rl == nil {
t.Error("Expected non-nil RateLimiter")
}
})
}
}
func TestRateLimiterAllow(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 10,
Burst: 10,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
// Test burst allowance
key := "test-key"
// Should allow burst requests
for i := range 10 {
if !rl.Allow(key) {
t.Errorf("Expected request %d to be allowed", i+1)
}
}
// Next request should be denied (burst exhausted)
if rl.Allow(key) {
t.Error("Expected request to be denied after burst exhausted")
}
}
func TestRateLimiterTokenRefill(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100, // 100 tokens per second
Burst: 100,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
key := "refill-test"
// Exhaust the burst
for range 100 {
rl.Allow(key)
}
// Should be denied
if rl.Allow(key) {
t.Error("Expected request to be denied")
}
// Wait for token refill (10ms should give us 1 token at 100/s)
time.Sleep(15 * time.Millisecond)
// Should be allowed now
if !rl.Allow(key) {
t.Error("Expected request to be allowed after refill")
}
}
func TestRateLimiterReset(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 1,
Burst: 1,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
key := "reset-test"
// Exhaust
rl.Allow(key)
if rl.Allow(key) {
t.Error("Expected denial")
}
// Reset
rl.Reset(key)
// Should be allowed again
if !rl.Allow(key) {
t.Error("Expected request to be allowed after reset")
}
}
func TestRateLimiterResetAll(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 1,
Burst: 1,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
// Create multiple buckets
rl.Allow("key1")
rl.Allow("key2")
// Reset all
rl.ResetAll()
stats := rl.GetStats()
if stats.BucketCount != 0 {
t.Errorf("Expected 0 buckets after reset, got %d", stats.BucketCount)
}
}
func TestRateLimiterCleanup(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 100,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
// Create some buckets
rl.Allow("key1")
rl.Allow("key2")
// Cleanup with very short max age
rl.Cleanup(1 * time.Nanosecond)
stats := rl.GetStats()
if stats.BucketCount != 0 {
t.Errorf("Expected 0 buckets after cleanup, got %d", stats.BucketCount)
}
}
func TestRateLimiterProcess(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 100,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
nextHandler := func(ctx *fasthttp.RequestCtx) {
_, _ = ctx.WriteString("OK")
}
handler := mw.Process(nextHandler)
if handler == nil {
t.Error("Process() returned nil handler")
}
}
func TestRateLimiterGetStats(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
rl.Allow("key1")
rl.Allow("key2")
stats := rl.GetStats()
if stats.BucketCount != 2 {
t.Errorf("Expected BucketCount 2, got %d", stats.BucketCount)
}
if stats.Rate != 100 {
t.Errorf("Expected Rate 100, got %f", stats.Rate)
}
if stats.Burst != 200 {
t.Errorf("Expected Burst 200, got %f", stats.Burst)
}
// 测试优雅关闭
rl.StopCleanup()
}
func TestRateLimiterAutoCleanup(t *testing.T) {
// 使用自定义的创建方式,方便测试
cfg := &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
Key: "ip",
}
mw, err := NewRateLimiter(cfg)
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
// 创建一些桶
rl.Allow("key1")
rl.Allow("key2")
rl.Allow("key3")
// 验证桶已创建
stats := rl.GetStats()
if stats.BucketCount != 3 {
t.Errorf("Expected 3 buckets, got %d", stats.BucketCount)
}
// 手动调用 Cleanup 模拟过期清理(使用很短的过期时间)
rl.Cleanup(1 * time.Nanosecond)
// 验证所有桶已被清理
stats = rl.GetStats()
if stats.BucketCount != 0 {
t.Errorf("Expected 0 buckets after cleanup, got %d", stats.BucketCount)
}
// 测试优雅关闭
rl.StopCleanup()
}
func TestRateLimiterStopCleanup(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
// 验证可以正常关闭
rl.StopCleanup()
// 再次调用不应 panic
rl.StopCleanup()
}
func TestNewConnLimiter(t *testing.T) {
tests := []struct {
name string
keyType string
max int
perKey bool
wantErr bool
}{
{
name: "global limit",
max: 100,
perKey: false,
},
{
name: "per-key by IP",
max: 10,
perKey: true,
keyType: "ip",
},
{
name: "per-key by header",
max: 10,
perKey: true,
keyType: "header",
},
{
name: "zero max",
max: 0,
wantErr: true,
},
{
name: "negative max",
max: -1,
wantErr: true,
},
{
name: "invalid key type",
max: 10,
perKey: true,
keyType: "invalid",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cl, err := NewConnLimiter(tt.max, tt.perKey, tt.keyType)
if (err != nil) != tt.wantErr {
t.Errorf("NewConnLimiter() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && cl == nil {
t.Error("Expected non-nil ConnLimiter")
}
})
}
}
func TestConnLimiterGlobal(t *testing.T) {
cl, err := NewConnLimiter(2, false, "")
if err != nil {
t.Fatalf("NewConnLimiter() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/")
// First two should succeed
if !cl.Acquire(ctx) {
t.Error("Expected first acquire to succeed")
}
if !cl.Acquire(ctx) {
t.Error("Expected second acquire to succeed")
}
// Third should fail
if cl.Acquire(ctx) {
t.Error("Expected third acquire to fail")
}
// Release one
cl.Release(ctx)
// Should succeed now
if !cl.Acquire(ctx) {
t.Error("Expected acquire after release to succeed")
}
}
func TestConnLimiterMiddleware(t *testing.T) {
cl, err := NewConnLimiter(1, false, "")
if err != nil {
t.Fatalf("NewConnLimiter() error: %v", err)
}
middleware := cl.Middleware()
if middleware == nil {
t.Error("Expected non-nil middleware")
}
if middleware.Name() != "conn_limiter" {
t.Errorf("Expected name 'conn_limiter', got %s", middleware.Name())
}
}
// TestNewRateLimiter_SlidingWindow 测试滑动窗口算法
func TestNewRateLimiter_SlidingWindow(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 100,
Algorithm: "sliding_window",
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
if mw == nil {
t.Error("Expected non-nil middleware for sliding_window")
}
if mw.Name() != "sliding_window_limiter" {
t.Errorf("Expected name 'sliding_window_limiter', got %s", mw.Name())
}
}
// TestNewRateLimiter_SlidingWindowDefault 测试滑动窗口默认配置
func TestNewRateLimiter_SlidingWindowDefault(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 100,
Algorithm: "sliding_window",
SlidingWindow: 0, // 使用默认值
SlidingWindowMode: "", // 使用默认值
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
if mw == nil {
t.Error("Expected non-nil middleware")
}
}
// TestNewRateLimiter_SlidingWindowPrecise 测试滑动窗口精确模式
func TestNewRateLimiter_SlidingWindowPrecise(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 100,
Algorithm: "sliding_window",
SlidingWindow: 1,
SlidingWindowMode: "precise",
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
if mw == nil {
t.Error("Expected non-nil middleware")
}
}
// TestNewRateLimiter_UnknownAlgorithm 测试未知算法
func TestNewRateLimiter_UnknownAlgorithm(t *testing.T) {
_, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 100,
Algorithm: "unknown_algo",
})
if err == nil {
t.Error("NewRateLimiter() should return error for unknown algorithm")
}
}
// TestSlidingWindowLimiterWrapper_Process 测试滑动窗口包装器的 Process 方法
func TestSlidingWindowLimiterWrapper_Process(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 100,
Algorithm: "sliding_window",
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
called := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
called = true
}
handler := mw.Process(nextHandler)
if handler == nil {
t.Fatal("Process() returned nil handler")
}
ctx := testutil.NewRequestCtx("GET", "/test")
handler(ctx)
if !called {
t.Error("Next handler should be called when rate limit allows")
}
}
// TestSlidingWindowLimiterWrapper_ProcessDenied 测试滑动窗口拒绝请求
func TestSlidingWindowLimiterWrapper_ProcessDenied(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 1,
Burst: 1,
Algorithm: "sliding_window",
SlidingWindowMode: "precise",
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
callCount := 0
nextHandler := func(ctx *fasthttp.RequestCtx) {
callCount++
}
handler := mw.Process(nextHandler)
ctx := testutil.NewRequestCtx("GET", "/test")
// 第一个请求应该被允许
handler(ctx)
// 第二个请求应该被拒绝
handler(ctx)
if callCount != 1 {
t.Errorf("Expected next handler to be called once, got %d", callCount)
}
}
// TestRateLimiter_GetRetryAfter 测试 getRetryAfter 方法
func TestRateLimiter_GetRetryAfter(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 10,
Burst: 10,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
defer rl.StopCleanup()
// 测试不存在的键
retryAfter := rl.getRetryAfter("nonexistent")
if retryAfter != 1 {
t.Errorf("getRetryAfter(nonexistent) = %d, want 1", retryAfter)
}
// 创建一个桶并消耗令牌
key := "test-key"
for range 10 {
rl.Allow(key)
}
// 获取重试时间
retryAfter = rl.getRetryAfter(key)
if retryAfter < 1 {
t.Errorf("getRetryAfter() = %d, want at least 1", retryAfter)
}
}
// TestKeyByIP 测试 keyByIP 函数
func TestKeyByIP(t *testing.T) {
ctx := testutil.NewRequestCtx("GET", "/test")
key := keyByIP(ctx)
if key == "" {
t.Error("keyByIP() should return non-empty string")
}
if key == "unknown" {
t.Error("keyByIP() should return IP address, not 'unknown'")
}
}
// TestKeyByHeader 测试 keyByHeader 函数
func TestKeyByHeader(t *testing.T) {
// 测试有头部的情况
ctx := testutil.NewRequestCtx("GET", "/test")
ctx.Request.Header.Set("X-RateLimit-Key", "custom-key")
key := keyByHeader(ctx)
if key != "custom-key" {
t.Errorf("keyByHeader() = %q, want 'custom-key'", key)
}
// 测试没有头部的情况(应该回退到 IP
ctx2 := testutil.NewRequestCtx("GET", "/test")
key2 := keyByHeader(ctx2)
if key2 == "" {
t.Error("keyByHeader() should fallback to IP when header not set")
}
}
// TestConnLimiter_PerKey 测试按键限制
func TestConnLimiter_PerKey(t *testing.T) {
cl, err := NewConnLimiter(2, true, "ip")
if err != nil {
t.Fatalf("NewConnLimiter() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/")
// 同一个键的前两个应该成功
if !cl.Acquire(ctx) {
t.Error("Expected first acquire to succeed")
}
if !cl.Acquire(ctx) {
t.Error("Expected second acquire to succeed")
}
// 第三个应该失败
if cl.Acquire(ctx) {
t.Error("Expected third acquire to fail")
}
// 释放一个
cl.Release(ctx)
// 现在应该成功
if !cl.Acquire(ctx) {
t.Error("Expected acquire after release to succeed")
}
}
// TestConnLimiter_ReleaseUnderflow 测试 Release 下溢保护
func TestConnLimiter_ReleaseUnderflow(t *testing.T) {
cl, err := NewConnLimiter(2, true, "ip")
if err != nil {
t.Fatalf("NewConnLimiter() error: %v", err)
}
ctx := testutil.NewRequestCtx("GET", "/")
// 在没有 Acquire 的情况下 Release测试下溢保护
cl.Release(ctx) // 不应该 panic
// 验证计数不会变成负数
cl.Acquire(ctx)
cl.Acquire(ctx)
if cl.Acquire(ctx) {
t.Error("Expected third acquire to fail")
}
}
// TestConnLimiterMiddleware_Process 测试连接限制中间件 Process
func TestConnLimiterMiddleware_Process(t *testing.T) {
cl, err := NewConnLimiter(1, false, "")
if err != nil {
t.Fatalf("NewConnLimiter() error: %v", err)
}
mw := cl.Middleware()
called := false
nextHandler := func(ctx *fasthttp.RequestCtx) {
called = true
ctx.SetStatusCode(200)
}
handler := mw.Process(nextHandler)
ctx := testutil.NewRequestCtx("GET", "/test")
// 第一个请求应该成功
handler(ctx)
if !called {
t.Error("Next handler should be called")
}
if ctx.Response.StatusCode() != 200 {
t.Errorf("Status = %d, want 200", ctx.Response.StatusCode())
}
}
// TestConnLimiterMiddleware_ProcessLimitExceeded 测试连接限制超出
func TestConnLimiterMiddleware_ProcessLimitExceeded(t *testing.T) {
cl, err := NewConnLimiter(1, false, "")
if err != nil {
t.Fatalf("NewConnLimiter() error: %v", err)
}
mw := cl.Middleware()
callCount := 0
nextHandler := func(ctx *fasthttp.RequestCtx) {
callCount++
ctx.SetStatusCode(200)
}
handler := mw.Process(nextHandler)
// 用尽连接限制
ctx1 := testutil.NewRequestCtx("GET", "/test1")
cl.Acquire(ctx1) // 手动占用一个槽位
// 现在应该无法获取新的连接
ctx2 := testutil.NewRequestCtx("GET", "/test2")
handler(ctx2)
if callCount != 0 {
t.Error("Next handler should NOT be called when limit exceeded")
}
if ctx2.Response.StatusCode() != 503 {
t.Errorf("Status = %d, want 503", ctx2.Response.StatusCode())
}
}
// TestNewSlidingWindowLimiterWrapper_Error 测试滑动窗口包装器错误情况
func TestNewSlidingWindowLimiterWrapper_Error(t *testing.T) {
_, err := NewSlidingWindowLimiterWrapper(&config.RateLimitConfig{
RequestRate: 100,
Key: "invalid_key_type",
}, time.Second, false)
if err == nil {
t.Error("NewSlidingWindowLimiterWrapper should return error for invalid key type")
}
}
// TestRateLimiter_Name 测试 RateLimiter Name 方法
func TestRateLimiter_Name(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 10,
Burst: 10,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
defer rl.StopCleanup()
if rl.Name() != "rate_limiter" {
t.Errorf("Name() = %q, want 'rate_limiter'", rl.Name())
}
}
// TestRateLimiter_ProcessDenied 测试限流拒绝
func TestRateLimiter_ProcessDenied(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 1,
Burst: 1,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
defer rl.StopCleanup()
callCount := 0
nextHandler := func(ctx *fasthttp.RequestCtx) {
callCount++
}
handler := rl.Process(nextHandler)
// 第一个请求应该成功
ctx1 := testutil.NewRequestCtx("GET", "/test")
handler(ctx1)
// 第二个请求应该被限流(使用不同的 context
ctx2 := testutil.NewRequestCtx("GET", "/test")
handler(ctx2)
if callCount != 1 {
t.Errorf("Expected next handler to be called once, got %d", callCount)
}
// 检查第二个请求的状态码
if ctx2.Response.StatusCode() != 429 {
t.Errorf("Status = %d, want 429", ctx2.Response.StatusCode())
}
}
// TestKeyByIP_Unknown 测试无法获取 IP 的情况
func TestKeyByIP_Unknown(t *testing.T) {
// 创建一个没有设置远程地址的上下文
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/test")
key := keyByIP(ctx)
// netutil.ExtractClientIPNet 会返回默认值 0.0.0.0 而不是 nil
// 所以这里验证返回值不是空的即可
if key == "" {
t.Error("keyByIP() should return non-empty string")
}
}
// TestConnLimiter_MiddlewareIdentity 验证 Middleware() 返回相同实例
func TestConnLimiter_MiddlewareIdentity(t *testing.T) {
cl, err := NewConnLimiter(100, false, "")
if err != nil {
t.Fatalf("NewConnLimiter() error: %v", err)
}
// Middleware() 应该返回自身
middleware := cl.Middleware()
if middleware != cl {
t.Error("Middleware() should return the same ConnLimiter instance")
}
// 验证返回的实例实现了 Middleware 接口
if middleware.Name() != "conn_limiter" {
t.Errorf("Name() = %s, want 'conn_limiter'", middleware.Name())
}
}