feat(config): implement include directive with glob support
Support loading config fragments from external files via include directive. Servers and streams are appended, variables merged with main config priority. Includes glob expansion, nested includes (depth limit 10), and circular include detection.
This commit is contained in:
parent
53ac4c84cd
commit
2e9ddc7400
@ -20,6 +20,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@ -138,6 +139,12 @@ func Load(path string) (*Config, error) {
|
||||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
if len(cfg.Include) > 0 {
|
||||
if err := processIncludes(&cfg, filepath.Dir(path), 0); err != nil {
|
||||
return nil, fmt.Errorf("处理配置引入失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := Validate(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("配置验证失败: %w", err)
|
||||
}
|
||||
@ -145,6 +152,69 @@ func Load(path string) (*Config, error) {
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
const maxIncludeDepth = 10
|
||||
|
||||
func processIncludes(cfg *Config, baseDir string, depth int) error {
|
||||
if depth >= maxIncludeDepth {
|
||||
return fmt.Errorf("配置引入嵌套深度超过 %d 层,可能存在循环引入", maxIncludeDepth)
|
||||
}
|
||||
|
||||
for _, inc := range cfg.Include {
|
||||
pattern := inc.Path
|
||||
if !filepath.IsAbs(pattern) {
|
||||
pattern = filepath.Join(baseDir, pattern)
|
||||
}
|
||||
|
||||
matches, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("展开引入路径 %q 失败: %w", inc.Path, err)
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return fmt.Errorf("引入路径 %q 未匹配到任何文件", inc.Path)
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
info, err := os.Stat(match)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取引入文件 %q 失败: %w", match, err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(match)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取引入文件 %q 失败: %w", match, err)
|
||||
}
|
||||
|
||||
var included Config
|
||||
if err := yaml.Unmarshal(data, &included); err != nil {
|
||||
return fmt.Errorf("解析引入文件 %q 失败: %w", match, err)
|
||||
}
|
||||
|
||||
if len(included.Include) > 0 {
|
||||
if err := processIncludes(&included, filepath.Dir(match), depth+1); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
cfg.Servers = append(cfg.Servers, included.Servers...)
|
||||
cfg.Stream = append(cfg.Stream, included.Stream...)
|
||||
for k, v := range included.Variables.Set {
|
||||
if _, exists := cfg.Variables.Set[k]; !exists {
|
||||
if cfg.Variables.Set == nil {
|
||||
cfg.Variables.Set = make(map[string]string)
|
||||
}
|
||||
cfg.Variables.Set[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cfg.Include = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromString 从 YAML 字符串加载配置。
|
||||
//
|
||||
// 解析 YAML 格式的配置字符串,适用于从环境变量或命令行参数加载配置。
|
||||
|
||||
@ -491,3 +491,198 @@ func TestProxyBufferingConfig_ParseBuffers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_Include(t *testing.T) {
|
||||
t.Run("append servers from include", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
mainCfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
name: main
|
||||
include:
|
||||
- path: "conf.d/*.yaml"
|
||||
`
|
||||
incCfg := `
|
||||
servers:
|
||||
- listen: ":9090"
|
||||
name: included
|
||||
`
|
||||
confDir := filepath.Join(tmpDir, "conf.d")
|
||||
if err := os.MkdirAll(confDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(confDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Servers) != 2 {
|
||||
t.Fatalf("expected 2 servers, got %d", len(cfg.Servers))
|
||||
}
|
||||
if cfg.Servers[0].Name != "main" {
|
||||
t.Errorf("Servers[0].Name = %q, want %q", cfg.Servers[0].Name, "main")
|
||||
}
|
||||
if cfg.Servers[1].Name != "included" {
|
||||
t.Errorf("Servers[1].Name = %q, want %q", cfg.Servers[1].Name, "included")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("merge variables with main priority", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
mainCfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
variables:
|
||||
set:
|
||||
app: lolly
|
||||
env: production
|
||||
include:
|
||||
- path: "extra.yaml"
|
||||
`
|
||||
incCfg := `
|
||||
servers:
|
||||
- listen: ":9090"
|
||||
variables:
|
||||
set:
|
||||
app: other
|
||||
debug: "true"
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Variables.Set["app"] != "lolly" {
|
||||
t.Errorf("app = %q, want %q (main should win)", cfg.Variables.Set["app"], "lolly")
|
||||
}
|
||||
if cfg.Variables.Set["debug"] != "true" {
|
||||
t.Errorf("debug = %q, want %q (included should fill missing)", cfg.Variables.Set["debug"], "true")
|
||||
}
|
||||
if cfg.Variables.Set["env"] != "production" {
|
||||
t.Errorf("env = %q, want %q", cfg.Variables.Set["env"], "production")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no matches returns error", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
mainCfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
include:
|
||||
- path: "nonexistent/*.yaml"
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(filepath.Join(tmpDir, "config.yaml"))
|
||||
if err == nil {
|
||||
t.Error("expected error for non-matching glob pattern")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("circular include detected", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cfg1 := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
include:
|
||||
- path: "b.yaml"
|
||||
`
|
||||
cfg2 := `
|
||||
servers:
|
||||
- listen: ":9090"
|
||||
include:
|
||||
- path: "a.yaml"
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg1), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(cfg2), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(filepath.Join(tmpDir, "a.yaml"))
|
||||
if err == nil {
|
||||
t.Error("expected error for circular include")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty include list is no-op", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
mainCfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
if len(cfg.Servers) != 1 {
|
||||
t.Errorf("expected 1 server, got %d", len(cfg.Servers))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("append stream from include", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
mainCfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
stream:
|
||||
- listen: ":5432"
|
||||
protocol: tcp
|
||||
upstream:
|
||||
targets:
|
||||
- addr: "127.0.0.1:9000"
|
||||
include:
|
||||
- path: "stream.yaml"
|
||||
`
|
||||
incCfg := `
|
||||
stream:
|
||||
- listen: ":5433"
|
||||
protocol: udp
|
||||
upstream:
|
||||
targets:
|
||||
- addr: "127.0.0.1:9001"
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "stream.yaml"), []byte(incCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Stream) != 2 {
|
||||
t.Fatalf("expected 2 streams, got %d", len(cfg.Stream))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user