diff --git a/internal/config/config.go b/internal/config/config.go index ff5d0a1..c1c6076 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "os" + "path/filepath" "strconv" "strings" @@ -138,6 +139,12 @@ func Load(path string) (*Config, error) { return nil, fmt.Errorf("解析配置文件失败: %w", err) } + if len(cfg.Include) > 0 { + if err := processIncludes(&cfg, filepath.Dir(path), 0); err != nil { + return nil, fmt.Errorf("处理配置引入失败: %w", err) + } + } + if err := Validate(&cfg); err != nil { return nil, fmt.Errorf("配置验证失败: %w", err) } @@ -145,6 +152,69 @@ func Load(path string) (*Config, error) { return &cfg, nil } +const maxIncludeDepth = 10 + +func processIncludes(cfg *Config, baseDir string, depth int) error { + if depth >= maxIncludeDepth { + return fmt.Errorf("配置引入嵌套深度超过 %d 层,可能存在循环引入", maxIncludeDepth) + } + + for _, inc := range cfg.Include { + pattern := inc.Path + if !filepath.IsAbs(pattern) { + pattern = filepath.Join(baseDir, pattern) + } + + matches, err := filepath.Glob(pattern) + if err != nil { + return fmt.Errorf("展开引入路径 %q 失败: %w", inc.Path, err) + } + if len(matches) == 0 { + return fmt.Errorf("引入路径 %q 未匹配到任何文件", inc.Path) + } + + for _, match := range matches { + info, err := os.Stat(match) + if err != nil { + return fmt.Errorf("读取引入文件 %q 失败: %w", match, err) + } + if info.IsDir() { + continue + } + + data, err := os.ReadFile(match) + if err != nil { + return fmt.Errorf("读取引入文件 %q 失败: %w", match, err) + } + + var included Config + if err := yaml.Unmarshal(data, &included); err != nil { + return fmt.Errorf("解析引入文件 %q 失败: %w", match, err) + } + + if len(included.Include) > 0 { + if err := processIncludes(&included, filepath.Dir(match), depth+1); err != nil { + return err + } + } + + cfg.Servers = append(cfg.Servers, included.Servers...) + cfg.Stream = append(cfg.Stream, included.Stream...) + for k, v := range included.Variables.Set { + if _, exists := cfg.Variables.Set[k]; !exists { + if cfg.Variables.Set == nil { + cfg.Variables.Set = make(map[string]string) + } + cfg.Variables.Set[k] = v + } + } + } + } + + cfg.Include = nil + return nil +} + // LoadFromString 从 YAML 字符串加载配置。 // // 解析 YAML 格式的配置字符串,适用于从环境变量或命令行参数加载配置。 diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 131b27e..acb854f 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -491,3 +491,198 @@ func TestProxyBufferingConfig_ParseBuffers(t *testing.T) { }) } } + +func TestLoad_Include(t *testing.T) { + t.Run("append servers from include", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" + name: main +include: + - path: "conf.d/*.yaml" +` + incCfg := ` +servers: + - listen: ":9090" + name: included +` + confDir := filepath.Join(tmpDir, "conf.d") + if err := os.MkdirAll(confDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(confDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if len(cfg.Servers) != 2 { + t.Fatalf("expected 2 servers, got %d", len(cfg.Servers)) + } + if cfg.Servers[0].Name != "main" { + t.Errorf("Servers[0].Name = %q, want %q", cfg.Servers[0].Name, "main") + } + if cfg.Servers[1].Name != "included" { + t.Errorf("Servers[1].Name = %q, want %q", cfg.Servers[1].Name, "included") + } + }) + + t.Run("merge variables with main priority", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +variables: + set: + app: lolly + env: production +include: + - path: "extra.yaml" +` + incCfg := ` +servers: + - listen: ":9090" +variables: + set: + app: other + debug: "true" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if cfg.Variables.Set["app"] != "lolly" { + t.Errorf("app = %q, want %q (main should win)", cfg.Variables.Set["app"], "lolly") + } + if cfg.Variables.Set["debug"] != "true" { + t.Errorf("debug = %q, want %q (included should fill missing)", cfg.Variables.Set["debug"], "true") + } + if cfg.Variables.Set["env"] != "production" { + t.Errorf("env = %q, want %q", cfg.Variables.Set["env"], "production") + } + }) + + t.Run("no matches returns error", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +include: + - path: "nonexistent/*.yaml" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + + _, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err == nil { + t.Error("expected error for non-matching glob pattern") + } + }) + + t.Run("circular include detected", func(t *testing.T) { + tmpDir := t.TempDir() + + cfg1 := ` +servers: + - listen: ":8080" +include: + - path: "b.yaml" +` + cfg2 := ` +servers: + - listen: ":9090" +include: + - path: "a.yaml" +` + if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg1), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(cfg2), 0o644); err != nil { + t.Fatal(err) + } + + _, err := Load(filepath.Join(tmpDir, "a.yaml")) + if err == nil { + t.Error("expected error for circular include") + } + }) + + t.Run("empty include list is no-op", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + if len(cfg.Servers) != 1 { + t.Errorf("expected 1 server, got %d", len(cfg.Servers)) + } + }) + + t.Run("append stream from include", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +stream: + - listen: ":5432" + protocol: tcp + upstream: + targets: + - addr: "127.0.0.1:9000" +include: + - path: "stream.yaml" +` + incCfg := ` +stream: + - listen: ":5433" + protocol: udp + upstream: + targets: + - addr: "127.0.0.1:9001" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "stream.yaml"), []byte(incCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if len(cfg.Stream) != 2 { + t.Fatalf("expected 2 streams, got %d", len(cfg.Stream)) + } + }) +}