lolly/internal/middleware/security/ratelimit_test.go
xfy f2352ab9cc docs(config,stream,logging,handler,proxy,cache,server,ssl,middleware): 为核心模块添加详细 GoDoc 文档注释
- config: 为 Config 和所有子配置结构添加完整文档,包含使用示例和注意事项
- stream: 为负载均衡器和服务器添加详细的参数、返回值和功能说明
- logging: 为日志格式化和输出函数添加文档,说明支持的变量替换
- handler: 为路由器、静态文件和 sendfile 处理器添加文档
- proxy: 为健康检查器和代理功能添加完整文档
- cache/server/ssl/middleware: 补充相关模块的文档注释
- config.example.yaml: 添加可信代理配置、加密套件示例,更新压缩级别说明

Co-Authored-By: Claude <noreply@anthropic.com>
2026-04-07 15:36:09 +08:00

395 lines
7.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package security 提供速率限制功能的测试。
//
// 该文件测试速率限制模块的各项功能,包括:
// - 速率限制器创建
// - 令牌桶算法
// - 令牌补充机制
// - 计数器重置
// - 连接数限制
// - 统计信息获取
//
// 作者xfy
package security
import (
"testing"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
)
func TestNewRateLimiter(t *testing.T) {
tests := []struct {
name string
cfg *config.RateLimitConfig
wantErr bool
}{
{
name: "nil config",
cfg: nil,
wantErr: true,
},
{
name: "valid config",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
},
},
{
name: "zero rate",
cfg: &config.RateLimitConfig{
RequestRate: 0,
Burst: 100,
},
wantErr: true,
},
{
name: "burst less than rate",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 50,
},
wantErr: true,
},
{
name: "key by IP",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
Key: "ip",
},
},
{
name: "key by header",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
Key: "header",
},
},
{
name: "unknown key type",
cfg: &config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
Key: "unknown",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rl, err := NewRateLimiter(tt.cfg)
if (err != nil) != tt.wantErr {
t.Errorf("NewRateLimiter() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && rl == nil {
t.Error("Expected non-nil RateLimiter")
}
})
}
}
func TestRateLimiterAllow(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 10,
Burst: 10,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
// Test burst allowance
key := "test-key"
// Should allow burst requests
for i := 0; i < 10; i++ {
if !rl.Allow(key) {
t.Errorf("Expected request %d to be allowed", i+1)
}
}
// Next request should be denied (burst exhausted)
if rl.Allow(key) {
t.Error("Expected request to be denied after burst exhausted")
}
}
func TestRateLimiterTokenRefill(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100, // 100 tokens per second
Burst: 100,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
key := "refill-test"
// Exhaust the burst
for i := 0; i < 100; i++ {
rl.Allow(key)
}
// Should be denied
if rl.Allow(key) {
t.Error("Expected request to be denied")
}
// Wait for token refill (10ms should give us 1 token at 100/s)
time.Sleep(15 * time.Millisecond)
// Should be allowed now
if !rl.Allow(key) {
t.Error("Expected request to be allowed after refill")
}
}
func TestRateLimiterReset(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 1,
Burst: 1,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
key := "reset-test"
// Exhaust
rl.Allow(key)
if rl.Allow(key) {
t.Error("Expected denial")
}
// Reset
rl.Reset(key)
// Should be allowed again
if !rl.Allow(key) {
t.Error("Expected request to be allowed after reset")
}
}
func TestRateLimiterResetAll(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 1,
Burst: 1,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
// Create multiple buckets
rl.Allow("key1")
rl.Allow("key2")
// Reset all
rl.ResetAll()
stats := rl.GetStats()
if stats.BucketCount != 0 {
t.Errorf("Expected 0 buckets after reset, got %d", stats.BucketCount)
}
}
func TestRateLimiterCleanup(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 100,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
// Create some buckets
rl.Allow("key1")
rl.Allow("key2")
// Cleanup with very short max age
rl.Cleanup(1 * time.Nanosecond)
stats := rl.GetStats()
if stats.BucketCount != 0 {
t.Errorf("Expected 0 buckets after cleanup, got %d", stats.BucketCount)
}
}
func TestRateLimiterProcess(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 100,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
nextHandler := func(ctx *fasthttp.RequestCtx) {
_, _ = ctx.WriteString("OK")
}
handler := mw.Process(nextHandler)
if handler == nil {
t.Error("Process() returned nil handler")
}
}
func TestRateLimiterGetStats(t *testing.T) {
mw, err := NewRateLimiter(&config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
})
if err != nil {
t.Fatalf("NewRateLimiter() error: %v", err)
}
rl, ok := mw.(*RateLimiter)
if !ok {
t.Fatalf("Expected *RateLimiter, got %T", mw)
}
rl.Allow("key1")
rl.Allow("key2")
stats := rl.GetStats()
if stats.BucketCount != 2 {
t.Errorf("Expected BucketCount 2, got %d", stats.BucketCount)
}
if stats.Rate != 100 {
t.Errorf("Expected Rate 100, got %f", stats.Rate)
}
if stats.Burst != 200 {
t.Errorf("Expected Burst 200, got %f", stats.Burst)
}
}
func TestNewConnLimiter(t *testing.T) {
tests := []struct {
name string
max int
perKey bool
keyType string
wantErr bool
}{
{
name: "global limit",
max: 100,
perKey: false,
},
{
name: "per-key by IP",
max: 10,
perKey: true,
keyType: "ip",
},
{
name: "per-key by header",
max: 10,
perKey: true,
keyType: "header",
},
{
name: "zero max",
max: 0,
wantErr: true,
},
{
name: "negative max",
max: -1,
wantErr: true,
},
{
name: "invalid key type",
max: 10,
perKey: true,
keyType: "invalid",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cl, err := NewConnLimiter(tt.max, tt.perKey, tt.keyType)
if (err != nil) != tt.wantErr {
t.Errorf("NewConnLimiter() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && cl == nil {
t.Error("Expected non-nil ConnLimiter")
}
})
}
}
func TestConnLimiterGlobal(t *testing.T) {
cl, err := NewConnLimiter(2, false, "")
if err != nil {
t.Fatalf("NewConnLimiter() error: %v", err)
}
ctx := &fasthttp.RequestCtx{}
// First two should succeed
if !cl.Acquire(ctx) {
t.Error("Expected first acquire to succeed")
}
if !cl.Acquire(ctx) {
t.Error("Expected second acquire to succeed")
}
// Third should fail
if cl.Acquire(ctx) {
t.Error("Expected third acquire to fail")
}
// Release one
cl.Release(ctx)
// Should succeed now
if !cl.Acquire(ctx) {
t.Error("Expected acquire after release to succeed")
}
}
func TestConnLimiterMiddleware(t *testing.T) {
cl, err := NewConnLimiter(1, false, "")
if err != nil {
t.Fatalf("NewConnLimiter() error: %v", err)
}
middleware := cl.Middleware()
if middleware == nil {
t.Error("Expected non-nil middleware")
}
if middleware.Name() != "conn_limiter" {
t.Errorf("Expected name 'conn_limiter', got %s", middleware.Name())
}
}