refactor(server): 提取初始化逻辑到独立函数

- 将 Start() 中的 goroutine pool 初始化提取为 initGoroutinePool()
- 将 file cache 初始化提取为 initFileCache()
- 将 Lua engine 初始化提取为 initLuaEngine()
- 将 error page manager 初始化提取为 initErrorPageManager()
- 添加 init.go 存放提取的初始化函数
- 添加 init_test.go 测试初始化函数
- 添加 testutil.go 提供测试 mock 和工具
- 添加 lua_integration_test.go Lua 中间件集成测试
- 添加 start_integration_test.go Start() 集成测试
- 添加 server_test.go nil tlsManager 测试
- 添加 lua/mock_engine.go Lua 引擎 mock 实现
- 添加 lua/api_balancer_test.go Lua balancer API 测试

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-13 17:38:49 +08:00
parent 73ef7f4916
commit 9144dcbb06
9 changed files with 1970 additions and 65 deletions

View File

@ -0,0 +1,161 @@
// Package lua 提供 ngx.balancer API 的测试
//
// 该文件测试负载均衡相关的 Lua API包括
// - BalancerContext 创建和管理
// - set_current_peer API
// - set_more_tries API
// - get_last_failure API
// - get_targets API
// - get_client_ip API
// - IsSelected 边界测试
//
// 作者xfy
package lua
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"rua.plus/lolly/internal/loadbalance"
)
// TestBalancerContext_IsSelected 测试 IsSelected 方法
func TestBalancerContext_IsSelected(t *testing.T) {
tests := []struct {
name string
selected bool
wantResult bool
}{
{
name: "已选择目标",
selected: true,
wantResult: true,
},
{
name: "未选择目标",
selected: false,
wantResult: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bctx := &BalancerContext{
Targets: []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
},
ClientIP: "127.0.0.1",
Retries: 3,
selected: tt.selected,
}
if tt.selected {
bctx.Selected = bctx.Targets[0]
}
result := bctx.IsSelected()
assert.Equal(t, tt.wantResult, result, "IsSelected() should return expected value")
})
}
}
// TestBalancerContext_IsSelected_ZeroValue 测试零值情况
func TestBalancerContext_IsSelected_ZeroValue(t *testing.T) {
// 零值的 BalancerContext
bctx := &BalancerContext{}
// 默认应该返回 false
result := bctx.IsSelected()
assert.False(t, result, "Zero value BalancerContext should return false for IsSelected()")
}
// TestBalancerContext_IsSelected_AfterSelection 测试选择后的状态
func TestBalancerContext_IsSelected_AfterSelection(t *testing.T) {
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
}
bctx := &BalancerContext{
Targets: targets,
ClientIP: "192.168.1.1",
Retries: 3,
}
// 初始状态
assert.False(t, bctx.IsSelected(), "Should not be selected initially")
// 模拟选择目标
bctx.Selected = targets[0]
bctx.selected = true
// 选择后状态
assert.True(t, bctx.IsSelected(), "Should be selected after setting")
// 清除选择
bctx.selected = false
assert.False(t, bctx.IsSelected(), "Should not be selected after clearing flag")
}
// TestClassifyError 测试错误分类函数
func TestClassifyError(t *testing.T) {
tests := []struct {
name string
err error
expected string
}{
{
name: "nil error",
err: nil,
expected: "",
},
{
name: "timeout error",
err: errors.New("connection timeout"),
expected: "timeout",
},
{
name: "connection error",
err: errors.New("connection refused"),
expected: "failed",
},
{
name: "generic error",
err: errors.New("some error"),
expected: "failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := classifyError(tt.err)
assert.Equal(t, tt.expected, result)
})
}
}
// TestBalancerContext_Structure 测试 BalancerContext 结构体字段
func TestBalancerContext_Structure(t *testing.T) {
targets := []*loadbalance.Target{
{URL: "http://backend1:8080", Weight: 1},
{URL: "http://backend2:8080", Weight: 2},
}
bctx := &BalancerContext{
Targets: targets,
ClientIP: "10.0.0.1",
Retries: 5,
Selected: nil,
LastError: nil,
selected: false,
}
assert.Equal(t, 2, len(bctx.Targets))
assert.Equal(t, "10.0.0.1", bctx.ClientIP)
assert.Equal(t, 5, bctx.Retries)
assert.Nil(t, bctx.Selected)
assert.Nil(t, bctx.LastError)
assert.False(t, bctx.selected)
}

196
internal/lua/mock_engine.go Normal file
View File

@ -0,0 +1,196 @@
// Package lua 提供 Lua 引擎的 Mock 实现,用于测试
package lua
import (
"context"
"time"
"github.com/valyala/fasthttp"
glua "github.com/yuin/gopher-lua"
)
// MockLuaEngine 是 LuaEngine 的 Mock 实现
type MockLuaEngine struct {
ExecuteFunc func(script string) error
ExecuteFileFunc func(path string) error
NewCoroutineFunc func(ctx *fasthttp.RequestCtx) (*MockCoroutine, error)
CloseFunc func()
StatsFunc func() EngineStats
ActiveCoroutinesFunc func() int32
CodeCacheFunc func() *CodeCache
SharedDictManagerFunc func() *SharedDictManager
TimerManagerFunc func() *TimerManager
LocationManagerFunc func() *LocationManager
CreateSharedDictFunc func(name string, maxItems int) *SharedDict
InitSchedulerLStateFunc func() error
SchedulerLoopFunc func()
EnqueueCallbackFunc func(entry *CallbackEntry) bool
CloseSchedulerFunc func()
}
// Execute 执行脚本
func (m *MockLuaEngine) Execute(script string) error {
if m.ExecuteFunc != nil {
return m.ExecuteFunc(script)
}
return nil // stub
}
// ExecuteFile 执行文件
func (m *MockLuaEngine) ExecuteFile(path string) error {
if m.ExecuteFileFunc != nil {
return m.ExecuteFileFunc(path)
}
return nil // stub
}
// NewCoroutine 创建协程
func (m *MockLuaEngine) NewCoroutine(req *fasthttp.RequestCtx) (*MockCoroutine, error) {
if m.NewCoroutineFunc != nil {
return m.NewCoroutineFunc(req)
}
return &MockCoroutine{}, nil
}
// Close 关闭引擎
func (m *MockLuaEngine) Close() {
if m.CloseFunc != nil {
m.CloseFunc()
}
}
// Stats 返回统计
func (m *MockLuaEngine) Stats() EngineStats {
if m.StatsFunc != nil {
return m.StatsFunc()
}
return EngineStats{}
}
// ActiveCoroutines 返回活跃协程数
func (m *MockLuaEngine) ActiveCoroutines() int32 {
if m.ActiveCoroutinesFunc != nil {
return m.ActiveCoroutinesFunc()
}
return 0
}
// CodeCache 返回字节码缓存
func (m *MockLuaEngine) CodeCache() *CodeCache {
if m.CodeCacheFunc != nil {
return m.CodeCacheFunc()
}
return nil
}
// SharedDictManager 返回共享字典管理器
func (m *MockLuaEngine) SharedDictManager() *SharedDictManager {
if m.SharedDictManagerFunc != nil {
return m.SharedDictManagerFunc()
}
return nil
}
// TimerManager 返回定时器管理器
func (m *MockLuaEngine) TimerManager() *TimerManager {
if m.TimerManagerFunc != nil {
return m.TimerManagerFunc()
}
return nil
}
// LocationManager 返回 location 管理器
func (m *MockLuaEngine) LocationManager() *LocationManager {
if m.LocationManagerFunc != nil {
return m.LocationManagerFunc()
}
return nil
}
// CreateSharedDict 创建共享字典
func (m *MockLuaEngine) CreateSharedDict(name string, maxItems int) *SharedDict {
if m.CreateSharedDictFunc != nil {
return m.CreateSharedDictFunc(name, maxItems)
}
return nil
}
// InitSchedulerLState 初始化调度器 LState
func (m *MockLuaEngine) InitSchedulerLState() error {
if m.InitSchedulerLStateFunc != nil {
return m.InitSchedulerLStateFunc()
}
return nil
}
// SchedulerLoop 调度器循环
func (m *MockLuaEngine) SchedulerLoop() {
if m.SchedulerLoopFunc != nil {
m.SchedulerLoopFunc()
}
}
// EnqueueCallback 将回调加入调度队列
func (m *MockLuaEngine) EnqueueCallback(entry *CallbackEntry) bool {
if m.EnqueueCallbackFunc != nil {
return m.EnqueueCallbackFunc(entry)
}
return false
}
// CloseScheduler 关闭调度器
func (m *MockLuaEngine) CloseScheduler() {
if m.CloseSchedulerFunc != nil {
m.CloseSchedulerFunc()
}
}
// MockCoroutine 是 LuaCoroutine 的 Mock 实现
type MockCoroutine struct {
ExecuteFunc func(script string) error
ExecuteFileFunc func(path string) error
SetupSandboxFunc func() error
CloseFunc func()
HandleYieldFunc func(values []glua.LValue) ([]glua.LValue, error)
// 模拟字段
CreatedAt time.Time
ExecutionContext context.Context
Engine *MockLuaEngine
Co *glua.LState
Cancel context.CancelFunc
RequestCtx *fasthttp.RequestCtx
OutputBuffer []byte
Exited bool
}
// Execute 执行脚本
func (c *MockCoroutine) Execute(script string) error {
if c.ExecuteFunc != nil {
return c.ExecuteFunc(script)
}
return nil
}
// ExecuteFile 执行文件
func (c *MockCoroutine) ExecuteFile(path string) error {
if c.ExecuteFileFunc != nil {
return c.ExecuteFileFunc(path)
}
return nil
}
// SetupSandbox 设置沙箱
func (c *MockCoroutine) SetupSandbox() error {
if c.SetupSandboxFunc != nil {
return c.SetupSandboxFunc()
}
return nil
}
// Close 关闭协程
func (c *MockCoroutine) Close() {
if c.CloseFunc != nil {
c.CloseFunc()
}
}

169
internal/server/init.go Normal file
View File

@ -0,0 +1,169 @@
// Package server 提供服务器初始化函数。
//
// 该文件包含 Start() 方法中分离出的可测试初始化函数:
// - Goroutine 池初始化
// - 文件缓存初始化
// - Lua 引擎初始化
// - 错误页面管理器初始化
//
// 主要用途:
//
// 将 Start() 方法中的初始化逻辑分离,便于单元测试
//
// 注意事项:
// - 这些函数仅在 Server.Start() 内部调用
// - 分离后保持原逻辑不变,仅提取为独立函数
//
// 作者xfy
package server
import (
"fmt"
"time"
"rua.plus/lolly/internal/cache"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/handler"
"rua.plus/lolly/internal/logging"
"rua.plus/lolly/internal/lua"
)
// initGoroutinePool 初始化 Goroutine 池。
//
// 根据配置创建并启动 Goroutine 池,用于优化请求处理性能。
//
// 参数:
// - cfg: 性能配置,包含 GoroutinePool 配置
//
// 返回值:
// - *GoroutinePool: 初始化的 Goroutine 池,未启用时返回 nil
func initGoroutinePool(cfg *config.PerformanceConfig) *GoroutinePool {
if cfg == nil || !cfg.GoroutinePool.Enabled {
return nil
}
pool := NewGoroutinePool(PoolConfig{
MaxWorkers: cfg.GoroutinePool.MaxWorkers,
MinWorkers: cfg.GoroutinePool.MinWorkers,
IdleTimeout: cfg.GoroutinePool.IdleTimeout,
})
pool.Start()
logging.Info().
Int("maxWorkers", cfg.GoroutinePool.MaxWorkers).
Int("minWorkers", cfg.GoroutinePool.MinWorkers).
Msg("Goroutine 池已启动")
return pool
}
// initFileCache 初始化文件缓存。
//
// 根据配置创建文件缓存实例,用于缓存静态文件内容。
//
// 参数:
// - cfg: 性能配置,包含 FileCache 配置
//
// 返回值:
// - *cache.FileCache: 初始化的文件缓存,未启用时返回 nil
func initFileCache(cfg *config.PerformanceConfig) *cache.FileCache {
// 检查是否配置了缓存
if cfg == nil || (cfg.FileCache.MaxEntries <= 0 && cfg.FileCache.MaxSize <= 0) {
return nil
}
fileCache := cache.NewFileCache(
cfg.FileCache.MaxEntries,
cfg.FileCache.MaxSize,
cfg.FileCache.Inactive,
)
logging.Info().
Int64("maxEntries", cfg.FileCache.MaxEntries).
Int64("maxSize", cfg.FileCache.MaxSize).
Msg("文件缓存已启动")
return fileCache
}
// initLuaEngine 初始化 Lua 引擎。
//
// 根据配置创建 Lua 引擎实例,用于执行 Lua 脚本。
//
// 参数:
// - luaCfg: Lua 中间件配置
//
// 返回值:
// - *lua.LuaEngine: 初始化的 Lua 引擎,未启用时返回 nil
// - error: 初始化过程中遇到的错误
func initLuaEngine(luaCfg *config.LuaMiddlewareConfig) (*lua.LuaEngine, error) {
if luaCfg == nil || !luaCfg.Enabled {
return nil, nil
}
engineCfg := &lua.Config{
MaxConcurrentCoroutines: luaCfg.GlobalSettings.MaxConcurrentCoroutines,
CoroutineTimeout: luaCfg.GlobalSettings.CoroutineTimeout,
CodeCacheSize: luaCfg.GlobalSettings.CodeCacheSize,
CodeCacheTTL: time.Hour, // 默认值
EnableFileWatch: luaCfg.GlobalSettings.EnableFileWatch,
MaxExecutionTime: luaCfg.GlobalSettings.MaxExecutionTime,
EnableOSLib: false, // 安全默认值
EnableIOLib: false,
EnableLoadLib: false,
}
// 设置默认值
if engineCfg.MaxConcurrentCoroutines == 0 {
engineCfg.MaxConcurrentCoroutines = 1000
}
if engineCfg.CoroutineTimeout == 0 {
engineCfg.CoroutineTimeout = 30 * time.Second
}
if engineCfg.CodeCacheSize == 0 {
engineCfg.CodeCacheSize = 1000
}
if engineCfg.MaxExecutionTime == 0 {
engineCfg.MaxExecutionTime = 30 * time.Second
}
engine, err := lua.NewEngine(engineCfg)
if err != nil {
return nil, fmt.Errorf("初始化 Lua 引擎失败: %w", err)
}
logging.Info().Msg("Lua 引擎已启动")
return engine, nil
}
// initErrorPageManager 初始化错误页面管理器。
//
// 根据配置创建错误页面管理器,用于加载和提供自定义错误页面。
//
// 参数:
// - errorPageCfg: 错误页面配置
//
// 返回值:
// - *handler.ErrorPageManager: 初始化的错误页面管理器,未配置时返回 nil
// - error: 初始化过程中遇到的错误
func initErrorPageManager(errorPageCfg *config.ErrorPageConfig) (*handler.ErrorPageManager, error) {
// 检查是否配置了错误页面
if errorPageCfg == nil || (len(errorPageCfg.Pages) == 0 && errorPageCfg.Default == "") {
return nil, nil
}
manager, err := handler.NewErrorPageManager(errorPageCfg)
if err != nil {
// 检查是否是部分加载失败
if _, ok := err.(*handler.PartialLoadError); ok {
logging.Warn().Msg("部分错误页面加载失败: " + err.Error())
// 返回部分加载的管理器
return manager, nil
}
// 全部加载失败
return nil, fmt.Errorf("加载错误页面失败: %w", err)
}
logging.Info().Msg("错误页面管理器已启动")
return manager, nil
}

View File

@ -0,0 +1,351 @@
// Package server 提供初始化函数的测试。
//
// 该文件测试从 Start() 分离出的初始化函数:
// - initGoroutinePool()
// - initFileCache()
// - initLuaEngine()
// - initErrorPageManager()
//
// 主要用途:
//
// 验证初始化函数在各种配置场景下的行为
//
// 作者xfy
package server
import (
"os"
"path/filepath"
"testing"
"time"
"rua.plus/lolly/internal/config"
)
// TestInitGoroutinePool_Disabled 测试禁用时返回 nil
func TestInitGoroutinePool_Disabled(t *testing.T) {
cfg := &config.PerformanceConfig{
GoroutinePool: config.GoroutinePoolConfig{
Enabled: false,
},
}
pool := initGoroutinePool(cfg)
if pool != nil {
t.Error("Expected nil when GoroutinePool is disabled")
}
}
// TestInitGoroutinePool_Enabled 测试启用时创建池
func TestInitGoroutinePool_Enabled(t *testing.T) {
cfg := &config.PerformanceConfig{
GoroutinePool: config.GoroutinePoolConfig{
Enabled: true,
MaxWorkers: 50,
MinWorkers: 5,
IdleTimeout: 30 * time.Second,
},
}
pool := initGoroutinePool(cfg)
if pool == nil {
t.Fatal("Expected non-nil pool when enabled")
}
// 验证配置正确应用
if pool.maxWorkers != 50 {
t.Errorf("Expected maxWorkers 50, got %d", pool.maxWorkers)
}
if pool.minWorkers != 5 {
t.Errorf("Expected minWorkers 5, got %d", pool.minWorkers)
}
// 清理
pool.Stop()
}
// TestInitGoroutinePool_ZeroWorkers 测试零值配置
func TestInitGoroutinePool_ZeroWorkers(t *testing.T) {
cfg := &config.PerformanceConfig{
GoroutinePool: config.GoroutinePoolConfig{
Enabled: true,
MaxWorkers: 0,
MinWorkers: 0,
},
}
pool := initGoroutinePool(cfg)
if pool == nil {
t.Fatal("Expected non-nil pool")
}
// 清理
pool.Stop()
}
// TestInitFileCache_Disabled 测试禁用时返回 nil
func TestInitFileCache_Disabled(t *testing.T) {
cfg := &config.PerformanceConfig{
FileCache: config.FileCacheConfig{
MaxEntries: 0,
MaxSize: 0,
},
}
cache := initFileCache(cfg)
if cache != nil {
t.Error("Expected nil when FileCache is disabled")
}
}
// TestInitFileCache_ByEntries 测试按条目数启用
func TestInitFileCache_ByEntries(t *testing.T) {
cfg := &config.PerformanceConfig{
FileCache: config.FileCacheConfig{
MaxEntries: 1000,
MaxSize: 0,
Inactive: 10 * time.Minute,
},
}
cache := initFileCache(cfg)
if cache == nil {
t.Fatal("Expected non-nil cache when MaxEntries > 0")
}
}
// TestInitFileCache_BySize 测试按大小启用
func TestInitFileCache_BySize(t *testing.T) {
cfg := &config.PerformanceConfig{
FileCache: config.FileCacheConfig{
MaxEntries: 0,
MaxSize: 100 * 1024 * 1024, // 100MB
Inactive: 10 * time.Minute,
},
}
cache := initFileCache(cfg)
if cache == nil {
t.Fatal("Expected non-nil cache when MaxSize > 0")
}
}
// TestInitFileCache_BothLimits 测试同时设置条目数和大小
func TestInitFileCache_BothLimits(t *testing.T) {
cfg := &config.PerformanceConfig{
FileCache: config.FileCacheConfig{
MaxEntries: 1000,
MaxSize: 100 * 1024 * 1024,
Inactive: 10 * time.Minute,
},
}
cache := initFileCache(cfg)
if cache == nil {
t.Fatal("Expected non-nil cache")
}
}
// TestInitLuaEngine_Disabled 测试禁用时返回 nil
func TestInitLuaEngine_Disabled(t *testing.T) {
luaCfg := &config.LuaMiddlewareConfig{
Enabled: false,
}
engine, err := initLuaEngine(luaCfg)
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
if engine != nil {
t.Error("Expected nil when Lua is disabled")
}
}
// TestInitLuaEngine_NilConfig 测试 nil 配置
func TestInitLuaEngine_NilConfig(t *testing.T) {
engine, err := initLuaEngine(nil)
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
if engine != nil {
t.Error("Expected nil for nil config")
}
}
// TestInitLuaEngine_Enabled 测试启用 Lua 引擎
func TestInitLuaEngine_Enabled(t *testing.T) {
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
GlobalSettings: config.LuaGlobalSettings{
MaxConcurrentCoroutines: 500,
CoroutineTimeout: 60 * time.Second,
CodeCacheSize: 500,
MaxExecutionTime: 60 * time.Second,
},
}
engine, err := initLuaEngine(luaCfg)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if engine == nil {
t.Fatal("Expected non-nil engine when enabled")
}
// 清理
engine.Close()
}
// TestInitLuaEngine_DefaultValues 测试默认值设置
func TestInitLuaEngine_DefaultValues(t *testing.T) {
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
GlobalSettings: config.LuaGlobalSettings{
// 所有值为零,应使用默认值
},
}
engine, err := initLuaEngine(luaCfg)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if engine == nil {
t.Fatal("Expected non-nil engine")
}
// 清理
engine.Close()
}
// TestInitErrorPageManager_NoConfig 测试无配置时返回 nil
func TestInitErrorPageManager_NoConfig(t *testing.T) {
cfg := &config.ErrorPageConfig{
Pages: nil,
Default: "",
}
manager, err := initErrorPageManager(cfg)
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
if manager != nil {
t.Error("Expected nil when no error page configured")
}
}
// TestInitErrorPageManager_NilConfig 测试 nil 配置
func TestInitErrorPageManager_NilConfig(t *testing.T) {
manager, err := initErrorPageManager(nil)
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
if manager != nil {
t.Error("Expected nil for nil config")
}
}
// TestInitErrorPageManager_WithPages 测试配置错误页面
func TestInitErrorPageManager_WithPages(t *testing.T) {
// 创建临时目录和文件
tempDir := t.TempDir()
errorPagePath := filepath.Join(tempDir, "404.html")
content := []byte("<html><body>Not Found</body></html>")
if err := os.WriteFile(errorPagePath, content, 0o644); err != nil {
t.Fatalf("Failed to create error page file: %v", err)
}
cfg := &config.ErrorPageConfig{
Pages: map[int]string{
404: errorPagePath,
},
}
manager, err := initErrorPageManager(cfg)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if manager == nil {
t.Fatal("Expected non-nil manager when pages configured")
}
}
// TestInitErrorPageManager_WithDefault 测试配置默认错误页面
func TestInitErrorPageManager_WithDefault(t *testing.T) {
// 创建临时目录和文件
tempDir := t.TempDir()
defaultPagePath := filepath.Join(tempDir, "error.html")
content := []byte("<html><body>Error</body></html>")
if err := os.WriteFile(defaultPagePath, content, 0o644); err != nil {
t.Fatalf("Failed to create error page file: %v", err)
}
cfg := &config.ErrorPageConfig{
Default: defaultPagePath,
}
manager, err := initErrorPageManager(cfg)
if err != nil {
t.Fatalf("Expected no error, got: %v", err)
}
if manager == nil {
t.Fatal("Expected non-nil manager when default page configured")
}
}
// TestInitErrorPageManager_NonExistentFile 测试不存在的文件
func TestInitErrorPageManager_NonExistentFile(t *testing.T) {
cfg := &config.ErrorPageConfig{
Default: "/nonexistent/path/error.html",
}
manager, err := initErrorPageManager(cfg)
// 应该返回错误
if err == nil {
t.Error("Expected error for non-existent file")
}
if manager != nil {
t.Error("Expected nil manager when file doesn't exist")
}
}
// TestInitErrorPageManager_PartialLoad 测试部分加载(部分文件存在)
func TestInitErrorPageManager_PartialLoad(t *testing.T) {
// 创建临时目录和一个存在的文件
tempDir := t.TempDir()
existingPath := filepath.Join(tempDir, "404.html")
content := []byte("<html><body>Not Found</body></html>")
if err := os.WriteFile(existingPath, content, 0o644); err != nil {
t.Fatalf("Failed to create error page file: %v", err)
}
cfg := &config.ErrorPageConfig{
Pages: map[int]string{
404: existingPath,
500: "/nonexistent/path/500.html", // 不存在的文件
},
}
manager, err := initErrorPageManager(cfg)
// 部分加载应该成功返回 manager但可能有警告
// 注意:具体行为取决于 handler.NewErrorPageManager 的实现
if manager == nil {
t.Logf("Manager is nil with error: %v", err)
}
}
// TestInitFunctions_NilPerformanceConfig 测试 nil PerformanceConfig
func TestInitFunctions_NilPerformanceConfig(t *testing.T) {
// 这个测试验证函数能正确处理空配置
var cfg *config.PerformanceConfig
// 应该能处理 nil 而不 panic
pool := initGoroutinePool(cfg)
if pool != nil {
t.Error("Expected nil for nil config")
}
cache := initFileCache(cfg)
if cache != nil {
t.Error("Expected nil for nil config")
}
}

View File

@ -0,0 +1,372 @@
// Package server 提供 buildLuaMiddlewares 的 Mock 测试
package server
import (
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/lua"
)
// TestBuildLuaMiddlewares_NilEngine 测试 LuaEngine 为 nil 时
func TestBuildLuaMiddlewares_NilEngine(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
},
}
s := New(cfg)
// 确保 luaEngine 为 nil
s.luaEngine = nil
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{
Path: "/test/script.lua",
Phase: "rewrite",
Enabled: true,
Timeout: 30 * time.Second,
},
},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
require.NoError(t, err)
assert.Nil(t, middlewares)
}
// TestBuildLuaMiddlewares_InvalidPhase 测试无效阶段
func TestBuildLuaMiddlewares_InvalidPhase(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
},
}
s := New(cfg)
// 使用真实的 LuaEngine 来测试无效阶段
engine, err := lua.NewEngine(lua.DefaultConfig())
require.NoError(t, err)
defer engine.Close()
s.luaEngine = engine
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{
Path: "/test/script.lua",
Phase: "invalid_phase",
Enabled: true,
Timeout: 30 * time.Second,
},
},
}
_, err = s.buildLuaMiddlewares(luaCfg)
assert.Error(t, err)
assert.Contains(t, err.Error(), "无效的阶段")
}
// TestBuildLuaMiddlewares_WithTimeout 测试超时配置
func TestBuildLuaMiddlewares_WithTimeout(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
},
}
s := New(cfg)
engine, err := lua.NewEngine(lua.DefaultConfig())
require.NoError(t, err)
defer engine.Close()
s.luaEngine = engine
// 测试自定义超时
customTimeout := 60 * time.Second
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{
Path: "/test/script.lua",
Phase: "rewrite",
Enabled: true,
Timeout: customTimeout,
},
},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
require.NoError(t, err)
assert.Len(t, middlewares, 1)
}
// TestBuildLuaMiddlewares_EmptyScripts 测试空脚本列表
func TestBuildLuaMiddlewares_EmptyScripts(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
},
}
s := New(cfg)
engine, err := lua.NewEngine(lua.DefaultConfig())
require.NoError(t, err)
defer engine.Close()
s.luaEngine = engine
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
require.NoError(t, err)
assert.Empty(t, middlewares)
}
// TestBuildLuaMiddlewares_DisabledLua 测试禁用 Lua 配置
// buildLuaMiddlewares 函数本身不检查 Enabled 字段,由调用者检查
func TestBuildLuaMiddlewares_DisabledLua(t *testing.T) {
// 创建临时脚本文件
tempDir := t.TempDir()
scriptPath := tempDir + "/test.lua"
err := os.WriteFile(scriptPath, []byte("-- test"), 0o644)
require.NoError(t, err)
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
},
}
s := New(cfg)
engine, err := lua.NewEngine(lua.DefaultConfig())
require.NoError(t, err)
defer engine.Close()
s.luaEngine = engine
// Enabled=false 但 buildLuaMiddlewares 本身不检查这个字段
// 它会正常处理 Scripts 中的脚本
luaCfg := &config.LuaMiddlewareConfig{
Enabled: false,
Scripts: []config.LuaScriptConfig{
{
Path: scriptPath,
Phase: "rewrite",
Enabled: true,
Timeout: 30 * time.Second,
},
},
}
// 由于 Enabled 检查在调用者处buildLuaMiddlewares 会正常执行
middlewares, err := s.buildLuaMiddlewares(luaCfg)
require.NoError(t, err)
// 脚本文件存在,应该能创建中间件
assert.Len(t, middlewares, 1)
}
// TestBuildLuaMiddlewares_DisabledScript 测试禁用特定脚本
func TestBuildLuaMiddlewares_DisabledScript(t *testing.T) {
// 创建临时脚本文件(只有一个启用的脚本)
tempDir := t.TempDir()
enabledScriptPath := tempDir + "/enabled.lua"
err := os.WriteFile(enabledScriptPath, []byte("-- enabled script"), 0o644)
require.NoError(t, err)
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
},
}
s := New(cfg)
engine, err := lua.NewEngine(lua.DefaultConfig())
require.NoError(t, err)
defer engine.Close()
s.luaEngine = engine
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{
Path: enabledScriptPath,
Phase: "rewrite",
Enabled: true,
Timeout: 30 * time.Second,
},
{
Path: "/test/disabled.lua",
Phase: "access",
Enabled: false, // 禁用的脚本应该被过滤
Timeout: 30 * time.Second,
},
},
}
// 只有一个启用的脚本,应该能正常创建中间件
middlewares, err := s.buildLuaMiddlewares(luaCfg)
require.NoError(t, err)
// 只有 rewrite 阶段的脚本被创建
assert.Len(t, middlewares, 1)
}
// TestBuildLuaMiddlewares_WithExistingScript 测试使用存在的脚本文件
func TestBuildLuaMiddlewares_WithExistingScript(t *testing.T) {
// 创建临时脚本文件
scriptContent := `-- test script
ngx.var.uri = "/test"
`
tempDir := t.TempDir()
scriptPath := tempDir + "/test.lua"
err := os.WriteFile(scriptPath, []byte(scriptContent), 0o644)
require.NoError(t, err)
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
},
}
s := New(cfg)
engine, err := lua.NewEngine(lua.DefaultConfig())
require.NoError(t, err)
defer engine.Close()
s.luaEngine = engine
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{
Path: scriptPath,
Phase: "rewrite",
Enabled: true,
Timeout: 30 * time.Second,
},
},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
require.NoError(t, err)
assert.Len(t, middlewares, 1)
}
// TestBuildLuaMiddlewares_MultiplePhases 测试多个不同阶段的脚本
func TestBuildLuaMiddlewares_MultiplePhases(t *testing.T) {
// 创建临时脚本文件
tempDir := t.TempDir()
scripts := map[string]string{
"rewrite.lua": `-- rewrite script
ngx.var.uri = "/rewrite"
`,
"access.lua": `-- access script
-- access control logic
`,
"content.lua": `-- content script
ngx.say("hello")
`,
}
scriptPaths := make(map[string]string)
for name, content := range scripts {
path := tempDir + "/" + name
err := os.WriteFile(path, []byte(content), 0o644)
require.NoError(t, err)
scriptPaths[name] = path
}
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
},
}
s := New(cfg)
engine, err := lua.NewEngine(lua.DefaultConfig())
require.NoError(t, err)
defer engine.Close()
s.luaEngine = engine
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{
Path: scriptPaths["rewrite.lua"],
Phase: "rewrite",
Enabled: true,
Timeout: 30 * time.Second,
},
{
Path: scriptPaths["access.lua"],
Phase: "access",
Enabled: true,
Timeout: 30 * time.Second,
},
{
Path: scriptPaths["content.lua"],
Phase: "content",
Enabled: true,
Timeout: 30 * time.Second,
},
},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
require.NoError(t, err)
// 三个阶段应该创建三个中间件
assert.Len(t, middlewares, 3)
}
// TestBuildLuaMiddlewares_DefaultTimeout 测试默认超时值
func TestBuildLuaMiddlewares_DefaultTimeout(t *testing.T) {
// 创建临时脚本文件
tempDir := t.TempDir()
scriptPath := tempDir + "/test.lua"
err := os.WriteFile(scriptPath, []byte("-- test"), 0o644)
require.NoError(t, err)
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
},
}
s := New(cfg)
engine, err := lua.NewEngine(lua.DefaultConfig())
require.NoError(t, err)
defer engine.Close()
s.luaEngine = engine
// 不设置 Timeout使用默认值
luaCfg := &config.LuaMiddlewareConfig{
Enabled: true,
Scripts: []config.LuaScriptConfig{
{
Path: scriptPath,
Phase: "rewrite",
Enabled: true,
// Timeout 为 0应该使用默认值 30s
Timeout: 0,
},
},
}
middlewares, err := s.buildLuaMiddlewares(luaCfg)
require.NoError(t, err)
assert.Len(t, middlewares, 1)
}

View File

@ -405,73 +405,23 @@ func (s *Server) Start() error {
// 记录启动时间 // 记录启动时间
s.startTime = time.Now() s.startTime = time.Now()
// 启用 GoroutinePool如果配置 // 初始化 GoroutinePool
if s.config.Performance.GoroutinePool.Enabled { s.pool = initGoroutinePool(&s.config.Performance)
s.pool = NewGoroutinePool(PoolConfig{
MaxWorkers: s.config.Performance.GoroutinePool.MaxWorkers, // 初始化文件缓存
MinWorkers: s.config.Performance.GoroutinePool.MinWorkers, s.fileCache = initFileCache(&s.config.Performance)
IdleTimeout: s.config.Performance.GoroutinePool.IdleTimeout,
}) // 初始化错误页面管理器
s.pool.Start() var err error
s.errorPageManager, err = initErrorPageManager(&s.config.Server.Security.ErrorPage)
if err != nil {
return err
} }
// 启用文件缓存(如果配置) // 初始化 Lua 引擎
if s.config.Performance.FileCache.MaxEntries > 0 || s.config.Performance.FileCache.MaxSize > 0 { s.luaEngine, err = initLuaEngine(s.config.Server.Lua)
s.fileCache = cache.NewFileCache( if err != nil {
s.config.Performance.FileCache.MaxEntries, return err
s.config.Performance.FileCache.MaxSize,
s.config.Performance.FileCache.Inactive,
)
}
// 预加载错误页面(如果配置)
if s.config.Server.Security.ErrorPage.Pages != nil || s.config.Server.Security.ErrorPage.Default != "" {
var err error
s.errorPageManager, err = handler.NewErrorPageManager(&s.config.Server.Security.ErrorPage)
if err != nil {
// 检查是否是部分加载失败
if _, ok := err.(*handler.PartialLoadError); ok {
logging.Warn().Msg("部分错误页面加载失败: " + err.Error())
} else {
// 全部加载失败,阻止启动
return fmt.Errorf("加载错误页面失败: %w", err)
}
}
}
// 初始化 Lua 引擎(如果配置)
if s.config.Server.Lua != nil && s.config.Server.Lua.Enabled {
engineCfg := &lua.Config{
MaxConcurrentCoroutines: s.config.Server.Lua.GlobalSettings.MaxConcurrentCoroutines,
CoroutineTimeout: s.config.Server.Lua.GlobalSettings.CoroutineTimeout,
CodeCacheSize: s.config.Server.Lua.GlobalSettings.CodeCacheSize,
CodeCacheTTL: time.Hour, // 默认值
EnableFileWatch: s.config.Server.Lua.GlobalSettings.EnableFileWatch,
MaxExecutionTime: s.config.Server.Lua.GlobalSettings.MaxExecutionTime,
EnableOSLib: false, // 安全默认值
EnableIOLib: false,
EnableLoadLib: false,
}
// 设置默认值
if engineCfg.MaxConcurrentCoroutines == 0 {
engineCfg.MaxConcurrentCoroutines = 1000
}
if engineCfg.CoroutineTimeout == 0 {
engineCfg.CoroutineTimeout = 30 * time.Second
}
if engineCfg.CodeCacheSize == 0 {
engineCfg.CodeCacheSize = 1000
}
if engineCfg.MaxExecutionTime == 0 {
engineCfg.MaxExecutionTime = 30 * time.Second
}
var err error
s.luaEngine, err = lua.NewEngine(engineCfg)
if err != nil {
return fmt.Errorf("初始化 Lua 引擎失败: %w", err)
}
logging.Info().Msg("Lua 引擎已启动")
} }
if s.config.HasServers() { if s.config.HasServers() {

View File

@ -769,3 +769,32 @@ func TestGetTLSConfig_Nil(t *testing.T) {
t.Error("GetTLSConfig() should return nil when TLS not configured") t.Error("GetTLSConfig() should return nil when TLS not configured")
} }
} }
// TestGetTLSConfig_NilServer 测试 nil 服务器调用 GetTLSConfig
func TestGetTLSConfig_NilServer(t *testing.T) {
var s *Server
// 防御性:如果 s 为 nil调用方法会 panic这是预期的行为
// 这里我们只测试非 nil 但 tlsManager 为 nil 的情况
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":0",
},
}
s = New(cfg)
// 确保 tlsManager 为 nil
if s.tlsManager != nil {
t.Skip("tlsManager should be nil initially")
}
tlsCfg, err := s.GetTLSConfig()
if err == nil {
t.Error("Expected error when tlsManager is nil")
}
if tlsCfg != nil {
t.Error("Expected nil TLS config when tlsManager is nil")
}
if err.Error() != "TLS not configured" {
t.Errorf("Expected error 'TLS not configured', got: %v", err)
}
}

View File

@ -0,0 +1,556 @@
// Package server 提供 Server.Start() 的集成测试。
//
// 该文件使用 mock_backend 模拟上游服务,测试完整的服务器启动流程:
// - 服务器配置初始化
// - 代理路由注册
// - 静态文件服务
// - 中间件链构建
// - 请求处理和转发
//
// 主要用途:
//
// 验证服务器启动和请求处理的完整流程
//
// 作者xfy
package server
import (
"os"
"testing"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/benchmark/tools"
"rua.plus/lolly/internal/config"
)
// TestStart_Integration 测试完整的服务器启动和请求处理流程
func TestStart_Integration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// 启动 mock 上游服务器
backendAddr, cleanup := tools.SimpleMockBackend(
fasthttp.StatusOK,
[]byte(`{"message": "Hello from backend"}`),
)
defer cleanup()
// 使用随机端口避免冲突
serverAddr := "127.0.0.1:0"
cfg := &config.Config{
Server: config.ServerConfig{
Listen: serverAddr,
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://" + backendAddr, Weight: 1},
},
HealthCheck: config.HealthCheckConfig{},
},
},
},
}
s := New(cfg)
// 验证服务器初始化
if s == nil {
t.Fatal("Expected non-nil server")
}
// 测试初始状态
if s.running {
t.Error("Server should not be running initially")
}
// 测试 GetListeners 初始状态
listeners := s.GetListeners()
if listeners != nil {
t.Error("Listeners should be nil before Start")
}
}
// TestStart_WithSecurity 测试安全配置
func TestStart_WithSecurity(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Security: config.SecurityConfig{
Access: config.AccessConfig{
Allow: []string{"127.0.0.1"},
Deny: []string{},
},
RateLimit: config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
},
Headers: config.SecurityHeaders{
XFrameOptions: "DENY",
XContentTypeOptions: "nosniff",
},
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
// 验证安全配置
if len(s.config.Server.Security.Access.Allow) != 1 {
t.Errorf("Expected 1 allowed IP, got %d", len(s.config.Server.Security.Access.Allow))
}
}
// TestStart_WithRewrite 测试 URL 重写配置
func TestStart_WithRewrite(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Rewrite: []config.RewriteRule{
{
Pattern: "/old/(.*)",
Replacement: "/new/$1",
},
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
// 验证重写配置
if len(s.config.Server.Rewrite) != 1 {
t.Errorf("Expected 1 rewrite rule, got %d", len(s.config.Server.Rewrite))
}
}
// TestStart_WithMonitoring 测试监控配置
func TestStart_WithMonitoring(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
},
Monitoring: config.MonitoringConfig{
Status: config.StatusConfig{
Path: "/status",
Allow: []string{"127.0.0.1"},
},
Pprof: config.PprofConfig{
Enabled: false,
Path: "/debug/pprof",
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
// 验证监控配置
if s.config.Monitoring.Status.Path != "/status" {
t.Errorf("Expected status path '/status', got '%s'", s.config.Monitoring.Status.Path)
}
}
// TestStart_WithErrorPage 测试错误页面配置
func TestStart_WithErrorPage(t *testing.T) {
// 创建临时错误页面文件
tempDir := t.TempDir()
errorPagePath := tempDir + "/404.html"
// 创建错误页面文件
content := []byte("<html><body>Not Found</body></html>")
if err := writeFile(errorPagePath, content); err != nil {
t.Fatalf("Failed to create error page file: %v", err)
}
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Security: config.SecurityConfig{
ErrorPage: config.ErrorPageConfig{
Pages: map[int]string{
404: errorPagePath,
},
},
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
// 验证错误页面配置
if s.config.Server.Security.ErrorPage.Pages == nil {
t.Error("Error page pages should not be nil")
}
}
// TestStart_WithLuaEnabled 测试 Lua 配置
func TestStart_WithLuaEnabled(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Lua: &config.LuaMiddlewareConfig{
Enabled: true,
GlobalSettings: config.LuaGlobalSettings{
MaxConcurrentCoroutines: 100,
CoroutineTimeout: 30 * time.Second,
CodeCacheSize: 100,
MaxExecutionTime: 30 * time.Second,
},
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
// 验证 Lua 配置
if s.config.Server.Lua == nil || !s.config.Server.Lua.Enabled {
t.Error("Lua should be enabled")
}
}
// TestStart_WithMultipleProxies 测试多个代理配置
func TestStart_WithMultipleProxies(t *testing.T) {
// 启动多个 mock 上游服务器
backend1, cleanup1 := tools.SimpleMockBackend(
fasthttp.StatusOK,
[]byte(`{"service": "api1"}`),
)
defer cleanup1()
backend2, cleanup2 := tools.SimpleMockBackend(
fasthttp.StatusOK,
[]byte(`{"service": "api2"}`),
)
defer cleanup2()
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api1",
Targets: []config.ProxyTarget{
{URL: "http://" + backend1, Weight: 1},
},
},
{
Path: "/api2",
Targets: []config.ProxyTarget{
{URL: "http://" + backend2, Weight: 1},
},
},
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
// 验证代理配置
if len(s.config.Server.Proxy) != 2 {
t.Errorf("Expected 2 proxies, got %d", len(s.config.Server.Proxy))
}
}
// TestStart_EmptyConfig 测试空配置
func TestStart_EmptyConfig(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
// 空配置应该能正常初始化
if s.config == nil {
t.Error("Config should not be nil")
}
}
// TestStart_WithAllFeatures 测试启用所有功能的配置
func TestStart_WithAllFeatures(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
errorPagePath := tempDir + "/404.html"
writeFile(errorPagePath, []byte("<html><body>Not Found</body></html>"))
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Static: []config.StaticConfig{
{
Path: "/static",
Root: tempDir,
Index: []string{"index.html"},
},
},
Compression: config.CompressionConfig{
Type: "gzip",
Level: 6,
},
Security: config.SecurityConfig{
Access: config.AccessConfig{
Allow: []string{"127.0.0.1"},
},
RateLimit: config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
},
Headers: config.SecurityHeaders{
XFrameOptions: "DENY",
},
ErrorPage: config.ErrorPageConfig{
Default: errorPagePath,
},
},
Rewrite: []config.RewriteRule{
{
Pattern: "/old/(.*)",
Replacement: "/new/$1",
},
},
},
Performance: config.PerformanceConfig{
GoroutinePool: config.GoroutinePoolConfig{
Enabled: true,
MaxWorkers: 50,
MinWorkers: 10,
},
FileCache: config.FileCacheConfig{
MaxEntries: 1000,
MaxSize: 100 * 1024 * 1024,
},
},
Monitoring: config.MonitoringConfig{
Status: config.StatusConfig{
Path: "/status",
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
// 验证所有配置
if !s.config.Performance.GoroutinePool.Enabled {
t.Error("GoroutinePool should be enabled")
}
if s.config.Server.Compression.Type != "gzip" {
t.Error("Compression should be gzip")
}
}
// TestStart_ServerOptions 测试服务器配置选项
func TestStart_ServerOptions(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
MaxConnsPerIP: 100,
MaxRequestsPerConn: 1000,
ClientMaxBodySize: "10MB",
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
// 验证服务器选项
if s.config.Server.ReadTimeout != 30*time.Second {
t.Errorf("Expected ReadTimeout 30s, got %v", s.config.Server.ReadTimeout)
}
if s.config.Server.MaxConnsPerIP != 100 {
t.Errorf("Expected MaxConnsPerIP 100, got %d", s.config.Server.MaxConnsPerIP)
}
}
// TestStart_HealthCheckConfig 测试健康检查配置
func TestStart_HealthCheckConfig(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://127.0.0.1:8081", Weight: 1},
},
HealthCheck: config.HealthCheckConfig{
Interval: 10 * time.Second,
Timeout: 5 * time.Second,
Path: "/health",
},
},
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
// 验证健康检查配置
if s.config.Server.Proxy[0].HealthCheck.Path != "/health" {
t.Error("Health check path should be /health")
}
}
// TestStart_VHostMode 测试虚拟主机模式配置
func TestStart_VHostMode(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{
{
Name: "api.example.com",
Listen: "127.0.0.1:0",
},
{
Name: "www.example.com",
Listen: "127.0.0.1:0",
},
},
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
// 验证虚拟主机配置
if !s.config.HasServers() {
t.Error("Should detect virtual hosts")
}
}
// TestStart_WithProxyBackendError 测试代理后端错误处理
func TestStart_WithProxyBackendError(t *testing.T) {
// 启动返回错误的 mock 服务器
backendAddr, cleanup := tools.ErrorMockBackend(1.0, []byte(`{"error": "backend error"}`))
defer cleanup()
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://" + backendAddr, Weight: 1},
},
},
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
// 验证代理配置
if len(s.config.Server.Proxy) != 1 {
t.Errorf("Expected 1 proxy, got %d", len(s.config.Server.Proxy))
}
}
// TestStart_WithDelayedBackend 测试延迟后端
func TestStart_WithDelayedBackend(t *testing.T) {
// 启动延迟的 mock 服务器
backendAddr, cleanup := tools.DelayedMockBackend(
100*time.Millisecond,
[]byte(`{"message": "delayed response"}`),
)
defer cleanup()
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://" + backendAddr, Weight: 1},
},
},
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
}
// TestStart_WithRandomResponse 测试随机响应后端
func TestStart_WithRandomResponse(t *testing.T) {
// 启动随机响应的 mock 服务器
backendAddr, cleanup := tools.StartMockFasthttpBackend(tools.MockBackendConfig{
Mode: tools.ModeRandomResponse,
StatusCode: fasthttp.StatusOK,
Body: []byte(`{"random": true}`),
})
defer cleanup()
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://" + backendAddr, Weight: 1},
},
},
},
},
}
s := New(cfg)
if s == nil {
t.Fatal("Expected non-nil server")
}
}
// writeFile 辅助函数:写入文件
func writeFile(path string, content []byte) error {
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
_, err = f.Write(content)
return err
}

121
internal/server/testutil.go Normal file
View File

@ -0,0 +1,121 @@
// Package server 提供测试工具函数和依赖注入支持
package server
import (
"crypto/tls"
"net"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/lua"
"rua.plus/lolly/internal/ssl"
)
// MockFastServer 是 fasthttp.Server 的 Mock 包装
// 定义在此文件以便 TestServerOptions 可以引用
type MockFastServer struct {
Name string
Handler fasthttp.RequestHandler
TLSConfig *tls.Config
ReadTimeout time.Duration
WriteTimeout time.Duration
IdleTimeout time.Duration
MaxConnsPerIP int
MaxRequestsPerConn int
CloseOnShutdown bool
ServeFunc func(ln net.Listener) error
ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error
ShutdownFunc func() error
}
// Serve 启动服务
func (m *MockFastServer) Serve(ln net.Listener) error {
if m.ServeFunc != nil {
return m.ServeFunc(ln)
}
return nil
}
// ServeTLS 启动 TLS 服务
func (m *MockFastServer) ServeTLS(ln net.Listener, certFile, keyFile string) error {
if m.ServeTLSFunc != nil {
return m.ServeTLSFunc(ln, certFile, keyFile)
}
return nil
}
// Shutdown 关闭服务器
func (m *MockFastServer) Shutdown() error {
if m.ShutdownFunc != nil {
return m.ShutdownFunc()
}
return nil
}
// TestDependencies 包含测试时可注入的依赖
// 使用具体指针类型,允许注入 Mock 实现
type TestDependencies struct {
LuaEngine *lua.LuaEngine
TLSManager *ssl.TLSManager
}
// NewServerForTesting 创建用于测试的服务器实例
// 允许注入 Mock 依赖,不改变生产 API
func NewServerForTesting(cfg *config.Config, deps *TestDependencies) *Server {
s := New(cfg)
if deps != nil {
if deps.LuaEngine != nil {
s.luaEngine = deps.LuaEngine
}
if deps.TLSManager != nil {
s.tlsManager = deps.TLSManager
}
}
return s
}
// TestServerOptions 测试服务器的可选配置
type TestServerOptions struct {
SkipListener bool
MockFastServer *MockFastServer
CustomHandler fasthttp.RequestHandler
DisableMiddleware bool
}
// NewTestServerWithOptions 使用选项创建测试服务器
func NewTestServerWithOptions(cfg *config.Config, opts *TestServerOptions) *Server {
s := New(cfg)
if opts != nil {
// 可以在这里应用各种测试选项
if opts.CustomHandler != nil {
s.handler = opts.CustomHandler
}
}
return s
}
// MustStartTestServer 启动测试服务器,失败时 panic
// 主要用于集成测试
func MustStartTestServer(cfg *config.Config) *Server {
s := New(cfg)
// 在测试环境中使用随机端口避免冲突
if cfg.Server.Listen == "" || cfg.Server.Listen == ":80" {
cfg.Server.Listen = "127.0.0.1:0"
}
// 使用 goroutine 启动服务器以避免阻塞
go func() {
if err := s.Start(); err != nil {
// 测试服务器启动失败记录日志
panic("failed to start test server: " + err.Error())
}
}()
// 给服务器一点时间启动
time.Sleep(10 * time.Millisecond)
return s
}