fix(config): real circular include detection with visited set
Replace depth-only detection with path-based visited set tracking. Detects cycles immediately on first revisit instead of after 10 depth iterations. Supports diamond patterns (A->B->shared, A->C->shared) via backtracking. Add self-include and diamond tests. Document that only servers/stream/variables are merged in defaults.go.
This commit is contained in:
parent
556d40ceb0
commit
9b8ce2a08a
@ -140,7 +140,9 @@ func Load(path string) (*Config, error) {
|
||||
}
|
||||
|
||||
if len(cfg.Include) > 0 {
|
||||
if err := processIncludes(&cfg, filepath.Dir(path), 0); err != nil {
|
||||
absPath, _ := filepath.Abs(path)
|
||||
visited := map[string]bool{absPath: true}
|
||||
if err := processIncludes(&cfg, filepath.Dir(path), 0, visited); err != nil {
|
||||
return nil, fmt.Errorf("处理配置引入失败: %w", err)
|
||||
}
|
||||
}
|
||||
@ -154,9 +156,9 @@ func Load(path string) (*Config, error) {
|
||||
|
||||
const maxIncludeDepth = 10
|
||||
|
||||
func processIncludes(cfg *Config, baseDir string, depth int) error {
|
||||
func processIncludes(cfg *Config, baseDir string, depth int, visited map[string]bool) error {
|
||||
if depth >= maxIncludeDepth {
|
||||
return fmt.Errorf("配置引入嵌套深度超过 %d 层,可能存在循环引入", maxIncludeDepth)
|
||||
return fmt.Errorf("配置引入嵌套深度超过 %d 层", maxIncludeDepth)
|
||||
}
|
||||
|
||||
for _, inc := range cfg.Include {
|
||||
@ -174,11 +176,18 @@ func processIncludes(cfg *Config, baseDir string, depth int) error {
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
absMatch, _ := filepath.Abs(match)
|
||||
if visited[absMatch] {
|
||||
return fmt.Errorf("检测到循环引入: %s", absMatch)
|
||||
}
|
||||
visited[absMatch] = true
|
||||
|
||||
info, err := os.Stat(match)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取引入文件 %q 失败: %w", match, err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
delete(visited, absMatch)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -193,7 +202,7 @@ func processIncludes(cfg *Config, baseDir string, depth int) error {
|
||||
}
|
||||
|
||||
if len(included.Include) > 0 {
|
||||
if err := processIncludes(&included, filepath.Dir(match), depth+1); err != nil {
|
||||
if err := processIncludes(&included, filepath.Dir(match), depth+1, visited); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -208,6 +217,8 @@ func processIncludes(cfg *Config, baseDir string, depth int) error {
|
||||
cfg.Variables.Set[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
delete(visited, absMatch)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ package config
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -598,7 +599,7 @@ include:
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("circular include detected", func(t *testing.T) {
|
||||
t.Run("circular include detected immediately", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cfg1 := `
|
||||
@ -624,6 +625,81 @@ include:
|
||||
if err == nil {
|
||||
t.Error("expected error for circular include")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "循环引入") {
|
||||
t.Errorf("error should mention circular include, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("self include detected", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
include:
|
||||
- path: "a.yaml"
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(filepath.Join(tmpDir, "a.yaml"))
|
||||
if err == nil {
|
||||
t.Error("expected error for self include")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("diamond include works", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
mainCfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
include:
|
||||
- path: "b.yaml"
|
||||
- path: "c.yaml"
|
||||
`
|
||||
bCfg := `
|
||||
servers:
|
||||
- listen: ":9090"
|
||||
include:
|
||||
- path: "shared.yaml"
|
||||
`
|
||||
cCfg := `
|
||||
servers:
|
||||
- listen: ":9091"
|
||||
include:
|
||||
- path: "shared.yaml"
|
||||
`
|
||||
sharedCfg := `
|
||||
variables:
|
||||
set:
|
||||
shared: value
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(bCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "c.yaml"), []byte(cCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "shared.yaml"), []byte(sharedCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
|
||||
if err != nil {
|
||||
t.Fatalf("diamond include should work: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Servers) != 3 {
|
||||
t.Errorf("expected 3 servers, got %d", len(cfg.Servers))
|
||||
}
|
||||
if cfg.Variables.Set["shared"] != "value" {
|
||||
t.Error("shared variable should be merged")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty include list is no-op", func(t *testing.T) {
|
||||
|
||||
@ -734,6 +734,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
||||
buf.WriteString("# - path: \"conf.d/*.yaml\" # 相对路径 + glob 模式\n")
|
||||
buf.WriteString("# - path: \"sites/example.yaml\" # 单个文件引入\n")
|
||||
buf.WriteString("# 支持循环检测和深度限制(最大 10 层)\n")
|
||||
buf.WriteString("# 注意:只有 servers、stream、variables 会被合并,其他字段忽略\n")
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user