fix(config): real circular include detection with visited set

Replace depth-only detection with path-based visited set tracking.
Detects cycles immediately on first revisit instead of after 10 depth
iterations. Supports diamond patterns (A->B->shared, A->C->shared)
via backtracking. Add self-include and diamond tests. Document that
only servers/stream/variables are merged in defaults.go.
This commit is contained in:
xfy 2026-06-03 11:51:17 +08:00
parent 556d40ceb0
commit 9b8ce2a08a
3 changed files with 93 additions and 5 deletions

View File

@ -140,7 +140,9 @@ func Load(path string) (*Config, error) {
}
if len(cfg.Include) > 0 {
if err := processIncludes(&cfg, filepath.Dir(path), 0); err != nil {
absPath, _ := filepath.Abs(path)
visited := map[string]bool{absPath: true}
if err := processIncludes(&cfg, filepath.Dir(path), 0, visited); err != nil {
return nil, fmt.Errorf("处理配置引入失败: %w", err)
}
}
@ -154,9 +156,9 @@ func Load(path string) (*Config, error) {
const maxIncludeDepth = 10
func processIncludes(cfg *Config, baseDir string, depth int) error {
func processIncludes(cfg *Config, baseDir string, depth int, visited map[string]bool) error {
if depth >= maxIncludeDepth {
return fmt.Errorf("配置引入嵌套深度超过 %d 层,可能存在循环引入", maxIncludeDepth)
return fmt.Errorf("配置引入嵌套深度超过 %d 层", maxIncludeDepth)
}
for _, inc := range cfg.Include {
@ -174,11 +176,18 @@ func processIncludes(cfg *Config, baseDir string, depth int) error {
}
for _, match := range matches {
absMatch, _ := filepath.Abs(match)
if visited[absMatch] {
return fmt.Errorf("检测到循环引入: %s", absMatch)
}
visited[absMatch] = true
info, err := os.Stat(match)
if err != nil {
return fmt.Errorf("读取引入文件 %q 失败: %w", match, err)
}
if info.IsDir() {
delete(visited, absMatch)
continue
}
@ -193,7 +202,7 @@ func processIncludes(cfg *Config, baseDir string, depth int) error {
}
if len(included.Include) > 0 {
if err := processIncludes(&included, filepath.Dir(match), depth+1); err != nil {
if err := processIncludes(&included, filepath.Dir(match), depth+1, visited); err != nil {
return err
}
}
@ -208,6 +217,8 @@ func processIncludes(cfg *Config, baseDir string, depth int) error {
cfg.Variables.Set[k] = v
}
}
delete(visited, absMatch)
}
}

View File

@ -4,6 +4,7 @@ package config
import (
"os"
"path/filepath"
"strings"
"testing"
)
@ -598,7 +599,7 @@ include:
}
})
t.Run("circular include detected", func(t *testing.T) {
t.Run("circular include detected immediately", func(t *testing.T) {
tmpDir := t.TempDir()
cfg1 := `
@ -624,6 +625,81 @@ include:
if err == nil {
t.Error("expected error for circular include")
}
if !strings.Contains(err.Error(), "循环引入") {
t.Errorf("error should mention circular include, got: %v", err)
}
})
t.Run("self include detected", func(t *testing.T) {
tmpDir := t.TempDir()
cfg := `
servers:
- listen: ":8080"
include:
- path: "a.yaml"
`
if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg), 0o644); err != nil {
t.Fatal(err)
}
_, err := Load(filepath.Join(tmpDir, "a.yaml"))
if err == nil {
t.Error("expected error for self include")
}
})
t.Run("diamond include works", func(t *testing.T) {
tmpDir := t.TempDir()
mainCfg := `
servers:
- listen: ":8080"
include:
- path: "b.yaml"
- path: "c.yaml"
`
bCfg := `
servers:
- listen: ":9090"
include:
- path: "shared.yaml"
`
cCfg := `
servers:
- listen: ":9091"
include:
- path: "shared.yaml"
`
sharedCfg := `
variables:
set:
shared: value
`
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(bCfg), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "c.yaml"), []byte(cCfg), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "shared.yaml"), []byte(sharedCfg), 0o644); err != nil {
t.Fatal(err)
}
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
if err != nil {
t.Fatalf("diamond include should work: %v", err)
}
if len(cfg.Servers) != 3 {
t.Errorf("expected 3 servers, got %d", len(cfg.Servers))
}
if cfg.Variables.Set["shared"] != "value" {
t.Error("shared variable should be merged")
}
})
t.Run("empty include list is no-op", func(t *testing.T) {

View File

@ -734,6 +734,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
buf.WriteString("# - path: \"conf.d/*.yaml\" # 相对路径 + glob 模式\n")
buf.WriteString("# - path: \"sites/example.yaml\" # 单个文件引入\n")
buf.WriteString("# 支持循环检测和深度限制(最大 10 层)\n")
buf.WriteString("# 注意:只有 servers、stream、variables 会被合并,其他字段忽略\n")
return buf.Bytes(), nil
}