feat(config): implement include directive with glob support

Support loading config fragments from external files via include
directive. Servers and streams are appended, variables merged with
main config priority. Includes glob expansion, nested includes
(depth limit 10), and circular include detection.
This commit is contained in:
xfy 2026-06-03 10:20:33 +08:00
parent 53ac4c84cd
commit 2e9ddc7400
2 changed files with 265 additions and 0 deletions

View File

@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
@ -138,6 +139,12 @@ func Load(path string) (*Config, error) {
return nil, fmt.Errorf("解析配置文件失败: %w", err)
}
if len(cfg.Include) > 0 {
if err := processIncludes(&cfg, filepath.Dir(path), 0); err != nil {
return nil, fmt.Errorf("处理配置引入失败: %w", err)
}
}
if err := Validate(&cfg); err != nil {
return nil, fmt.Errorf("配置验证失败: %w", err)
}
@ -145,6 +152,69 @@ func Load(path string) (*Config, error) {
return &cfg, nil
}
const maxIncludeDepth = 10
func processIncludes(cfg *Config, baseDir string, depth int) error {
if depth >= maxIncludeDepth {
return fmt.Errorf("配置引入嵌套深度超过 %d 层,可能存在循环引入", maxIncludeDepth)
}
for _, inc := range cfg.Include {
pattern := inc.Path
if !filepath.IsAbs(pattern) {
pattern = filepath.Join(baseDir, pattern)
}
matches, err := filepath.Glob(pattern)
if err != nil {
return fmt.Errorf("展开引入路径 %q 失败: %w", inc.Path, err)
}
if len(matches) == 0 {
return fmt.Errorf("引入路径 %q 未匹配到任何文件", inc.Path)
}
for _, match := range matches {
info, err := os.Stat(match)
if err != nil {
return fmt.Errorf("读取引入文件 %q 失败: %w", match, err)
}
if info.IsDir() {
continue
}
data, err := os.ReadFile(match)
if err != nil {
return fmt.Errorf("读取引入文件 %q 失败: %w", match, err)
}
var included Config
if err := yaml.Unmarshal(data, &included); err != nil {
return fmt.Errorf("解析引入文件 %q 失败: %w", match, err)
}
if len(included.Include) > 0 {
if err := processIncludes(&included, filepath.Dir(match), depth+1); err != nil {
return err
}
}
cfg.Servers = append(cfg.Servers, included.Servers...)
cfg.Stream = append(cfg.Stream, included.Stream...)
for k, v := range included.Variables.Set {
if _, exists := cfg.Variables.Set[k]; !exists {
if cfg.Variables.Set == nil {
cfg.Variables.Set = make(map[string]string)
}
cfg.Variables.Set[k] = v
}
}
}
}
cfg.Include = nil
return nil
}
// LoadFromString 从 YAML 字符串加载配置。
//
// 解析 YAML 格式的配置字符串,适用于从环境变量或命令行参数加载配置。

View File

@ -491,3 +491,198 @@ func TestProxyBufferingConfig_ParseBuffers(t *testing.T) {
})
}
}
func TestLoad_Include(t *testing.T) {
t.Run("append servers from include", func(t *testing.T) {
tmpDir := t.TempDir()
mainCfg := `
servers:
- listen: ":8080"
name: main
include:
- path: "conf.d/*.yaml"
`
incCfg := `
servers:
- listen: ":9090"
name: included
`
confDir := filepath.Join(tmpDir, "conf.d")
if err := os.MkdirAll(confDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(confDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil {
t.Fatal(err)
}
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if len(cfg.Servers) != 2 {
t.Fatalf("expected 2 servers, got %d", len(cfg.Servers))
}
if cfg.Servers[0].Name != "main" {
t.Errorf("Servers[0].Name = %q, want %q", cfg.Servers[0].Name, "main")
}
if cfg.Servers[1].Name != "included" {
t.Errorf("Servers[1].Name = %q, want %q", cfg.Servers[1].Name, "included")
}
})
t.Run("merge variables with main priority", func(t *testing.T) {
tmpDir := t.TempDir()
mainCfg := `
servers:
- listen: ":8080"
variables:
set:
app: lolly
env: production
include:
- path: "extra.yaml"
`
incCfg := `
servers:
- listen: ":9090"
variables:
set:
app: other
debug: "true"
`
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil {
t.Fatal(err)
}
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if cfg.Variables.Set["app"] != "lolly" {
t.Errorf("app = %q, want %q (main should win)", cfg.Variables.Set["app"], "lolly")
}
if cfg.Variables.Set["debug"] != "true" {
t.Errorf("debug = %q, want %q (included should fill missing)", cfg.Variables.Set["debug"], "true")
}
if cfg.Variables.Set["env"] != "production" {
t.Errorf("env = %q, want %q", cfg.Variables.Set["env"], "production")
}
})
t.Run("no matches returns error", func(t *testing.T) {
tmpDir := t.TempDir()
mainCfg := `
servers:
- listen: ":8080"
include:
- path: "nonexistent/*.yaml"
`
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
t.Fatal(err)
}
_, err := Load(filepath.Join(tmpDir, "config.yaml"))
if err == nil {
t.Error("expected error for non-matching glob pattern")
}
})
t.Run("circular include detected", func(t *testing.T) {
tmpDir := t.TempDir()
cfg1 := `
servers:
- listen: ":8080"
include:
- path: "b.yaml"
`
cfg2 := `
servers:
- listen: ":9090"
include:
- path: "a.yaml"
`
if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg1), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(cfg2), 0o644); err != nil {
t.Fatal(err)
}
_, err := Load(filepath.Join(tmpDir, "a.yaml"))
if err == nil {
t.Error("expected error for circular include")
}
})
t.Run("empty include list is no-op", func(t *testing.T) {
tmpDir := t.TempDir()
mainCfg := `
servers:
- listen: ":8080"
`
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
t.Fatal(err)
}
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if len(cfg.Servers) != 1 {
t.Errorf("expected 1 server, got %d", len(cfg.Servers))
}
})
t.Run("append stream from include", func(t *testing.T) {
tmpDir := t.TempDir()
mainCfg := `
servers:
- listen: ":8080"
stream:
- listen: ":5432"
protocol: tcp
upstream:
targets:
- addr: "127.0.0.1:9000"
include:
- path: "stream.yaml"
`
incCfg := `
stream:
- listen: ":5433"
protocol: udp
upstream:
targets:
- addr: "127.0.0.1:9001"
`
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "stream.yaml"), []byte(incCfg), 0o644); err != nil {
t.Fatal(err)
}
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if len(cfg.Stream) != 2 {
t.Fatalf("expected 2 streams, got %d", len(cfg.Stream))
}
})
}