diff --git a/AGENTS.md b/AGENTS.md index 9fb5fe0..ff3487f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -14,7 +14,6 @@ | `go.sum` | 依赖版本锁定 | | `Makefile` | 构建脚本,支持多平台编译、测试、覆盖率 | | `lolly.yaml` | 默认配置文件示例 | -| `config.example.yaml` | 完整配置文件示例(所有字段枚举) | | `.gitignore` | Git 忽略规则 | ## Subdirectories @@ -27,7 +26,6 @@ | `examples/` | Lua 脚本示例 | | `html/` | 静态 HTML 文件(测试/示例) | | `scripts/` | 构建/测试辅助脚本(回归检测) | -| `.github/` | CI/CD 工作流定义 | ## For AI Agents @@ -43,6 +41,7 @@ - 运行测试前确保依赖已下载:`go mod download` - 测试覆盖率目标 >80% - 使用 `make check` 运行完整检查(fmt + lint + test) +- 使用 `lolly --generate-config` 生成完整配置文件模板 ### Common Patterns - 配置结构体使用 `yaml` 标签,通过 `gopkg.in/yaml.v3` 解析 diff --git a/docs/llms.txt b/docs/llms.txt index 2089cc5..a4da944 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -58,8 +58,7 @@ http3: {} # HTTP/3 配置 resolver: {} # DNS 解析配置 performance: {} # 性能配置 shutdown: {} # 关闭配置 -include: [] # 配置引入 -cache_path: {} # 缓存路径配置 + include: [] # 配置引入 ``` ### 运行模式 diff --git a/internal/app/app.go b/internal/app/app.go index 2cff744..aa47e68 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -4,6 +4,7 @@ package app import ( "fmt" + "net" "os" "os/signal" "syscall" @@ -164,11 +165,129 @@ func (a *App) reloadConfig() { return } + if a.srv == nil { + a.cfg = newCfg + a.logger = logging.NewAppLogger(&newCfg.Logging) + a.logger.LogStartup("Config reloaded (no running server)", nil) + return + } + + if a.requiresFullRestart(newCfg) { + logging.Warn().Msg("Config requires full restart (listen address or mode changed). Use SIGUSR2 for graceful upgrade.") + return + } + + listeners := a.srv.GetListeners() + if len(listeners) == 0 { + a.logger.Error().Msg("Cannot reload: server has no saved listeners") + return + } + + duped := make([]net.Listener, len(listeners)) + for i, ln := range listeners { + duped[i], err = server.DupListener(ln) + if err != nil { + for j := 0; j < i; j++ { + _ = duped[j].Close() + } + a.logger.Error().Err(err).Msg("Failed to dup listener for reload") + return + } + } + + newSrv := server.New(newCfg) + if a.resv != nil { + newSrv.SetResolver(a.resv) + } + newSrv.SetListeners(duped) + + startErr := make(chan error, 1) + go func() { + if err := newSrv.Start(); err != nil { + startErr <- err + } + }() + + reloadTimeout := a.cfg.Shutdown.ReloadTimeout + if reloadTimeout <= 0 { + reloadTimeout = 5 * time.Second + } + + select { + case err := <-startErr: + a.logger.Error().Err(err).Msg("Failed to start new server with reloaded config") + for _, ln := range duped { + _ = ln.Close() + } + return + case <-time.After(reloadTimeout): + } + + oldSrv := a.srv + oldHTTP2 := a.http2Srv + oldHTTP3 := a.http3Srv + + a.srv = newSrv a.cfg = newCfg a.logger = logging.NewAppLogger(&newCfg.Logging) + a.http2Srv = nil + a.http3Srv = nil + + a.initVariables() + a.initHTTP2() + a.initHTTP3() + + if a.upgradeMgr != nil { + a.upgradeMgr.SetListeners(newSrv.GetListeners()) + } + + go func() { + if oldHTTP2 != nil { + _ = oldHTTP2.Stop() + } + if oldHTTP3 != nil { + _ = oldHTTP3.Stop() + } + _ = oldSrv.GracefulStop(30 * time.Second) + }() + a.logger.LogStartup("Config reloaded successfully", nil) } +func (a *App) requiresFullRestart(newCfg *config.Config) bool { + if a.cfg.GetMode() != newCfg.GetMode() { + return true + } + oldMode := a.cfg.GetMode() + switch oldMode { + case config.ServerModeSingle: + if len(a.cfg.Servers) > 0 && len(newCfg.Servers) > 0 { + if a.cfg.Servers[0].Listen != newCfg.Servers[0].Listen { + return true + } + } + case config.ServerModeVHost: + if len(a.cfg.Servers) != len(newCfg.Servers) { + return true + } + if len(a.cfg.Servers) > 0 && len(newCfg.Servers) > 0 { + if a.cfg.Servers[0].Listen != newCfg.Servers[0].Listen { + return true + } + } + case config.ServerModeMultiServer: + if len(a.cfg.Servers) != len(newCfg.Servers) { + return true + } + for i := range a.cfg.Servers { + if a.cfg.Servers[i].Listen != newCfg.Servers[i].Listen { + return true + } + } + } + return false +} + func (a *App) gracefulUpgrade() { execPath, err := os.Executable() if err != nil { diff --git a/internal/app/app_test.go b/internal/app/app_test.go index dd8b32f..42e4dba 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -461,7 +461,6 @@ func TestHandleSignal_SIGINT(t *testing.T) { // TestHandleSignal_SIGHUP 测试 SIGHUP 信号处理(重载配置) func TestHandleSignal_SIGHUP(t *testing.T) { - // 创建临时配置文件 tmpDir := t.TempDir() cfgPath := filepath.Join(tmpDir, "config.yaml") cfgContent := ` @@ -1448,14 +1447,12 @@ logging: } app.logger = setupTestLogger() - // 发送 SIGHUP 信号 result := app.handleSignal(syscall.SIGHUP) if result != true { t.Error("Expected handleSignal(SIGHUP) to return true") } - // 验证配置已更新 if app.cfg.Servers[0].Listen != ":7070" { t.Errorf("Expected listen ':7070', got '%s'", app.cfg.Servers[0].Listen) } diff --git a/internal/config/cache_config.go b/internal/config/cache_config.go index a5d5f5f..3b7e1e7 100644 --- a/internal/config/cache_config.go +++ b/internal/config/cache_config.go @@ -2,54 +2,6 @@ package config import "time" -// ProxyCachePathConfig 缓存路径配置(磁盘持久化)。 -// -// 配置磁盘缓存路径和相关参数,支持 L1/L2 分层缓存架构。 -// 配置后,代理缓存将持久化到磁盘,服务重启后可恢复。 -// -// 注意事项: -// - Path 为必填项,指定缓存根目录 -// - Levels 支持最多 3 级目录(如 "1:2:2") -// - MaxSize 为 0 表示不限制大小 -// - L1MaxEntries/L1MaxSize 为 0 时使用默认值 -// -// 使用示例: -// -// cache_path: -// path: "/var/cache/lolly" -// levels: "1:2" -// max_size: "1GB" -// inactive: "60m" -// l1_max_entries: 10000 -type ProxyCachePathConfig struct { - // Path 缓存根目录 - Path string `yaml:"path"` - - // Levels 目录层级,如 "1:2" 表示两级目录 - Levels string `yaml:"levels"` - - // MaxSize 最大缓存大小(字节) - MaxSize int64 `yaml:"max_size"` - - // Inactive 未访问淘汰时间 - Inactive time.Duration `yaml:"inactive"` - - // Purger 是否启用后台清理 - Purger bool `yaml:"purger"` - - // PurgerInterval 清理间隔 - PurgerInterval time.Duration `yaml:"purger_interval"` - - // L1MaxEntries L1 最大条目数 - L1MaxEntries int64 `yaml:"l1_max_entries"` - - // L1MaxSize L1 最大内存大小 - L1MaxSize int64 `yaml:"l1_max_size"` - - // PromoteThreshold 提升到 L1 的访问阈值 - PromoteThreshold int `yaml:"promote_threshold"` -} - // ProxyCacheConfig 代理缓存配置。 // // 缓存后端响应,减少重复请求,提高响应速度。 diff --git a/internal/config/config.go b/internal/config/config.go index 10a6e14..b790326 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "os" + "path/filepath" "strconv" "strings" @@ -81,8 +82,7 @@ type Config struct { Resolver ResolverConfig `yaml:"resolver"` Performance PerformanceConfig `yaml:"performance"` Shutdown ShutdownConfig `yaml:"shutdown"` - Include []IncludeConfig `yaml:"include"` // 配置引入,支持从其他文件引入配置片段 - CachePath *ProxyCachePathConfig `yaml:"cache_path"` // 缓存路径配置(磁盘持久化) + Include []IncludeConfig `yaml:"include"` // 配置引入,支持从其他文件引入配置片段 } // parseSize 解析大小字符串(支持 k, m 单位)。 @@ -139,6 +139,17 @@ func Load(path string) (*Config, error) { return nil, fmt.Errorf("解析配置文件失败: %w", err) } + if len(cfg.Include) > 0 { + absPath, err := filepath.Abs(path) + if err != nil { + return nil, fmt.Errorf("获取配置文件绝对路径失败: %w", err) + } + visited := map[string]bool{absPath: true} + if err := processIncludes(&cfg, filepath.Dir(path), 0, visited); err != nil { + return nil, fmt.Errorf("处理配置引入失败: %w", err) + } + } + if err := Validate(&cfg); err != nil { return nil, fmt.Errorf("配置验证失败: %w", err) } @@ -146,6 +157,81 @@ func Load(path string) (*Config, error) { return &cfg, nil } +const maxIncludeDepth = 10 + +func processIncludes(cfg *Config, baseDir string, depth int, visited map[string]bool) error { + if depth >= maxIncludeDepth { + return fmt.Errorf("配置引入嵌套深度超过 %d 层", maxIncludeDepth) + } + + for _, inc := range cfg.Include { + pattern := inc.Path + if !filepath.IsAbs(pattern) { + pattern = filepath.Join(baseDir, pattern) + } + + matches, err := filepath.Glob(pattern) + if err != nil { + return fmt.Errorf("展开引入路径 %q 失败: %w", inc.Path, err) + } + if len(matches) == 0 { + return fmt.Errorf("引入路径 %q 未匹配到任何文件", inc.Path) + } + + for _, match := range matches { + absMatch, err := filepath.Abs(match) + if err != nil { + return fmt.Errorf("获取引入文件绝对路径失败 %q: %w", match, err) + } + if visited[absMatch] { + return fmt.Errorf("检测到循环引入: %s", absMatch) + } + visited[absMatch] = true + + info, err := os.Stat(match) + if err != nil { + return fmt.Errorf("读取引入文件 %q 失败: %w", match, err) + } + if info.IsDir() { + delete(visited, absMatch) + continue + } + + data, err := os.ReadFile(match) + if err != nil { + return fmt.Errorf("读取引入文件 %q 失败: %w", match, err) + } + + var included Config + if err := yaml.Unmarshal(data, &included); err != nil { + return fmt.Errorf("解析引入文件 %q 失败: %w", match, err) + } + + if len(included.Include) > 0 { + if err := processIncludes(&included, filepath.Dir(match), depth+1, visited); err != nil { + return err + } + } + + cfg.Servers = append(cfg.Servers, included.Servers...) + cfg.Stream = append(cfg.Stream, included.Stream...) + for k, v := range included.Variables.Set { + if _, exists := cfg.Variables.Set[k]; !exists { + if cfg.Variables.Set == nil { + cfg.Variables.Set = make(map[string]string) + } + cfg.Variables.Set[k] = v + } + } + + delete(visited, absMatch) + } + } + + cfg.Include = nil + return nil +} + // LoadFromString 从 YAML 字符串加载配置。 // // 解析 YAML 格式的配置字符串,适用于从环境变量或命令行参数加载配置。 diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 131b27e..1bf5d9b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -4,6 +4,7 @@ package config import ( "os" "path/filepath" + "strings" "testing" ) @@ -491,3 +492,273 @@ func TestProxyBufferingConfig_ParseBuffers(t *testing.T) { }) } } + +func TestLoad_Include(t *testing.T) { + t.Run("append servers from include", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" + name: main +include: + - path: "conf.d/*.yaml" +` + incCfg := ` +servers: + - listen: ":9090" + name: included +` + confDir := filepath.Join(tmpDir, "conf.d") + if err := os.MkdirAll(confDir, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(confDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if len(cfg.Servers) != 2 { + t.Fatalf("expected 2 servers, got %d", len(cfg.Servers)) + } + if cfg.Servers[0].Name != "main" { + t.Errorf("Servers[0].Name = %q, want %q", cfg.Servers[0].Name, "main") + } + if cfg.Servers[1].Name != "included" { + t.Errorf("Servers[1].Name = %q, want %q", cfg.Servers[1].Name, "included") + } + }) + + t.Run("merge variables with main priority", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +variables: + set: + app: lolly + env: production +include: + - path: "extra.yaml" +` + incCfg := ` +servers: + - listen: ":9090" +variables: + set: + app: other + debug: "true" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if cfg.Variables.Set["app"] != "lolly" { + t.Errorf("app = %q, want %q (main should win)", cfg.Variables.Set["app"], "lolly") + } + if cfg.Variables.Set["debug"] != "true" { + t.Errorf("debug = %q, want %q (included should fill missing)", cfg.Variables.Set["debug"], "true") + } + if cfg.Variables.Set["env"] != "production" { + t.Errorf("env = %q, want %q", cfg.Variables.Set["env"], "production") + } + }) + + t.Run("no matches returns error", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +include: + - path: "nonexistent/*.yaml" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + + _, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err == nil { + t.Error("expected error for non-matching glob pattern") + } + }) + + t.Run("circular include detected immediately", func(t *testing.T) { + tmpDir := t.TempDir() + + cfg1 := ` +servers: + - listen: ":8080" +include: + - path: "b.yaml" +` + cfg2 := ` +servers: + - listen: ":9090" +include: + - path: "a.yaml" +` + if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg1), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(cfg2), 0o644); err != nil { + t.Fatal(err) + } + + _, err := Load(filepath.Join(tmpDir, "a.yaml")) + if err == nil { + t.Error("expected error for circular include") + } + if !strings.Contains(err.Error(), "循环引入") { + t.Errorf("error should mention circular include, got: %v", err) + } + }) + + t.Run("self include detected", func(t *testing.T) { + tmpDir := t.TempDir() + + cfg := ` +servers: + - listen: ":8080" +include: + - path: "a.yaml" +` + if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg), 0o644); err != nil { + t.Fatal(err) + } + + _, err := Load(filepath.Join(tmpDir, "a.yaml")) + if err == nil { + t.Error("expected error for self include") + } + }) + + t.Run("diamond include works", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +include: + - path: "b.yaml" + - path: "c.yaml" +` + bCfg := ` +servers: + - listen: ":9090" +include: + - path: "shared.yaml" +` + cCfg := ` +servers: + - listen: ":9091" +include: + - path: "shared.yaml" +` + sharedCfg := ` +variables: + set: + shared: value +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(bCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "c.yaml"), []byte(cCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "shared.yaml"), []byte(sharedCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("diamond include should work: %v", err) + } + + if len(cfg.Servers) != 3 { + t.Errorf("expected 3 servers, got %d", len(cfg.Servers)) + } + if cfg.Variables.Set["shared"] != "value" { + t.Error("shared variable should be merged") + } + }) + + t.Run("empty include list is no-op", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + if len(cfg.Servers) != 1 { + t.Errorf("expected 1 server, got %d", len(cfg.Servers)) + } + }) + + t.Run("append stream from include", func(t *testing.T) { + tmpDir := t.TempDir() + + mainCfg := ` +servers: + - listen: ":8080" +stream: + - listen: ":5432" + protocol: tcp + upstream: + targets: + - addr: "127.0.0.1:9000" +include: + - path: "stream.yaml" +` + incCfg := ` +stream: + - listen: ":5433" + protocol: udp + upstream: + targets: + - addr: "127.0.0.1:9001" +` + if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(tmpDir, "stream.yaml"), []byte(incCfg), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load(filepath.Join(tmpDir, "config.yaml")) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + if len(cfg.Stream) != 2 { + t.Fatalf("expected 2 streams, got %d", len(cfg.Stream)) + } + }) +} diff --git a/internal/config/defaults.go b/internal/config/defaults.go index b4f4bce..dc2e0e4 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -222,6 +222,7 @@ func DefaultConfig() *Config { Shutdown: ShutdownConfig{ GracefulTimeout: 30 * time.Second, FastTimeout: 5 * time.Second, + ReloadTimeout: 5 * time.Second, }, } } @@ -612,6 +613,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString("shutdown:\n") fmt.Fprintf(&buf, " graceful_timeout: %ds # 优雅停止超时(SIGQUIT),等待活跃请求完成(0=使用默认30s)\n", int(cfg.Shutdown.GracefulTimeout.Seconds())) fmt.Fprintf(&buf, " fast_timeout: %ds # 快速停止超时(SIGINT/SIGTERM,0=使用默认5s)\n", int(cfg.Shutdown.FastTimeout.Seconds())) + fmt.Fprintf(&buf, " reload_timeout: %ds # 热重载启动等待超时(SIGHUP,0=使用默认5s)\n", int(cfg.Shutdown.ReloadTimeout.Seconds())) buf.WriteString("\n") // stream 配置 @@ -734,6 +736,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString("# - path: \"conf.d/*.yaml\" # 相对路径 + glob 模式\n") buf.WriteString("# - path: \"sites/example.yaml\" # 单个文件引入\n") buf.WriteString("# 支持循环检测和深度限制(最大 10 层)\n") + buf.WriteString("# 注意:只有 servers、stream、variables 会被合并,其他字段忽略\n") return buf.Bytes(), nil } diff --git a/internal/config/performance_config.go b/internal/config/performance_config.go index a644264..b2926b0 100644 --- a/internal/config/performance_config.go +++ b/internal/config/performance_config.go @@ -181,6 +181,11 @@ type ShutdownConfig struct { // 接收到 SIGINT 或 SIGTERM 信号后,等待服务器关闭的最大时间 // 默认: 5s(当值为 0 时使用默认值) FastTimeout time.Duration `yaml:"fast_timeout"` + + // ReloadTimeout 热重载启动等待超时(SIGHUP) + // 等待新服务器启动完成的最大时间,超时后视为启动成功 + // 默认: 5s(当值为 0 时使用默认值) + ReloadTimeout time.Duration `yaml:"reload_timeout"` } // ResolverConfig DNS 解析器配置。 diff --git a/internal/e2e/testutil/config.go b/internal/e2e/testutil/config.go index ea1319a..704447f 100644 --- a/internal/e2e/testutil/config.go +++ b/internal/e2e/testutil/config.go @@ -462,15 +462,6 @@ func (b *ConfigBuilder) WithRewrite(pattern, replacement string, opts ...Rewrite return b } -// WithCachePath 配置缓存路径。 -func (b *ConfigBuilder) WithCachePath(path string, maxSize int64) *ConfigBuilder { - b.cfg.CachePath = &config.ProxyCachePathConfig{ - Path: path, - MaxSize: maxSize, - } - return b -} - // WithResolver 配置 DNS 解析器。 func (b *ConfigBuilder) WithResolver(addresses []string, valid, timeout time.Duration) *ConfigBuilder { b.cfg.Resolver = config.ResolverConfig{ diff --git a/internal/matcher/integration_test.go b/internal/matcher/integration_test.go index 687ba46..c2b72a0 100644 --- a/internal/matcher/integration_test.go +++ b/internal/matcher/integration_test.go @@ -1,6 +1,7 @@ package matcher import ( + "errors" "testing" "github.com/valyala/fasthttp" @@ -107,4 +108,8 @@ func TestLocationEngine_PathConflict(t *testing.T) { if err == nil { t.Error("should fail on path conflict") } + var ce *ConflictError + if !errors.As(err, &ce) { + t.Errorf("expected *ConflictError, got %T: %v", err, err) + } } diff --git a/internal/matcher/location.go b/internal/matcher/location.go index a96153d..1df4521 100644 --- a/internal/matcher/location.go +++ b/internal/matcher/location.go @@ -173,8 +173,12 @@ func (e *LocationEngine) AddNamed(name string, handler fasthttp.RequestHandler) return errors.New("LocationEngine already initialized") } - if existing, ok := e.namedMatchers[name]; ok { - return fmt.Errorf("named location '@%s' already registered", existing.name) + if _, ok := e.namedMatchers[name]; ok { + return &ConflictError{ + Path: "@" + name, + ExistingType: "named", + NewType: "named", + } } matcher := NewNamedMatcher(name, handler) @@ -240,6 +244,21 @@ func (e *LocationEngine) MarkInitialized() { e.prefixTree.MarkInitialized() } +// ConflictError 路径冲突错误。 +// +// 当同一路径被重复注册为不同类型的 location 时返回此错误。 +// 调用方可通过 errors.As 检测此类型,区分冲突与致命错误。 +type ConflictError struct { + Path string + ExistingType string + NewType string +} + +func (e *ConflictError) Error() string { + return fmt.Sprintf("path conflict: '%s' already registered as '%s', trying to register as '%s'", + e.Path, e.ExistingType, e.NewType) +} + // checkConflict 检查路径冲突。 // // 参数: @@ -247,11 +266,10 @@ func (e *LocationEngine) MarkInitialized() { // - locationType: location 类型 // // 返回值: -// - error: 路径已存在时返回冲突错误 +// - error: 路径已存在时返回 *ConflictError func (e *LocationEngine) checkConflict(path, locationType string) error { if existing, ok := e.registeredPaths[path]; ok { - return fmt.Errorf("path conflict: '%s' already registered as '%s', trying to register as '%s'", - path, existing, locationType) + return &ConflictError{Path: path, ExistingType: existing, NewType: locationType} } e.registeredPaths[path] = locationType return nil diff --git a/internal/server/router.go b/internal/server/router.go index 238ba11..54c5290 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -67,7 +67,7 @@ func (s *Server) createProxyForConfig(proxyCfg *config.ProxyConfig) *proxy.Proxy // // 根据配置为 LocationEngine 注册代理路径,创建代理处理器和健康检查器。 // 支持通过 LocationType 配置不同的匹配方式。 -func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) { +func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) error { for i := range serverCfg.Proxy { proxyCfg := &serverCfg.Proxy[i] @@ -76,7 +76,6 @@ func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerC continue } - // 根据 LocationType 注册路由 locType := proxyCfg.LocationType if locType == "" { locType = matcher.LocationTypePrefix @@ -84,22 +83,47 @@ func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerC switch locType { case matcher.LocationTypeExact: - _ = s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) + if err := s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil { + if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil { + return err + } + } case matcher.LocationTypePrefixPriority: - _ = s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) + if err := s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil { + if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil { + return err + } + } case matcher.LocationTypeRegex, matcher.LocationTypeRegexCaseless: caseInsensitive := locType == matcher.LocationTypeRegexCaseless - _ = s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal) + if err := s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal); err != nil { + if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil { + return err + } + } case matcher.LocationTypeNamed: if proxyCfg.LocationName != "" { - _ = s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP) + if err := s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP); err != nil { + if err := s.handleRegistrationError("proxy", "@"+proxyCfg.LocationName, err); err != nil { + return err + } + } } case matcher.LocationTypePrefix: - _ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) + if err := s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil { + if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil { + return err + } + } default: - _ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) + if err := s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil { + if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil { + return err + } + } } } + return nil } // configureStaticHandler 配置静态文件处理器。 @@ -156,7 +180,7 @@ func (s *Server) configureStaticHandler(static *config.StaticConfig, cfg *config } // registerStaticHandlersWithLocationEngine 使用 LocationEngine 注册静态文件处理器。 -func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) { +func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) error { for _, static := range cfg.Static { staticHandler := s.configureStaticHandler(&static, cfg) path := static.Path @@ -164,7 +188,6 @@ func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConf path = "/" } - // 根据 LocationType 注册路由 locType := static.LocationType if locType == "" { locType = matcher.LocationTypePrefix @@ -172,15 +195,32 @@ func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConf switch locType { case matcher.LocationTypeExact: - _ = s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal) + if err := s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal); err != nil { + if err := s.handleRegistrationError("static", path, err); err != nil { + return err + } + } case matcher.LocationTypePrefixPriority: - _ = s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal) + if err := s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal); err != nil { + if err := s.handleRegistrationError("static", path, err); err != nil { + return err + } + } case matcher.LocationTypePrefix: - _ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal) + if err := s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal); err != nil { + if err := s.handleRegistrationError("static", path, err); err != nil { + return err + } + } default: - _ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal) + if err := s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal); err != nil { + if err := s.handleRegistrationError("static", path, err); err != nil { + return err + } + } } } + return nil } // registerProxyRoutes 注册代理路由。 @@ -324,9 +364,9 @@ func (s *Server) registerLuaRoutes(router *handler.Router, serverCfg *config.Ser // - 只有设置了 Route 字段的脚本才会被注册 // - 路由脚本不经过完整中间件链,只应用 accesslog 和 errorintercept // - 支持 exact、prefix、prefix_priority、regex、regex_caseless 匹配类型 -func (s *Server) registerLuaRoutesWithLocationEngine(serverCfg *config.ServerConfig) { +func (s *Server) registerLuaRoutesWithLocationEngine(serverCfg *config.ServerConfig) error { if s.luaEngine == nil || serverCfg.Lua == nil || !serverCfg.Lua.Enabled { - return + return nil } for _, script := range serverCfg.Lua.Scripts { @@ -348,17 +388,38 @@ func (s *Server) registerLuaRoutesWithLocationEngine(serverCfg *config.ServerCon switch routeType { case matcher.LocationTypeExact: - _ = s.locationEngine.AddExact(script.Route, handler, false) + if err := s.locationEngine.AddExact(script.Route, handler, false); err != nil { + if err := s.handleRegistrationError("lua", script.Route, err); err != nil { + return err + } + } case matcher.LocationTypePrefixPriority: - _ = s.locationEngine.AddPrefixPriority(script.Route, handler, false) + if err := s.locationEngine.AddPrefixPriority(script.Route, handler, false); err != nil { + if err := s.handleRegistrationError("lua", script.Route, err); err != nil { + return err + } + } case matcher.LocationTypeRegex: - _ = s.locationEngine.AddRegex(script.Route, handler, false, false) + if err := s.locationEngine.AddRegex(script.Route, handler, false, false); err != nil { + if err := s.handleRegistrationError("lua", script.Route, err); err != nil { + return err + } + } case matcher.LocationTypeRegexCaseless: - _ = s.locationEngine.AddRegex(script.Route, handler, true, false) + if err := s.locationEngine.AddRegex(script.Route, handler, true, false); err != nil { + if err := s.handleRegistrationError("lua", script.Route, err); err != nil { + return err + } + } default: - _ = s.locationEngine.AddPrefix(script.Route, handler, false) + if err := s.locationEngine.AddPrefix(script.Route, handler, false); err != nil { + if err := s.handleRegistrationError("lua", script.Route, err); err != nil { + return err + } + } } } + return nil } // wrapRoutedHandler 为路由处理器包装基础中间件链。 diff --git a/internal/server/server.go b/internal/server/server.go index 99dced5..5c0e4ce 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -21,6 +21,7 @@ package server import ( "crypto/tls" + "errors" "fmt" "net" "os" @@ -96,6 +97,15 @@ func New(cfg *config.Config) *Server { return &Server{config: cfg} } +func (s *Server) handleRegistrationError(source, path string, err error) error { + var ce *matcher.ConflictError + if errors.As(err, &ce) { + logging.Warn().Msgf("Route registration skipped (%s %s): %s", source, path, err) + return nil + } + return fmt.Errorf("%s route %s: %w", source, path, err) +} + // getServerName 根据配置返回服务器名称。 // // 当 ServerTokens 为 false 时隐藏版本号,仅返回 "lolly"。 @@ -308,32 +318,31 @@ func (s *Server) Start() error { func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error) { listenAddr := cfg.Listen + if s.upgradeManager != nil && s.upgradeManager.IsChild() { + inherited, _ := s.upgradeManager.GetInheritedListeners() + if ln := s.matchInheritedListener(inherited, listenAddr); ln != nil { + return ln, nil + } + } + + if len(s.listeners) > 0 { + if ln := s.matchInheritedListener(s.listeners, listenAddr); ln != nil { + return ln, nil + } + } + if strings.HasPrefix(listenAddr, "unix:") { - // Unix Socket 模式 socketPath := listenAddr[5:] - // 1. 检查继承的监听器(热升级场景) - if s.upgradeManager != nil && s.upgradeManager.IsChild() { - inherited, _ := s.upgradeManager.GetInheritedListeners() - for _, ln := range inherited { - if ln.Addr().Network() == "unix" && ln.Addr().String() == socketPath { - return ln, nil - } - } - } - - // 2. 清理旧 socket 文件 if _, err := os.Stat(socketPath); err == nil { _ = os.Remove(socketPath) } - // 3. 创建 Unix socket listener listener, err := net.Listen("unix", socketPath) if err != nil { return nil, fmt.Errorf("create unix socket failed: %w", err) } - // 4. 设置 socket 文件权限 mode := 0o666 if cfg.UnixSocket.Mode > 0 { mode = cfg.UnixSocket.Mode @@ -342,19 +351,95 @@ func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error) logging.Warn().Err(err).Msg("Failed to set socket file permissions") } - // 5. 设置文件所有权(需要 root 权限) if cfg.UnixSocket.User != "" || cfg.UnixSocket.Group != "" { - // 简化处理:仅记录警告,实际实现需要 syscall.Chown logging.Warn().Msg("Unix socket user/group config requires root privileges, skipped") } return listener, nil } - // TCP 模式 return net.Listen("tcp", listenAddr) } +func (s *Server) matchInheritedListener(inherited []net.Listener, listenAddr string) net.Listener { + if len(inherited) == 0 { + return nil + } + + if strings.HasPrefix(listenAddr, "unix:") { + socketPath := listenAddr[5:] + for _, ln := range inherited { + if ln == nil { + continue + } + if ln.Addr().Network() == "unix" && ln.Addr().String() == socketPath { + return ln + } + } + return nil + } + + for _, ln := range inherited { + if ln == nil { + continue + } + if ln.Addr().Network() != "tcp" { + continue + } + if s.tcpAddrMatch(ln.Addr().String(), listenAddr) { + return ln + } + } + return nil +} + +func (s *Server) tcpAddrMatch(inherited, target string) bool { + if inherited == target { + return true + } + host1, port1, err1 := net.SplitHostPort(inherited) + host2, port2, err2 := net.SplitHostPort(target) + if err1 != nil || err2 != nil { + return false + } + if port1 != port2 { + return false + } + return host1 == host2 || isAnyAddr(host1) || isAnyAddr(host2) +} + +func isAnyAddr(host string) bool { + if host == "" { + return true + } + ip := net.ParseIP(host) + return ip != nil && ip.IsUnspecified() +} + +// DupListener 复制 listener 的文件描述符,返回独立的 listener。 +// +// 用于热重载场景:新旧 server 各自持有独立 FD,互不影响关闭操作。 +func DupListener(ln net.Listener) (net.Listener, error) { + switch l := ln.(type) { + case *net.TCPListener: + file, err := l.File() + if err != nil { + return nil, fmt.Errorf("dup tcp listener: %w", err) + } + defer file.Close() + return net.FileListener(file) + case *net.UnixListener: + file, err := l.File() + if err != nil { + return nil, fmt.Errorf("dup unix listener: %w", err) + } + defer file.Close() + return net.FileListener(file) + default: + return nil, fmt.Errorf("unsupported listener type: %T", ln) + } +} + // startSingleMode 单服务器模式启动。 // // 在单服务器模式下,创建单一路由器,注册代理路由和静态文件服务, @@ -382,39 +467,56 @@ func (s *Server) startSingleMode() error { if err != nil { logging.Error().Msg("Failed to create status handler: " + err.Error()) } else { - _ = s.locationEngine.AddExact(statusHandler.Path(), statusHandler.ServeHTTP, false) + if err := s.locationEngine.AddExact(statusHandler.Path(), statusHandler.ServeHTTP, false); err != nil { + if err := s.handleRegistrationError("status", statusHandler.Path(), err); err != nil { + return err + } + } } } - // 注册 pprof 性能分析端点(如果配置) if s.config.Monitoring.Pprof.Enabled { pprofHandler, err := NewPprofHandler(&s.config.Monitoring.Pprof) if err != nil { logging.Error().Msg("Failed to create pprof handler: " + err.Error()) } else { - _ = s.locationEngine.AddExact(pprofHandler.Path(), pprofHandler.ServeHTTP, false) - _ = s.locationEngine.AddPrefixPriority(pprofHandler.Path()+"/", pprofHandler.ServeHTTP, false) + if err := s.locationEngine.AddExact(pprofHandler.Path(), pprofHandler.ServeHTTP, false); err != nil { + if err := s.handleRegistrationError("pprof", pprofHandler.Path(), err); err != nil { + return err + } + } + if err := s.locationEngine.AddPrefixPriority(pprofHandler.Path()+"/", pprofHandler.ServeHTTP, false); err != nil { + if err := s.handleRegistrationError("pprof", pprofHandler.Path()+"/", err); err != nil { + return err + } + } } } - // 注册缓存清理 API(如果配置) if serverCfg.CacheAPI != nil && serverCfg.CacheAPI.Enabled { purgeHandler, err := NewPurgeHandler(s, serverCfg.CacheAPI) if err != nil { logging.Error().Msg("Failed to create cache purge handler: " + err.Error()) } else { - _ = s.locationEngine.AddExact(purgeHandler.Path(), purgeHandler.ServeHTTP, false) + if err := s.locationEngine.AddExact(purgeHandler.Path(), purgeHandler.ServeHTTP, false); err != nil { + if err := s.handleRegistrationError("cache-purge", purgeHandler.Path(), err); err != nil { + return err + } + } } } - // 注册代理路由 - s.registerProxyRoutesWithLocationEngine(serverCfg) + if err := s.registerProxyRoutesWithLocationEngine(serverCfg); err != nil { + return err + } - // Lua 路由 - s.registerLuaRoutesWithLocationEngine(serverCfg) + if err := s.registerLuaRoutesWithLocationEngine(serverCfg); err != nil { + return err + } - // 静态文件服务 - s.registerStaticHandlersWithLocationEngine(serverCfg) + if err := s.registerStaticHandlersWithLocationEngine(serverCfg); err != nil { + return err + } // 标记 LocationEngine 初始化完成 s.locationEngine.MarkInitialized() @@ -613,22 +715,28 @@ func (s *Server) startVHostMode() error { // // 注意事项: // - 每个服务器有独立的中间件配置 -// - 热升级场景下回退到虚拟主机模式 // - 使用 goroutine 并行启动多个服务器 func (s *Server) startMultiServerMode() error { - // 热升级检测:multi_server 热升级未实现,回退到 vhost 模式 - if os.Getenv("GRACEFUL_UPGRADE") == "1" { - logging.Warn().Msg("multi_server mode not implemented for graceful upgrade, falling back to vhost mode") - return s.startVHostMode() - } - s.fastServers = make([]*fasthttp.Server, len(s.config.Servers)) s.listeners = make([]net.Listener, len(s.config.Servers)) + for i := range s.config.Servers { + serverCfg := &s.config.Servers[i] + ln, err := s.createListener(serverCfg) + if err != nil { + for j := 0; j < i; j++ { + if s.listeners[j] != nil { + _ = s.listeners[j].Close() + } + } + return fmt.Errorf("failed to listen on %s: %w", serverCfg.Listen, err) + } + s.listeners[i] = ln + } + var wg sync.WaitGroup errCh := make(chan error, len(s.config.Servers)) - // 并行创建监听器和 fasthttp.Server for i := range s.config.Servers { wg.Add(1) go func(idx int) { @@ -636,15 +744,6 @@ func (s *Server) startMultiServerMode() error { serverCfg := &s.config.Servers[idx] - // 创建监听器 - ln, err := s.createListener(serverCfg) - if err != nil { - errCh <- fmt.Errorf("failed to listen on %s: %w", serverCfg.Listen, err) - return - } - s.listeners[idx] = ln - - // 创建路由器 router := handler.NewRouter() // 注册状态监控端点(仅默认服务器) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 938cb80..a68c82d 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -18,6 +18,7 @@ import ( "fmt" "net" "os" + "strings" "sync" "testing" "time" @@ -26,6 +27,7 @@ import ( "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/lua" + "rua.plus/lolly/internal/matcher" "rua.plus/lolly/internal/middleware/accesslog" "rua.plus/lolly/internal/middleware/security" "rua.plus/lolly/internal/proxy" @@ -989,6 +991,157 @@ func TestCreateListener_UnixSocketCleanup(t *testing.T) { defer ln.Close() } +func TestHandleRegistrationError_ConflictWarning(t *testing.T) { + s := &Server{} + err := s.handleRegistrationError("proxy", "/api", + &matcher.ConflictError{Path: "/api", ExistingType: "exact", NewType: "prefix"}) + if err != nil { + t.Errorf("conflict should return nil, got: %v", err) + } +} + +func TestHandleRegistrationError_FatalError(t *testing.T) { + s := &Server{} + err := s.handleRegistrationError("proxy", "/api", + fmt.Errorf("invalid regex pattern: missing closing parenthesis")) + if err == nil { + t.Error("fatal error should return non-nil") + } + if !strings.Contains(err.Error(), "proxy route /api") { + t.Errorf("error should wrap context, got: %v", err) + } +} + +func TestDupListener_TCP(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + duped, err := DupListener(ln) + if err != nil { + t.Fatalf("DupListener() error: %v", err) + } + defer duped.Close() + + if duped.Addr().Network() != "tcp" { + t.Errorf("expected tcp, got %s", duped.Addr().Network()) + } + if duped.Addr().String() != ln.Addr().String() { + t.Errorf("expected same address %s, got %s", ln.Addr().String(), duped.Addr().String()) + } +} + +func TestDupListener_Unix(t *testing.T) { + dir := t.TempDir() + socketPath := dir + "/dup.sock" + ln, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + duped, err := DupListener(ln) + if err != nil { + t.Fatalf("DupListener() error: %v", err) + } + defer duped.Close() +} + +func TestDupListener_Unsupported(t *testing.T) { + _, err := DupListener(struct{ net.Listener }{}) + if err == nil { + t.Error("expected error for unsupported type") + } +} + +func TestTcpAddrMatch(t *testing.T) { + s := &Server{} + + tests := []struct { + inherited string + target string + want bool + }{ + {"127.0.0.1:8080", "127.0.0.1:8080", true}, + {"0.0.0.0:8080", ":8080", true}, + {"[::]:8080", ":8080", true}, + {"0.0.0.0:8080", "0.0.0.0:8080", true}, + {"0.0.0.0:8080", "127.0.0.1:8080", true}, + {"127.0.0.1:8080", "0.0.0.0:8080", true}, + {"127.0.0.1:8080", ":9090", false}, + {"127.0.0.1:8080", "192.168.1.1:8080", false}, + } + + for _, tt := range tests { + got := s.tcpAddrMatch(tt.inherited, tt.target) + if got != tt.want { + t.Errorf("tcpAddrMatch(%q, %q) = %v, want %v", tt.inherited, tt.target, got, tt.want) + } + } +} + +func TestMatchInheritedListener_TCP(t *testing.T) { + s := &Server{} + + ln1, _ := net.Listen("tcp", "127.0.0.1:0") + defer ln1.Close() + + ln2, _ := net.Listen("tcp", "127.0.0.1:0") + defer ln2.Close() + + inherited := []net.Listener{ln1, ln2} + + result := s.matchInheritedListener(inherited, "0.0.0.0:99999") + if result != nil { + t.Error("expected nil for non-matching address") + } + + addr1 := ln1.Addr().String() + result = s.matchInheritedListener(inherited, addr1) + if result != ln1 { + t.Errorf("expected ln1 for address %s", addr1) + } +} + +func TestMatchInheritedListener_Empty(t *testing.T) { + s := &Server{} + result := s.matchInheritedListener(nil, ":8080") + if result != nil { + t.Error("expected nil for empty inherited list") + } +} + +func TestMatchInheritedListener_PresetListeners(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{Listen: "127.0.0.1:0"}}, + } + s := New(cfg) + + ln, err := s.createListener(&cfg.Servers[0]) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + s.SetListeners([]net.Listener{ln}) + + addr := ln.Addr().String() + cfg.Servers[0].Listen = addr + + matched, err := s.createListener(&cfg.Servers[0]) + if err != nil { + t.Fatalf("createListener with preset should reuse: %v", err) + } + if matched == nil { + t.Fatal("expected non-nil listener from preset match") + } + if matched.Addr().String() != addr { + t.Errorf("expected same address %s, got %s", addr, matched.Addr().String()) + } +} + // TestServer_StatsMethods 测试服务器统计方法。 func TestServer_StatsMethods(t *testing.T) { cfg := &config.Config{