Conflict: sendfile.go (!linux build tag) was incorrectly modified to include linuxSendfile and getSocketFd functions which already exist in sendfile_linux.go. Resolution: Keep HEAD version (simple fallback returning ENOTSUP) as Linux implementation is properly separated in sendfile_linux.go. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
461 lines
8.8 KiB
Go
461 lines
8.8 KiB
Go
// Package security 提供速率限制功能的测试。
|
||
//
|
||
// 该文件测试速率限制模块的各项功能,包括:
|
||
// - 速率限制器创建
|
||
// - 令牌桶算法
|
||
// - 令牌补充机制
|
||
// - 计数器重置
|
||
// - 连接数限制
|
||
// - 统计信息获取
|
||
//
|
||
// 作者:xfy
|
||
package security
|
||
|
||
import (
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/valyala/fasthttp"
|
||
"rua.plus/lolly/internal/config"
|
||
)
|
||
|
||
func TestNewRateLimiter(t *testing.T) {
|
||
tests := []struct {
|
||
cfg *config.RateLimitConfig
|
||
name string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "nil config",
|
||
cfg: nil,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "valid config",
|
||
cfg: &config.RateLimitConfig{
|
||
RequestRate: 100,
|
||
Burst: 200,
|
||
},
|
||
},
|
||
{
|
||
name: "zero rate",
|
||
cfg: &config.RateLimitConfig{
|
||
RequestRate: 0,
|
||
Burst: 100,
|
||
},
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "burst less than rate",
|
||
cfg: &config.RateLimitConfig{
|
||
RequestRate: 100,
|
||
Burst: 50,
|
||
},
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "key by IP",
|
||
cfg: &config.RateLimitConfig{
|
||
RequestRate: 100,
|
||
Burst: 200,
|
||
Key: "ip",
|
||
},
|
||
},
|
||
{
|
||
name: "key by header",
|
||
cfg: &config.RateLimitConfig{
|
||
RequestRate: 100,
|
||
Burst: 200,
|
||
Key: "header",
|
||
},
|
||
},
|
||
{
|
||
name: "unknown key type",
|
||
cfg: &config.RateLimitConfig{
|
||
RequestRate: 100,
|
||
Burst: 200,
|
||
Key: "unknown",
|
||
},
|
||
wantErr: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
rl, err := NewRateLimiter(tt.cfg)
|
||
if (err != nil) != tt.wantErr {
|
||
t.Errorf("NewRateLimiter() error = %v, wantErr %v", err, tt.wantErr)
|
||
}
|
||
if !tt.wantErr && rl == nil {
|
||
t.Error("Expected non-nil RateLimiter")
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestRateLimiterAllow(t *testing.T) {
|
||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||
RequestRate: 10,
|
||
Burst: 10,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||
}
|
||
|
||
rl, ok := mw.(*RateLimiter)
|
||
if !ok {
|
||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||
}
|
||
|
||
// Test burst allowance
|
||
key := "test-key"
|
||
|
||
// Should allow burst requests
|
||
for i := 0; i < 10; i++ {
|
||
if !rl.Allow(key) {
|
||
t.Errorf("Expected request %d to be allowed", i+1)
|
||
}
|
||
}
|
||
|
||
// Next request should be denied (burst exhausted)
|
||
if rl.Allow(key) {
|
||
t.Error("Expected request to be denied after burst exhausted")
|
||
}
|
||
}
|
||
|
||
func TestRateLimiterTokenRefill(t *testing.T) {
|
||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||
RequestRate: 100, // 100 tokens per second
|
||
Burst: 100,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||
}
|
||
|
||
rl, ok := mw.(*RateLimiter)
|
||
if !ok {
|
||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||
}
|
||
|
||
key := "refill-test"
|
||
|
||
// Exhaust the burst
|
||
for i := 0; i < 100; i++ {
|
||
rl.Allow(key)
|
||
}
|
||
|
||
// Should be denied
|
||
if rl.Allow(key) {
|
||
t.Error("Expected request to be denied")
|
||
}
|
||
|
||
// Wait for token refill (10ms should give us 1 token at 100/s)
|
||
time.Sleep(15 * time.Millisecond)
|
||
|
||
// Should be allowed now
|
||
if !rl.Allow(key) {
|
||
t.Error("Expected request to be allowed after refill")
|
||
}
|
||
}
|
||
|
||
func TestRateLimiterReset(t *testing.T) {
|
||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||
RequestRate: 1,
|
||
Burst: 1,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||
}
|
||
|
||
rl, ok := mw.(*RateLimiter)
|
||
if !ok {
|
||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||
}
|
||
|
||
key := "reset-test"
|
||
|
||
// Exhaust
|
||
rl.Allow(key)
|
||
if rl.Allow(key) {
|
||
t.Error("Expected denial")
|
||
}
|
||
|
||
// Reset
|
||
rl.Reset(key)
|
||
|
||
// Should be allowed again
|
||
if !rl.Allow(key) {
|
||
t.Error("Expected request to be allowed after reset")
|
||
}
|
||
}
|
||
|
||
func TestRateLimiterResetAll(t *testing.T) {
|
||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||
RequestRate: 1,
|
||
Burst: 1,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||
}
|
||
|
||
rl, ok := mw.(*RateLimiter)
|
||
if !ok {
|
||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||
}
|
||
|
||
// Create multiple buckets
|
||
rl.Allow("key1")
|
||
rl.Allow("key2")
|
||
|
||
// Reset all
|
||
rl.ResetAll()
|
||
|
||
stats := rl.GetStats()
|
||
if stats.BucketCount != 0 {
|
||
t.Errorf("Expected 0 buckets after reset, got %d", stats.BucketCount)
|
||
}
|
||
}
|
||
|
||
func TestRateLimiterCleanup(t *testing.T) {
|
||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||
RequestRate: 100,
|
||
Burst: 100,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||
}
|
||
|
||
rl, ok := mw.(*RateLimiter)
|
||
if !ok {
|
||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||
}
|
||
|
||
// Create some buckets
|
||
rl.Allow("key1")
|
||
rl.Allow("key2")
|
||
|
||
// Cleanup with very short max age
|
||
rl.Cleanup(1 * time.Nanosecond)
|
||
|
||
stats := rl.GetStats()
|
||
if stats.BucketCount != 0 {
|
||
t.Errorf("Expected 0 buckets after cleanup, got %d", stats.BucketCount)
|
||
}
|
||
}
|
||
|
||
func TestRateLimiterProcess(t *testing.T) {
|
||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||
RequestRate: 100,
|
||
Burst: 100,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||
}
|
||
|
||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||
_, _ = ctx.WriteString("OK")
|
||
}
|
||
|
||
handler := mw.Process(nextHandler)
|
||
if handler == nil {
|
||
t.Error("Process() returned nil handler")
|
||
}
|
||
}
|
||
|
||
func TestRateLimiterGetStats(t *testing.T) {
|
||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||
RequestRate: 100,
|
||
Burst: 200,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||
}
|
||
|
||
rl, ok := mw.(*RateLimiter)
|
||
if !ok {
|
||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||
}
|
||
|
||
rl.Allow("key1")
|
||
rl.Allow("key2")
|
||
|
||
stats := rl.GetStats()
|
||
if stats.BucketCount != 2 {
|
||
t.Errorf("Expected BucketCount 2, got %d", stats.BucketCount)
|
||
}
|
||
if stats.Rate != 100 {
|
||
t.Errorf("Expected Rate 100, got %f", stats.Rate)
|
||
}
|
||
if stats.Burst != 200 {
|
||
t.Errorf("Expected Burst 200, got %f", stats.Burst)
|
||
}
|
||
|
||
// 测试优雅关闭
|
||
rl.StopCleanup()
|
||
}
|
||
|
||
func TestRateLimiterAutoCleanup(t *testing.T) {
|
||
// 使用自定义的创建方式,方便测试
|
||
cfg := &config.RateLimitConfig{
|
||
RequestRate: 100,
|
||
Burst: 200,
|
||
Key: "ip",
|
||
}
|
||
|
||
mw, err := NewRateLimiter(cfg)
|
||
if err != nil {
|
||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||
}
|
||
|
||
rl, ok := mw.(*RateLimiter)
|
||
if !ok {
|
||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||
}
|
||
|
||
// 创建一些桶
|
||
rl.Allow("key1")
|
||
rl.Allow("key2")
|
||
rl.Allow("key3")
|
||
|
||
// 验证桶已创建
|
||
stats := rl.GetStats()
|
||
if stats.BucketCount != 3 {
|
||
t.Errorf("Expected 3 buckets, got %d", stats.BucketCount)
|
||
}
|
||
|
||
// 手动调用 Cleanup 模拟过期清理(使用很短的过期时间)
|
||
rl.Cleanup(1 * time.Nanosecond)
|
||
|
||
// 验证所有桶已被清理
|
||
stats = rl.GetStats()
|
||
if stats.BucketCount != 0 {
|
||
t.Errorf("Expected 0 buckets after cleanup, got %d", stats.BucketCount)
|
||
}
|
||
|
||
// 测试优雅关闭
|
||
rl.StopCleanup()
|
||
}
|
||
|
||
func TestRateLimiterStopCleanup(t *testing.T) {
|
||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||
RequestRate: 100,
|
||
Burst: 200,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||
}
|
||
|
||
rl, ok := mw.(*RateLimiter)
|
||
if !ok {
|
||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||
}
|
||
|
||
// 验证可以正常关闭
|
||
rl.StopCleanup()
|
||
|
||
// 再次调用不应 panic
|
||
rl.StopCleanup()
|
||
}
|
||
|
||
func TestNewConnLimiter(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
keyType string
|
||
max int
|
||
perKey bool
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "global limit",
|
||
max: 100,
|
||
perKey: false,
|
||
},
|
||
{
|
||
name: "per-key by IP",
|
||
max: 10,
|
||
perKey: true,
|
||
keyType: "ip",
|
||
},
|
||
{
|
||
name: "per-key by header",
|
||
max: 10,
|
||
perKey: true,
|
||
keyType: "header",
|
||
},
|
||
{
|
||
name: "zero max",
|
||
max: 0,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "negative max",
|
||
max: -1,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "invalid key type",
|
||
max: 10,
|
||
perKey: true,
|
||
keyType: "invalid",
|
||
wantErr: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
cl, err := NewConnLimiter(tt.max, tt.perKey, tt.keyType)
|
||
if (err != nil) != tt.wantErr {
|
||
t.Errorf("NewConnLimiter() error = %v, wantErr %v", err, tt.wantErr)
|
||
}
|
||
if !tt.wantErr && cl == nil {
|
||
t.Error("Expected non-nil ConnLimiter")
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestConnLimiterGlobal(t *testing.T) {
|
||
cl, err := NewConnLimiter(2, false, "")
|
||
if err != nil {
|
||
t.Fatalf("NewConnLimiter() error: %v", err)
|
||
}
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
|
||
// First two should succeed
|
||
if !cl.Acquire(ctx) {
|
||
t.Error("Expected first acquire to succeed")
|
||
}
|
||
if !cl.Acquire(ctx) {
|
||
t.Error("Expected second acquire to succeed")
|
||
}
|
||
|
||
// Third should fail
|
||
if cl.Acquire(ctx) {
|
||
t.Error("Expected third acquire to fail")
|
||
}
|
||
|
||
// Release one
|
||
cl.Release(ctx)
|
||
|
||
// Should succeed now
|
||
if !cl.Acquire(ctx) {
|
||
t.Error("Expected acquire after release to succeed")
|
||
}
|
||
}
|
||
|
||
func TestConnLimiterMiddleware(t *testing.T) {
|
||
cl, err := NewConnLimiter(1, false, "")
|
||
if err != nil {
|
||
t.Fatalf("NewConnLimiter() error: %v", err)
|
||
}
|
||
|
||
middleware := cl.Middleware()
|
||
if middleware == nil {
|
||
t.Error("Expected non-nil middleware")
|
||
}
|
||
if middleware.Name() != "conn_limiter" {
|
||
t.Errorf("Expected name 'conn_limiter', got %s", middleware.Name())
|
||
}
|
||
}
|