Merge pull request #2 from DefectingCat/fix/identified-issues
fix/identified issues
This commit is contained in:
commit
634fc5b51b
@ -14,7 +14,6 @@
|
||||
| `go.sum` | 依赖版本锁定 |
|
||||
| `Makefile` | 构建脚本,支持多平台编译、测试、覆盖率 |
|
||||
| `lolly.yaml` | 默认配置文件示例 |
|
||||
| `config.example.yaml` | 完整配置文件示例(所有字段枚举) |
|
||||
| `.gitignore` | Git 忽略规则 |
|
||||
|
||||
## Subdirectories
|
||||
@ -27,7 +26,6 @@
|
||||
| `examples/` | Lua 脚本示例 |
|
||||
| `html/` | 静态 HTML 文件(测试/示例) |
|
||||
| `scripts/` | 构建/测试辅助脚本(回归检测) |
|
||||
| `.github/` | CI/CD 工作流定义 |
|
||||
|
||||
## For AI Agents
|
||||
|
||||
@ -43,6 +41,7 @@
|
||||
- 运行测试前确保依赖已下载:`go mod download`
|
||||
- 测试覆盖率目标 >80%
|
||||
- 使用 `make check` 运行完整检查(fmt + lint + test)
|
||||
- 使用 `lolly --generate-config` 生成完整配置文件模板
|
||||
|
||||
### Common Patterns
|
||||
- 配置结构体使用 `yaml` 标签,通过 `gopkg.in/yaml.v3` 解析
|
||||
|
||||
@ -58,8 +58,7 @@ http3: {} # HTTP/3 配置
|
||||
resolver: {} # DNS 解析配置
|
||||
performance: {} # 性能配置
|
||||
shutdown: {} # 关闭配置
|
||||
include: [] # 配置引入
|
||||
cache_path: {} # 缓存路径配置
|
||||
include: [] # 配置引入
|
||||
```
|
||||
|
||||
### 运行模式
|
||||
|
||||
@ -4,6 +4,7 @@ package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
@ -164,11 +165,129 @@ func (a *App) reloadConfig() {
|
||||
return
|
||||
}
|
||||
|
||||
if a.srv == nil {
|
||||
a.cfg = newCfg
|
||||
a.logger = logging.NewAppLogger(&newCfg.Logging)
|
||||
a.logger.LogStartup("Config reloaded (no running server)", nil)
|
||||
return
|
||||
}
|
||||
|
||||
if a.requiresFullRestart(newCfg) {
|
||||
logging.Warn().Msg("Config requires full restart (listen address or mode changed). Use SIGUSR2 for graceful upgrade.")
|
||||
return
|
||||
}
|
||||
|
||||
listeners := a.srv.GetListeners()
|
||||
if len(listeners) == 0 {
|
||||
a.logger.Error().Msg("Cannot reload: server has no saved listeners")
|
||||
return
|
||||
}
|
||||
|
||||
duped := make([]net.Listener, len(listeners))
|
||||
for i, ln := range listeners {
|
||||
duped[i], err = server.DupListener(ln)
|
||||
if err != nil {
|
||||
for j := 0; j < i; j++ {
|
||||
_ = duped[j].Close()
|
||||
}
|
||||
a.logger.Error().Err(err).Msg("Failed to dup listener for reload")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
newSrv := server.New(newCfg)
|
||||
if a.resv != nil {
|
||||
newSrv.SetResolver(a.resv)
|
||||
}
|
||||
newSrv.SetListeners(duped)
|
||||
|
||||
startErr := make(chan error, 1)
|
||||
go func() {
|
||||
if err := newSrv.Start(); err != nil {
|
||||
startErr <- err
|
||||
}
|
||||
}()
|
||||
|
||||
reloadTimeout := a.cfg.Shutdown.ReloadTimeout
|
||||
if reloadTimeout <= 0 {
|
||||
reloadTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-startErr:
|
||||
a.logger.Error().Err(err).Msg("Failed to start new server with reloaded config")
|
||||
for _, ln := range duped {
|
||||
_ = ln.Close()
|
||||
}
|
||||
return
|
||||
case <-time.After(reloadTimeout):
|
||||
}
|
||||
|
||||
oldSrv := a.srv
|
||||
oldHTTP2 := a.http2Srv
|
||||
oldHTTP3 := a.http3Srv
|
||||
|
||||
a.srv = newSrv
|
||||
a.cfg = newCfg
|
||||
a.logger = logging.NewAppLogger(&newCfg.Logging)
|
||||
a.http2Srv = nil
|
||||
a.http3Srv = nil
|
||||
|
||||
a.initVariables()
|
||||
a.initHTTP2()
|
||||
a.initHTTP3()
|
||||
|
||||
if a.upgradeMgr != nil {
|
||||
a.upgradeMgr.SetListeners(newSrv.GetListeners())
|
||||
}
|
||||
|
||||
go func() {
|
||||
if oldHTTP2 != nil {
|
||||
_ = oldHTTP2.Stop()
|
||||
}
|
||||
if oldHTTP3 != nil {
|
||||
_ = oldHTTP3.Stop()
|
||||
}
|
||||
_ = oldSrv.GracefulStop(30 * time.Second)
|
||||
}()
|
||||
|
||||
a.logger.LogStartup("Config reloaded successfully", nil)
|
||||
}
|
||||
|
||||
func (a *App) requiresFullRestart(newCfg *config.Config) bool {
|
||||
if a.cfg.GetMode() != newCfg.GetMode() {
|
||||
return true
|
||||
}
|
||||
oldMode := a.cfg.GetMode()
|
||||
switch oldMode {
|
||||
case config.ServerModeSingle:
|
||||
if len(a.cfg.Servers) > 0 && len(newCfg.Servers) > 0 {
|
||||
if a.cfg.Servers[0].Listen != newCfg.Servers[0].Listen {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case config.ServerModeVHost:
|
||||
if len(a.cfg.Servers) != len(newCfg.Servers) {
|
||||
return true
|
||||
}
|
||||
if len(a.cfg.Servers) > 0 && len(newCfg.Servers) > 0 {
|
||||
if a.cfg.Servers[0].Listen != newCfg.Servers[0].Listen {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case config.ServerModeMultiServer:
|
||||
if len(a.cfg.Servers) != len(newCfg.Servers) {
|
||||
return true
|
||||
}
|
||||
for i := range a.cfg.Servers {
|
||||
if a.cfg.Servers[i].Listen != newCfg.Servers[i].Listen {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *App) gracefulUpgrade() {
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
|
||||
@ -461,7 +461,6 @@ func TestHandleSignal_SIGINT(t *testing.T) {
|
||||
|
||||
// TestHandleSignal_SIGHUP 测试 SIGHUP 信号处理(重载配置)
|
||||
func TestHandleSignal_SIGHUP(t *testing.T) {
|
||||
// 创建临时配置文件
|
||||
tmpDir := t.TempDir()
|
||||
cfgPath := filepath.Join(tmpDir, "config.yaml")
|
||||
cfgContent := `
|
||||
@ -1448,14 +1447,12 @@ logging:
|
||||
}
|
||||
app.logger = setupTestLogger()
|
||||
|
||||
// 发送 SIGHUP 信号
|
||||
result := app.handleSignal(syscall.SIGHUP)
|
||||
|
||||
if result != true {
|
||||
t.Error("Expected handleSignal(SIGHUP) to return true")
|
||||
}
|
||||
|
||||
// 验证配置已更新
|
||||
if app.cfg.Servers[0].Listen != ":7070" {
|
||||
t.Errorf("Expected listen ':7070', got '%s'", app.cfg.Servers[0].Listen)
|
||||
}
|
||||
|
||||
@ -2,54 +2,6 @@ package config
|
||||
|
||||
import "time"
|
||||
|
||||
// ProxyCachePathConfig 缓存路径配置(磁盘持久化)。
|
||||
//
|
||||
// 配置磁盘缓存路径和相关参数,支持 L1/L2 分层缓存架构。
|
||||
// 配置后,代理缓存将持久化到磁盘,服务重启后可恢复。
|
||||
//
|
||||
// 注意事项:
|
||||
// - Path 为必填项,指定缓存根目录
|
||||
// - Levels 支持最多 3 级目录(如 "1:2:2")
|
||||
// - MaxSize 为 0 表示不限制大小
|
||||
// - L1MaxEntries/L1MaxSize 为 0 时使用默认值
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// cache_path:
|
||||
// path: "/var/cache/lolly"
|
||||
// levels: "1:2"
|
||||
// max_size: "1GB"
|
||||
// inactive: "60m"
|
||||
// l1_max_entries: 10000
|
||||
type ProxyCachePathConfig struct {
|
||||
// Path 缓存根目录
|
||||
Path string `yaml:"path"`
|
||||
|
||||
// Levels 目录层级,如 "1:2" 表示两级目录
|
||||
Levels string `yaml:"levels"`
|
||||
|
||||
// MaxSize 最大缓存大小(字节)
|
||||
MaxSize int64 `yaml:"max_size"`
|
||||
|
||||
// Inactive 未访问淘汰时间
|
||||
Inactive time.Duration `yaml:"inactive"`
|
||||
|
||||
// Purger 是否启用后台清理
|
||||
Purger bool `yaml:"purger"`
|
||||
|
||||
// PurgerInterval 清理间隔
|
||||
PurgerInterval time.Duration `yaml:"purger_interval"`
|
||||
|
||||
// L1MaxEntries L1 最大条目数
|
||||
L1MaxEntries int64 `yaml:"l1_max_entries"`
|
||||
|
||||
// L1MaxSize L1 最大内存大小
|
||||
L1MaxSize int64 `yaml:"l1_max_size"`
|
||||
|
||||
// PromoteThreshold 提升到 L1 的访问阈值
|
||||
PromoteThreshold int `yaml:"promote_threshold"`
|
||||
}
|
||||
|
||||
// ProxyCacheConfig 代理缓存配置。
|
||||
//
|
||||
// 缓存后端响应,减少重复请求,提高响应速度。
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@ -81,8 +82,7 @@ type Config struct {
|
||||
Resolver ResolverConfig `yaml:"resolver"`
|
||||
Performance PerformanceConfig `yaml:"performance"`
|
||||
Shutdown ShutdownConfig `yaml:"shutdown"`
|
||||
Include []IncludeConfig `yaml:"include"` // 配置引入,支持从其他文件引入配置片段
|
||||
CachePath *ProxyCachePathConfig `yaml:"cache_path"` // 缓存路径配置(磁盘持久化)
|
||||
Include []IncludeConfig `yaml:"include"` // 配置引入,支持从其他文件引入配置片段
|
||||
}
|
||||
|
||||
// parseSize 解析大小字符串(支持 k, m 单位)。
|
||||
@ -139,6 +139,17 @@ func Load(path string) (*Config, error) {
|
||||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
if len(cfg.Include) > 0 {
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取配置文件绝对路径失败: %w", err)
|
||||
}
|
||||
visited := map[string]bool{absPath: true}
|
||||
if err := processIncludes(&cfg, filepath.Dir(path), 0, visited); err != nil {
|
||||
return nil, fmt.Errorf("处理配置引入失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := Validate(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("配置验证失败: %w", err)
|
||||
}
|
||||
@ -146,6 +157,81 @@ func Load(path string) (*Config, error) {
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
const maxIncludeDepth = 10
|
||||
|
||||
func processIncludes(cfg *Config, baseDir string, depth int, visited map[string]bool) error {
|
||||
if depth >= maxIncludeDepth {
|
||||
return fmt.Errorf("配置引入嵌套深度超过 %d 层", maxIncludeDepth)
|
||||
}
|
||||
|
||||
for _, inc := range cfg.Include {
|
||||
pattern := inc.Path
|
||||
if !filepath.IsAbs(pattern) {
|
||||
pattern = filepath.Join(baseDir, pattern)
|
||||
}
|
||||
|
||||
matches, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("展开引入路径 %q 失败: %w", inc.Path, err)
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return fmt.Errorf("引入路径 %q 未匹配到任何文件", inc.Path)
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
absMatch, err := filepath.Abs(match)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取引入文件绝对路径失败 %q: %w", match, err)
|
||||
}
|
||||
if visited[absMatch] {
|
||||
return fmt.Errorf("检测到循环引入: %s", absMatch)
|
||||
}
|
||||
visited[absMatch] = true
|
||||
|
||||
info, err := os.Stat(match)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取引入文件 %q 失败: %w", match, err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
delete(visited, absMatch)
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(match)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取引入文件 %q 失败: %w", match, err)
|
||||
}
|
||||
|
||||
var included Config
|
||||
if err := yaml.Unmarshal(data, &included); err != nil {
|
||||
return fmt.Errorf("解析引入文件 %q 失败: %w", match, err)
|
||||
}
|
||||
|
||||
if len(included.Include) > 0 {
|
||||
if err := processIncludes(&included, filepath.Dir(match), depth+1, visited); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
cfg.Servers = append(cfg.Servers, included.Servers...)
|
||||
cfg.Stream = append(cfg.Stream, included.Stream...)
|
||||
for k, v := range included.Variables.Set {
|
||||
if _, exists := cfg.Variables.Set[k]; !exists {
|
||||
if cfg.Variables.Set == nil {
|
||||
cfg.Variables.Set = make(map[string]string)
|
||||
}
|
||||
cfg.Variables.Set[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
delete(visited, absMatch)
|
||||
}
|
||||
}
|
||||
|
||||
cfg.Include = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromString 从 YAML 字符串加载配置。
|
||||
//
|
||||
// 解析 YAML 格式的配置字符串,适用于从环境变量或命令行参数加载配置。
|
||||
|
||||
@ -4,6 +4,7 @@ package config
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -491,3 +492,273 @@ func TestProxyBufferingConfig_ParseBuffers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_Include(t *testing.T) {
|
||||
t.Run("append servers from include", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
mainCfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
name: main
|
||||
include:
|
||||
- path: "conf.d/*.yaml"
|
||||
`
|
||||
incCfg := `
|
||||
servers:
|
||||
- listen: ":9090"
|
||||
name: included
|
||||
`
|
||||
confDir := filepath.Join(tmpDir, "conf.d")
|
||||
if err := os.MkdirAll(confDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(confDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Servers) != 2 {
|
||||
t.Fatalf("expected 2 servers, got %d", len(cfg.Servers))
|
||||
}
|
||||
if cfg.Servers[0].Name != "main" {
|
||||
t.Errorf("Servers[0].Name = %q, want %q", cfg.Servers[0].Name, "main")
|
||||
}
|
||||
if cfg.Servers[1].Name != "included" {
|
||||
t.Errorf("Servers[1].Name = %q, want %q", cfg.Servers[1].Name, "included")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("merge variables with main priority", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
mainCfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
variables:
|
||||
set:
|
||||
app: lolly
|
||||
env: production
|
||||
include:
|
||||
- path: "extra.yaml"
|
||||
`
|
||||
incCfg := `
|
||||
servers:
|
||||
- listen: ":9090"
|
||||
variables:
|
||||
set:
|
||||
app: other
|
||||
debug: "true"
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "extra.yaml"), []byte(incCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Variables.Set["app"] != "lolly" {
|
||||
t.Errorf("app = %q, want %q (main should win)", cfg.Variables.Set["app"], "lolly")
|
||||
}
|
||||
if cfg.Variables.Set["debug"] != "true" {
|
||||
t.Errorf("debug = %q, want %q (included should fill missing)", cfg.Variables.Set["debug"], "true")
|
||||
}
|
||||
if cfg.Variables.Set["env"] != "production" {
|
||||
t.Errorf("env = %q, want %q", cfg.Variables.Set["env"], "production")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no matches returns error", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
mainCfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
include:
|
||||
- path: "nonexistent/*.yaml"
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(filepath.Join(tmpDir, "config.yaml"))
|
||||
if err == nil {
|
||||
t.Error("expected error for non-matching glob pattern")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("circular include detected immediately", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cfg1 := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
include:
|
||||
- path: "b.yaml"
|
||||
`
|
||||
cfg2 := `
|
||||
servers:
|
||||
- listen: ":9090"
|
||||
include:
|
||||
- path: "a.yaml"
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg1), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(cfg2), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(filepath.Join(tmpDir, "a.yaml"))
|
||||
if err == nil {
|
||||
t.Error("expected error for circular include")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "循环引入") {
|
||||
t.Errorf("error should mention circular include, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("self include detected", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
cfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
include:
|
||||
- path: "a.yaml"
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "a.yaml"), []byte(cfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := Load(filepath.Join(tmpDir, "a.yaml"))
|
||||
if err == nil {
|
||||
t.Error("expected error for self include")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("diamond include works", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
mainCfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
include:
|
||||
- path: "b.yaml"
|
||||
- path: "c.yaml"
|
||||
`
|
||||
bCfg := `
|
||||
servers:
|
||||
- listen: ":9090"
|
||||
include:
|
||||
- path: "shared.yaml"
|
||||
`
|
||||
cCfg := `
|
||||
servers:
|
||||
- listen: ":9091"
|
||||
include:
|
||||
- path: "shared.yaml"
|
||||
`
|
||||
sharedCfg := `
|
||||
variables:
|
||||
set:
|
||||
shared: value
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "b.yaml"), []byte(bCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "c.yaml"), []byte(cCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "shared.yaml"), []byte(sharedCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
|
||||
if err != nil {
|
||||
t.Fatalf("diamond include should work: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Servers) != 3 {
|
||||
t.Errorf("expected 3 servers, got %d", len(cfg.Servers))
|
||||
}
|
||||
if cfg.Variables.Set["shared"] != "value" {
|
||||
t.Error("shared variable should be merged")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty include list is no-op", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
mainCfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
if len(cfg.Servers) != 1 {
|
||||
t.Errorf("expected 1 server, got %d", len(cfg.Servers))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("append stream from include", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
mainCfg := `
|
||||
servers:
|
||||
- listen: ":8080"
|
||||
stream:
|
||||
- listen: ":5432"
|
||||
protocol: tcp
|
||||
upstream:
|
||||
targets:
|
||||
- addr: "127.0.0.1:9000"
|
||||
include:
|
||||
- path: "stream.yaml"
|
||||
`
|
||||
incCfg := `
|
||||
stream:
|
||||
- listen: ":5433"
|
||||
protocol: udp
|
||||
upstream:
|
||||
targets:
|
||||
- addr: "127.0.0.1:9001"
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "config.yaml"), []byte(mainCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "stream.yaml"), []byte(incCfg), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cfg, err := Load(filepath.Join(tmpDir, "config.yaml"))
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Stream) != 2 {
|
||||
t.Fatalf("expected 2 streams, got %d", len(cfg.Stream))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -222,6 +222,7 @@ func DefaultConfig() *Config {
|
||||
Shutdown: ShutdownConfig{
|
||||
GracefulTimeout: 30 * time.Second,
|
||||
FastTimeout: 5 * time.Second,
|
||||
ReloadTimeout: 5 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -612,6 +613,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
||||
buf.WriteString("shutdown:\n")
|
||||
fmt.Fprintf(&buf, " graceful_timeout: %ds # 优雅停止超时(SIGQUIT),等待活跃请求完成(0=使用默认30s)\n", int(cfg.Shutdown.GracefulTimeout.Seconds()))
|
||||
fmt.Fprintf(&buf, " fast_timeout: %ds # 快速停止超时(SIGINT/SIGTERM,0=使用默认5s)\n", int(cfg.Shutdown.FastTimeout.Seconds()))
|
||||
fmt.Fprintf(&buf, " reload_timeout: %ds # 热重载启动等待超时(SIGHUP,0=使用默认5s)\n", int(cfg.Shutdown.ReloadTimeout.Seconds()))
|
||||
buf.WriteString("\n")
|
||||
|
||||
// stream 配置
|
||||
@ -734,6 +736,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
||||
buf.WriteString("# - path: \"conf.d/*.yaml\" # 相对路径 + glob 模式\n")
|
||||
buf.WriteString("# - path: \"sites/example.yaml\" # 单个文件引入\n")
|
||||
buf.WriteString("# 支持循环检测和深度限制(最大 10 层)\n")
|
||||
buf.WriteString("# 注意:只有 servers、stream、variables 会被合并,其他字段忽略\n")
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
@ -181,6 +181,11 @@ type ShutdownConfig struct {
|
||||
// 接收到 SIGINT 或 SIGTERM 信号后,等待服务器关闭的最大时间
|
||||
// 默认: 5s(当值为 0 时使用默认值)
|
||||
FastTimeout time.Duration `yaml:"fast_timeout"`
|
||||
|
||||
// ReloadTimeout 热重载启动等待超时(SIGHUP)
|
||||
// 等待新服务器启动完成的最大时间,超时后视为启动成功
|
||||
// 默认: 5s(当值为 0 时使用默认值)
|
||||
ReloadTimeout time.Duration `yaml:"reload_timeout"`
|
||||
}
|
||||
|
||||
// ResolverConfig DNS 解析器配置。
|
||||
|
||||
@ -462,15 +462,6 @@ func (b *ConfigBuilder) WithRewrite(pattern, replacement string, opts ...Rewrite
|
||||
return b
|
||||
}
|
||||
|
||||
// WithCachePath 配置缓存路径。
|
||||
func (b *ConfigBuilder) WithCachePath(path string, maxSize int64) *ConfigBuilder {
|
||||
b.cfg.CachePath = &config.ProxyCachePathConfig{
|
||||
Path: path,
|
||||
MaxSize: maxSize,
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// WithResolver 配置 DNS 解析器。
|
||||
func (b *ConfigBuilder) WithResolver(addresses []string, valid, timeout time.Duration) *ConfigBuilder {
|
||||
b.cfg.Resolver = config.ResolverConfig{
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package matcher
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
@ -107,4 +108,8 @@ func TestLocationEngine_PathConflict(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Error("should fail on path conflict")
|
||||
}
|
||||
var ce *ConflictError
|
||||
if !errors.As(err, &ce) {
|
||||
t.Errorf("expected *ConflictError, got %T: %v", err, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -173,8 +173,12 @@ func (e *LocationEngine) AddNamed(name string, handler fasthttp.RequestHandler)
|
||||
return errors.New("LocationEngine already initialized")
|
||||
}
|
||||
|
||||
if existing, ok := e.namedMatchers[name]; ok {
|
||||
return fmt.Errorf("named location '@%s' already registered", existing.name)
|
||||
if _, ok := e.namedMatchers[name]; ok {
|
||||
return &ConflictError{
|
||||
Path: "@" + name,
|
||||
ExistingType: "named",
|
||||
NewType: "named",
|
||||
}
|
||||
}
|
||||
|
||||
matcher := NewNamedMatcher(name, handler)
|
||||
@ -240,6 +244,21 @@ func (e *LocationEngine) MarkInitialized() {
|
||||
e.prefixTree.MarkInitialized()
|
||||
}
|
||||
|
||||
// ConflictError 路径冲突错误。
|
||||
//
|
||||
// 当同一路径被重复注册为不同类型的 location 时返回此错误。
|
||||
// 调用方可通过 errors.As 检测此类型,区分冲突与致命错误。
|
||||
type ConflictError struct {
|
||||
Path string
|
||||
ExistingType string
|
||||
NewType string
|
||||
}
|
||||
|
||||
func (e *ConflictError) Error() string {
|
||||
return fmt.Sprintf("path conflict: '%s' already registered as '%s', trying to register as '%s'",
|
||||
e.Path, e.ExistingType, e.NewType)
|
||||
}
|
||||
|
||||
// checkConflict 检查路径冲突。
|
||||
//
|
||||
// 参数:
|
||||
@ -247,11 +266,10 @@ func (e *LocationEngine) MarkInitialized() {
|
||||
// - locationType: location 类型
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 路径已存在时返回冲突错误
|
||||
// - error: 路径已存在时返回 *ConflictError
|
||||
func (e *LocationEngine) checkConflict(path, locationType string) error {
|
||||
if existing, ok := e.registeredPaths[path]; ok {
|
||||
return fmt.Errorf("path conflict: '%s' already registered as '%s', trying to register as '%s'",
|
||||
path, existing, locationType)
|
||||
return &ConflictError{Path: path, ExistingType: existing, NewType: locationType}
|
||||
}
|
||||
e.registeredPaths[path] = locationType
|
||||
return nil
|
||||
|
||||
@ -67,7 +67,7 @@ func (s *Server) createProxyForConfig(proxyCfg *config.ProxyConfig) *proxy.Proxy
|
||||
//
|
||||
// 根据配置为 LocationEngine 注册代理路径,创建代理处理器和健康检查器。
|
||||
// 支持通过 LocationType 配置不同的匹配方式。
|
||||
func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) {
|
||||
func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) error {
|
||||
for i := range serverCfg.Proxy {
|
||||
proxyCfg := &serverCfg.Proxy[i]
|
||||
|
||||
@ -76,7 +76,6 @@ func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerC
|
||||
continue
|
||||
}
|
||||
|
||||
// 根据 LocationType 注册路由
|
||||
locType := proxyCfg.LocationType
|
||||
if locType == "" {
|
||||
locType = matcher.LocationTypePrefix
|
||||
@ -84,22 +83,47 @@ func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerC
|
||||
|
||||
switch locType {
|
||||
case matcher.LocationTypeExact:
|
||||
_ = s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
|
||||
if err := s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil {
|
||||
if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case matcher.LocationTypePrefixPriority:
|
||||
_ = s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
|
||||
if err := s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil {
|
||||
if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case matcher.LocationTypeRegex, matcher.LocationTypeRegexCaseless:
|
||||
caseInsensitive := locType == matcher.LocationTypeRegexCaseless
|
||||
_ = s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal)
|
||||
if err := s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal); err != nil {
|
||||
if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case matcher.LocationTypeNamed:
|
||||
if proxyCfg.LocationName != "" {
|
||||
_ = s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP)
|
||||
if err := s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP); err != nil {
|
||||
if err := s.handleRegistrationError("proxy", "@"+proxyCfg.LocationName, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
case matcher.LocationTypePrefix:
|
||||
_ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
|
||||
if err := s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil {
|
||||
if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
_ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
|
||||
if err := s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal); err != nil {
|
||||
if err := s.handleRegistrationError("proxy", proxyCfg.Path, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// configureStaticHandler 配置静态文件处理器。
|
||||
@ -156,7 +180,7 @@ func (s *Server) configureStaticHandler(static *config.StaticConfig, cfg *config
|
||||
}
|
||||
|
||||
// registerStaticHandlersWithLocationEngine 使用 LocationEngine 注册静态文件处理器。
|
||||
func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) {
|
||||
func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) error {
|
||||
for _, static := range cfg.Static {
|
||||
staticHandler := s.configureStaticHandler(&static, cfg)
|
||||
path := static.Path
|
||||
@ -164,7 +188,6 @@ func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConf
|
||||
path = "/"
|
||||
}
|
||||
|
||||
// 根据 LocationType 注册路由
|
||||
locType := static.LocationType
|
||||
if locType == "" {
|
||||
locType = matcher.LocationTypePrefix
|
||||
@ -172,15 +195,32 @@ func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConf
|
||||
|
||||
switch locType {
|
||||
case matcher.LocationTypeExact:
|
||||
_ = s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal)
|
||||
if err := s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal); err != nil {
|
||||
if err := s.handleRegistrationError("static", path, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case matcher.LocationTypePrefixPriority:
|
||||
_ = s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal)
|
||||
if err := s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal); err != nil {
|
||||
if err := s.handleRegistrationError("static", path, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case matcher.LocationTypePrefix:
|
||||
_ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal)
|
||||
if err := s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal); err != nil {
|
||||
if err := s.handleRegistrationError("static", path, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
_ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal)
|
||||
if err := s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal); err != nil {
|
||||
if err := s.handleRegistrationError("static", path, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerProxyRoutes 注册代理路由。
|
||||
@ -324,9 +364,9 @@ func (s *Server) registerLuaRoutes(router *handler.Router, serverCfg *config.Ser
|
||||
// - 只有设置了 Route 字段的脚本才会被注册
|
||||
// - 路由脚本不经过完整中间件链,只应用 accesslog 和 errorintercept
|
||||
// - 支持 exact、prefix、prefix_priority、regex、regex_caseless 匹配类型
|
||||
func (s *Server) registerLuaRoutesWithLocationEngine(serverCfg *config.ServerConfig) {
|
||||
func (s *Server) registerLuaRoutesWithLocationEngine(serverCfg *config.ServerConfig) error {
|
||||
if s.luaEngine == nil || serverCfg.Lua == nil || !serverCfg.Lua.Enabled {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, script := range serverCfg.Lua.Scripts {
|
||||
@ -348,17 +388,38 @@ func (s *Server) registerLuaRoutesWithLocationEngine(serverCfg *config.ServerCon
|
||||
|
||||
switch routeType {
|
||||
case matcher.LocationTypeExact:
|
||||
_ = s.locationEngine.AddExact(script.Route, handler, false)
|
||||
if err := s.locationEngine.AddExact(script.Route, handler, false); err != nil {
|
||||
if err := s.handleRegistrationError("lua", script.Route, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case matcher.LocationTypePrefixPriority:
|
||||
_ = s.locationEngine.AddPrefixPriority(script.Route, handler, false)
|
||||
if err := s.locationEngine.AddPrefixPriority(script.Route, handler, false); err != nil {
|
||||
if err := s.handleRegistrationError("lua", script.Route, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case matcher.LocationTypeRegex:
|
||||
_ = s.locationEngine.AddRegex(script.Route, handler, false, false)
|
||||
if err := s.locationEngine.AddRegex(script.Route, handler, false, false); err != nil {
|
||||
if err := s.handleRegistrationError("lua", script.Route, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case matcher.LocationTypeRegexCaseless:
|
||||
_ = s.locationEngine.AddRegex(script.Route, handler, true, false)
|
||||
if err := s.locationEngine.AddRegex(script.Route, handler, true, false); err != nil {
|
||||
if err := s.handleRegistrationError("lua", script.Route, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
_ = s.locationEngine.AddPrefix(script.Route, handler, false)
|
||||
if err := s.locationEngine.AddPrefix(script.Route, handler, false); err != nil {
|
||||
if err := s.handleRegistrationError("lua", script.Route, err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// wrapRoutedHandler 为路由处理器包装基础中间件链。
|
||||
|
||||
@ -21,6 +21,7 @@ package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@ -96,6 +97,15 @@ func New(cfg *config.Config) *Server {
|
||||
return &Server{config: cfg}
|
||||
}
|
||||
|
||||
func (s *Server) handleRegistrationError(source, path string, err error) error {
|
||||
var ce *matcher.ConflictError
|
||||
if errors.As(err, &ce) {
|
||||
logging.Warn().Msgf("Route registration skipped (%s %s): %s", source, path, err)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%s route %s: %w", source, path, err)
|
||||
}
|
||||
|
||||
// getServerName 根据配置返回服务器名称。
|
||||
//
|
||||
// 当 ServerTokens 为 false 时隐藏版本号,仅返回 "lolly"。
|
||||
@ -308,32 +318,31 @@ func (s *Server) Start() error {
|
||||
func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error) {
|
||||
listenAddr := cfg.Listen
|
||||
|
||||
if s.upgradeManager != nil && s.upgradeManager.IsChild() {
|
||||
inherited, _ := s.upgradeManager.GetInheritedListeners()
|
||||
if ln := s.matchInheritedListener(inherited, listenAddr); ln != nil {
|
||||
return ln, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.listeners) > 0 {
|
||||
if ln := s.matchInheritedListener(s.listeners, listenAddr); ln != nil {
|
||||
return ln, nil
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(listenAddr, "unix:") {
|
||||
// Unix Socket 模式
|
||||
socketPath := listenAddr[5:]
|
||||
|
||||
// 1. 检查继承的监听器(热升级场景)
|
||||
if s.upgradeManager != nil && s.upgradeManager.IsChild() {
|
||||
inherited, _ := s.upgradeManager.GetInheritedListeners()
|
||||
for _, ln := range inherited {
|
||||
if ln.Addr().Network() == "unix" && ln.Addr().String() == socketPath {
|
||||
return ln, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 清理旧 socket 文件
|
||||
if _, err := os.Stat(socketPath); err == nil {
|
||||
_ = os.Remove(socketPath)
|
||||
}
|
||||
|
||||
// 3. 创建 Unix socket listener
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create unix socket failed: %w", err)
|
||||
}
|
||||
|
||||
// 4. 设置 socket 文件权限
|
||||
mode := 0o666
|
||||
if cfg.UnixSocket.Mode > 0 {
|
||||
mode = cfg.UnixSocket.Mode
|
||||
@ -342,19 +351,95 @@ func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error)
|
||||
logging.Warn().Err(err).Msg("Failed to set socket file permissions")
|
||||
}
|
||||
|
||||
// 5. 设置文件所有权(需要 root 权限)
|
||||
if cfg.UnixSocket.User != "" || cfg.UnixSocket.Group != "" {
|
||||
// 简化处理:仅记录警告,实际实现需要 syscall.Chown
|
||||
logging.Warn().Msg("Unix socket user/group config requires root privileges, skipped")
|
||||
}
|
||||
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// TCP 模式
|
||||
return net.Listen("tcp", listenAddr)
|
||||
}
|
||||
|
||||
func (s *Server) matchInheritedListener(inherited []net.Listener, listenAddr string) net.Listener {
|
||||
if len(inherited) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(listenAddr, "unix:") {
|
||||
socketPath := listenAddr[5:]
|
||||
for _, ln := range inherited {
|
||||
if ln == nil {
|
||||
continue
|
||||
}
|
||||
if ln.Addr().Network() == "unix" && ln.Addr().String() == socketPath {
|
||||
return ln
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, ln := range inherited {
|
||||
if ln == nil {
|
||||
continue
|
||||
}
|
||||
if ln.Addr().Network() != "tcp" {
|
||||
continue
|
||||
}
|
||||
if s.tcpAddrMatch(ln.Addr().String(), listenAddr) {
|
||||
return ln
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) tcpAddrMatch(inherited, target string) bool {
|
||||
if inherited == target {
|
||||
return true
|
||||
}
|
||||
host1, port1, err1 := net.SplitHostPort(inherited)
|
||||
host2, port2, err2 := net.SplitHostPort(target)
|
||||
if err1 != nil || err2 != nil {
|
||||
return false
|
||||
}
|
||||
if port1 != port2 {
|
||||
return false
|
||||
}
|
||||
return host1 == host2 || isAnyAddr(host1) || isAnyAddr(host2)
|
||||
}
|
||||
|
||||
func isAnyAddr(host string) bool {
|
||||
if host == "" {
|
||||
return true
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.IsUnspecified()
|
||||
}
|
||||
|
||||
// DupListener 复制 listener 的文件描述符,返回独立的 listener。
|
||||
//
|
||||
// 用于热重载场景:新旧 server 各自持有独立 FD,互不影响关闭操作。
|
||||
func DupListener(ln net.Listener) (net.Listener, error) {
|
||||
switch l := ln.(type) {
|
||||
case *net.TCPListener:
|
||||
file, err := l.File()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dup tcp listener: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
return net.FileListener(file)
|
||||
case *net.UnixListener:
|
||||
file, err := l.File()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dup unix listener: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
return net.FileListener(file)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported listener type: %T", ln)
|
||||
}
|
||||
}
|
||||
|
||||
// startSingleMode 单服务器模式启动。
|
||||
//
|
||||
// 在单服务器模式下,创建单一路由器,注册代理路由和静态文件服务,
|
||||
@ -382,39 +467,56 @@ func (s *Server) startSingleMode() error {
|
||||
if err != nil {
|
||||
logging.Error().Msg("Failed to create status handler: " + err.Error())
|
||||
} else {
|
||||
_ = s.locationEngine.AddExact(statusHandler.Path(), statusHandler.ServeHTTP, false)
|
||||
if err := s.locationEngine.AddExact(statusHandler.Path(), statusHandler.ServeHTTP, false); err != nil {
|
||||
if err := s.handleRegistrationError("status", statusHandler.Path(), err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 注册 pprof 性能分析端点(如果配置)
|
||||
if s.config.Monitoring.Pprof.Enabled {
|
||||
pprofHandler, err := NewPprofHandler(&s.config.Monitoring.Pprof)
|
||||
if err != nil {
|
||||
logging.Error().Msg("Failed to create pprof handler: " + err.Error())
|
||||
} else {
|
||||
_ = s.locationEngine.AddExact(pprofHandler.Path(), pprofHandler.ServeHTTP, false)
|
||||
_ = s.locationEngine.AddPrefixPriority(pprofHandler.Path()+"/", pprofHandler.ServeHTTP, false)
|
||||
if err := s.locationEngine.AddExact(pprofHandler.Path(), pprofHandler.ServeHTTP, false); err != nil {
|
||||
if err := s.handleRegistrationError("pprof", pprofHandler.Path(), err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := s.locationEngine.AddPrefixPriority(pprofHandler.Path()+"/", pprofHandler.ServeHTTP, false); err != nil {
|
||||
if err := s.handleRegistrationError("pprof", pprofHandler.Path()+"/", err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 注册缓存清理 API(如果配置)
|
||||
if serverCfg.CacheAPI != nil && serverCfg.CacheAPI.Enabled {
|
||||
purgeHandler, err := NewPurgeHandler(s, serverCfg.CacheAPI)
|
||||
if err != nil {
|
||||
logging.Error().Msg("Failed to create cache purge handler: " + err.Error())
|
||||
} else {
|
||||
_ = s.locationEngine.AddExact(purgeHandler.Path(), purgeHandler.ServeHTTP, false)
|
||||
if err := s.locationEngine.AddExact(purgeHandler.Path(), purgeHandler.ServeHTTP, false); err != nil {
|
||||
if err := s.handleRegistrationError("cache-purge", purgeHandler.Path(), err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 注册代理路由
|
||||
s.registerProxyRoutesWithLocationEngine(serverCfg)
|
||||
if err := s.registerProxyRoutesWithLocationEngine(serverCfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Lua 路由
|
||||
s.registerLuaRoutesWithLocationEngine(serverCfg)
|
||||
if err := s.registerLuaRoutesWithLocationEngine(serverCfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 静态文件服务
|
||||
s.registerStaticHandlersWithLocationEngine(serverCfg)
|
||||
if err := s.registerStaticHandlersWithLocationEngine(serverCfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 标记 LocationEngine 初始化完成
|
||||
s.locationEngine.MarkInitialized()
|
||||
@ -613,22 +715,28 @@ func (s *Server) startVHostMode() error {
|
||||
//
|
||||
// 注意事项:
|
||||
// - 每个服务器有独立的中间件配置
|
||||
// - 热升级场景下回退到虚拟主机模式
|
||||
// - 使用 goroutine 并行启动多个服务器
|
||||
func (s *Server) startMultiServerMode() error {
|
||||
// 热升级检测:multi_server 热升级未实现,回退到 vhost 模式
|
||||
if os.Getenv("GRACEFUL_UPGRADE") == "1" {
|
||||
logging.Warn().Msg("multi_server mode not implemented for graceful upgrade, falling back to vhost mode")
|
||||
return s.startVHostMode()
|
||||
}
|
||||
|
||||
s.fastServers = make([]*fasthttp.Server, len(s.config.Servers))
|
||||
s.listeners = make([]net.Listener, len(s.config.Servers))
|
||||
|
||||
for i := range s.config.Servers {
|
||||
serverCfg := &s.config.Servers[i]
|
||||
ln, err := s.createListener(serverCfg)
|
||||
if err != nil {
|
||||
for j := 0; j < i; j++ {
|
||||
if s.listeners[j] != nil {
|
||||
_ = s.listeners[j].Close()
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to listen on %s: %w", serverCfg.Listen, err)
|
||||
}
|
||||
s.listeners[i] = ln
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, len(s.config.Servers))
|
||||
|
||||
// 并行创建监听器和 fasthttp.Server
|
||||
for i := range s.config.Servers {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
@ -636,15 +744,6 @@ func (s *Server) startMultiServerMode() error {
|
||||
|
||||
serverCfg := &s.config.Servers[idx]
|
||||
|
||||
// 创建监听器
|
||||
ln, err := s.createListener(serverCfg)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("failed to listen on %s: %w", serverCfg.Listen, err)
|
||||
return
|
||||
}
|
||||
s.listeners[idx] = ln
|
||||
|
||||
// 创建路由器
|
||||
router := handler.NewRouter()
|
||||
|
||||
// 注册状态监控端点(仅默认服务器)
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@ -26,6 +27,7 @@ import (
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/lua"
|
||||
"rua.plus/lolly/internal/matcher"
|
||||
"rua.plus/lolly/internal/middleware/accesslog"
|
||||
"rua.plus/lolly/internal/middleware/security"
|
||||
"rua.plus/lolly/internal/proxy"
|
||||
@ -989,6 +991,157 @@ func TestCreateListener_UnixSocketCleanup(t *testing.T) {
|
||||
defer ln.Close()
|
||||
}
|
||||
|
||||
func TestHandleRegistrationError_ConflictWarning(t *testing.T) {
|
||||
s := &Server{}
|
||||
err := s.handleRegistrationError("proxy", "/api",
|
||||
&matcher.ConflictError{Path: "/api", ExistingType: "exact", NewType: "prefix"})
|
||||
if err != nil {
|
||||
t.Errorf("conflict should return nil, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleRegistrationError_FatalError(t *testing.T) {
|
||||
s := &Server{}
|
||||
err := s.handleRegistrationError("proxy", "/api",
|
||||
fmt.Errorf("invalid regex pattern: missing closing parenthesis"))
|
||||
if err == nil {
|
||||
t.Error("fatal error should return non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "proxy route /api") {
|
||||
t.Errorf("error should wrap context, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDupListener_TCP(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
duped, err := DupListener(ln)
|
||||
if err != nil {
|
||||
t.Fatalf("DupListener() error: %v", err)
|
||||
}
|
||||
defer duped.Close()
|
||||
|
||||
if duped.Addr().Network() != "tcp" {
|
||||
t.Errorf("expected tcp, got %s", duped.Addr().Network())
|
||||
}
|
||||
if duped.Addr().String() != ln.Addr().String() {
|
||||
t.Errorf("expected same address %s, got %s", ln.Addr().String(), duped.Addr().String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDupListener_Unix(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
socketPath := dir + "/dup.sock"
|
||||
ln, err := net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
duped, err := DupListener(ln)
|
||||
if err != nil {
|
||||
t.Fatalf("DupListener() error: %v", err)
|
||||
}
|
||||
defer duped.Close()
|
||||
}
|
||||
|
||||
func TestDupListener_Unsupported(t *testing.T) {
|
||||
_, err := DupListener(struct{ net.Listener }{})
|
||||
if err == nil {
|
||||
t.Error("expected error for unsupported type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTcpAddrMatch(t *testing.T) {
|
||||
s := &Server{}
|
||||
|
||||
tests := []struct {
|
||||
inherited string
|
||||
target string
|
||||
want bool
|
||||
}{
|
||||
{"127.0.0.1:8080", "127.0.0.1:8080", true},
|
||||
{"0.0.0.0:8080", ":8080", true},
|
||||
{"[::]:8080", ":8080", true},
|
||||
{"0.0.0.0:8080", "0.0.0.0:8080", true},
|
||||
{"0.0.0.0:8080", "127.0.0.1:8080", true},
|
||||
{"127.0.0.1:8080", "0.0.0.0:8080", true},
|
||||
{"127.0.0.1:8080", ":9090", false},
|
||||
{"127.0.0.1:8080", "192.168.1.1:8080", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := s.tcpAddrMatch(tt.inherited, tt.target)
|
||||
if got != tt.want {
|
||||
t.Errorf("tcpAddrMatch(%q, %q) = %v, want %v", tt.inherited, tt.target, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchInheritedListener_TCP(t *testing.T) {
|
||||
s := &Server{}
|
||||
|
||||
ln1, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
defer ln1.Close()
|
||||
|
||||
ln2, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
defer ln2.Close()
|
||||
|
||||
inherited := []net.Listener{ln1, ln2}
|
||||
|
||||
result := s.matchInheritedListener(inherited, "0.0.0.0:99999")
|
||||
if result != nil {
|
||||
t.Error("expected nil for non-matching address")
|
||||
}
|
||||
|
||||
addr1 := ln1.Addr().String()
|
||||
result = s.matchInheritedListener(inherited, addr1)
|
||||
if result != ln1 {
|
||||
t.Errorf("expected ln1 for address %s", addr1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchInheritedListener_Empty(t *testing.T) {
|
||||
s := &Server{}
|
||||
result := s.matchInheritedListener(nil, ":8080")
|
||||
if result != nil {
|
||||
t.Error("expected nil for empty inherited list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchInheritedListener_PresetListeners(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{Listen: "127.0.0.1:0"}},
|
||||
}
|
||||
s := New(cfg)
|
||||
|
||||
ln, err := s.createListener(&cfg.Servers[0])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
s.SetListeners([]net.Listener{ln})
|
||||
|
||||
addr := ln.Addr().String()
|
||||
cfg.Servers[0].Listen = addr
|
||||
|
||||
matched, err := s.createListener(&cfg.Servers[0])
|
||||
if err != nil {
|
||||
t.Fatalf("createListener with preset should reuse: %v", err)
|
||||
}
|
||||
if matched == nil {
|
||||
t.Fatal("expected non-nil listener from preset match")
|
||||
}
|
||||
if matched.Addr().String() != addr {
|
||||
t.Errorf("expected same address %s, got %s", addr, matched.Addr().String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestServer_StatsMethods 测试服务器统计方法。
|
||||
func TestServer_StatsMethods(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user