refactor(config,server): 移除 Config.Server 字段,完善 servers 多服务器配置

- 移除 Config.Server 单服务器字段,统一使用 Servers 列表
- 为 ServerConfig 添加 Default 标记支持虚拟主机默认主机
- 重命名 GetDefaultServer 为 GetDefaultServerFromList
- 更新验证逻辑确保 servers 列表必填
- 更新默认配置生成和测试适配

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-15 13:32:36 +08:00
parent 37d8f9eebc
commit d5b4509014
8 changed files with 302 additions and 240 deletions

View File

@ -56,7 +56,7 @@ const (
// 是配置文件的顶级结构体,所有其他配置都作为其子结构。
//
// 注意事项:
// - 必须配置 server 或 servers 中的至少一个
// - 必须配置 servers 列表中的至少一个
// - 加载后会自动进行配置验证
// - Stream 配置为可选,用于 TCP/UDP 层代理
// - HTTP/3 配置为可选,需 SSL 配置配合才能生效
@ -67,8 +67,7 @@ const (
// if err != nil {
// log.Fatal(err)
// }
// server := cfg.Server
// // 或使用多虚拟主机
// // 使用多虚拟主机模式
// for _, s := range cfg.Servers {
// // 处理每个服务器配置
// }
@ -81,7 +80,6 @@ type Config struct {
Monitoring MonitoringConfig `yaml:"monitoring"`
HTTP3 HTTP3Config `yaml:"http3"`
Resolver ResolverConfig `yaml:"resolver"`
Server ServerConfig `yaml:"server"`
Performance PerformanceConfig `yaml:"performance"`
Shutdown ShutdownConfig `yaml:"shutdown"`
}
@ -188,6 +186,7 @@ type ServerConfig struct {
Lua *LuaMiddlewareConfig `yaml:"lua"`
ClientMaxBodySize string `yaml:"client_max_body_size"`
Name string `yaml:"name"`
Default bool `yaml:"default,omitempty"` // VHost 默认主机标记
Listen string `yaml:"listen"`
Security SecurityConfig `yaml:"security"`
Static []StaticConfig `yaml:"static"`
@ -1669,23 +1668,18 @@ func (c *Config) HasServers() bool {
return len(c.Servers) > 0
}
// HasDefaultServer 检查是否有默认服务器配置。
//
// 返回值:
// - bool: 如果 server.listen 已配置,返回 true
func (c *Config) HasDefaultServer() bool {
return c.Server.Listen != ""
}
// GetDefaultServer 获取默认服务器配置。
// GetDefaultServerFromList 从 servers 列表中获取默认服务器配置。
//
// 遍历 servers 列表,返回第一个 Default 标记为 true 的服务器。
// 用于在虚拟主机模式下获取默认服务器的配置作为 fallback。
//
// 返回值:
// - *ServerConfig: 默认服务器配置,如未配置则返回 nil
func (c *Config) GetDefaultServer() *ServerConfig {
if c.HasDefaultServer() {
return &c.Server
// - *ServerConfig: 默认服务器配置,如无则返回 nil
func (c *Config) GetDefaultServerFromList() *ServerConfig {
for i := range c.Servers {
if c.Servers[i].Default {
return &c.Servers[i]
}
}
return nil
}
@ -1766,6 +1760,11 @@ func Validate(cfg *Config) error {
return err
}
// 验证 default 服务器唯一性
if err := validateDefaultServer(cfg.Servers); err != nil {
return err
}
// 验证所有服务器
for i := range cfg.Servers {
if err := validateServer(&cfg.Servers[i], false); err != nil {

View File

@ -357,45 +357,31 @@ func TestConfigMethods(t *testing.T) {
}
})
t.Run("HasDefaultServer_有默认服务器", func(t *testing.T) {
t.Run("GetDefaultServerFromList_有默认服务器", func(t *testing.T) {
cfg := &Config{
Server: ServerConfig{
Listen: ":8080",
Servers: []ServerConfig{
{Listen: ":8080", Name: "api"},
{Listen: ":8081", Name: "default", Default: true},
},
}
if !cfg.HasDefaultServer() {
t.Error("HasDefaultServer() = false, want true")
}
})
t.Run("HasDefaultServer_无默认服务器", func(t *testing.T) {
cfg := &Config{}
if cfg.HasDefaultServer() {
t.Error("HasDefaultServer() = true, want false")
}
})
t.Run("GetDefaultServer_有默认服务器", func(t *testing.T) {
cfg := &Config{
Server: ServerConfig{
Listen: ":8080",
Name: "default",
},
}
server := cfg.GetDefaultServer()
server := cfg.GetDefaultServerFromList()
if server == nil {
t.Fatal("GetDefaultServer() = nil, want non-nil")
t.Fatal("GetDefaultServerFromList() = nil, want non-nil")
}
if server.Listen != ":8080" {
t.Errorf("server.Listen = %q, want %q", server.Listen, ":8080")
if server.Listen != ":8081" {
t.Errorf("server.Listen = %q, want %q", server.Listen, ":8081")
}
})
t.Run("GetDefaultServer_无默认服务器", func(t *testing.T) {
cfg := &Config{}
server := cfg.GetDefaultServer()
t.Run("GetDefaultServerFromList_无默认服务器", func(t *testing.T) {
cfg := &Config{
Servers: []ServerConfig{
{Listen: ":8080", Name: "api"},
},
}
server := cfg.GetDefaultServerFromList()
if server != nil {
t.Errorf("GetDefaultServer() = %v, want nil", server)
t.Errorf("GetDefaultServerFromList() = %v, want nil", server)
}
})
@ -404,28 +390,18 @@ func TestConfigMethods(t *testing.T) {
cfg *Config
name string
wantHasServers bool
wantHasDefault bool
}{
{
name: "仅默认服务器",
cfg: &Config{Server: ServerConfig{Listen: ":8080"}},
wantHasServers: false,
wantHasDefault: true,
},
{
name: "仅多虚拟主机",
cfg: &Config{Servers: []ServerConfig{{Listen: ":8080"}}},
wantHasServers: true,
wantHasDefault: false,
},
{
name: "混合模式",
cfg: &Config{
Server: ServerConfig{Listen: ":8080"},
Servers: []ServerConfig{{Listen: ":8081"}},
},
wantHasServers: true,
wantHasDefault: true,
},
}
@ -434,9 +410,6 @@ func TestConfigMethods(t *testing.T) {
if got := tt.cfg.HasServers(); got != tt.wantHasServers {
t.Errorf("HasServers() = %v, want %v", got, tt.wantHasServers)
}
if got := tt.cfg.HasDefaultServer(); got != tt.wantHasDefault {
t.Errorf("HasDefaultServer() = %v, want %v", got, tt.wantHasDefault)
}
})
}
})

View File

@ -234,6 +234,15 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
// buf.WriteString("# 文档: https://github.com/xfy/lolly\n")
buf.WriteString("\n")
// mode 配置
buf.WriteString("# 运行模式配置\n")
buf.WriteString("# mode: auto # 运行模式(有效值: single, vhost, multi_server, auto\n")
buf.WriteString("# auto: 自动推断模式(根据 servers 配置自动选择,默认)\n")
buf.WriteString("# single: 单服务器模式(只有一个 server\n")
buf.WriteString("# vhost: 虚拟主机模式(多个 server 共享相同监听地址)\n")
buf.WriteString("# multi_server: 多服务器模式(多个 server 监听不同地址)\n")
buf.WriteString("\n")
// servers 配置
buf.WriteString("# 服务器配置(多服务器模式)\n")
buf.WriteString("servers:\n")
@ -277,6 +286,9 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
buf.WriteString(" # code_cache_size: 1000 # 字节码缓存条目数\n")
buf.WriteString(" # enable_file_watch: true # 启用文件变更检测\n")
buf.WriteString(" # max_execution_time: 30s # 最大执行时间\n")
buf.WriteString(" # coroutine_stack_size: 0 # 协程栈大小0=使用默认值)\n")
buf.WriteString(" # coroutine_pool_warmup: 0 # 协程池预热数量0=不预热)\n")
buf.WriteString(" # minimize_stack_memory: false # 最小化栈内存使用\n")
buf.WriteString("\n")
// static 配置
@ -291,6 +303,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
}
buf.WriteString(" try_files: [] # SPA 部署示例: [\"$uri\", \"$uri/\", \"/index.html\"]\n")
buf.WriteString(" try_files_pass: false # 内部重定向是否触发中间件\n")
buf.WriteString(" symlink_check: false # 是否检查符号链接安全(防止路径遍历攻击)\n")
}
buf.WriteString(" # 示例:额外的静态目录\n")
buf.WriteString(" # - path: \"/assets/\"\n")
@ -331,6 +344,11 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
buf.WriteString(" # next_upstream: # 故障转移配置\n")
buf.WriteString(" # tries: 1 # 最大尝试次数1 表示禁用故障转移)\n")
buf.WriteString(" # http_codes: [502, 503, 504] # 触发重试的 HTTP 状态码\n")
buf.WriteString(" # balancer_by_lua: # Lua 动态负载均衡(在 load_balance 基础上自定义选择逻辑)\n")
buf.WriteString(" # enabled: false # 是否启用 Lua 负载均衡\n")
buf.WriteString(" # script: \"\" # Lua 脚本路径,返回目标索引\n")
buf.WriteString(" # fallback: \"round_robin\" # Lua 失败时的备用算法(有效值: round_robin, weighted_round_robin, least_conn\n")
buf.WriteString(" # timeout: 5s # Lua 执行超时\n")
buf.WriteString("\n")
// SSL 配置
@ -375,6 +393,12 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
fmt.Fprintf(&buf, " # graceful_shutdown_timeout: %ds # HTTP/2 优雅关闭超时\n", int(cfg.Servers[0].SSL.HTTP2.GracefulShutdownTimeout.Seconds()))
buf.WriteString("\n")
// SSL 默认值说明(即使不启用也展示默认配置)
buf.WriteString(" # SSL/TLS 默认配置说明(未配置证书时不启用)\n")
buf.WriteString(" # 默认 TLS 协议: TLSv1.2, TLSv1.3(不支持 TLSv1.0/1.1\n")
buf.WriteString(" # 默认 HSTS 配置: max_age=315360001年, include_sub_domains=true\n")
buf.WriteString("\n")
// security 配置
buf.WriteString(" # 安全配置\n")
buf.WriteString(" security:\n")
@ -385,6 +409,17 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
fmt.Fprintf(&buf, " default: \"%s\" # 默认动作(有效值: allow, deny\n", cfg.Servers[0].Security.Access.Default)
buf.WriteString(" trusted_proxies: [] # 可信代理 CIDR 列表,用于 X-Forwarded-For 解析\n")
buf.WriteString("\n")
buf.WriteString(" # GeoIP 地理访问控制(基于 IP 所属国家/地区)\n")
buf.WriteString(" geoip:\n")
buf.WriteString(" enabled: false # 是否启用 GeoIP 访问控制\n")
buf.WriteString(" database: \"\" # GeoIP 数据库文件路径(如 /usr/share/GeoIP/GeoLite2-Country.mmdb\n")
buf.WriteString(" default: \"allow\" # 未匹配时的默认动作(有效值: allow, deny\n")
buf.WriteString(" private_ip_behavior: \"bypass\" # 私有 IP 处理方式(有效值: bypass, apply_default, deny\n")
buf.WriteString(" allow_countries: [] # 允许的国家代码列表(如 [\"CN\", \"US\"]\n")
buf.WriteString(" deny_countries: [] # 拒绝的国家代码列表\n")
buf.WriteString(" cache_size: 10000 # GeoIP 查询缓存大小\n")
buf.WriteString(" cache_ttl: 3600 # 缓存有效期(秒)\n")
buf.WriteString("\n")
buf.WriteString(" # 速率限制\n")
buf.WriteString(" rate_limit:\n")
fmt.Fprintf(&buf, " request_rate: %d # 每秒请求数0 表示不限制)\n", cfg.Servers[0].Security.RateLimit.RequestRate)
@ -460,12 +495,6 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
fmt.Fprintf(&buf, " fast_timeout: %ds # 快速停止超时SIGINT/SIGTERM0=使用默认5s\n", int(cfg.Shutdown.FastTimeout.Seconds()))
buf.WriteString("\n")
// SSL 默认值说明(即使不启用也展示默认配置)
buf.WriteString("# SSL/TLS 默认配置说明(未配置证书时不启用)\n")
buf.WriteString("# 默认 TLS 协议: TLSv1.2, TLSv1.3(不支持 TLSv1.0/1.1\n")
buf.WriteString("# 默认 HSTS 配置: max_age=315360001年, include_sub_domains=true\n")
buf.WriteString("\n")
// stream 配置
buf.WriteString("# TCP/UDP Stream 代理配置(可选)\n")
buf.WriteString("# stream:\n")

View File

@ -28,6 +28,26 @@ import (
"rua.plus/lolly/internal/variable"
)
// validateDefaultServer 验证 servers 中最多只有一个 default: true 服务器。
//
// 参数:
// - servers: 服务器配置列表
//
// 返回值:
// - error: 超过一个 default 时返回错误信息,成功返回 nil
func validateDefaultServer(servers []ServerConfig) error {
count := 0
for _, s := range servers {
if s.Default {
count++
}
}
if count > 1 {
return errors.New("只能有一个 default: true 服务器")
}
return nil
}
// validateMode 验证服务器运行模式有效值。
//
// 检查 Mode 是否为 ServerModeSingle, ServerModeVHost,

View File

@ -15,8 +15,10 @@ import (
// TestBuildLuaMiddlewares_NilEngine 测试 LuaEngine 为 nil 时
func TestBuildLuaMiddlewares_NilEngine(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
Servers: []config.ServerConfig{
{
Listen: ":8080",
},
},
}
s := New(cfg)
@ -44,8 +46,10 @@ func TestBuildLuaMiddlewares_NilEngine(t *testing.T) {
// TestBuildLuaMiddlewares_InvalidPhase 测试无效阶段
func TestBuildLuaMiddlewares_InvalidPhase(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
Servers: []config.ServerConfig{
{
Listen: ":8080",
},
},
}
s := New(cfg)
@ -77,8 +81,10 @@ func TestBuildLuaMiddlewares_InvalidPhase(t *testing.T) {
// TestBuildLuaMiddlewares_WithTimeout 测试超时配置
func TestBuildLuaMiddlewares_WithTimeout(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
Servers: []config.ServerConfig{
{
Listen: ":8080",
},
},
}
s := New(cfg)
@ -111,8 +117,10 @@ func TestBuildLuaMiddlewares_WithTimeout(t *testing.T) {
// TestBuildLuaMiddlewares_EmptyScripts 测试空脚本列表
func TestBuildLuaMiddlewares_EmptyScripts(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
Servers: []config.ServerConfig{
{
Listen: ":8080",
},
},
}
s := New(cfg)
@ -143,8 +151,10 @@ func TestBuildLuaMiddlewares_DisabledLua(t *testing.T) {
require.NoError(t, err)
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
Servers: []config.ServerConfig{
{
Listen: ":8080",
},
},
}
s := New(cfg)
@ -185,8 +195,10 @@ func TestBuildLuaMiddlewares_DisabledScript(t *testing.T) {
require.NoError(t, err)
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
Servers: []config.ServerConfig{
{
Listen: ":8080",
},
},
}
s := New(cfg)
@ -234,8 +246,10 @@ ngx.var.uri = "/test"
require.NoError(t, err)
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
Servers: []config.ServerConfig{
{
Listen: ":8080",
},
},
}
s := New(cfg)
@ -289,8 +303,10 @@ ngx.say("hello")
}
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
Servers: []config.ServerConfig{
{
Listen: ":8080",
},
},
}
s := New(cfg)
@ -340,8 +356,10 @@ func TestBuildLuaMiddlewares_DefaultTimeout(t *testing.T) {
require.NoError(t, err)
cfg := &config.Config{
Server: config.ServerConfig{
Listen: ":8080",
Servers: []config.ServerConfig{
{
Listen: ":8080",
},
},
}
s := New(cfg)

View File

@ -573,7 +573,7 @@ func (s *Server) startVHostMode() error {
}
// 默认主机
if s.config.HasDefaultServer() {
if s.config.GetDefaultServerFromList() != nil {
router := handler.NewRouter()
// 注册状态监控端点(如果配置)
@ -597,12 +597,14 @@ func (s *Server) startVHostMode() error {
}
}
s.registerProxyRoutes(router, &s.config.Server)
defaultSrv := s.config.GetDefaultServerFromList()
s.registerProxyRoutes(router, defaultSrv)
// 静态文件
s.registerStaticHandlers(router, &s.config.Server)
s.registerStaticHandlers(router, defaultSrv)
chain, err := s.buildMiddlewareChain(&s.config.Server)
chain, err := s.buildMiddlewareChain(defaultSrv)
if err != nil {
return err
}

View File

@ -41,15 +41,17 @@ func TestStart_Integration(t *testing.T) {
serverAddr := "127.0.0.1:0"
cfg := &config.Config{
Server: config.ServerConfig{
Listen: serverAddr,
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://" + backendAddr, Weight: 1},
Servers: []config.ServerConfig{
{
Listen: serverAddr,
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://" + backendAddr, Weight: 1},
},
HealthCheck: config.HealthCheckConfig{},
},
HealthCheck: config.HealthCheckConfig{},
},
},
},
@ -77,20 +79,22 @@ func TestStart_Integration(t *testing.T) {
// TestStart_WithSecurity 测试安全配置
func TestStart_WithSecurity(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Security: config.SecurityConfig{
Access: config.AccessConfig{
Allow: []string{"127.0.0.1"},
Deny: []string{},
},
RateLimit: config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
},
Headers: config.SecurityHeaders{
XFrameOptions: "DENY",
XContentTypeOptions: "nosniff",
Servers: []config.ServerConfig{
{
Listen: "127.0.0.1:0",
Security: config.SecurityConfig{
Access: config.AccessConfig{
Allow: []string{"127.0.0.1"},
Deny: []string{},
},
RateLimit: config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
},
Headers: config.SecurityHeaders{
XFrameOptions: "DENY",
XContentTypeOptions: "nosniff",
},
},
},
},
@ -102,20 +106,22 @@ func TestStart_WithSecurity(t *testing.T) {
}
// 验证安全配置
if len(s.config.Server.Security.Access.Allow) != 1 {
t.Errorf("Expected 1 allowed IP, got %d", len(s.config.Server.Security.Access.Allow))
if len(s.config.Servers[0].Security.Access.Allow) != 1 {
t.Errorf("Expected 1 allowed IP, got %d", len(s.config.Servers[0].Security.Access.Allow))
}
}
// TestStart_WithRewrite 测试 URL 重写配置
func TestStart_WithRewrite(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Rewrite: []config.RewriteRule{
{
Pattern: "/old/(.*)",
Replacement: "/new/$1",
Servers: []config.ServerConfig{
{
Listen: "127.0.0.1:0",
Rewrite: []config.RewriteRule{
{
Pattern: "/old/(.*)",
Replacement: "/new/$1",
},
},
},
},
@ -127,16 +133,18 @@ func TestStart_WithRewrite(t *testing.T) {
}
// 验证重写配置
if len(s.config.Server.Rewrite) != 1 {
t.Errorf("Expected 1 rewrite rule, got %d", len(s.config.Server.Rewrite))
if len(s.config.Servers[0].Rewrite) != 1 {
t.Errorf("Expected 1 rewrite rule, got %d", len(s.config.Servers[0].Rewrite))
}
}
// TestStart_WithMonitoring 测试监控配置
func TestStart_WithMonitoring(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Servers: []config.ServerConfig{
{
Listen: "127.0.0.1:0",
},
},
Monitoring: config.MonitoringConfig{
Status: config.StatusConfig{
@ -174,12 +182,14 @@ func TestStart_WithErrorPage(t *testing.T) {
}
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Security: config.SecurityConfig{
ErrorPage: config.ErrorPageConfig{
Pages: map[int]string{
404: errorPagePath,
Servers: []config.ServerConfig{
{
Listen: "127.0.0.1:0",
Security: config.SecurityConfig{
ErrorPage: config.ErrorPageConfig{
Pages: map[int]string{
404: errorPagePath,
},
},
},
},
@ -192,7 +202,7 @@ func TestStart_WithErrorPage(t *testing.T) {
}
// 验证错误页面配置
if s.config.Server.Security.ErrorPage.Pages == nil {
if s.config.Servers[0].Security.ErrorPage.Pages == nil {
t.Error("Error page pages should not be nil")
}
}
@ -200,15 +210,17 @@ func TestStart_WithErrorPage(t *testing.T) {
// TestStart_WithLuaEnabled 测试 Lua 配置
func TestStart_WithLuaEnabled(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Lua: &config.LuaMiddlewareConfig{
Enabled: true,
GlobalSettings: config.LuaGlobalSettings{
MaxConcurrentCoroutines: 100,
CoroutineTimeout: 30 * time.Second,
CodeCacheSize: 100,
MaxExecutionTime: 30 * time.Second,
Servers: []config.ServerConfig{
{
Listen: "127.0.0.1:0",
Lua: &config.LuaMiddlewareConfig{
Enabled: true,
GlobalSettings: config.LuaGlobalSettings{
MaxConcurrentCoroutines: 100,
CoroutineTimeout: 30 * time.Second,
CodeCacheSize: 100,
MaxExecutionTime: 30 * time.Second,
},
},
},
},
@ -220,7 +232,7 @@ func TestStart_WithLuaEnabled(t *testing.T) {
}
// 验证 Lua 配置
if s.config.Server.Lua == nil || !s.config.Server.Lua.Enabled {
if s.config.Servers[0].Lua == nil || !s.config.Servers[0].Lua.Enabled {
t.Error("Lua should be enabled")
}
}
@ -241,19 +253,21 @@ func TestStart_WithMultipleProxies(t *testing.T) {
defer cleanup2()
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api1",
Targets: []config.ProxyTarget{
{URL: "http://" + backend1, Weight: 1},
Servers: []config.ServerConfig{
{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api1",
Targets: []config.ProxyTarget{
{URL: "http://" + backend1, Weight: 1},
},
},
},
{
Path: "/api2",
Targets: []config.ProxyTarget{
{URL: "http://" + backend2, Weight: 1},
{
Path: "/api2",
Targets: []config.ProxyTarget{
{URL: "http://" + backend2, Weight: 1},
},
},
},
},
@ -266,16 +280,18 @@ func TestStart_WithMultipleProxies(t *testing.T) {
}
// 验证代理配置
if len(s.config.Server.Proxy) != 2 {
t.Errorf("Expected 2 proxies, got %d", len(s.config.Server.Proxy))
if len(s.config.Servers[0].Proxy) != 2 {
t.Errorf("Expected 2 proxies, got %d", len(s.config.Servers[0].Proxy))
}
}
// TestStart_EmptyConfig 测试空配置
func TestStart_EmptyConfig(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Servers: []config.ServerConfig{
{
Listen: "127.0.0.1:0",
},
},
}
@ -298,38 +314,40 @@ func TestStart_WithAllFeatures(t *testing.T) {
writeFile(errorPagePath, []byte("<html><body>Not Found</body></html>"))
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Static: []config.StaticConfig{
{
Path: "/static",
Root: tempDir,
Index: []string{"index.html"},
Servers: []config.ServerConfig{
{
Listen: "127.0.0.1:0",
Static: []config.StaticConfig{
{
Path: "/static",
Root: tempDir,
Index: []string{"index.html"},
},
},
},
Compression: config.CompressionConfig{
Type: "gzip",
Level: 6,
},
Security: config.SecurityConfig{
Access: config.AccessConfig{
Allow: []string{"127.0.0.1"},
Compression: config.CompressionConfig{
Type: "gzip",
Level: 6,
},
RateLimit: config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
Security: config.SecurityConfig{
Access: config.AccessConfig{
Allow: []string{"127.0.0.1"},
},
RateLimit: config.RateLimitConfig{
RequestRate: 100,
Burst: 200,
},
Headers: config.SecurityHeaders{
XFrameOptions: "DENY",
},
ErrorPage: config.ErrorPageConfig{
Default: errorPagePath,
},
},
Headers: config.SecurityHeaders{
XFrameOptions: "DENY",
},
ErrorPage: config.ErrorPageConfig{
Default: errorPagePath,
},
},
Rewrite: []config.RewriteRule{
{
Pattern: "/old/(.*)",
Replacement: "/new/$1",
Rewrite: []config.RewriteRule{
{
Pattern: "/old/(.*)",
Replacement: "/new/$1",
},
},
},
},
@ -360,7 +378,7 @@ func TestStart_WithAllFeatures(t *testing.T) {
if !s.config.Performance.GoroutinePool.Enabled {
t.Error("GoroutinePool should be enabled")
}
if s.config.Server.Compression.Type != "gzip" {
if s.config.Servers[0].Compression.Type != "gzip" {
t.Error("Compression should be gzip")
}
}
@ -368,14 +386,16 @@ func TestStart_WithAllFeatures(t *testing.T) {
// TestStart_ServerOptions 测试服务器配置选项
func TestStart_ServerOptions(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
MaxConnsPerIP: 100,
MaxRequestsPerConn: 1000,
ClientMaxBodySize: "10MB",
Servers: []config.ServerConfig{
{
Listen: "127.0.0.1:0",
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
MaxConnsPerIP: 100,
MaxRequestsPerConn: 1000,
ClientMaxBodySize: "10MB",
},
},
}
@ -385,29 +405,31 @@ func TestStart_ServerOptions(t *testing.T) {
}
// 验证服务器选项
if s.config.Server.ReadTimeout != 30*time.Second {
t.Errorf("Expected ReadTimeout 30s, got %v", s.config.Server.ReadTimeout)
if s.config.Servers[0].ReadTimeout != 30*time.Second {
t.Errorf("Expected ReadTimeout 30s, got %v", s.config.Servers[0].ReadTimeout)
}
if s.config.Server.MaxConnsPerIP != 100 {
t.Errorf("Expected MaxConnsPerIP 100, got %d", s.config.Server.MaxConnsPerIP)
if s.config.Servers[0].MaxConnsPerIP != 100 {
t.Errorf("Expected MaxConnsPerIP 100, got %d", s.config.Servers[0].MaxConnsPerIP)
}
}
// TestStart_HealthCheckConfig 测试健康检查配置
func TestStart_HealthCheckConfig(t *testing.T) {
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://127.0.0.1:8081", Weight: 1},
},
HealthCheck: config.HealthCheckConfig{
Interval: 10 * time.Second,
Timeout: 5 * time.Second,
Path: "/health",
Servers: []config.ServerConfig{
{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://127.0.0.1:8081", Weight: 1},
},
HealthCheck: config.HealthCheckConfig{
Interval: 10 * time.Second,
Timeout: 5 * time.Second,
Path: "/health",
},
},
},
},
@ -420,7 +442,7 @@ func TestStart_HealthCheckConfig(t *testing.T) {
}
// 验证健康检查配置
if s.config.Server.Proxy[0].HealthCheck.Path != "/health" {
if s.config.Servers[0].Proxy[0].HealthCheck.Path != "/health" {
t.Error("Health check path should be /health")
}
}
@ -438,9 +460,6 @@ func TestStart_VHostMode(t *testing.T) {
Listen: "127.0.0.1:0",
},
},
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
},
}
s := New(cfg)
@ -461,13 +480,15 @@ func TestStart_WithProxyBackendError(t *testing.T) {
defer cleanup()
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://" + backendAddr, Weight: 1},
Servers: []config.ServerConfig{
{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://" + backendAddr, Weight: 1},
},
},
},
},
@ -480,8 +501,8 @@ func TestStart_WithProxyBackendError(t *testing.T) {
}
// 验证代理配置
if len(s.config.Server.Proxy) != 1 {
t.Errorf("Expected 1 proxy, got %d", len(s.config.Server.Proxy))
if len(s.config.Servers[0].Proxy) != 1 {
t.Errorf("Expected 1 proxy, got %d", len(s.config.Servers[0].Proxy))
}
}
@ -495,13 +516,15 @@ func TestStart_WithDelayedBackend(t *testing.T) {
defer cleanup()
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://" + backendAddr, Weight: 1},
Servers: []config.ServerConfig{
{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://" + backendAddr, Weight: 1},
},
},
},
},
@ -525,13 +548,15 @@ func TestStart_WithRandomResponse(t *testing.T) {
defer cleanup()
cfg := &config.Config{
Server: config.ServerConfig{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://" + backendAddr, Weight: 1},
Servers: []config.ServerConfig{
{
Listen: "127.0.0.1:0",
Proxy: []config.ProxyConfig{
{
Path: "/api",
Targets: []config.ProxyTarget{
{URL: "http://" + backendAddr, Weight: 1},
},
},
},
},

View File

@ -105,14 +105,10 @@ func MustStartTestServer(cfg *config.Config) *Server {
listenAddr := ""
if len(cfg.Servers) > 0 {
listenAddr = cfg.Servers[0].Listen
} else {
listenAddr = cfg.Server.Listen
}
if listenAddr == "" || listenAddr == ":80" {
if len(cfg.Servers) > 0 {
cfg.Servers[0].Listen = "127.0.0.1:0"
} else {
cfg.Server.Listen = "127.0.0.1:0"
}
}