diff --git a/internal/config/config.go b/internal/config/config.go index 460cffb..6f33077 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -281,6 +281,10 @@ type ServerConfig struct { // CacheAPI 缓存 API 配置 // 用于主动清理代理缓存 CacheAPI *CacheAPIConfig `yaml:"cache_api"` + + // Lua Lua 中间件配置 + // 用于嵌入 Lua 脚本处理请求 + Lua *LuaMiddlewareConfig `yaml:"lua"` } // StaticConfig 静态文件服务配置。 @@ -1525,6 +1529,102 @@ type CacheAPIAuthConfig struct { Token string `yaml:"token"` } +// LuaMiddlewareConfig Lua 中间件配置(配置文件格式) +// +// 用于配置 Lua 中间件的行为,包括脚本路径、执行阶段和全局设置。 +// +// 注意事项: +// - Enabled 为 true 时启用 Lua 中间件 +// - Scripts 配置要执行的脚本列表 +// - GlobalSettings 控制 Lua 引擎的全局行为 +// +// 使用示例: +// +// lua: +// enabled: true +// scripts: +// - path: "/scripts/auth.lua" +// phase: "access" +// timeout: 10s +// global_settings: +// max_concurrent_coroutines: 1000 +// coroutine_timeout: 30s +type LuaMiddlewareConfig struct { + // Enabled 是否启用 Lua 中间件 + Enabled bool `yaml:"enabled"` + + // Scripts 脚本配置列表 + Scripts []LuaScriptConfig `yaml:"scripts"` + + // GlobalSettings 全局设置 + GlobalSettings LuaGlobalSettings `yaml:"global_settings"` +} + +// LuaScriptConfig 单个脚本配置 +// +// 定义单个 Lua 脚本的执行参数。 +// +// 注意事项: +// - Path 为脚本文件路径,必需字段 +// - Phase 为执行阶段,必需字段 +// - Timeout 控制脚本执行超时 +// +// 使用示例: +// +// scripts: +// - path: "/scripts/auth.lua" +// phase: "access" +// timeout: 10s +// enabled: true +type LuaScriptConfig struct { + // Path 脚本路径 + Path string `yaml:"path"` + + // Phase 执行阶段 + // 可选值:rewrite、access、content、log、header_filter、body_filter + Phase string `yaml:"phase"` + + // Timeout 执行超时 + Timeout time.Duration `yaml:"timeout"` + + // Enabled 是否启用此脚本(默认 true) + Enabled bool `yaml:"enabled"` +} + +// LuaGlobalSettings 全局 Lua 设置 +// +// 控制 Lua 引擎的全局行为。 +// +// 注意事项: +// - MaxConcurrentCoroutines 控制最大并发协程数 +// - CoroutineTimeout 控制协程执行超时 +// - CodeCacheSize 控制字节码缓存大小 +// +// 使用示例: +// +// global_settings: +// max_concurrent_coroutines: 1000 +// coroutine_timeout: 30s +// code_cache_size: 1000 +// enable_file_watch: true +// max_execution_time: 30s +type LuaGlobalSettings struct { + // MaxConcurrentCoroutines 最大并发协程数 + MaxConcurrentCoroutines int `yaml:"max_concurrent_coroutines"` + + // CoroutineTimeout 协程执行超时 + CoroutineTimeout time.Duration `yaml:"coroutine_timeout"` + + // CodeCacheSize 字节码缓存条目数 + CodeCacheSize int `yaml:"code_cache_size"` + + // EnableFileWatch 启用文件变更检测 + EnableFileWatch bool `yaml:"enable_file_watch"` + + // MaxExecutionTime 单脚本最大执行时间 + MaxExecutionTime time.Duration `yaml:"max_execution_time"` +} + // StreamConfig TCP/UDP Stream 代理配置。 // // 用于四层网络代理,如数据库、Redis 等 TCP/UDP 服务。 diff --git a/internal/config/lua_config_test.go b/internal/config/lua_config_test.go new file mode 100644 index 0000000..288302e --- /dev/null +++ b/internal/config/lua_config_test.go @@ -0,0 +1,156 @@ +// Package config 提供 Lua 配置测试 +package config + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestLuaMiddlewareConfigValidation 测试 Lua 配置验证 +func TestLuaMiddlewareConfigValidation(t *testing.T) { + // 未配置时跳过验证 + cfg := &LuaMiddlewareConfig{} + require.NoError(t, validateLua(cfg)) + + // 禁用时跳过验证 + cfg = &LuaMiddlewareConfig{Enabled: false} + require.NoError(t, validateLua(cfg)) + + // 启用但无脚本也允许 + cfg = &LuaMiddlewareConfig{Enabled: true} + require.NoError(t, validateLua(cfg)) + + // 有效配置 + cfg = &LuaMiddlewareConfig{ + Enabled: true, + Scripts: []LuaScriptConfig{ + {Path: "/scripts/test.lua", Phase: "rewrite", Timeout: 10 * time.Second}, + }, + GlobalSettings: LuaGlobalSettings{ + MaxConcurrentCoroutines: 1000, + CoroutineTimeout: 30 * time.Second, + CodeCacheSize: 100, + EnableFileWatch: true, + MaxExecutionTime: 30 * time.Second, + }, + } + require.NoError(t, validateLua(cfg)) +} + +// TestLuaMiddlewareConfigInvalidPhase 测试无效阶段 +func TestLuaMiddlewareConfigInvalidPhase(t *testing.T) { + cfg := &LuaMiddlewareConfig{ + Enabled: true, + Scripts: []LuaScriptConfig{ + {Path: "/scripts/test.lua", Phase: "invalid_phase"}, + }, + } + + err := validateLua(cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "scripts[0].phase 无效") +} + +// TestLuaMiddlewareConfigMissingPath 测试缺少路径 +func TestLuaMiddlewareConfigMissingPath(t *testing.T) { + cfg := &LuaMiddlewareConfig{ + Enabled: true, + Scripts: []LuaScriptConfig{ + {Phase: "rewrite"}, + }, + } + + err := validateLua(cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "scripts[0].path 必填") +} + +// TestLuaMiddlewareConfigNegativeTimeout 测试负超时 +func TestLuaMiddlewareConfigNegativeTimeout(t *testing.T) { + cfg := &LuaMiddlewareConfig{ + Enabled: true, + Scripts: []LuaScriptConfig{ + {Path: "/scripts/test.lua", Phase: "rewrite", Timeout: -5 * time.Second}, + }, + } + + err := validateLua(cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "timeout 不能为负数") +} + +// TestLuaMiddlewareConfigGlobalSettingsValidation 测试全局设置验证 +func TestLuaMiddlewareConfigGlobalSettingsValidation(t *testing.T) { + // MaxConcurrentCoroutines 为负 + cfg := &LuaMiddlewareConfig{ + Enabled: true, + GlobalSettings: LuaGlobalSettings{ + MaxConcurrentCoroutines: -1, + }, + } + err := validateLua(cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "max_concurrent_coroutines 不能为负数") + + // CoroutineTimeout 为负 + cfg = &LuaMiddlewareConfig{ + Enabled: true, + GlobalSettings: LuaGlobalSettings{ + CoroutineTimeout: -1 * time.Second, + }, + } + err = validateLua(cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "coroutine_timeout 不能为负数") +} + +// TestServerConfigLuaField 测试 ServerConfig 包含 Lua 字段 +func TestServerConfigLuaField(t *testing.T) { + cfg := &ServerConfig{ + Listen: ":8080", + Lua: &LuaMiddlewareConfig{ + Enabled: true, + Scripts: []LuaScriptConfig{ + {Path: "/scripts/auth.lua", Phase: "access"}, + }, + }, + } + + require.NotNil(t, cfg.Lua) + assert.True(t, cfg.Lua.Enabled) + assert.Len(t, cfg.Lua.Scripts, 1) + assert.Equal(t, "/scripts/auth.lua", cfg.Lua.Scripts[0].Path) + assert.Equal(t, "access", cfg.Lua.Scripts[0].Phase) +} + +// TestValidateServerWithLuaConfig 测试服务器验证包含 Lua +func TestValidateServerWithLuaConfig(t *testing.T) { + // 有效 Lua 配置 + cfg := &ServerConfig{ + Listen: ":8080", + Lua: &LuaMiddlewareConfig{ + Enabled: true, + Scripts: []LuaScriptConfig{ + {Path: "/scripts/test.lua", Phase: "content"}, + }, + }, + } + require.NoError(t, validateServer(cfg, false)) + + // 无效 Lua 配置 + cfg = &ServerConfig{ + Listen: ":8080", + Lua: &LuaMiddlewareConfig{ + Enabled: true, + Scripts: []LuaScriptConfig{ + {Path: "", Phase: "content"}, // 空路径 + }, + }, + } + err := validateServer(cfg, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "lua:") +} diff --git a/internal/config/validate.go b/internal/config/validate.go index 6a3b454..2114c6f 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -94,6 +94,11 @@ func validateServer(s *ServerConfig, isDefault bool) error { return fmt.Errorf("compression: %w", err) } + // 验证 Lua 中间件配置 + if err := validateLua(s.Lua); err != nil { + return fmt.Errorf("lua: %w", err) + } + return nil } @@ -1025,3 +1030,68 @@ func validateVariables(v *VariablesConfig) error { } return nil } + +// validateLua 验证 Lua 中间件配置。 +// +// 检查 Lua 脚本配置的有效性,包括脚本路径、执行阶段和全局设置。 +// +// 参数: +// - l: Lua 配置对象 +// +// 返回值: +// - error: 验证失败时返回错误信息,成功返回 nil +// +// 验证规则: +// - scripts[].path 必填 +// - scripts[].phase 必须是有效阶段 +// - global_settings.max_concurrent_coroutines 必须 >= 1 +func validateLua(l *LuaMiddlewareConfig) error { + // 未配置时跳过 + if l == nil || !l.Enabled { + return nil + } + + // 验证脚本配置 + for i, script := range l.Scripts { + if script.Path == "" { + return fmt.Errorf("scripts[%d].path 必填", i) + } + + // 验证阶段值 + validPhases := []string{"rewrite", "access", "content", "log", "header_filter", "body_filter"} + valid := false + for _, phase := range validPhases { + if script.Phase == phase { + valid = true + break + } + } + if !valid { + return fmt.Errorf("scripts[%d].phase 无效: %s(仅支持 rewrite, access, content, log, header_filter, body_filter)", i, script.Phase) + } + + // 超时时间验证 + if script.Timeout < 0 { + return fmt.Errorf("scripts[%d].timeout 不能为负数", i) + } + } + + // 验证全局设置 + if l.GlobalSettings.MaxConcurrentCoroutines < 0 { + return errors.New("global_settings.max_concurrent_coroutines 不能为负数") + } + if l.GlobalSettings.MaxConcurrentCoroutines > 0 && l.GlobalSettings.MaxConcurrentCoroutines < 1 { + return errors.New("global_settings.max_concurrent_coroutines 至少为 1") + } + if l.GlobalSettings.CoroutineTimeout < 0 { + return errors.New("global_settings.coroutine_timeout 不能为负数") + } + if l.GlobalSettings.CodeCacheSize < 0 { + return errors.New("global_settings.code_cache_size 不能为负数") + } + if l.GlobalSettings.MaxExecutionTime < 0 { + return errors.New("global_settings.max_execution_time 不能为负数") + } + + return nil +}