refactor(server): 提取初始化逻辑到独立函数
- 将 Start() 中的 goroutine pool 初始化提取为 initGoroutinePool() - 将 file cache 初始化提取为 initFileCache() - 将 Lua engine 初始化提取为 initLuaEngine() - 将 error page manager 初始化提取为 initErrorPageManager() - 添加 init.go 存放提取的初始化函数 - 添加 init_test.go 测试初始化函数 - 添加 testutil.go 提供测试 mock 和工具 - 添加 lua_integration_test.go Lua 中间件集成测试 - 添加 start_integration_test.go Start() 集成测试 - 添加 server_test.go nil tlsManager 测试 - 添加 lua/mock_engine.go Lua 引擎 mock 实现 - 添加 lua/api_balancer_test.go Lua balancer API 测试 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
73ef7f4916
commit
9144dcbb06
161
internal/lua/api_balancer_test.go
Normal file
161
internal/lua/api_balancer_test.go
Normal file
@ -0,0 +1,161 @@
|
||||
// Package lua 提供 ngx.balancer API 的测试
|
||||
//
|
||||
// 该文件测试负载均衡相关的 Lua API,包括:
|
||||
// - BalancerContext 创建和管理
|
||||
// - set_current_peer API
|
||||
// - set_more_tries API
|
||||
// - get_last_failure API
|
||||
// - get_targets API
|
||||
// - get_client_ip API
|
||||
// - IsSelected 边界测试
|
||||
//
|
||||
// 作者:xfy
|
||||
package lua
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
)
|
||||
|
||||
// TestBalancerContext_IsSelected 测试 IsSelected 方法
|
||||
func TestBalancerContext_IsSelected(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
selected bool
|
||||
wantResult bool
|
||||
}{
|
||||
{
|
||||
name: "已选择目标",
|
||||
selected: true,
|
||||
wantResult: true,
|
||||
},
|
||||
{
|
||||
name: "未选择目标",
|
||||
selected: false,
|
||||
wantResult: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
bctx := &BalancerContext{
|
||||
Targets: []*loadbalance.Target{
|
||||
{URL: "http://backend1:8080"},
|
||||
{URL: "http://backend2:8080"},
|
||||
},
|
||||
ClientIP: "127.0.0.1",
|
||||
Retries: 3,
|
||||
selected: tt.selected,
|
||||
}
|
||||
|
||||
if tt.selected {
|
||||
bctx.Selected = bctx.Targets[0]
|
||||
}
|
||||
|
||||
result := bctx.IsSelected()
|
||||
assert.Equal(t, tt.wantResult, result, "IsSelected() should return expected value")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBalancerContext_IsSelected_ZeroValue 测试零值情况
|
||||
func TestBalancerContext_IsSelected_ZeroValue(t *testing.T) {
|
||||
// 零值的 BalancerContext
|
||||
bctx := &BalancerContext{}
|
||||
|
||||
// 默认应该返回 false
|
||||
result := bctx.IsSelected()
|
||||
assert.False(t, result, "Zero value BalancerContext should return false for IsSelected()")
|
||||
}
|
||||
|
||||
// TestBalancerContext_IsSelected_AfterSelection 测试选择后的状态
|
||||
func TestBalancerContext_IsSelected_AfterSelection(t *testing.T) {
|
||||
targets := []*loadbalance.Target{
|
||||
{URL: "http://backend1:8080"},
|
||||
{URL: "http://backend2:8080"},
|
||||
}
|
||||
|
||||
bctx := &BalancerContext{
|
||||
Targets: targets,
|
||||
ClientIP: "192.168.1.1",
|
||||
Retries: 3,
|
||||
}
|
||||
|
||||
// 初始状态
|
||||
assert.False(t, bctx.IsSelected(), "Should not be selected initially")
|
||||
|
||||
// 模拟选择目标
|
||||
bctx.Selected = targets[0]
|
||||
bctx.selected = true
|
||||
|
||||
// 选择后状态
|
||||
assert.True(t, bctx.IsSelected(), "Should be selected after setting")
|
||||
|
||||
// 清除选择
|
||||
bctx.selected = false
|
||||
assert.False(t, bctx.IsSelected(), "Should not be selected after clearing flag")
|
||||
}
|
||||
|
||||
// TestClassifyError 测试错误分类函数
|
||||
func TestClassifyError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "timeout error",
|
||||
err: errors.New("connection timeout"),
|
||||
expected: "timeout",
|
||||
},
|
||||
{
|
||||
name: "connection error",
|
||||
err: errors.New("connection refused"),
|
||||
expected: "failed",
|
||||
},
|
||||
{
|
||||
name: "generic error",
|
||||
err: errors.New("some error"),
|
||||
expected: "failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := classifyError(tt.err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBalancerContext_Structure 测试 BalancerContext 结构体字段
|
||||
func TestBalancerContext_Structure(t *testing.T) {
|
||||
targets := []*loadbalance.Target{
|
||||
{URL: "http://backend1:8080", Weight: 1},
|
||||
{URL: "http://backend2:8080", Weight: 2},
|
||||
}
|
||||
|
||||
bctx := &BalancerContext{
|
||||
Targets: targets,
|
||||
ClientIP: "10.0.0.1",
|
||||
Retries: 5,
|
||||
Selected: nil,
|
||||
LastError: nil,
|
||||
selected: false,
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, len(bctx.Targets))
|
||||
assert.Equal(t, "10.0.0.1", bctx.ClientIP)
|
||||
assert.Equal(t, 5, bctx.Retries)
|
||||
assert.Nil(t, bctx.Selected)
|
||||
assert.Nil(t, bctx.LastError)
|
||||
assert.False(t, bctx.selected)
|
||||
}
|
||||
196
internal/lua/mock_engine.go
Normal file
196
internal/lua/mock_engine.go
Normal file
@ -0,0 +1,196 @@
|
||||
// Package lua 提供 Lua 引擎的 Mock 实现,用于测试
|
||||
package lua
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// MockLuaEngine 是 LuaEngine 的 Mock 实现
|
||||
type MockLuaEngine struct {
|
||||
ExecuteFunc func(script string) error
|
||||
ExecuteFileFunc func(path string) error
|
||||
NewCoroutineFunc func(ctx *fasthttp.RequestCtx) (*MockCoroutine, error)
|
||||
CloseFunc func()
|
||||
StatsFunc func() EngineStats
|
||||
ActiveCoroutinesFunc func() int32
|
||||
CodeCacheFunc func() *CodeCache
|
||||
SharedDictManagerFunc func() *SharedDictManager
|
||||
TimerManagerFunc func() *TimerManager
|
||||
LocationManagerFunc func() *LocationManager
|
||||
CreateSharedDictFunc func(name string, maxItems int) *SharedDict
|
||||
InitSchedulerLStateFunc func() error
|
||||
SchedulerLoopFunc func()
|
||||
EnqueueCallbackFunc func(entry *CallbackEntry) bool
|
||||
CloseSchedulerFunc func()
|
||||
}
|
||||
|
||||
// Execute 执行脚本
|
||||
func (m *MockLuaEngine) Execute(script string) error {
|
||||
if m.ExecuteFunc != nil {
|
||||
return m.ExecuteFunc(script)
|
||||
}
|
||||
return nil // stub
|
||||
}
|
||||
|
||||
// ExecuteFile 执行文件
|
||||
func (m *MockLuaEngine) ExecuteFile(path string) error {
|
||||
if m.ExecuteFileFunc != nil {
|
||||
return m.ExecuteFileFunc(path)
|
||||
}
|
||||
return nil // stub
|
||||
}
|
||||
|
||||
// NewCoroutine 创建协程
|
||||
func (m *MockLuaEngine) NewCoroutine(req *fasthttp.RequestCtx) (*MockCoroutine, error) {
|
||||
if m.NewCoroutineFunc != nil {
|
||||
return m.NewCoroutineFunc(req)
|
||||
}
|
||||
return &MockCoroutine{}, nil
|
||||
}
|
||||
|
||||
// Close 关闭引擎
|
||||
func (m *MockLuaEngine) Close() {
|
||||
if m.CloseFunc != nil {
|
||||
m.CloseFunc()
|
||||
}
|
||||
}
|
||||
|
||||
// Stats 返回统计
|
||||
func (m *MockLuaEngine) Stats() EngineStats {
|
||||
if m.StatsFunc != nil {
|
||||
return m.StatsFunc()
|
||||
}
|
||||
return EngineStats{}
|
||||
}
|
||||
|
||||
// ActiveCoroutines 返回活跃协程数
|
||||
func (m *MockLuaEngine) ActiveCoroutines() int32 {
|
||||
if m.ActiveCoroutinesFunc != nil {
|
||||
return m.ActiveCoroutinesFunc()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// CodeCache 返回字节码缓存
|
||||
func (m *MockLuaEngine) CodeCache() *CodeCache {
|
||||
if m.CodeCacheFunc != nil {
|
||||
return m.CodeCacheFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SharedDictManager 返回共享字典管理器
|
||||
func (m *MockLuaEngine) SharedDictManager() *SharedDictManager {
|
||||
if m.SharedDictManagerFunc != nil {
|
||||
return m.SharedDictManagerFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TimerManager 返回定时器管理器
|
||||
func (m *MockLuaEngine) TimerManager() *TimerManager {
|
||||
if m.TimerManagerFunc != nil {
|
||||
return m.TimerManagerFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocationManager 返回 location 管理器
|
||||
func (m *MockLuaEngine) LocationManager() *LocationManager {
|
||||
if m.LocationManagerFunc != nil {
|
||||
return m.LocationManagerFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateSharedDict 创建共享字典
|
||||
func (m *MockLuaEngine) CreateSharedDict(name string, maxItems int) *SharedDict {
|
||||
if m.CreateSharedDictFunc != nil {
|
||||
return m.CreateSharedDictFunc(name, maxItems)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitSchedulerLState 初始化调度器 LState
|
||||
func (m *MockLuaEngine) InitSchedulerLState() error {
|
||||
if m.InitSchedulerLStateFunc != nil {
|
||||
return m.InitSchedulerLStateFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SchedulerLoop 调度器循环
|
||||
func (m *MockLuaEngine) SchedulerLoop() {
|
||||
if m.SchedulerLoopFunc != nil {
|
||||
m.SchedulerLoopFunc()
|
||||
}
|
||||
}
|
||||
|
||||
// EnqueueCallback 将回调加入调度队列
|
||||
func (m *MockLuaEngine) EnqueueCallback(entry *CallbackEntry) bool {
|
||||
if m.EnqueueCallbackFunc != nil {
|
||||
return m.EnqueueCallbackFunc(entry)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CloseScheduler 关闭调度器
|
||||
func (m *MockLuaEngine) CloseScheduler() {
|
||||
if m.CloseSchedulerFunc != nil {
|
||||
m.CloseSchedulerFunc()
|
||||
}
|
||||
}
|
||||
|
||||
// MockCoroutine 是 LuaCoroutine 的 Mock 实现
|
||||
type MockCoroutine struct {
|
||||
ExecuteFunc func(script string) error
|
||||
ExecuteFileFunc func(path string) error
|
||||
SetupSandboxFunc func() error
|
||||
CloseFunc func()
|
||||
HandleYieldFunc func(values []glua.LValue) ([]glua.LValue, error)
|
||||
|
||||
// 模拟字段
|
||||
CreatedAt time.Time
|
||||
ExecutionContext context.Context
|
||||
Engine *MockLuaEngine
|
||||
Co *glua.LState
|
||||
Cancel context.CancelFunc
|
||||
RequestCtx *fasthttp.RequestCtx
|
||||
OutputBuffer []byte
|
||||
Exited bool
|
||||
}
|
||||
|
||||
// Execute 执行脚本
|
||||
func (c *MockCoroutine) Execute(script string) error {
|
||||
if c.ExecuteFunc != nil {
|
||||
return c.ExecuteFunc(script)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteFile 执行文件
|
||||
func (c *MockCoroutine) ExecuteFile(path string) error {
|
||||
if c.ExecuteFileFunc != nil {
|
||||
return c.ExecuteFileFunc(path)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetupSandbox 设置沙箱
|
||||
func (c *MockCoroutine) SetupSandbox() error {
|
||||
if c.SetupSandboxFunc != nil {
|
||||
return c.SetupSandboxFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭协程
|
||||
func (c *MockCoroutine) Close() {
|
||||
if c.CloseFunc != nil {
|
||||
c.CloseFunc()
|
||||
}
|
||||
}
|
||||
169
internal/server/init.go
Normal file
169
internal/server/init.go
Normal file
@ -0,0 +1,169 @@
|
||||
// Package server 提供服务器初始化函数。
|
||||
//
|
||||
// 该文件包含 Start() 方法中分离出的可测试初始化函数:
|
||||
// - Goroutine 池初始化
|
||||
// - 文件缓存初始化
|
||||
// - Lua 引擎初始化
|
||||
// - 错误页面管理器初始化
|
||||
//
|
||||
// 主要用途:
|
||||
//
|
||||
// 将 Start() 方法中的初始化逻辑分离,便于单元测试
|
||||
//
|
||||
// 注意事项:
|
||||
// - 这些函数仅在 Server.Start() 内部调用
|
||||
// - 分离后保持原逻辑不变,仅提取为独立函数
|
||||
//
|
||||
// 作者:xfy
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/cache"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/handler"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
"rua.plus/lolly/internal/lua"
|
||||
)
|
||||
|
||||
// initGoroutinePool 初始化 Goroutine 池。
|
||||
//
|
||||
// 根据配置创建并启动 Goroutine 池,用于优化请求处理性能。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 性能配置,包含 GoroutinePool 配置
|
||||
//
|
||||
// 返回值:
|
||||
// - *GoroutinePool: 初始化的 Goroutine 池,未启用时返回 nil
|
||||
func initGoroutinePool(cfg *config.PerformanceConfig) *GoroutinePool {
|
||||
if cfg == nil || !cfg.GoroutinePool.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
pool := NewGoroutinePool(PoolConfig{
|
||||
MaxWorkers: cfg.GoroutinePool.MaxWorkers,
|
||||
MinWorkers: cfg.GoroutinePool.MinWorkers,
|
||||
IdleTimeout: cfg.GoroutinePool.IdleTimeout,
|
||||
})
|
||||
pool.Start()
|
||||
|
||||
logging.Info().
|
||||
Int("maxWorkers", cfg.GoroutinePool.MaxWorkers).
|
||||
Int("minWorkers", cfg.GoroutinePool.MinWorkers).
|
||||
Msg("Goroutine 池已启动")
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
// initFileCache 初始化文件缓存。
|
||||
//
|
||||
// 根据配置创建文件缓存实例,用于缓存静态文件内容。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 性能配置,包含 FileCache 配置
|
||||
//
|
||||
// 返回值:
|
||||
// - *cache.FileCache: 初始化的文件缓存,未启用时返回 nil
|
||||
func initFileCache(cfg *config.PerformanceConfig) *cache.FileCache {
|
||||
// 检查是否配置了缓存
|
||||
if cfg == nil || (cfg.FileCache.MaxEntries <= 0 && cfg.FileCache.MaxSize <= 0) {
|
||||
return nil
|
||||
}
|
||||
|
||||
fileCache := cache.NewFileCache(
|
||||
cfg.FileCache.MaxEntries,
|
||||
cfg.FileCache.MaxSize,
|
||||
cfg.FileCache.Inactive,
|
||||
)
|
||||
|
||||
logging.Info().
|
||||
Int64("maxEntries", cfg.FileCache.MaxEntries).
|
||||
Int64("maxSize", cfg.FileCache.MaxSize).
|
||||
Msg("文件缓存已启动")
|
||||
|
||||
return fileCache
|
||||
}
|
||||
|
||||
// initLuaEngine 初始化 Lua 引擎。
|
||||
//
|
||||
// 根据配置创建 Lua 引擎实例,用于执行 Lua 脚本。
|
||||
//
|
||||
// 参数:
|
||||
// - luaCfg: Lua 中间件配置
|
||||
//
|
||||
// 返回值:
|
||||
// - *lua.LuaEngine: 初始化的 Lua 引擎,未启用时返回 nil
|
||||
// - error: 初始化过程中遇到的错误
|
||||
func initLuaEngine(luaCfg *config.LuaMiddlewareConfig) (*lua.LuaEngine, error) {
|
||||
if luaCfg == nil || !luaCfg.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
engineCfg := &lua.Config{
|
||||
MaxConcurrentCoroutines: luaCfg.GlobalSettings.MaxConcurrentCoroutines,
|
||||
CoroutineTimeout: luaCfg.GlobalSettings.CoroutineTimeout,
|
||||
CodeCacheSize: luaCfg.GlobalSettings.CodeCacheSize,
|
||||
CodeCacheTTL: time.Hour, // 默认值
|
||||
EnableFileWatch: luaCfg.GlobalSettings.EnableFileWatch,
|
||||
MaxExecutionTime: luaCfg.GlobalSettings.MaxExecutionTime,
|
||||
EnableOSLib: false, // 安全默认值
|
||||
EnableIOLib: false,
|
||||
EnableLoadLib: false,
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if engineCfg.MaxConcurrentCoroutines == 0 {
|
||||
engineCfg.MaxConcurrentCoroutines = 1000
|
||||
}
|
||||
if engineCfg.CoroutineTimeout == 0 {
|
||||
engineCfg.CoroutineTimeout = 30 * time.Second
|
||||
}
|
||||
if engineCfg.CodeCacheSize == 0 {
|
||||
engineCfg.CodeCacheSize = 1000
|
||||
}
|
||||
if engineCfg.MaxExecutionTime == 0 {
|
||||
engineCfg.MaxExecutionTime = 30 * time.Second
|
||||
}
|
||||
|
||||
engine, err := lua.NewEngine(engineCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化 Lua 引擎失败: %w", err)
|
||||
}
|
||||
|
||||
logging.Info().Msg("Lua 引擎已启动")
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// initErrorPageManager 初始化错误页面管理器。
|
||||
//
|
||||
// 根据配置创建错误页面管理器,用于加载和提供自定义错误页面。
|
||||
//
|
||||
// 参数:
|
||||
// - errorPageCfg: 错误页面配置
|
||||
//
|
||||
// 返回值:
|
||||
// - *handler.ErrorPageManager: 初始化的错误页面管理器,未配置时返回 nil
|
||||
// - error: 初始化过程中遇到的错误
|
||||
func initErrorPageManager(errorPageCfg *config.ErrorPageConfig) (*handler.ErrorPageManager, error) {
|
||||
// 检查是否配置了错误页面
|
||||
if errorPageCfg == nil || (len(errorPageCfg.Pages) == 0 && errorPageCfg.Default == "") {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
manager, err := handler.NewErrorPageManager(errorPageCfg)
|
||||
if err != nil {
|
||||
// 检查是否是部分加载失败
|
||||
if _, ok := err.(*handler.PartialLoadError); ok {
|
||||
logging.Warn().Msg("部分错误页面加载失败: " + err.Error())
|
||||
// 返回部分加载的管理器
|
||||
return manager, nil
|
||||
}
|
||||
// 全部加载失败
|
||||
return nil, fmt.Errorf("加载错误页面失败: %w", err)
|
||||
}
|
||||
|
||||
logging.Info().Msg("错误页面管理器已启动")
|
||||
return manager, nil
|
||||
}
|
||||
351
internal/server/init_test.go
Normal file
351
internal/server/init_test.go
Normal file
@ -0,0 +1,351 @@
|
||||
// Package server 提供初始化函数的测试。
|
||||
//
|
||||
// 该文件测试从 Start() 分离出的初始化函数:
|
||||
// - initGoroutinePool()
|
||||
// - initFileCache()
|
||||
// - initLuaEngine()
|
||||
// - initErrorPageManager()
|
||||
//
|
||||
// 主要用途:
|
||||
//
|
||||
// 验证初始化函数在各种配置场景下的行为
|
||||
//
|
||||
// 作者:xfy
|
||||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// TestInitGoroutinePool_Disabled 测试禁用时返回 nil
|
||||
func TestInitGoroutinePool_Disabled(t *testing.T) {
|
||||
cfg := &config.PerformanceConfig{
|
||||
GoroutinePool: config.GoroutinePoolConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
pool := initGoroutinePool(cfg)
|
||||
if pool != nil {
|
||||
t.Error("Expected nil when GoroutinePool is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitGoroutinePool_Enabled 测试启用时创建池
|
||||
func TestInitGoroutinePool_Enabled(t *testing.T) {
|
||||
cfg := &config.PerformanceConfig{
|
||||
GoroutinePool: config.GoroutinePoolConfig{
|
||||
Enabled: true,
|
||||
MaxWorkers: 50,
|
||||
MinWorkers: 5,
|
||||
IdleTimeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
pool := initGoroutinePool(cfg)
|
||||
if pool == nil {
|
||||
t.Fatal("Expected non-nil pool when enabled")
|
||||
}
|
||||
|
||||
// 验证配置正确应用
|
||||
if pool.maxWorkers != 50 {
|
||||
t.Errorf("Expected maxWorkers 50, got %d", pool.maxWorkers)
|
||||
}
|
||||
if pool.minWorkers != 5 {
|
||||
t.Errorf("Expected minWorkers 5, got %d", pool.minWorkers)
|
||||
}
|
||||
|
||||
// 清理
|
||||
pool.Stop()
|
||||
}
|
||||
|
||||
// TestInitGoroutinePool_ZeroWorkers 测试零值配置
|
||||
func TestInitGoroutinePool_ZeroWorkers(t *testing.T) {
|
||||
cfg := &config.PerformanceConfig{
|
||||
GoroutinePool: config.GoroutinePoolConfig{
|
||||
Enabled: true,
|
||||
MaxWorkers: 0,
|
||||
MinWorkers: 0,
|
||||
},
|
||||
}
|
||||
|
||||
pool := initGoroutinePool(cfg)
|
||||
if pool == nil {
|
||||
t.Fatal("Expected non-nil pool")
|
||||
}
|
||||
|
||||
// 清理
|
||||
pool.Stop()
|
||||
}
|
||||
|
||||
// TestInitFileCache_Disabled 测试禁用时返回 nil
|
||||
func TestInitFileCache_Disabled(t *testing.T) {
|
||||
cfg := &config.PerformanceConfig{
|
||||
FileCache: config.FileCacheConfig{
|
||||
MaxEntries: 0,
|
||||
MaxSize: 0,
|
||||
},
|
||||
}
|
||||
|
||||
cache := initFileCache(cfg)
|
||||
if cache != nil {
|
||||
t.Error("Expected nil when FileCache is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitFileCache_ByEntries 测试按条目数启用
|
||||
func TestInitFileCache_ByEntries(t *testing.T) {
|
||||
cfg := &config.PerformanceConfig{
|
||||
FileCache: config.FileCacheConfig{
|
||||
MaxEntries: 1000,
|
||||
MaxSize: 0,
|
||||
Inactive: 10 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
cache := initFileCache(cfg)
|
||||
if cache == nil {
|
||||
t.Fatal("Expected non-nil cache when MaxEntries > 0")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitFileCache_BySize 测试按大小启用
|
||||
func TestInitFileCache_BySize(t *testing.T) {
|
||||
cfg := &config.PerformanceConfig{
|
||||
FileCache: config.FileCacheConfig{
|
||||
MaxEntries: 0,
|
||||
MaxSize: 100 * 1024 * 1024, // 100MB
|
||||
Inactive: 10 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
cache := initFileCache(cfg)
|
||||
if cache == nil {
|
||||
t.Fatal("Expected non-nil cache when MaxSize > 0")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitFileCache_BothLimits 测试同时设置条目数和大小
|
||||
func TestInitFileCache_BothLimits(t *testing.T) {
|
||||
cfg := &config.PerformanceConfig{
|
||||
FileCache: config.FileCacheConfig{
|
||||
MaxEntries: 1000,
|
||||
MaxSize: 100 * 1024 * 1024,
|
||||
Inactive: 10 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
cache := initFileCache(cfg)
|
||||
if cache == nil {
|
||||
t.Fatal("Expected non-nil cache")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitLuaEngine_Disabled 测试禁用时返回 nil
|
||||
func TestInitLuaEngine_Disabled(t *testing.T) {
|
||||
luaCfg := &config.LuaMiddlewareConfig{
|
||||
Enabled: false,
|
||||
}
|
||||
|
||||
engine, err := initLuaEngine(luaCfg)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
if engine != nil {
|
||||
t.Error("Expected nil when Lua is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitLuaEngine_NilConfig 测试 nil 配置
|
||||
func TestInitLuaEngine_NilConfig(t *testing.T) {
|
||||
engine, err := initLuaEngine(nil)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
if engine != nil {
|
||||
t.Error("Expected nil for nil config")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitLuaEngine_Enabled 测试启用 Lua 引擎
|
||||
func TestInitLuaEngine_Enabled(t *testing.T) {
|
||||
luaCfg := &config.LuaMiddlewareConfig{
|
||||
Enabled: true,
|
||||
GlobalSettings: config.LuaGlobalSettings{
|
||||
MaxConcurrentCoroutines: 500,
|
||||
CoroutineTimeout: 60 * time.Second,
|
||||
CodeCacheSize: 500,
|
||||
MaxExecutionTime: 60 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
engine, err := initLuaEngine(luaCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
if engine == nil {
|
||||
t.Fatal("Expected non-nil engine when enabled")
|
||||
}
|
||||
|
||||
// 清理
|
||||
engine.Close()
|
||||
}
|
||||
|
||||
// TestInitLuaEngine_DefaultValues 测试默认值设置
|
||||
func TestInitLuaEngine_DefaultValues(t *testing.T) {
|
||||
luaCfg := &config.LuaMiddlewareConfig{
|
||||
Enabled: true,
|
||||
GlobalSettings: config.LuaGlobalSettings{
|
||||
// 所有值为零,应使用默认值
|
||||
},
|
||||
}
|
||||
|
||||
engine, err := initLuaEngine(luaCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
if engine == nil {
|
||||
t.Fatal("Expected non-nil engine")
|
||||
}
|
||||
|
||||
// 清理
|
||||
engine.Close()
|
||||
}
|
||||
|
||||
// TestInitErrorPageManager_NoConfig 测试无配置时返回 nil
|
||||
func TestInitErrorPageManager_NoConfig(t *testing.T) {
|
||||
cfg := &config.ErrorPageConfig{
|
||||
Pages: nil,
|
||||
Default: "",
|
||||
}
|
||||
|
||||
manager, err := initErrorPageManager(cfg)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
if manager != nil {
|
||||
t.Error("Expected nil when no error page configured")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitErrorPageManager_NilConfig 测试 nil 配置
|
||||
func TestInitErrorPageManager_NilConfig(t *testing.T) {
|
||||
manager, err := initErrorPageManager(nil)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
if manager != nil {
|
||||
t.Error("Expected nil for nil config")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitErrorPageManager_WithPages 测试配置错误页面
|
||||
func TestInitErrorPageManager_WithPages(t *testing.T) {
|
||||
// 创建临时目录和文件
|
||||
tempDir := t.TempDir()
|
||||
errorPagePath := filepath.Join(tempDir, "404.html")
|
||||
content := []byte("<html><body>Not Found</body></html>")
|
||||
if err := os.WriteFile(errorPagePath, content, 0o644); err != nil {
|
||||
t.Fatalf("Failed to create error page file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: errorPagePath,
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := initErrorPageManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
if manager == nil {
|
||||
t.Fatal("Expected non-nil manager when pages configured")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitErrorPageManager_WithDefault 测试配置默认错误页面
|
||||
func TestInitErrorPageManager_WithDefault(t *testing.T) {
|
||||
// 创建临时目录和文件
|
||||
tempDir := t.TempDir()
|
||||
defaultPagePath := filepath.Join(tempDir, "error.html")
|
||||
content := []byte("<html><body>Error</body></html>")
|
||||
if err := os.WriteFile(defaultPagePath, content, 0o644); err != nil {
|
||||
t.Fatalf("Failed to create error page file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.ErrorPageConfig{
|
||||
Default: defaultPagePath,
|
||||
}
|
||||
|
||||
manager, err := initErrorPageManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
if manager == nil {
|
||||
t.Fatal("Expected non-nil manager when default page configured")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitErrorPageManager_NonExistentFile 测试不存在的文件
|
||||
func TestInitErrorPageManager_NonExistentFile(t *testing.T) {
|
||||
cfg := &config.ErrorPageConfig{
|
||||
Default: "/nonexistent/path/error.html",
|
||||
}
|
||||
|
||||
manager, err := initErrorPageManager(cfg)
|
||||
// 应该返回错误
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent file")
|
||||
}
|
||||
if manager != nil {
|
||||
t.Error("Expected nil manager when file doesn't exist")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitErrorPageManager_PartialLoad 测试部分加载(部分文件存在)
|
||||
func TestInitErrorPageManager_PartialLoad(t *testing.T) {
|
||||
// 创建临时目录和一个存在的文件
|
||||
tempDir := t.TempDir()
|
||||
existingPath := filepath.Join(tempDir, "404.html")
|
||||
content := []byte("<html><body>Not Found</body></html>")
|
||||
if err := os.WriteFile(existingPath, content, 0o644); err != nil {
|
||||
t.Fatalf("Failed to create error page file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: existingPath,
|
||||
500: "/nonexistent/path/500.html", // 不存在的文件
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := initErrorPageManager(cfg)
|
||||
// 部分加载应该成功返回 manager,但可能有警告
|
||||
// 注意:具体行为取决于 handler.NewErrorPageManager 的实现
|
||||
if manager == nil {
|
||||
t.Logf("Manager is nil with error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitFunctions_NilPerformanceConfig 测试 nil PerformanceConfig
|
||||
func TestInitFunctions_NilPerformanceConfig(t *testing.T) {
|
||||
// 这个测试验证函数能正确处理空配置
|
||||
var cfg *config.PerformanceConfig
|
||||
|
||||
// 应该能处理 nil 而不 panic
|
||||
pool := initGoroutinePool(cfg)
|
||||
if pool != nil {
|
||||
t.Error("Expected nil for nil config")
|
||||
}
|
||||
|
||||
cache := initFileCache(cfg)
|
||||
if cache != nil {
|
||||
t.Error("Expected nil for nil config")
|
||||
}
|
||||
}
|
||||
372
internal/server/lua_integration_test.go
Normal file
372
internal/server/lua_integration_test.go
Normal file
@ -0,0 +1,372 @@
|
||||
// Package server 提供 buildLuaMiddlewares 的 Mock 测试
|
||||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/lua"
|
||||
)
|
||||
|
||||
// TestBuildLuaMiddlewares_NilEngine 测试 LuaEngine 为 nil 时
|
||||
func TestBuildLuaMiddlewares_NilEngine(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":8080",
|
||||
},
|
||||
}
|
||||
s := New(cfg)
|
||||
|
||||
// 确保 luaEngine 为 nil
|
||||
s.luaEngine = nil
|
||||
|
||||
luaCfg := &config.LuaMiddlewareConfig{
|
||||
Enabled: true,
|
||||
Scripts: []config.LuaScriptConfig{
|
||||
{
|
||||
Path: "/test/script.lua",
|
||||
Phase: "rewrite",
|
||||
Enabled: true,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, middlewares)
|
||||
}
|
||||
|
||||
// TestBuildLuaMiddlewares_InvalidPhase 测试无效阶段
|
||||
func TestBuildLuaMiddlewares_InvalidPhase(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":8080",
|
||||
},
|
||||
}
|
||||
s := New(cfg)
|
||||
|
||||
// 使用真实的 LuaEngine 来测试无效阶段
|
||||
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
s.luaEngine = engine
|
||||
|
||||
luaCfg := &config.LuaMiddlewareConfig{
|
||||
Enabled: true,
|
||||
Scripts: []config.LuaScriptConfig{
|
||||
{
|
||||
Path: "/test/script.lua",
|
||||
Phase: "invalid_phase",
|
||||
Enabled: true,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err = s.buildLuaMiddlewares(luaCfg)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "无效的阶段")
|
||||
}
|
||||
|
||||
// TestBuildLuaMiddlewares_WithTimeout 测试超时配置
|
||||
func TestBuildLuaMiddlewares_WithTimeout(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":8080",
|
||||
},
|
||||
}
|
||||
s := New(cfg)
|
||||
|
||||
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
s.luaEngine = engine
|
||||
|
||||
// 测试自定义超时
|
||||
customTimeout := 60 * time.Second
|
||||
luaCfg := &config.LuaMiddlewareConfig{
|
||||
Enabled: true,
|
||||
Scripts: []config.LuaScriptConfig{
|
||||
{
|
||||
Path: "/test/script.lua",
|
||||
Phase: "rewrite",
|
||||
Enabled: true,
|
||||
Timeout: customTimeout,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, middlewares, 1)
|
||||
}
|
||||
|
||||
// TestBuildLuaMiddlewares_EmptyScripts 测试空脚本列表
|
||||
func TestBuildLuaMiddlewares_EmptyScripts(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":8080",
|
||||
},
|
||||
}
|
||||
s := New(cfg)
|
||||
|
||||
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
s.luaEngine = engine
|
||||
|
||||
luaCfg := &config.LuaMiddlewareConfig{
|
||||
Enabled: true,
|
||||
Scripts: []config.LuaScriptConfig{},
|
||||
}
|
||||
|
||||
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, middlewares)
|
||||
}
|
||||
|
||||
// TestBuildLuaMiddlewares_DisabledLua 测试禁用 Lua 配置
|
||||
// buildLuaMiddlewares 函数本身不检查 Enabled 字段,由调用者检查
|
||||
func TestBuildLuaMiddlewares_DisabledLua(t *testing.T) {
|
||||
// 创建临时脚本文件
|
||||
tempDir := t.TempDir()
|
||||
scriptPath := tempDir + "/test.lua"
|
||||
err := os.WriteFile(scriptPath, []byte("-- test"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":8080",
|
||||
},
|
||||
}
|
||||
s := New(cfg)
|
||||
|
||||
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
s.luaEngine = engine
|
||||
|
||||
// Enabled=false 但 buildLuaMiddlewares 本身不检查这个字段
|
||||
// 它会正常处理 Scripts 中的脚本
|
||||
luaCfg := &config.LuaMiddlewareConfig{
|
||||
Enabled: false,
|
||||
Scripts: []config.LuaScriptConfig{
|
||||
{
|
||||
Path: scriptPath,
|
||||
Phase: "rewrite",
|
||||
Enabled: true,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 由于 Enabled 检查在调用者处,buildLuaMiddlewares 会正常执行
|
||||
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||
require.NoError(t, err)
|
||||
// 脚本文件存在,应该能创建中间件
|
||||
assert.Len(t, middlewares, 1)
|
||||
}
|
||||
|
||||
// TestBuildLuaMiddlewares_DisabledScript 测试禁用特定脚本
|
||||
func TestBuildLuaMiddlewares_DisabledScript(t *testing.T) {
|
||||
// 创建临时脚本文件(只有一个启用的脚本)
|
||||
tempDir := t.TempDir()
|
||||
enabledScriptPath := tempDir + "/enabled.lua"
|
||||
err := os.WriteFile(enabledScriptPath, []byte("-- enabled script"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":8080",
|
||||
},
|
||||
}
|
||||
s := New(cfg)
|
||||
|
||||
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
s.luaEngine = engine
|
||||
|
||||
luaCfg := &config.LuaMiddlewareConfig{
|
||||
Enabled: true,
|
||||
Scripts: []config.LuaScriptConfig{
|
||||
{
|
||||
Path: enabledScriptPath,
|
||||
Phase: "rewrite",
|
||||
Enabled: true,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
{
|
||||
Path: "/test/disabled.lua",
|
||||
Phase: "access",
|
||||
Enabled: false, // 禁用的脚本应该被过滤
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 只有一个启用的脚本,应该能正常创建中间件
|
||||
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||
require.NoError(t, err)
|
||||
// 只有 rewrite 阶段的脚本被创建
|
||||
assert.Len(t, middlewares, 1)
|
||||
}
|
||||
|
||||
// TestBuildLuaMiddlewares_WithExistingScript 测试使用存在的脚本文件
|
||||
func TestBuildLuaMiddlewares_WithExistingScript(t *testing.T) {
|
||||
// 创建临时脚本文件
|
||||
scriptContent := `-- test script
|
||||
ngx.var.uri = "/test"
|
||||
`
|
||||
tempDir := t.TempDir()
|
||||
scriptPath := tempDir + "/test.lua"
|
||||
err := os.WriteFile(scriptPath, []byte(scriptContent), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":8080",
|
||||
},
|
||||
}
|
||||
s := New(cfg)
|
||||
|
||||
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
s.luaEngine = engine
|
||||
|
||||
luaCfg := &config.LuaMiddlewareConfig{
|
||||
Enabled: true,
|
||||
Scripts: []config.LuaScriptConfig{
|
||||
{
|
||||
Path: scriptPath,
|
||||
Phase: "rewrite",
|
||||
Enabled: true,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, middlewares, 1)
|
||||
}
|
||||
|
||||
// TestBuildLuaMiddlewares_MultiplePhases 测试多个不同阶段的脚本
|
||||
func TestBuildLuaMiddlewares_MultiplePhases(t *testing.T) {
|
||||
// 创建临时脚本文件
|
||||
tempDir := t.TempDir()
|
||||
|
||||
scripts := map[string]string{
|
||||
"rewrite.lua": `-- rewrite script
|
||||
ngx.var.uri = "/rewrite"
|
||||
`,
|
||||
"access.lua": `-- access script
|
||||
-- access control logic
|
||||
`,
|
||||
"content.lua": `-- content script
|
||||
ngx.say("hello")
|
||||
`,
|
||||
}
|
||||
|
||||
scriptPaths := make(map[string]string)
|
||||
for name, content := range scripts {
|
||||
path := tempDir + "/" + name
|
||||
err := os.WriteFile(path, []byte(content), 0o644)
|
||||
require.NoError(t, err)
|
||||
scriptPaths[name] = path
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":8080",
|
||||
},
|
||||
}
|
||||
s := New(cfg)
|
||||
|
||||
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
s.luaEngine = engine
|
||||
|
||||
luaCfg := &config.LuaMiddlewareConfig{
|
||||
Enabled: true,
|
||||
Scripts: []config.LuaScriptConfig{
|
||||
{
|
||||
Path: scriptPaths["rewrite.lua"],
|
||||
Phase: "rewrite",
|
||||
Enabled: true,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
{
|
||||
Path: scriptPaths["access.lua"],
|
||||
Phase: "access",
|
||||
Enabled: true,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
{
|
||||
Path: scriptPaths["content.lua"],
|
||||
Phase: "content",
|
||||
Enabled: true,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||
require.NoError(t, err)
|
||||
// 三个阶段应该创建三个中间件
|
||||
assert.Len(t, middlewares, 3)
|
||||
}
|
||||
|
||||
// TestBuildLuaMiddlewares_DefaultTimeout 测试默认超时值
|
||||
func TestBuildLuaMiddlewares_DefaultTimeout(t *testing.T) {
|
||||
// 创建临时脚本文件
|
||||
tempDir := t.TempDir()
|
||||
scriptPath := tempDir + "/test.lua"
|
||||
err := os.WriteFile(scriptPath, []byte("-- test"), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":8080",
|
||||
},
|
||||
}
|
||||
s := New(cfg)
|
||||
|
||||
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
s.luaEngine = engine
|
||||
|
||||
// 不设置 Timeout,使用默认值
|
||||
luaCfg := &config.LuaMiddlewareConfig{
|
||||
Enabled: true,
|
||||
Scripts: []config.LuaScriptConfig{
|
||||
{
|
||||
Path: scriptPath,
|
||||
Phase: "rewrite",
|
||||
Enabled: true,
|
||||
// Timeout 为 0,应该使用默认值 30s
|
||||
Timeout: 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, middlewares, 1)
|
||||
}
|
||||
@ -405,73 +405,23 @@ func (s *Server) Start() error {
|
||||
// 记录启动时间
|
||||
s.startTime = time.Now()
|
||||
|
||||
// 启用 GoroutinePool(如果配置)
|
||||
if s.config.Performance.GoroutinePool.Enabled {
|
||||
s.pool = NewGoroutinePool(PoolConfig{
|
||||
MaxWorkers: s.config.Performance.GoroutinePool.MaxWorkers,
|
||||
MinWorkers: s.config.Performance.GoroutinePool.MinWorkers,
|
||||
IdleTimeout: s.config.Performance.GoroutinePool.IdleTimeout,
|
||||
})
|
||||
s.pool.Start()
|
||||
// 初始化 GoroutinePool
|
||||
s.pool = initGoroutinePool(&s.config.Performance)
|
||||
|
||||
// 初始化文件缓存
|
||||
s.fileCache = initFileCache(&s.config.Performance)
|
||||
|
||||
// 初始化错误页面管理器
|
||||
var err error
|
||||
s.errorPageManager, err = initErrorPageManager(&s.config.Server.Security.ErrorPage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 启用文件缓存(如果配置)
|
||||
if s.config.Performance.FileCache.MaxEntries > 0 || s.config.Performance.FileCache.MaxSize > 0 {
|
||||
s.fileCache = cache.NewFileCache(
|
||||
s.config.Performance.FileCache.MaxEntries,
|
||||
s.config.Performance.FileCache.MaxSize,
|
||||
s.config.Performance.FileCache.Inactive,
|
||||
)
|
||||
}
|
||||
|
||||
// 预加载错误页面(如果配置)
|
||||
if s.config.Server.Security.ErrorPage.Pages != nil || s.config.Server.Security.ErrorPage.Default != "" {
|
||||
var err error
|
||||
s.errorPageManager, err = handler.NewErrorPageManager(&s.config.Server.Security.ErrorPage)
|
||||
if err != nil {
|
||||
// 检查是否是部分加载失败
|
||||
if _, ok := err.(*handler.PartialLoadError); ok {
|
||||
logging.Warn().Msg("部分错误页面加载失败: " + err.Error())
|
||||
} else {
|
||||
// 全部加载失败,阻止启动
|
||||
return fmt.Errorf("加载错误页面失败: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化 Lua 引擎(如果配置)
|
||||
if s.config.Server.Lua != nil && s.config.Server.Lua.Enabled {
|
||||
engineCfg := &lua.Config{
|
||||
MaxConcurrentCoroutines: s.config.Server.Lua.GlobalSettings.MaxConcurrentCoroutines,
|
||||
CoroutineTimeout: s.config.Server.Lua.GlobalSettings.CoroutineTimeout,
|
||||
CodeCacheSize: s.config.Server.Lua.GlobalSettings.CodeCacheSize,
|
||||
CodeCacheTTL: time.Hour, // 默认值
|
||||
EnableFileWatch: s.config.Server.Lua.GlobalSettings.EnableFileWatch,
|
||||
MaxExecutionTime: s.config.Server.Lua.GlobalSettings.MaxExecutionTime,
|
||||
EnableOSLib: false, // 安全默认值
|
||||
EnableIOLib: false,
|
||||
EnableLoadLib: false,
|
||||
}
|
||||
// 设置默认值
|
||||
if engineCfg.MaxConcurrentCoroutines == 0 {
|
||||
engineCfg.MaxConcurrentCoroutines = 1000
|
||||
}
|
||||
if engineCfg.CoroutineTimeout == 0 {
|
||||
engineCfg.CoroutineTimeout = 30 * time.Second
|
||||
}
|
||||
if engineCfg.CodeCacheSize == 0 {
|
||||
engineCfg.CodeCacheSize = 1000
|
||||
}
|
||||
if engineCfg.MaxExecutionTime == 0 {
|
||||
engineCfg.MaxExecutionTime = 30 * time.Second
|
||||
}
|
||||
|
||||
var err error
|
||||
s.luaEngine, err = lua.NewEngine(engineCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("初始化 Lua 引擎失败: %w", err)
|
||||
}
|
||||
logging.Info().Msg("Lua 引擎已启动")
|
||||
// 初始化 Lua 引擎
|
||||
s.luaEngine, err = initLuaEngine(s.config.Server.Lua)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.config.HasServers() {
|
||||
|
||||
@ -769,3 +769,32 @@ func TestGetTLSConfig_Nil(t *testing.T) {
|
||||
t.Error("GetTLSConfig() should return nil when TLS not configured")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetTLSConfig_NilServer 测试 nil 服务器调用 GetTLSConfig
|
||||
func TestGetTLSConfig_NilServer(t *testing.T) {
|
||||
var s *Server
|
||||
// 防御性:如果 s 为 nil,调用方法会 panic,这是预期的行为
|
||||
// 这里我们只测试非 nil 但 tlsManager 为 nil 的情况
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":0",
|
||||
},
|
||||
}
|
||||
s = New(cfg)
|
||||
|
||||
// 确保 tlsManager 为 nil
|
||||
if s.tlsManager != nil {
|
||||
t.Skip("tlsManager should be nil initially")
|
||||
}
|
||||
|
||||
tlsCfg, err := s.GetTLSConfig()
|
||||
if err == nil {
|
||||
t.Error("Expected error when tlsManager is nil")
|
||||
}
|
||||
if tlsCfg != nil {
|
||||
t.Error("Expected nil TLS config when tlsManager is nil")
|
||||
}
|
||||
if err.Error() != "TLS not configured" {
|
||||
t.Errorf("Expected error 'TLS not configured', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
556
internal/server/start_integration_test.go
Normal file
556
internal/server/start_integration_test.go
Normal file
@ -0,0 +1,556 @@
|
||||
// Package server 提供 Server.Start() 的集成测试。
|
||||
//
|
||||
// 该文件使用 mock_backend 模拟上游服务,测试完整的服务器启动流程:
|
||||
// - 服务器配置初始化
|
||||
// - 代理路由注册
|
||||
// - 静态文件服务
|
||||
// - 中间件链构建
|
||||
// - 请求处理和转发
|
||||
//
|
||||
// 主要用途:
|
||||
//
|
||||
// 验证服务器启动和请求处理的完整流程
|
||||
//
|
||||
// 作者:xfy
|
||||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/benchmark/tools"
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// TestStart_Integration 测试完整的服务器启动和请求处理流程
|
||||
func TestStart_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// 启动 mock 上游服务器
|
||||
backendAddr, cleanup := tools.SimpleMockBackend(
|
||||
fasthttp.StatusOK,
|
||||
[]byte(`{"message": "Hello from backend"}`),
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
// 使用随机端口避免冲突
|
||||
serverAddr := "127.0.0.1:0"
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: serverAddr,
|
||||
Proxy: []config.ProxyConfig{
|
||||
{
|
||||
Path: "/api",
|
||||
Targets: []config.ProxyTarget{
|
||||
{URL: "http://" + backendAddr, Weight: 1},
|
||||
},
|
||||
HealthCheck: config.HealthCheckConfig{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
// 验证服务器初始化
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
|
||||
// 测试初始状态
|
||||
if s.running {
|
||||
t.Error("Server should not be running initially")
|
||||
}
|
||||
|
||||
// 测试 GetListeners 初始状态
|
||||
listeners := s.GetListeners()
|
||||
if listeners != nil {
|
||||
t.Error("Listeners should be nil before Start")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithSecurity 测试安全配置
|
||||
func TestStart_WithSecurity(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
Security: config.SecurityConfig{
|
||||
Access: config.AccessConfig{
|
||||
Allow: []string{"127.0.0.1"},
|
||||
Deny: []string{},
|
||||
},
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestRate: 100,
|
||||
Burst: 200,
|
||||
},
|
||||
Headers: config.SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
XContentTypeOptions: "nosniff",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
|
||||
// 验证安全配置
|
||||
if len(s.config.Server.Security.Access.Allow) != 1 {
|
||||
t.Errorf("Expected 1 allowed IP, got %d", len(s.config.Server.Security.Access.Allow))
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithRewrite 测试 URL 重写配置
|
||||
func TestStart_WithRewrite(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
Rewrite: []config.RewriteRule{
|
||||
{
|
||||
Pattern: "/old/(.*)",
|
||||
Replacement: "/new/$1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
|
||||
// 验证重写配置
|
||||
if len(s.config.Server.Rewrite) != 1 {
|
||||
t.Errorf("Expected 1 rewrite rule, got %d", len(s.config.Server.Rewrite))
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithMonitoring 测试监控配置
|
||||
func TestStart_WithMonitoring(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
},
|
||||
Monitoring: config.MonitoringConfig{
|
||||
Status: config.StatusConfig{
|
||||
Path: "/status",
|
||||
Allow: []string{"127.0.0.1"},
|
||||
},
|
||||
Pprof: config.PprofConfig{
|
||||
Enabled: false,
|
||||
Path: "/debug/pprof",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
|
||||
// 验证监控配置
|
||||
if s.config.Monitoring.Status.Path != "/status" {
|
||||
t.Errorf("Expected status path '/status', got '%s'", s.config.Monitoring.Status.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithErrorPage 测试错误页面配置
|
||||
func TestStart_WithErrorPage(t *testing.T) {
|
||||
// 创建临时错误页面文件
|
||||
tempDir := t.TempDir()
|
||||
errorPagePath := tempDir + "/404.html"
|
||||
|
||||
// 创建错误页面文件
|
||||
content := []byte("<html><body>Not Found</body></html>")
|
||||
if err := writeFile(errorPagePath, content); err != nil {
|
||||
t.Fatalf("Failed to create error page file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
Security: config.SecurityConfig{
|
||||
ErrorPage: config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: errorPagePath,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
|
||||
// 验证错误页面配置
|
||||
if s.config.Server.Security.ErrorPage.Pages == nil {
|
||||
t.Error("Error page pages should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithLuaEnabled 测试 Lua 配置
|
||||
func TestStart_WithLuaEnabled(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
Lua: &config.LuaMiddlewareConfig{
|
||||
Enabled: true,
|
||||
GlobalSettings: config.LuaGlobalSettings{
|
||||
MaxConcurrentCoroutines: 100,
|
||||
CoroutineTimeout: 30 * time.Second,
|
||||
CodeCacheSize: 100,
|
||||
MaxExecutionTime: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
|
||||
// 验证 Lua 配置
|
||||
if s.config.Server.Lua == nil || !s.config.Server.Lua.Enabled {
|
||||
t.Error("Lua should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithMultipleProxies 测试多个代理配置
|
||||
func TestStart_WithMultipleProxies(t *testing.T) {
|
||||
// 启动多个 mock 上游服务器
|
||||
backend1, cleanup1 := tools.SimpleMockBackend(
|
||||
fasthttp.StatusOK,
|
||||
[]byte(`{"service": "api1"}`),
|
||||
)
|
||||
defer cleanup1()
|
||||
|
||||
backend2, cleanup2 := tools.SimpleMockBackend(
|
||||
fasthttp.StatusOK,
|
||||
[]byte(`{"service": "api2"}`),
|
||||
)
|
||||
defer cleanup2()
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
Proxy: []config.ProxyConfig{
|
||||
{
|
||||
Path: "/api1",
|
||||
Targets: []config.ProxyTarget{
|
||||
{URL: "http://" + backend1, Weight: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
Path: "/api2",
|
||||
Targets: []config.ProxyTarget{
|
||||
{URL: "http://" + backend2, Weight: 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
|
||||
// 验证代理配置
|
||||
if len(s.config.Server.Proxy) != 2 {
|
||||
t.Errorf("Expected 2 proxies, got %d", len(s.config.Server.Proxy))
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_EmptyConfig 测试空配置
|
||||
func TestStart_EmptyConfig(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
|
||||
// 空配置应该能正常初始化
|
||||
if s.config == nil {
|
||||
t.Error("Config should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithAllFeatures 测试启用所有功能的配置
|
||||
func TestStart_WithAllFeatures(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
errorPagePath := tempDir + "/404.html"
|
||||
writeFile(errorPagePath, []byte("<html><body>Not Found</body></html>"))
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
Static: []config.StaticConfig{
|
||||
{
|
||||
Path: "/static",
|
||||
Root: tempDir,
|
||||
Index: []string{"index.html"},
|
||||
},
|
||||
},
|
||||
Compression: config.CompressionConfig{
|
||||
Type: "gzip",
|
||||
Level: 6,
|
||||
},
|
||||
Security: config.SecurityConfig{
|
||||
Access: config.AccessConfig{
|
||||
Allow: []string{"127.0.0.1"},
|
||||
},
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestRate: 100,
|
||||
Burst: 200,
|
||||
},
|
||||
Headers: config.SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
},
|
||||
ErrorPage: config.ErrorPageConfig{
|
||||
Default: errorPagePath,
|
||||
},
|
||||
},
|
||||
Rewrite: []config.RewriteRule{
|
||||
{
|
||||
Pattern: "/old/(.*)",
|
||||
Replacement: "/new/$1",
|
||||
},
|
||||
},
|
||||
},
|
||||
Performance: config.PerformanceConfig{
|
||||
GoroutinePool: config.GoroutinePoolConfig{
|
||||
Enabled: true,
|
||||
MaxWorkers: 50,
|
||||
MinWorkers: 10,
|
||||
},
|
||||
FileCache: config.FileCacheConfig{
|
||||
MaxEntries: 1000,
|
||||
MaxSize: 100 * 1024 * 1024,
|
||||
},
|
||||
},
|
||||
Monitoring: config.MonitoringConfig{
|
||||
Status: config.StatusConfig{
|
||||
Path: "/status",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
|
||||
// 验证所有配置
|
||||
if !s.config.Performance.GoroutinePool.Enabled {
|
||||
t.Error("GoroutinePool should be enabled")
|
||||
}
|
||||
if s.config.Server.Compression.Type != "gzip" {
|
||||
t.Error("Compression should be gzip")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_ServerOptions 测试服务器配置选项
|
||||
func TestStart_ServerOptions(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
MaxConnsPerIP: 100,
|
||||
MaxRequestsPerConn: 1000,
|
||||
ClientMaxBodySize: "10MB",
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
|
||||
// 验证服务器选项
|
||||
if s.config.Server.ReadTimeout != 30*time.Second {
|
||||
t.Errorf("Expected ReadTimeout 30s, got %v", s.config.Server.ReadTimeout)
|
||||
}
|
||||
if s.config.Server.MaxConnsPerIP != 100 {
|
||||
t.Errorf("Expected MaxConnsPerIP 100, got %d", s.config.Server.MaxConnsPerIP)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_HealthCheckConfig 测试健康检查配置
|
||||
func TestStart_HealthCheckConfig(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
Proxy: []config.ProxyConfig{
|
||||
{
|
||||
Path: "/api",
|
||||
Targets: []config.ProxyTarget{
|
||||
{URL: "http://127.0.0.1:8081", Weight: 1},
|
||||
},
|
||||
HealthCheck: config.HealthCheckConfig{
|
||||
Interval: 10 * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
Path: "/health",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
|
||||
// 验证健康检查配置
|
||||
if s.config.Server.Proxy[0].HealthCheck.Path != "/health" {
|
||||
t.Error("Health check path should be /health")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_VHostMode 测试虚拟主机模式配置
|
||||
func TestStart_VHostMode(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{
|
||||
{
|
||||
Name: "api.example.com",
|
||||
Listen: "127.0.0.1:0",
|
||||
},
|
||||
{
|
||||
Name: "www.example.com",
|
||||
Listen: "127.0.0.1:0",
|
||||
},
|
||||
},
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
|
||||
// 验证虚拟主机配置
|
||||
if !s.config.HasServers() {
|
||||
t.Error("Should detect virtual hosts")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithProxyBackendError 测试代理后端错误处理
|
||||
func TestStart_WithProxyBackendError(t *testing.T) {
|
||||
// 启动返回错误的 mock 服务器
|
||||
backendAddr, cleanup := tools.ErrorMockBackend(1.0, []byte(`{"error": "backend error"}`))
|
||||
defer cleanup()
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
Proxy: []config.ProxyConfig{
|
||||
{
|
||||
Path: "/api",
|
||||
Targets: []config.ProxyTarget{
|
||||
{URL: "http://" + backendAddr, Weight: 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
|
||||
// 验证代理配置
|
||||
if len(s.config.Server.Proxy) != 1 {
|
||||
t.Errorf("Expected 1 proxy, got %d", len(s.config.Server.Proxy))
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithDelayedBackend 测试延迟后端
|
||||
func TestStart_WithDelayedBackend(t *testing.T) {
|
||||
// 启动延迟的 mock 服务器
|
||||
backendAddr, cleanup := tools.DelayedMockBackend(
|
||||
100*time.Millisecond,
|
||||
[]byte(`{"message": "delayed response"}`),
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
Proxy: []config.ProxyConfig{
|
||||
{
|
||||
Path: "/api",
|
||||
Targets: []config.ProxyTarget{
|
||||
{URL: "http://" + backendAddr, Weight: 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithRandomResponse 测试随机响应后端
|
||||
func TestStart_WithRandomResponse(t *testing.T) {
|
||||
// 启动随机响应的 mock 服务器
|
||||
backendAddr, cleanup := tools.StartMockFasthttpBackend(tools.MockBackendConfig{
|
||||
Mode: tools.ModeRandomResponse,
|
||||
StatusCode: fasthttp.StatusOK,
|
||||
Body: []byte(`{"random": true}`),
|
||||
})
|
||||
defer cleanup()
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: "127.0.0.1:0",
|
||||
Proxy: []config.ProxyConfig{
|
||||
{
|
||||
Path: "/api",
|
||||
Targets: []config.ProxyTarget{
|
||||
{URL: "http://" + backendAddr, Weight: 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("Expected non-nil server")
|
||||
}
|
||||
}
|
||||
|
||||
// writeFile 辅助函数:写入文件
|
||||
func writeFile(path string, content []byte) error {
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
_, err = f.Write(content)
|
||||
return err
|
||||
}
|
||||
121
internal/server/testutil.go
Normal file
121
internal/server/testutil.go
Normal file
@ -0,0 +1,121 @@
|
||||
// Package server 提供测试工具函数和依赖注入支持
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/lua"
|
||||
"rua.plus/lolly/internal/ssl"
|
||||
)
|
||||
|
||||
// MockFastServer 是 fasthttp.Server 的 Mock 包装
|
||||
// 定义在此文件以便 TestServerOptions 可以引用
|
||||
type MockFastServer struct {
|
||||
Name string
|
||||
Handler fasthttp.RequestHandler
|
||||
TLSConfig *tls.Config
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
MaxConnsPerIP int
|
||||
MaxRequestsPerConn int
|
||||
CloseOnShutdown bool
|
||||
ServeFunc func(ln net.Listener) error
|
||||
ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error
|
||||
ShutdownFunc func() error
|
||||
}
|
||||
|
||||
// Serve 启动服务
|
||||
func (m *MockFastServer) Serve(ln net.Listener) error {
|
||||
if m.ServeFunc != nil {
|
||||
return m.ServeFunc(ln)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServeTLS 启动 TLS 服务
|
||||
func (m *MockFastServer) ServeTLS(ln net.Listener, certFile, keyFile string) error {
|
||||
if m.ServeTLSFunc != nil {
|
||||
return m.ServeTLSFunc(ln, certFile, keyFile)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown 关闭服务器
|
||||
func (m *MockFastServer) Shutdown() error {
|
||||
if m.ShutdownFunc != nil {
|
||||
return m.ShutdownFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestDependencies 包含测试时可注入的依赖
|
||||
// 使用具体指针类型,允许注入 Mock 实现
|
||||
type TestDependencies struct {
|
||||
LuaEngine *lua.LuaEngine
|
||||
TLSManager *ssl.TLSManager
|
||||
}
|
||||
|
||||
// NewServerForTesting 创建用于测试的服务器实例
|
||||
// 允许注入 Mock 依赖,不改变生产 API
|
||||
func NewServerForTesting(cfg *config.Config, deps *TestDependencies) *Server {
|
||||
s := New(cfg)
|
||||
if deps != nil {
|
||||
if deps.LuaEngine != nil {
|
||||
s.luaEngine = deps.LuaEngine
|
||||
}
|
||||
if deps.TLSManager != nil {
|
||||
s.tlsManager = deps.TLSManager
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// TestServerOptions 测试服务器的可选配置
|
||||
type TestServerOptions struct {
|
||||
SkipListener bool
|
||||
MockFastServer *MockFastServer
|
||||
CustomHandler fasthttp.RequestHandler
|
||||
DisableMiddleware bool
|
||||
}
|
||||
|
||||
// NewTestServerWithOptions 使用选项创建测试服务器
|
||||
func NewTestServerWithOptions(cfg *config.Config, opts *TestServerOptions) *Server {
|
||||
s := New(cfg)
|
||||
|
||||
if opts != nil {
|
||||
// 可以在这里应用各种测试选项
|
||||
if opts.CustomHandler != nil {
|
||||
s.handler = opts.CustomHandler
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// MustStartTestServer 启动测试服务器,失败时 panic
|
||||
// 主要用于集成测试
|
||||
func MustStartTestServer(cfg *config.Config) *Server {
|
||||
s := New(cfg)
|
||||
// 在测试环境中使用随机端口避免冲突
|
||||
if cfg.Server.Listen == "" || cfg.Server.Listen == ":80" {
|
||||
cfg.Server.Listen = "127.0.0.1:0"
|
||||
}
|
||||
|
||||
// 使用 goroutine 启动服务器以避免阻塞
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
// 测试服务器启动失败记录日志
|
||||
panic("failed to start test server: " + err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
// 给服务器一点时间启动
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
return s
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user