diff --git a/internal/lua/api_balancer_test.go b/internal/lua/api_balancer_test.go new file mode 100644 index 0000000..52d1826 --- /dev/null +++ b/internal/lua/api_balancer_test.go @@ -0,0 +1,161 @@ +// Package lua 提供 ngx.balancer API 的测试 +// +// 该文件测试负载均衡相关的 Lua API,包括: +// - BalancerContext 创建和管理 +// - set_current_peer API +// - set_more_tries API +// - get_last_failure API +// - get_targets API +// - get_client_ip API +// - IsSelected 边界测试 +// +// 作者:xfy +package lua + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "rua.plus/lolly/internal/loadbalance" +) + +// TestBalancerContext_IsSelected 测试 IsSelected 方法 +func TestBalancerContext_IsSelected(t *testing.T) { + tests := []struct { + name string + selected bool + wantResult bool + }{ + { + name: "已选择目标", + selected: true, + wantResult: true, + }, + { + name: "未选择目标", + selected: false, + wantResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bctx := &BalancerContext{ + Targets: []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + }, + ClientIP: "127.0.0.1", + Retries: 3, + selected: tt.selected, + } + + if tt.selected { + bctx.Selected = bctx.Targets[0] + } + + result := bctx.IsSelected() + assert.Equal(t, tt.wantResult, result, "IsSelected() should return expected value") + }) + } +} + +// TestBalancerContext_IsSelected_ZeroValue 测试零值情况 +func TestBalancerContext_IsSelected_ZeroValue(t *testing.T) { + // 零值的 BalancerContext + bctx := &BalancerContext{} + + // 默认应该返回 false + result := bctx.IsSelected() + assert.False(t, result, "Zero value BalancerContext should return false for IsSelected()") +} + +// TestBalancerContext_IsSelected_AfterSelection 测试选择后的状态 +func TestBalancerContext_IsSelected_AfterSelection(t *testing.T) { + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + } + + bctx := &BalancerContext{ + Targets: targets, + ClientIP: "192.168.1.1", + Retries: 3, + } + + // 初始状态 + assert.False(t, bctx.IsSelected(), "Should not be selected initially") + + // 模拟选择目标 + bctx.Selected = targets[0] + bctx.selected = true + + // 选择后状态 + assert.True(t, bctx.IsSelected(), "Should be selected after setting") + + // 清除选择 + bctx.selected = false + assert.False(t, bctx.IsSelected(), "Should not be selected after clearing flag") +} + +// TestClassifyError 测试错误分类函数 +func TestClassifyError(t *testing.T) { + tests := []struct { + name string + err error + expected string + }{ + { + name: "nil error", + err: nil, + expected: "", + }, + { + name: "timeout error", + err: errors.New("connection timeout"), + expected: "timeout", + }, + { + name: "connection error", + err: errors.New("connection refused"), + expected: "failed", + }, + { + name: "generic error", + err: errors.New("some error"), + expected: "failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := classifyError(tt.err) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestBalancerContext_Structure 测试 BalancerContext 结构体字段 +func TestBalancerContext_Structure(t *testing.T) { + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080", Weight: 1}, + {URL: "http://backend2:8080", Weight: 2}, + } + + bctx := &BalancerContext{ + Targets: targets, + ClientIP: "10.0.0.1", + Retries: 5, + Selected: nil, + LastError: nil, + selected: false, + } + + assert.Equal(t, 2, len(bctx.Targets)) + assert.Equal(t, "10.0.0.1", bctx.ClientIP) + assert.Equal(t, 5, bctx.Retries) + assert.Nil(t, bctx.Selected) + assert.Nil(t, bctx.LastError) + assert.False(t, bctx.selected) +} diff --git a/internal/lua/mock_engine.go b/internal/lua/mock_engine.go new file mode 100644 index 0000000..0050ad7 --- /dev/null +++ b/internal/lua/mock_engine.go @@ -0,0 +1,196 @@ +// Package lua 提供 Lua 引擎的 Mock 实现,用于测试 +package lua + +import ( + "context" + "time" + + "github.com/valyala/fasthttp" + glua "github.com/yuin/gopher-lua" +) + +// MockLuaEngine 是 LuaEngine 的 Mock 实现 +type MockLuaEngine struct { + ExecuteFunc func(script string) error + ExecuteFileFunc func(path string) error + NewCoroutineFunc func(ctx *fasthttp.RequestCtx) (*MockCoroutine, error) + CloseFunc func() + StatsFunc func() EngineStats + ActiveCoroutinesFunc func() int32 + CodeCacheFunc func() *CodeCache + SharedDictManagerFunc func() *SharedDictManager + TimerManagerFunc func() *TimerManager + LocationManagerFunc func() *LocationManager + CreateSharedDictFunc func(name string, maxItems int) *SharedDict + InitSchedulerLStateFunc func() error + SchedulerLoopFunc func() + EnqueueCallbackFunc func(entry *CallbackEntry) bool + CloseSchedulerFunc func() +} + +// Execute 执行脚本 +func (m *MockLuaEngine) Execute(script string) error { + if m.ExecuteFunc != nil { + return m.ExecuteFunc(script) + } + return nil // stub +} + +// ExecuteFile 执行文件 +func (m *MockLuaEngine) ExecuteFile(path string) error { + if m.ExecuteFileFunc != nil { + return m.ExecuteFileFunc(path) + } + return nil // stub +} + +// NewCoroutine 创建协程 +func (m *MockLuaEngine) NewCoroutine(req *fasthttp.RequestCtx) (*MockCoroutine, error) { + if m.NewCoroutineFunc != nil { + return m.NewCoroutineFunc(req) + } + return &MockCoroutine{}, nil +} + +// Close 关闭引擎 +func (m *MockLuaEngine) Close() { + if m.CloseFunc != nil { + m.CloseFunc() + } +} + +// Stats 返回统计 +func (m *MockLuaEngine) Stats() EngineStats { + if m.StatsFunc != nil { + return m.StatsFunc() + } + return EngineStats{} +} + +// ActiveCoroutines 返回活跃协程数 +func (m *MockLuaEngine) ActiveCoroutines() int32 { + if m.ActiveCoroutinesFunc != nil { + return m.ActiveCoroutinesFunc() + } + return 0 +} + +// CodeCache 返回字节码缓存 +func (m *MockLuaEngine) CodeCache() *CodeCache { + if m.CodeCacheFunc != nil { + return m.CodeCacheFunc() + } + return nil +} + +// SharedDictManager 返回共享字典管理器 +func (m *MockLuaEngine) SharedDictManager() *SharedDictManager { + if m.SharedDictManagerFunc != nil { + return m.SharedDictManagerFunc() + } + return nil +} + +// TimerManager 返回定时器管理器 +func (m *MockLuaEngine) TimerManager() *TimerManager { + if m.TimerManagerFunc != nil { + return m.TimerManagerFunc() + } + return nil +} + +// LocationManager 返回 location 管理器 +func (m *MockLuaEngine) LocationManager() *LocationManager { + if m.LocationManagerFunc != nil { + return m.LocationManagerFunc() + } + return nil +} + +// CreateSharedDict 创建共享字典 +func (m *MockLuaEngine) CreateSharedDict(name string, maxItems int) *SharedDict { + if m.CreateSharedDictFunc != nil { + return m.CreateSharedDictFunc(name, maxItems) + } + return nil +} + +// InitSchedulerLState 初始化调度器 LState +func (m *MockLuaEngine) InitSchedulerLState() error { + if m.InitSchedulerLStateFunc != nil { + return m.InitSchedulerLStateFunc() + } + return nil +} + +// SchedulerLoop 调度器循环 +func (m *MockLuaEngine) SchedulerLoop() { + if m.SchedulerLoopFunc != nil { + m.SchedulerLoopFunc() + } +} + +// EnqueueCallback 将回调加入调度队列 +func (m *MockLuaEngine) EnqueueCallback(entry *CallbackEntry) bool { + if m.EnqueueCallbackFunc != nil { + return m.EnqueueCallbackFunc(entry) + } + return false +} + +// CloseScheduler 关闭调度器 +func (m *MockLuaEngine) CloseScheduler() { + if m.CloseSchedulerFunc != nil { + m.CloseSchedulerFunc() + } +} + +// MockCoroutine 是 LuaCoroutine 的 Mock 实现 +type MockCoroutine struct { + ExecuteFunc func(script string) error + ExecuteFileFunc func(path string) error + SetupSandboxFunc func() error + CloseFunc func() + HandleYieldFunc func(values []glua.LValue) ([]glua.LValue, error) + + // 模拟字段 + CreatedAt time.Time + ExecutionContext context.Context + Engine *MockLuaEngine + Co *glua.LState + Cancel context.CancelFunc + RequestCtx *fasthttp.RequestCtx + OutputBuffer []byte + Exited bool +} + +// Execute 执行脚本 +func (c *MockCoroutine) Execute(script string) error { + if c.ExecuteFunc != nil { + return c.ExecuteFunc(script) + } + return nil +} + +// ExecuteFile 执行文件 +func (c *MockCoroutine) ExecuteFile(path string) error { + if c.ExecuteFileFunc != nil { + return c.ExecuteFileFunc(path) + } + return nil +} + +// SetupSandbox 设置沙箱 +func (c *MockCoroutine) SetupSandbox() error { + if c.SetupSandboxFunc != nil { + return c.SetupSandboxFunc() + } + return nil +} + +// Close 关闭协程 +func (c *MockCoroutine) Close() { + if c.CloseFunc != nil { + c.CloseFunc() + } +} diff --git a/internal/server/init.go b/internal/server/init.go new file mode 100644 index 0000000..d2e5a3f --- /dev/null +++ b/internal/server/init.go @@ -0,0 +1,169 @@ +// Package server 提供服务器初始化函数。 +// +// 该文件包含 Start() 方法中分离出的可测试初始化函数: +// - Goroutine 池初始化 +// - 文件缓存初始化 +// - Lua 引擎初始化 +// - 错误页面管理器初始化 +// +// 主要用途: +// +// 将 Start() 方法中的初始化逻辑分离,便于单元测试 +// +// 注意事项: +// - 这些函数仅在 Server.Start() 内部调用 +// - 分离后保持原逻辑不变,仅提取为独立函数 +// +// 作者:xfy +package server + +import ( + "fmt" + "time" + + "rua.plus/lolly/internal/cache" + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/handler" + "rua.plus/lolly/internal/logging" + "rua.plus/lolly/internal/lua" +) + +// initGoroutinePool 初始化 Goroutine 池。 +// +// 根据配置创建并启动 Goroutine 池,用于优化请求处理性能。 +// +// 参数: +// - cfg: 性能配置,包含 GoroutinePool 配置 +// +// 返回值: +// - *GoroutinePool: 初始化的 Goroutine 池,未启用时返回 nil +func initGoroutinePool(cfg *config.PerformanceConfig) *GoroutinePool { + if cfg == nil || !cfg.GoroutinePool.Enabled { + return nil + } + + pool := NewGoroutinePool(PoolConfig{ + MaxWorkers: cfg.GoroutinePool.MaxWorkers, + MinWorkers: cfg.GoroutinePool.MinWorkers, + IdleTimeout: cfg.GoroutinePool.IdleTimeout, + }) + pool.Start() + + logging.Info(). + Int("maxWorkers", cfg.GoroutinePool.MaxWorkers). + Int("minWorkers", cfg.GoroutinePool.MinWorkers). + Msg("Goroutine 池已启动") + + return pool +} + +// initFileCache 初始化文件缓存。 +// +// 根据配置创建文件缓存实例,用于缓存静态文件内容。 +// +// 参数: +// - cfg: 性能配置,包含 FileCache 配置 +// +// 返回值: +// - *cache.FileCache: 初始化的文件缓存,未启用时返回 nil +func initFileCache(cfg *config.PerformanceConfig) *cache.FileCache { + // 检查是否配置了缓存 + if cfg == nil || (cfg.FileCache.MaxEntries <= 0 && cfg.FileCache.MaxSize <= 0) { + return nil + } + + fileCache := cache.NewFileCache( + cfg.FileCache.MaxEntries, + cfg.FileCache.MaxSize, + cfg.FileCache.Inactive, + ) + + logging.Info(). + Int64("maxEntries", cfg.FileCache.MaxEntries). + Int64("maxSize", cfg.FileCache.MaxSize). + Msg("文件缓存已启动") + + return fileCache +} + +// initLuaEngine 初始化 Lua 引擎。 +// +// 根据配置创建 Lua 引擎实例,用于执行 Lua 脚本。 +// +// 参数: +// - luaCfg: Lua 中间件配置 +// +// 返回值: +// - *lua.LuaEngine: 初始化的 Lua 引擎,未启用时返回 nil +// - error: 初始化过程中遇到的错误 +func initLuaEngine(luaCfg *config.LuaMiddlewareConfig) (*lua.LuaEngine, error) { + if luaCfg == nil || !luaCfg.Enabled { + return nil, nil + } + + engineCfg := &lua.Config{ + MaxConcurrentCoroutines: luaCfg.GlobalSettings.MaxConcurrentCoroutines, + CoroutineTimeout: luaCfg.GlobalSettings.CoroutineTimeout, + CodeCacheSize: luaCfg.GlobalSettings.CodeCacheSize, + CodeCacheTTL: time.Hour, // 默认值 + EnableFileWatch: luaCfg.GlobalSettings.EnableFileWatch, + MaxExecutionTime: luaCfg.GlobalSettings.MaxExecutionTime, + EnableOSLib: false, // 安全默认值 + EnableIOLib: false, + EnableLoadLib: false, + } + + // 设置默认值 + if engineCfg.MaxConcurrentCoroutines == 0 { + engineCfg.MaxConcurrentCoroutines = 1000 + } + if engineCfg.CoroutineTimeout == 0 { + engineCfg.CoroutineTimeout = 30 * time.Second + } + if engineCfg.CodeCacheSize == 0 { + engineCfg.CodeCacheSize = 1000 + } + if engineCfg.MaxExecutionTime == 0 { + engineCfg.MaxExecutionTime = 30 * time.Second + } + + engine, err := lua.NewEngine(engineCfg) + if err != nil { + return nil, fmt.Errorf("初始化 Lua 引擎失败: %w", err) + } + + logging.Info().Msg("Lua 引擎已启动") + return engine, nil +} + +// initErrorPageManager 初始化错误页面管理器。 +// +// 根据配置创建错误页面管理器,用于加载和提供自定义错误页面。 +// +// 参数: +// - errorPageCfg: 错误页面配置 +// +// 返回值: +// - *handler.ErrorPageManager: 初始化的错误页面管理器,未配置时返回 nil +// - error: 初始化过程中遇到的错误 +func initErrorPageManager(errorPageCfg *config.ErrorPageConfig) (*handler.ErrorPageManager, error) { + // 检查是否配置了错误页面 + if errorPageCfg == nil || (len(errorPageCfg.Pages) == 0 && errorPageCfg.Default == "") { + return nil, nil + } + + manager, err := handler.NewErrorPageManager(errorPageCfg) + if err != nil { + // 检查是否是部分加载失败 + if _, ok := err.(*handler.PartialLoadError); ok { + logging.Warn().Msg("部分错误页面加载失败: " + err.Error()) + // 返回部分加载的管理器 + return manager, nil + } + // 全部加载失败 + return nil, fmt.Errorf("加载错误页面失败: %w", err) + } + + logging.Info().Msg("错误页面管理器已启动") + return manager, nil +} diff --git a/internal/server/init_test.go b/internal/server/init_test.go new file mode 100644 index 0000000..b866d91 --- /dev/null +++ b/internal/server/init_test.go @@ -0,0 +1,351 @@ +// Package server 提供初始化函数的测试。 +// +// 该文件测试从 Start() 分离出的初始化函数: +// - initGoroutinePool() +// - initFileCache() +// - initLuaEngine() +// - initErrorPageManager() +// +// 主要用途: +// +// 验证初始化函数在各种配置场景下的行为 +// +// 作者:xfy +package server + +import ( + "os" + "path/filepath" + "testing" + "time" + + "rua.plus/lolly/internal/config" +) + +// TestInitGoroutinePool_Disabled 测试禁用时返回 nil +func TestInitGoroutinePool_Disabled(t *testing.T) { + cfg := &config.PerformanceConfig{ + GoroutinePool: config.GoroutinePoolConfig{ + Enabled: false, + }, + } + + pool := initGoroutinePool(cfg) + if pool != nil { + t.Error("Expected nil when GoroutinePool is disabled") + } +} + +// TestInitGoroutinePool_Enabled 测试启用时创建池 +func TestInitGoroutinePool_Enabled(t *testing.T) { + cfg := &config.PerformanceConfig{ + GoroutinePool: config.GoroutinePoolConfig{ + Enabled: true, + MaxWorkers: 50, + MinWorkers: 5, + IdleTimeout: 30 * time.Second, + }, + } + + pool := initGoroutinePool(cfg) + if pool == nil { + t.Fatal("Expected non-nil pool when enabled") + } + + // 验证配置正确应用 + if pool.maxWorkers != 50 { + t.Errorf("Expected maxWorkers 50, got %d", pool.maxWorkers) + } + if pool.minWorkers != 5 { + t.Errorf("Expected minWorkers 5, got %d", pool.minWorkers) + } + + // 清理 + pool.Stop() +} + +// TestInitGoroutinePool_ZeroWorkers 测试零值配置 +func TestInitGoroutinePool_ZeroWorkers(t *testing.T) { + cfg := &config.PerformanceConfig{ + GoroutinePool: config.GoroutinePoolConfig{ + Enabled: true, + MaxWorkers: 0, + MinWorkers: 0, + }, + } + + pool := initGoroutinePool(cfg) + if pool == nil { + t.Fatal("Expected non-nil pool") + } + + // 清理 + pool.Stop() +} + +// TestInitFileCache_Disabled 测试禁用时返回 nil +func TestInitFileCache_Disabled(t *testing.T) { + cfg := &config.PerformanceConfig{ + FileCache: config.FileCacheConfig{ + MaxEntries: 0, + MaxSize: 0, + }, + } + + cache := initFileCache(cfg) + if cache != nil { + t.Error("Expected nil when FileCache is disabled") + } +} + +// TestInitFileCache_ByEntries 测试按条目数启用 +func TestInitFileCache_ByEntries(t *testing.T) { + cfg := &config.PerformanceConfig{ + FileCache: config.FileCacheConfig{ + MaxEntries: 1000, + MaxSize: 0, + Inactive: 10 * time.Minute, + }, + } + + cache := initFileCache(cfg) + if cache == nil { + t.Fatal("Expected non-nil cache when MaxEntries > 0") + } +} + +// TestInitFileCache_BySize 测试按大小启用 +func TestInitFileCache_BySize(t *testing.T) { + cfg := &config.PerformanceConfig{ + FileCache: config.FileCacheConfig{ + MaxEntries: 0, + MaxSize: 100 * 1024 * 1024, // 100MB + Inactive: 10 * time.Minute, + }, + } + + cache := initFileCache(cfg) + if cache == nil { + t.Fatal("Expected non-nil cache when MaxSize > 0") + } +} + +// TestInitFileCache_BothLimits 测试同时设置条目数和大小 +func TestInitFileCache_BothLimits(t *testing.T) { + cfg := &config.PerformanceConfig{ + FileCache: config.FileCacheConfig{ + MaxEntries: 1000, + MaxSize: 100 * 1024 * 1024, + Inactive: 10 * time.Minute, + }, + } + + cache := initFileCache(cfg) + if cache == nil { + t.Fatal("Expected non-nil cache") + } +} + +// TestInitLuaEngine_Disabled 测试禁用时返回 nil +func TestInitLuaEngine_Disabled(t *testing.T) { + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: false, + } + + engine, err := initLuaEngine(luaCfg) + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + if engine != nil { + t.Error("Expected nil when Lua is disabled") + } +} + +// TestInitLuaEngine_NilConfig 测试 nil 配置 +func TestInitLuaEngine_NilConfig(t *testing.T) { + engine, err := initLuaEngine(nil) + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + if engine != nil { + t.Error("Expected nil for nil config") + } +} + +// TestInitLuaEngine_Enabled 测试启用 Lua 引擎 +func TestInitLuaEngine_Enabled(t *testing.T) { + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + GlobalSettings: config.LuaGlobalSettings{ + MaxConcurrentCoroutines: 500, + CoroutineTimeout: 60 * time.Second, + CodeCacheSize: 500, + MaxExecutionTime: 60 * time.Second, + }, + } + + engine, err := initLuaEngine(luaCfg) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if engine == nil { + t.Fatal("Expected non-nil engine when enabled") + } + + // 清理 + engine.Close() +} + +// TestInitLuaEngine_DefaultValues 测试默认值设置 +func TestInitLuaEngine_DefaultValues(t *testing.T) { + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + GlobalSettings: config.LuaGlobalSettings{ + // 所有值为零,应使用默认值 + }, + } + + engine, err := initLuaEngine(luaCfg) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if engine == nil { + t.Fatal("Expected non-nil engine") + } + + // 清理 + engine.Close() +} + +// TestInitErrorPageManager_NoConfig 测试无配置时返回 nil +func TestInitErrorPageManager_NoConfig(t *testing.T) { + cfg := &config.ErrorPageConfig{ + Pages: nil, + Default: "", + } + + manager, err := initErrorPageManager(cfg) + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + if manager != nil { + t.Error("Expected nil when no error page configured") + } +} + +// TestInitErrorPageManager_NilConfig 测试 nil 配置 +func TestInitErrorPageManager_NilConfig(t *testing.T) { + manager, err := initErrorPageManager(nil) + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + if manager != nil { + t.Error("Expected nil for nil config") + } +} + +// TestInitErrorPageManager_WithPages 测试配置错误页面 +func TestInitErrorPageManager_WithPages(t *testing.T) { + // 创建临时目录和文件 + tempDir := t.TempDir() + errorPagePath := filepath.Join(tempDir, "404.html") + content := []byte("Not Found") + if err := os.WriteFile(errorPagePath, content, 0o644); err != nil { + t.Fatalf("Failed to create error page file: %v", err) + } + + cfg := &config.ErrorPageConfig{ + Pages: map[int]string{ + 404: errorPagePath, + }, + } + + manager, err := initErrorPageManager(cfg) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if manager == nil { + t.Fatal("Expected non-nil manager when pages configured") + } +} + +// TestInitErrorPageManager_WithDefault 测试配置默认错误页面 +func TestInitErrorPageManager_WithDefault(t *testing.T) { + // 创建临时目录和文件 + tempDir := t.TempDir() + defaultPagePath := filepath.Join(tempDir, "error.html") + content := []byte("Error") + if err := os.WriteFile(defaultPagePath, content, 0o644); err != nil { + t.Fatalf("Failed to create error page file: %v", err) + } + + cfg := &config.ErrorPageConfig{ + Default: defaultPagePath, + } + + manager, err := initErrorPageManager(cfg) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if manager == nil { + t.Fatal("Expected non-nil manager when default page configured") + } +} + +// TestInitErrorPageManager_NonExistentFile 测试不存在的文件 +func TestInitErrorPageManager_NonExistentFile(t *testing.T) { + cfg := &config.ErrorPageConfig{ + Default: "/nonexistent/path/error.html", + } + + manager, err := initErrorPageManager(cfg) + // 应该返回错误 + if err == nil { + t.Error("Expected error for non-existent file") + } + if manager != nil { + t.Error("Expected nil manager when file doesn't exist") + } +} + +// TestInitErrorPageManager_PartialLoad 测试部分加载(部分文件存在) +func TestInitErrorPageManager_PartialLoad(t *testing.T) { + // 创建临时目录和一个存在的文件 + tempDir := t.TempDir() + existingPath := filepath.Join(tempDir, "404.html") + content := []byte("Not Found") + if err := os.WriteFile(existingPath, content, 0o644); err != nil { + t.Fatalf("Failed to create error page file: %v", err) + } + + cfg := &config.ErrorPageConfig{ + Pages: map[int]string{ + 404: existingPath, + 500: "/nonexistent/path/500.html", // 不存在的文件 + }, + } + + manager, err := initErrorPageManager(cfg) + // 部分加载应该成功返回 manager,但可能有警告 + // 注意:具体行为取决于 handler.NewErrorPageManager 的实现 + if manager == nil { + t.Logf("Manager is nil with error: %v", err) + } +} + +// TestInitFunctions_NilPerformanceConfig 测试 nil PerformanceConfig +func TestInitFunctions_NilPerformanceConfig(t *testing.T) { + // 这个测试验证函数能正确处理空配置 + var cfg *config.PerformanceConfig + + // 应该能处理 nil 而不 panic + pool := initGoroutinePool(cfg) + if pool != nil { + t.Error("Expected nil for nil config") + } + + cache := initFileCache(cfg) + if cache != nil { + t.Error("Expected nil for nil config") + } +} diff --git a/internal/server/lua_integration_test.go b/internal/server/lua_integration_test.go new file mode 100644 index 0000000..ed9c5c4 --- /dev/null +++ b/internal/server/lua_integration_test.go @@ -0,0 +1,372 @@ +// Package server 提供 buildLuaMiddlewares 的 Mock 测试 +package server + +import ( + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/lua" +) + +// TestBuildLuaMiddlewares_NilEngine 测试 LuaEngine 为 nil 时 +func TestBuildLuaMiddlewares_NilEngine(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":8080", + }, + } + s := New(cfg) + + // 确保 luaEngine 为 nil + s.luaEngine = nil + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + { + Path: "/test/script.lua", + Phase: "rewrite", + Enabled: true, + Timeout: 30 * time.Second, + }, + }, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + require.NoError(t, err) + assert.Nil(t, middlewares) +} + +// TestBuildLuaMiddlewares_InvalidPhase 测试无效阶段 +func TestBuildLuaMiddlewares_InvalidPhase(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":8080", + }, + } + s := New(cfg) + + // 使用真实的 LuaEngine 来测试无效阶段 + engine, err := lua.NewEngine(lua.DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + s.luaEngine = engine + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + { + Path: "/test/script.lua", + Phase: "invalid_phase", + Enabled: true, + Timeout: 30 * time.Second, + }, + }, + } + + _, err = s.buildLuaMiddlewares(luaCfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "无效的阶段") +} + +// TestBuildLuaMiddlewares_WithTimeout 测试超时配置 +func TestBuildLuaMiddlewares_WithTimeout(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":8080", + }, + } + s := New(cfg) + + engine, err := lua.NewEngine(lua.DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + s.luaEngine = engine + + // 测试自定义超时 + customTimeout := 60 * time.Second + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + { + Path: "/test/script.lua", + Phase: "rewrite", + Enabled: true, + Timeout: customTimeout, + }, + }, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + require.NoError(t, err) + assert.Len(t, middlewares, 1) +} + +// TestBuildLuaMiddlewares_EmptyScripts 测试空脚本列表 +func TestBuildLuaMiddlewares_EmptyScripts(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":8080", + }, + } + s := New(cfg) + + engine, err := lua.NewEngine(lua.DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + s.luaEngine = engine + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{}, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + require.NoError(t, err) + assert.Empty(t, middlewares) +} + +// TestBuildLuaMiddlewares_DisabledLua 测试禁用 Lua 配置 +// buildLuaMiddlewares 函数本身不检查 Enabled 字段,由调用者检查 +func TestBuildLuaMiddlewares_DisabledLua(t *testing.T) { + // 创建临时脚本文件 + tempDir := t.TempDir() + scriptPath := tempDir + "/test.lua" + err := os.WriteFile(scriptPath, []byte("-- test"), 0o644) + require.NoError(t, err) + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":8080", + }, + } + s := New(cfg) + + engine, err := lua.NewEngine(lua.DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + s.luaEngine = engine + + // Enabled=false 但 buildLuaMiddlewares 本身不检查这个字段 + // 它会正常处理 Scripts 中的脚本 + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: false, + Scripts: []config.LuaScriptConfig{ + { + Path: scriptPath, + Phase: "rewrite", + Enabled: true, + Timeout: 30 * time.Second, + }, + }, + } + + // 由于 Enabled 检查在调用者处,buildLuaMiddlewares 会正常执行 + middlewares, err := s.buildLuaMiddlewares(luaCfg) + require.NoError(t, err) + // 脚本文件存在,应该能创建中间件 + assert.Len(t, middlewares, 1) +} + +// TestBuildLuaMiddlewares_DisabledScript 测试禁用特定脚本 +func TestBuildLuaMiddlewares_DisabledScript(t *testing.T) { + // 创建临时脚本文件(只有一个启用的脚本) + tempDir := t.TempDir() + enabledScriptPath := tempDir + "/enabled.lua" + err := os.WriteFile(enabledScriptPath, []byte("-- enabled script"), 0o644) + require.NoError(t, err) + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":8080", + }, + } + s := New(cfg) + + engine, err := lua.NewEngine(lua.DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + s.luaEngine = engine + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + { + Path: enabledScriptPath, + Phase: "rewrite", + Enabled: true, + Timeout: 30 * time.Second, + }, + { + Path: "/test/disabled.lua", + Phase: "access", + Enabled: false, // 禁用的脚本应该被过滤 + Timeout: 30 * time.Second, + }, + }, + } + + // 只有一个启用的脚本,应该能正常创建中间件 + middlewares, err := s.buildLuaMiddlewares(luaCfg) + require.NoError(t, err) + // 只有 rewrite 阶段的脚本被创建 + assert.Len(t, middlewares, 1) +} + +// TestBuildLuaMiddlewares_WithExistingScript 测试使用存在的脚本文件 +func TestBuildLuaMiddlewares_WithExistingScript(t *testing.T) { + // 创建临时脚本文件 + scriptContent := `-- test script +ngx.var.uri = "/test" +` + tempDir := t.TempDir() + scriptPath := tempDir + "/test.lua" + err := os.WriteFile(scriptPath, []byte(scriptContent), 0o644) + require.NoError(t, err) + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":8080", + }, + } + s := New(cfg) + + engine, err := lua.NewEngine(lua.DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + s.luaEngine = engine + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + { + Path: scriptPath, + Phase: "rewrite", + Enabled: true, + Timeout: 30 * time.Second, + }, + }, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + require.NoError(t, err) + assert.Len(t, middlewares, 1) +} + +// TestBuildLuaMiddlewares_MultiplePhases 测试多个不同阶段的脚本 +func TestBuildLuaMiddlewares_MultiplePhases(t *testing.T) { + // 创建临时脚本文件 + tempDir := t.TempDir() + + scripts := map[string]string{ + "rewrite.lua": `-- rewrite script +ngx.var.uri = "/rewrite" +`, + "access.lua": `-- access script +-- access control logic +`, + "content.lua": `-- content script +ngx.say("hello") +`, + } + + scriptPaths := make(map[string]string) + for name, content := range scripts { + path := tempDir + "/" + name + err := os.WriteFile(path, []byte(content), 0o644) + require.NoError(t, err) + scriptPaths[name] = path + } + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":8080", + }, + } + s := New(cfg) + + engine, err := lua.NewEngine(lua.DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + s.luaEngine = engine + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + { + Path: scriptPaths["rewrite.lua"], + Phase: "rewrite", + Enabled: true, + Timeout: 30 * time.Second, + }, + { + Path: scriptPaths["access.lua"], + Phase: "access", + Enabled: true, + Timeout: 30 * time.Second, + }, + { + Path: scriptPaths["content.lua"], + Phase: "content", + Enabled: true, + Timeout: 30 * time.Second, + }, + }, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + require.NoError(t, err) + // 三个阶段应该创建三个中间件 + assert.Len(t, middlewares, 3) +} + +// TestBuildLuaMiddlewares_DefaultTimeout 测试默认超时值 +func TestBuildLuaMiddlewares_DefaultTimeout(t *testing.T) { + // 创建临时脚本文件 + tempDir := t.TempDir() + scriptPath := tempDir + "/test.lua" + err := os.WriteFile(scriptPath, []byte("-- test"), 0o644) + require.NoError(t, err) + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":8080", + }, + } + s := New(cfg) + + engine, err := lua.NewEngine(lua.DefaultConfig()) + require.NoError(t, err) + defer engine.Close() + + s.luaEngine = engine + + // 不设置 Timeout,使用默认值 + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + { + Path: scriptPath, + Phase: "rewrite", + Enabled: true, + // Timeout 为 0,应该使用默认值 30s + Timeout: 0, + }, + }, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + require.NoError(t, err) + assert.Len(t, middlewares, 1) +} diff --git a/internal/server/server.go b/internal/server/server.go index 918aa8e..5445e45 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -405,73 +405,23 @@ func (s *Server) Start() error { // 记录启动时间 s.startTime = time.Now() - // 启用 GoroutinePool(如果配置) - if s.config.Performance.GoroutinePool.Enabled { - s.pool = NewGoroutinePool(PoolConfig{ - MaxWorkers: s.config.Performance.GoroutinePool.MaxWorkers, - MinWorkers: s.config.Performance.GoroutinePool.MinWorkers, - IdleTimeout: s.config.Performance.GoroutinePool.IdleTimeout, - }) - s.pool.Start() + // 初始化 GoroutinePool + s.pool = initGoroutinePool(&s.config.Performance) + + // 初始化文件缓存 + s.fileCache = initFileCache(&s.config.Performance) + + // 初始化错误页面管理器 + var err error + s.errorPageManager, err = initErrorPageManager(&s.config.Server.Security.ErrorPage) + if err != nil { + return err } - // 启用文件缓存(如果配置) - if s.config.Performance.FileCache.MaxEntries > 0 || s.config.Performance.FileCache.MaxSize > 0 { - s.fileCache = cache.NewFileCache( - s.config.Performance.FileCache.MaxEntries, - s.config.Performance.FileCache.MaxSize, - s.config.Performance.FileCache.Inactive, - ) - } - - // 预加载错误页面(如果配置) - if s.config.Server.Security.ErrorPage.Pages != nil || s.config.Server.Security.ErrorPage.Default != "" { - var err error - s.errorPageManager, err = handler.NewErrorPageManager(&s.config.Server.Security.ErrorPage) - if err != nil { - // 检查是否是部分加载失败 - if _, ok := err.(*handler.PartialLoadError); ok { - logging.Warn().Msg("部分错误页面加载失败: " + err.Error()) - } else { - // 全部加载失败,阻止启动 - return fmt.Errorf("加载错误页面失败: %w", err) - } - } - } - - // 初始化 Lua 引擎(如果配置) - if s.config.Server.Lua != nil && s.config.Server.Lua.Enabled { - engineCfg := &lua.Config{ - MaxConcurrentCoroutines: s.config.Server.Lua.GlobalSettings.MaxConcurrentCoroutines, - CoroutineTimeout: s.config.Server.Lua.GlobalSettings.CoroutineTimeout, - CodeCacheSize: s.config.Server.Lua.GlobalSettings.CodeCacheSize, - CodeCacheTTL: time.Hour, // 默认值 - EnableFileWatch: s.config.Server.Lua.GlobalSettings.EnableFileWatch, - MaxExecutionTime: s.config.Server.Lua.GlobalSettings.MaxExecutionTime, - EnableOSLib: false, // 安全默认值 - EnableIOLib: false, - EnableLoadLib: false, - } - // 设置默认值 - if engineCfg.MaxConcurrentCoroutines == 0 { - engineCfg.MaxConcurrentCoroutines = 1000 - } - if engineCfg.CoroutineTimeout == 0 { - engineCfg.CoroutineTimeout = 30 * time.Second - } - if engineCfg.CodeCacheSize == 0 { - engineCfg.CodeCacheSize = 1000 - } - if engineCfg.MaxExecutionTime == 0 { - engineCfg.MaxExecutionTime = 30 * time.Second - } - - var err error - s.luaEngine, err = lua.NewEngine(engineCfg) - if err != nil { - return fmt.Errorf("初始化 Lua 引擎失败: %w", err) - } - logging.Info().Msg("Lua 引擎已启动") + // 初始化 Lua 引擎 + s.luaEngine, err = initLuaEngine(s.config.Server.Lua) + if err != nil { + return err } if s.config.HasServers() { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 0159a06..06f3e4c 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -769,3 +769,32 @@ func TestGetTLSConfig_Nil(t *testing.T) { t.Error("GetTLSConfig() should return nil when TLS not configured") } } + +// TestGetTLSConfig_NilServer 测试 nil 服务器调用 GetTLSConfig +func TestGetTLSConfig_NilServer(t *testing.T) { + var s *Server + // 防御性:如果 s 为 nil,调用方法会 panic,这是预期的行为 + // 这里我们只测试非 nil 但 tlsManager 为 nil 的情况 + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: ":0", + }, + } + s = New(cfg) + + // 确保 tlsManager 为 nil + if s.tlsManager != nil { + t.Skip("tlsManager should be nil initially") + } + + tlsCfg, err := s.GetTLSConfig() + if err == nil { + t.Error("Expected error when tlsManager is nil") + } + if tlsCfg != nil { + t.Error("Expected nil TLS config when tlsManager is nil") + } + if err.Error() != "TLS not configured" { + t.Errorf("Expected error 'TLS not configured', got: %v", err) + } +} diff --git a/internal/server/start_integration_test.go b/internal/server/start_integration_test.go new file mode 100644 index 0000000..2e151b0 --- /dev/null +++ b/internal/server/start_integration_test.go @@ -0,0 +1,556 @@ +// Package server 提供 Server.Start() 的集成测试。 +// +// 该文件使用 mock_backend 模拟上游服务,测试完整的服务器启动流程: +// - 服务器配置初始化 +// - 代理路由注册 +// - 静态文件服务 +// - 中间件链构建 +// - 请求处理和转发 +// +// 主要用途: +// +// 验证服务器启动和请求处理的完整流程 +// +// 作者:xfy +package server + +import ( + "os" + "testing" + "time" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/benchmark/tools" + "rua.plus/lolly/internal/config" +) + +// TestStart_Integration 测试完整的服务器启动和请求处理流程 +func TestStart_Integration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // 启动 mock 上游服务器 + backendAddr, cleanup := tools.SimpleMockBackend( + fasthttp.StatusOK, + []byte(`{"message": "Hello from backend"}`), + ) + defer cleanup() + + // 使用随机端口避免冲突 + serverAddr := "127.0.0.1:0" + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: serverAddr, + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://" + backendAddr, Weight: 1}, + }, + HealthCheck: config.HealthCheckConfig{}, + }, + }, + }, + } + + s := New(cfg) + + // 验证服务器初始化 + if s == nil { + t.Fatal("Expected non-nil server") + } + + // 测试初始状态 + if s.running { + t.Error("Server should not be running initially") + } + + // 测试 GetListeners 初始状态 + listeners := s.GetListeners() + if listeners != nil { + t.Error("Listeners should be nil before Start") + } +} + +// TestStart_WithSecurity 测试安全配置 +func TestStart_WithSecurity(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + Access: config.AccessConfig{ + Allow: []string{"127.0.0.1"}, + Deny: []string{}, + }, + RateLimit: config.RateLimitConfig{ + RequestRate: 100, + Burst: 200, + }, + Headers: config.SecurityHeaders{ + XFrameOptions: "DENY", + XContentTypeOptions: "nosniff", + }, + }, + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } + + // 验证安全配置 + if len(s.config.Server.Security.Access.Allow) != 1 { + t.Errorf("Expected 1 allowed IP, got %d", len(s.config.Server.Security.Access.Allow)) + } +} + +// TestStart_WithRewrite 测试 URL 重写配置 +func TestStart_WithRewrite(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + Rewrite: []config.RewriteRule{ + { + Pattern: "/old/(.*)", + Replacement: "/new/$1", + }, + }, + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } + + // 验证重写配置 + if len(s.config.Server.Rewrite) != 1 { + t.Errorf("Expected 1 rewrite rule, got %d", len(s.config.Server.Rewrite)) + } +} + +// TestStart_WithMonitoring 测试监控配置 +func TestStart_WithMonitoring(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + }, + Monitoring: config.MonitoringConfig{ + Status: config.StatusConfig{ + Path: "/status", + Allow: []string{"127.0.0.1"}, + }, + Pprof: config.PprofConfig{ + Enabled: false, + Path: "/debug/pprof", + }, + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } + + // 验证监控配置 + if s.config.Monitoring.Status.Path != "/status" { + t.Errorf("Expected status path '/status', got '%s'", s.config.Monitoring.Status.Path) + } +} + +// TestStart_WithErrorPage 测试错误页面配置 +func TestStart_WithErrorPage(t *testing.T) { + // 创建临时错误页面文件 + tempDir := t.TempDir() + errorPagePath := tempDir + "/404.html" + + // 创建错误页面文件 + content := []byte("Not Found") + if err := writeFile(errorPagePath, content); err != nil { + t.Fatalf("Failed to create error page file: %v", err) + } + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + ErrorPage: config.ErrorPageConfig{ + Pages: map[int]string{ + 404: errorPagePath, + }, + }, + }, + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } + + // 验证错误页面配置 + if s.config.Server.Security.ErrorPage.Pages == nil { + t.Error("Error page pages should not be nil") + } +} + +// TestStart_WithLuaEnabled 测试 Lua 配置 +func TestStart_WithLuaEnabled(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + Lua: &config.LuaMiddlewareConfig{ + Enabled: true, + GlobalSettings: config.LuaGlobalSettings{ + MaxConcurrentCoroutines: 100, + CoroutineTimeout: 30 * time.Second, + CodeCacheSize: 100, + MaxExecutionTime: 30 * time.Second, + }, + }, + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } + + // 验证 Lua 配置 + if s.config.Server.Lua == nil || !s.config.Server.Lua.Enabled { + t.Error("Lua should be enabled") + } +} + +// TestStart_WithMultipleProxies 测试多个代理配置 +func TestStart_WithMultipleProxies(t *testing.T) { + // 启动多个 mock 上游服务器 + backend1, cleanup1 := tools.SimpleMockBackend( + fasthttp.StatusOK, + []byte(`{"service": "api1"}`), + ) + defer cleanup1() + + backend2, cleanup2 := tools.SimpleMockBackend( + fasthttp.StatusOK, + []byte(`{"service": "api2"}`), + ) + defer cleanup2() + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api1", + Targets: []config.ProxyTarget{ + {URL: "http://" + backend1, Weight: 1}, + }, + }, + { + Path: "/api2", + Targets: []config.ProxyTarget{ + {URL: "http://" + backend2, Weight: 1}, + }, + }, + }, + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } + + // 验证代理配置 + if len(s.config.Server.Proxy) != 2 { + t.Errorf("Expected 2 proxies, got %d", len(s.config.Server.Proxy)) + } +} + +// TestStart_EmptyConfig 测试空配置 +func TestStart_EmptyConfig(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } + + // 空配置应该能正常初始化 + if s.config == nil { + t.Error("Config should not be nil") + } +} + +// TestStart_WithAllFeatures 测试启用所有功能的配置 +func TestStart_WithAllFeatures(t *testing.T) { + // 创建临时目录 + tempDir := t.TempDir() + errorPagePath := tempDir + "/404.html" + writeFile(errorPagePath, []byte("Not Found")) + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + Static: []config.StaticConfig{ + { + Path: "/static", + Root: tempDir, + Index: []string{"index.html"}, + }, + }, + Compression: config.CompressionConfig{ + Type: "gzip", + Level: 6, + }, + Security: config.SecurityConfig{ + Access: config.AccessConfig{ + Allow: []string{"127.0.0.1"}, + }, + RateLimit: config.RateLimitConfig{ + RequestRate: 100, + Burst: 200, + }, + Headers: config.SecurityHeaders{ + XFrameOptions: "DENY", + }, + ErrorPage: config.ErrorPageConfig{ + Default: errorPagePath, + }, + }, + Rewrite: []config.RewriteRule{ + { + Pattern: "/old/(.*)", + Replacement: "/new/$1", + }, + }, + }, + Performance: config.PerformanceConfig{ + GoroutinePool: config.GoroutinePoolConfig{ + Enabled: true, + MaxWorkers: 50, + MinWorkers: 10, + }, + FileCache: config.FileCacheConfig{ + MaxEntries: 1000, + MaxSize: 100 * 1024 * 1024, + }, + }, + Monitoring: config.MonitoringConfig{ + Status: config.StatusConfig{ + Path: "/status", + }, + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } + + // 验证所有配置 + if !s.config.Performance.GoroutinePool.Enabled { + t.Error("GoroutinePool should be enabled") + } + if s.config.Server.Compression.Type != "gzip" { + t.Error("Compression should be gzip") + } +} + +// TestStart_ServerOptions 测试服务器配置选项 +func TestStart_ServerOptions(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + MaxConnsPerIP: 100, + MaxRequestsPerConn: 1000, + ClientMaxBodySize: "10MB", + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } + + // 验证服务器选项 + if s.config.Server.ReadTimeout != 30*time.Second { + t.Errorf("Expected ReadTimeout 30s, got %v", s.config.Server.ReadTimeout) + } + if s.config.Server.MaxConnsPerIP != 100 { + t.Errorf("Expected MaxConnsPerIP 100, got %d", s.config.Server.MaxConnsPerIP) + } +} + +// TestStart_HealthCheckConfig 测试健康检查配置 +func TestStart_HealthCheckConfig(t *testing.T) { + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:8081", Weight: 1}, + }, + HealthCheck: config.HealthCheckConfig{ + Interval: 10 * time.Second, + Timeout: 5 * time.Second, + Path: "/health", + }, + }, + }, + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } + + // 验证健康检查配置 + if s.config.Server.Proxy[0].HealthCheck.Path != "/health" { + t.Error("Health check path should be /health") + } +} + +// TestStart_VHostMode 测试虚拟主机模式配置 +func TestStart_VHostMode(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{ + { + Name: "api.example.com", + Listen: "127.0.0.1:0", + }, + { + Name: "www.example.com", + Listen: "127.0.0.1:0", + }, + }, + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } + + // 验证虚拟主机配置 + if !s.config.HasServers() { + t.Error("Should detect virtual hosts") + } +} + +// TestStart_WithProxyBackendError 测试代理后端错误处理 +func TestStart_WithProxyBackendError(t *testing.T) { + // 启动返回错误的 mock 服务器 + backendAddr, cleanup := tools.ErrorMockBackend(1.0, []byte(`{"error": "backend error"}`)) + defer cleanup() + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://" + backendAddr, Weight: 1}, + }, + }, + }, + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } + + // 验证代理配置 + if len(s.config.Server.Proxy) != 1 { + t.Errorf("Expected 1 proxy, got %d", len(s.config.Server.Proxy)) + } +} + +// TestStart_WithDelayedBackend 测试延迟后端 +func TestStart_WithDelayedBackend(t *testing.T) { + // 启动延迟的 mock 服务器 + backendAddr, cleanup := tools.DelayedMockBackend( + 100*time.Millisecond, + []byte(`{"message": "delayed response"}`), + ) + defer cleanup() + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://" + backendAddr, Weight: 1}, + }, + }, + }, + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } +} + +// TestStart_WithRandomResponse 测试随机响应后端 +func TestStart_WithRandomResponse(t *testing.T) { + // 启动随机响应的 mock 服务器 + backendAddr, cleanup := tools.StartMockFasthttpBackend(tools.MockBackendConfig{ + Mode: tools.ModeRandomResponse, + StatusCode: fasthttp.StatusOK, + Body: []byte(`{"random": true}`), + }) + defer cleanup() + + cfg := &config.Config{ + Server: config.ServerConfig{ + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://" + backendAddr, Weight: 1}, + }, + }, + }, + }, + } + + s := New(cfg) + if s == nil { + t.Fatal("Expected non-nil server") + } +} + +// writeFile 辅助函数:写入文件 +func writeFile(path string, content []byte) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + _, err = f.Write(content) + return err +} diff --git a/internal/server/testutil.go b/internal/server/testutil.go new file mode 100644 index 0000000..48eae41 --- /dev/null +++ b/internal/server/testutil.go @@ -0,0 +1,121 @@ +// Package server 提供测试工具函数和依赖注入支持 +package server + +import ( + "crypto/tls" + "net" + "time" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/lua" + "rua.plus/lolly/internal/ssl" +) + +// MockFastServer 是 fasthttp.Server 的 Mock 包装 +// 定义在此文件以便 TestServerOptions 可以引用 +type MockFastServer struct { + Name string + Handler fasthttp.RequestHandler + TLSConfig *tls.Config + ReadTimeout time.Duration + WriteTimeout time.Duration + IdleTimeout time.Duration + MaxConnsPerIP int + MaxRequestsPerConn int + CloseOnShutdown bool + ServeFunc func(ln net.Listener) error + ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error + ShutdownFunc func() error +} + +// Serve 启动服务 +func (m *MockFastServer) Serve(ln net.Listener) error { + if m.ServeFunc != nil { + return m.ServeFunc(ln) + } + return nil +} + +// ServeTLS 启动 TLS 服务 +func (m *MockFastServer) ServeTLS(ln net.Listener, certFile, keyFile string) error { + if m.ServeTLSFunc != nil { + return m.ServeTLSFunc(ln, certFile, keyFile) + } + return nil +} + +// Shutdown 关闭服务器 +func (m *MockFastServer) Shutdown() error { + if m.ShutdownFunc != nil { + return m.ShutdownFunc() + } + return nil +} + +// TestDependencies 包含测试时可注入的依赖 +// 使用具体指针类型,允许注入 Mock 实现 +type TestDependencies struct { + LuaEngine *lua.LuaEngine + TLSManager *ssl.TLSManager +} + +// NewServerForTesting 创建用于测试的服务器实例 +// 允许注入 Mock 依赖,不改变生产 API +func NewServerForTesting(cfg *config.Config, deps *TestDependencies) *Server { + s := New(cfg) + if deps != nil { + if deps.LuaEngine != nil { + s.luaEngine = deps.LuaEngine + } + if deps.TLSManager != nil { + s.tlsManager = deps.TLSManager + } + } + return s +} + +// TestServerOptions 测试服务器的可选配置 +type TestServerOptions struct { + SkipListener bool + MockFastServer *MockFastServer + CustomHandler fasthttp.RequestHandler + DisableMiddleware bool +} + +// NewTestServerWithOptions 使用选项创建测试服务器 +func NewTestServerWithOptions(cfg *config.Config, opts *TestServerOptions) *Server { + s := New(cfg) + + if opts != nil { + // 可以在这里应用各种测试选项 + if opts.CustomHandler != nil { + s.handler = opts.CustomHandler + } + } + + return s +} + +// MustStartTestServer 启动测试服务器,失败时 panic +// 主要用于集成测试 +func MustStartTestServer(cfg *config.Config) *Server { + s := New(cfg) + // 在测试环境中使用随机端口避免冲突 + if cfg.Server.Listen == "" || cfg.Server.Listen == ":80" { + cfg.Server.Listen = "127.0.0.1:0" + } + + // 使用 goroutine 启动服务器以避免阻塞 + go func() { + if err := s.Start(); err != nil { + // 测试服务器启动失败记录日志 + panic("failed to start test server: " + err.Error()) + } + }() + + // 给服务器一点时间启动 + time.Sleep(10 * time.Millisecond) + + return s +}