diff --git a/internal/config/config.go b/internal/config/config.go index 1d04ea0..a70e4f2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -56,7 +56,7 @@ const ( // 是配置文件的顶级结构体,所有其他配置都作为其子结构。 // // 注意事项: -// - 必须配置 server 或 servers 中的至少一个 +// - 必须配置 servers 列表中的至少一个 // - 加载后会自动进行配置验证 // - Stream 配置为可选,用于 TCP/UDP 层代理 // - HTTP/3 配置为可选,需 SSL 配置配合才能生效 @@ -67,8 +67,7 @@ const ( // if err != nil { // log.Fatal(err) // } -// server := cfg.Server -// // 或使用多虚拟主机 +// // 使用多虚拟主机模式 // for _, s := range cfg.Servers { // // 处理每个服务器配置 // } @@ -81,7 +80,6 @@ type Config struct { Monitoring MonitoringConfig `yaml:"monitoring"` HTTP3 HTTP3Config `yaml:"http3"` Resolver ResolverConfig `yaml:"resolver"` - Server ServerConfig `yaml:"server"` Performance PerformanceConfig `yaml:"performance"` Shutdown ShutdownConfig `yaml:"shutdown"` } @@ -188,6 +186,7 @@ type ServerConfig struct { Lua *LuaMiddlewareConfig `yaml:"lua"` ClientMaxBodySize string `yaml:"client_max_body_size"` Name string `yaml:"name"` + Default bool `yaml:"default,omitempty"` // VHost 默认主机标记 Listen string `yaml:"listen"` Security SecurityConfig `yaml:"security"` Static []StaticConfig `yaml:"static"` @@ -1669,23 +1668,18 @@ func (c *Config) HasServers() bool { return len(c.Servers) > 0 } -// HasDefaultServer 检查是否有默认服务器配置。 -// -// 返回值: -// - bool: 如果 server.listen 已配置,返回 true -func (c *Config) HasDefaultServer() bool { - return c.Server.Listen != "" -} - -// GetDefaultServer 获取默认服务器配置。 +// GetDefaultServerFromList 从 servers 列表中获取默认服务器配置。 // +// 遍历 servers 列表,返回第一个 Default 标记为 true 的服务器。 // 用于在虚拟主机模式下获取默认服务器的配置作为 fallback。 // // 返回值: -// - *ServerConfig: 默认服务器配置,如未配置则返回 nil -func (c *Config) GetDefaultServer() *ServerConfig { - if c.HasDefaultServer() { - return &c.Server +// - *ServerConfig: 默认服务器配置,如无则返回 nil +func (c *Config) GetDefaultServerFromList() *ServerConfig { + for i := range c.Servers { + if c.Servers[i].Default { + return &c.Servers[i] + } } return nil } @@ -1766,6 +1760,11 @@ func Validate(cfg *Config) error { return err } + // 验证 default 服务器唯一性 + if err := validateDefaultServer(cfg.Servers); err != nil { + return err + } + // 验证所有服务器 for i := range cfg.Servers { if err := validateServer(&cfg.Servers[i], false); err != nil { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 348f3fc..9ff76e5 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -357,45 +357,31 @@ func TestConfigMethods(t *testing.T) { } }) - t.Run("HasDefaultServer_有默认服务器", func(t *testing.T) { + t.Run("GetDefaultServerFromList_有默认服务器", func(t *testing.T) { cfg := &Config{ - Server: ServerConfig{ - Listen: ":8080", + Servers: []ServerConfig{ + {Listen: ":8080", Name: "api"}, + {Listen: ":8081", Name: "default", Default: true}, }, } - if !cfg.HasDefaultServer() { - t.Error("HasDefaultServer() = false, want true") - } - }) - - t.Run("HasDefaultServer_无默认服务器", func(t *testing.T) { - cfg := &Config{} - if cfg.HasDefaultServer() { - t.Error("HasDefaultServer() = true, want false") - } - }) - - t.Run("GetDefaultServer_有默认服务器", func(t *testing.T) { - cfg := &Config{ - Server: ServerConfig{ - Listen: ":8080", - Name: "default", - }, - } - server := cfg.GetDefaultServer() + server := cfg.GetDefaultServerFromList() if server == nil { - t.Fatal("GetDefaultServer() = nil, want non-nil") + t.Fatal("GetDefaultServerFromList() = nil, want non-nil") } - if server.Listen != ":8080" { - t.Errorf("server.Listen = %q, want %q", server.Listen, ":8080") + if server.Listen != ":8081" { + t.Errorf("server.Listen = %q, want %q", server.Listen, ":8081") } }) - t.Run("GetDefaultServer_无默认服务器", func(t *testing.T) { - cfg := &Config{} - server := cfg.GetDefaultServer() + t.Run("GetDefaultServerFromList_无默认服务器", func(t *testing.T) { + cfg := &Config{ + Servers: []ServerConfig{ + {Listen: ":8080", Name: "api"}, + }, + } + server := cfg.GetDefaultServerFromList() if server != nil { - t.Errorf("GetDefaultServer() = %v, want nil", server) + t.Errorf("GetDefaultServerFromList() = %v, want nil", server) } }) @@ -404,28 +390,18 @@ func TestConfigMethods(t *testing.T) { cfg *Config name string wantHasServers bool - wantHasDefault bool }{ - { - name: "仅默认服务器", - cfg: &Config{Server: ServerConfig{Listen: ":8080"}}, - wantHasServers: false, - wantHasDefault: true, - }, { name: "仅多虚拟主机", cfg: &Config{Servers: []ServerConfig{{Listen: ":8080"}}}, wantHasServers: true, - wantHasDefault: false, }, { name: "混合模式", cfg: &Config{ - Server: ServerConfig{Listen: ":8080"}, Servers: []ServerConfig{{Listen: ":8081"}}, }, wantHasServers: true, - wantHasDefault: true, }, } @@ -434,9 +410,6 @@ func TestConfigMethods(t *testing.T) { if got := tt.cfg.HasServers(); got != tt.wantHasServers { t.Errorf("HasServers() = %v, want %v", got, tt.wantHasServers) } - if got := tt.cfg.HasDefaultServer(); got != tt.wantHasDefault { - t.Errorf("HasDefaultServer() = %v, want %v", got, tt.wantHasDefault) - } }) } }) diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 6f6a692..c523fa3 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -234,6 +234,15 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { // buf.WriteString("# 文档: https://github.com/xfy/lolly\n") buf.WriteString("\n") + // mode 配置 + buf.WriteString("# 运行模式配置\n") + buf.WriteString("# mode: auto # 运行模式(有效值: single, vhost, multi_server, auto)\n") + buf.WriteString("# auto: 自动推断模式(根据 servers 配置自动选择,默认)\n") + buf.WriteString("# single: 单服务器模式(只有一个 server)\n") + buf.WriteString("# vhost: 虚拟主机模式(多个 server 共享相同监听地址)\n") + buf.WriteString("# multi_server: 多服务器模式(多个 server 监听不同地址)\n") + buf.WriteString("\n") + // servers 配置 buf.WriteString("# 服务器配置(多服务器模式)\n") buf.WriteString("servers:\n") @@ -277,6 +286,9 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString(" # code_cache_size: 1000 # 字节码缓存条目数\n") buf.WriteString(" # enable_file_watch: true # 启用文件变更检测\n") buf.WriteString(" # max_execution_time: 30s # 最大执行时间\n") + buf.WriteString(" # coroutine_stack_size: 0 # 协程栈大小(0=使用默认值)\n") + buf.WriteString(" # coroutine_pool_warmup: 0 # 协程池预热数量(0=不预热)\n") + buf.WriteString(" # minimize_stack_memory: false # 最小化栈内存使用\n") buf.WriteString("\n") // static 配置 @@ -291,6 +303,7 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { } buf.WriteString(" try_files: [] # SPA 部署示例: [\"$uri\", \"$uri/\", \"/index.html\"]\n") buf.WriteString(" try_files_pass: false # 内部重定向是否触发中间件\n") + buf.WriteString(" symlink_check: false # 是否检查符号链接安全(防止路径遍历攻击)\n") } buf.WriteString(" # 示例:额外的静态目录\n") buf.WriteString(" # - path: \"/assets/\"\n") @@ -331,6 +344,11 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString(" # next_upstream: # 故障转移配置\n") buf.WriteString(" # tries: 1 # 最大尝试次数(1 表示禁用故障转移)\n") buf.WriteString(" # http_codes: [502, 503, 504] # 触发重试的 HTTP 状态码\n") + buf.WriteString(" # balancer_by_lua: # Lua 动态负载均衡(在 load_balance 基础上自定义选择逻辑)\n") + buf.WriteString(" # enabled: false # 是否启用 Lua 负载均衡\n") + buf.WriteString(" # script: \"\" # Lua 脚本路径,返回目标索引\n") + buf.WriteString(" # fallback: \"round_robin\" # Lua 失败时的备用算法(有效值: round_robin, weighted_round_robin, least_conn)\n") + buf.WriteString(" # timeout: 5s # Lua 执行超时\n") buf.WriteString("\n") // SSL 配置 @@ -375,6 +393,12 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { fmt.Fprintf(&buf, " # graceful_shutdown_timeout: %ds # HTTP/2 优雅关闭超时\n", int(cfg.Servers[0].SSL.HTTP2.GracefulShutdownTimeout.Seconds())) buf.WriteString("\n") + // SSL 默认值说明(即使不启用也展示默认配置) + buf.WriteString(" # SSL/TLS 默认配置说明(未配置证书时不启用)\n") + buf.WriteString(" # 默认 TLS 协议: TLSv1.2, TLSv1.3(不支持 TLSv1.0/1.1)\n") + buf.WriteString(" # 默认 HSTS 配置: max_age=31536000(1年), include_sub_domains=true\n") + buf.WriteString("\n") + // security 配置 buf.WriteString(" # 安全配置\n") buf.WriteString(" security:\n") @@ -385,6 +409,17 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { fmt.Fprintf(&buf, " default: \"%s\" # 默认动作(有效值: allow, deny)\n", cfg.Servers[0].Security.Access.Default) buf.WriteString(" trusted_proxies: [] # 可信代理 CIDR 列表,用于 X-Forwarded-For 解析\n") buf.WriteString("\n") + buf.WriteString(" # GeoIP 地理访问控制(基于 IP 所属国家/地区)\n") + buf.WriteString(" geoip:\n") + buf.WriteString(" enabled: false # 是否启用 GeoIP 访问控制\n") + buf.WriteString(" database: \"\" # GeoIP 数据库文件路径(如 /usr/share/GeoIP/GeoLite2-Country.mmdb)\n") + buf.WriteString(" default: \"allow\" # 未匹配时的默认动作(有效值: allow, deny)\n") + buf.WriteString(" private_ip_behavior: \"bypass\" # 私有 IP 处理方式(有效值: bypass, apply_default, deny)\n") + buf.WriteString(" allow_countries: [] # 允许的国家代码列表(如 [\"CN\", \"US\"])\n") + buf.WriteString(" deny_countries: [] # 拒绝的国家代码列表\n") + buf.WriteString(" cache_size: 10000 # GeoIP 查询缓存大小\n") + buf.WriteString(" cache_ttl: 3600 # 缓存有效期(秒)\n") + buf.WriteString("\n") buf.WriteString(" # 速率限制\n") buf.WriteString(" rate_limit:\n") fmt.Fprintf(&buf, " request_rate: %d # 每秒请求数(0 表示不限制)\n", cfg.Servers[0].Security.RateLimit.RequestRate) @@ -460,12 +495,6 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { fmt.Fprintf(&buf, " fast_timeout: %ds # 快速停止超时(SIGINT/SIGTERM,0=使用默认5s)\n", int(cfg.Shutdown.FastTimeout.Seconds())) buf.WriteString("\n") - // SSL 默认值说明(即使不启用也展示默认配置) - buf.WriteString("# SSL/TLS 默认配置说明(未配置证书时不启用)\n") - buf.WriteString("# 默认 TLS 协议: TLSv1.2, TLSv1.3(不支持 TLSv1.0/1.1)\n") - buf.WriteString("# 默认 HSTS 配置: max_age=31536000(1年), include_sub_domains=true\n") - buf.WriteString("\n") - // stream 配置 buf.WriteString("# TCP/UDP Stream 代理配置(可选)\n") buf.WriteString("# stream:\n") diff --git a/internal/config/validate.go b/internal/config/validate.go index f86b6d8..017177a 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -28,6 +28,26 @@ import ( "rua.plus/lolly/internal/variable" ) +// validateDefaultServer 验证 servers 中最多只有一个 default: true 服务器。 +// +// 参数: +// - servers: 服务器配置列表 +// +// 返回值: +// - error: 超过一个 default 时返回错误信息,成功返回 nil +func validateDefaultServer(servers []ServerConfig) error { + count := 0 + for _, s := range servers { + if s.Default { + count++ + } + } + if count > 1 { + return errors.New("只能有一个 default: true 服务器") + } + return nil +} + // validateMode 验证服务器运行模式有效值。 // // 检查 Mode 是否为 ServerModeSingle, ServerModeVHost, diff --git a/internal/server/lua_integration_test.go b/internal/server/lua_integration_test.go index ed9c5c4..7027f3c 100644 --- a/internal/server/lua_integration_test.go +++ b/internal/server/lua_integration_test.go @@ -15,8 +15,10 @@ import ( // TestBuildLuaMiddlewares_NilEngine 测试 LuaEngine 为 nil 时 func TestBuildLuaMiddlewares_NilEngine(t *testing.T) { cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: ":8080", + Servers: []config.ServerConfig{ + { + Listen: ":8080", + }, }, } s := New(cfg) @@ -44,8 +46,10 @@ func TestBuildLuaMiddlewares_NilEngine(t *testing.T) { // TestBuildLuaMiddlewares_InvalidPhase 测试无效阶段 func TestBuildLuaMiddlewares_InvalidPhase(t *testing.T) { cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: ":8080", + Servers: []config.ServerConfig{ + { + Listen: ":8080", + }, }, } s := New(cfg) @@ -77,8 +81,10 @@ func TestBuildLuaMiddlewares_InvalidPhase(t *testing.T) { // TestBuildLuaMiddlewares_WithTimeout 测试超时配置 func TestBuildLuaMiddlewares_WithTimeout(t *testing.T) { cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: ":8080", + Servers: []config.ServerConfig{ + { + Listen: ":8080", + }, }, } s := New(cfg) @@ -111,8 +117,10 @@ func TestBuildLuaMiddlewares_WithTimeout(t *testing.T) { // TestBuildLuaMiddlewares_EmptyScripts 测试空脚本列表 func TestBuildLuaMiddlewares_EmptyScripts(t *testing.T) { cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: ":8080", + Servers: []config.ServerConfig{ + { + Listen: ":8080", + }, }, } s := New(cfg) @@ -143,8 +151,10 @@ func TestBuildLuaMiddlewares_DisabledLua(t *testing.T) { require.NoError(t, err) cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: ":8080", + Servers: []config.ServerConfig{ + { + Listen: ":8080", + }, }, } s := New(cfg) @@ -185,8 +195,10 @@ func TestBuildLuaMiddlewares_DisabledScript(t *testing.T) { require.NoError(t, err) cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: ":8080", + Servers: []config.ServerConfig{ + { + Listen: ":8080", + }, }, } s := New(cfg) @@ -234,8 +246,10 @@ ngx.var.uri = "/test" require.NoError(t, err) cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: ":8080", + Servers: []config.ServerConfig{ + { + Listen: ":8080", + }, }, } s := New(cfg) @@ -289,8 +303,10 @@ ngx.say("hello") } cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: ":8080", + Servers: []config.ServerConfig{ + { + Listen: ":8080", + }, }, } s := New(cfg) @@ -340,8 +356,10 @@ func TestBuildLuaMiddlewares_DefaultTimeout(t *testing.T) { require.NoError(t, err) cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: ":8080", + Servers: []config.ServerConfig{ + { + Listen: ":8080", + }, }, } s := New(cfg) diff --git a/internal/server/server.go b/internal/server/server.go index e0b9d3a..189e24c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -573,7 +573,7 @@ func (s *Server) startVHostMode() error { } // 默认主机 - if s.config.HasDefaultServer() { + if s.config.GetDefaultServerFromList() != nil { router := handler.NewRouter() // 注册状态监控端点(如果配置) @@ -597,12 +597,14 @@ func (s *Server) startVHostMode() error { } } - s.registerProxyRoutes(router, &s.config.Server) + defaultSrv := s.config.GetDefaultServerFromList() + + s.registerProxyRoutes(router, defaultSrv) // 静态文件 - s.registerStaticHandlers(router, &s.config.Server) + s.registerStaticHandlers(router, defaultSrv) - chain, err := s.buildMiddlewareChain(&s.config.Server) + chain, err := s.buildMiddlewareChain(defaultSrv) if err != nil { return err } diff --git a/internal/server/start_integration_test.go b/internal/server/start_integration_test.go index 2e151b0..1ce2c81 100644 --- a/internal/server/start_integration_test.go +++ b/internal/server/start_integration_test.go @@ -41,15 +41,17 @@ func TestStart_Integration(t *testing.T) { serverAddr := "127.0.0.1:0" cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: serverAddr, - Proxy: []config.ProxyConfig{ - { - Path: "/api", - Targets: []config.ProxyTarget{ - {URL: "http://" + backendAddr, Weight: 1}, + Servers: []config.ServerConfig{ + { + Listen: serverAddr, + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://" + backendAddr, Weight: 1}, + }, + HealthCheck: config.HealthCheckConfig{}, }, - HealthCheck: config.HealthCheckConfig{}, }, }, }, @@ -77,20 +79,22 @@ func TestStart_Integration(t *testing.T) { // TestStart_WithSecurity 测试安全配置 func TestStart_WithSecurity(t *testing.T) { cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", - Security: config.SecurityConfig{ - Access: config.AccessConfig{ - Allow: []string{"127.0.0.1"}, - Deny: []string{}, - }, - RateLimit: config.RateLimitConfig{ - RequestRate: 100, - Burst: 200, - }, - Headers: config.SecurityHeaders{ - XFrameOptions: "DENY", - XContentTypeOptions: "nosniff", + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + Access: config.AccessConfig{ + Allow: []string{"127.0.0.1"}, + Deny: []string{}, + }, + RateLimit: config.RateLimitConfig{ + RequestRate: 100, + Burst: 200, + }, + Headers: config.SecurityHeaders{ + XFrameOptions: "DENY", + XContentTypeOptions: "nosniff", + }, }, }, }, @@ -102,20 +106,22 @@ func TestStart_WithSecurity(t *testing.T) { } // 验证安全配置 - if len(s.config.Server.Security.Access.Allow) != 1 { - t.Errorf("Expected 1 allowed IP, got %d", len(s.config.Server.Security.Access.Allow)) + if len(s.config.Servers[0].Security.Access.Allow) != 1 { + t.Errorf("Expected 1 allowed IP, got %d", len(s.config.Servers[0].Security.Access.Allow)) } } // TestStart_WithRewrite 测试 URL 重写配置 func TestStart_WithRewrite(t *testing.T) { cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", - Rewrite: []config.RewriteRule{ - { - Pattern: "/old/(.*)", - Replacement: "/new/$1", + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Rewrite: []config.RewriteRule{ + { + Pattern: "/old/(.*)", + Replacement: "/new/$1", + }, }, }, }, @@ -127,16 +133,18 @@ func TestStart_WithRewrite(t *testing.T) { } // 验证重写配置 - if len(s.config.Server.Rewrite) != 1 { - t.Errorf("Expected 1 rewrite rule, got %d", len(s.config.Server.Rewrite)) + if len(s.config.Servers[0].Rewrite) != 1 { + t.Errorf("Expected 1 rewrite rule, got %d", len(s.config.Servers[0].Rewrite)) } } // TestStart_WithMonitoring 测试监控配置 func TestStart_WithMonitoring(t *testing.T) { cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + }, }, Monitoring: config.MonitoringConfig{ Status: config.StatusConfig{ @@ -174,12 +182,14 @@ func TestStart_WithErrorPage(t *testing.T) { } cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", - Security: config.SecurityConfig{ - ErrorPage: config.ErrorPageConfig{ - Pages: map[int]string{ - 404: errorPagePath, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + ErrorPage: config.ErrorPageConfig{ + Pages: map[int]string{ + 404: errorPagePath, + }, }, }, }, @@ -192,7 +202,7 @@ func TestStart_WithErrorPage(t *testing.T) { } // 验证错误页面配置 - if s.config.Server.Security.ErrorPage.Pages == nil { + if s.config.Servers[0].Security.ErrorPage.Pages == nil { t.Error("Error page pages should not be nil") } } @@ -200,15 +210,17 @@ func TestStart_WithErrorPage(t *testing.T) { // TestStart_WithLuaEnabled 测试 Lua 配置 func TestStart_WithLuaEnabled(t *testing.T) { cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", - Lua: &config.LuaMiddlewareConfig{ - Enabled: true, - GlobalSettings: config.LuaGlobalSettings{ - MaxConcurrentCoroutines: 100, - CoroutineTimeout: 30 * time.Second, - CodeCacheSize: 100, - MaxExecutionTime: 30 * time.Second, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Lua: &config.LuaMiddlewareConfig{ + Enabled: true, + GlobalSettings: config.LuaGlobalSettings{ + MaxConcurrentCoroutines: 100, + CoroutineTimeout: 30 * time.Second, + CodeCacheSize: 100, + MaxExecutionTime: 30 * time.Second, + }, }, }, }, @@ -220,7 +232,7 @@ func TestStart_WithLuaEnabled(t *testing.T) { } // 验证 Lua 配置 - if s.config.Server.Lua == nil || !s.config.Server.Lua.Enabled { + if s.config.Servers[0].Lua == nil || !s.config.Servers[0].Lua.Enabled { t.Error("Lua should be enabled") } } @@ -241,19 +253,21 @@ func TestStart_WithMultipleProxies(t *testing.T) { defer cleanup2() cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", - Proxy: []config.ProxyConfig{ - { - Path: "/api1", - Targets: []config.ProxyTarget{ - {URL: "http://" + backend1, Weight: 1}, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api1", + Targets: []config.ProxyTarget{ + {URL: "http://" + backend1, Weight: 1}, + }, }, - }, - { - Path: "/api2", - Targets: []config.ProxyTarget{ - {URL: "http://" + backend2, Weight: 1}, + { + Path: "/api2", + Targets: []config.ProxyTarget{ + {URL: "http://" + backend2, Weight: 1}, + }, }, }, }, @@ -266,16 +280,18 @@ func TestStart_WithMultipleProxies(t *testing.T) { } // 验证代理配置 - if len(s.config.Server.Proxy) != 2 { - t.Errorf("Expected 2 proxies, got %d", len(s.config.Server.Proxy)) + if len(s.config.Servers[0].Proxy) != 2 { + t.Errorf("Expected 2 proxies, got %d", len(s.config.Servers[0].Proxy)) } } // TestStart_EmptyConfig 测试空配置 func TestStart_EmptyConfig(t *testing.T) { cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + }, }, } @@ -298,38 +314,40 @@ func TestStart_WithAllFeatures(t *testing.T) { writeFile(errorPagePath, []byte("Not Found")) cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", - Static: []config.StaticConfig{ - { - Path: "/static", - Root: tempDir, - Index: []string{"index.html"}, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Static: []config.StaticConfig{ + { + Path: "/static", + Root: tempDir, + Index: []string{"index.html"}, + }, }, - }, - Compression: config.CompressionConfig{ - Type: "gzip", - Level: 6, - }, - Security: config.SecurityConfig{ - Access: config.AccessConfig{ - Allow: []string{"127.0.0.1"}, + Compression: config.CompressionConfig{ + Type: "gzip", + Level: 6, }, - RateLimit: config.RateLimitConfig{ - RequestRate: 100, - Burst: 200, + Security: config.SecurityConfig{ + Access: config.AccessConfig{ + Allow: []string{"127.0.0.1"}, + }, + RateLimit: config.RateLimitConfig{ + RequestRate: 100, + Burst: 200, + }, + Headers: config.SecurityHeaders{ + XFrameOptions: "DENY", + }, + ErrorPage: config.ErrorPageConfig{ + Default: errorPagePath, + }, }, - Headers: config.SecurityHeaders{ - XFrameOptions: "DENY", - }, - ErrorPage: config.ErrorPageConfig{ - Default: errorPagePath, - }, - }, - Rewrite: []config.RewriteRule{ - { - Pattern: "/old/(.*)", - Replacement: "/new/$1", + Rewrite: []config.RewriteRule{ + { + Pattern: "/old/(.*)", + Replacement: "/new/$1", + }, }, }, }, @@ -360,7 +378,7 @@ func TestStart_WithAllFeatures(t *testing.T) { if !s.config.Performance.GoroutinePool.Enabled { t.Error("GoroutinePool should be enabled") } - if s.config.Server.Compression.Type != "gzip" { + if s.config.Servers[0].Compression.Type != "gzip" { t.Error("Compression should be gzip") } } @@ -368,14 +386,16 @@ func TestStart_WithAllFeatures(t *testing.T) { // TestStart_ServerOptions 测试服务器配置选项 func TestStart_ServerOptions(t *testing.T) { cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - IdleTimeout: 60 * time.Second, - MaxConnsPerIP: 100, - MaxRequestsPerConn: 1000, - ClientMaxBodySize: "10MB", + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + MaxConnsPerIP: 100, + MaxRequestsPerConn: 1000, + ClientMaxBodySize: "10MB", + }, }, } @@ -385,29 +405,31 @@ func TestStart_ServerOptions(t *testing.T) { } // 验证服务器选项 - if s.config.Server.ReadTimeout != 30*time.Second { - t.Errorf("Expected ReadTimeout 30s, got %v", s.config.Server.ReadTimeout) + if s.config.Servers[0].ReadTimeout != 30*time.Second { + t.Errorf("Expected ReadTimeout 30s, got %v", s.config.Servers[0].ReadTimeout) } - if s.config.Server.MaxConnsPerIP != 100 { - t.Errorf("Expected MaxConnsPerIP 100, got %d", s.config.Server.MaxConnsPerIP) + if s.config.Servers[0].MaxConnsPerIP != 100 { + t.Errorf("Expected MaxConnsPerIP 100, got %d", s.config.Servers[0].MaxConnsPerIP) } } // TestStart_HealthCheckConfig 测试健康检查配置 func TestStart_HealthCheckConfig(t *testing.T) { cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", - Proxy: []config.ProxyConfig{ - { - Path: "/api", - Targets: []config.ProxyTarget{ - {URL: "http://127.0.0.1:8081", Weight: 1}, - }, - HealthCheck: config.HealthCheckConfig{ - Interval: 10 * time.Second, - Timeout: 5 * time.Second, - Path: "/health", + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:8081", Weight: 1}, + }, + HealthCheck: config.HealthCheckConfig{ + Interval: 10 * time.Second, + Timeout: 5 * time.Second, + Path: "/health", + }, }, }, }, @@ -420,7 +442,7 @@ func TestStart_HealthCheckConfig(t *testing.T) { } // 验证健康检查配置 - if s.config.Server.Proxy[0].HealthCheck.Path != "/health" { + if s.config.Servers[0].Proxy[0].HealthCheck.Path != "/health" { t.Error("Health check path should be /health") } } @@ -438,9 +460,6 @@ func TestStart_VHostMode(t *testing.T) { Listen: "127.0.0.1:0", }, }, - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", - }, } s := New(cfg) @@ -461,13 +480,15 @@ func TestStart_WithProxyBackendError(t *testing.T) { defer cleanup() cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", - Proxy: []config.ProxyConfig{ - { - Path: "/api", - Targets: []config.ProxyTarget{ - {URL: "http://" + backendAddr, Weight: 1}, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://" + backendAddr, Weight: 1}, + }, }, }, }, @@ -480,8 +501,8 @@ func TestStart_WithProxyBackendError(t *testing.T) { } // 验证代理配置 - if len(s.config.Server.Proxy) != 1 { - t.Errorf("Expected 1 proxy, got %d", len(s.config.Server.Proxy)) + if len(s.config.Servers[0].Proxy) != 1 { + t.Errorf("Expected 1 proxy, got %d", len(s.config.Servers[0].Proxy)) } } @@ -495,13 +516,15 @@ func TestStart_WithDelayedBackend(t *testing.T) { defer cleanup() cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", - Proxy: []config.ProxyConfig{ - { - Path: "/api", - Targets: []config.ProxyTarget{ - {URL: "http://" + backendAddr, Weight: 1}, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://" + backendAddr, Weight: 1}, + }, }, }, }, @@ -525,13 +548,15 @@ func TestStart_WithRandomResponse(t *testing.T) { defer cleanup() cfg := &config.Config{ - Server: config.ServerConfig{ - Listen: "127.0.0.1:0", - Proxy: []config.ProxyConfig{ - { - Path: "/api", - Targets: []config.ProxyTarget{ - {URL: "http://" + backendAddr, Weight: 1}, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://" + backendAddr, Weight: 1}, + }, }, }, }, diff --git a/internal/server/testutil.go b/internal/server/testutil.go index 0b9a9df..3311337 100644 --- a/internal/server/testutil.go +++ b/internal/server/testutil.go @@ -105,14 +105,10 @@ func MustStartTestServer(cfg *config.Config) *Server { listenAddr := "" if len(cfg.Servers) > 0 { listenAddr = cfg.Servers[0].Listen - } else { - listenAddr = cfg.Server.Listen } if listenAddr == "" || listenAddr == ":80" { if len(cfg.Servers) > 0 { cfg.Servers[0].Listen = "127.0.0.1:0" - } else { - cfg.Server.Listen = "127.0.0.1:0" } }