From 2e9ddc7400c6cfc16af9fcc50b3ee5f9d2b8907e Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 3 Jun 2026 10:20:33 +0800 Subject: [PATCH] feat(config): implement include directive with glob support Support loading config fragments from external files via include directive. Servers and streams are appended, variables merged with main config priority. Includes glob expansion, nested includes (depth limit 10), and circular include detection. --- internal/config/config.go | 70 ++++++++++++ internal/config/config_test.go | 195 +++++++++++++++++++++++++++++++++ 2 files changed, 265 insertions(+) diff --git a/internal/config/config.go b/internal/config/config.go index ff5d0a1..c1c6076 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "os" + "path/filepath" "strconv" "strings" @@ -138,6 +139,12 @@ func Load(path string) (*Config, error) { return nil, fmt.Errorf("解析配置文件失败: %w", err) } + if len(cfg.Include) > 0 { + if err := processIncludes(&cfg, filepath.Dir(path), 0); err != nil { + return nil, fmt.Errorf("处理配置引入失败: %w", err) + } + } + if err := Validate(&cfg); err != nil { return nil, fmt.Errorf("配置验证失败: %w", err) } @@ -145,6 +152,69 @@ func Load(path string) (*Config, error) { return &cfg, nil } +const maxIncludeDepth = 10 + +func processIncludes(cfg *Config, baseDir string, depth int) error { + if depth >= maxIncludeDepth { + return fmt.Errorf("配置引入嵌套深度超过 %d 层,可能存在循环引入", maxIncludeDepth) + } + + for _, inc := range cfg.Include { + pattern := inc.Path + if !filepath.IsAbs(pattern) { + pattern = filepath.Join(baseDir, pattern) + } + + matches, err := filepath.Glob(pattern) + if err != nil { + return fmt.Errorf("展开引入路径 %q 失败: %w", inc.Path, err) + } + if len(matches) == 0 { + return fmt.Errorf("引入路径 %q 未匹配到任何文件", inc.Path) + } + + for _, match := range matches { + info, err := os.Stat(match) + if err != nil { + return fmt.Errorf("读取引入文件 %q 失败: %w", match, err) + } + if info.IsDir() { + continue + } + + data, err := os.ReadFile(match) + if err != nil { + return fmt.Errorf("读取引入文件 %q 失败: %w", match, err) + } + + var included Config + if err := yaml.Unmarshal(data, &included); err != nil { + return fmt.Errorf("解析引入文件 %q 失败: %w", match, err) + } + + if len(included.Include) > 0 { + if err := processIncludes(&included, filepath.Dir(match), depth+1); err != nil { + return err + } + } + + cfg.Servers = append(cfg.Servers, included.Servers...) + cfg.Stream = append(cfg.Stream, included.Stream...) + for k, v := range included.Variables.Set { + if _, exists := cfg.Variables.Set[k]; !exists { + if cfg.Variables.Set == nil { + cfg.Variables.Set = make(map[string]string) + } + cfg.Variables.Set[k] = v + } + } + } + } + + cfg.Include = nil + return nil +} + // LoadFromString 从 YAML 字符串加载配置。 // // 解析 YAML 格式的配置字符串,适用于从环境变量或命令行参数加载配置。 diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 131b27e..acb854f 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -491,3 +491,198 @@ func TestProxyBufferingConfig_ParseBuffers(t *testing.T) { }) } } + +func TestLoad_Include(t *testing.T) { + t.Run("append servers from include", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" + name: main +include: + - path: "conf.d/*.yaml" +` + incCfg := ` +servers: + - listen: ":9090" + name: included +` + confDir := filepath.Join(tmpDir, "conf.d") + if err := os.MkdirAll(confDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(confDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if len(cfg.Servers) != 2 { + t.Fatalf("expected 2 servers, got %d", len(cfg.Servers)) + } + if cfg.Servers[0].Name != "main" { + t.Errorf("Servers[0].Name = %q, want %q", cfg.Servers[0].Name, "main") + } + if cfg.Servers[1].Name != "included" { + t.Errorf("Servers[1].Name = %q, want %q", cfg.Servers[1].Name, "included") + } + }) + + t.Run("merge variables with main priority", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +variables: + set: + app: lolly + env: production +include: + - path: "extra.yaml" +` + incCfg := ` +servers: + - listen: ":9090" +variables: + set: + app: other + debug: "true" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if cfg.Variables.Set["app"] != "lolly" { + t.Errorf("app = %q, want %q (main should win)", cfg.Variables.Set["app"], "lolly") + } + if cfg.Variables.Set["debug"] != "true" { + t.Errorf("debug = %q, want %q (included should fill missing)", cfg.Variables.Set["debug"], "true") + } + if cfg.Variables.Set["env"] != "production" { + t.Errorf("env = %q, want %q", cfg.Variables.Set["env"], "production") + } + }) + + t.Run("no matches returns error", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +include: + - path: "nonexistent/*.yaml" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + + _, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err == nil { + t.Error("expected error for non-matching glob pattern") + } + }) + + t.Run("circular include detected", func(t *testing.T) { + tmpDir := t.TempDir() + + cfg1 := ` +servers: + - listen: ":8080" +include: + - path: "b.yaml" +` + cfg2 := ` +servers: + - listen: ":9090" +include: + - path: "a.yaml" +` + if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg1), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(cfg2), 0o644); err != nil { + t.Fatal(err) + } + + _, err := Load(filepath.Join(tmpDir, "a.yaml")) + if err == nil { + t.Error("expected error for circular include") + } + }) + + t.Run("empty include list is no-op", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + if len(cfg.Servers) != 1 { + t.Errorf("expected 1 server, got %d", len(cfg.Servers)) + } + }) + + t.Run("append stream from include", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +stream: + - listen: ":5432" + protocol: tcp + upstream: + targets: + - addr: "127.0.0.1:9000" +include: + - path: "stream.yaml" +` + incCfg := ` +stream: + - listen: ":5433" + protocol: udp + upstream: + targets: + - addr: "127.0.0.1:9001" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "stream.yaml"), []byte(incCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if len(cfg.Stream) != 2 { + t.Fatalf("expected 2 streams, got %d", len(cfg.Stream)) + } + }) +}