// Package server 提供 HTTP 服务器功能的测试。 // // 该文件测试服务器模块的各项功能,包括: // - 服务器创建和初始化 // - 启动和停止控制 // - 优雅关闭 // - 中间件链构建 // - 请求统计追踪 // - 监听器管理 // - TLS 配置 // - 代理缓存统计 // // 作者:xfy package server import ( "context" "fmt" "net" "os" "strings" "sync" "testing" "time" "github.com/valyala/fasthttp" "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/lua" "rua.plus/lolly/internal/matcher" "rua.plus/lolly/internal/middleware/accesslog" "rua.plus/lolly/internal/middleware/security" "rua.plus/lolly/internal/proxy" "rua.plus/lolly/internal/ssl" "rua.plus/lolly/internal/version" ) // TestNew 测试服务器创建 func TestNew(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", Static: []config.StaticConfig{{ Path: "/", Root: "./static", Index: []string{"index.html"}, }}, }}, } s := New(cfg) if s == nil { t.Fatal("New() returned nil, expected non-nil Server") } if s.config != cfg { t.Error("Server.config not set correctly") } if s.running.Load() { t.Error("Server.running should be false initially") } if s.fastServer != nil { t.Error("Server.fastServer should be nil before Start()") } } // TestStopWithoutServer 测试无服务器时调用 Stop func TestStopWithoutServer(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 在未启动时调用 Stop,应返回 nil err := s.StopWithTimeout(5 * time.Second) if err != nil { t.Errorf("StopWithTimeout() on non-started server returned error: %v", err) } } // TestGracefulStop 测试 GracefulStop 调用 func TestGracefulStop(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 在未启动时调用 GracefulStop,应返回 nil err := s.GracefulStop(5 * time.Second) if err != nil { t.Errorf("GracefulStop() on non-started server returned error: %v", err) } } // TestStopAfterStop 测试多次调用 Stop func TestStopAfterStop(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 多次调用 StopWithTimeout 应该都是安全的 for i := range 3 { err := s.StopWithTimeout(5 * time.Second) if err != nil { t.Errorf("StopWithTimeout() call %d returned error: %v", i+1, err) } } } // TestGracefulStopWithZeroTimeout 测试零超时的 GracefulStop func TestGracefulStopWithZeroTimeout(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) err := s.GracefulStop(0) if err != nil { t.Errorf("GracefulStop(0) returned error: %v", err) } } // TestBuildMiddlewareChain_AccessLog 测试访问日志中间件 func TestBuildMiddlewareChain_AccessLog(t *testing.T) { cfg := &config.Config{ Logging: config.LoggingConfig{}, Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) chain, err := s.buildMiddlewareChain(&cfg.Servers[0]) if err != nil { t.Errorf("buildMiddlewareChain failed: %v", err) } if chain == nil { t.Error("Expected non-nil chain") } } // TestBuildMiddlewareChain_AccessControl 测试访问控制中间件 func TestBuildMiddlewareChain_AccessControl(t *testing.T) { cfg := &config.Config{ Logging: config.LoggingConfig{}, Servers: []config.ServerConfig{{ Listen: ":8080", Security: config.SecurityConfig{ Access: config.AccessConfig{ Allow: []string{"127.0.0.1"}, }, }, }}, } s := New(cfg) chain, err := s.buildMiddlewareChain(&cfg.Servers[0]) if err != nil { t.Errorf("buildMiddlewareChain failed: %v", err) } if chain == nil { t.Error("Expected non-nil chain") } } // TestBuildMiddlewareChain_RateLimiter 测试限流中间件 func TestBuildMiddlewareChain_RateLimiter(t *testing.T) { cfg := &config.Config{ Logging: config.LoggingConfig{}, Servers: []config.ServerConfig{{ Listen: ":8080", Security: config.SecurityConfig{ RateLimit: config.RateLimitConfig{ RequestRate: 100, Burst: 200, }, }, }}, } s := New(cfg) chain, err := s.buildMiddlewareChain(&cfg.Servers[0]) if err != nil { t.Errorf("buildMiddlewareChain failed: %v", err) } if chain == nil { t.Error("Expected non-nil chain") } } // TestBuildMiddlewareChain_Rewrite 测试重写中间件 func TestBuildMiddlewareChain_Rewrite(t *testing.T) { cfg := &config.Config{ Logging: config.LoggingConfig{}, Servers: []config.ServerConfig{{ Listen: ":8080", Rewrite: []config.RewriteRule{ {Pattern: "/old/(.*)", Replacement: "/new/$1"}, }, }}, } s := New(cfg) chain, err := s.buildMiddlewareChain(&cfg.Servers[0]) if err != nil { t.Errorf("buildMiddlewareChain failed: %v", err) } if chain == nil { t.Error("Expected non-nil chain") } } // TestBuildMiddlewareChain_Compression 测试压缩中间件 func TestBuildMiddlewareChain_Compression(t *testing.T) { cfg := &config.Config{ Logging: config.LoggingConfig{}, Servers: []config.ServerConfig{{ Listen: ":8080", Compression: config.CompressionConfig{ Level: 6, }, }}, } s := New(cfg) chain, err := s.buildMiddlewareChain(&cfg.Servers[0]) if err != nil { t.Errorf("buildMiddlewareChain failed: %v", err) } if chain == nil { t.Error("Expected non-nil chain") } } // TestBuildMiddlewareChain_SecurityHeaders 测试安全头中间件 func TestBuildMiddlewareChain_SecurityHeaders(t *testing.T) { cfg := &config.Config{ Logging: config.LoggingConfig{}, Servers: []config.ServerConfig{{ Listen: ":8080", Security: config.SecurityConfig{ Headers: config.SecurityHeaders{ XFrameOptions: "DENY", XContentTypeOptions: "nosniff", }, }, }}, } s := New(cfg) chain, err := s.buildMiddlewareChain(&cfg.Servers[0]) if err != nil { t.Errorf("buildMiddlewareChain failed: %v", err) } if chain == nil { t.Error("Expected non-nil chain") } } // TestBuildMiddlewareChain_AllMiddlewares 测试所有中间件组合 func TestBuildMiddlewareChain_AllMiddlewares(t *testing.T) { cfg := &config.Config{ Logging: config.LoggingConfig{}, Servers: []config.ServerConfig{{ Listen: ":8080", Security: config.SecurityConfig{ Access: config.AccessConfig{ Allow: []string{"127.0.0.1"}, }, RateLimit: config.RateLimitConfig{ RequestRate: 100, Burst: 200, }, Headers: config.SecurityHeaders{ XFrameOptions: "DENY", }, }, Rewrite: []config.RewriteRule{ {Pattern: "/old/(.*)", Replacement: "/new/$1"}, }, Compression: config.CompressionConfig{ Level: 6, }, }}, } s := New(cfg) chain, err := s.buildMiddlewareChain(&cfg.Servers[0]) if err != nil { t.Errorf("buildMiddlewareChain failed: %v", err) } if chain == nil { t.Error("Expected non-nil chain") } } // TestTrackStats 测试请求统计追踪 func TestTrackStats(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 初始统计应该为 0 if s.requests.Load() != 0 { t.Error("Initial requests should be 0") } if s.bytesSent.Load() != 0 { t.Error("Initial bytesSent should be 0") } if s.bytesReceived.Load() != 0 { t.Error("Initial bytesReceived should be 0") } // 创建测试 handler handler := func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("response body") } // 包装 handler wrappedHandler := s.trackStats(handler) // 创建测试请求上下文 ctx := &fasthttp.RequestCtx{} ctx.Init(&fasthttp.Request{}, nil, nil) ctx.Request.SetBody([]byte("request body")) // 执行 wrappedHandler(ctx) // 验证统计 if s.requests.Load() != 1 { t.Errorf("Expected 1 request, got %d", s.requests.Load()) } if s.bytesReceived.Load() != int64(len("request body")) { t.Errorf("Expected bytesReceived %d, got %d", len("request body"), s.bytesReceived.Load()) } if s.bytesSent.Load() != int64(len("response body")) { t.Errorf("Expected bytesSent %d, got %d", len("response body"), s.bytesSent.Load()) } } // TestTrackStats_MultipleRequests 测试多次请求统计 func TestTrackStats_MultipleRequests(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) handler := func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("ok") } wrappedHandler := s.trackStats(handler) // 执行多次请求 for range 10 { ctx := &fasthttp.RequestCtx{} ctx.Init(&fasthttp.Request{}, nil, nil) wrappedHandler(ctx) } if s.requests.Load() != 10 { t.Errorf("Expected 10 requests, got %d", s.requests.Load()) } } // TestGetListeners_Empty 测试空监听器列表 func TestGetListeners_Empty(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) listeners := s.GetListeners() if listeners != nil { t.Errorf("Expected nil listeners, got %v", listeners) } } // TestSetListeners 测试设置监听器 func TestSetListeners(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 创建模拟监听器 listener1, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("Failed to create listener: %v", err) } defer func() { _ = listener1.Close() }() listener2, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("Failed to create listener: %v", err) } defer func() { _ = listener2.Close() }() listeners := []net.Listener{listener1, listener2} s.SetListeners(listeners) // 验证设置成功 got := s.GetListeners() if len(got) != 2 { t.Errorf("Expected 2 listeners, got %d", len(got)) } } // TestGetTLSConfig_NotConfigured 测试未配置 TLS func TestGetTLSConfig_NotConfigured(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) tlsConfig, err := s.GetTLSConfig() if err == nil { t.Error("Expected error for unconfigured TLS") } if tlsConfig != nil { t.Error("Expected nil TLS config") } if err.Error() != "TLS not configured" { t.Errorf("Expected error 'TLS not configured', got: %v", err) } } // TestGetHandler 测试获取 handler func TestGetHandler(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 初始 handler 应该为 nil handler := s.GetHandler() if handler != nil { t.Error("Expected nil handler initially") } // 设置一个 handler testHandler := func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") } s.handler = testHandler // 验证获取成功 got := s.GetHandler() if got == nil { t.Error("Expected non-nil handler after setting") } } // TestServer_Connections 测试连接统计 func TestServer_Connections(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 初始连接数应该为 0 if s.connections.Load() != 0 { t.Error("Initial connections should be 0") } // 增加 s.connections.Add(1) if s.connections.Load() != 1 { t.Errorf("Expected 1 connection, got %d", s.connections.Load()) } // 减少 s.connections.Add(-1) if s.connections.Load() != 0 { t.Errorf("Expected 0 connections, got %d", s.connections.Load()) } } // TestServer_Proxies 测试代理管理 func TestServer_Proxies(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 初始代理列表应为空 if len(s.proxies) != 0 { t.Error("Initial proxies should be empty") } } // TestServer_Running 测试运行状态 func TestServer_Running(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 初始状态应为未运行 if s.running.Load() { t.Error("Initial running state should be false") } } // TestServer_StopWithNilFastServer 测试无 fastServer 时停止 func TestServer_StopWithNilFastServer(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) s.fastServer = nil err := s.StopWithTimeout(5 * time.Second) if err != nil { t.Errorf("StopWithTimeout with nil fastServer should succeed: %v", err) } } // TestServer_GracefulStopWithNilFastServer 测试无 fastServer 时优雅停止 func TestServer_GracefulStopWithNilFastServer(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) s.fastServer = nil err := s.GracefulStop(5 * time.Second) if err != nil { t.Errorf("GracefulStop with nil fastServer should succeed: %v", err) } } // TestServer_GetProxyCacheStats 测试代理缓存统计 func TestServer_GetProxyCacheStats(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 无代理时应返回空统计 stats := s.getProxyCacheStats() if stats.Entries != 0 { t.Errorf("Expected 0 entries, got %d", stats.Entries) } if stats.Pending != 0 { t.Errorf("Expected 0 pending, got %d", stats.Pending) } } // TestServer_BuildMiddlewareChain_EmptyConfig 测试空配置的中间件链 func TestServer_BuildMiddlewareChain_EmptyConfig(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) chain, err := s.buildMiddlewareChain(&cfg.Servers[0]) if err != nil { t.Errorf("Unexpected error: %v", err) } if chain == nil { t.Error("Expected non-nil chain") } } // TestServer_TrackStats_EmptyBody 测试空响应体的统计 func TestServer_TrackStats_EmptyBody(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) handler := func(_ *fasthttp.RequestCtx) { // 空响应 } wrappedHandler := s.trackStats(handler) ctx := &fasthttp.RequestCtx{} ctx.Init(&fasthttp.Request{}, nil, nil) ctx.Request.SetBody(nil) wrappedHandler(ctx) if s.requests.Load() != 1 { t.Errorf("Expected 1 request, got %d", s.requests.Load()) } if s.bytesSent.Load() != 0 { t.Errorf("Expected 0 bytes sent, got %d", s.bytesSent.Load()) } } // TestStart_Success 测试服务器配置初始化 func TestStart_Success(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 验证服务器正确初始化 if s == nil { t.Fatal("New() returned nil, expected non-nil Server") } if s.config != cfg { t.Error("Server.config not set correctly") } } // TestStart_WithStaticFiles 测试静态文件配置 func TestStart_WithStaticFiles(t *testing.T) { // 创建临时目录 tempDir := t.TempDir() cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", Static: []config.StaticConfig{{ Path: "/static", Root: tempDir, Index: []string{"index.html"}, }}, }}, } s := New(cfg) if s == nil { t.Fatal("New() returned nil") } } // TestStart_WithGoroutinePool 测试 GoroutinePool 配置 func TestStart_WithGoroutinePool(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, Performance: config.PerformanceConfig{ GoroutinePool: config.GoroutinePoolConfig{ Enabled: true, MaxWorkers: 100, MinWorkers: 10, IdleTimeout: 30 * time.Second, }, }, } s := New(cfg) if s == nil { t.Fatal("New() returned nil") } } // TestStart_WithFileCache 测试文件缓存配置 func TestStart_WithFileCache(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, Performance: config.PerformanceConfig{ FileCache: config.FileCacheConfig{ MaxEntries: 1000, MaxSize: 100 * 1024 * 1024, }, }, } s := New(cfg) if s == nil { t.Fatal("New() returned nil") } } // TestStop_Graceful 测试优雅停止(无 race 模式) func TestStop_Graceful(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) // 在未启动时调用 GracefulStop,应返回 nil err := s.GracefulStop(1 * time.Second) if err != nil { t.Errorf("GracefulStop() on non-started server returned error: %v", err) } } // TestGetTLSConfig_Nil 测试无 TLS 配置 func TestGetTLSConfig_Nil(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) tlsCfg, err := s.GetTLSConfig() if err == nil { t.Error("GetTLSConfig() should return error when TLS not configured") } if tlsCfg != nil { t.Error("GetTLSConfig() should return nil when TLS not configured") } } // TestGetTLSConfig_NilServer 测试 nil 服务器调用 GetTLSConfig func TestGetTLSConfig_NilServer(t *testing.T) { var s *Server // 防御性:如果 s 为 nil,调用方法会 panic,这是预期的行为 // 这里我们只测试非 nil 但 tlsManager 为 nil 的情况 cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s = New(cfg) // 确保 tlsManager 为 nil if s.tlsManager != nil { t.Skip("tlsManager should be nil initially") } tlsCfg, err := s.GetTLSConfig() if err == nil { t.Error("Expected error when tlsManager is nil") } if tlsCfg != nil { t.Error("Expected nil TLS config when tlsManager is nil") } if err.Error() != "TLS not configured" { t.Errorf("Expected error 'TLS not configured', got: %v", err) } } // TestGetServerName 测试服务器名称返回。 func TestGetServerName(t *testing.T) { tests := []struct { name string cfg *config.ServerConfig wantName string }{ { name: "nil config", cfg: nil, wantName: "lolly/" + version.Version, }, { name: "ServerTokens true (default)", cfg: &config.ServerConfig{ ServerTokens: true, }, wantName: "lolly/" + version.Version, }, { name: "ServerTokens false", cfg: &config.ServerConfig{ ServerTokens: false, }, wantName: "lolly", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &Server{} got := s.getServerName(tt.cfg) if got != tt.wantName { t.Errorf("getServerName() = %q, want %q", got, tt.wantName) } }) } } // TestApplyTypesConfig 测试 MIME 类型配置应用。 func TestApplyTypesConfig(t *testing.T) { t.Run("nil config", func(t *testing.T) { s := &Server{} // 不应该 panic s.applyTypesConfig(nil) }) t.Run("empty config", func(t *testing.T) { s := &Server{} cfg := &config.ServerConfig{} // 不应该 panic s.applyTypesConfig(cfg) }) t.Run("with types map", func(t *testing.T) { s := &Server{} cfg := &config.ServerConfig{ Types: config.TypesConfig{ Map: map[string]string{ ".custom": "application/x-custom", }, }, } // 不应该 panic s.applyTypesConfig(cfg) }) t.Run("with default type", func(t *testing.T) { s := &Server{} cfg := &config.ServerConfig{ Types: config.TypesConfig{ DefaultType: "application/octet-stream", }, } // 不应该 panic s.applyTypesConfig(cfg) }) } // TestCreateListener_TCP 测试 TCP 监听器创建。 func TestCreateListener_TCP(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", // 随机端口 }}, } s := New(cfg) ln, err := s.createListener(&cfg.Servers[0]) if err != nil { t.Fatalf("createListener() error: %v", err) } if ln == nil { t.Fatal("createListener() returned nil listener") } defer ln.Close() if ln.Addr().Network() != "tcp" { t.Errorf("expected tcp network, got %s", ln.Addr().Network()) } } // TestCreateListener_UnixSocket 测试 Unix Socket 监听器创建。 func TestCreateListener_UnixSocket(t *testing.T) { tempDir := t.TempDir() socketPath := tempDir + "/test.sock" cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "unix:" + socketPath, }}, } s := New(cfg) ln, err := s.createListener(&cfg.Servers[0]) if err != nil { t.Fatalf("createListener() error: %v", err) } if ln == nil { t.Fatal("createListener() returned nil listener") } defer ln.Close() if ln.Addr().Network() != "unix" { t.Errorf("expected unix network, got %s", ln.Addr().Network()) } } // TestCreateListener_UnixSocketWithPermissions 测试带权限的 Unix Socket 创建。 func TestCreateListener_UnixSocketWithPermissions(t *testing.T) { tempDir := t.TempDir() socketPath := tempDir + "/test_perm.sock" cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "unix:" + socketPath, UnixSocket: config.UnixSocketConfig{ Mode: 0o600, User: "nobody", Group: "nobody", }, }}, } s := New(cfg) ln, err := s.createListener(&cfg.Servers[0]) if err != nil { t.Fatalf("createListener() error: %v", err) } if ln == nil { t.Fatal("createListener() returned nil listener") } defer ln.Close() } // TestCreateListener_UnixSocketCleanup 测试 Unix Socket 文件清理。 func TestCreateListener_UnixSocketCleanup(t *testing.T) { tempDir := t.TempDir() socketPath := tempDir + "/cleanup.sock" // 先创建一个已存在的 socket 文件 if err := os.WriteFile(socketPath, []byte{}, 0o666); err != nil { t.Fatalf("failed to create existing socket file: %v", err) } cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "unix:" + socketPath, }}, } s := New(cfg) ln, err := s.createListener(&cfg.Servers[0]) if err != nil { t.Fatalf("createListener() error: %v", err) } defer ln.Close() } func TestHandleRegistrationError_ConflictWarning(t *testing.T) { s := &Server{} err := s.handleRegistrationError("proxy", "/api", &matcher.ConflictError{Path: "/api", ExistingType: "exact", NewType: "prefix"}) if err != nil { t.Errorf("conflict should return nil, got: %v", err) } } func TestHandleRegistrationError_FatalError(t *testing.T) { s := &Server{} err := s.handleRegistrationError("proxy", "/api", fmt.Errorf("invalid regex pattern: missing closing parenthesis")) if err == nil { t.Error("fatal error should return non-nil") } if !strings.Contains(err.Error(), "proxy route /api") { t.Errorf("error should wrap context, got: %v", err) } } func TestDupListener_TCP(t *testing.T) { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } defer ln.Close() duped, err := DupListener(ln) if err != nil { t.Fatalf("DupListener() error: %v", err) } defer duped.Close() if duped.Addr().Network() != "tcp" { t.Errorf("expected tcp, got %s", duped.Addr().Network()) } if duped.Addr().String() != ln.Addr().String() { t.Errorf("expected same address %s, got %s", ln.Addr().String(), duped.Addr().String()) } } func TestDupListener_Unix(t *testing.T) { dir := t.TempDir() socketPath := dir + "/dup.sock" ln, err := net.Listen("unix", socketPath) if err != nil { t.Fatal(err) } defer ln.Close() duped, err := DupListener(ln) if err != nil { t.Fatalf("DupListener() error: %v", err) } defer duped.Close() } func TestDupListener_Unsupported(t *testing.T) { _, err := DupListener(struct{ net.Listener }{}) if err == nil { t.Error("expected error for unsupported type") } } func TestTcpAddrMatch(t *testing.T) { s := &Server{} tests := []struct { inherited string target string want bool }{ {"127.0.0.1:8080", "127.0.0.1:8080", true}, {"0.0.0.0:8080", ":8080", true}, {"[::]:8080", ":8080", true}, {"0.0.0.0:8080", "0.0.0.0:8080", true}, {"0.0.0.0:8080", "127.0.0.1:8080", true}, {"127.0.0.1:8080", "0.0.0.0:8080", true}, {"127.0.0.1:8080", ":9090", false}, {"127.0.0.1:8080", "192.168.1.1:8080", false}, } for _, tt := range tests { got := s.tcpAddrMatch(tt.inherited, tt.target) if got != tt.want { t.Errorf("tcpAddrMatch(%q, %q) = %v, want %v", tt.inherited, tt.target, got, tt.want) } } } func TestMatchInheritedListener_TCP(t *testing.T) { s := &Server{} ln1, _ := net.Listen("tcp", "127.0.0.1:0") defer ln1.Close() ln2, _ := net.Listen("tcp", "127.0.0.1:0") defer ln2.Close() inherited := []net.Listener{ln1, ln2} result := s.matchInheritedListener(inherited, "0.0.0.0:99999") if result != nil { t.Error("expected nil for non-matching address") } addr1 := ln1.Addr().String() result = s.matchInheritedListener(inherited, addr1) if result != ln1 { t.Errorf("expected ln1 for address %s", addr1) } } func TestMatchInheritedListener_Empty(t *testing.T) { s := &Server{} result := s.matchInheritedListener(nil, ":8080") if result != nil { t.Error("expected nil for empty inherited list") } } func TestMatchInheritedListener_PresetListeners(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{Listen: "127.0.0.1:0"}}, } s := New(cfg) ln, err := s.createListener(&cfg.Servers[0]) if err != nil { t.Fatal(err) } defer ln.Close() s.SetListeners([]net.Listener{ln}) addr := ln.Addr().String() cfg.Servers[0].Listen = addr matched, err := s.createListener(&cfg.Servers[0]) if err != nil { t.Fatalf("createListener with preset should reuse: %v", err) } if matched == nil { t.Fatal("expected non-nil listener from preset match") } if matched.Addr().String() != addr { t.Errorf("expected same address %s, got %s", addr, matched.Addr().String()) } } // TestServer_StatsMethods 测试服务器统计方法。 func TestServer_StatsMethods(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 测试 startTime 初始值 if !s.startTime.IsZero() { t.Error("startTime should be zero initially") } // 设置 startTime s.startTime = time.Now() if s.startTime.IsZero() { t.Error("startTime should not be zero after setting") } // 测试统计值 if s.connections.Load() != 0 { t.Error("initial connections should be 0") } if s.requests.Load() != 0 { t.Error("initial requests should be 0") } if s.bytesSent.Load() != 0 { t.Error("initial bytesSent should be 0") } if s.bytesReceived.Load() != 0 { t.Error("initial bytesReceived should be 0") } } // TestServer_SetResolver 测试设置解析器。 func TestServer_SetResolver(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 设置 nil resolver s.SetResolver(nil) if s.resolver != nil { t.Error("resolver should be nil") } } // TestServer_SetUpgradeManager 测试设置升级管理器。 func TestServer_SetUpgradeManager(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 设置 nil upgrade manager s.SetUpgradeManager(nil) if s.upgradeManager != nil { t.Error("upgradeManager should be nil") } // 设置实际的 upgrade manager um := NewUpgradeManager(s) s.SetUpgradeManager(um) if s.upgradeManager == nil { t.Error("upgradeManager should not be nil after setting") } } // TestServer_GetResolver 测试获取解析器。 func TestServer_GetResolver(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 初始 resolver 应为 nil resolver := s.GetResolver() if resolver != nil { t.Error("expected nil resolver initially") } } // TestServer_StopWithTimeout_WithListeners 测试带监听器的停止。 func TestServer_StopWithTimeout_WithListeners(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) // 创建监听器 ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to create listener: %v", err) } s.listeners = []net.Listener{ln} // 调用停止 err = s.StopWithTimeout(1 * time.Second) if err != nil { t.Errorf("StopWithTimeout failed: %v", err) } } // TestServer_GracefulStop_WithListeners 测试带监听器的优雅停止。 func TestServer_GracefulStop_WithListeners(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) // 创建监听器 ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to create listener: %v", err) } s.listeners = []net.Listener{ln} // 调用优雅停止 err = s.GracefulStop(1 * time.Second) if err != nil { t.Errorf("GracefulStop failed: %v", err) } } // TestServer_StopWithTimeout_WithFastServer 测试带 fastServer 的停止。 func TestServer_StopWithTimeout_WithFastServer(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) // 创建 mock fastServer s.fastServer = &fasthttp.Server{} // 调用停止 err := s.StopWithTimeout(1 * time.Second) if err != nil { t.Errorf("StopWithTimeout failed: %v", err) } } // TestBuildMiddlewareChain_BodyLimit 测试请求体限制中间件。 func TestBuildMiddlewareChain_BodyLimit(t *testing.T) { cfg := &config.Config{ Logging: config.LoggingConfig{}, Servers: []config.ServerConfig{{ Listen: ":8080", ClientMaxBodySize: "1MB", }}, } s := New(cfg) chain, err := s.buildMiddlewareChain(&cfg.Servers[0]) if err != nil { t.Errorf("buildMiddlewareChain failed: %v", err) } if chain == nil { t.Error("Expected non-nil chain") } } // TestBuildMiddlewareChain_ErrorIntercept 测试错误拦截中间件。 func TestBuildMiddlewareChain_ErrorIntercept(t *testing.T) { cfg := &config.Config{ Logging: config.LoggingConfig{}, Servers: []config.ServerConfig{{ Listen: ":8080", Security: config.SecurityConfig{ ErrorPage: config.ErrorPageConfig{ Pages: map[int]string{ 404: "/404.html", }, }, }, }}, } s := New(cfg) chain, err := s.buildMiddlewareChain(&cfg.Servers[0]) if err != nil { t.Errorf("buildMiddlewareChain failed: %v", err) } if chain == nil { t.Error("Expected non-nil chain") } } // TestBuildMiddlewareChain_NilServerConfig 测试 nil 服务器配置。 // 注意:buildMiddlewareChain 不接受 nil,所以这个测试验证空配置。 func TestBuildMiddlewareChain_NilServerConfig(t *testing.T) { cfg := &config.Config{ Logging: config.LoggingConfig{}, Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) chain, err := s.buildMiddlewareChain(&cfg.Servers[0]) if err != nil { t.Errorf("buildMiddlewareChain failed: %v", err) } if chain == nil { t.Error("Expected non-nil chain") } } // TestServer_StatusCode_MethodNotAllowed 测试不支持的 HTTP 方法。 func TestServer_StatusCode_MethodNotAllowed(t *testing.T) { // 简单验证 if fasthttp.StatusMethodNotAllowed != 405 { t.Errorf("StatusMethodNotAllowed should be 405") } } // TestServer_ConnectionTracking 测试连接追踪。 func TestServer_ConnectionTracking(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 测试原子操作 initial := s.connections.Load() s.connections.Add(1) if s.connections.Load() != initial+1 { t.Error("connections should have incremented") } s.connections.Add(-1) if s.connections.Load() != initial { t.Error("connections should be back to initial") } } // TestServer_RequestTracking 测试请求追踪。 func TestServer_RequestTracking(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 测试原子操作 s.requests.Add(5) if s.requests.Load() != 5 { t.Errorf("expected 5 requests, got %d", s.requests.Load()) } } // TestServer_BytesTracking 测试字节追踪。 func TestServer_BytesTracking(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 测试原子操作 s.bytesSent.Add(1024) s.bytesReceived.Add(512) if s.bytesSent.Load() != 1024 { t.Errorf("expected 1024 bytes sent, got %d", s.bytesSent.Load()) } if s.bytesReceived.Load() != 512 { t.Errorf("expected 512 bytes received, got %d", s.bytesReceived.Load()) } } // TestServer_GracefulStop_WithFastServers 测试带多个 fastServer 的优雅停止。 func TestServer_GracefulStop_WithFastServers(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) // 创建多个 fastServer s.fastServers = []*fasthttp.Server{ {}, {}, } // 调用优雅停止 err := s.GracefulStop(1 * time.Second) if err != nil { t.Errorf("GracefulStop failed: %v", err) } } // TestServer_StopWithTimeout_WithFastServers 测试带多个 fastServer 的停止。 func TestServer_StopWithTimeout_WithFastServers(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) // 创建多个 fastServer s.fastServers = []*fasthttp.Server{ {}, {}, } // 调用停止 err := s.StopWithTimeout(1 * time.Second) if err != nil { t.Errorf("StopWithTimeout failed: %v", err) } } // TestServer_GetProxyCacheStats_WithProxies 测试带代理的缓存统计。 func TestServer_GetProxyCacheStats_WithProxies(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) s.proxies = nil // 确保 proxies 为 nil // 无代理时应返回空统计 stats := s.getProxyCacheStats() if stats.Entries != 0 { t.Errorf("Expected 0 entries, got %d", stats.Entries) } } // TestServer_GetProxyCacheStats_SingleProxyWithCache 测试单个代理带缓存的统计。 func TestServer_GetProxyCacheStats_SingleProxyWithCache(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 创建带缓存的代理 proxyCfg := &config.ProxyConfig{ Path: "/api", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, Cache: config.ProxyCacheConfig{ Enabled: true, MaxAge: 10 * time.Second, }, } targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} p, err := proxy.NewProxy(proxyCfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } s.proxies = []*proxy.Proxy{p} // 获取统计 stats := s.getProxyCacheStats() // 新创建的缓存应该有 0 条目 if stats.Entries < 0 { t.Errorf("Expected non-negative entries, got %d", stats.Entries) } if stats.Pending < 0 { t.Errorf("Expected non-negative pending, got %d", stats.Pending) } } // TestServer_GetProxyCacheStats_SingleProxyNoCache 测试单个代理无缓存的统计。 func TestServer_GetProxyCacheStats_SingleProxyNoCache(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 创建不带缓存的代理 proxyCfg := &config.ProxyConfig{ Path: "/api", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, // Cache 未启用 } targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} p, err := proxy.NewProxy(proxyCfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } s.proxies = []*proxy.Proxy{p} // 获取统计 stats := s.getProxyCacheStats() // 无缓存时应返回 0 if stats.Entries != 0 { t.Errorf("Expected 0 entries for proxy without cache, got %d", stats.Entries) } if stats.Pending != 0 { t.Errorf("Expected 0 pending for proxy without cache, got %d", stats.Pending) } } // TestServer_GetProxyCacheStats_MultipleProxies 测试多个代理的缓存统计聚合。 func TestServer_GetProxyCacheStats_MultipleProxies(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) // 创建多个代理:部分带缓存,部分不带 targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} // 代理1:带缓存 proxyCfg1 := &config.ProxyConfig{ Path: "/api", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, Cache: config.ProxyCacheConfig{ Enabled: true, MaxAge: 10 * time.Second, }, } p1, err := proxy.NewProxy(proxyCfg1, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } // 代理2:不带缓存 proxyCfg2 := &config.ProxyConfig{ Path: "/static", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, } p2, err := proxy.NewProxy(proxyCfg2, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } // 代理3:带缓存 proxyCfg3 := &config.ProxyConfig{ Path: "/data", LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, Cache: config.ProxyCacheConfig{ Enabled: true, MaxAge: 20 * time.Second, }, } p3, err := proxy.NewProxy(proxyCfg3, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } s.proxies = []*proxy.Proxy{p1, p2, p3} // 获取聚合统计 stats := s.getProxyCacheStats() // 统计应该非负 if stats.Entries < 0 { t.Errorf("Expected non-negative entries, got %d", stats.Entries) } if stats.Pending < 0 { t.Errorf("Expected non-negative pending, got %d", stats.Pending) } } // TestServer_GetProxyCacheStats_AllProxiesWithCache 测试所有代理都有缓存的统计。 func TestServer_GetProxyCacheStats_AllProxiesWithCache(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} // 创建多个带缓存的代理 proxies := make([]*proxy.Proxy, 3) for i := range 3 { proxyCfg := &config.ProxyConfig{ Path: fmt.Sprintf("/api%d", i), LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, Cache: config.ProxyCacheConfig{ Enabled: true, MaxAge: 10 * time.Second, }, } p, err := proxy.NewProxy(proxyCfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } proxies[i] = p } s.proxies = proxies // 获取统计 stats := s.getProxyCacheStats() // 应该聚合所有代理的统计 if stats.Entries < 0 { t.Errorf("Expected non-negative entries, got %d", stats.Entries) } if stats.Pending < 0 { t.Errorf("Expected non-negative pending, got %d", stats.Pending) } } // TestServer_GetProxyCacheStats_AllProxiesNoCache 测试所有代理都没有缓存的统计。 func TestServer_GetProxyCacheStats_AllProxiesNoCache(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} // 创建多个不带缓存的代理 proxies := make([]*proxy.Proxy, 3) for i := range 3 { proxyCfg := &config.ProxyConfig{ Path: fmt.Sprintf("/api%d", i), LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, } p, err := proxy.NewProxy(proxyCfg, targets, nil, nil) if err != nil { t.Fatalf("NewProxy() error: %v", err) } proxies[i] = p } s.proxies = proxies // 获取统计 stats := s.getProxyCacheStats() // 所有代理都没有缓存,应该返回 0 if stats.Entries != 0 { t.Errorf("Expected 0 entries, got %d", stats.Entries) } if stats.Pending != 0 { t.Errorf("Expected 0 pending, got %d", stats.Pending) } } // TestServer_GetProxyCacheStats_EmptyProxiesSlice 测试空代理切片的统计。 func TestServer_GetProxyCacheStats_EmptyProxiesSlice(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) s.proxies = []*proxy.Proxy{} // 空切片 // 获取统计 stats := s.getProxyCacheStats() if stats.Entries != 0 { t.Errorf("Expected 0 entries, got %d", stats.Entries) } if stats.Pending != 0 { t.Errorf("Expected 0 pending, got %d", stats.Pending) } } // TestServer_MultipleListeners 测试多个监听器。 func TestServer_MultipleListeners(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) // 创建多个监听器 ln1, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to create listener: %v", err) } ln2, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to create listener: %v", err) _ = ln1.Close() } s.listeners = []net.Listener{ln1, ln2} // 验证可以获取监听器 got := s.GetListeners() if len(got) != 2 { t.Errorf("expected 2 listeners, got %d", len(got)) } // 清理 _ = s.StopWithTimeout(1 * time.Second) } // TestGracefulStop_RunningState 测试 GracefulStop 设置 running 为 false。 func TestGracefulStop_RunningState(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) if !s.running.Load() { t.Fatal("running should be true before GracefulStop") } err := s.GracefulStop(1 * time.Second) if err != nil { t.Errorf("GracefulStop failed: %v", err) } if s.running.Load() { t.Error("running should be false after GracefulStop") } } // TestGracefulStop_WithPool 测试 GracefulStop 停止 GoroutinePool。 func TestGracefulStop_WithPool(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, Performance: config.PerformanceConfig{ GoroutinePool: config.GoroutinePoolConfig{ Enabled: true, MaxWorkers: 10, MinWorkers: 2, IdleTimeout: 5 * time.Second, }, }, } s := New(cfg) s.running.Store(true) // 初始化并启动 pool s.pool = initGoroutinePool(&cfg.Performance) if s.pool != nil { s.pool.Start() } err := s.GracefulStop(1 * time.Second) if err != nil { t.Errorf("GracefulStop failed: %v", err) } } // TestGracefulStop_WithHealthCheckers 测试 GracefulStop 停止健康检查器。 func TestGracefulStop_WithHealthCheckers(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) // 创建 mock healthChecker (使用 nil,因为我们只测试循环不会 panic) s.healthCheckers = nil err := s.GracefulStop(1 * time.Second) if err != nil { t.Errorf("GracefulStop failed: %v", err) } } // TestGracefulStop_WithAccessLog 测试 GracefulStop 关闭访问日志。 func TestGracefulStop_WithAccessLog(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) // 创建 accessLogMiddleware s.accessLogMiddleware = accesslog.New(&config.LoggingConfig{}) err := s.GracefulStop(1 * time.Second) if err != nil { t.Errorf("GracefulStop failed: %v", err) } } // TestGracefulStop_WithTLSManager 测试 GracefulStop 关闭 TLS 管理器。 func TestGracefulStop_WithTLSManager(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) // 创建临时证书文件 tempDir := t.TempDir() certFile := tempDir + "/cert.pem" keyFile := tempDir + "/key.pem" // 生成自签名证书用于测试 if err := generateTestCert(certFile, keyFile); err != nil { t.Skipf("failed to generate test cert: %v", err) } tlsMgr, err := ssl.NewTLSManager(&config.SSLConfig{ Cert: certFile, Key: keyFile, }) if err != nil { t.Skipf("failed to create TLS manager: %v", err) } s.tlsManager = tlsMgr err = s.GracefulStop(1 * time.Second) if err != nil { t.Errorf("GracefulStop failed: %v", err) } } // TestGracefulStop_WithLuaEngine 测试 GracefulStop 关闭 Lua 引擎。 func TestGracefulStop_WithLuaEngine(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) // 创建 Lua 引擎 luaEngine, err := lua.NewEngine(lua.DefaultConfig()) if err != nil { t.Skipf("failed to create Lua engine: %v", err) } s.luaEngine = luaEngine err = s.GracefulStop(1 * time.Second) if err != nil { t.Errorf("GracefulStop failed: %v", err) } } // TestGracefulStop_Timeout 测试 GracefulStop 超时场景。 func TestGracefulStop_Timeout(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) // 创建一个真实的 fastServer,但通过模拟长时间关闭来测试超时 s.fastServer = &fasthttp.Server{ Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }, } // 使用非常短的超时 err := s.GracefulStop(1 * time.Nanosecond) // 超时可能返回 context.DeadlineExceeded 或 nil(取决于关闭速度) if err != nil && err != context.DeadlineExceeded { t.Errorf("unexpected error: %v", err) } } // TestGracefulStop_AllComponents 测试 GracefulStop 关闭所有组件。 func TestGracefulStop_AllComponents(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, Performance: config.PerformanceConfig{ GoroutinePool: config.GoroutinePoolConfig{ Enabled: true, MaxWorkers: 10, IdleTimeout: 5 * time.Second, }, }, } s := New(cfg) s.running.Store(true) // 初始化所有组件 s.pool = initGoroutinePool(&cfg.Performance) if s.pool != nil { s.pool.Start() } s.accessLogMiddleware = accesslog.New(&config.LoggingConfig{}) // 创建监听器 ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to create listener: %v", err) } s.listeners = []net.Listener{ln} err = s.GracefulStop(2 * time.Second) if err != nil { t.Errorf("GracefulStop failed: %v", err) } // 验证 running 状态 if s.running.Load() { t.Error("running should be false after GracefulStop") } } // generateTestCert 生成测试用的自签名证书。 func generateTestCert(certFile, keyFile string) error { // 简化实现:跳过证书生成 return fmt.Errorf("test cert generation not implemented") } // TestGracefulStop_WithAccessControl 测试 GracefulStop 关闭访问控制。 func TestGracefulStop_WithAccessControl(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", Security: config.SecurityConfig{ Access: config.AccessConfig{ Allow: []string{"127.0.0.1"}, }, }, }}, } s := New(cfg) s.running.Store(true) // 创建 AccessControl ac, err := security.NewAccessControl(&cfg.Servers[0].Security.Access) if err != nil { t.Skipf("failed to create AccessControl: %v", err) } s.accessControl = ac err = s.GracefulStop(1 * time.Second) if err != nil { t.Errorf("GracefulStop failed: %v", err) } } // TestGracefulStop_ContextCancelled 测试 GracefulStop 上下文取消场景。 func TestGracefulStop_ContextCancelled(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) // 创建一个监听中的服务器 ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to create listener: %v", err) } s.listeners = []net.Listener{ln} // 创建 fastServer 并开始服务 s.fastServer = &fasthttp.Server{ Handler: func(ctx *fasthttp.RequestCtx) { time.Sleep(100 * time.Millisecond) // 模拟慢请求 ctx.SetBodyString("ok") }, } // 启动服务器 go func() { _ = s.fastServer.Serve(ln) }() // 等待服务器启动 time.Sleep(10 * time.Millisecond) // 使用非常短的超时测试超时场景 err = s.GracefulStop(1 * time.Nanosecond) // 超时可能返回 context.DeadlineExceeded 或 nil if err != nil && err != context.DeadlineExceeded { t.Errorf("unexpected error: %v", err) } } // TestGracefulStop_MultipleHealthCheckers 测试 GracefulStop 停止多个健康检查器。 func TestGracefulStop_MultipleHealthCheckers(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) // 创建多个 mock healthChecker // 注意:这里使用 nil slice 测试空循环不会 panic s.healthCheckers = make([]*proxy.HealthChecker, 0) err := s.GracefulStop(1 * time.Second) if err != nil { t.Errorf("GracefulStop failed: %v", err) } } // TestGracefulStop_NilComponents 测试 GracefulStop 所有组件为 nil。 func TestGracefulStop_NilComponents(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) // 确保所有组件为 nil s.pool = nil s.healthCheckers = nil s.accessLogMiddleware = nil s.tlsManager = nil s.accessControl = nil s.luaEngine = nil s.fastServer = nil s.fastServers = nil err := s.GracefulStop(1 * time.Second) if err != nil { t.Errorf("GracefulStop failed: %v", err) } if s.running.Load() { t.Error("running should be false after GracefulStop") } } // TestGracefulStop_FastServersWithNil 测试 GracefulStop 处理 fastServers 中的 nil。 func TestGracefulStop_FastServersWithNil(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) // 创建包含 nil 的 fastServers s.fastServers = []*fasthttp.Server{nil, {}, nil} err := s.GracefulStop(1 * time.Second) if err != nil { t.Errorf("GracefulStop failed: %v", err) } } // TestGracefulStop_ZeroTimeout 测试 GracefulStop 零超时。 func TestGracefulStop_ZeroTimeout(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) s.fastServer = &fasthttp.Server{} err := s.GracefulStop(0) // 零超时应该立即返回(可能导致超时错误或成功关闭) if err != nil && err != context.DeadlineExceeded { t.Errorf("unexpected error: %v", err) } } // TestGracefulStop_NegativeTimeout 测试 GracefulStop 负超时。 func TestGracefulStop_NegativeTimeout(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":0", }}, } s := New(cfg) s.running.Store(true) s.fastServer = &fasthttp.Server{} err := s.GracefulStop(-1 * time.Second) // 负超时应该立即返回 if err != nil && err != context.DeadlineExceeded { t.Errorf("unexpected error: %v", err) } } // TestStartSingleMode_StaticFiles 测试 startSingleMode 静态文件配置。 func TestStartSingleMode_StaticFiles(t *testing.T) { tempDir := t.TempDir() cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", Static: []config.StaticConfig{ { Path: "/static", Root: tempDir, Index: []string{"index.html"}, }, { Path: "/assets", Root: tempDir, LocationType: "exact", SymlinkCheck: true, Internal: true, TryFiles: []string{"$uri", "/fallback.html"}, TryFilesPass: true, }, }, }}, } s := New(cfg) // 验证静态文件配置已正确设置 if len(s.config.Servers[0].Static) != 2 { t.Errorf("expected 2 static configs, got %d", len(s.config.Servers[0].Static)) } // 验证第一个静态配置 static1 := s.config.Servers[0].Static[0] if static1.Path != "/static" { t.Errorf("expected path /static, got %s", static1.Path) } if static1.Root != tempDir { t.Errorf("expected root %s, got %s", tempDir, static1.Root) } } // TestStartSingleMode_StaticFilesWithGzipStatic 测试静态文件 gzip 预压缩配置。 func TestStartSingleMode_StaticFilesWithGzipStatic(t *testing.T) { tempDir := t.TempDir() cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", Static: []config.StaticConfig{ { Path: "/", Root: tempDir, Index: []string{"index.html"}, }, }, Compression: config.CompressionConfig{ Type: "gzip", Level: 6, GzipStatic: true, GzipStaticExtensions: []string{".html", ".css", ".js"}, }, }}, } s := New(cfg) // 验证 gzip 静态配置 if !s.config.Servers[0].Compression.GzipStatic { t.Error("expected GzipStatic to be true") } if len(s.config.Servers[0].Compression.GzipStaticExtensions) != 3 { t.Errorf("expected 3 extensions, got %d", len(s.config.Servers[0].Compression.GzipStaticExtensions)) } } // TestStartSingleMode_ProxyWithLocationTypes 测试代理配置的不同位置类型。 func TestStartSingleMode_ProxyWithLocationTypes(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", Proxy: []config.ProxyConfig{ { Path: "/api/exact", LocationType: "exact", Targets: []config.ProxyTarget{ {URL: "http://127.0.0.1:8081", Weight: 1}, }, }, { Path: "/api/priority", LocationType: "prefix_priority", Targets: []config.ProxyTarget{ {URL: "http://127.0.0.1:8082", Weight: 1}, }, }, { Path: "^/api/regex/(.*)$", LocationType: "regex", Targets: []config.ProxyTarget{ {URL: "http://127.0.0.1:8083", Weight: 1}, }, }, { Path: "^/api/caseless/(.*)$", LocationType: "regex_caseless", Targets: []config.ProxyTarget{ {URL: "http://127.0.0.1:8084", Weight: 1}, }, }, { Path: "/api/named", LocationType: "named", LocationName: "@api_named", Targets: []config.ProxyTarget{ {URL: "http://127.0.0.1:8085", Weight: 1}, }, }, { Path: "/api/default", // 默认 prefix 类型 Targets: []config.ProxyTarget{ {URL: "http://127.0.0.1:8086", Weight: 1}, }, Internal: true, }, }, }}, } s := New(cfg) // 验证代理配置数量 if len(s.config.Servers[0].Proxy) != 6 { t.Errorf("expected 6 proxy configs, got %d", len(s.config.Servers[0].Proxy)) } // 验证不同位置类型 proxyTypes := []string{"exact", "prefix_priority", "regex", "regex_caseless", "named", ""} for i, pt := range proxyTypes { if s.config.Servers[0].Proxy[i].LocationType != pt { t.Errorf("proxy[%d]: expected location type %s, got %s", i, pt, s.config.Servers[0].Proxy[i].LocationType) } } } // TestStartSingleMode_ProxyWithHealthCheck 测试代理健康检查配置。 func TestStartSingleMode_ProxyWithHealthCheck(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", Proxy: []config.ProxyConfig{ { Path: "/api", Targets: []config.ProxyTarget{ { URL: "http://127.0.0.1:8081", Weight: 3, MaxFails: 3, FailTimeout: 10 * time.Second, MaxConns: 100, Backup: false, Down: false, }, { URL: "http://127.0.0.1:8082", Weight: 1, Backup: true, }, }, LoadBalance: "weighted_round_robin", HealthCheck: config.HealthCheckConfig{ Interval: 10 * time.Second, Timeout: 5 * time.Second, Path: "/health", }, }, }, }}, } s := New(cfg) // 验证健康检查配置 hc := s.config.Servers[0].Proxy[0].HealthCheck if hc.Interval != 10*time.Second { t.Errorf("expected interval 10s, got %v", hc.Interval) } if hc.Path != "/health" { t.Errorf("expected path /health, got %s", hc.Path) } } // TestStartSingleMode_MonitoringEndpoints 测试监控端点配置。 func TestStartSingleMode_MonitoringEndpoints(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", }}, Monitoring: config.MonitoringConfig{ Status: config.StatusConfig{ Enabled: true, Path: "/_status", Format: "json", Allow: []string{"127.0.0.1", "192.168.0.0/16"}, }, Pprof: config.PprofConfig{ Enabled: true, Path: "/debug/pprof", Allow: []string{"127.0.0.1"}, }, }, } s := New(cfg) // 验证状态端点配置 if !s.config.Monitoring.Status.Enabled { t.Error("expected status enabled") } if s.config.Monitoring.Status.Path != "/_status" { t.Errorf("expected status path /_status, got %s", s.config.Monitoring.Status.Path) } if len(s.config.Monitoring.Status.Allow) != 2 { t.Errorf("expected 2 allowed IPs, got %d", len(s.config.Monitoring.Status.Allow)) } // 验证 pprof 配置 if !s.config.Monitoring.Pprof.Enabled { t.Error("expected pprof enabled") } } // TestStartSingleMode_CacheAPI 测试缓存 API 配置。 func TestStartSingleMode_CacheAPI(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", CacheAPI: &config.CacheAPIConfig{ Enabled: true, Path: "/_cache/purge", Allow: []string{"127.0.0.1"}, Auth: config.CacheAPIAuthConfig{Type: "token", Token: "secret-token"}, }, }}, } s := New(cfg) // 验证缓存 API 配置 if s.config.Servers[0].CacheAPI == nil || !s.config.Servers[0].CacheAPI.Enabled { t.Error("expected cache API enabled") } if s.config.Servers[0].CacheAPI.Path != "/_cache/purge" { t.Errorf("expected path /_cache/purge, got %s", s.config.Servers[0].CacheAPI.Path) } } // TestStartSingleMode_TLSConfig 测试 TLS 配置。 func TestStartSingleMode_TLSConfig(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", SSL: config.SSLConfig{ Cert: "/path/to/cert.pem", Key: "/path/to/key.pem", Protocols: []string{"TLSv1.2", "TLSv1.3"}, Ciphers: []string{"TLS_AES_128_GCM_SHA256"}, HSTS: config.HSTSConfig{ MaxAge: 31536000, IncludeSubDomains: true, Preload: true, }, }, }}, } s := New(cfg) // 验证 SSL 配置 if s.config.Servers[0].SSL.Cert != "/path/to/cert.pem" { t.Errorf("expected cert path, got %s", s.config.Servers[0].SSL.Cert) } if s.config.Servers[0].SSL.HSTS.MaxAge != 31536000 { t.Errorf("expected HSTS MaxAge 31536000, got %d", s.config.Servers[0].SSL.HSTS.MaxAge) } } // TestStartSingleMode_MIMETypes 测试 MIME 类型配置。 func TestStartSingleMode_MIMETypes(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", Types: config.TypesConfig{ Map: map[string]string{ ".wasm": "application/wasm", ".custom": "application/x-custom", }, DefaultType: "application/octet-stream", }, }}, } s := New(cfg) // 验证 MIME 类型配置 if len(s.config.Servers[0].Types.Map) != 2 { t.Errorf("expected 2 MIME types, got %d", len(s.config.Servers[0].Types.Map)) } if s.config.Servers[0].Types.DefaultType != "application/octet-stream" { t.Errorf("expected default type, got %s", s.config.Servers[0].Types.DefaultType) } } // TestStartSingleMode_ServerOptions 测试服务器选项配置。 func TestStartSingleMode_ServerOptions(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, IdleTimeout: 60 * time.Second, MaxConnsPerIP: 100, MaxRequestsPerConn: 1000, Concurrency: 256 * 1024, ReadBufferSize: 16 * 1024, WriteBufferSize: 16 * 1024, ReduceMemoryUsage: true, ServerTokens: false, }}, } s := New(cfg) // 验证服务器选项 sc := s.config.Servers[0] if sc.ReadTimeout != 30*time.Second { t.Errorf("expected ReadTimeout 30s, got %v", sc.ReadTimeout) } if sc.MaxConnsPerIP != 100 { t.Errorf("expected MaxConnsPerIP 100, got %d", sc.MaxConnsPerIP) } if !sc.ReduceMemoryUsage { t.Error("expected ReduceMemoryUsage true") } if sc.ServerTokens { t.Error("expected ServerTokens false") } } // TestStartSingleMode_WithMiddlewareChain 测试中间件链配置。 func TestStartSingleMode_WithMiddlewareChain(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", Security: config.SecurityConfig{ Access: config.AccessConfig{ Allow: []string{"127.0.0.1"}, Deny: []string{"10.0.0.0/8"}, }, RateLimit: config.RateLimitConfig{ RequestRate: 100, Burst: 200, Key: "remote_addr", }, Auth: config.AuthConfig{ Users: []config.User{ {Name: "admin", Password: "secret"}, }, }, Headers: config.SecurityHeaders{ XFrameOptions: "DENY", XContentTypeOptions: "nosniff", ContentSecurityPolicy: "default-src 'self'", ReferrerPolicy: "strict-origin-when-cross-origin", }, }, Compression: config.CompressionConfig{ Type: "gzip", Level: 6, }, Rewrite: []config.RewriteRule{ {Pattern: "^/old/(.*)$", Replacement: "/new/$1"}, }, }}, } s := New(cfg) // 验证中间件配置 security := s.config.Servers[0].Security if len(security.Access.Allow) != 1 { t.Errorf("expected 1 allow rule, got %d", len(security.Access.Allow)) } if security.RateLimit.RequestRate != 100 { t.Errorf("expected request rate 100, got %d", security.RateLimit.RequestRate) } if len(security.Auth.Users) != 1 { t.Errorf("expected 1 auth user, got %d", len(security.Auth.Users)) } } // TestStartSingleMode_PerformanceConfig 测试性能配置。 func TestStartSingleMode_PerformanceConfig(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", }}, Performance: config.PerformanceConfig{ GoroutinePool: config.GoroutinePoolConfig{ Enabled: true, MaxWorkers: 100, MinWorkers: 10, IdleTimeout: 30 * time.Second, }, FileCache: config.FileCacheConfig{ MaxEntries: 10000, MaxSize: 100 * 1024 * 1024, }, }, } s := New(cfg) // 验证性能配置 if !s.config.Performance.GoroutinePool.Enabled { t.Error("expected goroutine pool enabled") } if s.config.Performance.FileCache.MaxEntries != 10000 { t.Errorf("expected 10000 max entries, got %d", s.config.Performance.FileCache.MaxEntries) } } // TestStartSingleMode_WithLuaMiddleware 测试 Lua 中间件配置。 func TestStartSingleMode_WithLuaMiddleware(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", Lua: &config.LuaMiddlewareConfig{ Enabled: true, Scripts: []config.LuaScriptConfig{ { Path: "/scripts/access.lua", Phase: "access", Timeout: 30 * time.Second, }, { Path: "/scripts/header.lua", Phase: "header_filter", Timeout: 10 * time.Second, }, }, }, }}, } s := New(cfg) // 验证 Lua 配置 if s.config.Servers[0].Lua == nil || !s.config.Servers[0].Lua.Enabled { t.Error("expected Lua enabled") } if len(s.config.Servers[0].Lua.Scripts) != 2 { t.Errorf("expected 2 scripts, got %d", len(s.config.Servers[0].Lua.Scripts)) } } // TestStartSingleMode_WithErrorPage 测试错误页面配置。 func TestStartSingleMode_WithErrorPage(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", Security: config.SecurityConfig{ ErrorPage: config.ErrorPageConfig{ Pages: map[int]string{ 404: "/errors/404.html", 500: "/errors/500.html", 502: "/errors/502.html", }, Default: "/errors/default.html", }, }, }}, } s := New(cfg) // 验证错误页面配置 ep := s.config.Servers[0].Security.ErrorPage if len(ep.Pages) != 3 { t.Errorf("expected 3 error pages, got %d", len(ep.Pages)) } if ep.Default != "/errors/default.html" { t.Errorf("expected default error page, got %s", ep.Default) } } // TestStartSingleMode_WithConnLimiter 测试连接限制配置。 func TestStartSingleMode_WithConnLimiter(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", Security: config.SecurityConfig{ RateLimit: config.RateLimitConfig{ ConnLimit: 100, Key: "remote_addr", }, }, }}, } s := New(cfg) // 验证连接限制配置 if s.config.Servers[0].Security.RateLimit.ConnLimit != 100 { t.Errorf("expected ConnLimit 100, got %d", s.config.Servers[0].Security.RateLimit.ConnLimit) } } // TestStartSingleMode_WithAuthRequest 测试外部认证配置。 func TestStartSingleMode_WithAuthRequest(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: "127.0.0.1:0", Security: config.SecurityConfig{ AuthRequest: config.AuthRequestConfig{ Enabled: true, URI: "/auth/validate", Timeout: 5 * time.Second, }, }, }}, } s := New(cfg) // 验证外部认证配置 ar := s.config.Servers[0].Security.AuthRequest if !ar.Enabled { t.Error("expected AuthRequest enabled") } if ar.URI != "/auth/validate" { t.Errorf("expected URI /auth/validate, got %s", ar.URI) } } // TestShutdownServers_EmptySlice 测试空服务器列表。 func TestShutdownServers_EmptySlice(t *testing.T) { ctx := context.Background() err := shutdownServers(ctx, []*fasthttp.Server{}) if err != nil { t.Errorf("shutdownServers with empty slice should return nil, got: %v", err) } } // TestShutdownServers_NilSlice 测试 nil 服务器列表。 func TestShutdownServers_NilSlice(t *testing.T) { ctx := context.Background() err := shutdownServers(ctx, nil) if err != nil { t.Errorf("shutdownServers with nil slice should return nil, got: %v", err) } } // TestShutdownServers_NilContext 测试 nil 上下文。 func TestShutdownServers_NilContext(t *testing.T) { // nil ctx 应该使用 context.Background() err := shutdownServers(nil, []*fasthttp.Server{}) if err != nil { t.Errorf("shutdownServers with nil ctx should return nil, got: %v", err) } } // TestShutdownServers_SingleServer 测试单个服务器关闭。 func TestShutdownServers_SingleServer(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() servers := []*fasthttp.Server{ {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, } err := shutdownServers(ctx, servers) if err != nil { t.Errorf("shutdownServers failed: %v", err) } } // TestShutdownServers_MultipleServers 测试多个服务器关闭。 func TestShutdownServers_MultipleServers(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() servers := []*fasthttp.Server{ {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test1") }}, {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test2") }}, {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test3") }}, } err := shutdownServers(ctx, servers) if err != nil { t.Errorf("shutdownServers failed: %v", err) } } // TestShutdownServers_WithNilServers 测试服务器列表中包含 nil。 func TestShutdownServers_WithNilServers(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() servers := []*fasthttp.Server{ nil, {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, nil, } err := shutdownServers(ctx, servers) if err != nil { t.Errorf("shutdownServers failed: %v", err) } } // TestShutdownServers_AllNilServers 测试所有服务器都是 nil。 func TestShutdownServers_AllNilServers(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() servers := []*fasthttp.Server{nil, nil, nil} err := shutdownServers(ctx, servers) if err != nil { t.Errorf("shutdownServers with all nil servers should return nil, got: %v", err) } } // TestShutdownServers_ContextCancelled 测试上下文取消。 func TestShutdownServers_ContextCancelled(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } // 创建一个已取消的上下文 ctx, cancel := context.WithCancel(context.Background()) cancel() servers := []*fasthttp.Server{ {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, } err := shutdownServers(ctx, servers) // 已取消的上下文可能返回 context.Canceled 或 nil(取决于服务器关闭速度) if err != nil && err != context.Canceled { t.Errorf("unexpected error: %v", err) } } // TestShutdownServers_ContextTimeout 测试上下文超时。 func TestShutdownServers_ContextTimeout(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } // 创建一个极短超时的上下文 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() // 等待超时 time.Sleep(1 * time.Millisecond) servers := []*fasthttp.Server{ {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, } err := shutdownServers(ctx, servers) // 超时的上下文可能返回 context.DeadlineExceeded 或 nil if err != nil && err != context.DeadlineExceeded { t.Errorf("unexpected error: %v", err) } } // TestShutdownServers_RunningServers 测试关闭运行中的服务器。 func TestShutdownServers_RunningServers(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // 创建服务器并启动 servers := make([]*fasthttp.Server, 2) listeners := make([]net.Listener, 2) for i := range 2 { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to create listener: %v", err) } listeners[i] = ln srv := &fasthttp.Server{ Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }, } servers[i] = srv go func(s *fasthttp.Server, l net.Listener) { _ = s.Serve(l) }(srv, ln) } // 等待服务器启动 time.Sleep(10 * time.Millisecond) // 关闭服务器 err := shutdownServers(ctx, servers) if err != nil { t.Errorf("shutdownServers failed: %v", err) } // 关闭监听器(如果服务器没有关闭它们) for _, ln := range listeners { _ = ln.Close() } } // TestShutdownServers_ManyServers 测试关闭大量服务器。 func TestShutdownServers_ManyServers(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // 创建大量服务器 count := 50 servers := make([]*fasthttp.Server, count) for i := range count { servers[i] = &fasthttp.Server{ Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }, } } err := shutdownServers(ctx, servers) if err != nil { t.Errorf("shutdownServers with many servers failed: %v", err) } } // TestShutdownServers_MixedNilAndRealServers 测试混合 nil 和真实服务器。 func TestShutdownServers_MixedNilAndRealServers(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() count := 20 servers := make([]*fasthttp.Server, count) for i := range count { if i%2 == 0 { servers[i] = nil } else { servers[i] = &fasthttp.Server{ Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }, } } } err := shutdownServers(ctx, servers) if err != nil { t.Errorf("shutdownServers failed: %v", err) } } // TestShutdownServers_ConcurrentSafety 测试并发安全性。 func TestShutdownServers_ConcurrentSafety(t *testing.T) { ctx := context.Background() // 并发调用 shutdownServers var wg sync.WaitGroup for range 10 { wg.Go(func() { servers := []*fasthttp.Server{ {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, } _ = shutdownServers(ctx, servers) }) } wg.Wait() } // TestShutdownServers_WithDeadline 测试带截止时间的上下文。 func TestShutdownServers_WithDeadline(t *testing.T) { deadline := time.Now().Add(5 * time.Second) ctx, cancel := context.WithDeadline(context.Background(), deadline) defer cancel() servers := []*fasthttp.Server{ {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, } err := shutdownServers(ctx, servers) if err != nil { t.Errorf("shutdownServers failed: %v", err) } } // TestBuildLuaMiddlewares_SingleScript 测试单个脚本配置。 func TestBuildLuaMiddlewares_SingleScript(t *testing.T) { // 创建临时 Lua 脚本 tempDir := t.TempDir() scriptPath := tempDir + "/test.lua" if err := os.WriteFile(scriptPath, []byte("ngx.say('hello')"), 0o644); err != nil { t.Fatalf("failed to create script: %v", err) } cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) luaEngine, err := lua.NewEngine(lua.DefaultConfig()) if err != nil { t.Skipf("failed to create Lua engine: %v", err) } s.luaEngine = luaEngine luaCfg := &config.LuaMiddlewareConfig{ Enabled: true, Scripts: []config.LuaScriptConfig{ {Path: scriptPath, Phase: "access", Timeout: 10 * time.Second, Enabled: true}, }, } middlewares, err := s.buildLuaMiddlewares(luaCfg) if err != nil { t.Errorf("expected nil error, got: %v", err) } if len(middlewares) != 1 { t.Errorf("expected 1 middleware, got: %d", len(middlewares)) } } // TestBuildLuaMiddlewares_SingleScriptDefaultTimeout 测试单脚本默认超时。 func TestBuildLuaMiddlewares_SingleScriptDefaultTimeout(t *testing.T) { tempDir := t.TempDir() scriptPath := tempDir + "/test.lua" if err := os.WriteFile(scriptPath, []byte("ngx.say('hello')"), 0o644); err != nil { t.Fatalf("failed to create script: %v", err) } cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) luaEngine, err := lua.NewEngine(lua.DefaultConfig()) if err != nil { t.Skipf("failed to create Lua engine: %v", err) } s.luaEngine = luaEngine luaCfg := &config.LuaMiddlewareConfig{ Enabled: true, Scripts: []config.LuaScriptConfig{ {Path: scriptPath, Phase: "content", Timeout: 0}, // 使用默认超时 }, } middlewares, err := s.buildLuaMiddlewares(luaCfg) if err != nil { t.Errorf("expected nil error, got: %v", err) } if len(middlewares) != 1 { t.Errorf("expected 1 middleware, got: %d", len(middlewares)) } } // TestBuildLuaMiddlewares_MultipleScriptsSamePhase 测试多脚本同阶段。 func TestBuildLuaMiddlewares_MultipleScriptsSamePhase(t *testing.T) { tempDir := t.TempDir() script1 := tempDir + "/test1.lua" script2 := tempDir + "/test2.lua" if err := os.WriteFile(script1, []byte("ngx.say('1')"), 0o644); err != nil { t.Fatalf("failed to create script: %v", err) } if err := os.WriteFile(script2, []byte("ngx.say('2')"), 0o644); err != nil { t.Fatalf("failed to create script: %v", err) } cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) luaEngine, err := lua.NewEngine(lua.DefaultConfig()) if err != nil { t.Skipf("failed to create Lua engine: %v", err) } s.luaEngine = luaEngine luaCfg := &config.LuaMiddlewareConfig{ Enabled: true, Scripts: []config.LuaScriptConfig{ {Path: script1, Phase: "access", Timeout: 10 * time.Second, Enabled: true}, {Path: script2, Phase: "access", Timeout: 20 * time.Second, Enabled: true}, }, } middlewares, err := s.buildLuaMiddlewares(luaCfg) if err != nil { t.Errorf("expected nil error, got: %v", err) } if len(middlewares) != 1 { t.Errorf("expected 1 middleware (multi-phase), got: %d", len(middlewares)) } } // TestBuildLuaMiddlewares_MultipleScriptsDifferentPhases 测试多脚本不同阶段。 func TestBuildLuaMiddlewares_MultipleScriptsDifferentPhases(t *testing.T) { tempDir := t.TempDir() script1 := tempDir + "/rewrite.lua" script2 := tempDir + "/access.lua" script3 := tempDir + "/log.lua" for _, p := range []string{script1, script2, script3} { if err := os.WriteFile(p, []byte("ngx.say('hello')"), 0o644); err != nil { t.Fatalf("failed to create script: %v", err) } } cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) luaEngine, err := lua.NewEngine(lua.DefaultConfig()) if err != nil { t.Skipf("failed to create Lua engine: %v", err) } s.luaEngine = luaEngine luaCfg := &config.LuaMiddlewareConfig{ Enabled: true, Scripts: []config.LuaScriptConfig{ {Path: script1, Phase: "rewrite", Timeout: 10 * time.Second, Enabled: true}, {Path: script2, Phase: "access", Timeout: 15 * time.Second, Enabled: true}, {Path: script3, Phase: "log", Timeout: 20 * time.Second, Enabled: true}, }, } middlewares, err := s.buildLuaMiddlewares(luaCfg) if err != nil { t.Errorf("expected nil error, got: %v", err) } if len(middlewares) != 3 { t.Errorf("expected 3 middlewares, got: %d", len(middlewares)) } } // TestBuildLuaMiddlewares_DefaultEnabled 测试默认启用逻辑。 func TestBuildLuaMiddlewares_DefaultEnabled(t *testing.T) { tempDir := t.TempDir() scriptPath := tempDir + "/test.lua" if err := os.WriteFile(scriptPath, []byte("ngx.say('hello')"), 0o644); err != nil { t.Fatalf("failed to create script: %v", err) } cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) luaEngine, err := lua.NewEngine(lua.DefaultConfig()) if err != nil { t.Skipf("failed to create Lua engine: %v", err) } s.luaEngine = luaEngine // Enabled 为 false,但 Timeout=0 且 Path 不为空,应该默认启用 luaCfg := &config.LuaMiddlewareConfig{ Enabled: true, Scripts: []config.LuaScriptConfig{ {Path: scriptPath, Phase: "access", Timeout: 0, Enabled: false}, }, } middlewares, err := s.buildLuaMiddlewares(luaCfg) if err != nil { t.Errorf("expected nil error, got: %v", err) } // 默认启用逻辑:Enabled=false && Timeout=0 && Path!="" -> enabled=true if len(middlewares) != 1 { t.Errorf("expected 1 middleware (default enabled), got: %d", len(middlewares)) } } // TestBuildLuaMiddlewares_InvalidPhaseInMultiScript 测试多脚本中的无效阶段。 func TestBuildLuaMiddlewares_InvalidPhaseInMultiScript(t *testing.T) { tempDir := t.TempDir() script1 := tempDir + "/test1.lua" script2 := tempDir + "/test2.lua" if err := os.WriteFile(script1, []byte("ngx.say('1')"), 0o644); err != nil { t.Fatalf("failed to create script: %v", err) } if err := os.WriteFile(script2, []byte("ngx.say('2')"), 0o644); err != nil { t.Fatalf("failed to create script: %v", err) } cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) luaEngine, err := lua.NewEngine(lua.DefaultConfig()) if err != nil { t.Skipf("failed to create Lua engine: %v", err) } s.luaEngine = luaEngine luaCfg := &config.LuaMiddlewareConfig{ Enabled: true, Scripts: []config.LuaScriptConfig{ {Path: script1, Phase: "access", Timeout: 10 * time.Second, Enabled: true}, {Path: script2, Phase: "invalid_phase", Timeout: 10 * time.Second, Enabled: true}, }, } middlewares, err := s.buildLuaMiddlewares(luaCfg) if err == nil { t.Error("expected error for invalid phase in multi-script") } if middlewares != nil { t.Errorf("expected nil middlewares on error, got: %v", middlewares) } } // TestBuildLuaMiddlewares_AllPhases 测试所有阶段。 func TestBuildLuaMiddlewares_AllPhases(t *testing.T) { tempDir := t.TempDir() phases := []string{"rewrite", "access", "content", "log", "header_filter", "body_filter"} cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) luaEngine, err := lua.NewEngine(lua.DefaultConfig()) if err != nil { t.Skipf("failed to create Lua engine: %v", err) } s.luaEngine = luaEngine scripts := make([]config.LuaScriptConfig, len(phases)) for i, phase := range phases { scriptPath := tempDir + "/" + phase + ".lua" if err := os.WriteFile(scriptPath, []byte("ngx.say('"+phase+"')"), 0o644); err != nil { t.Fatalf("failed to create script: %v", err) } scripts[i] = config.LuaScriptConfig{Path: scriptPath, Phase: phase, Timeout: 10 * time.Second, Enabled: true} } luaCfg := &config.LuaMiddlewareConfig{ Enabled: true, Scripts: scripts, } middlewares, err := s.buildLuaMiddlewares(luaCfg) if err != nil { t.Errorf("expected nil error, got: %v", err) } if len(middlewares) != len(phases) { t.Errorf("expected %d middlewares, got: %d", len(phases), len(middlewares)) } } // TestBuildLuaMiddlewares_NonExistentScript 测试不存在的脚本文件。 func TestBuildLuaMiddlewares_NonExistentScript(t *testing.T) { cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) luaEngine, err := lua.NewEngine(lua.DefaultConfig()) if err != nil { t.Skipf("failed to create Lua engine: %v", err) } s.luaEngine = luaEngine luaCfg := &config.LuaMiddlewareConfig{ Enabled: true, Scripts: []config.LuaScriptConfig{ {Path: "/non/existent/script.lua", Phase: "access", Timeout: 10 * time.Second}, }, } // NewLuaMiddleware 会在创建时验证脚本文件 middlewares, err := s.buildLuaMiddlewares(luaCfg) // 由于脚本不存在,可能会返回错误或创建失败 // 这取决于 lua.NewLuaMiddleware 的实现 _ = middlewares _ = err } // TestBuildLuaMiddlewares_MixedEnabledDisabled 测试混合启用禁用脚本。 func TestBuildLuaMiddlewares_MixedEnabledDisabled(t *testing.T) { tempDir := t.TempDir() for _, name := range []string{"enabled1", "enabled2", "disabled1", "disabled2"} { scriptPath := tempDir + "/" + name + ".lua" if err := os.WriteFile(scriptPath, []byte("ngx.say('"+name+"')"), 0o644); err != nil { t.Fatalf("failed to create script: %v", err) } } cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) luaEngine, err := lua.NewEngine(lua.DefaultConfig()) if err != nil { t.Skipf("failed to create Lua engine: %v", err) } s.luaEngine = luaEngine luaCfg := &config.LuaMiddlewareConfig{ Enabled: true, Scripts: []config.LuaScriptConfig{ {Path: tempDir + "/enabled1.lua", Phase: "rewrite", Timeout: 10 * time.Second, Enabled: true}, {Path: tempDir + "/disabled1.lua", Phase: "rewrite", Timeout: 10 * time.Second, Enabled: false}, {Path: tempDir + "/enabled2.lua", Phase: "access", Timeout: 10 * time.Second, Enabled: true}, {Path: tempDir + "/disabled2.lua", Phase: "access", Timeout: 10 * time.Second, Enabled: false}, }, } middlewares, err := s.buildLuaMiddlewares(luaCfg) if err != nil { t.Errorf("expected nil error, got: %v", err) } // 只有启用的脚本应该被处理:rewrite(1) + access(1) = 2 if len(middlewares) != 2 { t.Errorf("expected 2 middlewares, got: %d", len(middlewares)) } } // TestBuildLuaMiddlewares_MultiPhaseDefaultTimeout 测试多脚本阶段默认超时。 func TestBuildLuaMiddlewares_MultiPhaseDefaultTimeout(t *testing.T) { tempDir := t.TempDir() script1 := tempDir + "/test1.lua" script2 := tempDir + "/test2.lua" if err := os.WriteFile(script1, []byte("ngx.say('1')"), 0o644); err != nil { t.Fatalf("failed to create script: %v", err) } if err := os.WriteFile(script2, []byte("ngx.say('2')"), 0o644); err != nil { t.Fatalf("failed to create script: %v", err) } cfg := &config.Config{ Servers: []config.ServerConfig{{ Listen: ":8080", }}, } s := New(cfg) luaEngine, err := lua.NewEngine(lua.DefaultConfig()) if err != nil { t.Skipf("failed to create Lua engine: %v", err) } s.luaEngine = luaEngine luaCfg := &config.LuaMiddlewareConfig{ Enabled: true, Scripts: []config.LuaScriptConfig{ {Path: script1, Phase: "access", Timeout: 0}, // 默认超时 {Path: script2, Phase: "access", Timeout: 0}, // 默认超时 }, } middlewares, err := s.buildLuaMiddlewares(luaCfg) if err != nil { t.Errorf("expected nil error, got: %v", err) } if len(middlewares) != 1 { t.Errorf("expected 1 middleware (multi-phase), got: %d", len(middlewares)) } }