From 9b8ce2a08a96e78fc5be080b367842cc7c63c4e4 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 3 Jun 2026 11:51:17 +0800 Subject: [PATCH] fix(config): real circular include detection with visited set Replace depth-only detection with path-based visited set tracking. Detects cycles immediately on first revisit instead of after 10 depth iterations. Supports diamond patterns (A->B->shared, A->C->shared) via backtracking. Add self-include and diamond tests. Document that only servers/stream/variables are merged in defaults.go. --- internal/config/config.go | 19 +++++++-- internal/config/config_test.go | 78 +++++++++++++++++++++++++++++++++- internal/config/defaults.go | 1 + 3 files changed, 93 insertions(+), 5 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index c1c6076..3d49bfc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -140,7 +140,9 @@ func Load(path string) (*Config, error) { } if len(cfg.Include) > 0 { - if err := processIncludes(&cfg, filepath.Dir(path), 0); err != nil { + absPath, _ := filepath.Abs(path) + visited := map[string]bool{absPath: true} + if err := processIncludes(&cfg, filepath.Dir(path), 0, visited); err != nil { return nil, fmt.Errorf("处理配置引入失败: %w", err) } } @@ -154,9 +156,9 @@ func Load(path string) (*Config, error) { const maxIncludeDepth = 10 -func processIncludes(cfg *Config, baseDir string, depth int) error { +func processIncludes(cfg *Config, baseDir string, depth int, visited map[string]bool) error { if depth >= maxIncludeDepth { - return fmt.Errorf("配置引入嵌套深度超过 %d 层,可能存在循环引入", maxIncludeDepth) + return fmt.Errorf("配置引入嵌套深度超过 %d 层", maxIncludeDepth) } for _, inc := range cfg.Include { @@ -174,11 +176,18 @@ func processIncludes(cfg *Config, baseDir string, depth int) error { } for _, match := range matches { + absMatch, _ := filepath.Abs(match) + if visited[absMatch] { + return fmt.Errorf("检测到循环引入: %s", absMatch) + } + visited[absMatch] = true + info, err := os.Stat(match) if err != nil { return fmt.Errorf("读取引入文件 %q 失败: %w", match, err) } if info.IsDir() { + delete(visited, absMatch) continue } @@ -193,7 +202,7 @@ func processIncludes(cfg *Config, baseDir string, depth int) error { } if len(included.Include) > 0 { - if err := processIncludes(&included, filepath.Dir(match), depth+1); err != nil { + if err := processIncludes(&included, filepath.Dir(match), depth+1, visited); err != nil { return err } } @@ -208,6 +217,8 @@ func processIncludes(cfg *Config, baseDir string, depth int) error { cfg.Variables.Set[k] = v } } + + delete(visited, absMatch) } } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index acb854f..1bf5d9b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -4,6 +4,7 @@ package config import ( "os" "path/filepath" + "strings" "testing" ) @@ -598,7 +599,7 @@ include: } }) - t.Run("circular include detected", func(t *testing.T) { + t.Run("circular include detected immediately", func(t *testing.T) { tmpDir := t.TempDir() cfg1 := ` @@ -624,6 +625,81 @@ include: if err == nil { t.Error("expected error for circular include") } + if !strings.Contains(err.Error(), "循环引入") { + t.Errorf("error should mention circular include, got: %v", err) + } + }) + + t.Run("self include detected", func(t *testing.T) { + tmpDir := t.TempDir() + + cfg := ` +servers: + - listen: ":8080" +include: + - path: "a.yaml" +` + if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg), 0o644); err != nil { + t.Fatal(err) + } + + _, err := Load(filepath.Join(tmpDir, "a.yaml")) + if err == nil { + t.Error("expected error for self include") + } + }) + + t.Run("diamond include works", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +include: + - path: "b.yaml" + - path: "c.yaml" +` + bCfg := ` +servers: + - listen: ":9090" +include: + - path: "shared.yaml" +` + cCfg := ` +servers: + - listen: ":9091" +include: + - path: "shared.yaml" +` + sharedCfg := ` +variables: + set: + shared: value +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(bCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "c.yaml"), []byte(cCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "shared.yaml"), []byte(sharedCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("diamond include should work: %v", err) + } + + if len(cfg.Servers) != 3 { + t.Errorf("expected 3 servers, got %d", len(cfg.Servers)) + } + if cfg.Variables.Set["shared"] != "value" { + t.Error("shared variable should be merged") + } }) t.Run("empty include list is no-op", func(t *testing.T) { diff --git a/internal/config/defaults.go b/internal/config/defaults.go index b4f4bce..1f9ec9d 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -734,6 +734,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString("# - path: \"conf.d/*.yaml\" # 相对路径 + glob 模式\n") buf.WriteString("# - path: \"sites/example.yaml\" # 单个文件引入\n") buf.WriteString("# 支持循环检测和深度限制(最大 10 层)\n") + buf.WriteString("# 注意:只有 servers、stream、variables 会被合并,其他字段忽略\n") return buf.Bytes(), nil }