fix(config): real circular include detection with visited set
Replace depth-only detection with path-based visited set tracking. Detects cycles immediately on first revisit instead of after 10 depth iterations. Supports diamond patterns (A->B->shared, A->C->shared) via backtracking. Add self-include and diamond tests. Document that only servers/stream/variables are merged in defaults.go.
This commit is contained in:
parent
556d40ceb0
commit
9b8ce2a08a
@ -140,7 +140,9 @@ func Load(path string) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(cfg.Include) > 0 {
|
if len(cfg.Include) > 0 {
|
||||||
if err := processIncludes(&cfg, filepath.Dir(path), 0); err != nil {
|
absPath, _ := filepath.Abs(path)
|
||||||
|
visited := map[string]bool{absPath: true}
|
||||||
|
if err := processIncludes(&cfg, filepath.Dir(path), 0, visited); err != nil {
|
||||||
return nil, fmt.Errorf("处理配置引入失败: %w", err)
|
return nil, fmt.Errorf("处理配置引入失败: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -154,9 +156,9 @@ func Load(path string) (*Config, error) {
|
|||||||
|
|
||||||
const maxIncludeDepth = 10
|
const maxIncludeDepth = 10
|
||||||
|
|
||||||
func processIncludes(cfg *Config, baseDir string, depth int) error {
|
func processIncludes(cfg *Config, baseDir string, depth int, visited map[string]bool) error {
|
||||||
if depth >= maxIncludeDepth {
|
if depth >= maxIncludeDepth {
|
||||||
return fmt.Errorf("配置引入嵌套深度超过 %d 层,可能存在循环引入", maxIncludeDepth)
|
return fmt.Errorf("配置引入嵌套深度超过 %d 层", maxIncludeDepth)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, inc := range cfg.Include {
|
for _, inc := range cfg.Include {
|
||||||
@ -174,11 +176,18 @@ func processIncludes(cfg *Config, baseDir string, depth int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, match := range matches {
|
for _, match := range matches {
|
||||||
|
absMatch, _ := filepath.Abs(match)
|
||||||
|
if visited[absMatch] {
|
||||||
|
return fmt.Errorf("检测到循环引入: %s", absMatch)
|
||||||
|
}
|
||||||
|
visited[absMatch] = true
|
||||||
|
|
||||||
info, err := os.Stat(match)
|
info, err := os.Stat(match)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("读取引入文件 %q 失败: %w", match, err)
|
return fmt.Errorf("读取引入文件 %q 失败: %w", match, err)
|
||||||
}
|
}
|
||||||
if info.IsDir() {
|
if info.IsDir() {
|
||||||
|
delete(visited, absMatch)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,7 +202,7 @@ func processIncludes(cfg *Config, baseDir string, depth int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(included.Include) > 0 {
|
if len(included.Include) > 0 {
|
||||||
if err := processIncludes(&included, filepath.Dir(match), depth+1); err != nil {
|
if err := processIncludes(&included, filepath.Dir(match), depth+1, visited); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -208,6 +217,8 @@ func processIncludes(cfg *Config, baseDir string, depth int) error {
|
|||||||
cfg.Variables.Set[k] = v
|
cfg.Variables.Set[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
delete(visited, absMatch)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ package config
|
|||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -598,7 +599,7 @@ include:
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("circular include detected", func(t *testing.T) {
|
t.Run("circular include detected immediately", func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
cfg1 := `
|
cfg1 := `
|
||||||
@ -624,6 +625,81 @@ include:
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error for circular include")
|
t.Error("expected error for circular include")
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(err.Error(), "循环引入") {
|
||||||
|
t.Errorf("error should mention circular include, got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("self include detected", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
cfg := `
|
||||||
|
servers:
|
||||||
|
- listen: ":8080"
|
||||||
|
include:
|
||||||
|
- path: "a.yaml"
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := Load(filepath.Join(tmpDir, "a.yaml"))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for self include")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("diamond include works", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
mainCfg := `
|
||||||
|
servers:
|
||||||
|
- listen: ":8080"
|
||||||
|
include:
|
||||||
|
- path: "b.yaml"
|
||||||
|
- path: "c.yaml"
|
||||||
|
`
|
||||||
|
bCfg := `
|
||||||
|
servers:
|
||||||
|
- listen: ":9090"
|
||||||
|
include:
|
||||||
|
- path: "shared.yaml"
|
||||||
|
`
|
||||||
|
cCfg := `
|
||||||
|
servers:
|
||||||
|
- listen: ":9091"
|
||||||
|
include:
|
||||||
|
- path: "shared.yaml"
|
||||||
|
`
|
||||||
|
sharedCfg := `
|
||||||
|
variables:
|
||||||
|
set:
|
||||||
|
shared: value
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(bCfg), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "c.yaml"), []byte(cCfg), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(tmpDir, "shared.yaml"), []byte(sharedCfg), 0o644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("diamond include should work: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.Servers) != 3 {
|
||||||
|
t.Errorf("expected 3 servers, got %d", len(cfg.Servers))
|
||||||
|
}
|
||||||
|
if cfg.Variables.Set["shared"] != "value" {
|
||||||
|
t.Error("shared variable should be merged")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("empty include list is no-op", func(t *testing.T) {
|
t.Run("empty include list is no-op", func(t *testing.T) {
|
||||||
|
|||||||
@ -734,6 +734,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
|||||||
buf.WriteString("# - path: \"conf.d/*.yaml\" # 相对路径 + glob 模式\n")
|
buf.WriteString("# - path: \"conf.d/*.yaml\" # 相对路径 + glob 模式\n")
|
||||||
buf.WriteString("# - path: \"sites/example.yaml\" # 单个文件引入\n")
|
buf.WriteString("# - path: \"sites/example.yaml\" # 单个文件引入\n")
|
||||||
buf.WriteString("# 支持循环检测和深度限制(最大 10 层)\n")
|
buf.WriteString("# 支持循环检测和深度限制(最大 10 层)\n")
|
||||||
|
buf.WriteString("# 注意:只有 servers、stream、variables 会被合并,其他字段忽略\n")
|
||||||
|
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user