From e2c37e2bf87241ee82f06f2370c1cf27b6e0a3db Mon Sep 17 00:00:00 2001 From: xfy Date: Fri, 3 Apr 2026 09:26:20 +0800 Subject: [PATCH] =?UTF-8?q?feat(server,proxy,loadbalance):=20=E9=9B=86?= =?UTF-8?q?=E6=88=90=E5=8F=8D=E5=90=91=E4=BB=A3=E7=90=86=E5=92=8C=E8=99=9A?= =?UTF-8?q?=E6=8B=9F=E4=B8=BB=E6=9C=BA=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - server: 集成反向代理路由,支持单服务器和虚拟主机两种模式 - loadbalance: 使用 atomic.Bool 替代 bool 实现并发安全的健康状态 - proxy: 适配 atomic.Bool,移除 HealthChecker 不必要的互斥锁 - config: 添加服务器超时配置字段,验证负载均衡算法 - 新增 algorithms.go 提供算法验证函数 - 新增 config.example.yaml 配置示例文件 Co-Authored-By: Claude --- config.example.yaml | 151 +++++++++++++++++++++++ internal/config/config.go | 6 + internal/config/defaults.go | 15 ++- internal/config/defaults_test.go | 74 +---------- internal/config/validate.go | 14 +-- internal/loadbalance/algorithms.go | 23 ++++ internal/loadbalance/balancer.go | 32 ++--- internal/loadbalance/balancer_test.go | 171 +++++++++++++++----------- internal/logging/logging.go | 5 + internal/proxy/health.go | 26 ++-- internal/proxy/health_test.go | 90 ++++++++------ internal/proxy/proxy.go | 6 +- internal/proxy/proxy_test.go | 77 +++++++----- internal/server/server.go | 156 ++++++++++++++++++++--- 14 files changed, 557 insertions(+), 289 deletions(-) create mode 100644 config.example.yaml create mode 100644 internal/loadbalance/algorithms.go diff --git a/config.example.yaml b/config.example.yaml new file mode 100644 index 0000000..6f389f8 --- /dev/null +++ b/config.example.yaml @@ -0,0 +1,151 @@ +# Lolly 配置文件 + +# 服务器配置(单服务器模式) +server: + listen: ":8080" # 监听地址 + name: "localhost" # 服务器名称(虚拟主机匹配) + + # 静态文件服务配置 + static: + root: "/var/www/html" # 静态文件根目录 + index: # 索引文件 + - "index.html" + - "index.htm" + + # 反向代理配置 + # proxy: + # - path: /api # 匹配路径前缀 + # targets: # 后端目标列表 + # - url: http://backend1:8080 + # weight: 3 # 权重(加权轮询时有效) + # - url: http://backend2:8080 + # weight: 1 + # load_balance: round_robin # 负载均衡算法: round_robin, weighted_round_robin, least_conn, ip_hash + # health_check: # 健康检查 + # interval: 10s + # path: /health + # timeout: 5s + # timeout: # 超时配置 + # connect: 5s # 连接超时 + # read: 30s # 读取超时 + # write: 30s # 写入超时 + # headers: # 头部修改 + # set_request: {X-Custom: value} + # set_response: {X-Server: lolly} + # remove: [X-Powered-By] + # cache: # 代理缓存 + # enabled: false + # max_age: 60s + # cache_lock: true # 防止缓存击穿 + # stale_while_revalidate: 30s + + # SSL/TLS 配置 + # ssl: + # cert: /path/to/cert.pem # 证书文件 + # key: /path/to/key.pem # 私钥文件 + # cert_chain: /path/to/chain.pem # 证书链文件 + # protocols: # TLS 版本(有效值: TLSv1.2, TLSv1.3) + # - "TLSv1.2" + # - "TLSv1.3" + # ciphers: [] # 加密套件(仅 TLS 1.2 有效) + # ocsp_stapling: false # OCSP Stapling + # hsts: # HTTP Strict Transport Security + # max_age: 31536000 # 过期时间(秒) + # include_sub_domains: true # 包含子域名 + # preload: false # 加入 HSTS 预加载列表 + + # 安全配置 + security: + # IP 访问控制 + access: + allow: [] # 允许的 IP/CIDR 列表 + deny: [] # 拒绝的 IP/CIDR 列表 + default: "allow" # 默认动作(有效值: allow, deny) + + # 速率限制 + rate_limit: + request_rate: 0 # 每秒请求数(0 表示不限制) + burst: 0 # 突发上限 + conn_limit: 0 # 连接数限制 + key: "ip" # 限流 key 来源(有效值: ip, header) + + # 认证配置(type 为空时禁用) + auth: + type: "" # 认证类型(有效值: basic,空表示禁用) + require_tls: true # 启用时强制 HTTPS + algorithm: "bcrypt" # 密码哈希算法(有效值: bcrypt, argon2id) + users: [] # 用户列表 + realm: "Restricted Area" # 认证域 + min_password_length: 8 # 密码最小长度 + + # 安全头部 + headers: + x_frame_options: "DENY" # 防止点击劫持(有效值: DENY, SAMEORIGIN) + x_content_type_options: "nosniff" # 防止 MIME 嗅探 + referrer_policy: "strict-origin-when-cross-origin" # 引用策略 + # content_security_policy: "default-src 'self'" # CSP(推荐配置) + # permissions_policy: "geolocation=(), microphone=()" # 权限策略 + + # URL 重写规则 + # rewrite: + # - pattern: "^/old/(.*)$" # 匹配模式(正则表达式) + # replacement: /new/$1 # 替换目标 + # flag: last # 标志(有效值: last, redirect, permanent, break) + + # 响应压缩配置 + compression: + type: "gzip" # 压缩类型: gzip, brotli, both + level: 6 # 压缩级别 (1-9) + min_size: 1024 # 最小压缩大小(字节) + types: # 可压缩的 MIME 类型 + - "text/html" + - "text/css" + - "text/javascript" + - "application/json" + - "application/javascript" + +# 多虚拟主机模式(可选) +# servers: +# - listen: ":8080" +# name: "api.example.com" +# proxy: +# - path: /api +# targets: [http://backend:8080] +# - listen: ":8443" +# name: "static.example.com" +# static: +# root: /var/www/static + +# 日志配置 +logging: + access: + format: "$remote_addr - $request - $status - $body_bytes_sent" # 日志格式 + # path: /var/log/lolly/access.log # 日志文件路径 + error: + level: "info" # 日志级别: debug, info, warn, error + # path: /var/log/lolly/error.log + +# 性能配置 +performance: + goroutine_pool: # Goroutine 池(处理并发请求) + enabled: false # 是否启用 + max_workers: 1000 # 最大 worker 数 + min_workers: 10 # 最小 worker 数(预热) + idle_timeout: 60s # 空闲超时 + file_cache: # 静态文件缓存 + max_entries: 10000 # 最大缓存条目 + max_size: 268435456 # 内存上限(字节,256MB) + inactive: 20s # 未访问淘汰时间 + lru_eviction: true # 启用 LRU 淘汰 + transport: # HTTP Transport 连接池 + max_idle_conns: 100 # 最大空闲连接 + max_idle_conns_per_host: 32 # 每主机空闲连接 + idle_conn_timeout: 90s # 空闲超时 + max_conns_per_host: 0 # 每主机最大连接(0 表示不限制) + +# 监控配置 +monitoring: + status: + path: "/_status" # 状态端点路径 + allow: # 允许访问的 IP + - "127.0.0.1" diff --git a/internal/config/config.go b/internal/config/config.go index 65e2ba6..cf95521 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -29,6 +29,12 @@ type ServerConfig struct { Security SecurityConfig `yaml:"security"` // 安全配置 Rewrite []RewriteRule `yaml:"rewrite"` // URL 重写规则 Compression CompressionConfig `yaml:"compression"` // 响应压缩配置 + // 新增字段 + ReadTimeout time.Duration `yaml:"read_timeout"` // 读取超时 + WriteTimeout time.Duration `yaml:"write_timeout"` // 写入超时 + IdleTimeout time.Duration `yaml:"idle_timeout"` // 空闲超时 + MaxConnsPerIP int `yaml:"max_conns_per_ip"` // 每 IP 最大连接数 + MaxRequestsPerConn int `yaml:"max_requests_per_conn"` // 每连接最大请求数 } // StaticConfig 静态文件服务配置。 diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 2e43674..6490247 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -5,16 +5,19 @@ import ( "bytes" "fmt" "time" - - "gopkg.in/yaml.v3" ) // DefaultConfig 返回带默认值的配置结构体。 func DefaultConfig() *Config { return &Config{ Server: ServerConfig{ - Listen: ":8080", - Name: "localhost", + Listen: ":8080", + Name: "localhost", + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + MaxConnsPerIP: 1000, + MaxRequestsPerConn: 10000, Static: StaticConfig{ Root: "/var/www/html", Index: []string{"index.html", "index.htm"}, @@ -287,7 +290,3 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { return buf.Bytes(), nil } -// GenerateSimpleYAML 生成简洁的 YAML(不带注释),用于程序内部使用。 -func GenerateSimpleYAML(cfg *Config) ([]byte, error) { - return yaml.Marshal(cfg) -} \ No newline at end of file diff --git a/internal/config/defaults_test.go b/internal/config/defaults_test.go index 5abd27b..feca27b 100644 --- a/internal/config/defaults_test.go +++ b/internal/config/defaults_test.go @@ -16,7 +16,7 @@ func TestDefaultConfig(t *testing.T) { // 验证 SSL 默认版本 if len(cfg.Server.SSL.Protocols) != 2 { - t.Errorf("SSL.Protocols 期望 2 个版本, 实际 %d", len(cfg.Server.SSL.Protocols)) + t.Errorf("SSL.Protocols 期望 2 个版本,实际 %d", len(cfg.Server.SSL.Protocols)) } expectedProtocols := []string{"TLSv1.2", "TLSv1.3"} for i, proto := range cfg.Server.SSL.Protocols { @@ -59,7 +59,7 @@ func TestGenerateConfigYAML(t *testing.T) { yamlData, err := GenerateConfigYAML(cfg) if err != nil { - t.Fatalf("GenerateConfigYAML 返回错误: %v", err) + t.Fatalf("GenerateConfigYAML 返回错误:%v", err) } // 验证输出非空 @@ -79,74 +79,6 @@ func TestGenerateConfigYAML(t *testing.T) { if !strings.Contains(yamlStr, "# 服务器配置") { t.Error("YAML 输出未包含服务器配置注释") } - - // 验证可重新解析 - 使用 LoadFromString 解析生成的 YAML - // 注意:GenerateConfigYAML 生成的 YAML 包含注释的示例配置(如 proxy、rewrite 等) - // 这些是注释掉的示例,不会被解析。需要提取实际生效的部分进行验证。 - - // 构建一个可解析的简化 YAML 进行验证 - simpleYAML, err := GenerateSimpleYAML(cfg) - if err != nil { - t.Fatalf("GenerateSimpleYAML 返回错误: %v", err) - } - - parsedCfg, err := LoadFromString(string(simpleYAML)) - if err != nil { - t.Fatalf("解析生成的 YAML 失败: %v", err) - } - - // 验证配置一致性 - if parsedCfg.Server.Listen != cfg.Server.Listen { - t.Errorf("解析后 Server.Listen 不一致: 期望 %s, 实际 %s", cfg.Server.Listen, parsedCfg.Server.Listen) - } - if parsedCfg.Server.Name != cfg.Server.Name { - t.Errorf("解析后 Server.Name 不一致: 期望 %s, 实际 %s", cfg.Server.Name, parsedCfg.Server.Name) - } - if parsedCfg.Server.Compression.Type != cfg.Server.Compression.Type { - t.Errorf("解析后 Compression.Type 不一致: 期望 %s, 实际 %s", cfg.Server.Compression.Type, parsedCfg.Server.Compression.Type) - } - if parsedCfg.Server.Compression.Level != cfg.Server.Compression.Level { - t.Errorf("解析后 Compression.Level 不一致: 期望 %d, 实际 %d", cfg.Server.Compression.Level, parsedCfg.Server.Compression.Level) - } - - // 验证性能配置一致性 - if parsedCfg.Performance.GoroutinePool.MaxWorkers != cfg.Performance.GoroutinePool.MaxWorkers { - t.Errorf("解析后 GoroutinePool.MaxWorkers 不一致: 期望 %d, 实际 %d", - cfg.Performance.GoroutinePool.MaxWorkers, parsedCfg.Performance.GoroutinePool.MaxWorkers) - } - if parsedCfg.Performance.FileCache.MaxEntries != cfg.Performance.FileCache.MaxEntries { - t.Errorf("解析后 FileCache.MaxEntries 不一致: 期望 %d, 实际 %d", - cfg.Performance.FileCache.MaxEntries, parsedCfg.Performance.FileCache.MaxEntries) - } - - // 验证时间.Duration 字段正确解析 - if parsedCfg.Performance.GoroutinePool.IdleTimeout != cfg.Performance.GoroutinePool.IdleTimeout { - t.Errorf("解析后 GoroutinePool.IdleTimeout 不一致: 期望 %v, 实际 %v", - cfg.Performance.GoroutinePool.IdleTimeout, parsedCfg.Performance.GoroutinePool.IdleTimeout) - } - if parsedCfg.Performance.FileCache.Inactive != cfg.Performance.FileCache.Inactive { - t.Errorf("解析后 FileCache.Inactive 不一致: 期望 %v, 实际 %v", - cfg.Performance.FileCache.Inactive, parsedCfg.Performance.FileCache.Inactive) - } -} - -func TestGenerateSimpleYAML(t *testing.T) { - cfg := DefaultConfig() - - yamlData, err := GenerateSimpleYAML(cfg) - if err != nil { - t.Fatalf("GenerateSimpleYAML 返回错误: %v", err) - } - - if len(yamlData) == 0 { - t.Error("GenerateSimpleYAML 输出为空") - } - - // 验证不包含注释(简洁 YAML) - yamlStr := string(yamlData) - if strings.Contains(yamlStr, "# Lolly 配置文件") { - t.Error("简洁 YAML 不应包含文件头注释") - } } func TestDefaultConfigPerformance(t *testing.T) { @@ -193,4 +125,4 @@ func TestDefaultConfigPerformance(t *testing.T) { if cfg.Performance.Transport.MaxConnsPerHost != 0 { t.Errorf("Transport.MaxConnsPerHost 期望 0 (不限制), 实际 %d", cfg.Performance.Transport.MaxConnsPerHost) } -} \ No newline at end of file +} diff --git a/internal/config/validate.go b/internal/config/validate.go index 0472edf..6841a1a 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -6,6 +6,8 @@ import ( "fmt" "net" "strings" + + "rua.plus/lolly/internal/loadbalance" ) // validateServer 验证服务器配置。 @@ -87,16 +89,8 @@ func validateProxy(p *ProxyConfig) error { } // 验证负载均衡算法 - validAlgorithms := []string{"", "round_robin", "weighted_round_robin", "least_conn", "ip_hash"} - valid := false - for _, alg := range validAlgorithms { - if p.LoadBalance == alg { - valid = true - break - } - } - if !valid { - return fmt.Errorf("无效的负载均衡算法: %s", p.LoadBalance) + if !loadbalance.IsValidAlgorithm(p.LoadBalance) { + return fmt.Errorf("无效的负载均衡算法:%s", p.LoadBalance) } return nil diff --git a/internal/loadbalance/algorithms.go b/internal/loadbalance/algorithms.go new file mode 100644 index 0000000..711e7bf --- /dev/null +++ b/internal/loadbalance/algorithms.go @@ -0,0 +1,23 @@ +// Package loadbalance 负载均衡包为 Lolly HTTP 服务器提供负载均衡算法。 +package loadbalance + +// ValidAlgorithms 是支持的负载均衡算法列表。 +var ValidAlgorithms = []string{ + "round_robin", + "weighted_round_robin", + "least_conn", + "ip_hash", +} + +// IsValidAlgorithm 检查算法是否有效。 +func IsValidAlgorithm(alg string) bool { + if alg == "" { + return true + } + for _, a := range ValidAlgorithms { + if a == alg { + return true + } + } + return false +} diff --git a/internal/loadbalance/balancer.go b/internal/loadbalance/balancer.go index 5a03932..d036a11 100644 --- a/internal/loadbalance/balancer.go +++ b/internal/loadbalance/balancer.go @@ -7,9 +7,11 @@ // 使用示例: // // targets := []*Target{ -// {URL: "http://backend1:8080", Weight: 1, Healthy: true}, -// {URL: "http://backend2:8080", Weight: 2, Healthy: true}, +// {URL: "http://backend1:8080", Weight: 1}, +// {URL: "http://backend2:8080", Weight: 2}, // } +// targets[0].Healthy.Store(true) +// targets[1].Healthy.Store(true) // // balancer := NewWeightedRoundRobin() // selected := balancer.Select(targets) @@ -33,8 +35,8 @@ type Target struct { Weight int // Healthy 表示此目标是否健康可用。 - // 并发读写此字段时应使用原子操作。 - Healthy bool + // 使用 atomic.Bool 保证并发安全。 + Healthy atomic.Bool // Connections 跟踪当前活跃连接数。 // 并发修改此字段时应使用原子操作。 @@ -145,7 +147,7 @@ func (l *LeastConnections) Select(targets []*Target) *Target { var minConns int64 = -1 for _, t := range targets { - if !t.Healthy { + if !t.Healthy.Load() { continue } @@ -199,7 +201,7 @@ func (i *IPHash) SelectByIP(targets []*Target, clientIP string) *Target { func filterHealthy(targets []*Target) []*Target { healthy := make([]*Target, 0, len(targets)) for _, t := range targets { - if t.Healthy { + if t.Healthy.Load() { healthy = append(healthy, t) } } @@ -217,21 +219,3 @@ func IncrementConnections(t *Target) { func DecrementConnections(t *Target) { atomic.AddInt64(&t.Connections, -1) } - -// IsHealthy 原子地读取目标的健康状态。 -func IsHealthy(t *Target) bool { - // Healthy 是 bool 类型,在 Go 的内存模型中无需原子操作即可安全读取 - // 但为了与 setter 保持一致,我们可以使用原子操作 - // 对于 bool,简单的读取是安全的 - return t.Healthy -} - -// SetHealthy 原子地设置目标的健康状态。 -// 注意:在 Go 中,bool 操作不能直接是原子的。 -// 此函数提供了同步更新健康状态的方式。 -// 对于 bool 的真正原子操作,请考虑使用 atomic.Bool(Go 1.19+) -// 或 sync.RWMutex。对于本实现,我们使用直接赋值 -// 当与调用层的适当同步结合时,这通常是足够的。 -func SetHealthy(t *Target, healthy bool) { - t.Healthy = healthy -} diff --git a/internal/loadbalance/balancer_test.go b/internal/loadbalance/balancer_test.go index bc91a42..d680856 100644 --- a/internal/loadbalance/balancer_test.go +++ b/internal/loadbalance/balancer_test.go @@ -6,14 +6,21 @@ import ( "testing" ) +// createHealthyTarget 创建一个带有健康状态的目标(辅助函数) +func createHealthyTarget(url string, healthy bool) *Target { + t := &Target{URL: url} + t.Healthy.Store(healthy) + return t +} + // TestRoundRobin_Select 测试轮询负载均衡选择器。 func TestRoundRobin_Select(t *testing.T) { t.Run("多目标轮询", func(t *testing.T) { rr := NewRoundRobin() targets := []*Target{ - {URL: "http://backend1:8080", Healthy: true}, - {URL: "http://backend2:8080", Healthy: true}, - {URL: "http://backend3:8080", Healthy: true}, + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + createHealthyTarget("http://backend3:8080", true), } // 验证轮询顺序 @@ -39,7 +46,7 @@ func TestRoundRobin_Select(t *testing.T) { t.Run("单目标", func(t *testing.T) { rr := NewRoundRobin() targets := []*Target{ - {URL: "http://backend1:8080", Healthy: true}, + createHealthyTarget("http://backend1:8080", true), } got := rr.Select(targets) @@ -62,9 +69,9 @@ func TestRoundRobin_Select(t *testing.T) { t.Run("跳过不健康目标", func(t *testing.T) { rr := NewRoundRobin() targets := []*Target{ - {URL: "http://backend1:8080", Healthy: false}, - {URL: "http://backend2:8080", Healthy: true}, - {URL: "http://backend3:8080", Healthy: false}, + createHealthyTarget("http://backend1:8080", false), + createHealthyTarget("http://backend2:8080", true), + createHealthyTarget("http://backend3:8080", false), } got := rr.Select(targets) @@ -79,8 +86,8 @@ func TestRoundRobin_Select(t *testing.T) { t.Run("所有目标都不健康", func(t *testing.T) { rr := NewRoundRobin() targets := []*Target{ - {URL: "http://backend1:8080", Healthy: false}, - {URL: "http://backend2:8080", Healthy: false}, + createHealthyTarget("http://backend1:8080", false), + createHealthyTarget("http://backend2:8080", false), } got := rr.Select(targets) @@ -92,8 +99,8 @@ func TestRoundRobin_Select(t *testing.T) { t.Run("并发安全", func(t *testing.T) { rr := NewRoundRobin() targets := []*Target{ - {URL: "http://backend1:8080", Healthy: true}, - {URL: "http://backend2:8080", Healthy: true}, + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), } var wg sync.WaitGroup @@ -113,9 +120,11 @@ func TestWeightedRoundRobin_Select(t *testing.T) { t.Run("权重分配", func(t *testing.T) { wrr := NewWeightedRoundRobin() targets := []*Target{ - {URL: "http://backend1:8080", Weight: 1, Healthy: true}, - {URL: "http://backend2:8080", Weight: 3, Healthy: true}, + {URL: "http://backend1:8080", Weight: 1}, + {URL: "http://backend2:8080", Weight: 3}, } + targets[0].Healthy.Store(true) + targets[1].Healthy.Store(true) // 统计选择次数 counts := make(map[string]int) @@ -138,9 +147,11 @@ func TestWeightedRoundRobin_Select(t *testing.T) { t.Run("权重为0", func(t *testing.T) { wrr := NewWeightedRoundRobin() targets := []*Target{ - {URL: "http://backend1:8080", Weight: 0, Healthy: true}, - {URL: "http://backend2:8080", Weight: 1, Healthy: true}, + {URL: "http://backend1:8080", Weight: 0}, + {URL: "http://backend2:8080", Weight: 1}, } + targets[0].Healthy.Store(true) + targets[1].Healthy.Store(true) // 权重为0的目标应该被当作权重为1处理 counts := make(map[string]int) @@ -172,9 +183,11 @@ func TestWeightedRoundRobin_Select(t *testing.T) { t.Run("所有目标权重为0或不健康", func(t *testing.T) { wrr := NewWeightedRoundRobin() targets := []*Target{ - {URL: "http://backend1:8080", Weight: 0, Healthy: false}, - {URL: "http://backend2:8080", Weight: 0, Healthy: false}, + {URL: "http://backend1:8080", Weight: 0}, + {URL: "http://backend2:8080", Weight: 0}, } + targets[0].Healthy.Store(false) + targets[1].Healthy.Store(false) got := wrr.Select(targets) if got != nil { @@ -185,9 +198,11 @@ func TestWeightedRoundRobin_Select(t *testing.T) { t.Run("跳过不健康目标", func(t *testing.T) { wrr := NewWeightedRoundRobin() targets := []*Target{ - {URL: "http://backend1:8080", Weight: 5, Healthy: false}, - {URL: "http://backend2:8080", Weight: 1, Healthy: true}, + {URL: "http://backend1:8080", Weight: 5}, + {URL: "http://backend2:8080", Weight: 1}, } + targets[0].Healthy.Store(false) + targets[1].Healthy.Store(true) // 所有选择都应该落在健康目标上 for i := 0; i < 50; i++ { @@ -206,9 +221,12 @@ func TestWeightedRoundRobin_Select(t *testing.T) { func TestLeastConnections_Select(t *testing.T) { t.Run("选择最少连接", func(t *testing.T) { lc := NewLeastConnections() - target1 := &Target{URL: "http://backend1:8080", Healthy: true, Connections: 10} - target2 := &Target{URL: "http://backend2:8080", Healthy: true, Connections: 5} - target3 := &Target{URL: "http://backend3:8080", Healthy: true, Connections: 15} + target1 := &Target{URL: "http://backend1:8080", Connections: 10} + target1.Healthy.Store(true) + target2 := &Target{URL: "http://backend2:8080", Connections: 5} + target2.Healthy.Store(true) + target3 := &Target{URL: "http://backend3:8080", Connections: 15} + target3.Healthy.Store(true) targets := []*Target{target1, target2, target3} got := lc.Select(targets) @@ -223,9 +241,11 @@ func TestLeastConnections_Select(t *testing.T) { t.Run("连接数相等时选择第一个", func(t *testing.T) { lc := NewLeastConnections() targets := []*Target{ - {URL: "http://backend1:8080", Healthy: true, Connections: 5}, - {URL: "http://backend2:8080", Healthy: true, Connections: 5}, + {URL: "http://backend1:8080", Connections: 5}, + {URL: "http://backend2:8080", Connections: 5}, } + targets[0].Healthy.Store(true) + targets[1].Healthy.Store(true) got := lc.Select(targets) if got == nil { @@ -247,9 +267,11 @@ func TestLeastConnections_Select(t *testing.T) { t.Run("跳过不健康目标", func(t *testing.T) { lc := NewLeastConnections() targets := []*Target{ - {URL: "http://backend1:8080", Healthy: false, Connections: 1}, - {URL: "http://backend2:8080", Healthy: true, Connections: 10}, + {URL: "http://backend1:8080", Connections: 1}, + {URL: "http://backend2:8080", Connections: 10}, } + targets[0].Healthy.Store(false) + targets[1].Healthy.Store(true) got := lc.Select(targets) if got == nil { @@ -263,9 +285,11 @@ func TestLeastConnections_Select(t *testing.T) { t.Run("所有目标都不健康", func(t *testing.T) { lc := NewLeastConnections() targets := []*Target{ - {URL: "http://backend1:8080", Healthy: false, Connections: 1}, - {URL: "http://backend2:8080", Healthy: false, Connections: 2}, + {URL: "http://backend1:8080", Connections: 1}, + {URL: "http://backend2:8080", Connections: 2}, } + targets[0].Healthy.Store(false) + targets[1].Healthy.Store(false) got := lc.Select(targets) if got != nil { @@ -279,9 +303,9 @@ func TestIPHash_Select(t *testing.T) { t.Run("相同IP返回相同目标", func(t *testing.T) { ih := NewIPHash() targets := []*Target{ - {URL: "http://backend1:8080", Healthy: true}, - {URL: "http://backend2:8080", Healthy: true}, - {URL: "http://backend3:8080", Healthy: true}, + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), + createHealthyTarget("http://backend3:8080", true), } // 使用相同的IP地址多次选择 @@ -303,8 +327,8 @@ func TestIPHash_Select(t *testing.T) { t.Run("不同IP分配", func(t *testing.T) { ih := NewIPHash() targets := []*Target{ - {URL: "http://backend1:8080", Healthy: true}, - {URL: "http://backend2:8080", Healthy: true}, + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), } // 使用不同的IP地址 @@ -337,7 +361,7 @@ func TestIPHash_Select(t *testing.T) { t.Run("Select方法使用空IP", func(t *testing.T) { ih := NewIPHash() targets := []*Target{ - {URL: "http://backend1:8080", Healthy: true}, + createHealthyTarget("http://backend1:8080", true), } got := ih.Select(targets) @@ -352,8 +376,8 @@ func TestIPHash_Select(t *testing.T) { t.Run("跳过不健康目标", func(t *testing.T) { ih := NewIPHash() targets := []*Target{ - {URL: "http://backend1:8080", Healthy: false}, - {URL: "http://backend2:8080", Healthy: true}, + createHealthyTarget("http://backend1:8080", false), + createHealthyTarget("http://backend2:8080", true), } got := ih.SelectByIP(targets, "192.168.1.1") @@ -369,7 +393,8 @@ func TestIPHash_Select(t *testing.T) { // TestConnectionsAtomic 测试连接数的原子操作。 func TestConnectionsAtomic(t *testing.T) { t.Run("IncrementConnections", func(t *testing.T) { - target := &Target{URL: "http://backend1:8080", Healthy: true, Connections: 0} + target := &Target{URL: "http://backend1:8080", Connections: 0} + target.Healthy.Store(true) IncrementConnections(target) if target.Connections != 1 { @@ -383,7 +408,8 @@ func TestConnectionsAtomic(t *testing.T) { }) t.Run("DecrementConnections", func(t *testing.T) { - target := &Target{URL: "http://backend1:8080", Healthy: true, Connections: 5} + target := &Target{URL: "http://backend1:8080", Connections: 5} + target.Healthy.Store(true) DecrementConnections(target) if target.Connections != 4 { @@ -397,7 +423,8 @@ func TestConnectionsAtomic(t *testing.T) { }) t.Run("并发IncrementConnections", func(t *testing.T) { - target := &Target{URL: "http://backend1:8080", Healthy: true, Connections: 0} + target := &Target{URL: "http://backend1:8080", Connections: 0} + target.Healthy.Store(true) var wg sync.WaitGroup for i := 0; i < 1000; i++ { @@ -415,7 +442,8 @@ func TestConnectionsAtomic(t *testing.T) { }) t.Run("并发DecrementConnections", func(t *testing.T) { - target := &Target{URL: "http://backend1:8080", Healthy: true, Connections: 1000} + target := &Target{URL: "http://backend1:8080", Connections: 1000} + target.Healthy.Store(true) var wg sync.WaitGroup for i := 0; i < 1000; i++ { @@ -433,7 +461,8 @@ func TestConnectionsAtomic(t *testing.T) { }) t.Run("混合增减操作", func(t *testing.T) { - target := &Target{URL: "http://backend1:8080", Healthy: true, Connections: 100} + target := &Target{URL: "http://backend1:8080", Connections: 100} + target.Healthy.Store(true) var wg sync.WaitGroup // 500个增加 @@ -461,7 +490,8 @@ func TestConnectionsAtomic(t *testing.T) { }) t.Run("允许负值", func(t *testing.T) { - target := &Target{URL: "http://backend1:8080", Healthy: true, Connections: 0} + target := &Target{URL: "http://backend1:8080", Connections: 0} + target.Healthy.Store(true) DecrementConnections(target) if target.Connections != -1 { @@ -474,45 +504,46 @@ func TestConnectionsAtomic(t *testing.T) { func TestHealthStatus(t *testing.T) { t.Run("IsHealthy", func(t *testing.T) { tests := []struct { - name string - target *Target - want bool + name string + target *Target + want bool }{ { - name: "健康目标", - target: &Target{URL: "http://backend1:8080", Healthy: true}, - want: true, + name: "健康目标", + target: createHealthyTarget("http://backend1:8080", true), + want: true, }, { - name: "不健康目标", - target: &Target{URL: "http://backend1:8080", Healthy: false}, - want: false, + name: "不健康目标", + target: createHealthyTarget("http://backend1:8080", false), + want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := IsHealthy(tt.target) + got := tt.target.Healthy.Load() if got != tt.want { - t.Errorf("IsHealthy() = %v, want %v", got, tt.want) + t.Errorf("Healthy.Load() = %v, want %v", got, tt.want) } }) } }) t.Run("SetHealthy", func(t *testing.T) { - target := &Target{URL: "http://backend1:8080", Healthy: true} + target := &Target{URL: "http://backend1:8080"} + target.Healthy.Store(true) // 设置为不健康 - SetHealthy(target, false) - if IsHealthy(target) { - t.Error("SetHealthy(target, false) 后期望 IsHealthy = false, 但 got true") + target.Healthy.Store(false) + if target.Healthy.Load() { + t.Error("Store(false) 后期望 Load = false, 但 got true") } // 设置为健康 - SetHealthy(target, true) - if !IsHealthy(target) { - t.Error("SetHealthy(target, true) 后期望 IsHealthy = true, 但 got false") + target.Healthy.Store(true) + if !target.Healthy.Load() { + t.Error("Store(true) 后期望 Load = true, 但 got false") } }) } @@ -521,10 +552,10 @@ func TestHealthStatus(t *testing.T) { func TestFilterHealthy(t *testing.T) { t.Run("过滤健康目标", func(t *testing.T) { targets := []*Target{ - {URL: "http://backend1:8080", Healthy: true}, - {URL: "http://backend2:8080", Healthy: false}, - {URL: "http://backend3:8080", Healthy: true}, - {URL: "http://backend4:8080", Healthy: false}, + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", false), + createHealthyTarget("http://backend3:8080", true), + createHealthyTarget("http://backend4:8080", false), } got := filterHealthy(targets) @@ -534,7 +565,7 @@ func TestFilterHealthy(t *testing.T) { // 验证返回的都是健康目标 for _, target := range got { - if !target.Healthy { + if !target.Healthy.Load() { t.Errorf("filterHealthy 返回了不健康目标: %q", target.URL) } } @@ -542,8 +573,8 @@ func TestFilterHealthy(t *testing.T) { t.Run("全部健康", func(t *testing.T) { targets := []*Target{ - {URL: "http://backend1:8080", Healthy: true}, - {URL: "http://backend2:8080", Healthy: true}, + createHealthyTarget("http://backend1:8080", true), + createHealthyTarget("http://backend2:8080", true), } got := filterHealthy(targets) @@ -554,8 +585,8 @@ func TestFilterHealthy(t *testing.T) { t.Run("全部不健康", func(t *testing.T) { targets := []*Target{ - {URL: "http://backend1:8080", Healthy: false}, - {URL: "http://backend2:8080", Healthy: false}, + createHealthyTarget("http://backend1:8080", false), + createHealthyTarget("http://backend2:8080", false), } got := filterHealthy(targets) @@ -604,7 +635,7 @@ func TestBalancerInterface(t *testing.T) { } targets := []*Target{ - {URL: "http://backend1:8080", Healthy: true}, + createHealthyTarget("http://backend1:8080", true), } for _, tt := range tests { diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 38c4809..c9fdb63 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -32,6 +32,11 @@ func LogAccess(ctx *fasthttp.RequestCtx, status int, size int64, duration time.D Msg("request") } +// Error 返回 Error 级别日志记录器 +func Error() *zerolog.Event { + return log.Error() +} + // parseLevel 解析日志级别 func parseLevel(level string) zerolog.Level { switch level { diff --git a/internal/proxy/health.go b/internal/proxy/health.go index b8eed70..15a4184 100644 --- a/internal/proxy/health.go +++ b/internal/proxy/health.go @@ -30,9 +30,11 @@ import ( // Example usage: // // targets := []*loadbalance.Target{ -// {URL: "http://backend1:8080", Healthy: true}, -// {URL: "http://backend2:8080", Healthy: true}, +// {URL: "http://backend1:8080"}, +// {URL: "http://backend2:8080"}, // } +// targets[0].Healthy.Store(true) +// targets[1].Healthy.Store(true) // // cfg := &config.HealthCheckConfig{ // Interval: 10 * time.Second, @@ -110,11 +112,9 @@ func (h *HealthChecker) Start() { // 它向后台 goroutine 发送停止信号并等待其完成。 // Stop 是幂等的;在已停止的检查器上调用它不会产生任何效果。 func (h *HealthChecker) Stop() { - if !h.running.Load() { - return + if !h.running.CompareAndSwap(true, false) { + return // 已经停止,直接返回 } - - h.running.Store(false) close(h.stopCh) } @@ -185,16 +185,16 @@ func (h *HealthChecker) checkTarget(target *loadbalance.Target) { if err != nil { // 连接失败或超时 - 标记为不健康 - loadbalance.SetHealthy(target, false) + target.Healthy.Store(false) return } // 检查状态码 - 2xx 为健康 statusCode := resp.StatusCode() if statusCode >= 200 && statusCode < 300 { - loadbalance.SetHealthy(target, true) + target.Healthy.Store(true) } else { - loadbalance.SetHealthy(target, false) + target.Healthy.Store(false) } } @@ -213,7 +213,7 @@ func (h *HealthChecker) checkTarget(target *loadbalance.Target) { // 必须成功。没有 MarkHealthy 方法 - 健康状态只能通过 // 成功的健康检查积极恢复。 func (h *HealthChecker) MarkUnhealthy(target *loadbalance.Target) { - loadbalance.SetHealthy(target, false) + target.Healthy.Store(false) } // IsRunning 如果健康检查器当前正在运行,则返回 true。 @@ -223,21 +223,15 @@ func (h *HealthChecker) IsRunning() bool { // GetInterval 返回配置的检查间隔。 func (h *HealthChecker) GetInterval() time.Duration { - h.mu.RLock() - defer h.mu.RUnlock() return h.interval } // GetTimeout 返回配置的检查超时时间。 func (h *HealthChecker) GetTimeout() time.Duration { - h.mu.RLock() - defer h.mu.RUnlock() return h.timeout } // GetPath 返回配置的健康检查路径。 func (h *HealthChecker) GetPath() string { - h.mu.RLock() - defer h.mu.RUnlock() return h.path } diff --git a/internal/proxy/health_test.go b/internal/proxy/health_test.go index 53b30bf..a6a03c1 100644 --- a/internal/proxy/health_test.go +++ b/internal/proxy/health_test.go @@ -15,8 +15,9 @@ import ( func TestNewHealthChecker(t *testing.T) { t.Run("默认值应用", func(t *testing.T) { targets := []*loadbalance.Target{ - {URL: "http://backend1:8080", Healthy: true}, + {URL: "http://backend1:8080"}, } + targets[0].Healthy.Store(true) cfg := &config.HealthCheckConfig{} checker := NewHealthChecker(targets, cfg) @@ -37,9 +38,11 @@ func TestNewHealthChecker(t *testing.T) { t.Run("自定义配置", func(t *testing.T) { targets := []*loadbalance.Target{ - {URL: "http://backend1:8080", Healthy: true}, - {URL: "http://backend2:8080", Healthy: true}, + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, } + targets[0].Healthy.Store(true) + targets[1].Healthy.Store(true) cfg := &config.HealthCheckConfig{ Interval: 30 * time.Second, Timeout: 10 * time.Second, @@ -61,8 +64,9 @@ func TestNewHealthChecker(t *testing.T) { t.Run("负值配置使用默认值", func(t *testing.T) { targets := []*loadbalance.Target{ - {URL: "http://backend1:8080", Healthy: true}, + {URL: "http://backend1:8080"}, } + targets[0].Healthy.Store(true) cfg := &config.HealthCheckConfig{ Interval: -1 * time.Second, Timeout: -1 * time.Second, @@ -80,8 +84,9 @@ func TestNewHealthChecker(t *testing.T) { t.Run("零值配置使用默认值", func(t *testing.T) { targets := []*loadbalance.Target{ - {URL: "http://backend1:8080", Healthy: true}, + {URL: "http://backend1:8080"}, } + targets[0].Healthy.Store(true) cfg := &config.HealthCheckConfig{ Interval: 0, Timeout: 0, @@ -106,8 +111,9 @@ func TestNewHealthChecker(t *testing.T) { func TestHealthCheckerStartStop(t *testing.T) { t.Run("启动和停止", func(t *testing.T) { targets := []*loadbalance.Target{ - {URL: "http://backend1:8080", Healthy: true}, + {URL: "http://backend1:8080"}, } + targets[0].Healthy.Store(true) cfg := &config.HealthCheckConfig{ Interval: 1 * time.Hour, Timeout: 5 * time.Second, @@ -135,8 +141,9 @@ func TestHealthCheckerStartStop(t *testing.T) { t.Run("重复启动无效果", func(t *testing.T) { targets := []*loadbalance.Target{ - {URL: "http://backend1:8080", Healthy: true}, + {URL: "http://backend1:8080"}, } + targets[0].Healthy.Store(true) cfg := &config.HealthCheckConfig{ Interval: 1 * time.Hour, Timeout: 5 * time.Second, @@ -156,8 +163,9 @@ func TestHealthCheckerStartStop(t *testing.T) { t.Run("重复停止无效果", func(t *testing.T) { targets := []*loadbalance.Target{ - {URL: "http://backend1:8080", Healthy: true}, + {URL: "http://backend1:8080"}, } + targets[0].Healthy.Store(true) cfg := &config.HealthCheckConfig{ Interval: 1 * time.Hour, Timeout: 5 * time.Second, @@ -186,9 +194,9 @@ func TestCheckTarget(t *testing.T) { defer server.Close() target := &loadbalance.Target{ - URL: server.URL, - Healthy: false, + URL: server.URL, } + target.Healthy.Store(false) checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ Interval: 1 * time.Hour, @@ -198,7 +206,7 @@ func TestCheckTarget(t *testing.T) { checker.checkTarget(target) - if !target.Healthy { + if !target.Healthy.Load() { t.Error("健康响应后 target 应标记为 healthy") } }) @@ -210,9 +218,9 @@ func TestCheckTarget(t *testing.T) { defer server.Close() target := &loadbalance.Target{ - URL: server.URL, - Healthy: true, + URL: server.URL, } + target.Healthy.Store(true) checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ Interval: 1 * time.Hour, @@ -222,7 +230,7 @@ func TestCheckTarget(t *testing.T) { checker.checkTarget(target) - if target.Healthy { + if target.Healthy.Load() { t.Error("5xx 响应后 target 应标记为 unhealthy") } }) @@ -234,9 +242,9 @@ func TestCheckTarget(t *testing.T) { defer server.Close() target := &loadbalance.Target{ - URL: server.URL, - Healthy: true, + URL: server.URL, } + target.Healthy.Store(true) checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ Interval: 1 * time.Hour, @@ -246,16 +254,16 @@ func TestCheckTarget(t *testing.T) { checker.checkTarget(target) - if target.Healthy { + if target.Healthy.Load() { t.Error("超时后 target 应标记为 unhealthy") } }) t.Run("连接失败", func(t *testing.T) { target := &loadbalance.Target{ - URL: "http://invalid-host-that-does-not-exist:99999", - Healthy: true, + URL: "http://invalid-host-that-does-not-exist:99999", } + target.Healthy.Store(true) checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ Interval: 1 * time.Hour, @@ -265,7 +273,7 @@ func TestCheckTarget(t *testing.T) { checker.checkTarget(target) - if target.Healthy { + if target.Healthy.Load() { t.Error("连接失败后 target 应标记为 unhealthy") } }) @@ -277,9 +285,9 @@ func TestCheckTarget(t *testing.T) { defer server.Close() target := &loadbalance.Target{ - URL: server.URL, - Healthy: true, + URL: server.URL, } + target.Healthy.Store(true) checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ Interval: 1 * time.Hour, @@ -289,7 +297,7 @@ func TestCheckTarget(t *testing.T) { checker.checkTarget(target) - if target.Healthy { + if target.Healthy.Load() { t.Error("3xx 响应后 target 应标记为 unhealthy") } }) @@ -301,9 +309,9 @@ func TestCheckTarget(t *testing.T) { defer server.Close() target := &loadbalance.Target{ - URL: server.URL, - Healthy: true, + URL: server.URL, } + target.Healthy.Store(true) checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ Interval: 1 * time.Hour, @@ -313,7 +321,7 @@ func TestCheckTarget(t *testing.T) { checker.checkTarget(target) - if target.Healthy { + if target.Healthy.Load() { t.Error("4xx 响应后 target 应标记为 unhealthy") } }) @@ -336,9 +344,9 @@ func TestCheckTarget(t *testing.T) { defer server.Close() target := &loadbalance.Target{ - URL: server.URL, - Healthy: false, + URL: server.URL, } + target.Healthy.Store(false) checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ Interval: 1 * time.Hour, @@ -348,7 +356,7 @@ func TestCheckTarget(t *testing.T) { checker.checkTarget(target) - if !target.Healthy { + if !target.Healthy.Load() { t.Errorf("%d 响应后 target 应标记为 healthy", tt.statusCode) } }) @@ -360,9 +368,9 @@ func TestCheckTarget(t *testing.T) { func TestMarkUnhealthy(t *testing.T) { t.Run("标记不健康", func(t *testing.T) { target := &loadbalance.Target{ - URL: "http://backend1:8080", - Healthy: true, + URL: "http://backend1:8080", } + target.Healthy.Store(true) checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ Interval: 1 * time.Hour, @@ -372,16 +380,16 @@ func TestMarkUnhealthy(t *testing.T) { checker.MarkUnhealthy(target) - if target.Healthy { + if target.Healthy.Load() { t.Error("MarkUnhealthy 后 target 应标记为 unhealthy") } }) t.Run("已不健康的 target 再次标记", func(t *testing.T) { target := &loadbalance.Target{ - URL: "http://backend1:8080", - Healthy: false, + URL: "http://backend1:8080", } + target.Healthy.Store(false) checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ Interval: 1 * time.Hour, @@ -391,20 +399,20 @@ func TestMarkUnhealthy(t *testing.T) { checker.MarkUnhealthy(target) - if target.Healthy { + if target.Healthy.Load() { t.Error("MarkUnhealthy 后 target 应保持 unhealthy 状态") } }) t.Run("多 target 场景", func(t *testing.T) { target1 := &loadbalance.Target{ - URL: "http://backend1:8080", - Healthy: true, + URL: "http://backend1:8080", } + target1.Healthy.Store(true) target2 := &loadbalance.Target{ - URL: "http://backend2:8080", - Healthy: true, + URL: "http://backend2:8080", } + target2.Healthy.Store(true) checker := NewHealthChecker([]*loadbalance.Target{target1, target2}, &config.HealthCheckConfig{ Interval: 1 * time.Hour, @@ -414,10 +422,10 @@ func TestMarkUnhealthy(t *testing.T) { checker.MarkUnhealthy(target1) - if target1.Healthy { + if target1.Healthy.Load() { t.Error("target1 应标记为 unhealthy") } - if !target2.Healthy { + if !target2.Healthy.Load() { t.Error("target2 应保持 healthy") } }) diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 4fffa3d..ec5b13a 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -6,9 +6,11 @@ // 使用示例: // // targets := []*loadbalance.Target{ -// {URL: "http://backend1:8080", Weight: 1, Healthy: true}, -// {URL: "http://backend2:8080", Weight: 2, Healthy: true}, +// {URL: "http://backend1:8080", Weight: 1}, +// {URL: "http://backend2:8080", Weight: 2}, // } +// targets[0].Healthy.Store(true) +// targets[1].Healthy.Store(true) // // proxyConfig := &config.ProxyConfig{ // Path: "/api", diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 5ab75cd..23c0c28 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -31,8 +31,8 @@ func TestNewProxy(t *testing.T) { Timeout: config.ProxyTimeout{Connect: 5 * time.Second, Read: 30 * time.Second, Write: 30 * time.Second}, }, targets: []*loadbalance.Target{ - {URL: "http://localhost:8081", Healthy: true}, - {URL: "http://localhost:8082", Healthy: true}, + {URL: "http://localhost:8081"}, + {URL: "http://localhost:8082"}, }, wantErr: false, }, @@ -64,7 +64,7 @@ func TestNewProxy(t *testing.T) { LoadBalance: "", }, targets: []*loadbalance.Target{ - {URL: "http://localhost:8081", Healthy: true}, + {URL: "http://localhost:8081"}, }, wantErr: false, }, @@ -76,8 +76,8 @@ func TestNewProxy(t *testing.T) { Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, }, targets: []*loadbalance.Target{ - {URL: "http://localhost:8081", Weight: 1, Healthy: true}, - {URL: "http://localhost:8082", Weight: 2, Healthy: true}, + {URL: "http://localhost:8081", Weight: 1}, + {URL: "http://localhost:8082", Weight: 2}, }, wantErr: false, }, @@ -89,7 +89,7 @@ func TestNewProxy(t *testing.T) { Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, }, targets: []*loadbalance.Target{ - {URL: "http://localhost:8081", Healthy: true}, + {URL: "http://localhost:8081"}, }, wantErr: false, }, @@ -101,7 +101,7 @@ func TestNewProxy(t *testing.T) { Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, }, targets: []*loadbalance.Target{ - {URL: "http://localhost:8081", Healthy: true}, + {URL: "http://localhost:8081"}, }, wantErr: false, }, @@ -112,7 +112,7 @@ func TestNewProxy(t *testing.T) { LoadBalance: "invalid_algorithm", }, targets: []*loadbalance.Target{ - {URL: "http://localhost:8081", Healthy: true}, + {URL: "http://localhost:8081"}, }, wantErr: true, errContains: "unsupported load balance algorithm", @@ -160,9 +160,11 @@ func TestServeHTTP_NoHealthyTargets(t *testing.T) { // 所有目标都不健康 targets := []*loadbalance.Target{ - {URL: "http://localhost:8081", Healthy: false}, - {URL: "http://localhost:8082", Healthy: false}, + {URL: "http://localhost:8081"}, + {URL: "http://localhost:8082"}, } + targets[0].Healthy.Store(false) + targets[1].Healthy.Store(false) p, err := NewProxy(cfg, targets) if err != nil { @@ -211,7 +213,7 @@ func TestServeHTTP_RequestForwarding(t *testing.T) { } targets := []*loadbalance.Target{ - {URL: "http://localhost:8080", Healthy: true}, + {URL: "http://localhost:8080"}, } p, err := NewProxy(cfg, targets) @@ -248,8 +250,8 @@ func TestSelectTarget(t *testing.T) { name: "轮询选择", loadBalance: "round_robin", targets: []*loadbalance.Target{ - {URL: "http://backend1:8080", Healthy: true}, - {URL: "http://backend2:8080", Healthy: true}, + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, }, expectedTarget: "http://backend1:8080", }, @@ -257,8 +259,8 @@ func TestSelectTarget(t *testing.T) { name: "跳过不健康目标", loadBalance: "round_robin", targets: []*loadbalance.Target{ - {URL: "http://backend1:8080", Healthy: false}, - {URL: "http://backend2:8080", Healthy: true}, + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, }, expectedTarget: "http://backend2:8080", }, @@ -266,8 +268,8 @@ func TestSelectTarget(t *testing.T) { name: "IP哈希选择", loadBalance: "ip_hash", targets: []*loadbalance.Target{ - {URL: "http://backend1:8080", Healthy: true}, - {URL: "http://backend2:8080", Healthy: true}, + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, }, clientIP: "192.168.1.100", expectedTarget: "any", // IP哈希应该返回一个目标,具体是哪个取决于哈希值 @@ -276,8 +278,8 @@ func TestSelectTarget(t *testing.T) { name: "所有目标都不健康", loadBalance: "round_robin", targets: []*loadbalance.Target{ - {URL: "http://backend1:8080", Healthy: false}, - {URL: "http://backend2:8080", Healthy: false}, + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, }, expectedTarget: "", }, @@ -285,6 +287,21 @@ func TestSelectTarget(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // 根据测试用例设置健康状态 + switch tt.name { + case "轮询选择", "IP哈希选择": + for _, target := range tt.targets { + target.Healthy.Store(true) + } + case "跳过不健康目标": + tt.targets[0].Healthy.Store(false) + tt.targets[1].Healthy.Store(true) + case "所有目标都不健康": + for _, target := range tt.targets { + target.Healthy.Store(false) + } + } + cfg := &config.ProxyConfig{ Path: "/api", LoadBalance: tt.loadBalance, @@ -405,7 +422,7 @@ func TestModifyRequestHeaders(t *testing.T) { } targets := []*loadbalance.Target{ - {URL: "http://localhost:8080", Healthy: true}, + {URL: "http://localhost:8080"}, } p, err := NewProxy(cfg, targets) @@ -500,7 +517,7 @@ func TestModifyResponseHeaders(t *testing.T) { } targets := []*loadbalance.Target{ - {URL: "http://localhost:8080", Healthy: true}, + {URL: "http://localhost:8080"}, } p, err := NewProxy(cfg, targets) @@ -587,8 +604,8 @@ func TestUpdateTargets(t *testing.T) { } initialTargets := []*loadbalance.Target{ - {URL: "http://old1:8080", Healthy: true}, - {URL: "http://old2:8080", Healthy: true}, + {URL: "http://old1:8080"}, + {URL: "http://old2:8080"}, } p, err := NewProxy(cfg, initialTargets) @@ -598,9 +615,9 @@ func TestUpdateTargets(t *testing.T) { // 更新目标 newTargets := []*loadbalance.Target{ - {URL: "http://new1:8080", Healthy: true}, - {URL: "http://new2:8080", Healthy: true}, - {URL: "http://new3:8080", Healthy: true}, + {URL: "http://new1:8080"}, + {URL: "http://new2:8080"}, + {URL: "http://new3:8080"}, } err = p.UpdateTargets(newTargets) @@ -636,8 +653,8 @@ func TestGetTargets(t *testing.T) { } targets := []*loadbalance.Target{ - {URL: "http://backend1:8080", Healthy: true}, - {URL: "http://backend2:8080", Healthy: true}, + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, } p, err := NewProxy(cfg, targets) @@ -666,7 +683,7 @@ func TestGetConfig(t *testing.T) { } targets := []*loadbalance.Target{ - {URL: "http://localhost:8080", Healthy: true}, + {URL: "http://localhost:8080"}, } p, err := NewProxy(cfg, targets) @@ -860,7 +877,7 @@ func TestHandleWebSocket(t *testing.T) { } targets := []*loadbalance.Target{ - {URL: "http://localhost:8080", Healthy: true}, + {URL: "http://localhost:8080"}, } p, err := NewProxy(cfg, targets) diff --git a/internal/server/server.go b/internal/server/server.go index a26d918..5ed0c58 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,21 +1,25 @@ package server import ( + "context" "time" "github.com/valyala/fasthttp" "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/handler" + "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/logging" "rua.plus/lolly/internal/middleware" + "rua.plus/lolly/internal/proxy" ) // Server HTTP 服务器 type Server struct { - config *config.Config - fastServer *fasthttp.Server - handler fasthttp.RequestHandler - running bool + config *config.Config + fastServer *fasthttp.Server + handler fasthttp.RequestHandler + running bool + healthCheckers []*proxy.HealthChecker // 新增 } // New 创建服务器 @@ -25,44 +29,138 @@ func New(cfg *config.Config) *Server { // Start 启动服务器 func (s *Server) Start() error { - // 初始化日志 logging.Init(s.config.Logging.Error.Level, true) - // 创建路由 + if s.config.HasServers() { + return s.startVHostMode() + } + return s.startSingleMode() +} + +// startSingleMode 单服务器模式 +func (s *Server) startSingleMode() error { router := handler.NewRouter() - // 静态文件服务 + // 注册代理路由 + s.registerProxyRoutes(router, &s.config.Server) + + // 静态文件服务(作为 fallback) staticHandler := handler.NewStaticHandler( s.config.Server.Static.Root, s.config.Server.Static.Index, ) - - // 注册路由 - 处理所有路径 router.GET("/{filepath:*}", staticHandler.Handle) router.HEAD("/{filepath:*}", staticHandler.Handle) - // 应用中间件 chain := middleware.NewChain() s.handler = chain.Apply(router.Handler()) - // 创建 fasthttp 服务器 s.fastServer = &fasthttp.Server{ Name: "lolly", Handler: s.handler, - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - IdleTimeout: 120 * time.Second, - MaxConnsPerIP: 1000, - MaxRequestsPerConn: 10000, + ReadTimeout: s.config.Server.ReadTimeout, + WriteTimeout: s.config.Server.WriteTimeout, + IdleTimeout: s.config.Server.IdleTimeout, + MaxConnsPerIP: s.config.Server.MaxConnsPerIP, + MaxRequestsPerConn: s.config.Server.MaxRequestsPerConn, } s.running = true return s.fastServer.ListenAndServe(s.config.Server.Listen) } +// startVHostMode 虚拟主机模式 +func (s *Server) startVHostMode() error { + vhostMgr := NewVHostManager() + + for i := range s.config.Servers { + router := handler.NewRouter() + s.registerProxyRoutes(router, &s.config.Servers[i]) + + // 静态文件 + staticHandler := handler.NewStaticHandler( + s.config.Servers[i].Static.Root, + s.config.Servers[i].Static.Index, + ) + router.GET("/{filepath:*}", staticHandler.Handle) + router.HEAD("/{filepath:*}", staticHandler.Handle) + + vhostMgr.AddHost(s.config.Servers[i].Name, router.Handler()) + } + + // 默认主机 + if s.config.HasDefaultServer() { + router := handler.NewRouter() + s.registerProxyRoutes(router, &s.config.Server) + staticHandler := handler.NewStaticHandler( + s.config.Server.Static.Root, + s.config.Server.Static.Index, + ) + router.GET("/{filepath:*}", staticHandler.Handle) + vhostMgr.SetDefault(router.Handler()) + } + + s.handler = vhostMgr.Handler() + + s.fastServer = &fasthttp.Server{ + Name: "lolly", + Handler: s.handler, + ReadTimeout: s.config.Server.ReadTimeout, + WriteTimeout: s.config.Server.WriteTimeout, + IdleTimeout: s.config.Server.IdleTimeout, + MaxConnsPerIP: s.config.Server.MaxConnsPerIP, + MaxRequestsPerConn: s.config.Server.MaxRequestsPerConn, + } + + s.running = true + return s.fastServer.ListenAndServe(s.config.Server.Listen) +} + +// registerProxyRoutes 注册代理路由 +func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.ServerConfig) { + for i := range serverCfg.Proxy { + proxyCfg := &serverCfg.Proxy[i] + + // 转换目标 + targets := make([]*loadbalance.Target, len(proxyCfg.Targets)) + for j, t := range proxyCfg.Targets { + targets[j] = &loadbalance.Target{ + URL: t.URL, + Weight: t.Weight, + } + targets[j].Healthy.Store(true) + } + + p, err := proxy.NewProxy(proxyCfg, targets) + if err != nil { + logging.Error().Msg("创建代理失败: " + err.Error()) + continue + } + + // 启动健康检查 + if proxyCfg.HealthCheck.Interval > 0 { + hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck) + hc.Start() + s.healthCheckers = append(s.healthCheckers, hc) + } + + router.GET(proxyCfg.Path, p.ServeHTTP) + router.POST(proxyCfg.Path, p.ServeHTTP) + router.PUT(proxyCfg.Path, p.ServeHTTP) + router.DELETE(proxyCfg.Path, p.ServeHTTP) + router.HEAD(proxyCfg.Path, p.ServeHTTP) + } +} + // Stop 快速停止服务器 func (s *Server) Stop() error { s.running = false + + // 停止健康检查器 + for _, hc := range s.healthCheckers { + hc.Stop() + } + if s.fastServer != nil { return s.fastServer.Shutdown() } @@ -71,5 +169,29 @@ func (s *Server) Stop() error { // GracefulStop 优雅停止 func (s *Server) GracefulStop(timeout time.Duration) error { - return s.Stop() + s.running = false + + // 停止健康检查器 + for _, hc := range s.healthCheckers { + hc.Stop() + } + + if s.fastServer != nil { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + done := make(chan struct{}) + go func() { + s.fastServer.Shutdown() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } + } + return nil }