diff --git a/internal/config/validate.go b/internal/config/validate.go index 2114c6f..121f700 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -27,6 +27,24 @@ import ( "rua.plus/lolly/internal/variable" ) +// ValidateEnum 验证值是否在有效枚举列表中 +func ValidateEnum(value string, validValues []string, fieldName string) error { + for _, v := range validValues { + if value == v { + return nil + } + } + return fmt.Errorf("无效的 %s: %s(仅支持 %v)", fieldName, value, validValues) +} + +// ValidateNonNegative 验证值为非负数 +func ValidateNonNegative(value int, fieldName string) error { + if value < 0 { + return fmt.Errorf("%s 不能为负数", fieldName) + } + return nil +} + // validateServer 验证服务器配置。 // // 检查服务器配置的各项参数是否符合要求,包括监听地址、 @@ -391,16 +409,11 @@ func validateProxy(p *ProxyConfig) error { // 验证一致性哈希键格式 if p.HashKey != "" { validHashKeys := []string{"ip", "uri"} - valid := false - for _, k := range validHashKeys { - if p.HashKey == k || strings.HasPrefix(p.HashKey, "header:") { - valid = true - break + if !strings.HasPrefix(p.HashKey, "header:") { + if err := ValidateEnum(p.HashKey, validHashKeys, "hash_key"); err != nil { + return fmt.Errorf("无效的 hash_key: %s(仅支持 ip, uri 或 header:X-Name 格式)", p.HashKey) } } - if !valid { - return fmt.Errorf("无效的 hash_key: %s(仅支持 ip, uri 或 header:X-Name 格式)", p.HashKey) - } } return nil @@ -566,14 +579,7 @@ func validateAuth(a *AuthConfig) error { // 验证哈希算法 validAlgorithms := []string{"", "bcrypt", "argon2id"} - valid := false - for _, alg := range validAlgorithms { - if a.Algorithm == alg { - valid = true - break - } - } - if !valid { + if err := ValidateEnum(a.Algorithm, validAlgorithms, "哈希算法"); err != nil { return fmt.Errorf("不支持的哈希算法: %s(仅支持 bcrypt 或 argon2id)", a.Algorithm) } @@ -623,52 +629,31 @@ func validateRateLimit(r *RateLimitConfig) error { } // 验证速率限制值 - if r.RequestRate < 0 { - return errors.New("request_rate 不能为负数") + if err := ValidateNonNegative(r.RequestRate, "request_rate"); err != nil { + return err } - if r.Burst < 0 { - return errors.New("burst 不能为负数") + if err := ValidateNonNegative(r.Burst, "burst"); err != nil { + return err } - if r.ConnLimit < 0 { - return errors.New("conn_limit 不能为负数") + if err := ValidateNonNegative(r.ConnLimit, "conn_limit"); err != nil { + return err } // 验证 key 来源 validKeys := []string{"", "ip", "header"} - valid := false - for _, k := range validKeys { - if r.Key == k { - valid = true - break - } - } - if !valid { + if err := ValidateEnum(r.Key, validKeys, "key 来源"); err != nil { return fmt.Errorf("无效的 key 来源: %s(仅支持 ip 或 header)", r.Key) } // 验证限流算法 validAlgorithms := []string{"", "token_bucket", "sliding_window"} - valid = false - for _, alg := range validAlgorithms { - if r.Algorithm == alg { - valid = true - break - } - } - if !valid { + if err := ValidateEnum(r.Algorithm, validAlgorithms, "限流算法"); err != nil { return fmt.Errorf("无效的限流算法: %s(仅支持 token_bucket 或 sliding_window)", r.Algorithm) } // 验证滑动窗口模式 validModes := []string{"", "approximate", "precise"} - valid = false - for _, mode := range validModes { - if r.SlidingWindowMode == mode { - valid = true - break - } - } - if !valid { + if err := ValidateEnum(r.SlidingWindowMode, validModes, "滑动窗口模式"); err != nil { return fmt.Errorf("无效的滑动窗口模式: %s(仅支持 approximate 或 precise)", r.SlidingWindowMode) } @@ -737,14 +722,7 @@ func validateCompression(c *CompressionConfig) error { // 验证压缩类型 validTypes := []string{"gzip", "brotli", "both"} - valid := false - for _, t := range validTypes { - if c.Type == t { - valid = true - break - } - } - if !valid { + if err := ValidateEnum(c.Type, validTypes, "压缩类型"); err != nil { return fmt.Errorf("无效的压缩类型: %s(仅支持 gzip, brotli 或 both)", c.Type) } @@ -754,8 +732,8 @@ func validateCompression(c *CompressionConfig) error { } // 验证最小压缩大小 - if c.MinSize < 0 { - return errors.New("min_size 不能为负数") + if err := ValidateNonNegative(c.MinSize, "min_size"); err != nil { + return err } return nil @@ -812,27 +790,13 @@ func validateRewrite(r *RewriteRule) error { func validateLogging(l *LoggingConfig) error { // 验证日志格式 validFormats := []string{"", "text", "json"} - valid := false - for _, f := range validFormats { - if l.Format == f { - valid = true - break - } - } - if !valid { + if err := ValidateEnum(l.Format, validFormats, "日志格式"); err != nil { return fmt.Errorf("无效的日志格式: %s(仅支持 text 或 json)", l.Format) } // 验证错误日志级别 validLevels := []string{"", "debug", "info", "warn", "error"} - valid = false - for _, lvl := range validLevels { - if l.Error.Level == lvl { - valid = true - break - } - } - if !valid { + if err := ValidateEnum(l.Error.Level, validLevels, "日志级别"); err != nil { return fmt.Errorf("无效的日志级别: %s(仅支持 debug, info, warn, error)", l.Error.Level) }