Merge pull request #2 from DefectingCat/fix/identified-issues

fix/identified issues
This commit is contained in:
Sonetto 2026-06-03 13:24:15 +08:00 committed by GitHub
commit 634fc5b51b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 896 additions and 138 deletions

View File

@ -14,7 +14,6 @@
| `go.sum` | 依赖版本锁定 | | `go.sum` | 依赖版本锁定 |
| `Makefile` | 构建脚本,支持多平台编译、测试、覆盖率 | | `Makefile` | 构建脚本,支持多平台编译、测试、覆盖率 |
| `lolly.yaml` | 默认配置文件示例 | | `lolly.yaml` | 默认配置文件示例 |
| `config.example.yaml` | 完整配置文件示例(所有字段枚举) |
| `.gitignore` | Git 忽略规则 | | `.gitignore` | Git 忽略规则 |
## Subdirectories ## Subdirectories
@ -27,7 +26,6 @@
| `examples/` | Lua 脚本示例 | | `examples/` | Lua 脚本示例 |
| `html/` | 静态 HTML 文件(测试/示例) | | `html/` | 静态 HTML 文件(测试/示例) |
| `scripts/` | 构建/测试辅助脚本(回归检测) | | `scripts/` | 构建/测试辅助脚本(回归检测) |
| `.github/` | CI/CD 工作流定义 |
## For AI Agents ## For AI Agents
@ -43,6 +41,7 @@
- 运行测试前确保依赖已下载:`go mod download` - 运行测试前确保依赖已下载:`go mod download`
- 测试覆盖率目标 >80% - 测试覆盖率目标 >80%
- 使用 `make check` 运行完整检查fmt + lint + test - 使用 `make check` 运行完整检查fmt + lint + test
- 使用 `lolly --generate-config` 生成完整配置文件模板
### Common Patterns ### Common Patterns
- 配置结构体使用 `yaml` 标签,通过 `gopkg.in/yaml.v3` 解析 - 配置结构体使用 `yaml` 标签,通过 `gopkg.in/yaml.v3` 解析

View File

@ -58,8 +58,7 @@ http3: {} # HTTP/3 配置
resolver: {} # DNS 解析配置 resolver: {} # DNS 解析配置
performance: {} # 性能配置 performance: {} # 性能配置
shutdown: {} # 关闭配置 shutdown: {} # 关闭配置
include: [] # 配置引入 include: [] # 配置引入
cache_path: {} # 缓存路径配置
``` ```
### 运行模式 ### 运行模式

View File

@ -4,6 +4,7 @@ package app
import ( import (
"fmt" "fmt"
"net"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
@ -164,11 +165,129 @@ func (a *App) reloadConfig() {
return return
} }
if a.srv == nil {
a.cfg = newCfg
a.logger = logging.NewAppLogger(&newCfg.Logging)
a.logger.LogStartup("Config reloaded (no running server)", nil)
return
}
if a.requiresFullRestart(newCfg) {
logging.Warn().Msg("Config requires full restart (listen address or mode changed). Use SIGUSR2 for graceful upgrade.")
return
}
listeners := a.srv.GetListeners()
if len(listeners) == 0 {
a.logger.Error().Msg("Cannot reload: server has no saved listeners")
return
}
duped := make([]net.Listener, len(listeners))
for i, ln := range listeners {
duped[i], err = server.DupListener(ln)
if err != nil {
for j := 0; j < i; j++ {
_ = duped[j].Close()
}
a.logger.Error().Err(err).Msg("Failed to dup listener for reload")
return
}
}
newSrv := server.New(newCfg)
if a.resv != nil {
newSrv.SetResolver(a.resv)
}
newSrv.SetListeners(duped)
startErr := make(chan error, 1)
go func() {
if err := newSrv.Start(); err != nil {
startErr <- err
}
}()
reloadTimeout := a.cfg.Shutdown.ReloadTimeout
if reloadTimeout <= 0 {
reloadTimeout = 5 * time.Second
}
select {
case err := <-startErr:
a.logger.Error().Err(err).Msg("Failed to start new server with reloaded config")
for _, ln := range duped {
_ = ln.Close()
}
return
case <-time.After(reloadTimeout):
}
oldSrv := a.srv
oldHTTP2 := a.http2Srv
oldHTTP3 := a.http3Srv
a.srv = newSrv
a.cfg = newCfg a.cfg = newCfg
a.logger = logging.NewAppLogger(&newCfg.Logging) a.logger = logging.NewAppLogger(&newCfg.Logging)
a.http2Srv = nil
a.http3Srv = nil
a.initVariables()
a.initHTTP2()
a.initHTTP3()
if a.upgradeMgr != nil {
a.upgradeMgr.SetListeners(newSrv.GetListeners())
}
go func() {
if oldHTTP2 != nil {
_ = oldHTTP2.Stop()
}
if oldHTTP3 != nil {
_ = oldHTTP3.Stop()
}
_ = oldSrv.GracefulStop(30 * time.Second)
}()
a.logger.LogStartup("Config reloaded successfully", nil) a.logger.LogStartup("Config reloaded successfully", nil)
} }
func (a *App) requiresFullRestart(newCfg *config.Config) bool {
if a.cfg.GetMode() != newCfg.GetMode() {
return true
}
oldMode := a.cfg.GetMode()
switch oldMode {
case config.ServerModeSingle:
if len(a.cfg.Servers) > 0 && len(newCfg.Servers) > 0 {
if a.cfg.Servers[0].Listen != newCfg.Servers[0].Listen {
return true
}
}
case config.ServerModeVHost:
if len(a.cfg.Servers) != len(newCfg.Servers) {
return true
}
if len(a.cfg.Servers) > 0 && len(newCfg.Servers) > 0 {
if a.cfg.Servers[0].Listen != newCfg.Servers[0].Listen {
return true
}
}
case config.ServerModeMultiServer:
if len(a.cfg.Servers) != len(newCfg.Servers) {
return true
}
for i := range a.cfg.Servers {
if a.cfg.Servers[i].Listen != newCfg.Servers[i].Listen {
return true
}
}
}
return false
}
func (a *App) gracefulUpgrade() { func (a *App) gracefulUpgrade() {
execPath, err := os.Executable() execPath, err := os.Executable()
if err != nil { if err != nil {

View File

@ -461,7 +461,6 @@ func TestHandleSignal_SIGINT(t *testing.T) {
// TestHandleSignal_SIGHUP 测试 SIGHUP 信号处理(重载配置) // TestHandleSignal_SIGHUP 测试 SIGHUP 信号处理(重载配置)
func TestHandleSignal_SIGHUP(t *testing.T) { func TestHandleSignal_SIGHUP(t *testing.T) {
// 创建临时配置文件
tmpDir := t.TempDir() tmpDir := t.TempDir()
cfgPath := filepath.Join(tmpDir, "config.yaml") cfgPath := filepath.Join(tmpDir, "config.yaml")
cfgContent := ` cfgContent := `
@ -1448,14 +1447,12 @@ logging:
} }
app.logger = setupTestLogger() app.logger = setupTestLogger()
// 发送 SIGHUP 信号
result := app.handleSignal(syscall.SIGHUP) result := app.handleSignal(syscall.SIGHUP)
if result != true { if result != true {
t.Error("Expected handleSignal(SIGHUP) to return true") t.Error("Expected handleSignal(SIGHUP) to return true")
} }
// 验证配置已更新
if app.cfg.Servers[0].Listen != ":7070" { if app.cfg.Servers[0].Listen != ":7070" {
t.Errorf("Expected listen ':7070', got '%s'", app.cfg.Servers[0].Listen) t.Errorf("Expected listen ':7070', got '%s'", app.cfg.Servers[0].Listen)
} }

View File

@ -2,54 +2,6 @@ package config
import "time" import "time"
// ProxyCachePathConfig 缓存路径配置(磁盘持久化)。
//
// 配置磁盘缓存路径和相关参数,支持 L1/L2 分层缓存架构。
// 配置后,代理缓存将持久化到磁盘,服务重启后可恢复。
//
// 注意事项:
// - Path 为必填项,指定缓存根目录
// - Levels 支持最多 3 级目录(如 "1:2:2"
// - MaxSize 为 0 表示不限制大小
// - L1MaxEntries/L1MaxSize 为 0 时使用默认值
//
// 使用示例:
//
// cache_path:
// path: "/var/cache/lolly"
// levels: "1:2"
// max_size: "1GB"
// inactive: "60m"
// l1_max_entries: 10000
type ProxyCachePathConfig struct {
// Path 缓存根目录
Path string `yaml:"path"`
// Levels 目录层级,如 "1:2" 表示两级目录
Levels string `yaml:"levels"`
// MaxSize 最大缓存大小(字节)
MaxSize int64 `yaml:"max_size"`
// Inactive 未访问淘汰时间
Inactive time.Duration `yaml:"inactive"`
// Purger 是否启用后台清理
Purger bool `yaml:"purger"`
// PurgerInterval 清理间隔
PurgerInterval time.Duration `yaml:"purger_interval"`
// L1MaxEntries L1 最大条目数
L1MaxEntries int64 `yaml:"l1_max_entries"`
// L1MaxSize L1 最大内存大小
L1MaxSize int64 `yaml:"l1_max_size"`
// PromoteThreshold 提升到 L1 的访问阈值
PromoteThreshold int `yaml:"promote_threshold"`
}
// ProxyCacheConfig 代理缓存配置。 // ProxyCacheConfig 代理缓存配置。
// //
// 缓存后端响应,减少重复请求,提高响应速度。 // 缓存后端响应,减少重复请求,提高响应速度。

View File

@ -20,6 +20,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"path/filepath"
"strconv" "strconv"
"strings" "strings"
@ -81,8 +82,7 @@ type Config struct {
Resolver ResolverConfig `yaml:"resolver"` Resolver ResolverConfig `yaml:"resolver"`
Performance PerformanceConfig `yaml:"performance"` Performance PerformanceConfig `yaml:"performance"`
Shutdown ShutdownConfig `yaml:"shutdown"` Shutdown ShutdownConfig `yaml:"shutdown"`
Include []IncludeConfig `yaml:"include"` // 配置引入,支持从其他文件引入配置片段 Include []IncludeConfig `yaml:"include"` // 配置引入,支持从其他文件引入配置片段
CachePath *ProxyCachePathConfig `yaml:"cache_path"` // 缓存路径配置(磁盘持久化)
} }
// parseSize 解析大小字符串(支持 k, m 单位)。 // parseSize 解析大小字符串(支持 k, m 单位)。
@ -139,6 +139,17 @@ func Load(path string) (*Config, error) {
return nil, fmt.Errorf("解析配置文件失败: %w", err) return nil, fmt.Errorf("解析配置文件失败: %w", err)
} }
if len(cfg.Include) > 0 {
absPath, err := filepath.Abs(path)
if err != nil {
return nil, fmt.Errorf("获取配置文件绝对路径失败: %w", err)
}
visited := map[string]bool{absPath: true}
if err := processIncludes(&cfg, filepath.Dir(path), 0, visited); err != nil {
return nil, fmt.Errorf("处理配置引入失败: %w", err)
}
}
if err := Validate(&cfg); err != nil { if err := Validate(&cfg); err != nil {
return nil, fmt.Errorf("配置验证失败: %w", err) return nil, fmt.Errorf("配置验证失败: %w", err)
} }
@ -146,6 +157,81 @@ func Load(path string) (*Config, error) {
return &cfg, nil return &cfg, nil
} }
const maxIncludeDepth = 10
func processIncludes(cfg *Config, baseDir string, depth int, visited map[string]bool) error {
if depth >= maxIncludeDepth {
return fmt.Errorf("配置引入嵌套深度超过 %d 层", maxIncludeDepth)
}
for _, inc := range cfg.Include {
pattern := inc.Path
if !filepath.IsAbs(pattern) {
pattern = filepath.Join(baseDir, pattern)
}
matches, err := filepath.Glob(pattern)
if err != nil {
return fmt.Errorf("展开引入路径 %q 失败: %w", inc.Path, err)
}
if len(matches) == 0 {
return fmt.Errorf("引入路径 %q 未匹配到任何文件", inc.Path)
}
for _, match := range matches {
absMatch, err := filepath.Abs(match)
if err != nil {
return fmt.Errorf("获取引入文件绝对路径失败 %q: %w", match, err)
}
if visited[absMatch] {
return fmt.Errorf("检测到循环引入: %s", absMatch)
}
visited[absMatch] = true
info, err := os.Stat(match)
if err != nil {
return fmt.Errorf("读取引入文件 %q 失败: %w", match, err)
}
if info.IsDir() {
delete(visited, absMatch)
continue
}
data, err := os.ReadFile(match)
if err != nil {
return fmt.Errorf("读取引入文件 %q 失败: %w", match, err)
}
var included Config
if err := yaml.Unmarshal(data, &included); err != nil {
return fmt.Errorf("解析引入文件 %q 失败: %w", match, err)
}
if len(included.Include) > 0 {
if err := processIncludes(&included, filepath.Dir(match), depth+1, visited); err != nil {
return err
}
}
cfg.Servers = append(cfg.Servers, included.Servers...)
cfg.Stream = append(cfg.Stream, included.Stream...)
for k, v := range included.Variables.Set {
if _, exists := cfg.Variables.Set[k]; !exists {
if cfg.Variables.Set == nil {
cfg.Variables.Set = make(map[string]string)
}
cfg.Variables.Set[k] = v
}
}
delete(visited, absMatch)
}
}
cfg.Include = nil
return nil
}
// LoadFromString 从 YAML 字符串加载配置。 // LoadFromString 从 YAML 字符串加载配置。
// //
// 解析 YAML 格式的配置字符串,适用于从环境变量或命令行参数加载配置。 // 解析 YAML 格式的配置字符串,适用于从环境变量或命令行参数加载配置。

View File

@ -4,6 +4,7 @@ package config
import ( import (
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
) )
@ -491,3 +492,273 @@ func TestProxyBufferingConfig_ParseBuffers(t *testing.T) {
}) })
} }
} }
func TestLoad_Include(t *testing.T) {
t.Run("append servers from include", func(t *testing.T) {
tmpDir := t.TempDir()
mainCfg := `
servers:
- listen: ":8080"
name: main
include:
- path: "conf.d/*.yaml"
`
incCfg := `
servers:
- listen: ":9090"
name: included
`
confDir := filepath.Join(tmpDir, "conf.d")
if err := os.MkdirAll(confDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(confDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil {
t.Fatal(err)
}
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if len(cfg.Servers) != 2 {
t.Fatalf("expected 2 servers, got %d", len(cfg.Servers))
}
if cfg.Servers[0].Name != "main" {
t.Errorf("Servers[0].Name = %q, want %q", cfg.Servers[0].Name, "main")
}
if cfg.Servers[1].Name != "included" {
t.Errorf("Servers[1].Name = %q, want %q", cfg.Servers[1].Name, "included")
}
})
t.Run("merge variables with main priority", func(t *testing.T) {
tmpDir := t.TempDir()
mainCfg := `
servers:
- listen: ":8080"
variables:
set:
app: lolly
env: production
include:
- path: "extra.yaml"
`
incCfg := `
servers:
- listen: ":9090"
variables:
set:
app: other
debug: "true"
`
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil {
t.Fatal(err)
}
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if cfg.Variables.Set["app"] != "lolly" {
t.Errorf("app = %q, want %q (main should win)", cfg.Variables.Set["app"], "lolly")
}
if cfg.Variables.Set["debug"] != "true" {
t.Errorf("debug = %q, want %q (included should fill missing)", cfg.Variables.Set["debug"], "true")
}
if cfg.Variables.Set["env"] != "production" {
t.Errorf("env = %q, want %q", cfg.Variables.Set["env"], "production")
}
})
t.Run("no matches returns error", func(t *testing.T) {
tmpDir := t.TempDir()
mainCfg := `
servers:
- listen: ":8080"
include:
- path: "nonexistent/*.yaml"
`
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
t.Fatal(err)
}
_, err := Load(filepath.Join(tmpDir, "config.yaml"))
if err == nil {
t.Error("expected error for non-matching glob pattern")
}
})
t.Run("circular include detected immediately", func(t *testing.T) {
tmpDir := t.TempDir()
cfg1 := `
servers:
- listen: ":8080"
include:
- path: "b.yaml"
`
cfg2 := `
servers:
- listen: ":9090"
include:
- path: "a.yaml"
`
if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg1), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(cfg2), 0o644); err != nil {
t.Fatal(err)
}
_, err := Load(filepath.Join(tmpDir, "a.yaml"))
if err == nil {
t.Error("expected error for circular include")
}
if !strings.Contains(err.Error(), "循环引入") {
t.Errorf("error should mention circular include, got: %v", err)
}
})
t.Run("self include detected", func(t *testing.T) {
tmpDir := t.TempDir()
cfg := `
servers:
- listen: ":8080"
include:
- path: "a.yaml"
`
if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg), 0o644); err != nil {
t.Fatal(err)
}
_, err := Load(filepath.Join(tmpDir, "a.yaml"))
if err == nil {
t.Error("expected error for self include")
}
})
t.Run("diamond include works", func(t *testing.T) {
tmpDir := t.TempDir()
mainCfg := `
servers:
- listen: ":8080"
include:
- path: "b.yaml"
- path: "c.yaml"
`
bCfg := `
servers:
- listen: ":9090"
include:
- path: "shared.yaml"
`
cCfg := `
servers:
- listen: ":9091"
include:
- path: "shared.yaml"
`
sharedCfg := `
variables:
set:
shared: value
`
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(bCfg), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "c.yaml"), []byte(cCfg), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "shared.yaml"), []byte(sharedCfg), 0o644); err != nil {
t.Fatal(err)
}
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
if err != nil {
t.Fatalf("diamond include should work: %v", err)
}
if len(cfg.Servers) != 3 {
t.Errorf("expected 3 servers, got %d", len(cfg.Servers))
}
if cfg.Variables.Set["shared"] != "value" {
t.Error("shared variable should be merged")
}
})
t.Run("empty include list is no-op", func(t *testing.T) {
tmpDir := t.TempDir()
mainCfg := `
servers:
- listen: ":8080"
`
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
t.Fatal(err)
}
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if len(cfg.Servers) != 1 {
t.Errorf("expected 1 server, got %d", len(cfg.Servers))
}
})
t.Run("append stream from include", func(t *testing.T) {
tmpDir := t.TempDir()
mainCfg := `
servers:
- listen: ":8080"
stream:
- listen: ":5432"
protocol: tcp
upstream:
targets:
- addr: "127.0.0.1:9000"
include:
- path: "stream.yaml"
`
incCfg := `
stream:
- listen: ":5433"
protocol: udp
upstream:
targets:
- addr: "127.0.0.1:9001"
`
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(tmpDir, "stream.yaml"), []byte(incCfg), 0o644); err != nil {
t.Fatal(err)
}
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
if err != nil {
t.Fatalf("Load() failed: %v", err)
}
if len(cfg.Stream) != 2 {
t.Fatalf("expected 2 streams, got %d", len(cfg.Stream))
}
})
}

View File

@ -222,6 +222,7 @@ func DefaultConfig() *Config {
Shutdown: ShutdownConfig{ Shutdown: ShutdownConfig{
GracefulTimeout: 30 * time.Second, GracefulTimeout: 30 * time.Second,
FastTimeout: 5 * time.Second, FastTimeout: 5 * time.Second,
ReloadTimeout: 5 * time.Second,
}, },
} }
} }
@ -612,6 +613,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
buf.WriteString("shutdown:\n") buf.WriteString("shutdown:\n")
fmt.Fprintf(&buf, " graceful_timeout: %ds # 优雅停止超时SIGQUIT等待活跃请求完成0=使用默认30s\n", int(cfg.Shutdown.GracefulTimeout.Seconds())) fmt.Fprintf(&buf, " graceful_timeout: %ds # 优雅停止超时SIGQUIT等待活跃请求完成0=使用默认30s\n", int(cfg.Shutdown.GracefulTimeout.Seconds()))
fmt.Fprintf(&buf, " fast_timeout: %ds # 快速停止超时SIGINT/SIGTERM0=使用默认5s\n", int(cfg.Shutdown.FastTimeout.Seconds())) fmt.Fprintf(&buf, " fast_timeout: %ds # 快速停止超时SIGINT/SIGTERM0=使用默认5s\n", int(cfg.Shutdown.FastTimeout.Seconds()))
fmt.Fprintf(&buf, " reload_timeout: %ds # 热重载启动等待超时SIGHUP0=使用默认5s\n", int(cfg.Shutdown.ReloadTimeout.Seconds()))
buf.WriteString("\n") buf.WriteString("\n")
// stream 配置 // stream 配置
@ -734,6 +736,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
buf.WriteString("# - path: \"conf.d/*.yaml\" # 相对路径 + glob 模式\n") buf.WriteString("# - path: \"conf.d/*.yaml\" # 相对路径 + glob 模式\n")
buf.WriteString("# - path: \"sites/example.yaml\" # 单个文件引入\n") buf.WriteString("# - path: \"sites/example.yaml\" # 单个文件引入\n")
buf.WriteString("# 支持循环检测和深度限制(最大 10 层)\n") buf.WriteString("# 支持循环检测和深度限制(最大 10 层)\n")
buf.WriteString("# 注意:只有 servers、stream、variables 会被合并,其他字段忽略\n")
return buf.Bytes(), nil return buf.Bytes(), nil
} }

View File

@ -181,6 +181,11 @@ type ShutdownConfig struct {
// 接收到 SIGINT 或 SIGTERM 信号后,等待服务器关闭的最大时间 // 接收到 SIGINT 或 SIGTERM 信号后,等待服务器关闭的最大时间
// 默认: 5s当值为 0 时使用默认值) // 默认: 5s当值为 0 时使用默认值)
FastTimeout time.Duration `yaml:"fast_timeout"` FastTimeout time.Duration `yaml:"fast_timeout"`
// ReloadTimeout 热重载启动等待超时SIGHUP
// 等待新服务器启动完成的最大时间,超时后视为启动成功
// 默认: 5s当值为 0 时使用默认值)
ReloadTimeout time.Duration `yaml:"reload_timeout"`
} }
// ResolverConfig DNS 解析器配置。 // ResolverConfig DNS 解析器配置。

View File

@ -462,15 +462,6 @@ func (b *ConfigBuilder) WithRewrite(pattern, replacement string, opts ...Rewrite
return b return b
} }
// WithCachePath 配置缓存路径。
func (b *ConfigBuilder) WithCachePath(path string, maxSize int64) *ConfigBuilder {
b.cfg.CachePath = &config.ProxyCachePathConfig{
Path: path,
MaxSize: maxSize,
}
return b
}
// WithResolver 配置 DNS 解析器。 // WithResolver 配置 DNS 解析器。
func (b *ConfigBuilder) WithResolver(addresses []string, valid, timeout time.Duration) *ConfigBuilder { func (b *ConfigBuilder) WithResolver(addresses []string, valid, timeout time.Duration) *ConfigBuilder {
b.cfg.Resolver = config.ResolverConfig{ b.cfg.Resolver = config.ResolverConfig{

View File

@ -1,6 +1,7 @@
package matcher package matcher
import ( import (
"errors"
"testing" "testing"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
@ -107,4 +108,8 @@ func TestLocationEngine_PathConflict(t *testing.T) {
if err == nil { if err == nil {
t.Error("should fail on path conflict") t.Error("should fail on path conflict")
} }
var ce *ConflictError
if !errors.As(err, &ce) {
t.Errorf("expected *ConflictError, got %T: %v", err, err)
}
} }

View File

@ -173,8 +173,12 @@ func (e *LocationEngine) AddNamed(name string, handler fasthttp.RequestHandler)
return errors.New("LocationEngine already initialized") return errors.New("LocationEngine already initialized")
} }
if existing, ok := e.namedMatchers[name]; ok { if _, ok := e.namedMatchers[name]; ok {
return fmt.Errorf("named location '@%s' already registered", existing.name) return &ConflictError{
Path: "@" + name,
ExistingType: "named",
NewType: "named",
}
} }
matcher := NewNamedMatcher(name, handler) matcher := NewNamedMatcher(name, handler)
@ -240,6 +244,21 @@ func (e *LocationEngine) MarkInitialized() {
e.prefixTree.MarkInitialized() e.prefixTree.MarkInitialized()
} }
// ConflictError 路径冲突错误。
//
// 当同一路径被重复注册为不同类型的 location 时返回此错误。
// 调用方可通过 errors.As 检测此类型,区分冲突与致命错误。
type ConflictError struct {
Path string
ExistingType string
NewType string
}
func (e *ConflictError) Error() string {
return fmt.Sprintf("path conflict: '%s' already registered as '%s', trying to register as '%s'",
e.Path, e.ExistingType, e.NewType)
}
// checkConflict 检查路径冲突。 // checkConflict 检查路径冲突。
// //
// 参数: // 参数:
@ -247,11 +266,10 @@ func (e *LocationEngine) MarkInitialized() {
// - locationType: location 类型 // - locationType: location 类型
// //
// 返回值: // 返回值:
// - error: 路径已存在时返回冲突错误 // - error: 路径已存在时返回 *ConflictError
func (e *LocationEngine) checkConflict(path, locationType string) error { func (e *LocationEngine) checkConflict(path, locationType string) error {
if existing, ok := e.registeredPaths[path]; ok { if existing, ok := e.registeredPaths[path]; ok {
return fmt.Errorf("path conflict: '%s' already registered as '%s', trying to register as '%s'", return &ConflictError{Path: path, ExistingType: existing, NewType: locationType}
path, existing, locationType)
} }
e.registeredPaths[path] = locationType e.registeredPaths[path] = locationType
return nil return nil

View File

@ -67,7 +67,7 @@ func (s *Server) createProxyForConfig(proxyCfg *config.ProxyConfig) *proxy.Proxy
// //
// 根据配置为 LocationEngine 注册代理路径,创建代理处理器和健康检查器。 // 根据配置为 LocationEngine 注册代理路径,创建代理处理器和健康检查器。
// 支持通过 LocationType 配置不同的匹配方式。 // 支持通过 LocationType 配置不同的匹配方式。
func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) { func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) error {
for i := range serverCfg.Proxy { for i := range serverCfg.Proxy {
proxyCfg := &serverCfg.Proxy[i] proxyCfg := &serverCfg.Proxy[i]
@ -76,7 +76,6 @@ func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerC
continue continue
} }
// 根据 LocationType 注册路由
locType := proxyCfg.LocationType locType := proxyCfg.LocationType
if locType == "" { if locType == "" {
locType = matcher.LocationTypePrefix locType = matcher.LocationTypePrefix
@ -84,22 +83,47 @@ func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerC
switch locType { switch locType {
case matcher.LocationTypeExact: case matcher.LocationTypeExact:
_ = s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) if err := s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil {
if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil {
return err
}
}
case matcher.LocationTypePrefixPriority: case matcher.LocationTypePrefixPriority:
_ = s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) if err := s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil {
if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil {
return err
}
}
case matcher.LocationTypeRegex, matcher.LocationTypeRegexCaseless: case matcher.LocationTypeRegex, matcher.LocationTypeRegexCaseless:
caseInsensitive := locType == matcher.LocationTypeRegexCaseless caseInsensitive := locType == matcher.LocationTypeRegexCaseless
_ = s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal) if err := s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal); err != nil {
if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil {
return err
}
}
case matcher.LocationTypeNamed: case matcher.LocationTypeNamed:
if proxyCfg.LocationName != "" { if proxyCfg.LocationName != "" {
_ = s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP) if err := s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP); err != nil {
if err := s.handleRegistrationError("proxy", "@"+proxyCfg.LocationName, err); err != nil {
return err
}
}
} }
case matcher.LocationTypePrefix: case matcher.LocationTypePrefix:
_ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) if err := s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil {
if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil {
return err
}
}
default: default:
_ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) if err := s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil {
if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil {
return err
}
}
} }
} }
return nil
} }
// configureStaticHandler 配置静态文件处理器。 // configureStaticHandler 配置静态文件处理器。
@ -156,7 +180,7 @@ func (s *Server) configureStaticHandler(static *config.StaticConfig, cfg *config
} }
// registerStaticHandlersWithLocationEngine 使用 LocationEngine 注册静态文件处理器。 // registerStaticHandlersWithLocationEngine 使用 LocationEngine 注册静态文件处理器。
func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) { func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) error {
for _, static := range cfg.Static { for _, static := range cfg.Static {
staticHandler := s.configureStaticHandler(&static, cfg) staticHandler := s.configureStaticHandler(&static, cfg)
path := static.Path path := static.Path
@ -164,7 +188,6 @@ func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConf
path = "/" path = "/"
} }
// 根据 LocationType 注册路由
locType := static.LocationType locType := static.LocationType
if locType == "" { if locType == "" {
locType = matcher.LocationTypePrefix locType = matcher.LocationTypePrefix
@ -172,15 +195,32 @@ func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConf
switch locType { switch locType {
case matcher.LocationTypeExact: case matcher.LocationTypeExact:
_ = s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal) if err := s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal); err != nil {
if err := s.handleRegistrationError("static", path, err); err != nil {
return err
}
}
case matcher.LocationTypePrefixPriority: case matcher.LocationTypePrefixPriority:
_ = s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal) if err := s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal); err != nil {
if err := s.handleRegistrationError("static", path, err); err != nil {
return err
}
}
case matcher.LocationTypePrefix: case matcher.LocationTypePrefix:
_ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal) if err := s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal); err != nil {
if err := s.handleRegistrationError("static", path, err); err != nil {
return err
}
}
default: default:
_ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal) if err := s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal); err != nil {
if err := s.handleRegistrationError("static", path, err); err != nil {
return err
}
}
} }
} }
return nil
} }
// registerProxyRoutes 注册代理路由。 // registerProxyRoutes 注册代理路由。
@ -324,9 +364,9 @@ func (s *Server) registerLuaRoutes(router *handler.Router, serverCfg *config.Ser
// - 只有设置了 Route 字段的脚本才会被注册 // - 只有设置了 Route 字段的脚本才会被注册
// - 路由脚本不经过完整中间件链,只应用 accesslog 和 errorintercept // - 路由脚本不经过完整中间件链,只应用 accesslog 和 errorintercept
// - 支持 exact、prefix、prefix_priority、regex、regex_caseless 匹配类型 // - 支持 exact、prefix、prefix_priority、regex、regex_caseless 匹配类型
func (s *Server) registerLuaRoutesWithLocationEngine(serverCfg *config.ServerConfig) { func (s *Server) registerLuaRoutesWithLocationEngine(serverCfg *config.ServerConfig) error {
if s.luaEngine == nil || serverCfg.Lua == nil || !serverCfg.Lua.Enabled { if s.luaEngine == nil || serverCfg.Lua == nil || !serverCfg.Lua.Enabled {
return return nil
} }
for _, script := range serverCfg.Lua.Scripts { for _, script := range serverCfg.Lua.Scripts {
@ -348,17 +388,38 @@ func (s *Server) registerLuaRoutesWithLocationEngine(serverCfg *config.ServerCon
switch routeType { switch routeType {
case matcher.LocationTypeExact: case matcher.LocationTypeExact:
_ = s.locationEngine.AddExact(script.Route, handler, false) if err := s.locationEngine.AddExact(script.Route, handler, false); err != nil {
if err := s.handleRegistrationError("lua", script.Route, err); err != nil {
return err
}
}
case matcher.LocationTypePrefixPriority: case matcher.LocationTypePrefixPriority:
_ = s.locationEngine.AddPrefixPriority(script.Route, handler, false) if err := s.locationEngine.AddPrefixPriority(script.Route, handler, false); err != nil {
if err := s.handleRegistrationError("lua", script.Route, err); err != nil {
return err
}
}
case matcher.LocationTypeRegex: case matcher.LocationTypeRegex:
_ = s.locationEngine.AddRegex(script.Route, handler, false, false) if err := s.locationEngine.AddRegex(script.Route, handler, false, false); err != nil {
if err := s.handleRegistrationError("lua", script.Route, err); err != nil {
return err
}
}
case matcher.LocationTypeRegexCaseless: case matcher.LocationTypeRegexCaseless:
_ = s.locationEngine.AddRegex(script.Route, handler, true, false) if err := s.locationEngine.AddRegex(script.Route, handler, true, false); err != nil {
if err := s.handleRegistrationError("lua", script.Route, err); err != nil {
return err
}
}
default: default:
_ = s.locationEngine.AddPrefix(script.Route, handler, false) if err := s.locationEngine.AddPrefix(script.Route, handler, false); err != nil {
if err := s.handleRegistrationError("lua", script.Route, err); err != nil {
return err
}
}
} }
} }
return nil
} }
// wrapRoutedHandler 为路由处理器包装基础中间件链。 // wrapRoutedHandler 为路由处理器包装基础中间件链。

View File

@ -21,6 +21,7 @@ package server
import ( import (
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"net" "net"
"os" "os"
@ -96,6 +97,15 @@ func New(cfg *config.Config) *Server {
return &Server{config: cfg} return &Server{config: cfg}
} }
func (s *Server) handleRegistrationError(source, path string, err error) error {
var ce *matcher.ConflictError
if errors.As(err, &ce) {
logging.Warn().Msgf("Route registration skipped (%s %s): %s", source, path, err)
return nil
}
return fmt.Errorf("%s route %s: %w", source, path, err)
}
// getServerName 根据配置返回服务器名称。 // getServerName 根据配置返回服务器名称。
// //
// 当 ServerTokens 为 false 时隐藏版本号,仅返回 "lolly"。 // 当 ServerTokens 为 false 时隐藏版本号,仅返回 "lolly"。
@ -308,32 +318,31 @@ func (s *Server) Start() error {
func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error) { func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error) {
listenAddr := cfg.Listen listenAddr := cfg.Listen
if s.upgradeManager != nil && s.upgradeManager.IsChild() {
inherited, _ := s.upgradeManager.GetInheritedListeners()
if ln := s.matchInheritedListener(inherited, listenAddr); ln != nil {
return ln, nil
}
}
if len(s.listeners) > 0 {
if ln := s.matchInheritedListener(s.listeners, listenAddr); ln != nil {
return ln, nil
}
}
if strings.HasPrefix(listenAddr, "unix:") { if strings.HasPrefix(listenAddr, "unix:") {
// Unix Socket 模式
socketPath := listenAddr[5:] socketPath := listenAddr[5:]
// 1. 检查继承的监听器(热升级场景)
if s.upgradeManager != nil && s.upgradeManager.IsChild() {
inherited, _ := s.upgradeManager.GetInheritedListeners()
for _, ln := range inherited {
if ln.Addr().Network() == "unix" && ln.Addr().String() == socketPath {
return ln, nil
}
}
}
// 2. 清理旧 socket 文件
if _, err := os.Stat(socketPath); err == nil { if _, err := os.Stat(socketPath); err == nil {
_ = os.Remove(socketPath) _ = os.Remove(socketPath)
} }
// 3. 创建 Unix socket listener
listener, err := net.Listen("unix", socketPath) listener, err := net.Listen("unix", socketPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("create unix socket failed: %w", err) return nil, fmt.Errorf("create unix socket failed: %w", err)
} }
// 4. 设置 socket 文件权限
mode := 0o666 mode := 0o666
if cfg.UnixSocket.Mode > 0 { if cfg.UnixSocket.Mode > 0 {
mode = cfg.UnixSocket.Mode mode = cfg.UnixSocket.Mode
@ -342,19 +351,95 @@ func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error)
logging.Warn().Err(err).Msg("Failed to set socket file permissions") logging.Warn().Err(err).Msg("Failed to set socket file permissions")
} }
// 5. 设置文件所有权(需要 root 权限)
if cfg.UnixSocket.User != "" || cfg.UnixSocket.Group != "" { if cfg.UnixSocket.User != "" || cfg.UnixSocket.Group != "" {
// 简化处理:仅记录警告,实际实现需要 syscall.Chown
logging.Warn().Msg("Unix socket user/group config requires root privileges, skipped") logging.Warn().Msg("Unix socket user/group config requires root privileges, skipped")
} }
return listener, nil return listener, nil
} }
// TCP 模式
return net.Listen("tcp", listenAddr) return net.Listen("tcp", listenAddr)
} }
func (s *Server) matchInheritedListener(inherited []net.Listener, listenAddr string) net.Listener {
if len(inherited) == 0 {
return nil
}
if strings.HasPrefix(listenAddr, "unix:") {
socketPath := listenAddr[5:]
for _, ln := range inherited {
if ln == nil {
continue
}
if ln.Addr().Network() == "unix" && ln.Addr().String() == socketPath {
return ln
}
}
return nil
}
for _, ln := range inherited {
if ln == nil {
continue
}
if ln.Addr().Network() != "tcp" {
continue
}
if s.tcpAddrMatch(ln.Addr().String(), listenAddr) {
return ln
}
}
return nil
}
func (s *Server) tcpAddrMatch(inherited, target string) bool {
if inherited == target {
return true
}
host1, port1, err1 := net.SplitHostPort(inherited)
host2, port2, err2 := net.SplitHostPort(target)
if err1 != nil || err2 != nil {
return false
}
if port1 != port2 {
return false
}
return host1 == host2 || isAnyAddr(host1) || isAnyAddr(host2)
}
func isAnyAddr(host string) bool {
if host == "" {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsUnspecified()
}
// DupListener 复制 listener 的文件描述符,返回独立的 listener。
//
// 用于热重载场景:新旧 server 各自持有独立 FD互不影响关闭操作。
func DupListener(ln net.Listener) (net.Listener, error) {
switch l := ln.(type) {
case *net.TCPListener:
file, err := l.File()
if err != nil {
return nil, fmt.Errorf("dup tcp listener: %w", err)
}
defer file.Close()
return net.FileListener(file)
case *net.UnixListener:
file, err := l.File()
if err != nil {
return nil, fmt.Errorf("dup unix listener: %w", err)
}
defer file.Close()
return net.FileListener(file)
default:
return nil, fmt.Errorf("unsupported listener type: %T", ln)
}
}
// startSingleMode 单服务器模式启动。 // startSingleMode 单服务器模式启动。
// //
// 在单服务器模式下,创建单一路由器,注册代理路由和静态文件服务, // 在单服务器模式下,创建单一路由器,注册代理路由和静态文件服务,
@ -382,39 +467,56 @@ func (s *Server) startSingleMode() error {
if err != nil { if err != nil {
logging.Error().Msg("Failed to create status handler: " + err.Error()) logging.Error().Msg("Failed to create status handler: " + err.Error())
} else { } else {
_ = s.locationEngine.AddExact(statusHandler.Path(), statusHandler.ServeHTTP, false) if err := s.locationEngine.AddExact(statusHandler.Path(), statusHandler.ServeHTTP, false); err != nil {
if err := s.handleRegistrationError("status", statusHandler.Path(), err); err != nil {
return err
}
}
} }
} }
// 注册 pprof 性能分析端点(如果配置)
if s.config.Monitoring.Pprof.Enabled { if s.config.Monitoring.Pprof.Enabled {
pprofHandler, err := NewPprofHandler(&s.config.Monitoring.Pprof) pprofHandler, err := NewPprofHandler(&s.config.Monitoring.Pprof)
if err != nil { if err != nil {
logging.Error().Msg("Failed to create pprof handler: " + err.Error()) logging.Error().Msg("Failed to create pprof handler: " + err.Error())
} else { } else {
_ = s.locationEngine.AddExact(pprofHandler.Path(), pprofHandler.ServeHTTP, false) if err := s.locationEngine.AddExact(pprofHandler.Path(), pprofHandler.ServeHTTP, false); err != nil {
_ = s.locationEngine.AddPrefixPriority(pprofHandler.Path()+"/", pprofHandler.ServeHTTP, false) if err := s.handleRegistrationError("pprof", pprofHandler.Path(), err); err != nil {
return err
}
}
if err := s.locationEngine.AddPrefixPriority(pprofHandler.Path()+"/", pprofHandler.ServeHTTP, false); err != nil {
if err := s.handleRegistrationError("pprof", pprofHandler.Path()+"/", err); err != nil {
return err
}
}
} }
} }
// 注册缓存清理 API如果配置
if serverCfg.CacheAPI != nil && serverCfg.CacheAPI.Enabled { if serverCfg.CacheAPI != nil && serverCfg.CacheAPI.Enabled {
purgeHandler, err := NewPurgeHandler(s, serverCfg.CacheAPI) purgeHandler, err := NewPurgeHandler(s, serverCfg.CacheAPI)
if err != nil { if err != nil {
logging.Error().Msg("Failed to create cache purge handler: " + err.Error()) logging.Error().Msg("Failed to create cache purge handler: " + err.Error())
} else { } else {
_ = s.locationEngine.AddExact(purgeHandler.Path(), purgeHandler.ServeHTTP, false) if err := s.locationEngine.AddExact(purgeHandler.Path(), purgeHandler.ServeHTTP, false); err != nil {
if err := s.handleRegistrationError("cache-purge", purgeHandler.Path(), err); err != nil {
return err
}
}
} }
} }
// 注册代理路由 if err := s.registerProxyRoutesWithLocationEngine(serverCfg); err != nil {
s.registerProxyRoutesWithLocationEngine(serverCfg) return err
}
// Lua 路由 if err := s.registerLuaRoutesWithLocationEngine(serverCfg); err != nil {
s.registerLuaRoutesWithLocationEngine(serverCfg) return err
}
// 静态文件服务 if err := s.registerStaticHandlersWithLocationEngine(serverCfg); err != nil {
s.registerStaticHandlersWithLocationEngine(serverCfg) return err
}
// 标记 LocationEngine 初始化完成 // 标记 LocationEngine 初始化完成
s.locationEngine.MarkInitialized() s.locationEngine.MarkInitialized()
@ -613,22 +715,28 @@ func (s *Server) startVHostMode() error {
// //
// 注意事项: // 注意事项:
// - 每个服务器有独立的中间件配置 // - 每个服务器有独立的中间件配置
// - 热升级场景下回退到虚拟主机模式
// - 使用 goroutine 并行启动多个服务器 // - 使用 goroutine 并行启动多个服务器
func (s *Server) startMultiServerMode() error { func (s *Server) startMultiServerMode() error {
// 热升级检测multi_server 热升级未实现,回退到 vhost 模式
if os.Getenv("GRACEFUL_UPGRADE") == "1" {
logging.Warn().Msg("multi_server mode not implemented for graceful upgrade, falling back to vhost mode")
return s.startVHostMode()
}
s.fastServers = make([]*fasthttp.Server, len(s.config.Servers)) s.fastServers = make([]*fasthttp.Server, len(s.config.Servers))
s.listeners = make([]net.Listener, len(s.config.Servers)) s.listeners = make([]net.Listener, len(s.config.Servers))
for i := range s.config.Servers {
serverCfg := &s.config.Servers[i]
ln, err := s.createListener(serverCfg)
if err != nil {
for j := 0; j < i; j++ {
if s.listeners[j] != nil {
_ = s.listeners[j].Close()
}
}
return fmt.Errorf("failed to listen on %s: %w", serverCfg.Listen, err)
}
s.listeners[i] = ln
}
var wg sync.WaitGroup var wg sync.WaitGroup
errCh := make(chan error, len(s.config.Servers)) errCh := make(chan error, len(s.config.Servers))
// 并行创建监听器和 fasthttp.Server
for i := range s.config.Servers { for i := range s.config.Servers {
wg.Add(1) wg.Add(1)
go func(idx int) { go func(idx int) {
@ -636,15 +744,6 @@ func (s *Server) startMultiServerMode() error {
serverCfg := &s.config.Servers[idx] serverCfg := &s.config.Servers[idx]
// 创建监听器
ln, err := s.createListener(serverCfg)
if err != nil {
errCh <- fmt.Errorf("failed to listen on %s: %w", serverCfg.Listen, err)
return
}
s.listeners[idx] = ln
// 创建路由器
router := handler.NewRouter() router := handler.NewRouter()
// 注册状态监控端点(仅默认服务器) // 注册状态监控端点(仅默认服务器)

View File

@ -18,6 +18,7 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -26,6 +27,7 @@ import (
"rua.plus/lolly/internal/config" "rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/lua" "rua.plus/lolly/internal/lua"
"rua.plus/lolly/internal/matcher"
"rua.plus/lolly/internal/middleware/accesslog" "rua.plus/lolly/internal/middleware/accesslog"
"rua.plus/lolly/internal/middleware/security" "rua.plus/lolly/internal/middleware/security"
"rua.plus/lolly/internal/proxy" "rua.plus/lolly/internal/proxy"
@ -989,6 +991,157 @@ func TestCreateListener_UnixSocketCleanup(t *testing.T) {
defer ln.Close() defer ln.Close()
} }
func TestHandleRegistrationError_ConflictWarning(t *testing.T) {
s := &Server{}
err := s.handleRegistrationError("proxy", "/api",
&matcher.ConflictError{Path: "/api", ExistingType: "exact", NewType: "prefix"})
if err != nil {
t.Errorf("conflict should return nil, got: %v", err)
}
}
func TestHandleRegistrationError_FatalError(t *testing.T) {
s := &Server{}
err := s.handleRegistrationError("proxy", "/api",
fmt.Errorf("invalid regex pattern: missing closing parenthesis"))
if err == nil {
t.Error("fatal error should return non-nil")
}
if !strings.Contains(err.Error(), "proxy route /api") {
t.Errorf("error should wrap context, got: %v", err)
}
}
func TestDupListener_TCP(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
duped, err := DupListener(ln)
if err != nil {
t.Fatalf("DupListener() error: %v", err)
}
defer duped.Close()
if duped.Addr().Network() != "tcp" {
t.Errorf("expected tcp, got %s", duped.Addr().Network())
}
if duped.Addr().String() != ln.Addr().String() {
t.Errorf("expected same address %s, got %s", ln.Addr().String(), duped.Addr().String())
}
}
func TestDupListener_Unix(t *testing.T) {
dir := t.TempDir()
socketPath := dir + "/dup.sock"
ln, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatal(err)
}
defer ln.Close()
duped, err := DupListener(ln)
if err != nil {
t.Fatalf("DupListener() error: %v", err)
}
defer duped.Close()
}
func TestDupListener_Unsupported(t *testing.T) {
_, err := DupListener(struct{ net.Listener }{})
if err == nil {
t.Error("expected error for unsupported type")
}
}
func TestTcpAddrMatch(t *testing.T) {
s := &Server{}
tests := []struct {
inherited string
target string
want bool
}{
{"127.0.0.1:8080", "127.0.0.1:8080", true},
{"0.0.0.0:8080", ":8080", true},
{"[::]:8080", ":8080", true},
{"0.0.0.0:8080", "0.0.0.0:8080", true},
{"0.0.0.0:8080", "127.0.0.1:8080", true},
{"127.0.0.1:8080", "0.0.0.0:8080", true},
{"127.0.0.1:8080", ":9090", false},
{"127.0.0.1:8080", "192.168.1.1:8080", false},
}
for _, tt := range tests {
got := s.tcpAddrMatch(tt.inherited, tt.target)
if got != tt.want {
t.Errorf("tcpAddrMatch(%q, %q) = %v, want %v", tt.inherited, tt.target, got, tt.want)
}
}
}
func TestMatchInheritedListener_TCP(t *testing.T) {
s := &Server{}
ln1, _ := net.Listen("tcp", "127.0.0.1:0")
defer ln1.Close()
ln2, _ := net.Listen("tcp", "127.0.0.1:0")
defer ln2.Close()
inherited := []net.Listener{ln1, ln2}
result := s.matchInheritedListener(inherited, "0.0.0.0:99999")
if result != nil {
t.Error("expected nil for non-matching address")
}
addr1 := ln1.Addr().String()
result = s.matchInheritedListener(inherited, addr1)
if result != ln1 {
t.Errorf("expected ln1 for address %s", addr1)
}
}
func TestMatchInheritedListener_Empty(t *testing.T) {
s := &Server{}
result := s.matchInheritedListener(nil, ":8080")
if result != nil {
t.Error("expected nil for empty inherited list")
}
}
func TestMatchInheritedListener_PresetListeners(t *testing.T) {
cfg := &config.Config{
Servers: []config.ServerConfig{{Listen: "127.0.0.1:0"}},
}
s := New(cfg)
ln, err := s.createListener(&cfg.Servers[0])
if err != nil {
t.Fatal(err)
}
defer ln.Close()
s.SetListeners([]net.Listener{ln})
addr := ln.Addr().String()
cfg.Servers[0].Listen = addr
matched, err := s.createListener(&cfg.Servers[0])
if err != nil {
t.Fatalf("createListener with preset should reuse: %v", err)
}
if matched == nil {
t.Fatal("expected non-nil listener from preset match")
}
if matched.Addr().String() != addr {
t.Errorf("expected same address %s, got %s", addr, matched.Addr().String())
}
}
// TestServer_StatsMethods 测试服务器统计方法。 // TestServer_StatsMethods 测试服务器统计方法。
func TestServer_StatsMethods(t *testing.T) { func TestServer_StatsMethods(t *testing.T) {
cfg := &config.Config{ cfg := &config.Config{