lolly/internal/server/server_test.go

3366 lines
82 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package server 提供 HTTP 服务器功能的测试。
//
// 该文件测试服务器模块的各项功能,包括:
// - 服务器创建和初始化
// - 启动和停止控制
// - 优雅关闭
// - 中间件链构建
// - 请求统计追踪
// - 监听器管理
// - TLS 配置
// - 代理缓存统计
//
// 作者xfy
package server
import (
"context"
"fmt"
"net"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/lua"
"rua.plus/lolly/internal/matcher"
"rua.plus/lolly/internal/middleware/accesslog"
"rua.plus/lolly/internal/middleware/security"
"rua.plus/lolly/internal/proxy"
"rua.plus/lolly/internal/ssl"
"rua.plus/lolly/internal/testutil"
"rua.plus/lolly/internal/version"
)
// TestNew 测试服务器创建
func TestNew(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
Static: []config.StaticConfig{{
Path: "/",
Root: "./static",
Index: []string{"index.html"},
}},
}},
}
s := New(cfg)
if s == nil {
t.Fatal("New() returned nil, expected non-nil Server")
}
if s.config != cfg {
t.Error("Server.config not set correctly")
}
if s.running.Load() {
t.Error("Server.running should be false initially")
}
if s.fastServer != nil {
t.Error("Server.fastServer should be nil before Start()")
}
}
// TestStopWithoutServer 测试无服务器时调用 Stop
func TestStopWithoutServer(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 在未启动时调用 Stop应返回 nil
err := s.StopWithTimeout(5 * time.Second)
if err != nil {
t.Errorf("StopWithTimeout() on non-started server returned error: %v", err)
}
}
// TestGracefulStop 测试 GracefulStop 调用
func TestGracefulStop(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 在未启动时调用 GracefulStop应返回 nil
err := s.GracefulStop(5 * time.Second)
if err != nil {
t.Errorf("GracefulStop() on non-started server returned error: %v", err)
}
}
// TestStopAfterStop 测试多次调用 Stop
func TestStopAfterStop(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 多次调用 StopWithTimeout 应该都是安全的
for i := range 3 {
err := s.StopWithTimeout(5 * time.Second)
if err != nil {
t.Errorf("StopWithTimeout() call %d returned error: %v", i+1, err)
}
}
}
// TestGracefulStopWithZeroTimeout 测试零超时的 GracefulStop
func TestGracefulStopWithZeroTimeout(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
err := s.GracefulStop(0)
if err != nil {
t.Errorf("GracefulStop(0) returned error: %v", err)
}
}
// TestBuildMiddlewareChain_AccessLog 测试访问日志中间件
func TestBuildMiddlewareChain_AccessLog(t *testing.T) {
cfg := &config.Config{
Logging: config.LoggingConfig{},
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
if err != nil {
t.Errorf("buildMiddlewareChain failed: %v", err)
}
if chain == nil {
t.Error("Expected non-nil chain")
}
}
// TestBuildMiddlewareChain_AccessControl 测试访问控制中间件
func TestBuildMiddlewareChain_AccessControl(t *testing.T) {
cfg := &config.Config{
Logging: config.LoggingConfig{},
Servers: []config.ServerConfig{{
Listen: ":8080",
Security: config.SecurityConfig{
Access: config.AccessConfig{
Allow: []string{"127.0.0.1"},
},
},
}},
}
s := New(cfg)
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
if err != nil {
t.Errorf("buildMiddlewareChain failed: %v", err)
}
if chain == nil {
t.Error("Expected non-nil chain")
}
}
// TestBuildMiddlewareChain_RateLimiter 测试限流中间件
func TestBuildMiddlewareChain_RateLimiter(t *testing.T) {
cfg := &config.Config{
Logging: config.LoggingConfig{},
Servers: []config.ServerConfig{{
Listen: ":8080",
Security: config.SecurityConfig{
RateLimit: config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
},
},
}},
}
s := New(cfg)
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
if err != nil {
t.Errorf("buildMiddlewareChain failed: %v", err)
}
if chain == nil {
t.Error("Expected non-nil chain")
}
}
// TestBuildMiddlewareChain_Rewrite 测试重写中间件
func TestBuildMiddlewareChain_Rewrite(t *testing.T) {
cfg := &config.Config{
Logging: config.LoggingConfig{},
Servers: []config.ServerConfig{{
Listen: ":8080",
Rewrite: []config.RewriteRule{
{Pattern: "/old/(.*)", Replacement: "/new/$1"},
},
}},
}
s := New(cfg)
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
if err != nil {
t.Errorf("buildMiddlewareChain failed: %v", err)
}
if chain == nil {
t.Error("Expected non-nil chain")
}
}
// TestBuildMiddlewareChain_Compression 测试压缩中间件
func TestBuildMiddlewareChain_Compression(t *testing.T) {
cfg := &config.Config{
Logging: config.LoggingConfig{},
Servers: []config.ServerConfig{{
Listen: ":8080",
Compression: config.CompressionConfig{
Level: 6,
},
}},
}
s := New(cfg)
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
if err != nil {
t.Errorf("buildMiddlewareChain failed: %v", err)
}
if chain == nil {
t.Error("Expected non-nil chain")
}
}
// TestBuildMiddlewareChain_SecurityHeaders 测试安全头中间件
func TestBuildMiddlewareChain_SecurityHeaders(t *testing.T) {
cfg := &config.Config{
Logging: config.LoggingConfig{},
Servers: []config.ServerConfig{{
Listen: ":8080",
Security: config.SecurityConfig{
Headers: config.SecurityHeaders{
XFrameOptions: "DENY",
XContentTypeOptions: "nosniff",
},
},
}},
}
s := New(cfg)
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
if err != nil {
t.Errorf("buildMiddlewareChain failed: %v", err)
}
if chain == nil {
t.Error("Expected non-nil chain")
}
}
// TestBuildMiddlewareChain_AllMiddlewares 测试所有中间件组合
func TestBuildMiddlewareChain_AllMiddlewares(t *testing.T) {
cfg := &config.Config{
Logging: config.LoggingConfig{},
Servers: []config.ServerConfig{{
Listen: ":8080",
Security: config.SecurityConfig{
Access: config.AccessConfig{
Allow: []string{"127.0.0.1"},
},
RateLimit: config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
},
Headers: config.SecurityHeaders{
XFrameOptions: "DENY",
},
},
Rewrite: []config.RewriteRule{
{Pattern: "/old/(.*)", Replacement: "/new/$1"},
},
Compression: config.CompressionConfig{
Level: 6,
},
}},
}
s := New(cfg)
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
if err != nil {
t.Errorf("buildMiddlewareChain failed: %v", err)
}
if chain == nil {
t.Error("Expected non-nil chain")
}
}
// TestTrackStats 测试请求统计追踪
func TestTrackStats(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 初始统计应该为 0
if s.requests.Load() != 0 {
t.Error("Initial requests should be 0")
}
if s.bytesSent.Load() != 0 {
t.Error("Initial bytesSent should be 0")
}
if s.bytesReceived.Load() != 0 {
t.Error("Initial bytesReceived should be 0")
}
// 创建测试 handler
handler := func(ctx *fasthttp.RequestCtx) {
ctx.SetBodyString("response body")
}
// 包装 handler
wrappedHandler := s.trackStats(handler)
// 创建测试请求上下文
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
ctx.Request.SetBody([]byte("request body"))
// 执行
wrappedHandler(ctx)
// 验证统计
if s.requests.Load() != 1 {
t.Errorf("Expected 1 request, got %d", s.requests.Load())
}
if s.bytesReceived.Load() != int64(len("request body")) {
t.Errorf("Expected bytesReceived %d, got %d", len("request body"), s.bytesReceived.Load())
}
if s.bytesSent.Load() != int64(len("response body")) {
t.Errorf("Expected bytesSent %d, got %d", len("response body"), s.bytesSent.Load())
}
}
// TestTrackStats_MultipleRequests 测试多次请求统计
func TestTrackStats_MultipleRequests(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
handler := func(ctx *fasthttp.RequestCtx) {
ctx.SetBodyString("ok")
}
wrappedHandler := s.trackStats(handler)
// 执行多次请求
for range 10 {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
wrappedHandler(ctx)
}
if s.requests.Load() != 10 {
t.Errorf("Expected 10 requests, got %d", s.requests.Load())
}
}
// TestGetListeners_Empty 测试空监听器列表
func TestGetListeners_Empty(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
listeners := s.GetListeners()
if listeners != nil {
t.Errorf("Expected nil listeners, got %v", listeners)
}
}
// TestSetListeners 测试设置监听器
func TestSetListeners(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 创建模拟监听器
listener1, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() {
_ = listener1.Close()
}()
listener2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() {
_ = listener2.Close()
}()
listeners := []net.Listener{listener1, listener2}
s.SetListeners(listeners)
// 验证设置成功
got := s.GetListeners()
if len(got) != 2 {
t.Errorf("Expected 2 listeners, got %d", len(got))
}
}
// TestGetTLSConfig_NotConfigured 测试未配置 TLS
func TestGetTLSConfig_NotConfigured(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
tlsConfig, err := s.GetTLSConfig()
if err == nil {
t.Error("Expected error for unconfigured TLS")
}
if tlsConfig != nil {
t.Error("Expected nil TLS config")
}
if err.Error() != "TLS not configured" {
t.Errorf("Expected error 'TLS not configured', got: %v", err)
}
}
// TestGetHandler 测试获取 handler
func TestGetHandler(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 初始 handler 应该为 nil
handler := s.GetHandler()
if handler != nil {
t.Error("Expected nil handler initially")
}
// 设置一个 handler
testHandler := func(ctx *fasthttp.RequestCtx) {
ctx.SetBodyString("test")
}
s.handler = testHandler
// 验证获取成功
got := s.GetHandler()
if got == nil {
t.Error("Expected non-nil handler after setting")
}
}
// TestServer_Connections 测试连接统计
func TestServer_Connections(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 初始连接数应该为 0
if s.connections.Load() != 0 {
t.Error("Initial connections should be 0")
}
// 增加
s.connections.Add(1)
if s.connections.Load() != 1 {
t.Errorf("Expected 1 connection, got %d", s.connections.Load())
}
// 减少
s.connections.Add(-1)
if s.connections.Load() != 0 {
t.Errorf("Expected 0 connections, got %d", s.connections.Load())
}
}
// TestServer_Proxies 测试代理管理
func TestServer_Proxies(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 初始代理列表应为空
if len(s.proxies) != 0 {
t.Error("Initial proxies should be empty")
}
}
// TestServer_Running 测试运行状态
func TestServer_Running(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 初始状态应为未运行
if s.running.Load() {
t.Error("Initial running state should be false")
}
}
// TestServer_StopWithNilFastServer 测试无 fastServer 时停止
func TestServer_StopWithNilFastServer(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
s.fastServer = nil
err := s.StopWithTimeout(5 * time.Second)
if err != nil {
t.Errorf("StopWithTimeout with nil fastServer should succeed: %v", err)
}
}
// TestServer_GracefulStopWithNilFastServer 测试无 fastServer 时优雅停止
func TestServer_GracefulStopWithNilFastServer(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
s.fastServer = nil
err := s.GracefulStop(5 * time.Second)
if err != nil {
t.Errorf("GracefulStop with nil fastServer should succeed: %v", err)
}
}
// TestServer_GetProxyCacheStats 测试代理缓存统计
func TestServer_GetProxyCacheStats(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 无代理时应返回空统计
stats := s.getProxyCacheStats()
if stats.Entries != 0 {
t.Errorf("Expected 0 entries, got %d", stats.Entries)
}
if stats.Pending != 0 {
t.Errorf("Expected 0 pending, got %d", stats.Pending)
}
}
// TestServer_BuildMiddlewareChain_EmptyConfig 测试空配置的中间件链
func TestServer_BuildMiddlewareChain_EmptyConfig(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if chain == nil {
t.Error("Expected non-nil chain")
}
}
// TestServer_TrackStats_EmptyBody 测试空响应体的统计
func TestServer_TrackStats_EmptyBody(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
handler := func(_ *fasthttp.RequestCtx) {
// 空响应
}
wrappedHandler := s.trackStats(handler)
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
ctx.Request.SetBody(nil)
wrappedHandler(ctx)
if s.requests.Load() != 1 {
t.Errorf("Expected 1 request, got %d", s.requests.Load())
}
if s.bytesSent.Load() != 0 {
t.Errorf("Expected 0 bytes sent, got %d", s.bytesSent.Load())
}
}
// TestStart_Success 测试服务器配置初始化
func TestStart_Success(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 验证服务器正确初始化
if s == nil {
t.Fatal("New() returned nil, expected non-nil Server")
}
if s.config != cfg {
t.Error("Server.config not set correctly")
}
}
// TestStart_WithStaticFiles 测试静态文件配置
func TestStart_WithStaticFiles(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
Static: []config.StaticConfig{{
Path: "/static",
Root: tempDir,
Index: []string{"index.html"},
}},
}},
}
s := New(cfg)
if s == nil {
t.Fatal("New() returned nil")
}
}
// TestStart_WithGoroutinePool 测试 GoroutinePool 配置
func TestStart_WithGoroutinePool(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
Performance: config.PerformanceConfig{
GoroutinePool: config.GoroutinePoolConfig{
Enabled: true,
MaxWorkers: 100,
MinWorkers: 10,
IdleTimeout: 30 * time.Second,
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("New() returned nil")
}
}
// TestStart_WithFileCache 测试文件缓存配置
func TestStart_WithFileCache(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
Performance: config.PerformanceConfig{
FileCache: config.FileCacheConfig{
MaxEntries: 1000,
MaxSize: 100 * 1024 * 1024,
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("New() returned nil")
}
}
// TestStop_Graceful 测试优雅停止(无 race 模式)
func TestStop_Graceful(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
// 在未启动时调用 GracefulStop应返回 nil
err := s.GracefulStop(1 * time.Second)
if err != nil {
t.Errorf("GracefulStop() on non-started server returned error: %v", err)
}
}
// TestGetTLSConfig_Nil 测试无 TLS 配置
func TestGetTLSConfig_Nil(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
tlsCfg, err := s.GetTLSConfig()
if err == nil {
t.Error("GetTLSConfig() should return error when TLS not configured")
}
if tlsCfg != nil {
t.Error("GetTLSConfig() should return nil when TLS not configured")
}
}
// TestGetTLSConfig_NilServer 测试 nil 服务器调用 GetTLSConfig
func TestGetTLSConfig_NilServer(t *testing.T) {
var s *Server
// 防御性:如果 s 为 nil调用方法会 panic这是预期的行为
// 这里我们只测试非 nil 但 tlsManager 为 nil 的情况
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s = New(cfg)
// 确保 tlsManager 为 nil
if s.tlsManager != nil {
t.Skip("tlsManager should be nil initially")
}
tlsCfg, err := s.GetTLSConfig()
if err == nil {
t.Error("Expected error when tlsManager is nil")
}
if tlsCfg != nil {
t.Error("Expected nil TLS config when tlsManager is nil")
}
if err.Error() != "TLS not configured" {
t.Errorf("Expected error 'TLS not configured', got: %v", err)
}
}
// TestGetServerName 测试服务器名称返回。
func TestGetServerName(t *testing.T) {
tests := []struct {
name string
cfg *config.ServerConfig
wantName string
}{
{
name: "nil config",
cfg: nil,
wantName: "lolly/" + version.Version,
},
{
name: "ServerTokens true (default)",
cfg: &config.ServerConfig{
ServerTokens: true,
},
wantName: "lolly/" + version.Version,
},
{
name: "ServerTokens false",
cfg: &config.ServerConfig{
ServerTokens: false,
},
wantName: "lolly",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Server{}
got := s.getServerName(tt.cfg)
if got != tt.wantName {
t.Errorf("getServerName() = %q, want %q", got, tt.wantName)
}
})
}
}
// TestApplyTypesConfig 测试 MIME 类型配置应用。
func TestApplyTypesConfig(t *testing.T) {
t.Run("nil config", func(t *testing.T) {
s := &Server{}
// 不应该 panic
s.applyTypesConfig(nil)
})
t.Run("empty config", func(t *testing.T) {
s := &Server{}
cfg := &config.ServerConfig{}
// 不应该 panic
s.applyTypesConfig(cfg)
})
t.Run("with types map", func(t *testing.T) {
s := &Server{}
cfg := &config.ServerConfig{
Types: config.TypesConfig{
Map: map[string]string{
".custom": "application/x-custom",
},
},
}
// 不应该 panic
s.applyTypesConfig(cfg)
})
t.Run("with default type", func(t *testing.T) {
s := &Server{}
cfg := &config.ServerConfig{
Types: config.TypesConfig{
DefaultType: "application/octet-stream",
},
}
// 不应该 panic
s.applyTypesConfig(cfg)
})
}
// TestCreateListener_TCP 测试 TCP 监听器创建。
func TestCreateListener_TCP(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0", // 随机端口
}},
}
s := New(cfg)
ln, err := s.createListener(&cfg.Servers[0])
if err != nil {
t.Fatalf("createListener() error: %v", err)
}
if ln == nil {
t.Fatal("createListener() returned nil listener")
}
defer ln.Close()
if ln.Addr().Network() != "tcp" {
t.Errorf("expected tcp network, got %s", ln.Addr().Network())
}
}
// TestCreateListener_UnixSocket 测试 Unix Socket 监听器创建。
func TestCreateListener_UnixSocket(t *testing.T) {
tempDir := t.TempDir()
socketPath := tempDir + "/test.sock"
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "unix:" + socketPath,
}},
}
s := New(cfg)
ln, err := s.createListener(&cfg.Servers[0])
if err != nil {
t.Fatalf("createListener() error: %v", err)
}
if ln == nil {
t.Fatal("createListener() returned nil listener")
}
defer ln.Close()
if ln.Addr().Network() != "unix" {
t.Errorf("expected unix network, got %s", ln.Addr().Network())
}
}
// TestCreateListener_UnixSocketWithPermissions 测试带权限的 Unix Socket 创建。
func TestCreateListener_UnixSocketWithPermissions(t *testing.T) {
tempDir := t.TempDir()
socketPath := tempDir + "/test_perm.sock"
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "unix:" + socketPath,
UnixSocket: config.UnixSocketConfig{
Mode: 0o600,
User: "nobody",
Group: "nobody",
},
}},
}
s := New(cfg)
ln, err := s.createListener(&cfg.Servers[0])
if err != nil {
t.Fatalf("createListener() error: %v", err)
}
if ln == nil {
t.Fatal("createListener() returned nil listener")
}
defer ln.Close()
}
// TestCreateListener_UnixSocketCleanup 测试 Unix Socket 文件清理。
func TestCreateListener_UnixSocketCleanup(t *testing.T) {
tempDir := t.TempDir()
socketPath := tempDir + "/cleanup.sock"
// 先创建一个已存在的 socket 文件
if err := os.WriteFile(socketPath, []byte{}, 0o666); err != nil {
t.Fatalf("failed to create existing socket file: %v", err)
}
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "unix:" + socketPath,
}},
}
s := New(cfg)
ln, err := s.createListener(&cfg.Servers[0])
if err != nil {
t.Fatalf("createListener() error: %v", err)
}
defer ln.Close()
}
func TestHandleRegistrationError_ConflictWarning(t *testing.T) {
s := &Server{}
err := s.handleRegistrationError("proxy", "/api",
&matcher.ConflictError{Path: "/api", ExistingType: "exact", NewType: "prefix"})
if err != nil {
t.Errorf("conflict should return nil, got: %v", err)
}
}
func TestHandleRegistrationError_FatalError(t *testing.T) {
s := &Server{}
err := s.handleRegistrationError("proxy", "/api",
fmt.Errorf("invalid regex pattern: missing closing parenthesis"))
if err == nil {
t.Error("fatal error should return non-nil")
}
if !strings.Contains(err.Error(), "proxy route /api") {
t.Errorf("error should wrap context, got: %v", err)
}
}
func TestDupListener_TCP(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
duped, err := DupListener(ln)
if err != nil {
t.Fatalf("DupListener() error: %v", err)
}
defer duped.Close()
if duped.Addr().Network() != "tcp" {
t.Errorf("expected tcp, got %s", duped.Addr().Network())
}
if duped.Addr().String() != ln.Addr().String() {
t.Errorf("expected same address %s, got %s", ln.Addr().String(), duped.Addr().String())
}
}
func TestDupListener_Unix(t *testing.T) {
dir := t.TempDir()
socketPath := dir + "/dup.sock"
ln, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatal(err)
}
defer ln.Close()
duped, err := DupListener(ln)
if err != nil {
t.Fatalf("DupListener() error: %v", err)
}
defer duped.Close()
}
func TestDupListener_Unsupported(t *testing.T) {
_, err := DupListener(struct{ net.Listener }{})
if err == nil {
t.Error("expected error for unsupported type")
}
}
func TestTcpAddrMatch(t *testing.T) {
s := &Server{}
tests := []struct {
inherited string
target string
want bool
}{
{"127.0.0.1:8080", "127.0.0.1:8080", true},
{"0.0.0.0:8080", ":8080", true},
{"[::]:8080", ":8080", true},
{"0.0.0.0:8080", "0.0.0.0:8080", true},
{"0.0.0.0:8080", "127.0.0.1:8080", true},
{"127.0.0.1:8080", "0.0.0.0:8080", true},
{"127.0.0.1:8080", ":9090", false},
{"127.0.0.1:8080", "192.168.1.1:8080", false},
}
for _, tt := range tests {
got := s.tcpAddrMatch(tt.inherited, tt.target)
if got != tt.want {
t.Errorf("tcpAddrMatch(%q, %q) = %v, want %v", tt.inherited, tt.target, got, tt.want)
}
}
}
func TestMatchInheritedListener_TCP(t *testing.T) {
s := &Server{}
ln1, _ := net.Listen("tcp", "127.0.0.1:0")
defer ln1.Close()
ln2, _ := net.Listen("tcp", "127.0.0.1:0")
defer ln2.Close()
inherited := []net.Listener{ln1, ln2}
result := s.matchInheritedListener(inherited, "0.0.0.0:99999")
if result != nil {
t.Error("expected nil for non-matching address")
}
addr1 := ln1.Addr().String()
result = s.matchInheritedListener(inherited, addr1)
if result != ln1 {
t.Errorf("expected ln1 for address %s", addr1)
}
}
func TestMatchInheritedListener_Empty(t *testing.T) {
s := &Server{}
result := s.matchInheritedListener(nil, ":8080")
if result != nil {
t.Error("expected nil for empty inherited list")
}
}
func TestMatchInheritedListener_PresetListeners(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{Listen: "127.0.0.1:0"}},
}
s := New(cfg)
ln, err := s.createListener(&cfg.Servers[0])
if err != nil {
t.Fatal(err)
}
defer ln.Close()
s.SetListeners([]net.Listener{ln})
addr := ln.Addr().String()
cfg.Servers[0].Listen = addr
matched, err := s.createListener(&cfg.Servers[0])
if err != nil {
t.Fatalf("createListener with preset should reuse: %v", err)
}
if matched == nil {
t.Fatal("expected non-nil listener from preset match")
}
if matched.Addr().String() != addr {
t.Errorf("expected same address %s, got %s", addr, matched.Addr().String())
}
}
// TestServer_StatsMethods 测试服务器统计方法。
func TestServer_StatsMethods(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 测试 startTime 初始值
if !s.startTime.IsZero() {
t.Error("startTime should be zero initially")
}
// 设置 startTime
s.startTime = time.Now()
if s.startTime.IsZero() {
t.Error("startTime should not be zero after setting")
}
// 测试统计值
if s.connections.Load() != 0 {
t.Error("initial connections should be 0")
}
if s.requests.Load() != 0 {
t.Error("initial requests should be 0")
}
if s.bytesSent.Load() != 0 {
t.Error("initial bytesSent should be 0")
}
if s.bytesReceived.Load() != 0 {
t.Error("initial bytesReceived should be 0")
}
}
// TestServer_SetResolver 测试设置解析器。
func TestServer_SetResolver(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 设置 nil resolver
s.SetResolver(nil)
if s.resolver != nil {
t.Error("resolver should be nil")
}
}
// TestServer_SetUpgradeManager 测试设置升级管理器。
func TestServer_SetUpgradeManager(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 设置 nil upgrade manager
s.SetUpgradeManager(nil)
if s.upgradeManager != nil {
t.Error("upgradeManager should be nil")
}
// 设置实际的 upgrade manager
um := NewUpgradeManager(s)
s.SetUpgradeManager(um)
if s.upgradeManager == nil {
t.Error("upgradeManager should not be nil after setting")
}
}
// TestServer_GetResolver 测试获取解析器。
// TestServer_StopWithTimeout_WithListeners 测试带监听器的停止。
func TestServer_StopWithTimeout_WithListeners(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
// 创建监听器
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
s.listeners = []net.Listener{ln}
// 调用停止
err = s.StopWithTimeout(1 * time.Second)
if err != nil {
t.Errorf("StopWithTimeout failed: %v", err)
}
}
// TestServer_GracefulStop_WithListeners 测试带监听器的优雅停止。
func TestServer_GracefulStop_WithListeners(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
// 创建监听器
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
s.listeners = []net.Listener{ln}
// 调用优雅停止
err = s.GracefulStop(1 * time.Second)
if err != nil {
t.Errorf("GracefulStop failed: %v", err)
}
}
// TestServer_StopWithTimeout_WithFastServer 测试带 fastServer 的停止。
func TestServer_StopWithTimeout_WithFastServer(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
// 创建 mock fastServer
s.fastServer = &fasthttp.Server{}
// 调用停止
err := s.StopWithTimeout(1 * time.Second)
if err != nil {
t.Errorf("StopWithTimeout failed: %v", err)
}
}
// TestBuildMiddlewareChain_BodyLimit 测试请求体限制中间件。
func TestBuildMiddlewareChain_BodyLimit(t *testing.T) {
cfg := &config.Config{
Logging: config.LoggingConfig{},
Servers: []config.ServerConfig{{
Listen: ":8080",
ClientMaxBodySize: "1MB",
}},
}
s := New(cfg)
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
if err != nil {
t.Errorf("buildMiddlewareChain failed: %v", err)
}
if chain == nil {
t.Error("Expected non-nil chain")
}
}
// TestBuildMiddlewareChain_ErrorIntercept 测试错误拦截中间件。
func TestBuildMiddlewareChain_ErrorIntercept(t *testing.T) {
cfg := &config.Config{
Logging: config.LoggingConfig{},
Servers: []config.ServerConfig{{
Listen: ":8080",
Security: config.SecurityConfig{
ErrorPage: config.ErrorPageConfig{
Pages: map[int]string{
404: "/404.html",
},
},
},
}},
}
s := New(cfg)
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
if err != nil {
t.Errorf("buildMiddlewareChain failed: %v", err)
}
if chain == nil {
t.Error("Expected non-nil chain")
}
}
// TestBuildMiddlewareChain_NilServerConfig 测试 nil 服务器配置。
// 注意buildMiddlewareChain 不接受 nil所以这个测试验证空配置。
func TestBuildMiddlewareChain_NilServerConfig(t *testing.T) {
cfg := &config.Config{
Logging: config.LoggingConfig{},
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
if err != nil {
t.Errorf("buildMiddlewareChain failed: %v", err)
}
if chain == nil {
t.Error("Expected non-nil chain")
}
}
// TestServer_StatusCode_MethodNotAllowed 测试不支持的 HTTP 方法。
func TestServer_StatusCode_MethodNotAllowed(t *testing.T) {
// 简单验证
if fasthttp.StatusMethodNotAllowed != 405 {
t.Errorf("StatusMethodNotAllowed should be 405")
}
}
// TestServer_ConnectionTracking 测试连接追踪。
func TestServer_ConnectionTracking(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 测试原子操作
initial := s.connections.Load()
s.connections.Add(1)
if s.connections.Load() != initial+1 {
t.Error("connections should have incremented")
}
s.connections.Add(-1)
if s.connections.Load() != initial {
t.Error("connections should be back to initial")
}
}
// TestServer_RequestTracking 测试请求追踪。
func TestServer_RequestTracking(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 测试原子操作
s.requests.Add(5)
if s.requests.Load() != 5 {
t.Errorf("expected 5 requests, got %d", s.requests.Load())
}
}
// TestServer_BytesTracking 测试字节追踪。
func TestServer_BytesTracking(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 测试原子操作
s.bytesSent.Add(1024)
s.bytesReceived.Add(512)
if s.bytesSent.Load() != 1024 {
t.Errorf("expected 1024 bytes sent, got %d", s.bytesSent.Load())
}
if s.bytesReceived.Load() != 512 {
t.Errorf("expected 512 bytes received, got %d", s.bytesReceived.Load())
}
}
// TestServer_GracefulStop_WithFastServers 测试带多个 fastServer 的优雅停止。
func TestServer_GracefulStop_WithFastServers(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
// 创建多个 fastServer
s.fastServers = []*fasthttp.Server{
{},
{},
}
// 调用优雅停止
err := s.GracefulStop(1 * time.Second)
if err != nil {
t.Errorf("GracefulStop failed: %v", err)
}
}
// TestServer_StopWithTimeout_WithFastServers 测试带多个 fastServer 的停止。
func TestServer_StopWithTimeout_WithFastServers(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
// 创建多个 fastServer
s.fastServers = []*fasthttp.Server{
{},
{},
}
// 调用停止
err := s.StopWithTimeout(1 * time.Second)
if err != nil {
t.Errorf("StopWithTimeout failed: %v", err)
}
}
// TestServer_GetProxyCacheStats_WithProxies 测试带代理的缓存统计。
func TestServer_GetProxyCacheStats_WithProxies(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
s.proxies = nil // 确保 proxies 为 nil
// 无代理时应返回空统计
stats := s.getProxyCacheStats()
if stats.Entries != 0 {
t.Errorf("Expected 0 entries, got %d", stats.Entries)
}
}
// TestServer_GetProxyCacheStats_SingleProxyWithCache 测试单个代理带缓存的统计。
func TestServer_GetProxyCacheStats_SingleProxyWithCache(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 创建带缓存的代理
proxyCfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
}
targets := testutil.NewTestTargets("http://localhost:8080")
p, err := proxy.NewProxy(proxyCfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
s.proxies = []*proxy.Proxy{p}
// 获取统计
stats := s.getProxyCacheStats()
// 新创建的缓存应该有 0 条目
if stats.Entries < 0 {
t.Errorf("Expected non-negative entries, got %d", stats.Entries)
}
if stats.Pending < 0 {
t.Errorf("Expected non-negative pending, got %d", stats.Pending)
}
}
// TestServer_GetProxyCacheStats_SingleProxyNoCache 测试单个代理无缓存的统计。
func TestServer_GetProxyCacheStats_SingleProxyNoCache(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 创建不带缓存的代理
proxyCfg := testutil.NewTestProxyConfig("/api")
targets := testutil.NewTestTargets("http://localhost:8080")
p, err := proxy.NewProxy(proxyCfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
s.proxies = []*proxy.Proxy{p}
// 获取统计
stats := s.getProxyCacheStats()
// 无缓存时应返回 0
if stats.Entries != 0 {
t.Errorf("Expected 0 entries for proxy without cache, got %d", stats.Entries)
}
if stats.Pending != 0 {
t.Errorf("Expected 0 pending for proxy without cache, got %d", stats.Pending)
}
}
// TestServer_GetProxyCacheStats_MultipleProxies 测试多个代理的缓存统计聚合。
func TestServer_GetProxyCacheStats_MultipleProxies(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
// 创建多个代理:部分带缓存,部分不带
targets := testutil.NewTestTargets("http://localhost:8080")
// 代理1带缓存
proxyCfg1 := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
}
p1, err := proxy.NewProxy(proxyCfg1, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 代理2不带缓存
proxyCfg2 := testutil.NewTestProxyConfig("/static")
p2, err := proxy.NewProxy(proxyCfg2, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
// 代理3带缓存
proxyCfg3 := &config.ProxyConfig{
Path: "/data",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 20 * time.Second,
},
}
p3, err := proxy.NewProxy(proxyCfg3, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
s.proxies = []*proxy.Proxy{p1, p2, p3}
// 获取聚合统计
stats := s.getProxyCacheStats()
// 统计应该非负
if stats.Entries < 0 {
t.Errorf("Expected non-negative entries, got %d", stats.Entries)
}
if stats.Pending < 0 {
t.Errorf("Expected non-negative pending, got %d", stats.Pending)
}
}
// TestServer_GetProxyCacheStats_AllProxiesWithCache 测试所有代理都有缓存的统计。
func TestServer_GetProxyCacheStats_AllProxiesWithCache(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
targets := testutil.NewTestTargets("http://localhost:8080")
// 创建多个带缓存的代理
proxies := make([]*proxy.Proxy, 3)
for i := range 3 {
proxyCfg := &config.ProxyConfig{
Path: fmt.Sprintf("/api%d", i),
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
}
p, err := proxy.NewProxy(proxyCfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
proxies[i] = p
}
s.proxies = proxies
// 获取统计
stats := s.getProxyCacheStats()
// 应该聚合所有代理的统计
if stats.Entries < 0 {
t.Errorf("Expected non-negative entries, got %d", stats.Entries)
}
if stats.Pending < 0 {
t.Errorf("Expected non-negative pending, got %d", stats.Pending)
}
}
// TestServer_GetProxyCacheStats_AllProxiesNoCache 测试所有代理都没有缓存的统计。
func TestServer_GetProxyCacheStats_AllProxiesNoCache(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
targets := testutil.NewTestTargets("http://localhost:8080")
// 创建多个不带缓存的代理
proxies := make([]*proxy.Proxy, 3)
for i := range 3 {
proxyCfg := testutil.NewTestProxyConfig(fmt.Sprintf("/api%d", i))
p, err := proxy.NewProxy(proxyCfg, targets, nil, nil)
if err != nil {
t.Fatalf("NewProxy() error: %v", err)
}
proxies[i] = p
}
s.proxies = proxies
// 获取统计
stats := s.getProxyCacheStats()
// 所有代理都没有缓存,应该返回 0
if stats.Entries != 0 {
t.Errorf("Expected 0 entries, got %d", stats.Entries)
}
if stats.Pending != 0 {
t.Errorf("Expected 0 pending, got %d", stats.Pending)
}
}
// TestServer_GetProxyCacheStats_EmptyProxiesSlice 测试空代理切片的统计。
func TestServer_GetProxyCacheStats_EmptyProxiesSlice(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
s.proxies = []*proxy.Proxy{} // 空切片
// 获取统计
stats := s.getProxyCacheStats()
if stats.Entries != 0 {
t.Errorf("Expected 0 entries, got %d", stats.Entries)
}
if stats.Pending != 0 {
t.Errorf("Expected 0 pending, got %d", stats.Pending)
}
}
// TestServer_MultipleListeners 测试多个监听器。
func TestServer_MultipleListeners(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
// 创建多个监听器
ln1, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
ln2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to create listener: %v", err)
_ = ln1.Close()
}
s.listeners = []net.Listener{ln1, ln2}
// 验证可以获取监听器
got := s.GetListeners()
if len(got) != 2 {
t.Errorf("expected 2 listeners, got %d", len(got))
}
// 清理
_ = s.StopWithTimeout(1 * time.Second)
}
// TestGracefulStop_RunningState 测试 GracefulStop 设置 running 为 false。
func TestGracefulStop_RunningState(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
if !s.running.Load() {
t.Fatal("running should be true before GracefulStop")
}
err := s.GracefulStop(1 * time.Second)
if err != nil {
t.Errorf("GracefulStop failed: %v", err)
}
if s.running.Load() {
t.Error("running should be false after GracefulStop")
}
}
// TestGracefulStop_WithPool 测试 GracefulStop 停止 GoroutinePool。
func TestGracefulStop_WithPool(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
Performance: config.PerformanceConfig{
GoroutinePool: config.GoroutinePoolConfig{
Enabled: true,
MaxWorkers: 10,
MinWorkers: 2,
IdleTimeout: 5 * time.Second,
},
},
}
s := New(cfg)
s.running.Store(true)
// 初始化并启动 pool
s.pool = initGoroutinePool(&cfg.Performance)
if s.pool != nil {
s.pool.Start()
}
err := s.GracefulStop(1 * time.Second)
if err != nil {
t.Errorf("GracefulStop failed: %v", err)
}
}
// TestGracefulStop_WithHealthCheckers 测试 GracefulStop 停止健康检查器。
func TestGracefulStop_WithHealthCheckers(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
// 创建 mock healthChecker (使用 nil因为我们只测试循环不会 panic)
s.healthCheckers = nil
err := s.GracefulStop(1 * time.Second)
if err != nil {
t.Errorf("GracefulStop failed: %v", err)
}
}
// TestGracefulStop_WithAccessLog 测试 GracefulStop 关闭访问日志。
func TestGracefulStop_WithAccessLog(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
// 创建 accessLogMiddleware
s.accessLogMiddleware = accesslog.New(&config.LoggingConfig{})
err := s.GracefulStop(1 * time.Second)
if err != nil {
t.Errorf("GracefulStop failed: %v", err)
}
}
// TestGracefulStop_WithTLSManager 测试 GracefulStop 关闭 TLS 管理器。
func TestGracefulStop_WithTLSManager(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
// 创建临时证书文件
tempDir := t.TempDir()
certFile := tempDir + "/cert.pem"
keyFile := tempDir + "/key.pem"
// 生成自签名证书用于测试
if err := generateTestCert(certFile, keyFile); err != nil {
t.Skipf("failed to generate test cert: %v", err)
}
tlsMgr, err := ssl.NewTLSManager(&config.SSLConfig{
Cert: certFile,
Key: keyFile,
})
if err != nil {
t.Skipf("failed to create TLS manager: %v", err)
}
s.tlsManager = tlsMgr
err = s.GracefulStop(1 * time.Second)
if err != nil {
t.Errorf("GracefulStop failed: %v", err)
}
}
// TestGracefulStop_WithLuaEngine 测试 GracefulStop 关闭 Lua 引擎。
func TestGracefulStop_WithLuaEngine(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
// 创建 Lua 引擎
luaEngine, err := lua.NewEngine(lua.DefaultConfig())
if err != nil {
t.Skipf("failed to create Lua engine: %v", err)
}
s.luaEngine = luaEngine
err = s.GracefulStop(1 * time.Second)
if err != nil {
t.Errorf("GracefulStop failed: %v", err)
}
}
// TestGracefulStop_Timeout 测试 GracefulStop 超时场景。
func TestGracefulStop_Timeout(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
// 创建一个真实的 fastServer但通过模拟长时间关闭来测试超时
s.fastServer = &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetBodyString("test")
},
}
// 使用非常短的超时
err := s.GracefulStop(1 * time.Nanosecond)
// 超时可能返回 context.DeadlineExceeded 或 nil取决于关闭速度
if err != nil && err != context.DeadlineExceeded {
t.Errorf("unexpected error: %v", err)
}
}
// TestGracefulStop_AllComponents 测试 GracefulStop 关闭所有组件。
func TestGracefulStop_AllComponents(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
Performance: config.PerformanceConfig{
GoroutinePool: config.GoroutinePoolConfig{
Enabled: true,
MaxWorkers: 10,
IdleTimeout: 5 * time.Second,
},
},
}
s := New(cfg)
s.running.Store(true)
// 初始化所有组件
s.pool = initGoroutinePool(&cfg.Performance)
if s.pool != nil {
s.pool.Start()
}
s.accessLogMiddleware = accesslog.New(&config.LoggingConfig{})
// 创建监听器
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
s.listeners = []net.Listener{ln}
err = s.GracefulStop(2 * time.Second)
if err != nil {
t.Errorf("GracefulStop failed: %v", err)
}
// 验证 running 状态
if s.running.Load() {
t.Error("running should be false after GracefulStop")
}
}
// generateTestCert 生成测试用的自签名证书。
func generateTestCert(certFile, keyFile string) error {
// 简化实现:跳过证书生成
return fmt.Errorf("test cert generation not implemented")
}
// TestGracefulStop_WithAccessControl 测试 GracefulStop 关闭访问控制。
func TestGracefulStop_WithAccessControl(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
Security: config.SecurityConfig{
Access: config.AccessConfig{
Allow: []string{"127.0.0.1"},
},
},
}},
}
s := New(cfg)
s.running.Store(true)
// 创建 AccessControl
ac, err := security.NewAccessControl(&cfg.Servers[0].Security.Access)
if err != nil {
t.Skipf("failed to create AccessControl: %v", err)
}
s.accessControl = ac
err = s.GracefulStop(1 * time.Second)
if err != nil {
t.Errorf("GracefulStop failed: %v", err)
}
}
// TestGracefulStop_ContextCancelled 测试 GracefulStop 上下文取消场景。
func TestGracefulStop_ContextCancelled(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
// 创建一个监听中的服务器
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
s.listeners = []net.Listener{ln}
// 创建 fastServer 并开始服务
s.fastServer = &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
time.Sleep(100 * time.Millisecond) // 模拟慢请求
ctx.SetBodyString("ok")
},
}
// 启动服务器
go func() {
_ = s.fastServer.Serve(ln)
}()
// 等待服务器启动
time.Sleep(10 * time.Millisecond)
// 使用非常短的超时测试超时场景
err = s.GracefulStop(1 * time.Nanosecond)
// 超时可能返回 context.DeadlineExceeded 或 nil
if err != nil && err != context.DeadlineExceeded {
t.Errorf("unexpected error: %v", err)
}
}
// TestGracefulStop_MultipleHealthCheckers 测试 GracefulStop 停止多个健康检查器。
func TestGracefulStop_MultipleHealthCheckers(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
// 创建多个 mock healthChecker
// 注意:这里使用 nil slice 测试空循环不会 panic
s.healthCheckers = make([]*proxy.HealthChecker, 0)
err := s.GracefulStop(1 * time.Second)
if err != nil {
t.Errorf("GracefulStop failed: %v", err)
}
}
// TestGracefulStop_NilComponents 测试 GracefulStop 所有组件为 nil。
func TestGracefulStop_NilComponents(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
// 确保所有组件为 nil
s.pool = nil
s.healthCheckers = nil
s.accessLogMiddleware = nil
s.tlsManager = nil
s.accessControl = nil
s.luaEngine = nil
s.fastServer = nil
s.fastServers = nil
err := s.GracefulStop(1 * time.Second)
if err != nil {
t.Errorf("GracefulStop failed: %v", err)
}
if s.running.Load() {
t.Error("running should be false after GracefulStop")
}
}
// TestGracefulStop_FastServersWithNil 测试 GracefulStop 处理 fastServers 中的 nil。
func TestGracefulStop_FastServersWithNil(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
// 创建包含 nil 的 fastServers
s.fastServers = []*fasthttp.Server{nil, {}, nil}
err := s.GracefulStop(1 * time.Second)
if err != nil {
t.Errorf("GracefulStop failed: %v", err)
}
}
// TestGracefulStop_ZeroTimeout 测试 GracefulStop 零超时。
func TestGracefulStop_ZeroTimeout(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
s.fastServer = &fasthttp.Server{}
err := s.GracefulStop(0)
// 零超时应该立即返回(可能导致超时错误或成功关闭)
if err != nil && err != context.DeadlineExceeded {
t.Errorf("unexpected error: %v", err)
}
}
// TestGracefulStop_NegativeTimeout 测试 GracefulStop 负超时。
func TestGracefulStop_NegativeTimeout(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":0",
}},
}
s := New(cfg)
s.running.Store(true)
s.fastServer = &fasthttp.Server{}
err := s.GracefulStop(-1 * time.Second)
// 负超时应该立即返回
if err != nil && err != context.DeadlineExceeded {
t.Errorf("unexpected error: %v", err)
}
}
// TestStartSingleMode_StaticFiles 测试 startSingleMode 静态文件配置。
func TestStartSingleMode_StaticFiles(t *testing.T) {
tempDir := t.TempDir()
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
Static: []config.StaticConfig{
{
Path: "/static",
Root: tempDir,
Index: []string{"index.html"},
},
{
Path: "/assets",
Root: tempDir,
LocationType: "exact",
SymlinkCheck: true,
Internal: true,
TryFiles: []string{"$uri", "/fallback.html"},
TryFilesPass: true,
},
},
}},
}
s := New(cfg)
// 验证静态文件配置已正确设置
if len(s.config.Servers[0].Static) != 2 {
t.Errorf("expected 2 static configs, got %d", len(s.config.Servers[0].Static))
}
// 验证第一个静态配置
static1 := s.config.Servers[0].Static[0]
if static1.Path != "/static" {
t.Errorf("expected path /static, got %s", static1.Path)
}
if static1.Root != tempDir {
t.Errorf("expected root %s, got %s", tempDir, static1.Root)
}
}
// TestStartSingleMode_StaticFilesWithGzipStatic 测试静态文件 gzip 预压缩配置。
func TestStartSingleMode_StaticFilesWithGzipStatic(t *testing.T) {
tempDir := t.TempDir()
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
Static: []config.StaticConfig{
{
Path: "/",
Root: tempDir,
Index: []string{"index.html"},
},
},
Compression: config.CompressionConfig{
Type: "gzip",
Level: 6,
GzipStatic: true,
GzipStaticExtensions: []string{".html", ".css", ".js"},
},
}},
}
s := New(cfg)
// 验证 gzip 静态配置
if !s.config.Servers[0].Compression.GzipStatic {
t.Error("expected GzipStatic to be true")
}
if len(s.config.Servers[0].Compression.GzipStaticExtensions) != 3 {
t.Errorf("expected 3 extensions, got %d", len(s.config.Servers[0].Compression.GzipStaticExtensions))
}
}
// TestStartSingleMode_ProxyWithLocationTypes 测试代理配置的不同位置类型。
func TestStartSingleMode_ProxyWithLocationTypes(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api/exact",
LocationType: "exact",
Targets: []config.ProxyTarget{
{URL: "http://127.0.0.1:8081", Weight: 1},
},
},
{
Path: "/api/priority",
LocationType: "prefix_priority",
Targets: []config.ProxyTarget{
{URL: "http://127.0.0.1:8082", Weight: 1},
},
},
{
Path: "^/api/regex/(.*)$",
LocationType: "regex",
Targets: []config.ProxyTarget{
{URL: "http://127.0.0.1:8083", Weight: 1},
},
},
{
Path: "^/api/caseless/(.*)$",
LocationType: "regex_caseless",
Targets: []config.ProxyTarget{
{URL: "http://127.0.0.1:8084", Weight: 1},
},
},
{
Path: "/api/named",
LocationType: "named",
LocationName: "@api_named",
Targets: []config.ProxyTarget{
{URL: "http://127.0.0.1:8085", Weight: 1},
},
},
{
Path: "/api/default",
// 默认 prefix 类型
Targets: []config.ProxyTarget{
{URL: "http://127.0.0.1:8086", Weight: 1},
},
Internal: true,
},
},
}},
}
s := New(cfg)
// 验证代理配置数量
if len(s.config.Servers[0].Proxy) != 6 {
t.Errorf("expected 6 proxy configs, got %d", len(s.config.Servers[0].Proxy))
}
// 验证不同位置类型
proxyTypes := []string{"exact", "prefix_priority", "regex", "regex_caseless", "named", ""}
for i, pt := range proxyTypes {
if s.config.Servers[0].Proxy[i].LocationType != pt {
t.Errorf("proxy[%d]: expected location type %s, got %s", i, pt, s.config.Servers[0].Proxy[i].LocationType)
}
}
}
// TestStartSingleMode_ProxyWithHealthCheck 测试代理健康检查配置。
func TestStartSingleMode_ProxyWithHealthCheck(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{
URL: "http://127.0.0.1:8081",
Weight: 3,
MaxFails: 3,
FailTimeout: 10 * time.Second,
MaxConns: 100,
Backup: false,
Down: false,
},
{
URL: "http://127.0.0.1:8082",
Weight: 1,
Backup: true,
},
},
LoadBalance: "weighted_round_robin",
HealthCheck: config.HealthCheckConfig{
Interval: 10 * time.Second,
Timeout: 5 * time.Second,
Path: "/health",
},
},
},
}},
}
s := New(cfg)
// 验证健康检查配置
hc := s.config.Servers[0].Proxy[0].HealthCheck
if hc.Interval != 10*time.Second {
t.Errorf("expected interval 10s, got %v", hc.Interval)
}
if hc.Path != "/health" {
t.Errorf("expected path /health, got %s", hc.Path)
}
}
// TestStartSingleMode_MonitoringEndpoints 测试监控端点配置。
func TestStartSingleMode_MonitoringEndpoints(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
}},
Monitoring: config.MonitoringConfig{
Status: config.StatusConfig{
Enabled: true,
Path: "/_status",
Format: "json",
Allow: []string{"127.0.0.1", "192.168.0.0/16"},
},
Pprof: config.PprofConfig{
Enabled: true,
Path: "/debug/pprof",
Allow: []string{"127.0.0.1"},
},
},
}
s := New(cfg)
// 验证状态端点配置
if !s.config.Monitoring.Status.Enabled {
t.Error("expected status enabled")
}
if s.config.Monitoring.Status.Path != "/_status" {
t.Errorf("expected status path /_status, got %s", s.config.Monitoring.Status.Path)
}
if len(s.config.Monitoring.Status.Allow) != 2 {
t.Errorf("expected 2 allowed IPs, got %d", len(s.config.Monitoring.Status.Allow))
}
// 验证 pprof 配置
if !s.config.Monitoring.Pprof.Enabled {
t.Error("expected pprof enabled")
}
}
// TestStartSingleMode_CacheAPI 测试缓存 API 配置。
func TestStartSingleMode_CacheAPI(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
CacheAPI: &config.CacheAPIConfig{
Enabled: true,
Path: "/_cache/purge",
Allow: []string{"127.0.0.1"},
Auth: config.CacheAPIAuthConfig{Type: "token", Token: "secret-token"},
},
}},
}
s := New(cfg)
// 验证缓存 API 配置
if s.config.Servers[0].CacheAPI == nil || !s.config.Servers[0].CacheAPI.Enabled {
t.Error("expected cache API enabled")
}
if s.config.Servers[0].CacheAPI.Path != "/_cache/purge" {
t.Errorf("expected path /_cache/purge, got %s", s.config.Servers[0].CacheAPI.Path)
}
}
// TestStartSingleMode_TLSConfig 测试 TLS 配置。
func TestStartSingleMode_TLSConfig(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
SSL: config.SSLConfig{
Cert: "/path/to/cert.pem",
Key: "/path/to/key.pem",
Protocols: []string{"TLSv1.2", "TLSv1.3"},
Ciphers: []string{"TLS_AES_128_GCM_SHA256"},
HSTS: config.HSTSConfig{
MaxAge: 31536000,
IncludeSubDomains: true,
Preload: true,
},
},
}},
}
s := New(cfg)
// 验证 SSL 配置
if s.config.Servers[0].SSL.Cert != "/path/to/cert.pem" {
t.Errorf("expected cert path, got %s", s.config.Servers[0].SSL.Cert)
}
if s.config.Servers[0].SSL.HSTS.MaxAge != 31536000 {
t.Errorf("expected HSTS MaxAge 31536000, got %d", s.config.Servers[0].SSL.HSTS.MaxAge)
}
}
// TestStartSingleMode_MIMETypes 测试 MIME 类型配置。
func TestStartSingleMode_MIMETypes(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
Types: config.TypesConfig{
Map: map[string]string{
".wasm": "application/wasm",
".custom": "application/x-custom",
},
DefaultType: "application/octet-stream",
},
}},
}
s := New(cfg)
// 验证 MIME 类型配置
if len(s.config.Servers[0].Types.Map) != 2 {
t.Errorf("expected 2 MIME types, got %d", len(s.config.Servers[0].Types.Map))
}
if s.config.Servers[0].Types.DefaultType != "application/octet-stream" {
t.Errorf("expected default type, got %s", s.config.Servers[0].Types.DefaultType)
}
}
// TestStartSingleMode_ServerOptions 测试服务器选项配置。
func TestStartSingleMode_ServerOptions(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
MaxConnsPerIP: 100,
MaxRequestsPerConn: 1000,
Concurrency: 256 * 1024,
ReadBufferSize: 16 * 1024,
WriteBufferSize: 16 * 1024,
ReduceMemoryUsage: true,
ServerTokens: false,
}},
}
s := New(cfg)
// 验证服务器选项
sc := s.config.Servers[0]
if sc.ReadTimeout != 30*time.Second {
t.Errorf("expected ReadTimeout 30s, got %v", sc.ReadTimeout)
}
if sc.MaxConnsPerIP != 100 {
t.Errorf("expected MaxConnsPerIP 100, got %d", sc.MaxConnsPerIP)
}
if !sc.ReduceMemoryUsage {
t.Error("expected ReduceMemoryUsage true")
}
if sc.ServerTokens {
t.Error("expected ServerTokens false")
}
}
// TestStartSingleMode_WithMiddlewareChain 测试中间件链配置。
func TestStartSingleMode_WithMiddlewareChain(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
Security: config.SecurityConfig{
Access: config.AccessConfig{
Allow: []string{"127.0.0.1"},
Deny: []string{"10.0.0.0/8"},
},
RateLimit: config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
Key: "remote_addr",
},
Auth: config.AuthConfig{
Users: []config.User{
{Name: "admin", Password: "secret"},
},
},
Headers: config.SecurityHeaders{
XFrameOptions: "DENY",
XContentTypeOptions: "nosniff",
ContentSecurityPolicy: "default-src 'self'",
ReferrerPolicy: "strict-origin-when-cross-origin",
},
},
Compression: config.CompressionConfig{
Type: "gzip",
Level: 6,
},
Rewrite: []config.RewriteRule{
{Pattern: "^/old/(.*)$", Replacement: "/new/$1"},
},
}},
}
s := New(cfg)
// 验证中间件配置
security := s.config.Servers[0].Security
if len(security.Access.Allow) != 1 {
t.Errorf("expected 1 allow rule, got %d", len(security.Access.Allow))
}
if security.RateLimit.RequestRate != 100 {
t.Errorf("expected request rate 100, got %d", security.RateLimit.RequestRate)
}
if len(security.Auth.Users) != 1 {
t.Errorf("expected 1 auth user, got %d", len(security.Auth.Users))
}
}
// TestStartSingleMode_PerformanceConfig 测试性能配置。
func TestStartSingleMode_PerformanceConfig(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
}},
Performance: config.PerformanceConfig{
GoroutinePool: config.GoroutinePoolConfig{
Enabled: true,
MaxWorkers: 100,
MinWorkers: 10,
IdleTimeout: 30 * time.Second,
},
FileCache: config.FileCacheConfig{
MaxEntries: 10000,
MaxSize: 100 * 1024 * 1024,
},
},
}
s := New(cfg)
// 验证性能配置
if !s.config.Performance.GoroutinePool.Enabled {
t.Error("expected goroutine pool enabled")
}
if s.config.Performance.FileCache.MaxEntries != 10000 {
t.Errorf("expected 10000 max entries, got %d", s.config.Performance.FileCache.MaxEntries)
}
}
// TestStartSingleMode_WithLuaMiddleware 测试 Lua 中间件配置。
func TestStartSingleMode_WithLuaMiddleware(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
Lua: &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{
Path: "/scripts/access.lua",
Phase: "access",
Timeout: 30 * time.Second,
},
{
Path: "/scripts/header.lua",
Phase: "header_filter",
Timeout: 10 * time.Second,
},
},
},
}},
}
s := New(cfg)
// 验证 Lua 配置
if s.config.Servers[0].Lua == nil || !s.config.Servers[0].Lua.Enabled {
t.Error("expected Lua enabled")
}
if len(s.config.Servers[0].Lua.Scripts) != 2 {
t.Errorf("expected 2 scripts, got %d", len(s.config.Servers[0].Lua.Scripts))
}
}
// TestStartSingleMode_WithErrorPage 测试错误页面配置。
func TestStartSingleMode_WithErrorPage(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
Security: config.SecurityConfig{
ErrorPage: config.ErrorPageConfig{
Pages: map[int]string{
404: "/errors/404.html",
500: "/errors/500.html",
502: "/errors/502.html",
},
Default: "/errors/default.html",
},
},
}},
}
s := New(cfg)
// 验证错误页面配置
ep := s.config.Servers[0].Security.ErrorPage
if len(ep.Pages) != 3 {
t.Errorf("expected 3 error pages, got %d", len(ep.Pages))
}
if ep.Default != "/errors/default.html" {
t.Errorf("expected default error page, got %s", ep.Default)
}
}
// TestStartSingleMode_WithConnLimiter 测试连接限制配置。
func TestStartSingleMode_WithConnLimiter(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
Security: config.SecurityConfig{
RateLimit: config.RateLimitConfig{
ConnLimit: 100,
Key: "remote_addr",
},
},
}},
}
s := New(cfg)
// 验证连接限制配置
if s.config.Servers[0].Security.RateLimit.ConnLimit != 100 {
t.Errorf("expected ConnLimit 100, got %d", s.config.Servers[0].Security.RateLimit.ConnLimit)
}
}
// TestStartSingleMode_WithAuthRequest 测试外部认证配置。
func TestStartSingleMode_WithAuthRequest(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: "127.0.0.1:0",
Security: config.SecurityConfig{
AuthRequest: config.AuthRequestConfig{
Enabled: true,
URI: "/auth/validate",
Timeout: 5 * time.Second,
},
},
}},
}
s := New(cfg)
// 验证外部认证配置
ar := s.config.Servers[0].Security.AuthRequest
if !ar.Enabled {
t.Error("expected AuthRequest enabled")
}
if ar.URI != "/auth/validate" {
t.Errorf("expected URI /auth/validate, got %s", ar.URI)
}
}
// TestShutdownServers_EmptySlice 测试空服务器列表。
func TestShutdownServers_EmptySlice(t *testing.T) {
ctx := context.Background()
err := shutdownServers(ctx, []*fasthttp.Server{})
if err != nil {
t.Errorf("shutdownServers with empty slice should return nil, got: %v", err)
}
}
// TestShutdownServers_NilSlice 测试 nil 服务器列表。
func TestShutdownServers_NilSlice(t *testing.T) {
ctx := context.Background()
err := shutdownServers(ctx, nil)
if err != nil {
t.Errorf("shutdownServers with nil slice should return nil, got: %v", err)
}
}
// TestShutdownServers_NilContext 测试 nil 上下文。
func TestShutdownServers_NilContext(t *testing.T) {
// nil ctx 应该使用 context.Background()
err := shutdownServers(nil, []*fasthttp.Server{})
if err != nil {
t.Errorf("shutdownServers with nil ctx should return nil, got: %v", err)
}
}
// TestShutdownServers_SingleServer 测试单个服务器关闭。
func TestShutdownServers_SingleServer(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
servers := []*fasthttp.Server{
{Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }},
}
err := shutdownServers(ctx, servers)
if err != nil {
t.Errorf("shutdownServers failed: %v", err)
}
}
// TestShutdownServers_MultipleServers 测试多个服务器关闭。
func TestShutdownServers_MultipleServers(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
servers := []*fasthttp.Server{
{Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test1") }},
{Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test2") }},
{Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test3") }},
}
err := shutdownServers(ctx, servers)
if err != nil {
t.Errorf("shutdownServers failed: %v", err)
}
}
// TestShutdownServers_WithNilServers 测试服务器列表中包含 nil。
func TestShutdownServers_WithNilServers(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
servers := []*fasthttp.Server{
nil,
{Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }},
nil,
}
err := shutdownServers(ctx, servers)
if err != nil {
t.Errorf("shutdownServers failed: %v", err)
}
}
// TestShutdownServers_AllNilServers 测试所有服务器都是 nil。
func TestShutdownServers_AllNilServers(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
servers := []*fasthttp.Server{nil, nil, nil}
err := shutdownServers(ctx, servers)
if err != nil {
t.Errorf("shutdownServers with all nil servers should return nil, got: %v", err)
}
}
// TestShutdownServers_ContextCancelled 测试上下文取消。
func TestShutdownServers_ContextCancelled(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
// 创建一个已取消的上下文
ctx, cancel := context.WithCancel(context.Background())
cancel()
servers := []*fasthttp.Server{
{Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }},
}
err := shutdownServers(ctx, servers)
// 已取消的上下文可能返回 context.Canceled 或 nil取决于服务器关闭速度
if err != nil && err != context.Canceled {
t.Errorf("unexpected error: %v", err)
}
}
// TestShutdownServers_ContextTimeout 测试上下文超时。
func TestShutdownServers_ContextTimeout(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
// 创建一个极短超时的上下文
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel()
// 等待超时
time.Sleep(1 * time.Millisecond)
servers := []*fasthttp.Server{
{Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }},
}
err := shutdownServers(ctx, servers)
// 超时的上下文可能返回 context.DeadlineExceeded 或 nil
if err != nil && err != context.DeadlineExceeded {
t.Errorf("unexpected error: %v", err)
}
}
// TestShutdownServers_RunningServers 测试关闭运行中的服务器。
func TestShutdownServers_RunningServers(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 创建服务器并启动
servers := make([]*fasthttp.Server, 2)
listeners := make([]net.Listener, 2)
for i := range 2 {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
listeners[i] = ln
srv := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetBodyString("test")
},
}
servers[i] = srv
go func(s *fasthttp.Server, l net.Listener) {
_ = s.Serve(l)
}(srv, ln)
}
// 等待服务器启动
time.Sleep(10 * time.Millisecond)
// 关闭服务器
err := shutdownServers(ctx, servers)
if err != nil {
t.Errorf("shutdownServers failed: %v", err)
}
// 关闭监听器(如果服务器没有关闭它们)
for _, ln := range listeners {
_ = ln.Close()
}
}
// TestShutdownServers_ManyServers 测试关闭大量服务器。
func TestShutdownServers_ManyServers(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 创建大量服务器
count := 50
servers := make([]*fasthttp.Server, count)
for i := range count {
servers[i] = &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") },
}
}
err := shutdownServers(ctx, servers)
if err != nil {
t.Errorf("shutdownServers with many servers failed: %v", err)
}
}
// TestShutdownServers_MixedNilAndRealServers 测试混合 nil 和真实服务器。
func TestShutdownServers_MixedNilAndRealServers(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
count := 20
servers := make([]*fasthttp.Server, count)
for i := range count {
if i%2 == 0 {
servers[i] = nil
} else {
servers[i] = &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") },
}
}
}
err := shutdownServers(ctx, servers)
if err != nil {
t.Errorf("shutdownServers failed: %v", err)
}
}
// TestShutdownServers_ConcurrentSafety 测试并发安全性。
func TestShutdownServers_ConcurrentSafety(t *testing.T) {
ctx := context.Background()
// 并发调用 shutdownServers
var wg sync.WaitGroup
for range 10 {
wg.Go(func() {
servers := []*fasthttp.Server{
{Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }},
{Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }},
}
_ = shutdownServers(ctx, servers)
})
}
wg.Wait()
}
// TestShutdownServers_WithDeadline 测试带截止时间的上下文。
func TestShutdownServers_WithDeadline(t *testing.T) {
deadline := time.Now().Add(5 * time.Second)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
servers := []*fasthttp.Server{
{Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }},
}
err := shutdownServers(ctx, servers)
if err != nil {
t.Errorf("shutdownServers failed: %v", err)
}
}
// TestBuildLuaMiddlewares_SingleScript 测试单个脚本配置。
func TestBuildLuaMiddlewares_SingleScript(t *testing.T) {
// 创建临时 Lua 脚本
tempDir := t.TempDir()
scriptPath := tempDir + "/test.lua"
if err := os.WriteFile(scriptPath, []byte("ngx.say('hello')"), 0o644); err != nil {
t.Fatalf("failed to create script: %v", err)
}
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
luaEngine, err := lua.NewEngine(lua.DefaultConfig())
if err != nil {
t.Skipf("failed to create Lua engine: %v", err)
}
s.luaEngine = luaEngine
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{Path: scriptPath, Phase: "access", Timeout: 10 * time.Second, Enabled: true},
},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
if err != nil {
t.Errorf("expected nil error, got: %v", err)
}
if len(middlewares) != 1 {
t.Errorf("expected 1 middleware, got: %d", len(middlewares))
}
}
// TestBuildLuaMiddlewares_SingleScriptDefaultTimeout 测试单脚本默认超时。
func TestBuildLuaMiddlewares_SingleScriptDefaultTimeout(t *testing.T) {
tempDir := t.TempDir()
scriptPath := tempDir + "/test.lua"
if err := os.WriteFile(scriptPath, []byte("ngx.say('hello')"), 0o644); err != nil {
t.Fatalf("failed to create script: %v", err)
}
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
luaEngine, err := lua.NewEngine(lua.DefaultConfig())
if err != nil {
t.Skipf("failed to create Lua engine: %v", err)
}
s.luaEngine = luaEngine
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{Path: scriptPath, Phase: "content", Timeout: 0}, // 使用默认超时
},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
if err != nil {
t.Errorf("expected nil error, got: %v", err)
}
if len(middlewares) != 1 {
t.Errorf("expected 1 middleware, got: %d", len(middlewares))
}
}
// TestBuildLuaMiddlewares_MultipleScriptsSamePhase 测试多脚本同阶段。
func TestBuildLuaMiddlewares_MultipleScriptsSamePhase(t *testing.T) {
tempDir := t.TempDir()
script1 := tempDir + "/test1.lua"
script2 := tempDir + "/test2.lua"
if err := os.WriteFile(script1, []byte("ngx.say('1')"), 0o644); err != nil {
t.Fatalf("failed to create script: %v", err)
}
if err := os.WriteFile(script2, []byte("ngx.say('2')"), 0o644); err != nil {
t.Fatalf("failed to create script: %v", err)
}
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
luaEngine, err := lua.NewEngine(lua.DefaultConfig())
if err != nil {
t.Skipf("failed to create Lua engine: %v", err)
}
s.luaEngine = luaEngine
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{Path: script1, Phase: "access", Timeout: 10 * time.Second, Enabled: true},
{Path: script2, Phase: "access", Timeout: 20 * time.Second, Enabled: true},
},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
if err != nil {
t.Errorf("expected nil error, got: %v", err)
}
if len(middlewares) != 1 {
t.Errorf("expected 1 middleware (multi-phase), got: %d", len(middlewares))
}
}
// TestBuildLuaMiddlewares_MultipleScriptsDifferentPhases 测试多脚本不同阶段。
func TestBuildLuaMiddlewares_MultipleScriptsDifferentPhases(t *testing.T) {
tempDir := t.TempDir()
script1 := tempDir + "/rewrite.lua"
script2 := tempDir + "/access.lua"
script3 := tempDir + "/log.lua"
for _, p := range []string{script1, script2, script3} {
if err := os.WriteFile(p, []byte("ngx.say('hello')"), 0o644); err != nil {
t.Fatalf("failed to create script: %v", err)
}
}
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
luaEngine, err := lua.NewEngine(lua.DefaultConfig())
if err != nil {
t.Skipf("failed to create Lua engine: %v", err)
}
s.luaEngine = luaEngine
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{Path: script1, Phase: "rewrite", Timeout: 10 * time.Second, Enabled: true},
{Path: script2, Phase: "access", Timeout: 15 * time.Second, Enabled: true},
{Path: script3, Phase: "log", Timeout: 20 * time.Second, Enabled: true},
},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
if err != nil {
t.Errorf("expected nil error, got: %v", err)
}
if len(middlewares) != 3 {
t.Errorf("expected 3 middlewares, got: %d", len(middlewares))
}
}
// TestBuildLuaMiddlewares_DefaultEnabled 测试默认启用逻辑。
func TestBuildLuaMiddlewares_DefaultEnabled(t *testing.T) {
tempDir := t.TempDir()
scriptPath := tempDir + "/test.lua"
if err := os.WriteFile(scriptPath, []byte("ngx.say('hello')"), 0o644); err != nil {
t.Fatalf("failed to create script: %v", err)
}
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
luaEngine, err := lua.NewEngine(lua.DefaultConfig())
if err != nil {
t.Skipf("failed to create Lua engine: %v", err)
}
s.luaEngine = luaEngine
// Enabled 为 false但 Timeout=0 且 Path 不为空,应该默认启用
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{Path: scriptPath, Phase: "access", Timeout: 0, Enabled: false},
},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
if err != nil {
t.Errorf("expected nil error, got: %v", err)
}
// 默认启用逻辑Enabled=false && Timeout=0 && Path!="" -> enabled=true
if len(middlewares) != 1 {
t.Errorf("expected 1 middleware (default enabled), got: %d", len(middlewares))
}
}
// TestBuildLuaMiddlewares_InvalidPhaseInMultiScript 测试多脚本中的无效阶段。
func TestBuildLuaMiddlewares_InvalidPhaseInMultiScript(t *testing.T) {
tempDir := t.TempDir()
script1 := tempDir + "/test1.lua"
script2 := tempDir + "/test2.lua"
if err := os.WriteFile(script1, []byte("ngx.say('1')"), 0o644); err != nil {
t.Fatalf("failed to create script: %v", err)
}
if err := os.WriteFile(script2, []byte("ngx.say('2')"), 0o644); err != nil {
t.Fatalf("failed to create script: %v", err)
}
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
luaEngine, err := lua.NewEngine(lua.DefaultConfig())
if err != nil {
t.Skipf("failed to create Lua engine: %v", err)
}
s.luaEngine = luaEngine
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{Path: script1, Phase: "access", Timeout: 10 * time.Second, Enabled: true},
{Path: script2, Phase: "invalid_phase", Timeout: 10 * time.Second, Enabled: true},
},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
if err == nil {
t.Error("expected error for invalid phase in multi-script")
}
if middlewares != nil {
t.Errorf("expected nil middlewares on error, got: %v", middlewares)
}
}
// TestBuildLuaMiddlewares_AllPhases 测试所有阶段。
func TestBuildLuaMiddlewares_AllPhases(t *testing.T) {
tempDir := t.TempDir()
phases := []string{"rewrite", "access", "content", "log", "header_filter", "body_filter"}
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
luaEngine, err := lua.NewEngine(lua.DefaultConfig())
if err != nil {
t.Skipf("failed to create Lua engine: %v", err)
}
s.luaEngine = luaEngine
scripts := make([]config.LuaScriptConfig, len(phases))
for i, phase := range phases {
scriptPath := tempDir + "/" + phase + ".lua"
if err := os.WriteFile(scriptPath, []byte("ngx.say('"+phase+"')"), 0o644); err != nil {
t.Fatalf("failed to create script: %v", err)
}
scripts[i] = config.LuaScriptConfig{Path: scriptPath, Phase: phase, Timeout: 10 * time.Second, Enabled: true}
}
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: scripts,
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
if err != nil {
t.Errorf("expected nil error, got: %v", err)
}
if len(middlewares) != len(phases) {
t.Errorf("expected %d middlewares, got: %d", len(phases), len(middlewares))
}
}
// TestBuildLuaMiddlewares_NonExistentScript 测试不存在的脚本文件。
func TestBuildLuaMiddlewares_NonExistentScript(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
luaEngine, err := lua.NewEngine(lua.DefaultConfig())
if err != nil {
t.Skipf("failed to create Lua engine: %v", err)
}
s.luaEngine = luaEngine
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{Path: "/non/existent/script.lua", Phase: "access", Timeout: 10 * time.Second},
},
}
// NewLuaMiddleware 会在创建时验证脚本文件
middlewares, err := s.buildLuaMiddlewares(luaCfg)
// 由于脚本不存在,可能会返回错误或创建失败
// 这取决于 lua.NewLuaMiddleware 的实现
_ = middlewares
_ = err
}
// TestBuildLuaMiddlewares_MixedEnabledDisabled 测试混合启用禁用脚本。
func TestBuildLuaMiddlewares_MixedEnabledDisabled(t *testing.T) {
tempDir := t.TempDir()
for _, name := range []string{"enabled1", "enabled2", "disabled1", "disabled2"} {
scriptPath := tempDir + "/" + name + ".lua"
if err := os.WriteFile(scriptPath, []byte("ngx.say('"+name+"')"), 0o644); err != nil {
t.Fatalf("failed to create script: %v", err)
}
}
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
luaEngine, err := lua.NewEngine(lua.DefaultConfig())
if err != nil {
t.Skipf("failed to create Lua engine: %v", err)
}
s.luaEngine = luaEngine
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{Path: tempDir + "/enabled1.lua", Phase: "rewrite", Timeout: 10 * time.Second, Enabled: true},
{Path: tempDir + "/disabled1.lua", Phase: "rewrite", Timeout: 10 * time.Second, Enabled: false},
{Path: tempDir + "/enabled2.lua", Phase: "access", Timeout: 10 * time.Second, Enabled: true},
{Path: tempDir + "/disabled2.lua", Phase: "access", Timeout: 10 * time.Second, Enabled: false},
},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
if err != nil {
t.Errorf("expected nil error, got: %v", err)
}
// 只有启用的脚本应该被处理rewrite(1) + access(1) = 2
if len(middlewares) != 2 {
t.Errorf("expected 2 middlewares, got: %d", len(middlewares))
}
}
// TestBuildLuaMiddlewares_MultiPhaseDefaultTimeout 测试多脚本阶段默认超时。
func TestBuildLuaMiddlewares_MultiPhaseDefaultTimeout(t *testing.T) {
tempDir := t.TempDir()
script1 := tempDir + "/test1.lua"
script2 := tempDir + "/test2.lua"
if err := os.WriteFile(script1, []byte("ngx.say('1')"), 0o644); err != nil {
t.Fatalf("failed to create script: %v", err)
}
if err := os.WriteFile(script2, []byte("ngx.say('2')"), 0o644); err != nil {
t.Fatalf("failed to create script: %v", err)
}
cfg := &config.Config{
Servers: []config.ServerConfig{{
Listen: ":8080",
}},
}
s := New(cfg)
luaEngine, err := lua.NewEngine(lua.DefaultConfig())
if err != nil {
t.Skipf("failed to create Lua engine: %v", err)
}
s.luaEngine = luaEngine
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{Path: script1, Phase: "access", Timeout: 0}, // 默认超时
{Path: script2, Phase: "access", Timeout: 0}, // 默认超时
},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
if err != nil {
t.Errorf("expected nil error, got: %v", err)
}
if len(middlewares) != 1 {
t.Errorf("expected 1 middleware (multi-phase), got: %d", len(middlewares))
}
}