refactor(server): 提取初始化逻辑到独立函数
- 将 Start() 中的 goroutine pool 初始化提取为 initGoroutinePool() - 将 file cache 初始化提取为 initFileCache() - 将 Lua engine 初始化提取为 initLuaEngine() - 将 error page manager 初始化提取为 initErrorPageManager() - 添加 init.go 存放提取的初始化函数 - 添加 init_test.go 测试初始化函数 - 添加 testutil.go 提供测试 mock 和工具 - 添加 lua_integration_test.go Lua 中间件集成测试 - 添加 start_integration_test.go Start() 集成测试 - 添加 server_test.go nil tlsManager 测试 - 添加 lua/mock_engine.go Lua 引擎 mock 实现 - 添加 lua/api_balancer_test.go Lua balancer API 测试 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
73ef7f4916
commit
9144dcbb06
161
internal/lua/api_balancer_test.go
Normal file
161
internal/lua/api_balancer_test.go
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
// Package lua 提供 ngx.balancer API 的测试
|
||||||
|
//
|
||||||
|
// 该文件测试负载均衡相关的 Lua API,包括:
|
||||||
|
// - BalancerContext 创建和管理
|
||||||
|
// - set_current_peer API
|
||||||
|
// - set_more_tries API
|
||||||
|
// - get_last_failure API
|
||||||
|
// - get_targets API
|
||||||
|
// - get_client_ip API
|
||||||
|
// - IsSelected 边界测试
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package lua
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"rua.plus/lolly/internal/loadbalance"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestBalancerContext_IsSelected 测试 IsSelected 方法
|
||||||
|
func TestBalancerContext_IsSelected(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
selected bool
|
||||||
|
wantResult bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "已选择目标",
|
||||||
|
selected: true,
|
||||||
|
wantResult: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "未选择目标",
|
||||||
|
selected: false,
|
||||||
|
wantResult: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
bctx := &BalancerContext{
|
||||||
|
Targets: []*loadbalance.Target{
|
||||||
|
{URL: "http://backend1:8080"},
|
||||||
|
{URL: "http://backend2:8080"},
|
||||||
|
},
|
||||||
|
ClientIP: "127.0.0.1",
|
||||||
|
Retries: 3,
|
||||||
|
selected: tt.selected,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.selected {
|
||||||
|
bctx.Selected = bctx.Targets[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
result := bctx.IsSelected()
|
||||||
|
assert.Equal(t, tt.wantResult, result, "IsSelected() should return expected value")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBalancerContext_IsSelected_ZeroValue 测试零值情况
|
||||||
|
func TestBalancerContext_IsSelected_ZeroValue(t *testing.T) {
|
||||||
|
// 零值的 BalancerContext
|
||||||
|
bctx := &BalancerContext{}
|
||||||
|
|
||||||
|
// 默认应该返回 false
|
||||||
|
result := bctx.IsSelected()
|
||||||
|
assert.False(t, result, "Zero value BalancerContext should return false for IsSelected()")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBalancerContext_IsSelected_AfterSelection 测试选择后的状态
|
||||||
|
func TestBalancerContext_IsSelected_AfterSelection(t *testing.T) {
|
||||||
|
targets := []*loadbalance.Target{
|
||||||
|
{URL: "http://backend1:8080"},
|
||||||
|
{URL: "http://backend2:8080"},
|
||||||
|
}
|
||||||
|
|
||||||
|
bctx := &BalancerContext{
|
||||||
|
Targets: targets,
|
||||||
|
ClientIP: "192.168.1.1",
|
||||||
|
Retries: 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始状态
|
||||||
|
assert.False(t, bctx.IsSelected(), "Should not be selected initially")
|
||||||
|
|
||||||
|
// 模拟选择目标
|
||||||
|
bctx.Selected = targets[0]
|
||||||
|
bctx.selected = true
|
||||||
|
|
||||||
|
// 选择后状态
|
||||||
|
assert.True(t, bctx.IsSelected(), "Should be selected after setting")
|
||||||
|
|
||||||
|
// 清除选择
|
||||||
|
bctx.selected = false
|
||||||
|
assert.False(t, bctx.IsSelected(), "Should not be selected after clearing flag")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestClassifyError 测试错误分类函数
|
||||||
|
func TestClassifyError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil error",
|
||||||
|
err: nil,
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "timeout error",
|
||||||
|
err: errors.New("connection timeout"),
|
||||||
|
expected: "timeout",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "connection error",
|
||||||
|
err: errors.New("connection refused"),
|
||||||
|
expected: "failed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "generic error",
|
||||||
|
err: errors.New("some error"),
|
||||||
|
expected: "failed",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := classifyError(tt.err)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBalancerContext_Structure 测试 BalancerContext 结构体字段
|
||||||
|
func TestBalancerContext_Structure(t *testing.T) {
|
||||||
|
targets := []*loadbalance.Target{
|
||||||
|
{URL: "http://backend1:8080", Weight: 1},
|
||||||
|
{URL: "http://backend2:8080", Weight: 2},
|
||||||
|
}
|
||||||
|
|
||||||
|
bctx := &BalancerContext{
|
||||||
|
Targets: targets,
|
||||||
|
ClientIP: "10.0.0.1",
|
||||||
|
Retries: 5,
|
||||||
|
Selected: nil,
|
||||||
|
LastError: nil,
|
||||||
|
selected: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, 2, len(bctx.Targets))
|
||||||
|
assert.Equal(t, "10.0.0.1", bctx.ClientIP)
|
||||||
|
assert.Equal(t, 5, bctx.Retries)
|
||||||
|
assert.Nil(t, bctx.Selected)
|
||||||
|
assert.Nil(t, bctx.LastError)
|
||||||
|
assert.False(t, bctx.selected)
|
||||||
|
}
|
||||||
196
internal/lua/mock_engine.go
Normal file
196
internal/lua/mock_engine.go
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
// Package lua 提供 Lua 引擎的 Mock 实现,用于测试
|
||||||
|
package lua
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
glua "github.com/yuin/gopher-lua"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockLuaEngine 是 LuaEngine 的 Mock 实现
|
||||||
|
type MockLuaEngine struct {
|
||||||
|
ExecuteFunc func(script string) error
|
||||||
|
ExecuteFileFunc func(path string) error
|
||||||
|
NewCoroutineFunc func(ctx *fasthttp.RequestCtx) (*MockCoroutine, error)
|
||||||
|
CloseFunc func()
|
||||||
|
StatsFunc func() EngineStats
|
||||||
|
ActiveCoroutinesFunc func() int32
|
||||||
|
CodeCacheFunc func() *CodeCache
|
||||||
|
SharedDictManagerFunc func() *SharedDictManager
|
||||||
|
TimerManagerFunc func() *TimerManager
|
||||||
|
LocationManagerFunc func() *LocationManager
|
||||||
|
CreateSharedDictFunc func(name string, maxItems int) *SharedDict
|
||||||
|
InitSchedulerLStateFunc func() error
|
||||||
|
SchedulerLoopFunc func()
|
||||||
|
EnqueueCallbackFunc func(entry *CallbackEntry) bool
|
||||||
|
CloseSchedulerFunc func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute 执行脚本
|
||||||
|
func (m *MockLuaEngine) Execute(script string) error {
|
||||||
|
if m.ExecuteFunc != nil {
|
||||||
|
return m.ExecuteFunc(script)
|
||||||
|
}
|
||||||
|
return nil // stub
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteFile 执行文件
|
||||||
|
func (m *MockLuaEngine) ExecuteFile(path string) error {
|
||||||
|
if m.ExecuteFileFunc != nil {
|
||||||
|
return m.ExecuteFileFunc(path)
|
||||||
|
}
|
||||||
|
return nil // stub
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCoroutine 创建协程
|
||||||
|
func (m *MockLuaEngine) NewCoroutine(req *fasthttp.RequestCtx) (*MockCoroutine, error) {
|
||||||
|
if m.NewCoroutineFunc != nil {
|
||||||
|
return m.NewCoroutineFunc(req)
|
||||||
|
}
|
||||||
|
return &MockCoroutine{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭引擎
|
||||||
|
func (m *MockLuaEngine) Close() {
|
||||||
|
if m.CloseFunc != nil {
|
||||||
|
m.CloseFunc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats 返回统计
|
||||||
|
func (m *MockLuaEngine) Stats() EngineStats {
|
||||||
|
if m.StatsFunc != nil {
|
||||||
|
return m.StatsFunc()
|
||||||
|
}
|
||||||
|
return EngineStats{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ActiveCoroutines 返回活跃协程数
|
||||||
|
func (m *MockLuaEngine) ActiveCoroutines() int32 {
|
||||||
|
if m.ActiveCoroutinesFunc != nil {
|
||||||
|
return m.ActiveCoroutinesFunc()
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CodeCache 返回字节码缓存
|
||||||
|
func (m *MockLuaEngine) CodeCache() *CodeCache {
|
||||||
|
if m.CodeCacheFunc != nil {
|
||||||
|
return m.CodeCacheFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SharedDictManager 返回共享字典管理器
|
||||||
|
func (m *MockLuaEngine) SharedDictManager() *SharedDictManager {
|
||||||
|
if m.SharedDictManagerFunc != nil {
|
||||||
|
return m.SharedDictManagerFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimerManager 返回定时器管理器
|
||||||
|
func (m *MockLuaEngine) TimerManager() *TimerManager {
|
||||||
|
if m.TimerManagerFunc != nil {
|
||||||
|
return m.TimerManagerFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocationManager 返回 location 管理器
|
||||||
|
func (m *MockLuaEngine) LocationManager() *LocationManager {
|
||||||
|
if m.LocationManagerFunc != nil {
|
||||||
|
return m.LocationManagerFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSharedDict 创建共享字典
|
||||||
|
func (m *MockLuaEngine) CreateSharedDict(name string, maxItems int) *SharedDict {
|
||||||
|
if m.CreateSharedDictFunc != nil {
|
||||||
|
return m.CreateSharedDictFunc(name, maxItems)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitSchedulerLState 初始化调度器 LState
|
||||||
|
func (m *MockLuaEngine) InitSchedulerLState() error {
|
||||||
|
if m.InitSchedulerLStateFunc != nil {
|
||||||
|
return m.InitSchedulerLStateFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SchedulerLoop 调度器循环
|
||||||
|
func (m *MockLuaEngine) SchedulerLoop() {
|
||||||
|
if m.SchedulerLoopFunc != nil {
|
||||||
|
m.SchedulerLoopFunc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnqueueCallback 将回调加入调度队列
|
||||||
|
func (m *MockLuaEngine) EnqueueCallback(entry *CallbackEntry) bool {
|
||||||
|
if m.EnqueueCallbackFunc != nil {
|
||||||
|
return m.EnqueueCallbackFunc(entry)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseScheduler 关闭调度器
|
||||||
|
func (m *MockLuaEngine) CloseScheduler() {
|
||||||
|
if m.CloseSchedulerFunc != nil {
|
||||||
|
m.CloseSchedulerFunc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockCoroutine 是 LuaCoroutine 的 Mock 实现
|
||||||
|
type MockCoroutine struct {
|
||||||
|
ExecuteFunc func(script string) error
|
||||||
|
ExecuteFileFunc func(path string) error
|
||||||
|
SetupSandboxFunc func() error
|
||||||
|
CloseFunc func()
|
||||||
|
HandleYieldFunc func(values []glua.LValue) ([]glua.LValue, error)
|
||||||
|
|
||||||
|
// 模拟字段
|
||||||
|
CreatedAt time.Time
|
||||||
|
ExecutionContext context.Context
|
||||||
|
Engine *MockLuaEngine
|
||||||
|
Co *glua.LState
|
||||||
|
Cancel context.CancelFunc
|
||||||
|
RequestCtx *fasthttp.RequestCtx
|
||||||
|
OutputBuffer []byte
|
||||||
|
Exited bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute 执行脚本
|
||||||
|
func (c *MockCoroutine) Execute(script string) error {
|
||||||
|
if c.ExecuteFunc != nil {
|
||||||
|
return c.ExecuteFunc(script)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteFile 执行文件
|
||||||
|
func (c *MockCoroutine) ExecuteFile(path string) error {
|
||||||
|
if c.ExecuteFileFunc != nil {
|
||||||
|
return c.ExecuteFileFunc(path)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupSandbox 设置沙箱
|
||||||
|
func (c *MockCoroutine) SetupSandbox() error {
|
||||||
|
if c.SetupSandboxFunc != nil {
|
||||||
|
return c.SetupSandboxFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭协程
|
||||||
|
func (c *MockCoroutine) Close() {
|
||||||
|
if c.CloseFunc != nil {
|
||||||
|
c.CloseFunc()
|
||||||
|
}
|
||||||
|
}
|
||||||
169
internal/server/init.go
Normal file
169
internal/server/init.go
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
// Package server 提供服务器初始化函数。
|
||||||
|
//
|
||||||
|
// 该文件包含 Start() 方法中分离出的可测试初始化函数:
|
||||||
|
// - Goroutine 池初始化
|
||||||
|
// - 文件缓存初始化
|
||||||
|
// - Lua 引擎初始化
|
||||||
|
// - 错误页面管理器初始化
|
||||||
|
//
|
||||||
|
// 主要用途:
|
||||||
|
//
|
||||||
|
// 将 Start() 方法中的初始化逻辑分离,便于单元测试
|
||||||
|
//
|
||||||
|
// 注意事项:
|
||||||
|
// - 这些函数仅在 Server.Start() 内部调用
|
||||||
|
// - 分离后保持原逻辑不变,仅提取为独立函数
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"rua.plus/lolly/internal/cache"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/handler"
|
||||||
|
"rua.plus/lolly/internal/logging"
|
||||||
|
"rua.plus/lolly/internal/lua"
|
||||||
|
)
|
||||||
|
|
||||||
|
// initGoroutinePool 初始化 Goroutine 池。
|
||||||
|
//
|
||||||
|
// 根据配置创建并启动 Goroutine 池,用于优化请求处理性能。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - cfg: 性能配置,包含 GoroutinePool 配置
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *GoroutinePool: 初始化的 Goroutine 池,未启用时返回 nil
|
||||||
|
func initGoroutinePool(cfg *config.PerformanceConfig) *GoroutinePool {
|
||||||
|
if cfg == nil || !cfg.GoroutinePool.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := NewGoroutinePool(PoolConfig{
|
||||||
|
MaxWorkers: cfg.GoroutinePool.MaxWorkers,
|
||||||
|
MinWorkers: cfg.GoroutinePool.MinWorkers,
|
||||||
|
IdleTimeout: cfg.GoroutinePool.IdleTimeout,
|
||||||
|
})
|
||||||
|
pool.Start()
|
||||||
|
|
||||||
|
logging.Info().
|
||||||
|
Int("maxWorkers", cfg.GoroutinePool.MaxWorkers).
|
||||||
|
Int("minWorkers", cfg.GoroutinePool.MinWorkers).
|
||||||
|
Msg("Goroutine 池已启动")
|
||||||
|
|
||||||
|
return pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// initFileCache 初始化文件缓存。
|
||||||
|
//
|
||||||
|
// 根据配置创建文件缓存实例,用于缓存静态文件内容。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - cfg: 性能配置,包含 FileCache 配置
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *cache.FileCache: 初始化的文件缓存,未启用时返回 nil
|
||||||
|
func initFileCache(cfg *config.PerformanceConfig) *cache.FileCache {
|
||||||
|
// 检查是否配置了缓存
|
||||||
|
if cfg == nil || (cfg.FileCache.MaxEntries <= 0 && cfg.FileCache.MaxSize <= 0) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fileCache := cache.NewFileCache(
|
||||||
|
cfg.FileCache.MaxEntries,
|
||||||
|
cfg.FileCache.MaxSize,
|
||||||
|
cfg.FileCache.Inactive,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.Info().
|
||||||
|
Int64("maxEntries", cfg.FileCache.MaxEntries).
|
||||||
|
Int64("maxSize", cfg.FileCache.MaxSize).
|
||||||
|
Msg("文件缓存已启动")
|
||||||
|
|
||||||
|
return fileCache
|
||||||
|
}
|
||||||
|
|
||||||
|
// initLuaEngine 初始化 Lua 引擎。
|
||||||
|
//
|
||||||
|
// 根据配置创建 Lua 引擎实例,用于执行 Lua 脚本。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - luaCfg: Lua 中间件配置
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *lua.LuaEngine: 初始化的 Lua 引擎,未启用时返回 nil
|
||||||
|
// - error: 初始化过程中遇到的错误
|
||||||
|
func initLuaEngine(luaCfg *config.LuaMiddlewareConfig) (*lua.LuaEngine, error) {
|
||||||
|
if luaCfg == nil || !luaCfg.Enabled {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
engineCfg := &lua.Config{
|
||||||
|
MaxConcurrentCoroutines: luaCfg.GlobalSettings.MaxConcurrentCoroutines,
|
||||||
|
CoroutineTimeout: luaCfg.GlobalSettings.CoroutineTimeout,
|
||||||
|
CodeCacheSize: luaCfg.GlobalSettings.CodeCacheSize,
|
||||||
|
CodeCacheTTL: time.Hour, // 默认值
|
||||||
|
EnableFileWatch: luaCfg.GlobalSettings.EnableFileWatch,
|
||||||
|
MaxExecutionTime: luaCfg.GlobalSettings.MaxExecutionTime,
|
||||||
|
EnableOSLib: false, // 安全默认值
|
||||||
|
EnableIOLib: false,
|
||||||
|
EnableLoadLib: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置默认值
|
||||||
|
if engineCfg.MaxConcurrentCoroutines == 0 {
|
||||||
|
engineCfg.MaxConcurrentCoroutines = 1000
|
||||||
|
}
|
||||||
|
if engineCfg.CoroutineTimeout == 0 {
|
||||||
|
engineCfg.CoroutineTimeout = 30 * time.Second
|
||||||
|
}
|
||||||
|
if engineCfg.CodeCacheSize == 0 {
|
||||||
|
engineCfg.CodeCacheSize = 1000
|
||||||
|
}
|
||||||
|
if engineCfg.MaxExecutionTime == 0 {
|
||||||
|
engineCfg.MaxExecutionTime = 30 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
engine, err := lua.NewEngine(engineCfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("初始化 Lua 引擎失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.Info().Msg("Lua 引擎已启动")
|
||||||
|
return engine, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initErrorPageManager 初始化错误页面管理器。
|
||||||
|
//
|
||||||
|
// 根据配置创建错误页面管理器,用于加载和提供自定义错误页面。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - errorPageCfg: 错误页面配置
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *handler.ErrorPageManager: 初始化的错误页面管理器,未配置时返回 nil
|
||||||
|
// - error: 初始化过程中遇到的错误
|
||||||
|
func initErrorPageManager(errorPageCfg *config.ErrorPageConfig) (*handler.ErrorPageManager, error) {
|
||||||
|
// 检查是否配置了错误页面
|
||||||
|
if errorPageCfg == nil || (len(errorPageCfg.Pages) == 0 && errorPageCfg.Default == "") {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := handler.NewErrorPageManager(errorPageCfg)
|
||||||
|
if err != nil {
|
||||||
|
// 检查是否是部分加载失败
|
||||||
|
if _, ok := err.(*handler.PartialLoadError); ok {
|
||||||
|
logging.Warn().Msg("部分错误页面加载失败: " + err.Error())
|
||||||
|
// 返回部分加载的管理器
|
||||||
|
return manager, nil
|
||||||
|
}
|
||||||
|
// 全部加载失败
|
||||||
|
return nil, fmt.Errorf("加载错误页面失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.Info().Msg("错误页面管理器已启动")
|
||||||
|
return manager, nil
|
||||||
|
}
|
||||||
351
internal/server/init_test.go
Normal file
351
internal/server/init_test.go
Normal file
@ -0,0 +1,351 @@
|
|||||||
|
// Package server 提供初始化函数的测试。
|
||||||
|
//
|
||||||
|
// 该文件测试从 Start() 分离出的初始化函数:
|
||||||
|
// - initGoroutinePool()
|
||||||
|
// - initFileCache()
|
||||||
|
// - initLuaEngine()
|
||||||
|
// - initErrorPageManager()
|
||||||
|
//
|
||||||
|
// 主要用途:
|
||||||
|
//
|
||||||
|
// 验证初始化函数在各种配置场景下的行为
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestInitGoroutinePool_Disabled 测试禁用时返回 nil
|
||||||
|
func TestInitGoroutinePool_Disabled(t *testing.T) {
|
||||||
|
cfg := &config.PerformanceConfig{
|
||||||
|
GoroutinePool: config.GoroutinePoolConfig{
|
||||||
|
Enabled: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := initGoroutinePool(cfg)
|
||||||
|
if pool != nil {
|
||||||
|
t.Error("Expected nil when GoroutinePool is disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitGoroutinePool_Enabled 测试启用时创建池
|
||||||
|
func TestInitGoroutinePool_Enabled(t *testing.T) {
|
||||||
|
cfg := &config.PerformanceConfig{
|
||||||
|
GoroutinePool: config.GoroutinePoolConfig{
|
||||||
|
Enabled: true,
|
||||||
|
MaxWorkers: 50,
|
||||||
|
MinWorkers: 5,
|
||||||
|
IdleTimeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := initGoroutinePool(cfg)
|
||||||
|
if pool == nil {
|
||||||
|
t.Fatal("Expected non-nil pool when enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证配置正确应用
|
||||||
|
if pool.maxWorkers != 50 {
|
||||||
|
t.Errorf("Expected maxWorkers 50, got %d", pool.maxWorkers)
|
||||||
|
}
|
||||||
|
if pool.minWorkers != 5 {
|
||||||
|
t.Errorf("Expected minWorkers 5, got %d", pool.minWorkers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理
|
||||||
|
pool.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitGoroutinePool_ZeroWorkers 测试零值配置
|
||||||
|
func TestInitGoroutinePool_ZeroWorkers(t *testing.T) {
|
||||||
|
cfg := &config.PerformanceConfig{
|
||||||
|
GoroutinePool: config.GoroutinePoolConfig{
|
||||||
|
Enabled: true,
|
||||||
|
MaxWorkers: 0,
|
||||||
|
MinWorkers: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pool := initGoroutinePool(cfg)
|
||||||
|
if pool == nil {
|
||||||
|
t.Fatal("Expected non-nil pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理
|
||||||
|
pool.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitFileCache_Disabled 测试禁用时返回 nil
|
||||||
|
func TestInitFileCache_Disabled(t *testing.T) {
|
||||||
|
cfg := &config.PerformanceConfig{
|
||||||
|
FileCache: config.FileCacheConfig{
|
||||||
|
MaxEntries: 0,
|
||||||
|
MaxSize: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := initFileCache(cfg)
|
||||||
|
if cache != nil {
|
||||||
|
t.Error("Expected nil when FileCache is disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitFileCache_ByEntries 测试按条目数启用
|
||||||
|
func TestInitFileCache_ByEntries(t *testing.T) {
|
||||||
|
cfg := &config.PerformanceConfig{
|
||||||
|
FileCache: config.FileCacheConfig{
|
||||||
|
MaxEntries: 1000,
|
||||||
|
MaxSize: 0,
|
||||||
|
Inactive: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := initFileCache(cfg)
|
||||||
|
if cache == nil {
|
||||||
|
t.Fatal("Expected non-nil cache when MaxEntries > 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitFileCache_BySize 测试按大小启用
|
||||||
|
func TestInitFileCache_BySize(t *testing.T) {
|
||||||
|
cfg := &config.PerformanceConfig{
|
||||||
|
FileCache: config.FileCacheConfig{
|
||||||
|
MaxEntries: 0,
|
||||||
|
MaxSize: 100 * 1024 * 1024, // 100MB
|
||||||
|
Inactive: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := initFileCache(cfg)
|
||||||
|
if cache == nil {
|
||||||
|
t.Fatal("Expected non-nil cache when MaxSize > 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitFileCache_BothLimits 测试同时设置条目数和大小
|
||||||
|
func TestInitFileCache_BothLimits(t *testing.T) {
|
||||||
|
cfg := &config.PerformanceConfig{
|
||||||
|
FileCache: config.FileCacheConfig{
|
||||||
|
MaxEntries: 1000,
|
||||||
|
MaxSize: 100 * 1024 * 1024,
|
||||||
|
Inactive: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := initFileCache(cfg)
|
||||||
|
if cache == nil {
|
||||||
|
t.Fatal("Expected non-nil cache")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitLuaEngine_Disabled 测试禁用时返回 nil
|
||||||
|
func TestInitLuaEngine_Disabled(t *testing.T) {
|
||||||
|
luaCfg := &config.LuaMiddlewareConfig{
|
||||||
|
Enabled: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
engine, err := initLuaEngine(luaCfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
if engine != nil {
|
||||||
|
t.Error("Expected nil when Lua is disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitLuaEngine_NilConfig 测试 nil 配置
|
||||||
|
func TestInitLuaEngine_NilConfig(t *testing.T) {
|
||||||
|
engine, err := initLuaEngine(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
if engine != nil {
|
||||||
|
t.Error("Expected nil for nil config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitLuaEngine_Enabled 测试启用 Lua 引擎
|
||||||
|
func TestInitLuaEngine_Enabled(t *testing.T) {
|
||||||
|
luaCfg := &config.LuaMiddlewareConfig{
|
||||||
|
Enabled: true,
|
||||||
|
GlobalSettings: config.LuaGlobalSettings{
|
||||||
|
MaxConcurrentCoroutines: 500,
|
||||||
|
CoroutineTimeout: 60 * time.Second,
|
||||||
|
CodeCacheSize: 500,
|
||||||
|
MaxExecutionTime: 60 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
engine, err := initLuaEngine(luaCfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
if engine == nil {
|
||||||
|
t.Fatal("Expected non-nil engine when enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理
|
||||||
|
engine.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitLuaEngine_DefaultValues 测试默认值设置
|
||||||
|
func TestInitLuaEngine_DefaultValues(t *testing.T) {
|
||||||
|
luaCfg := &config.LuaMiddlewareConfig{
|
||||||
|
Enabled: true,
|
||||||
|
GlobalSettings: config.LuaGlobalSettings{
|
||||||
|
// 所有值为零,应使用默认值
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
engine, err := initLuaEngine(luaCfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
if engine == nil {
|
||||||
|
t.Fatal("Expected non-nil engine")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理
|
||||||
|
engine.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitErrorPageManager_NoConfig 测试无配置时返回 nil
|
||||||
|
func TestInitErrorPageManager_NoConfig(t *testing.T) {
|
||||||
|
cfg := &config.ErrorPageConfig{
|
||||||
|
Pages: nil,
|
||||||
|
Default: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := initErrorPageManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
if manager != nil {
|
||||||
|
t.Error("Expected nil when no error page configured")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitErrorPageManager_NilConfig 测试 nil 配置
|
||||||
|
func TestInitErrorPageManager_NilConfig(t *testing.T) {
|
||||||
|
manager, err := initErrorPageManager(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
if manager != nil {
|
||||||
|
t.Error("Expected nil for nil config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitErrorPageManager_WithPages 测试配置错误页面
|
||||||
|
func TestInitErrorPageManager_WithPages(t *testing.T) {
|
||||||
|
// 创建临时目录和文件
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
errorPagePath := filepath.Join(tempDir, "404.html")
|
||||||
|
content := []byte("<html><body>Not Found</body></html>")
|
||||||
|
if err := os.WriteFile(errorPagePath, content, 0o644); err != nil {
|
||||||
|
t.Fatalf("Failed to create error page file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.ErrorPageConfig{
|
||||||
|
Pages: map[int]string{
|
||||||
|
404: errorPagePath,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := initErrorPageManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
if manager == nil {
|
||||||
|
t.Fatal("Expected non-nil manager when pages configured")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitErrorPageManager_WithDefault 测试配置默认错误页面
|
||||||
|
func TestInitErrorPageManager_WithDefault(t *testing.T) {
|
||||||
|
// 创建临时目录和文件
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
defaultPagePath := filepath.Join(tempDir, "error.html")
|
||||||
|
content := []byte("<html><body>Error</body></html>")
|
||||||
|
if err := os.WriteFile(defaultPagePath, content, 0o644); err != nil {
|
||||||
|
t.Fatalf("Failed to create error page file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.ErrorPageConfig{
|
||||||
|
Default: defaultPagePath,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := initErrorPageManager(cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
if manager == nil {
|
||||||
|
t.Fatal("Expected non-nil manager when default page configured")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitErrorPageManager_NonExistentFile 测试不存在的文件
|
||||||
|
func TestInitErrorPageManager_NonExistentFile(t *testing.T) {
|
||||||
|
cfg := &config.ErrorPageConfig{
|
||||||
|
Default: "/nonexistent/path/error.html",
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := initErrorPageManager(cfg)
|
||||||
|
// 应该返回错误
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for non-existent file")
|
||||||
|
}
|
||||||
|
if manager != nil {
|
||||||
|
t.Error("Expected nil manager when file doesn't exist")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitErrorPageManager_PartialLoad 测试部分加载(部分文件存在)
|
||||||
|
func TestInitErrorPageManager_PartialLoad(t *testing.T) {
|
||||||
|
// 创建临时目录和一个存在的文件
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
existingPath := filepath.Join(tempDir, "404.html")
|
||||||
|
content := []byte("<html><body>Not Found</body></html>")
|
||||||
|
if err := os.WriteFile(existingPath, content, 0o644); err != nil {
|
||||||
|
t.Fatalf("Failed to create error page file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.ErrorPageConfig{
|
||||||
|
Pages: map[int]string{
|
||||||
|
404: existingPath,
|
||||||
|
500: "/nonexistent/path/500.html", // 不存在的文件
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := initErrorPageManager(cfg)
|
||||||
|
// 部分加载应该成功返回 manager,但可能有警告
|
||||||
|
// 注意:具体行为取决于 handler.NewErrorPageManager 的实现
|
||||||
|
if manager == nil {
|
||||||
|
t.Logf("Manager is nil with error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInitFunctions_NilPerformanceConfig 测试 nil PerformanceConfig
|
||||||
|
func TestInitFunctions_NilPerformanceConfig(t *testing.T) {
|
||||||
|
// 这个测试验证函数能正确处理空配置
|
||||||
|
var cfg *config.PerformanceConfig
|
||||||
|
|
||||||
|
// 应该能处理 nil 而不 panic
|
||||||
|
pool := initGoroutinePool(cfg)
|
||||||
|
if pool != nil {
|
||||||
|
t.Error("Expected nil for nil config")
|
||||||
|
}
|
||||||
|
|
||||||
|
cache := initFileCache(cfg)
|
||||||
|
if cache != nil {
|
||||||
|
t.Error("Expected nil for nil config")
|
||||||
|
}
|
||||||
|
}
|
||||||
372
internal/server/lua_integration_test.go
Normal file
372
internal/server/lua_integration_test.go
Normal file
@ -0,0 +1,372 @@
|
|||||||
|
// Package server 提供 buildLuaMiddlewares 的 Mock 测试
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/lua"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestBuildLuaMiddlewares_NilEngine 测试 LuaEngine 为 nil 时
|
||||||
|
func TestBuildLuaMiddlewares_NilEngine(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: ":8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
// 确保 luaEngine 为 nil
|
||||||
|
s.luaEngine = nil
|
||||||
|
|
||||||
|
luaCfg := &config.LuaMiddlewareConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Scripts: []config.LuaScriptConfig{
|
||||||
|
{
|
||||||
|
Path: "/test/script.lua",
|
||||||
|
Phase: "rewrite",
|
||||||
|
Enabled: true,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, middlewares)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildLuaMiddlewares_InvalidPhase 测试无效阶段
|
||||||
|
func TestBuildLuaMiddlewares_InvalidPhase(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: ":8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
// 使用真实的 LuaEngine 来测试无效阶段
|
||||||
|
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer engine.Close()
|
||||||
|
|
||||||
|
s.luaEngine = engine
|
||||||
|
|
||||||
|
luaCfg := &config.LuaMiddlewareConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Scripts: []config.LuaScriptConfig{
|
||||||
|
{
|
||||||
|
Path: "/test/script.lua",
|
||||||
|
Phase: "invalid_phase",
|
||||||
|
Enabled: true,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.buildLuaMiddlewares(luaCfg)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "无效的阶段")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildLuaMiddlewares_WithTimeout 测试超时配置
|
||||||
|
func TestBuildLuaMiddlewares_WithTimeout(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: ":8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer engine.Close()
|
||||||
|
|
||||||
|
s.luaEngine = engine
|
||||||
|
|
||||||
|
// 测试自定义超时
|
||||||
|
customTimeout := 60 * time.Second
|
||||||
|
luaCfg := &config.LuaMiddlewareConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Scripts: []config.LuaScriptConfig{
|
||||||
|
{
|
||||||
|
Path: "/test/script.lua",
|
||||||
|
Phase: "rewrite",
|
||||||
|
Enabled: true,
|
||||||
|
Timeout: customTimeout,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, middlewares, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildLuaMiddlewares_EmptyScripts 测试空脚本列表
|
||||||
|
func TestBuildLuaMiddlewares_EmptyScripts(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: ":8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer engine.Close()
|
||||||
|
|
||||||
|
s.luaEngine = engine
|
||||||
|
|
||||||
|
luaCfg := &config.LuaMiddlewareConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Scripts: []config.LuaScriptConfig{},
|
||||||
|
}
|
||||||
|
|
||||||
|
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, middlewares)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildLuaMiddlewares_DisabledLua 测试禁用 Lua 配置
|
||||||
|
// buildLuaMiddlewares 函数本身不检查 Enabled 字段,由调用者检查
|
||||||
|
func TestBuildLuaMiddlewares_DisabledLua(t *testing.T) {
|
||||||
|
// 创建临时脚本文件
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
scriptPath := tempDir + "/test.lua"
|
||||||
|
err := os.WriteFile(scriptPath, []byte("-- test"), 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: ":8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer engine.Close()
|
||||||
|
|
||||||
|
s.luaEngine = engine
|
||||||
|
|
||||||
|
// Enabled=false 但 buildLuaMiddlewares 本身不检查这个字段
|
||||||
|
// 它会正常处理 Scripts 中的脚本
|
||||||
|
luaCfg := &config.LuaMiddlewareConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Scripts: []config.LuaScriptConfig{
|
||||||
|
{
|
||||||
|
Path: scriptPath,
|
||||||
|
Phase: "rewrite",
|
||||||
|
Enabled: true,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 由于 Enabled 检查在调用者处,buildLuaMiddlewares 会正常执行
|
||||||
|
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// 脚本文件存在,应该能创建中间件
|
||||||
|
assert.Len(t, middlewares, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildLuaMiddlewares_DisabledScript 测试禁用特定脚本
|
||||||
|
func TestBuildLuaMiddlewares_DisabledScript(t *testing.T) {
|
||||||
|
// 创建临时脚本文件(只有一个启用的脚本)
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
enabledScriptPath := tempDir + "/enabled.lua"
|
||||||
|
err := os.WriteFile(enabledScriptPath, []byte("-- enabled script"), 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: ":8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer engine.Close()
|
||||||
|
|
||||||
|
s.luaEngine = engine
|
||||||
|
|
||||||
|
luaCfg := &config.LuaMiddlewareConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Scripts: []config.LuaScriptConfig{
|
||||||
|
{
|
||||||
|
Path: enabledScriptPath,
|
||||||
|
Phase: "rewrite",
|
||||||
|
Enabled: true,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/test/disabled.lua",
|
||||||
|
Phase: "access",
|
||||||
|
Enabled: false, // 禁用的脚本应该被过滤
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 只有一个启用的脚本,应该能正常创建中间件
|
||||||
|
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// 只有 rewrite 阶段的脚本被创建
|
||||||
|
assert.Len(t, middlewares, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildLuaMiddlewares_WithExistingScript 测试使用存在的脚本文件
|
||||||
|
func TestBuildLuaMiddlewares_WithExistingScript(t *testing.T) {
|
||||||
|
// 创建临时脚本文件
|
||||||
|
scriptContent := `-- test script
|
||||||
|
ngx.var.uri = "/test"
|
||||||
|
`
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
scriptPath := tempDir + "/test.lua"
|
||||||
|
err := os.WriteFile(scriptPath, []byte(scriptContent), 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: ":8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer engine.Close()
|
||||||
|
|
||||||
|
s.luaEngine = engine
|
||||||
|
|
||||||
|
luaCfg := &config.LuaMiddlewareConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Scripts: []config.LuaScriptConfig{
|
||||||
|
{
|
||||||
|
Path: scriptPath,
|
||||||
|
Phase: "rewrite",
|
||||||
|
Enabled: true,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, middlewares, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildLuaMiddlewares_MultiplePhases 测试多个不同阶段的脚本
|
||||||
|
func TestBuildLuaMiddlewares_MultiplePhases(t *testing.T) {
|
||||||
|
// 创建临时脚本文件
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
scripts := map[string]string{
|
||||||
|
"rewrite.lua": `-- rewrite script
|
||||||
|
ngx.var.uri = "/rewrite"
|
||||||
|
`,
|
||||||
|
"access.lua": `-- access script
|
||||||
|
-- access control logic
|
||||||
|
`,
|
||||||
|
"content.lua": `-- content script
|
||||||
|
ngx.say("hello")
|
||||||
|
`,
|
||||||
|
}
|
||||||
|
|
||||||
|
scriptPaths := make(map[string]string)
|
||||||
|
for name, content := range scripts {
|
||||||
|
path := tempDir + "/" + name
|
||||||
|
err := os.WriteFile(path, []byte(content), 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
scriptPaths[name] = path
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: ":8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer engine.Close()
|
||||||
|
|
||||||
|
s.luaEngine = engine
|
||||||
|
|
||||||
|
luaCfg := &config.LuaMiddlewareConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Scripts: []config.LuaScriptConfig{
|
||||||
|
{
|
||||||
|
Path: scriptPaths["rewrite.lua"],
|
||||||
|
Phase: "rewrite",
|
||||||
|
Enabled: true,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: scriptPaths["access.lua"],
|
||||||
|
Phase: "access",
|
||||||
|
Enabled: true,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: scriptPaths["content.lua"],
|
||||||
|
Phase: "content",
|
||||||
|
Enabled: true,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// 三个阶段应该创建三个中间件
|
||||||
|
assert.Len(t, middlewares, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildLuaMiddlewares_DefaultTimeout 测试默认超时值
|
||||||
|
func TestBuildLuaMiddlewares_DefaultTimeout(t *testing.T) {
|
||||||
|
// 创建临时脚本文件
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
scriptPath := tempDir + "/test.lua"
|
||||||
|
err := os.WriteFile(scriptPath, []byte("-- test"), 0o644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: ":8080",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
engine, err := lua.NewEngine(lua.DefaultConfig())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer engine.Close()
|
||||||
|
|
||||||
|
s.luaEngine = engine
|
||||||
|
|
||||||
|
// 不设置 Timeout,使用默认值
|
||||||
|
luaCfg := &config.LuaMiddlewareConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Scripts: []config.LuaScriptConfig{
|
||||||
|
{
|
||||||
|
Path: scriptPath,
|
||||||
|
Phase: "rewrite",
|
||||||
|
Enabled: true,
|
||||||
|
// Timeout 为 0,应该使用默认值 30s
|
||||||
|
Timeout: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
middlewares, err := s.buildLuaMiddlewares(luaCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, middlewares, 1)
|
||||||
|
}
|
||||||
@ -405,73 +405,23 @@ func (s *Server) Start() error {
|
|||||||
// 记录启动时间
|
// 记录启动时间
|
||||||
s.startTime = time.Now()
|
s.startTime = time.Now()
|
||||||
|
|
||||||
// 启用 GoroutinePool(如果配置)
|
// 初始化 GoroutinePool
|
||||||
if s.config.Performance.GoroutinePool.Enabled {
|
s.pool = initGoroutinePool(&s.config.Performance)
|
||||||
s.pool = NewGoroutinePool(PoolConfig{
|
|
||||||
MaxWorkers: s.config.Performance.GoroutinePool.MaxWorkers,
|
// 初始化文件缓存
|
||||||
MinWorkers: s.config.Performance.GoroutinePool.MinWorkers,
|
s.fileCache = initFileCache(&s.config.Performance)
|
||||||
IdleTimeout: s.config.Performance.GoroutinePool.IdleTimeout,
|
|
||||||
})
|
// 初始化错误页面管理器
|
||||||
s.pool.Start()
|
var err error
|
||||||
|
s.errorPageManager, err = initErrorPageManager(&s.config.Server.Security.ErrorPage)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 启用文件缓存(如果配置)
|
// 初始化 Lua 引擎
|
||||||
if s.config.Performance.FileCache.MaxEntries > 0 || s.config.Performance.FileCache.MaxSize > 0 {
|
s.luaEngine, err = initLuaEngine(s.config.Server.Lua)
|
||||||
s.fileCache = cache.NewFileCache(
|
if err != nil {
|
||||||
s.config.Performance.FileCache.MaxEntries,
|
return err
|
||||||
s.config.Performance.FileCache.MaxSize,
|
|
||||||
s.config.Performance.FileCache.Inactive,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 预加载错误页面(如果配置)
|
|
||||||
if s.config.Server.Security.ErrorPage.Pages != nil || s.config.Server.Security.ErrorPage.Default != "" {
|
|
||||||
var err error
|
|
||||||
s.errorPageManager, err = handler.NewErrorPageManager(&s.config.Server.Security.ErrorPage)
|
|
||||||
if err != nil {
|
|
||||||
// 检查是否是部分加载失败
|
|
||||||
if _, ok := err.(*handler.PartialLoadError); ok {
|
|
||||||
logging.Warn().Msg("部分错误页面加载失败: " + err.Error())
|
|
||||||
} else {
|
|
||||||
// 全部加载失败,阻止启动
|
|
||||||
return fmt.Errorf("加载错误页面失败: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 初始化 Lua 引擎(如果配置)
|
|
||||||
if s.config.Server.Lua != nil && s.config.Server.Lua.Enabled {
|
|
||||||
engineCfg := &lua.Config{
|
|
||||||
MaxConcurrentCoroutines: s.config.Server.Lua.GlobalSettings.MaxConcurrentCoroutines,
|
|
||||||
CoroutineTimeout: s.config.Server.Lua.GlobalSettings.CoroutineTimeout,
|
|
||||||
CodeCacheSize: s.config.Server.Lua.GlobalSettings.CodeCacheSize,
|
|
||||||
CodeCacheTTL: time.Hour, // 默认值
|
|
||||||
EnableFileWatch: s.config.Server.Lua.GlobalSettings.EnableFileWatch,
|
|
||||||
MaxExecutionTime: s.config.Server.Lua.GlobalSettings.MaxExecutionTime,
|
|
||||||
EnableOSLib: false, // 安全默认值
|
|
||||||
EnableIOLib: false,
|
|
||||||
EnableLoadLib: false,
|
|
||||||
}
|
|
||||||
// 设置默认值
|
|
||||||
if engineCfg.MaxConcurrentCoroutines == 0 {
|
|
||||||
engineCfg.MaxConcurrentCoroutines = 1000
|
|
||||||
}
|
|
||||||
if engineCfg.CoroutineTimeout == 0 {
|
|
||||||
engineCfg.CoroutineTimeout = 30 * time.Second
|
|
||||||
}
|
|
||||||
if engineCfg.CodeCacheSize == 0 {
|
|
||||||
engineCfg.CodeCacheSize = 1000
|
|
||||||
}
|
|
||||||
if engineCfg.MaxExecutionTime == 0 {
|
|
||||||
engineCfg.MaxExecutionTime = 30 * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
s.luaEngine, err = lua.NewEngine(engineCfg)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("初始化 Lua 引擎失败: %w", err)
|
|
||||||
}
|
|
||||||
logging.Info().Msg("Lua 引擎已启动")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.config.HasServers() {
|
if s.config.HasServers() {
|
||||||
|
|||||||
@ -769,3 +769,32 @@ func TestGetTLSConfig_Nil(t *testing.T) {
|
|||||||
t.Error("GetTLSConfig() should return nil when TLS not configured")
|
t.Error("GetTLSConfig() should return nil when TLS not configured")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestGetTLSConfig_NilServer 测试 nil 服务器调用 GetTLSConfig
|
||||||
|
func TestGetTLSConfig_NilServer(t *testing.T) {
|
||||||
|
var s *Server
|
||||||
|
// 防御性:如果 s 为 nil,调用方法会 panic,这是预期的行为
|
||||||
|
// 这里我们只测试非 nil 但 tlsManager 为 nil 的情况
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: ":0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s = New(cfg)
|
||||||
|
|
||||||
|
// 确保 tlsManager 为 nil
|
||||||
|
if s.tlsManager != nil {
|
||||||
|
t.Skip("tlsManager should be nil initially")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsCfg, err := s.GetTLSConfig()
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when tlsManager is nil")
|
||||||
|
}
|
||||||
|
if tlsCfg != nil {
|
||||||
|
t.Error("Expected nil TLS config when tlsManager is nil")
|
||||||
|
}
|
||||||
|
if err.Error() != "TLS not configured" {
|
||||||
|
t.Errorf("Expected error 'TLS not configured', got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
556
internal/server/start_integration_test.go
Normal file
556
internal/server/start_integration_test.go
Normal file
@ -0,0 +1,556 @@
|
|||||||
|
// Package server 提供 Server.Start() 的集成测试。
|
||||||
|
//
|
||||||
|
// 该文件使用 mock_backend 模拟上游服务,测试完整的服务器启动流程:
|
||||||
|
// - 服务器配置初始化
|
||||||
|
// - 代理路由注册
|
||||||
|
// - 静态文件服务
|
||||||
|
// - 中间件链构建
|
||||||
|
// - 请求处理和转发
|
||||||
|
//
|
||||||
|
// 主要用途:
|
||||||
|
//
|
||||||
|
// 验证服务器启动和请求处理的完整流程
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"rua.plus/lolly/internal/benchmark/tools"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestStart_Integration 测试完整的服务器启动和请求处理流程
|
||||||
|
func TestStart_Integration(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动 mock 上游服务器
|
||||||
|
backendAddr, cleanup := tools.SimpleMockBackend(
|
||||||
|
fasthttp.StatusOK,
|
||||||
|
[]byte(`{"message": "Hello from backend"}`),
|
||||||
|
)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// 使用随机端口避免冲突
|
||||||
|
serverAddr := "127.0.0.1:0"
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: serverAddr,
|
||||||
|
Proxy: []config.ProxyConfig{
|
||||||
|
{
|
||||||
|
Path: "/api",
|
||||||
|
Targets: []config.ProxyTarget{
|
||||||
|
{URL: "http://" + backendAddr, Weight: 1},
|
||||||
|
},
|
||||||
|
HealthCheck: config.HealthCheckConfig{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
// 验证服务器初始化
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试初始状态
|
||||||
|
if s.running {
|
||||||
|
t.Error("Server should not be running initially")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试 GetListeners 初始状态
|
||||||
|
listeners := s.GetListeners()
|
||||||
|
if listeners != nil {
|
||||||
|
t.Error("Listeners should be nil before Start")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_WithSecurity 测试安全配置
|
||||||
|
func TestStart_WithSecurity(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
Security: config.SecurityConfig{
|
||||||
|
Access: config.AccessConfig{
|
||||||
|
Allow: []string{"127.0.0.1"},
|
||||||
|
Deny: []string{},
|
||||||
|
},
|
||||||
|
RateLimit: config.RateLimitConfig{
|
||||||
|
RequestRate: 100,
|
||||||
|
Burst: 200,
|
||||||
|
},
|
||||||
|
Headers: config.SecurityHeaders{
|
||||||
|
XFrameOptions: "DENY",
|
||||||
|
XContentTypeOptions: "nosniff",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证安全配置
|
||||||
|
if len(s.config.Server.Security.Access.Allow) != 1 {
|
||||||
|
t.Errorf("Expected 1 allowed IP, got %d", len(s.config.Server.Security.Access.Allow))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_WithRewrite 测试 URL 重写配置
|
||||||
|
func TestStart_WithRewrite(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
Rewrite: []config.RewriteRule{
|
||||||
|
{
|
||||||
|
Pattern: "/old/(.*)",
|
||||||
|
Replacement: "/new/$1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证重写配置
|
||||||
|
if len(s.config.Server.Rewrite) != 1 {
|
||||||
|
t.Errorf("Expected 1 rewrite rule, got %d", len(s.config.Server.Rewrite))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_WithMonitoring 测试监控配置
|
||||||
|
func TestStart_WithMonitoring(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
},
|
||||||
|
Monitoring: config.MonitoringConfig{
|
||||||
|
Status: config.StatusConfig{
|
||||||
|
Path: "/status",
|
||||||
|
Allow: []string{"127.0.0.1"},
|
||||||
|
},
|
||||||
|
Pprof: config.PprofConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Path: "/debug/pprof",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证监控配置
|
||||||
|
if s.config.Monitoring.Status.Path != "/status" {
|
||||||
|
t.Errorf("Expected status path '/status', got '%s'", s.config.Monitoring.Status.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_WithErrorPage 测试错误页面配置
|
||||||
|
func TestStart_WithErrorPage(t *testing.T) {
|
||||||
|
// 创建临时错误页面文件
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
errorPagePath := tempDir + "/404.html"
|
||||||
|
|
||||||
|
// 创建错误页面文件
|
||||||
|
content := []byte("<html><body>Not Found</body></html>")
|
||||||
|
if err := writeFile(errorPagePath, content); err != nil {
|
||||||
|
t.Fatalf("Failed to create error page file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
Security: config.SecurityConfig{
|
||||||
|
ErrorPage: config.ErrorPageConfig{
|
||||||
|
Pages: map[int]string{
|
||||||
|
404: errorPagePath,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证错误页面配置
|
||||||
|
if s.config.Server.Security.ErrorPage.Pages == nil {
|
||||||
|
t.Error("Error page pages should not be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_WithLuaEnabled 测试 Lua 配置
|
||||||
|
func TestStart_WithLuaEnabled(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
Lua: &config.LuaMiddlewareConfig{
|
||||||
|
Enabled: true,
|
||||||
|
GlobalSettings: config.LuaGlobalSettings{
|
||||||
|
MaxConcurrentCoroutines: 100,
|
||||||
|
CoroutineTimeout: 30 * time.Second,
|
||||||
|
CodeCacheSize: 100,
|
||||||
|
MaxExecutionTime: 30 * time.Second,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 Lua 配置
|
||||||
|
if s.config.Server.Lua == nil || !s.config.Server.Lua.Enabled {
|
||||||
|
t.Error("Lua should be enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_WithMultipleProxies 测试多个代理配置
|
||||||
|
func TestStart_WithMultipleProxies(t *testing.T) {
|
||||||
|
// 启动多个 mock 上游服务器
|
||||||
|
backend1, cleanup1 := tools.SimpleMockBackend(
|
||||||
|
fasthttp.StatusOK,
|
||||||
|
[]byte(`{"service": "api1"}`),
|
||||||
|
)
|
||||||
|
defer cleanup1()
|
||||||
|
|
||||||
|
backend2, cleanup2 := tools.SimpleMockBackend(
|
||||||
|
fasthttp.StatusOK,
|
||||||
|
[]byte(`{"service": "api2"}`),
|
||||||
|
)
|
||||||
|
defer cleanup2()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
Proxy: []config.ProxyConfig{
|
||||||
|
{
|
||||||
|
Path: "/api1",
|
||||||
|
Targets: []config.ProxyTarget{
|
||||||
|
{URL: "http://" + backend1, Weight: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/api2",
|
||||||
|
Targets: []config.ProxyTarget{
|
||||||
|
{URL: "http://" + backend2, Weight: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证代理配置
|
||||||
|
if len(s.config.Server.Proxy) != 2 {
|
||||||
|
t.Errorf("Expected 2 proxies, got %d", len(s.config.Server.Proxy))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_EmptyConfig 测试空配置
|
||||||
|
func TestStart_EmptyConfig(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 空配置应该能正常初始化
|
||||||
|
if s.config == nil {
|
||||||
|
t.Error("Config should not be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_WithAllFeatures 测试启用所有功能的配置
|
||||||
|
func TestStart_WithAllFeatures(t *testing.T) {
|
||||||
|
// 创建临时目录
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
errorPagePath := tempDir + "/404.html"
|
||||||
|
writeFile(errorPagePath, []byte("<html><body>Not Found</body></html>"))
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
Static: []config.StaticConfig{
|
||||||
|
{
|
||||||
|
Path: "/static",
|
||||||
|
Root: tempDir,
|
||||||
|
Index: []string{"index.html"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Compression: config.CompressionConfig{
|
||||||
|
Type: "gzip",
|
||||||
|
Level: 6,
|
||||||
|
},
|
||||||
|
Security: config.SecurityConfig{
|
||||||
|
Access: config.AccessConfig{
|
||||||
|
Allow: []string{"127.0.0.1"},
|
||||||
|
},
|
||||||
|
RateLimit: config.RateLimitConfig{
|
||||||
|
RequestRate: 100,
|
||||||
|
Burst: 200,
|
||||||
|
},
|
||||||
|
Headers: config.SecurityHeaders{
|
||||||
|
XFrameOptions: "DENY",
|
||||||
|
},
|
||||||
|
ErrorPage: config.ErrorPageConfig{
|
||||||
|
Default: errorPagePath,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Rewrite: []config.RewriteRule{
|
||||||
|
{
|
||||||
|
Pattern: "/old/(.*)",
|
||||||
|
Replacement: "/new/$1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Performance: config.PerformanceConfig{
|
||||||
|
GoroutinePool: config.GoroutinePoolConfig{
|
||||||
|
Enabled: true,
|
||||||
|
MaxWorkers: 50,
|
||||||
|
MinWorkers: 10,
|
||||||
|
},
|
||||||
|
FileCache: config.FileCacheConfig{
|
||||||
|
MaxEntries: 1000,
|
||||||
|
MaxSize: 100 * 1024 * 1024,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Monitoring: config.MonitoringConfig{
|
||||||
|
Status: config.StatusConfig{
|
||||||
|
Path: "/status",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证所有配置
|
||||||
|
if !s.config.Performance.GoroutinePool.Enabled {
|
||||||
|
t.Error("GoroutinePool should be enabled")
|
||||||
|
}
|
||||||
|
if s.config.Server.Compression.Type != "gzip" {
|
||||||
|
t.Error("Compression should be gzip")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_ServerOptions 测试服务器配置选项
|
||||||
|
func TestStart_ServerOptions(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
ReadTimeout: 30 * time.Second,
|
||||||
|
WriteTimeout: 30 * time.Second,
|
||||||
|
IdleTimeout: 60 * time.Second,
|
||||||
|
MaxConnsPerIP: 100,
|
||||||
|
MaxRequestsPerConn: 1000,
|
||||||
|
ClientMaxBodySize: "10MB",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证服务器选项
|
||||||
|
if s.config.Server.ReadTimeout != 30*time.Second {
|
||||||
|
t.Errorf("Expected ReadTimeout 30s, got %v", s.config.Server.ReadTimeout)
|
||||||
|
}
|
||||||
|
if s.config.Server.MaxConnsPerIP != 100 {
|
||||||
|
t.Errorf("Expected MaxConnsPerIP 100, got %d", s.config.Server.MaxConnsPerIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_HealthCheckConfig 测试健康检查配置
|
||||||
|
func TestStart_HealthCheckConfig(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
Proxy: []config.ProxyConfig{
|
||||||
|
{
|
||||||
|
Path: "/api",
|
||||||
|
Targets: []config.ProxyTarget{
|
||||||
|
{URL: "http://127.0.0.1:8081", Weight: 1},
|
||||||
|
},
|
||||||
|
HealthCheck: config.HealthCheckConfig{
|
||||||
|
Interval: 10 * time.Second,
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
Path: "/health",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证健康检查配置
|
||||||
|
if s.config.Server.Proxy[0].HealthCheck.Path != "/health" {
|
||||||
|
t.Error("Health check path should be /health")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_VHostMode 测试虚拟主机模式配置
|
||||||
|
func TestStart_VHostMode(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Servers: []config.ServerConfig{
|
||||||
|
{
|
||||||
|
Name: "api.example.com",
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "www.example.com",
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证虚拟主机配置
|
||||||
|
if !s.config.HasServers() {
|
||||||
|
t.Error("Should detect virtual hosts")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_WithProxyBackendError 测试代理后端错误处理
|
||||||
|
func TestStart_WithProxyBackendError(t *testing.T) {
|
||||||
|
// 启动返回错误的 mock 服务器
|
||||||
|
backendAddr, cleanup := tools.ErrorMockBackend(1.0, []byte(`{"error": "backend error"}`))
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
Proxy: []config.ProxyConfig{
|
||||||
|
{
|
||||||
|
Path: "/api",
|
||||||
|
Targets: []config.ProxyTarget{
|
||||||
|
{URL: "http://" + backendAddr, Weight: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证代理配置
|
||||||
|
if len(s.config.Server.Proxy) != 1 {
|
||||||
|
t.Errorf("Expected 1 proxy, got %d", len(s.config.Server.Proxy))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_WithDelayedBackend 测试延迟后端
|
||||||
|
func TestStart_WithDelayedBackend(t *testing.T) {
|
||||||
|
// 启动延迟的 mock 服务器
|
||||||
|
backendAddr, cleanup := tools.DelayedMockBackend(
|
||||||
|
100*time.Millisecond,
|
||||||
|
[]byte(`{"message": "delayed response"}`),
|
||||||
|
)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
Proxy: []config.ProxyConfig{
|
||||||
|
{
|
||||||
|
Path: "/api",
|
||||||
|
Targets: []config.ProxyTarget{
|
||||||
|
{URL: "http://" + backendAddr, Weight: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_WithRandomResponse 测试随机响应后端
|
||||||
|
func TestStart_WithRandomResponse(t *testing.T) {
|
||||||
|
// 启动随机响应的 mock 服务器
|
||||||
|
backendAddr, cleanup := tools.StartMockFasthttpBackend(tools.MockBackendConfig{
|
||||||
|
Mode: tools.ModeRandomResponse,
|
||||||
|
StatusCode: fasthttp.StatusOK,
|
||||||
|
Body: []byte(`{"random": true}`),
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Server: config.ServerConfig{
|
||||||
|
Listen: "127.0.0.1:0",
|
||||||
|
Proxy: []config.ProxyConfig{
|
||||||
|
{
|
||||||
|
Path: "/api",
|
||||||
|
Targets: []config.ProxyTarget{
|
||||||
|
{URL: "http://" + backendAddr, Weight: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("Expected non-nil server")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeFile 辅助函数:写入文件
|
||||||
|
func writeFile(path string, content []byte) error {
|
||||||
|
f, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
_, err = f.Write(content)
|
||||||
|
return err
|
||||||
|
}
|
||||||
121
internal/server/testutil.go
Normal file
121
internal/server/testutil.go
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
// Package server 提供测试工具函数和依赖注入支持
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/lua"
|
||||||
|
"rua.plus/lolly/internal/ssl"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockFastServer 是 fasthttp.Server 的 Mock 包装
|
||||||
|
// 定义在此文件以便 TestServerOptions 可以引用
|
||||||
|
type MockFastServer struct {
|
||||||
|
Name string
|
||||||
|
Handler fasthttp.RequestHandler
|
||||||
|
TLSConfig *tls.Config
|
||||||
|
ReadTimeout time.Duration
|
||||||
|
WriteTimeout time.Duration
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
MaxConnsPerIP int
|
||||||
|
MaxRequestsPerConn int
|
||||||
|
CloseOnShutdown bool
|
||||||
|
ServeFunc func(ln net.Listener) error
|
||||||
|
ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error
|
||||||
|
ShutdownFunc func() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve 启动服务
|
||||||
|
func (m *MockFastServer) Serve(ln net.Listener) error {
|
||||||
|
if m.ServeFunc != nil {
|
||||||
|
return m.ServeFunc(ln)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeTLS 启动 TLS 服务
|
||||||
|
func (m *MockFastServer) ServeTLS(ln net.Listener, certFile, keyFile string) error {
|
||||||
|
if m.ServeTLSFunc != nil {
|
||||||
|
return m.ServeTLSFunc(ln, certFile, keyFile)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown 关闭服务器
|
||||||
|
func (m *MockFastServer) Shutdown() error {
|
||||||
|
if m.ShutdownFunc != nil {
|
||||||
|
return m.ShutdownFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDependencies 包含测试时可注入的依赖
|
||||||
|
// 使用具体指针类型,允许注入 Mock 实现
|
||||||
|
type TestDependencies struct {
|
||||||
|
LuaEngine *lua.LuaEngine
|
||||||
|
TLSManager *ssl.TLSManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServerForTesting 创建用于测试的服务器实例
|
||||||
|
// 允许注入 Mock 依赖,不改变生产 API
|
||||||
|
func NewServerForTesting(cfg *config.Config, deps *TestDependencies) *Server {
|
||||||
|
s := New(cfg)
|
||||||
|
if deps != nil {
|
||||||
|
if deps.LuaEngine != nil {
|
||||||
|
s.luaEngine = deps.LuaEngine
|
||||||
|
}
|
||||||
|
if deps.TLSManager != nil {
|
||||||
|
s.tlsManager = deps.TLSManager
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServerOptions 测试服务器的可选配置
|
||||||
|
type TestServerOptions struct {
|
||||||
|
SkipListener bool
|
||||||
|
MockFastServer *MockFastServer
|
||||||
|
CustomHandler fasthttp.RequestHandler
|
||||||
|
DisableMiddleware bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestServerWithOptions 使用选项创建测试服务器
|
||||||
|
func NewTestServerWithOptions(cfg *config.Config, opts *TestServerOptions) *Server {
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
if opts != nil {
|
||||||
|
// 可以在这里应用各种测试选项
|
||||||
|
if opts.CustomHandler != nil {
|
||||||
|
s.handler = opts.CustomHandler
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustStartTestServer 启动测试服务器,失败时 panic
|
||||||
|
// 主要用于集成测试
|
||||||
|
func MustStartTestServer(cfg *config.Config) *Server {
|
||||||
|
s := New(cfg)
|
||||||
|
// 在测试环境中使用随机端口避免冲突
|
||||||
|
if cfg.Server.Listen == "" || cfg.Server.Listen == ":80" {
|
||||||
|
cfg.Server.Listen = "127.0.0.1:0"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 goroutine 启动服务器以避免阻塞
|
||||||
|
go func() {
|
||||||
|
if err := s.Start(); err != nil {
|
||||||
|
// 测试服务器启动失败记录日志
|
||||||
|
panic("failed to start test server: " + err.Error())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 给服务器一点时间启动
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user