From 38bb74378105b0b5e165f09af739e0c8cb5b0879 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 3 Jun 2026 10:12:09 +0800 Subject: [PATCH 1/9] fix(server): handle LocationEngine registration errors properly Add typed ConflictError for path conflicts, change register functions to return errors, handle conflicts as warnings and fatal errors as startup failures. Remove all 20 instances of ignored Add* return values. --- internal/matcher/integration_test.go | 5 ++ internal/matcher/location.go | 20 +++++- internal/server/router.go | 103 +++++++++++++++++++++------ internal/server/server.go | 51 +++++++++---- 4 files changed, 143 insertions(+), 36 deletions(-) diff --git a/internal/matcher/integration_test.go b/internal/matcher/integration_test.go index 687ba46..c2b72a0 100644 --- a/internal/matcher/integration_test.go +++ b/internal/matcher/integration_test.go @@ -1,6 +1,7 @@ package matcher import ( + "errors" "testing" "github.com/valyala/fasthttp" @@ -107,4 +108,8 @@ func TestLocationEngine_PathConflict(t *testing.T) { if err == nil { t.Error("should fail on path conflict") } + var ce *ConflictError + if !errors.As(err, &ce) { + t.Errorf("expected *ConflictError, got %T: %v", err, err) + } } diff --git a/internal/matcher/location.go b/internal/matcher/location.go index a96153d..c42a210 100644 --- a/internal/matcher/location.go +++ b/internal/matcher/location.go @@ -240,6 +240,21 @@ func (e *LocationEngine) MarkInitialized() { e.prefixTree.MarkInitialized() } +// ConflictError 路径冲突错误。 +// +// 当同一路径被重复注册为不同类型的 location 时返回此错误。 +// 调用方可通过 errors.As 检测此类型,区分冲突与致命错误。 +type ConflictError struct { + Path string + ExistingType string + NewType string +} + +func (e *ConflictError) Error() string { + return fmt.Sprintf("path conflict: '%s' already registered as '%s', trying to register as '%s'", + e.Path, e.ExistingType, e.NewType) +} + // checkConflict 检查路径冲突。 // // 参数: @@ -247,11 +262,10 @@ func (e *LocationEngine) MarkInitialized() { // - locationType: location 类型 // // 返回值: -// - error: 路径已存在时返回冲突错误 +// - error: 路径已存在时返回 *ConflictError func (e *LocationEngine) checkConflict(path, locationType string) error { if existing, ok := e.registeredPaths[path]; ok { - return fmt.Errorf("path conflict: '%s' already registered as '%s', trying to register as '%s'", - path, existing, locationType) + return &ConflictError{Path: path, ExistingType: existing, NewType: locationType} } e.registeredPaths[path] = locationType return nil diff --git a/internal/server/router.go b/internal/server/router.go index 238ba11..54c5290 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -67,7 +67,7 @@ func (s *Server) createProxyForConfig(proxyCfg *config.ProxyConfig) *proxy.Proxy // // 根据配置为 LocationEngine 注册代理路径,创建代理处理器和健康检查器。 // 支持通过 LocationType 配置不同的匹配方式。 -func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) { +func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) error { for i := range serverCfg.Proxy { proxyCfg := &serverCfg.Proxy[i] @@ -76,7 +76,6 @@ func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerC continue } - // 根据 LocationType 注册路由 locType := proxyCfg.LocationType if locType == "" { locType = matcher.LocationTypePrefix @@ -84,22 +83,47 @@ func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerC switch locType { case matcher.LocationTypeExact: - _ = s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) + if err := s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil { + if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil { + return err + } + } case matcher.LocationTypePrefixPriority: - _ = s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) + if err := s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil { + if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil { + return err + } + } case matcher.LocationTypeRegex, matcher.LocationTypeRegexCaseless: caseInsensitive := locType == matcher.LocationTypeRegexCaseless - _ = s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal) + if err := s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal); err != nil { + if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil { + return err + } + } case matcher.LocationTypeNamed: if proxyCfg.LocationName != "" { - _ = s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP) + if err := s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP); err != nil { + if err := s.handleRegistrationError("proxy", "@"+proxyCfg.LocationName, err); err != nil { + return err + } + } } case matcher.LocationTypePrefix: - _ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) + if err := s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil { + if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil { + return err + } + } default: - _ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) + if err := s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil { + if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil { + return err + } + } } } + return nil } // configureStaticHandler 配置静态文件处理器。 @@ -156,7 +180,7 @@ func (s *Server) configureStaticHandler(static *config.StaticConfig, cfg *config } // registerStaticHandlersWithLocationEngine 使用 LocationEngine 注册静态文件处理器。 -func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) { +func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) error { for _, static := range cfg.Static { staticHandler := s.configureStaticHandler(&static, cfg) path := static.Path @@ -164,7 +188,6 @@ func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConf path = "/" } - // 根据 LocationType 注册路由 locType := static.LocationType if locType == "" { locType = matcher.LocationTypePrefix @@ -172,15 +195,32 @@ func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConf switch locType { case matcher.LocationTypeExact: - _ = s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal) + if err := s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal); err != nil { + if err := s.handleRegistrationError("static", path, err); err != nil { + return err + } + } case matcher.LocationTypePrefixPriority: - _ = s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal) + if err := s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal); err != nil { + if err := s.handleRegistrationError("static", path, err); err != nil { + return err + } + } case matcher.LocationTypePrefix: - _ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal) + if err := s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal); err != nil { + if err := s.handleRegistrationError("static", path, err); err != nil { + return err + } + } default: - _ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal) + if err := s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal); err != nil { + if err := s.handleRegistrationError("static", path, err); err != nil { + return err + } + } } } + return nil } // registerProxyRoutes 注册代理路由。 @@ -324,9 +364,9 @@ func (s *Server) registerLuaRoutes(router *handler.Router, serverCfg *config.Ser // - 只有设置了 Route 字段的脚本才会被注册 // - 路由脚本不经过完整中间件链,只应用 accesslog 和 errorintercept // - 支持 exact、prefix、prefix_priority、regex、regex_caseless 匹配类型 -func (s *Server) registerLuaRoutesWithLocationEngine(serverCfg *config.ServerConfig) { +func (s *Server) registerLuaRoutesWithLocationEngine(serverCfg *config.ServerConfig) error { if s.luaEngine == nil || serverCfg.Lua == nil || !serverCfg.Lua.Enabled { - return + return nil } for _, script := range serverCfg.Lua.Scripts { @@ -348,17 +388,38 @@ func (s *Server) registerLuaRoutesWithLocationEngine(serverCfg *config.ServerCon switch routeType { case matcher.LocationTypeExact: - _ = s.locationEngine.AddExact(script.Route, handler, false) + if err := s.locationEngine.AddExact(script.Route, handler, false); err != nil { + if err := s.handleRegistrationError("lua", script.Route, err); err != nil { + return err + } + } case matcher.LocationTypePrefixPriority: - _ = s.locationEngine.AddPrefixPriority(script.Route, handler, false) + if err := s.locationEngine.AddPrefixPriority(script.Route, handler, false); err != nil { + if err := s.handleRegistrationError("lua", script.Route, err); err != nil { + return err + } + } case matcher.LocationTypeRegex: - _ = s.locationEngine.AddRegex(script.Route, handler, false, false) + if err := s.locationEngine.AddRegex(script.Route, handler, false, false); err != nil { + if err := s.handleRegistrationError("lua", script.Route, err); err != nil { + return err + } + } case matcher.LocationTypeRegexCaseless: - _ = s.locationEngine.AddRegex(script.Route, handler, true, false) + if err := s.locationEngine.AddRegex(script.Route, handler, true, false); err != nil { + if err := s.handleRegistrationError("lua", script.Route, err); err != nil { + return err + } + } default: - _ = s.locationEngine.AddPrefix(script.Route, handler, false) + if err := s.locationEngine.AddPrefix(script.Route, handler, false); err != nil { + if err := s.handleRegistrationError("lua", script.Route, err); err != nil { + return err + } + } } } + return nil } // wrapRoutedHandler 为路由处理器包装基础中间件链。 diff --git a/internal/server/server.go b/internal/server/server.go index 99dced5..ddba853 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -21,6 +21,7 @@ package server import ( "crypto/tls" + "errors" "fmt" "net" "os" @@ -96,6 +97,15 @@ func New(cfg *config.Config) *Server { return &Server{config: cfg} } +func (s *Server) handleRegistrationError(source, path string, err error) error { + var ce *matcher.ConflictError + if errors.As(err, &ce) { + logging.Warn().Msgf("Route registration skipped (%s %s): %s", source, path, err) + return nil + } + return fmt.Errorf("%s route %s: %w", source, path, err) +} + // getServerName 根据配置返回服务器名称。 // // 当 ServerTokens 为 false 时隐藏版本号,仅返回 "lolly"。 @@ -382,39 +392,56 @@ func (s *Server) startSingleMode() error { if err != nil { logging.Error().Msg("Failed to create status handler: " + err.Error()) } else { - _ = s.locationEngine.AddExact(statusHandler.Path(), statusHandler.ServeHTTP, false) + if err := s.locationEngine.AddExact(statusHandler.Path(), statusHandler.ServeHTTP, false); err != nil { + if err := s.handleRegistrationError("status", statusHandler.Path(), err); err != nil { + return err + } + } } } - // 注册 pprof 性能分析端点(如果配置) if s.config.Monitoring.Pprof.Enabled { pprofHandler, err := NewPprofHandler(&s.config.Monitoring.Pprof) if err != nil { logging.Error().Msg("Failed to create pprof handler: " + err.Error()) } else { - _ = s.locationEngine.AddExact(pprofHandler.Path(), pprofHandler.ServeHTTP, false) - _ = s.locationEngine.AddPrefixPriority(pprofHandler.Path()+"/", pprofHandler.ServeHTTP, false) + if err := s.locationEngine.AddExact(pprofHandler.Path(), pprofHandler.ServeHTTP, false); err != nil { + if err := s.handleRegistrationError("pprof", pprofHandler.Path(), err); err != nil { + return err + } + } + if err := s.locationEngine.AddPrefixPriority(pprofHandler.Path()+"/", pprofHandler.ServeHTTP, false); err != nil { + if err := s.handleRegistrationError("pprof", pprofHandler.Path()+"/", err); err != nil { + return err + } + } } } - // 注册缓存清理 API(如果配置) if serverCfg.CacheAPI != nil && serverCfg.CacheAPI.Enabled { purgeHandler, err := NewPurgeHandler(s, serverCfg.CacheAPI) if err != nil { logging.Error().Msg("Failed to create cache purge handler: " + err.Error()) } else { - _ = s.locationEngine.AddExact(purgeHandler.Path(), purgeHandler.ServeHTTP, false) + if err := s.locationEngine.AddExact(purgeHandler.Path(), purgeHandler.ServeHTTP, false); err != nil { + if err := s.handleRegistrationError("cache-purge", purgeHandler.Path(), err); err != nil { + return err + } + } } } - // 注册代理路由 - s.registerProxyRoutesWithLocationEngine(serverCfg) + if err := s.registerProxyRoutesWithLocationEngine(serverCfg); err != nil { + return err + } - // Lua 路由 - s.registerLuaRoutesWithLocationEngine(serverCfg) + if err := s.registerLuaRoutesWithLocationEngine(serverCfg); err != nil { + return err + } - // 静态文件服务 - s.registerStaticHandlersWithLocationEngine(serverCfg) + if err := s.registerStaticHandlersWithLocationEngine(serverCfg); err != nil { + return err + } // 标记 LocationEngine 初始化完成 s.locationEngine.MarkInitialized() From d9a7ab9ccaa1b0036aa3a91fa0e1161ec7b13f64 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 3 Jun 2026 10:14:07 +0800 Subject: [PATCH 2/9] cleanup(config): remove dead ProxyCachePathConfig and CachePath field Disk cache implementation was previously removed but config structs remained. Remove ProxyCachePathConfig, Config.CachePath field, e2e WithCachePath helper, and docs reference. --- docs/llms.txt | 3 +-- internal/config/cache_config.go | 48 --------------------------------- internal/config/config.go | 3 +-- internal/e2e/testutil/config.go | 9 ------- 4 files changed, 2 insertions(+), 61 deletions(-) diff --git a/docs/llms.txt b/docs/llms.txt index 2089cc5..a4da944 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -58,8 +58,7 @@ http3: {} # HTTP/3 配置 resolver: {} # DNS 解析配置 performance: {} # 性能配置 shutdown: {} # 关闭配置 -include: [] # 配置引入 -cache_path: {} # 缓存路径配置 + include: [] # 配置引入 ``` ### 运行模式 diff --git a/internal/config/cache_config.go b/internal/config/cache_config.go index a5d5f5f..3b7e1e7 100644 --- a/internal/config/cache_config.go +++ b/internal/config/cache_config.go @@ -2,54 +2,6 @@ package config import "time" -// ProxyCachePathConfig 缓存路径配置(磁盘持久化)。 -// -// 配置磁盘缓存路径和相关参数,支持 L1/L2 分层缓存架构。 -// 配置后,代理缓存将持久化到磁盘,服务重启后可恢复。 -// -// 注意事项: -// - Path 为必填项,指定缓存根目录 -// - Levels 支持最多 3 级目录(如 "1:2:2") -// - MaxSize 为 0 表示不限制大小 -// - L1MaxEntries/L1MaxSize 为 0 时使用默认值 -// -// 使用示例: -// -// cache_path: -// path: "/var/cache/lolly" -// levels: "1:2" -// max_size: "1GB" -// inactive: "60m" -// l1_max_entries: 10000 -type ProxyCachePathConfig struct { - // Path 缓存根目录 - Path string `yaml:"path"` - - // Levels 目录层级,如 "1:2" 表示两级目录 - Levels string `yaml:"levels"` - - // MaxSize 最大缓存大小(字节) - MaxSize int64 `yaml:"max_size"` - - // Inactive 未访问淘汰时间 - Inactive time.Duration `yaml:"inactive"` - - // Purger 是否启用后台清理 - Purger bool `yaml:"purger"` - - // PurgerInterval 清理间隔 - PurgerInterval time.Duration `yaml:"purger_interval"` - - // L1MaxEntries L1 最大条目数 - L1MaxEntries int64 `yaml:"l1_max_entries"` - - // L1MaxSize L1 最大内存大小 - L1MaxSize int64 `yaml:"l1_max_size"` - - // PromoteThreshold 提升到 L1 的访问阈值 - PromoteThreshold int `yaml:"promote_threshold"` -} - // ProxyCacheConfig 代理缓存配置。 // // 缓存后端响应,减少重复请求,提高响应速度。 diff --git a/internal/config/config.go b/internal/config/config.go index 10a6e14..ff5d0a1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -81,8 +81,7 @@ type Config struct { Resolver ResolverConfig `yaml:"resolver"` Performance PerformanceConfig `yaml:"performance"` Shutdown ShutdownConfig `yaml:"shutdown"` - Include []IncludeConfig `yaml:"include"` // 配置引入,支持从其他文件引入配置片段 - CachePath *ProxyCachePathConfig `yaml:"cache_path"` // 缓存路径配置(磁盘持久化) + Include []IncludeConfig `yaml:"include"` // 配置引入,支持从其他文件引入配置片段 } // parseSize 解析大小字符串(支持 k, m 单位)。 diff --git a/internal/e2e/testutil/config.go b/internal/e2e/testutil/config.go index ea1319a..704447f 100644 --- a/internal/e2e/testutil/config.go +++ b/internal/e2e/testutil/config.go @@ -462,15 +462,6 @@ func (b *ConfigBuilder) WithRewrite(pattern, replacement string, opts ...Rewrite return b } -// WithCachePath 配置缓存路径。 -func (b *ConfigBuilder) WithCachePath(path string, maxSize int64) *ConfigBuilder { - b.cfg.CachePath = &config.ProxyCachePathConfig{ - Path: path, - MaxSize: maxSize, - } - return b -} - // WithResolver 配置 DNS 解析器。 func (b *ConfigBuilder) WithResolver(addresses []string, valid, timeout time.Duration) *ConfigBuilder { b.cfg.Resolver = config.ResolverConfig{ From 53ac4c84cd420df64e87dcbbfea038c774c87403 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 3 Jun 2026 10:16:29 +0800 Subject: [PATCH 3/9] docs(AGENTS.md): fix stale references Remove non-existent config.example.yaml and .github/ directory references. Add --generate-config usage note. --- AGENTS.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 9fb5fe0..ff3487f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -14,7 +14,6 @@ | `go.sum` | 依赖版本锁定 | | `Makefile` | 构建脚本,支持多平台编译、测试、覆盖率 | | `lolly.yaml` | 默认配置文件示例 | -| `config.example.yaml` | 完整配置文件示例(所有字段枚举) | | `.gitignore` | Git 忽略规则 | ## Subdirectories @@ -27,7 +26,6 @@ | `examples/` | Lua 脚本示例 | | `html/` | 静态 HTML 文件(测试/示例) | | `scripts/` | 构建/测试辅助脚本(回归检测) | -| `.github/` | CI/CD 工作流定义 | ## For AI Agents @@ -43,6 +41,7 @@ - 运行测试前确保依赖已下载:`go mod download` - 测试覆盖率目标 >80% - 使用 `make check` 运行完整检查(fmt + lint + test) +- 使用 `lolly --generate-config` 生成完整配置文件模板 ### Common Patterns - 配置结构体使用 `yaml` 标签,通过 `gopkg.in/yaml.v3` 解析 From 2e9ddc7400c6cfc16af9fcc50b3ee5f9d2b8907e Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 3 Jun 2026 10:20:33 +0800 Subject: [PATCH 4/9] feat(config): implement include directive with glob support Support loading config fragments from external files via include directive. Servers and streams are appended, variables merged with main config priority. Includes glob expansion, nested includes (depth limit 10), and circular include detection. --- internal/config/config.go | 70 ++++++++++++ internal/config/config_test.go | 195 +++++++++++++++++++++++++++++++++ 2 files changed, 265 insertions(+) diff --git a/internal/config/config.go b/internal/config/config.go index ff5d0a1..c1c6076 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "os" + "path/filepath" "strconv" "strings" @@ -138,6 +139,12 @@ func Load(path string) (*Config, error) { return nil, fmt.Errorf("解析配置文件失败: %w", err) } + if len(cfg.Include) > 0 { + if err := processIncludes(&cfg, filepath.Dir(path), 0); err != nil { + return nil, fmt.Errorf("处理配置引入失败: %w", err) + } + } + if err := Validate(&cfg); err != nil { return nil, fmt.Errorf("配置验证失败: %w", err) } @@ -145,6 +152,69 @@ func Load(path string) (*Config, error) { return &cfg, nil } +const maxIncludeDepth = 10 + +func processIncludes(cfg *Config, baseDir string, depth int) error { + if depth >= maxIncludeDepth { + return fmt.Errorf("配置引入嵌套深度超过 %d 层,可能存在循环引入", maxIncludeDepth) + } + + for _, inc := range cfg.Include { + pattern := inc.Path + if !filepath.IsAbs(pattern) { + pattern = filepath.Join(baseDir, pattern) + } + + matches, err := filepath.Glob(pattern) + if err != nil { + return fmt.Errorf("展开引入路径 %q 失败: %w", inc.Path, err) + } + if len(matches) == 0 { + return fmt.Errorf("引入路径 %q 未匹配到任何文件", inc.Path) + } + + for _, match := range matches { + info, err := os.Stat(match) + if err != nil { + return fmt.Errorf("读取引入文件 %q 失败: %w", match, err) + } + if info.IsDir() { + continue + } + + data, err := os.ReadFile(match) + if err != nil { + return fmt.Errorf("读取引入文件 %q 失败: %w", match, err) + } + + var included Config + if err := yaml.Unmarshal(data, &included); err != nil { + return fmt.Errorf("解析引入文件 %q 失败: %w", match, err) + } + + if len(included.Include) > 0 { + if err := processIncludes(&included, filepath.Dir(match), depth+1); err != nil { + return err + } + } + + cfg.Servers = append(cfg.Servers, included.Servers...) + cfg.Stream = append(cfg.Stream, included.Stream...) + for k, v := range included.Variables.Set { + if _, exists := cfg.Variables.Set[k]; !exists { + if cfg.Variables.Set == nil { + cfg.Variables.Set = make(map[string]string) + } + cfg.Variables.Set[k] = v + } + } + } + } + + cfg.Include = nil + return nil +} + // LoadFromString 从 YAML 字符串加载配置。 // // 解析 YAML 格式的配置字符串,适用于从环境变量或命令行参数加载配置。 diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 131b27e..acb854f 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -491,3 +491,198 @@ func TestProxyBufferingConfig_ParseBuffers(t *testing.T) { }) } } + +func TestLoad_Include(t *testing.T) { + t.Run("append servers from include", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" + name: main +include: + - path: "conf.d/*.yaml" +` + incCfg := ` +servers: + - listen: ":9090" + name: included +` + confDir := filepath.Join(tmpDir, "conf.d") + if err := os.MkdirAll(confDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(confDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if len(cfg.Servers) != 2 { + t.Fatalf("expected 2 servers, got %d", len(cfg.Servers)) + } + if cfg.Servers[0].Name != "main" { + t.Errorf("Servers[0].Name = %q, want %q", cfg.Servers[0].Name, "main") + } + if cfg.Servers[1].Name != "included" { + t.Errorf("Servers[1].Name = %q, want %q", cfg.Servers[1].Name, "included") + } + }) + + t.Run("merge variables with main priority", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +variables: + set: + app: lolly + env: production +include: + - path: "extra.yaml" +` + incCfg := ` +servers: + - listen: ":9090" +variables: + set: + app: other + debug: "true" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if cfg.Variables.Set["app"] != "lolly" { + t.Errorf("app = %q, want %q (main should win)", cfg.Variables.Set["app"], "lolly") + } + if cfg.Variables.Set["debug"] != "true" { + t.Errorf("debug = %q, want %q (included should fill missing)", cfg.Variables.Set["debug"], "true") + } + if cfg.Variables.Set["env"] != "production" { + t.Errorf("env = %q, want %q", cfg.Variables.Set["env"], "production") + } + }) + + t.Run("no matches returns error", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +include: + - path: "nonexistent/*.yaml" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + + _, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err == nil { + t.Error("expected error for non-matching glob pattern") + } + }) + + t.Run("circular include detected", func(t *testing.T) { + tmpDir := t.TempDir() + + cfg1 := ` +servers: + - listen: ":8080" +include: + - path: "b.yaml" +` + cfg2 := ` +servers: + - listen: ":9090" +include: + - path: "a.yaml" +` + if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg1), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(cfg2), 0o644); err != nil { + t.Fatal(err) + } + + _, err := Load(filepath.Join(tmpDir, "a.yaml")) + if err == nil { + t.Error("expected error for circular include") + } + }) + + t.Run("empty include list is no-op", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + if len(cfg.Servers) != 1 { + t.Errorf("expected 1 server, got %d", len(cfg.Servers)) + } + }) + + t.Run("append stream from include", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +stream: + - listen: ":5432" + protocol: tcp + upstream: + targets: + - addr: "127.0.0.1:9000" +include: + - path: "stream.yaml" +` + incCfg := ` +stream: + - listen: ":5433" + protocol: udp + upstream: + targets: + - addr: "127.0.0.1:9001" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "stream.yaml"), []byte(incCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if len(cfg.Stream) != 2 { + t.Fatalf("expected 2 streams, got %d", len(cfg.Stream)) + } + }) +} From f3f78b24a862f31dd8eab2016e6e87c707289087 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 3 Jun 2026 11:42:45 +0800 Subject: [PATCH 5/9] feat(server,app): implement proper config hot reload via SIGHUP createListener now checks pre-set s.listeners (Path 2) for hot reload, not just upgradeManager.IsChild() (Path 1). Add DupListener to dup FDs so old/new servers own independent listeners. Reload rebuilds HTTP/2 and HTTP/3. Add matchInheritedListener with TCP any-addr matching. Add requiresFullRestart with VHost server count detection. --- internal/app/app.go | 111 ++++++++++++++++++++++++++++ internal/app/app_test.go | 3 - internal/server/server.go | 103 +++++++++++++++++++++----- internal/server/server_test.go | 130 +++++++++++++++++++++++++++++++++ 4 files changed, 327 insertions(+), 20 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index 2cff744..987c27c 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -4,6 +4,7 @@ package app import ( "fmt" + "net" "os" "os/signal" "syscall" @@ -164,11 +165,121 @@ func (a *App) reloadConfig() { return } + if a.srv == nil { + a.cfg = newCfg + a.logger = logging.NewAppLogger(&newCfg.Logging) + a.logger.LogStartup("Config reloaded (no running server)", nil) + return + } + + if a.requiresFullRestart(newCfg) { + logging.Warn().Msg("Config requires full restart (listen address or mode changed). Use SIGUSR2 for graceful upgrade.") + return + } + + listeners := a.srv.GetListeners() + if len(listeners) == 0 { + a.logger.Error().Msg("Cannot reload: server has no saved listeners") + return + } + + duped := make([]net.Listener, len(listeners)) + for i, ln := range listeners { + duped[i], err = server.DupListener(ln) + if err != nil { + a.logger.Error().Err(err).Msg("Failed to dup listener for reload") + return + } + } + + newSrv := server.New(newCfg) + if a.resv != nil { + newSrv.SetResolver(a.resv) + } + newSrv.SetListeners(duped) + + startErr := make(chan error, 1) + go func() { + if err := newSrv.Start(); err != nil { + startErr <- err + } + }() + + select { + case err := <-startErr: + a.logger.Error().Err(err).Msg("Failed to start new server with reloaded config") + for _, ln := range duped { + _ = ln.Close() + } + return + case <-time.After(5 * time.Second): + } + + oldSrv := a.srv + oldHTTP2 := a.http2Srv + oldHTTP3 := a.http3Srv + + a.srv = newSrv a.cfg = newCfg a.logger = logging.NewAppLogger(&newCfg.Logging) + a.http2Srv = nil + a.http3Srv = nil + + a.initVariables() + a.initHTTP2() + a.initHTTP3() + + if a.upgradeMgr != nil { + a.upgradeMgr.SetListeners(newSrv.GetListeners()) + } + + go func() { + if oldHTTP2 != nil { + _ = oldHTTP2.Stop() + } + if oldHTTP3 != nil { + _ = oldHTTP3.Stop() + } + _ = oldSrv.GracefulStop(30 * time.Second) + }() + a.logger.LogStartup("Config reloaded successfully", nil) } +func (a *App) requiresFullRestart(newCfg *config.Config) bool { + if a.cfg.GetMode() != newCfg.GetMode() { + return true + } + oldMode := a.cfg.GetMode() + switch oldMode { + case config.ServerModeSingle: + if len(a.cfg.Servers) > 0 && len(newCfg.Servers) > 0 { + if a.cfg.Servers[0].Listen != newCfg.Servers[0].Listen { + return true + } + } + case config.ServerModeVHost: + if len(a.cfg.Servers) != len(newCfg.Servers) { + return true + } + if len(a.cfg.Servers) > 0 && len(newCfg.Servers) > 0 { + if a.cfg.Servers[0].Listen != newCfg.Servers[0].Listen { + return true + } + } + case config.ServerModeMultiServer: + if len(a.cfg.Servers) != len(newCfg.Servers) { + return true + } + for i := range a.cfg.Servers { + if a.cfg.Servers[i].Listen != newCfg.Servers[i].Listen { + return true + } + } + } + return false +} + func (a *App) gracefulUpgrade() { execPath, err := os.Executable() if err != nil { diff --git a/internal/app/app_test.go b/internal/app/app_test.go index dd8b32f..42e4dba 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -461,7 +461,6 @@ func TestHandleSignal_SIGINT(t *testing.T) { // TestHandleSignal_SIGHUP 测试 SIGHUP 信号处理(重载配置) func TestHandleSignal_SIGHUP(t *testing.T) { - // 创建临时配置文件 tmpDir := t.TempDir() cfgPath := filepath.Join(tmpDir, "config.yaml") cfgContent := ` @@ -1448,14 +1447,12 @@ logging: } app.logger = setupTestLogger() - // 发送 SIGHUP 信号 result := app.handleSignal(syscall.SIGHUP) if result != true { t.Error("Expected handleSignal(SIGHUP) to return true") } - // 验证配置已更新 if app.cfg.Servers[0].Listen != ":7070" { t.Errorf("Expected listen ':7070', got '%s'", app.cfg.Servers[0].Listen) } diff --git a/internal/server/server.go b/internal/server/server.go index ddba853..834d282 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -318,32 +318,31 @@ func (s *Server) Start() error { func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error) { listenAddr := cfg.Listen + if s.upgradeManager != nil && s.upgradeManager.IsChild() { + inherited, _ := s.upgradeManager.GetInheritedListeners() + if ln := s.matchInheritedListener(inherited, listenAddr); ln != nil { + return ln, nil + } + } + + if len(s.listeners) > 0 { + if ln := s.matchInheritedListener(s.listeners, listenAddr); ln != nil { + return ln, nil + } + } + if strings.HasPrefix(listenAddr, "unix:") { - // Unix Socket 模式 socketPath := listenAddr[5:] - // 1. 检查继承的监听器(热升级场景) - if s.upgradeManager != nil && s.upgradeManager.IsChild() { - inherited, _ := s.upgradeManager.GetInheritedListeners() - for _, ln := range inherited { - if ln.Addr().Network() == "unix" && ln.Addr().String() == socketPath { - return ln, nil - } - } - } - - // 2. 清理旧 socket 文件 if _, err := os.Stat(socketPath); err == nil { _ = os.Remove(socketPath) } - // 3. 创建 Unix socket listener listener, err := net.Listen("unix", socketPath) if err != nil { return nil, fmt.Errorf("create unix socket failed: %w", err) } - // 4. 设置 socket 文件权限 mode := 0o666 if cfg.UnixSocket.Mode > 0 { mode = cfg.UnixSocket.Mode @@ -352,19 +351,89 @@ func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error) logging.Warn().Err(err).Msg("Failed to set socket file permissions") } - // 5. 设置文件所有权(需要 root 权限) if cfg.UnixSocket.User != "" || cfg.UnixSocket.Group != "" { - // 简化处理:仅记录警告,实际实现需要 syscall.Chown logging.Warn().Msg("Unix socket user/group config requires root privileges, skipped") } return listener, nil } - // TCP 模式 return net.Listen("tcp", listenAddr) } +func (s *Server) matchInheritedListener(inherited []net.Listener, listenAddr string) net.Listener { + if len(inherited) == 0 { + return nil + } + + if strings.HasPrefix(listenAddr, "unix:") { + socketPath := listenAddr[5:] + for _, ln := range inherited { + if ln == nil { + continue + } + if ln.Addr().Network() == "unix" && ln.Addr().String() == socketPath { + return ln + } + } + return nil + } + + for _, ln := range inherited { + if ln == nil { + continue + } + if ln.Addr().Network() != "tcp" { + continue + } + if s.tcpAddrMatch(ln.Addr().String(), listenAddr) { + return ln + } + } + return nil +} + +func (s *Server) tcpAddrMatch(inherited, target string) bool { + if inherited == target { + return true + } + host1, port1, err1 := net.SplitHostPort(inherited) + host2, port2, err2 := net.SplitHostPort(target) + if err1 != nil || err2 != nil { + return false + } + if port1 != port2 { + return false + } + return host1 == host2 || isAnyAddr(host1) || isAnyAddr(host2) +} + +func isAnyAddr(host string) bool { + return host == "" || host == "0.0.0.0" || host == "::" || host == "[::]" +} + +// DupListener 复制 listener 的文件描述符,返回独立的 listener。 +// +// 用于热重载场景:新旧 server 各自持有独立 FD,互不影响关闭操作。 +func DupListener(ln net.Listener) (net.Listener, error) { + switch l := ln.(type) { + case *net.TCPListener: + file, err := l.File() + if err != nil { + return nil, fmt.Errorf("dup tcp listener: %w", err) + } + return net.FileListener(file) + case *net.UnixListener: + file, err := l.File() + if err != nil { + return nil, fmt.Errorf("dup unix listener: %w", err) + } + return net.FileListener(file) + default: + return nil, fmt.Errorf("unsupported listener type: %T", ln) + } +} + // startSingleMode 单服务器模式启动。 // // 在单服务器模式下,创建单一路由器,注册代理路由和静态文件服务, diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 938cb80..2a4b76c 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -989,6 +989,136 @@ func TestCreateListener_UnixSocketCleanup(t *testing.T) { defer ln.Close() } +func TestDupListener_TCP(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + duped, err := DupListener(ln) + if err != nil { + t.Fatalf("DupListener() error: %v", err) + } + defer duped.Close() + + if duped.Addr().Network() != "tcp" { + t.Errorf("expected tcp, got %s", duped.Addr().Network()) + } + if duped.Addr().String() != ln.Addr().String() { + t.Errorf("expected same address %s, got %s", ln.Addr().String(), duped.Addr().String()) + } +} + +func TestDupListener_Unix(t *testing.T) { + dir := t.TempDir() + socketPath := dir + "/dup.sock" + ln, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + duped, err := DupListener(ln) + if err != nil { + t.Fatalf("DupListener() error: %v", err) + } + defer duped.Close() +} + +func TestDupListener_Unsupported(t *testing.T) { + _, err := DupListener(struct{ net.Listener }{}) + if err == nil { + t.Error("expected error for unsupported type") + } +} + +func TestTcpAddrMatch(t *testing.T) { + s := &Server{} + + tests := []struct { + inherited string + target string + want bool + }{ + {"127.0.0.1:8080", "127.0.0.1:8080", true}, + {"0.0.0.0:8080", ":8080", true}, + {"[::]:8080", ":8080", true}, + {"0.0.0.0:8080", "0.0.0.0:8080", true}, + {"0.0.0.0:8080", "127.0.0.1:8080", true}, + {"127.0.0.1:8080", "0.0.0.0:8080", true}, + {"127.0.0.1:8080", ":9090", false}, + {"127.0.0.1:8080", "192.168.1.1:8080", false}, + } + + for _, tt := range tests { + got := s.tcpAddrMatch(tt.inherited, tt.target) + if got != tt.want { + t.Errorf("tcpAddrMatch(%q, %q) = %v, want %v", tt.inherited, tt.target, got, tt.want) + } + } +} + +func TestMatchInheritedListener_TCP(t *testing.T) { + s := &Server{} + + ln1, _ := net.Listen("tcp", "127.0.0.1:0") + defer ln1.Close() + + ln2, _ := net.Listen("tcp", "127.0.0.1:0") + defer ln2.Close() + + inherited := []net.Listener{ln1, ln2} + + result := s.matchInheritedListener(inherited, "0.0.0.0:99999") + if result != nil { + t.Error("expected nil for non-matching address") + } + + addr1 := ln1.Addr().String() + result = s.matchInheritedListener(inherited, addr1) + if result != ln1 { + t.Errorf("expected ln1 for address %s", addr1) + } +} + +func TestMatchInheritedListener_Empty(t *testing.T) { + s := &Server{} + result := s.matchInheritedListener(nil, ":8080") + if result != nil { + t.Error("expected nil for empty inherited list") + } +} + +func TestMatchInheritedListener_PresetListeners(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{Listen: "127.0.0.1:0"}}, + } + s := New(cfg) + + ln, err := s.createListener(&cfg.Servers[0]) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + s.SetListeners([]net.Listener{ln}) + + addr := ln.Addr().String() + cfg.Servers[0].Listen = addr + + matched, err := s.createListener(&cfg.Servers[0]) + if err != nil { + t.Fatalf("createListener with preset should reuse: %v", err) + } + if matched == nil { + t.Fatal("expected non-nil listener from preset match") + } + if matched.Addr().String() != addr { + t.Errorf("expected same address %s, got %s", addr, matched.Addr().String()) + } +} + // TestServer_StatsMethods 测试服务器统计方法。 func TestServer_StatsMethods(t *testing.T) { cfg := &config.Config{ From f58f19475249668d4bcc4a6d517f81d7e6db763f Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 3 Jun 2026 11:44:14 +0800 Subject: [PATCH 6/9] fix(server): serialize listener creation in multi-server mode Remove VHost fallback during graceful upgrade. Serialize listener creation before parallel router/middleware setup to prevent concurrent inherited listener consumption. Fix tcpAddrMatch to match when either side is any-addr (0.0.0.0/::). --- internal/server/server.go | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index 834d282..c55747d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -709,22 +709,28 @@ func (s *Server) startVHostMode() error { // // 注意事项: // - 每个服务器有独立的中间件配置 -// - 热升级场景下回退到虚拟主机模式 // - 使用 goroutine 并行启动多个服务器 func (s *Server) startMultiServerMode() error { - // 热升级检测:multi_server 热升级未实现,回退到 vhost 模式 - if os.Getenv("GRACEFUL_UPGRADE") == "1" { - logging.Warn().Msg("multi_server mode not implemented for graceful upgrade, falling back to vhost mode") - return s.startVHostMode() - } - s.fastServers = make([]*fasthttp.Server, len(s.config.Servers)) s.listeners = make([]net.Listener, len(s.config.Servers)) + for i := range s.config.Servers { + serverCfg := &s.config.Servers[i] + ln, err := s.createListener(serverCfg) + if err != nil { + for j := 0; j < i; j++ { + if s.listeners[j] != nil { + _ = s.listeners[j].Close() + } + } + return fmt.Errorf("failed to listen on %s: %w", serverCfg.Listen, err) + } + s.listeners[i] = ln + } + var wg sync.WaitGroup errCh := make(chan error, len(s.config.Servers)) - // 并行创建监听器和 fasthttp.Server for i := range s.config.Servers { wg.Add(1) go func(idx int) { @@ -732,15 +738,6 @@ func (s *Server) startMultiServerMode() error { serverCfg := &s.config.Servers[idx] - // 创建监听器 - ln, err := s.createListener(serverCfg) - if err != nil { - errCh <- fmt.Errorf("failed to listen on %s: %w", serverCfg.Listen, err) - return - } - s.listeners[idx] = ln - - // 创建路由器 router := handler.NewRouter() // 注册状态监控端点(仅默认服务器) From 556d40ceb04b939556ea6b977ebaecc45ed8e843 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 3 Jun 2026 11:47:06 +0800 Subject: [PATCH 7/9] fix(matcher,server): use ConflictError in AddNamed and add tests Make AddNamed return *ConflictError for consistency with other Add* methods so handleRegistrationError treats named location conflicts as warnings instead of fatal errors. Add tests for handleRegistrationError covering both conflict and fatal error paths. --- internal/matcher/location.go | 8 ++++++-- internal/server/server_test.go | 23 +++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/internal/matcher/location.go b/internal/matcher/location.go index c42a210..1df4521 100644 --- a/internal/matcher/location.go +++ b/internal/matcher/location.go @@ -173,8 +173,12 @@ func (e *LocationEngine) AddNamed(name string, handler fasthttp.RequestHandler) return errors.New("LocationEngine already initialized") } - if existing, ok := e.namedMatchers[name]; ok { - return fmt.Errorf("named location '@%s' already registered", existing.name) + if _, ok := e.namedMatchers[name]; ok { + return &ConflictError{ + Path: "@" + name, + ExistingType: "named", + NewType: "named", + } } matcher := NewNamedMatcher(name, handler) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 2a4b76c..a68c82d 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -18,6 +18,7 @@ import ( "fmt" "net" "os" + "strings" "sync" "testing" "time" @@ -26,6 +27,7 @@ import ( "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/lua" + "rua.plus/lolly/internal/matcher" "rua.plus/lolly/internal/middleware/accesslog" "rua.plus/lolly/internal/middleware/security" "rua.plus/lolly/internal/proxy" @@ -989,6 +991,27 @@ func TestCreateListener_UnixSocketCleanup(t *testing.T) { defer ln.Close() } +func TestHandleRegistrationError_ConflictWarning(t *testing.T) { + s := &Server{} + err := s.handleRegistrationError("proxy", "/api", + &matcher.ConflictError{Path: "/api", ExistingType: "exact", NewType: "prefix"}) + if err != nil { + t.Errorf("conflict should return nil, got: %v", err) + } +} + +func TestHandleRegistrationError_FatalError(t *testing.T) { + s := &Server{} + err := s.handleRegistrationError("proxy", "/api", + fmt.Errorf("invalid regex pattern: missing closing parenthesis")) + if err == nil { + t.Error("fatal error should return non-nil") + } + if !strings.Contains(err.Error(), "proxy route /api") { + t.Errorf("error should wrap context, got: %v", err) + } +} + func TestDupListener_TCP(t *testing.T) { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { From 9b8ce2a08a96e78fc5be080b367842cc7c63c4e4 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 3 Jun 2026 11:51:17 +0800 Subject: [PATCH 8/9] fix(config): real circular include detection with visited set Replace depth-only detection with path-based visited set tracking. Detects cycles immediately on first revisit instead of after 10 depth iterations. Supports diamond patterns (A->B->shared, A->C->shared) via backtracking. Add self-include and diamond tests. Document that only servers/stream/variables are merged in defaults.go. --- internal/config/config.go | 19 +++++++-- internal/config/config_test.go | 78 +++++++++++++++++++++++++++++++++- internal/config/defaults.go | 1 + 3 files changed, 93 insertions(+), 5 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index c1c6076..3d49bfc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -140,7 +140,9 @@ func Load(path string) (*Config, error) { } if len(cfg.Include) > 0 { - if err := processIncludes(&cfg, filepath.Dir(path), 0); err != nil { + absPath, _ := filepath.Abs(path) + visited := map[string]bool{absPath: true} + if err := processIncludes(&cfg, filepath.Dir(path), 0, visited); err != nil { return nil, fmt.Errorf("处理配置引入失败: %w", err) } } @@ -154,9 +156,9 @@ func Load(path string) (*Config, error) { const maxIncludeDepth = 10 -func processIncludes(cfg *Config, baseDir string, depth int) error { +func processIncludes(cfg *Config, baseDir string, depth int, visited map[string]bool) error { if depth >= maxIncludeDepth { - return fmt.Errorf("配置引入嵌套深度超过 %d 层,可能存在循环引入", maxIncludeDepth) + return fmt.Errorf("配置引入嵌套深度超过 %d 层", maxIncludeDepth) } for _, inc := range cfg.Include { @@ -174,11 +176,18 @@ func processIncludes(cfg *Config, baseDir string, depth int) error { } for _, match := range matches { + absMatch, _ := filepath.Abs(match) + if visited[absMatch] { + return fmt.Errorf("检测到循环引入: %s", absMatch) + } + visited[absMatch] = true + info, err := os.Stat(match) if err != nil { return fmt.Errorf("读取引入文件 %q 失败: %w", match, err) } if info.IsDir() { + delete(visited, absMatch) continue } @@ -193,7 +202,7 @@ func processIncludes(cfg *Config, baseDir string, depth int) error { } if len(included.Include) > 0 { - if err := processIncludes(&included, filepath.Dir(match), depth+1); err != nil { + if err := processIncludes(&included, filepath.Dir(match), depth+1, visited); err != nil { return err } } @@ -208,6 +217,8 @@ func processIncludes(cfg *Config, baseDir string, depth int) error { cfg.Variables.Set[k] = v } } + + delete(visited, absMatch) } } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index acb854f..1bf5d9b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -4,6 +4,7 @@ package config import ( "os" "path/filepath" + "strings" "testing" ) @@ -598,7 +599,7 @@ include: } }) - t.Run("circular include detected", func(t *testing.T) { + t.Run("circular include detected immediately", func(t *testing.T) { tmpDir := t.TempDir() cfg1 := ` @@ -624,6 +625,81 @@ include: if err == nil { t.Error("expected error for circular include") } + if !strings.Contains(err.Error(), "循环引入") { + t.Errorf("error should mention circular include, got: %v", err) + } + }) + + t.Run("self include detected", func(t *testing.T) { + tmpDir := t.TempDir() + + cfg := ` +servers: + - listen: ":8080" +include: + - path: "a.yaml" +` + if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg), 0o644); err != nil { + t.Fatal(err) + } + + _, err := Load(filepath.Join(tmpDir, "a.yaml")) + if err == nil { + t.Error("expected error for self include") + } + }) + + t.Run("diamond include works", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +include: + - path: "b.yaml" + - path: "c.yaml" +` + bCfg := ` +servers: + - listen: ":9090" +include: + - path: "shared.yaml" +` + cCfg := ` +servers: + - listen: ":9091" +include: + - path: "shared.yaml" +` + sharedCfg := ` +variables: + set: + shared: value +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(bCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "c.yaml"), []byte(cCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "shared.yaml"), []byte(sharedCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("diamond include should work: %v", err) + } + + if len(cfg.Servers) != 3 { + t.Errorf("expected 3 servers, got %d", len(cfg.Servers)) + } + if cfg.Variables.Set["shared"] != "value" { + t.Error("shared variable should be merged") + } }) t.Run("empty include list is no-op", func(t *testing.T) { diff --git a/internal/config/defaults.go b/internal/config/defaults.go index b4f4bce..1f9ec9d 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -734,6 +734,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString("# - path: \"conf.d/*.yaml\" # 相对路径 + glob 模式\n") buf.WriteString("# - path: \"sites/example.yaml\" # 单个文件引入\n") buf.WriteString("# 支持循环检测和深度限制(最大 10 层)\n") + buf.WriteString("# 注意:只有 servers、stream、variables 会被合并,其他字段忽略\n") return buf.Bytes(), nil } From 728a9f454b5ebd0d328e7f5df5bb39356999645f Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 3 Jun 2026 13:16:05 +0800 Subject: [PATCH 9/9] fix(server,app,config): address code review findings - Fix FD leak in DupListener: close *os.File after net.FileListener - Add cleanup of partially-duped listeners on DupListener failure - Make reload timeout configurable via shutdown.reload_timeout - Handle filepath.Abs errors in processIncludes instead of ignoring - Use net.ParseIP in isAnyAddr for robust IPv6 support --- internal/app/app.go | 10 +++++++++- internal/config/config.go | 10 ++++++++-- internal/config/defaults.go | 2 ++ internal/config/performance_config.go | 5 +++++ internal/server/server.go | 8 +++++++- 5 files changed, 31 insertions(+), 4 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index 987c27c..aa47e68 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -187,6 +187,9 @@ func (a *App) reloadConfig() { for i, ln := range listeners { duped[i], err = server.DupListener(ln) if err != nil { + for j := 0; j < i; j++ { + _ = duped[j].Close() + } a.logger.Error().Err(err).Msg("Failed to dup listener for reload") return } @@ -205,6 +208,11 @@ func (a *App) reloadConfig() { } }() + reloadTimeout := a.cfg.Shutdown.ReloadTimeout + if reloadTimeout <= 0 { + reloadTimeout = 5 * time.Second + } + select { case err := <-startErr: a.logger.Error().Err(err).Msg("Failed to start new server with reloaded config") @@ -212,7 +220,7 @@ func (a *App) reloadConfig() { _ = ln.Close() } return - case <-time.After(5 * time.Second): + case <-time.After(reloadTimeout): } oldSrv := a.srv diff --git a/internal/config/config.go b/internal/config/config.go index 3d49bfc..b790326 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -140,7 +140,10 @@ func Load(path string) (*Config, error) { } if len(cfg.Include) > 0 { - absPath, _ := filepath.Abs(path) + absPath, err := filepath.Abs(path) + if err != nil { + return nil, fmt.Errorf("获取配置文件绝对路径失败: %w", err) + } visited := map[string]bool{absPath: true} if err := processIncludes(&cfg, filepath.Dir(path), 0, visited); err != nil { return nil, fmt.Errorf("处理配置引入失败: %w", err) @@ -176,7 +179,10 @@ func processIncludes(cfg *Config, baseDir string, depth int, visited map[string] } for _, match := range matches { - absMatch, _ := filepath.Abs(match) + absMatch, err := filepath.Abs(match) + if err != nil { + return fmt.Errorf("获取引入文件绝对路径失败 %q: %w", match, err) + } if visited[absMatch] { return fmt.Errorf("检测到循环引入: %s", absMatch) } diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 1f9ec9d..dc2e0e4 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -222,6 +222,7 @@ func DefaultConfig() *Config { Shutdown: ShutdownConfig{ GracefulTimeout: 30 * time.Second, FastTimeout: 5 * time.Second, + ReloadTimeout: 5 * time.Second, }, } } @@ -612,6 +613,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString("shutdown:\n") fmt.Fprintf(&buf, " graceful_timeout: %ds # 优雅停止超时(SIGQUIT),等待活跃请求完成(0=使用默认30s)\n", int(cfg.Shutdown.GracefulTimeout.Seconds())) fmt.Fprintf(&buf, " fast_timeout: %ds # 快速停止超时(SIGINT/SIGTERM,0=使用默认5s)\n", int(cfg.Shutdown.FastTimeout.Seconds())) + fmt.Fprintf(&buf, " reload_timeout: %ds # 热重载启动等待超时(SIGHUP,0=使用默认5s)\n", int(cfg.Shutdown.ReloadTimeout.Seconds())) buf.WriteString("\n") // stream 配置 diff --git a/internal/config/performance_config.go b/internal/config/performance_config.go index a644264..b2926b0 100644 --- a/internal/config/performance_config.go +++ b/internal/config/performance_config.go @@ -181,6 +181,11 @@ type ShutdownConfig struct { // 接收到 SIGINT 或 SIGTERM 信号后,等待服务器关闭的最大时间 // 默认: 5s(当值为 0 时使用默认值) FastTimeout time.Duration `yaml:"fast_timeout"` + + // ReloadTimeout 热重载启动等待超时(SIGHUP) + // 等待新服务器启动完成的最大时间,超时后视为启动成功 + // 默认: 5s(当值为 0 时使用默认值) + ReloadTimeout time.Duration `yaml:"reload_timeout"` } // ResolverConfig DNS 解析器配置。 diff --git a/internal/server/server.go b/internal/server/server.go index c55747d..5c0e4ce 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -409,7 +409,11 @@ func (s *Server) tcpAddrMatch(inherited, target string) bool { } func isAnyAddr(host string) bool { - return host == "" || host == "0.0.0.0" || host == "::" || host == "[::]" + if host == "" { + return true + } + ip := net.ParseIP(host) + return ip != nil && ip.IsUnspecified() } // DupListener 复制 listener 的文件描述符,返回独立的 listener。 @@ -422,12 +426,14 @@ func DupListener(ln net.Listener) (net.Listener, error) { if err != nil { return nil, fmt.Errorf("dup tcp listener: %w", err) } + defer file.Close() return net.FileListener(file) case *net.UnixListener: file, err := l.File() if err != nil { return nil, fmt.Errorf("dup unix listener: %w", err) } + defer file.Close() return net.FileListener(file) default: return nil, fmt.Errorf("unsupported listener type: %T", ln)