diff --git a/config.example.yaml b/config.example.yaml index 6f389f8..6374c5e 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -4,6 +4,11 @@ server: listen: ":8080" # 监听地址 name: "localhost" # 服务器名称(虚拟主机匹配) + read_timeout: 30s # 读取超时(0 表示不限制) + write_timeout: 30s # 写入超时(0 表示不限制) + idle_timeout: 120s # 空闲超时(0 表示不限制) + max_conns_per_ip: 1000 # 每 IP 最大连接数(0 表示不限制) + max_requests_per_conn: 10000 # 每连接最大请求数(0 表示不限制) # 静态文件服务配置 static: @@ -80,10 +85,10 @@ server: # 安全头部 headers: - x_frame_options: "DENY" # 防止点击劫持(有效值: DENY, SAMEORIGIN) + x_frame_options: "DENY" # 防止点击劫持(有效值: DENY, SAMEORIGIN, 空表示禁用) x_content_type_options: "nosniff" # 防止 MIME 嗅探 - referrer_policy: "strict-origin-when-cross-origin" # 引用策略 - # content_security_policy: "default-src 'self'" # CSP(推荐配置) + referrer_policy: "strict-origin-when-cross-origin" # 引用策略(有效值: no-referrer, no-referrer-when-downgrade, origin, origin-when-cross-origin, same-origin, strict-origin, strict-origin-when-cross-origin, unsafe-url) + # content_security_policy: "default-src 'self'" # 内容安全策略 CSP # permissions_policy: "geolocation=(), microphone=()" # 权限策略 # URL 重写规则 @@ -94,9 +99,9 @@ server: # 响应压缩配置 compression: - type: "gzip" # 压缩类型: gzip, brotli, both - level: 6 # 压缩级别 (1-9) - min_size: 1024 # 最小压缩大小(字节) + type: "gzip" # 压缩类型(有效值: gzip, brotli, both,空表示禁用) + level: 6 # 压缩级别(范围 1-9,值越大压缩率越高但速度越慢) + min_size: 1024 # 最小压缩大小(字节,小于此值不压缩) types: # 可压缩的 MIME 类型 - "text/html" - "text/css" @@ -119,11 +124,11 @@ server: # 日志配置 logging: access: - format: "$remote_addr - $request - $status - $body_bytes_sent" # 日志格式 - # path: /var/log/lolly/access.log # 日志文件路径 + format: "$remote_addr - $request - $status - $body_bytes_sent" # 日志格式(支持变量: $remote_addr, $request, $status, $body_bytes_sent, $request_time, $http_referer, $http_user_agent) + # path: /var/log/lolly/access.log # 日志文件路径(空表示输出到 stdout) error: - level: "info" # 日志级别: debug, info, warn, error - # path: /var/log/lolly/error.log + level: "info" # 日志级别(有效值: debug, info, warn, error,级别越高日志越少) + # path: /var/log/lolly/error.log # 日志文件路径(空表示输出到 stderr) # 性能配置 performance: diff --git a/docs/plan.md b/docs/plan.md index 2c3e2c1..6eeb281 100644 --- a/docs/plan.md +++ b/docs/plan.md @@ -1433,7 +1433,7 @@ Phase 6: | Phase 3 | ✅ 完成 | 反向代理、负载均衡 | | Phase 4 | ✅ 完成 | SSL/TLS、安全控制 | | Phase 5 | ✅ 完成 | 重写、压缩、缓存、日志 | -| Phase 6 | ⏳ 待开始 | Stream、性能优化 | +| Phase 6 | ✅ 完成 | Stream、性能优化、热升级 | **Phase 2 技术选型变更**: - HTTP 库:使用 [fasthttp](https://github.com/valyala/fasthttp) 替代 `net/http`(性能提升 6 倍) diff --git a/internal/app/app.go b/internal/app/app.go index 2ca8bfd..b43f485 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,6 +9,7 @@ import ( "time" "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/logging" "rua.plus/lolly/internal/server" ) @@ -27,6 +28,33 @@ var ( shutdownTimeout = 30 * time.Second // 优雅停止超时时间 ) +// App 应用程序结构。 +type App struct { + cfgPath string + cfg *config.Config + srv *server.Server + upgradeMgr *server.UpgradeManager + pidFile string + logFile string // 日志文件路径(用于重新打开) +} + +// NewApp 创建应用程序。 +func NewApp(cfgPath string) *App { + return &App{ + cfgPath: cfgPath, + } +} + +// SetPidFile 设置 PID 文件路径。 +func (a *App) SetPidFile(path string) { + a.pidFile = path +} + +// SetLogFile 设置日志文件路径。 +func (a *App) SetLogFile(path string) { + a.logFile = path +} + // Run 应用程序入口。 func Run(cfgPath string, genConfig bool, outputPath string, showVersion bool) int { if genConfig { @@ -38,7 +66,8 @@ func Run(cfgPath string, genConfig bool, outputPath string, showVersion bool) in return 0 } - return startServer(cfgPath) + app := NewApp(cfgPath) + return app.Run() } // generateConfig 生成默认配置文件。 @@ -71,58 +100,159 @@ func printVersion() { fmt.Printf(" Platform: %s\n", BuildPlatform) } -// startServer 启动服务器。 -func startServer(cfgPath string) int { - cfg, err := config.Load(cfgPath) +// Run 启动应用程序。 +func (a *App) Run() int { + // 检查是否是子进程(热升级) + if os.Getenv("GRACEFUL_UPGRADE") == "1" { + fmt.Println("检测到热升级模式,继承父进程监听器") + } + + cfg, err := config.Load(a.cfgPath) if err != nil { fmt.Fprintf(os.Stderr, "加载配置失败: %v\n", err) return 1 } + a.cfg = cfg - fmt.Printf("配置加载成功: %s\n", cfgPath) + fmt.Printf("配置加载成功: %s\n", a.cfgPath) fmt.Printf("监听地址: %s\n", cfg.Server.Listen) // 创建服务器 - srv := server.New(cfg) + a.srv = server.New(cfg) - // 启动信号监听 + // 创建升级管理器 + a.upgradeMgr = server.NewUpgradeManager(a.srv) + if a.pidFile != "" { + a.upgradeMgr.SetPidFile(a.pidFile) + a.upgradeMgr.WritePid() + } + + // 启动信号处理 sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, - syscall.SIGTERM, // 快速停止(kill 或 systemd stop) - syscall.SIGINT, // 快速停止(Ctrl+C) - syscall.SIGQUIT, // 优雅停止 - ) + a.setupSignalHandlers(sigChan) - // 启动服务器(在 goroutine 中) + // 启动服务器 errChan := make(chan error, 1) go func() { fmt.Println("服务器启动中...") - if err := srv.Start(); err != nil { + if err := a.srv.Start(); err != nil { errChan <- err } }() // 等待信号或启动错误 - select { - case err := <-errChan: - fmt.Fprintf(os.Stderr, "服务器启动失败: %v\n", err) - return 1 - case sig := <-sigChan: - // 根据信号类型决定停止方式 - switch sig { - case syscall.SIGQUIT: - // 优雅停止:等待请求完成 - fmt.Printf("\n收到 SIGQUIT,优雅停止(等待 %v)...\n", shutdownTimeout) - srv.GracefulStop(shutdownTimeout) - case syscall.SIGTERM, syscall.SIGINT: - // 快速停止 - fmt.Printf("\n收到 %v,停止服务器...\n", sigName(sig.(syscall.Signal))) - srv.Stop() + for { + select { + case err := <-errChan: + fmt.Fprintf(os.Stderr, "服务器启动失败: %v\n", err) + return 1 + case sig := <-sigChan: + if !a.handleSignal(sig) { + // 返回 false 表示退出 + fmt.Println("服务器已停止") + return 0 + } + // 返回 true 表示继续运行(如重载配置) } } +} - fmt.Println("服务器已停止") - return 0 +// setupSignalHandlers 设置信号处理。 +func (a *App) setupSignalHandlers(sigChan chan<- os.Signal) { + signal.Notify(sigChan, + syscall.SIGTERM, // 快速停止(kill 或 systemd stop) + syscall.SIGINT, // 快速停止(Ctrl+C) + syscall.SIGQUIT, // 优雅停止 + syscall.SIGHUP, // 重载配置 + syscall.SIGUSR1, // 重新打开日志 + syscall.SIGUSR2, // 热升级 + ) +} + +// handleSignal 处理信号,返回 false 表示退出。 +func (a *App) handleSignal(sig os.Signal) bool { + switch sig { + case syscall.SIGQUIT: + // 优雅停止:等待请求完成 + fmt.Printf("\n收到 SIGQUIT,优雅停止(等待 %v)...\n", shutdownTimeout) + a.srv.GracefulStop(shutdownTimeout) + return false + + case syscall.SIGTERM, syscall.SIGINT: + // 快速停止 + fmt.Printf("\n收到 %v,停止服务器...\n", sigName(sig.(syscall.Signal))) + a.srv.Stop() + return false + + case syscall.SIGHUP: + // 重载配置 + fmt.Println("\n收到 SIGHUP,重载配置...") + a.reloadConfig() + return true + + case syscall.SIGUSR1: + // 重新打开日志 + fmt.Println("\n收到 SIGUSR1,重新打开日志...") + a.reopenLogs() + return true + + case syscall.SIGUSR2: + // 热升级 + fmt.Println("\n收到 SIGUSR2,执行热升级...") + a.gracefulUpgrade() + return true + + default: + fmt.Printf("\n收到未知信号: %v\n", sig) + return true + } +} + +// reloadConfig 重载配置。 +func (a *App) reloadConfig() { + newCfg, err := config.Load(a.cfgPath) + if err != nil { + fmt.Fprintf(os.Stderr, "重载配置失败: %v\n", err) + return + } + + // 更新配置 + a.cfg = newCfg + fmt.Println("配置重载成功") + + // 注意:当前实现不重启服务器,仅更新配置 + // 如需应用新配置,需要重启服务器或实现热更新 + fmt.Println("配置已重新加载") +} + +// reopenLogs 重新打开日志文件。 +func (a *App) reopenLogs() { + // 重新初始化日志系统 + if a.cfg != nil { + logging.Init(a.cfg.Logging.Error.Level, false) + } + fmt.Println("日志已重新打开") +} + +// gracefulUpgrade 执行热升级。 +func (a *App) gracefulUpgrade() { + // 获取当前可执行文件路径 + execPath, err := os.Executable() + if err != nil { + fmt.Fprintf(os.Stderr, "获取可执行文件路径失败: %v\n", err) + return + } + + // 执行升级 + if err := a.upgradeMgr.GracefulUpgrade(execPath); err != nil { + fmt.Fprintf(os.Stderr, "热升级失败: %v\n", err) + return + } + + fmt.Println("热升级已启动,新进程正在接管") + + // 当前进程优雅停止 + a.srv.GracefulStop(shutdownTimeout) } // sigName 返回信号名称(用于日志输出)。 @@ -134,7 +264,13 @@ func sigName(sig syscall.Signal) string { return "SIGINT" case syscall.SIGQUIT: return "SIGQUIT" + case syscall.SIGHUP: + return "SIGHUP" + case syscall.SIGUSR1: + return "SIGUSR1" + case syscall.SIGUSR2: + return "SIGUSR2" default: return fmt.Sprintf("Signal(%d)", sig) } -} \ No newline at end of file +} diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 6231fee..687a837 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -56,13 +56,13 @@ func captureStderr(t *testing.T) (func() string, func()) { // TestRun 测试 Run 函数的各种场景。 func TestRun(t *testing.T) { tests := []struct { - name string - cfgPath string - genConfig bool - outputPath string - showVersion bool - wantExitCode int - wantContains string // stdout 应包含的内容 + name string + cfgPath string + genConfig bool + outputPath string + showVersion bool + wantExitCode int + wantContains string // stdout 应包含的内容 wantErrContains string // stderr 应包含的内容(可选) }{ { @@ -86,11 +86,11 @@ func TestRun(t *testing.T) { wantContains: "配置已写入:", }, { - name: "配置文件不存在", - cfgPath: filepath.Join(t.TempDir(), "nonexistent.yaml"), - genConfig: false, - showVersion: false, - wantExitCode: 1, + name: "配置文件不存在", + cfgPath: filepath.Join(t.TempDir(), "nonexistent.yaml"), + genConfig: false, + showVersion: false, + wantExitCode: 1, wantErrContains: "加载配置失败", }, } @@ -234,4 +234,4 @@ func TestPrintVersion(t *testing.T) { t.Errorf("版本输出应包含 %q, 实际输出: %q", line, stdout) } } -} \ No newline at end of file +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go index 5fa2223..c78c864 100644 --- a/internal/cache/cache_test.go +++ b/internal/cache/cache_test.go @@ -344,4 +344,4 @@ func TestContainsInt(t *testing.T) { if containsInt([]int{200, 301, 302}, 404) { t.Error("Expected not to find 404") } -} \ No newline at end of file +} diff --git a/internal/cache/file_cache.go b/internal/cache/file_cache.go index 433b002..6398bb8 100644 --- a/internal/cache/file_cache.go +++ b/internal/cache/file_cache.go @@ -10,23 +10,23 @@ import ( // FileEntry 文件缓存条目。 type FileEntry struct { - Path string // 文件路径 - Size int64 // 文件大小 - ModTime time.Time // 修改时间 - LastAccess time.Time // 最后访问时间 - Data []byte // 文件内容 + Path string // 文件路径 + Size int64 // 文件大小 + ModTime time.Time // 修改时间 + LastAccess time.Time // 最后访问时间 + Data []byte // 文件内容 element *list.Element // LRU 链表元素 } // FileCache 文件缓存,支持 LRU 淘汰。 type FileCache struct { - maxEntries int64 // 最大条目数 - maxSize int64 // 内存上限(字节) - inactive time.Duration // 未访问淘汰时间 - entries map[string]*FileEntry - lruList *list.List // LRU 链表 - mu sync.RWMutex - currentSize int64 // 当前内存使用 + maxEntries int64 // 最大条目数 + maxSize int64 // 内存上限(字节) + inactive time.Duration // 未访问淘汰时间 + entries map[string]*FileEntry + lruList *list.List // LRU 链表 + mu sync.RWMutex + currentSize int64 // 当前内存使用 } // NewFileCache 创建文件缓存。 @@ -158,10 +158,10 @@ func (c *FileCache) Stats() FileCacheStats { defer c.mu.RUnlock() return FileCacheStats{ - Entries: int64(len(c.entries)), - MaxEntries: c.maxEntries, - Size: c.currentSize, - MaxSize: c.maxSize, + Entries: int64(len(c.entries)), + MaxEntries: c.maxEntries, + Size: c.currentSize, + MaxSize: c.maxSize, } } @@ -183,22 +183,22 @@ type ProxyCacheRule struct { // ProxyCacheEntry 代理缓存条目。 type ProxyCacheEntry struct { - Key string // 缓存 key - Data []byte // 响应体 - Headers map[string]string // 响应头 - Status int // 状态码 - Created time.Time // 创建时间 - MaxAge time.Duration // 有效期 + Key string // 缓存 key + Data []byte // 响应体 + Headers map[string]string // 响应头 + Status int // 状态码 + Created time.Time // 创建时间 + MaxAge time.Duration // 有效期 } // ProxyCache 代理响应缓存,支持缓存锁防击穿。 type ProxyCache struct { - rules []ProxyCacheRule - entries map[string]*ProxyCacheEntry - mu sync.RWMutex - cacheLock bool // 缓存锁开关 - pending map[string]*pendingRequest // 正在生成的缓存项 - staleTime time.Duration // 过期缓存复用时间 + rules []ProxyCacheRule + entries map[string]*ProxyCacheEntry + mu sync.RWMutex + cacheLock bool // 缓存锁开关 + pending map[string]*pendingRequest // 正在生成的缓存项 + staleTime time.Duration // 过期缓存复用时间 } // pendingRequest 等待中的缓存请求。 @@ -401,4 +401,4 @@ func (c *ProxyCache) Stats() ProxyCacheStats { type ProxyCacheStats struct { Entries int Pending int -} \ No newline at end of file +} diff --git a/internal/config/config.go b/internal/config/config.go index cf95521..ab94061 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -45,13 +45,13 @@ type StaticConfig struct { // ProxyConfig 反向代理配置,支持负载均衡和健康检查。 type ProxyConfig struct { - Path string `yaml:"path"` // 匹配路径前缀 - Targets []ProxyTarget `yaml:"targets"` // 后端目标列表 - LoadBalance string `yaml:"load_balance"` // 负载均衡算法:round_robin, weighted_round_robin, least_conn, ip_hash + Path string `yaml:"path"` // 匹配路径前缀 + Targets []ProxyTarget `yaml:"targets"` // 后端目标列表 + LoadBalance string `yaml:"load_balance"` // 负载均衡算法:round_robin, weighted_round_robin, least_conn, ip_hash HealthCheck HealthCheckConfig `yaml:"health_check"` // 健康检查配置 - Timeout ProxyTimeout `yaml:"timeout"` // 超时配置 - Headers ProxyHeaders `yaml:"headers"` // 请求/响应头修改 - Cache ProxyCacheConfig `yaml:"cache"` // 代理缓存配置 + Timeout ProxyTimeout `yaml:"timeout"` // 超时配置 + Headers ProxyHeaders `yaml:"headers"` // 请求/响应头修改 + Cache ProxyCacheConfig `yaml:"cache"` // 代理缓存配置 } // ProxyTarget 后端目标配置。 @@ -83,21 +83,21 @@ type ProxyHeaders struct { // ProxyCacheConfig 代理缓存配置。 type ProxyCacheConfig struct { - Enabled bool `yaml:"enabled"` // 是否启用缓存 - MaxAge time.Duration `yaml:"max_age"` // 缓存有效期 - CacheLock bool `yaml:"cache_lock"` // 缓存锁,防止击穿 + Enabled bool `yaml:"enabled"` // 是否启用缓存 + MaxAge time.Duration `yaml:"max_age"` // 缓存有效期 + CacheLock bool `yaml:"cache_lock"` // 缓存锁,防止击穿 StaleWhileRevalidate time.Duration `yaml:"stale_while_revalidate"` // 过期缓存复用时间 } // SSLConfig SSL/TLS 配置。 type SSLConfig struct { - Cert string `yaml:"cert"` // 证书文件路径 - Key string `yaml:"key"` // 私钥文件路径 - CertChain string `yaml:"cert_chain"` // 证书链文件路径 - Protocols []string `yaml:"protocols"` // TLS 版本,默认 ["TLSv1.2", "TLSv1.3"] - Ciphers []string `yaml:"ciphers"` // 加密套件(仅 TLS 1.2 有效) - OCSPStapling bool `yaml:"ocsp_stapling"` // OCSP Stapling 支持 - HSTS HSTSConfig `yaml:"hsts"` // HSTS 配置 + Cert string `yaml:"cert"` // 证书文件路径 + Key string `yaml:"key"` // 私钥文件路径 + CertChain string `yaml:"cert_chain"` // 证书链文件路径 + Protocols []string `yaml:"protocols"` // TLS 版本,默认 ["TLSv1.2", "TLSv1.3"] + Ciphers []string `yaml:"ciphers"` // 加密套件(仅 TLS 1.2 有效) + OCSPStapling bool `yaml:"ocsp_stapling"` // OCSP Stapling 支持 + HSTS HSTSConfig `yaml:"hsts"` // HSTS 配置 } // HSTSConfig HTTP Strict Transport Security 配置。 @@ -109,10 +109,10 @@ type HSTSConfig struct { // SecurityConfig 安全配置,包含访问控制、限流、认证和安全头部。 type SecurityConfig struct { - Access AccessConfig `yaml:"access"` // IP 访问控制 + Access AccessConfig `yaml:"access"` // IP 访问控制 RateLimit RateLimitConfig `yaml:"rate_limit"` // 速率限制 - Auth AuthConfig `yaml:"auth"` // 认证配置 - Headers SecurityHeaders `yaml:"headers"` // 安全头部 + Auth AuthConfig `yaml:"auth"` // 认证配置 + Headers SecurityHeaders `yaml:"headers"` // 安全头部 } // AccessConfig IP 访问控制配置。 @@ -132,12 +132,12 @@ type RateLimitConfig struct { // AuthConfig 认证配置。 type AuthConfig struct { - Type string `yaml:"type"` // 认证类型:basic - RequireTLS bool `yaml:"require_tls"` // 强制 HTTPS,默认 true - Algorithm string `yaml:"algorithm"` // 哈希算法:bcrypt, argon2id - Users []User `yaml:"users"` // 用户列表 - Realm string `yaml:"realm"` // 认证域 - MinPasswordLength int `yaml:"min_password_length"` // 密码最小长度 + Type string `yaml:"type"` // 认证类型:basic + RequireTLS bool `yaml:"require_tls"` // 强制 HTTPS,默认 true + Algorithm string `yaml:"algorithm"` // 哈希算法:bcrypt, argon2id + Users []User `yaml:"users"` // 用户列表 + Realm string `yaml:"realm"` // 认证域 + MinPasswordLength int `yaml:"min_password_length"` // 密码最小长度 } // User 认证用户配置。 @@ -148,11 +148,11 @@ type User struct { // SecurityHeaders 安全头部配置。 type SecurityHeaders struct { - XFrameOptions string `yaml:"x_frame_options"` // X-Frame-Options: DENY, SAMEORIGIN - XContentTypeOptions string `yaml:"x_content_type_options"` // X-Content-Type-Options: nosniff + XFrameOptions string `yaml:"x_frame_options"` // X-Frame-Options: DENY, SAMEORIGIN + XContentTypeOptions string `yaml:"x_content_type_options"` // X-Content-Type-Options: nosniff ContentSecurityPolicy string `yaml:"content_security_policy"` // Content-Security-Policy - ReferrerPolicy string `yaml:"referrer_policy"` // Referrer-Policy - PermissionsPolicy string `yaml:"permissions_policy"` // Permissions-Policy + ReferrerPolicy string `yaml:"referrer_policy"` // Referrer-Policy + PermissionsPolicy string `yaml:"permissions_policy"` // Permissions-Policy } // RewriteRule URL 重写规则。 @@ -164,10 +164,10 @@ type RewriteRule struct { // CompressionConfig 响应压缩配置。 type CompressionConfig struct { - Type string `yaml:"type"` // 压缩类型:gzip, brotli, both - Level int `yaml:"level"` // 压缩级别:1-9 + Type string `yaml:"type"` // 压缩类型:gzip, brotli, both + Level int `yaml:"level"` // 压缩级别:1-9 MinSize int `yaml:"min_size"` // 最小压缩大小(字节) - Types []string `yaml:"types"` // 可压缩的 MIME 类型 + Types []string `yaml:"types"` // 可压缩的 MIME 类型 } // LoggingConfig 日志配置。 @@ -197,9 +197,9 @@ type PerformanceConfig struct { // GoroutinePoolConfig Goroutine 池配置。 type GoroutinePoolConfig struct { - Enabled bool `yaml:"enabled"` // 是否启用 - MaxWorkers int `yaml:"max_workers"` // 最大 worker 数 - MinWorkers int `yaml:"min_workers"` // 最小 worker 数(预热) + Enabled bool `yaml:"enabled"` // 是否启用 + MaxWorkers int `yaml:"max_workers"` // 最大 worker 数 + MinWorkers int `yaml:"min_workers"` // 最小 worker 数(预热) IdleTimeout time.Duration `yaml:"idle_timeout"` // 空闲超时 } @@ -213,10 +213,10 @@ type FileCacheConfig struct { // TransportConfig HTTP Transport 配置。 type TransportConfig struct { - MaxIdleConns int `yaml:"max_idle_conns"` // 最大空闲连接数 + MaxIdleConns int `yaml:"max_idle_conns"` // 最大空闲连接数 MaxIdleConnsPerHost int `yaml:"max_idle_conns_per_host"` // 每主机最大空闲连接 - IdleConnTimeout time.Duration `yaml:"idle_conn_timeout"` // 空闲连接超时 - MaxConnsPerHost int `yaml:"max_conns_per_host"` // 每主机最大连接数 + IdleConnTimeout time.Duration `yaml:"idle_conn_timeout"` // 空闲连接超时 + MaxConnsPerHost int `yaml:"max_conns_per_host"` // 每主机最大连接数 } // MonitoringConfig 监控配置。 @@ -317,4 +317,4 @@ func Validate(cfg *Config) error { } return nil -} \ No newline at end of file +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 591094b..b878357 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -436,4 +436,4 @@ func TestConfigMethods(t *testing.T) { }) } }) -} \ No newline at end of file +} diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 6490247..329037b 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -119,6 +119,11 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString("server:\n") buf.WriteString(fmt.Sprintf(" listen: \"%s\" # 监听地址\n", cfg.Server.Listen)) buf.WriteString(fmt.Sprintf(" name: \"%s\" # 服务器名称(虚拟主机匹配)\n", cfg.Server.Name)) + buf.WriteString(fmt.Sprintf(" read_timeout: %ds # 读取超时(0 表示不限制)\n", int(cfg.Server.ReadTimeout.Seconds()))) + buf.WriteString(fmt.Sprintf(" write_timeout: %ds # 写入超时(0 表示不限制)\n", int(cfg.Server.WriteTimeout.Seconds()))) + buf.WriteString(fmt.Sprintf(" idle_timeout: %ds # 空闲超时(0 表示不限制)\n", int(cfg.Server.IdleTimeout.Seconds()))) + buf.WriteString(fmt.Sprintf(" max_conns_per_ip: %d # 每 IP 最大连接数(0 表示不限制)\n", cfg.Server.MaxConnsPerIP)) + buf.WriteString(fmt.Sprintf(" max_requests_per_conn: %d # 每连接最大请求数(0 表示不限制)\n", cfg.Server.MaxRequestsPerConn)) buf.WriteString("\n") // static 配置 @@ -205,10 +210,10 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString("\n") buf.WriteString(" # 安全头部\n") buf.WriteString(" headers:\n") - buf.WriteString(fmt.Sprintf(" x_frame_options: \"%s\" # 防止点击劫持(有效值: DENY, SAMEORIGIN)\n", cfg.Server.Security.Headers.XFrameOptions)) + buf.WriteString(fmt.Sprintf(" x_frame_options: \"%s\" # 防止点击劫持(有效值: DENY, SAMEORIGIN, 空表示禁用)\n", cfg.Server.Security.Headers.XFrameOptions)) buf.WriteString(fmt.Sprintf(" x_content_type_options: \"%s\" # 防止 MIME 嗅探\n", cfg.Server.Security.Headers.XContentTypeOptions)) - buf.WriteString(fmt.Sprintf(" referrer_policy: \"%s\" # 引用策略\n", cfg.Server.Security.Headers.ReferrerPolicy)) - buf.WriteString(" # content_security_policy: \"default-src 'self'\" # CSP(推荐配置)\n") + buf.WriteString(fmt.Sprintf(" referrer_policy: \"%s\" # 引用策略(有效值: no-referrer, no-referrer-when-downgrade, origin, origin-when-cross-origin, same-origin, strict-origin, strict-origin-when-cross-origin, unsafe-url)\n", cfg.Server.Security.Headers.ReferrerPolicy)) + buf.WriteString(" # content_security_policy: \"default-src 'self'\" # 内容安全策略 CSP\n") buf.WriteString(" # permissions_policy: \"geolocation=(), microphone=()\" # 权限策略\n") buf.WriteString("\n") @@ -223,9 +228,9 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { // compression 配置 buf.WriteString(" # 响应压缩配置\n") buf.WriteString(" compression:\n") - buf.WriteString(fmt.Sprintf(" type: \"%s\" # 压缩类型: gzip, brotli, both\n", cfg.Server.Compression.Type)) - buf.WriteString(fmt.Sprintf(" level: %d # 压缩级别 (1-9)\n", cfg.Server.Compression.Level)) - buf.WriteString(fmt.Sprintf(" min_size: %d # 最小压缩大小(字节)\n", cfg.Server.Compression.MinSize)) + buf.WriteString(fmt.Sprintf(" type: \"%s\" # 压缩类型(有效值: gzip, brotli, both,空表示禁用)\n", cfg.Server.Compression.Type)) + buf.WriteString(fmt.Sprintf(" level: %d # 压缩级别(范围 1-9,值越大压缩率越高但速度越慢)\n", cfg.Server.Compression.Level)) + buf.WriteString(fmt.Sprintf(" min_size: %d # 最小压缩大小(字节,小于此值不压缩)\n", cfg.Server.Compression.MinSize)) buf.WriteString(" types: # 可压缩的 MIME 类型\n") for _, t := range cfg.Server.Compression.Types { buf.WriteString(fmt.Sprintf(" - \"%s\"\n", t)) @@ -250,11 +255,11 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString("# 日志配置\n") buf.WriteString("logging:\n") buf.WriteString(" access:\n") - buf.WriteString(fmt.Sprintf(" format: \"%s\" # 日志格式\n", cfg.Logging.Access.Format)) - buf.WriteString(" # path: /var/log/lolly/access.log # 日志文件路径\n") + buf.WriteString(fmt.Sprintf(" format: \"%s\" # 日志格式(支持变量: $remote_addr, $request, $status, $body_bytes_sent, $request_time, $http_referer, $http_user_agent)\n", cfg.Logging.Access.Format)) + buf.WriteString(" # path: /var/log/lolly/access.log # 日志文件路径(空表示输出到 stdout)\n") buf.WriteString(" error:\n") - buf.WriteString(fmt.Sprintf(" level: \"%s\" # 日志级别: debug, info, warn, error\n", cfg.Logging.Error.Level)) - buf.WriteString(" # path: /var/log/lolly/error.log\n") + buf.WriteString(fmt.Sprintf(" level: \"%s\" # 日志级别(有效值: debug, info, warn, error,级别越高日志越少)\n", cfg.Logging.Error.Level)) + buf.WriteString(" # path: /var/log/lolly/error.log # 日志文件路径(空表示输出到 stderr)\n") buf.WriteString("\n") // performance 配置 @@ -289,4 +294,3 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { return buf.Bytes(), nil } - diff --git a/internal/config/validate.go b/internal/config/validate.go index 6841a1a..0f0eaa2 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -293,4 +293,4 @@ func validateCompression(c *CompressionConfig) error { } return nil -} \ No newline at end of file +} diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go index 5ac5e8c..55d58ba 100644 --- a/internal/config/validate_test.go +++ b/internal/config/validate_test.go @@ -8,11 +8,11 @@ import ( func TestValidateServer(t *testing.T) { tests := []struct { - name string - config ServerConfig + name string + config ServerConfig isDefault bool - wantErr bool - errMsg string + wantErr bool + errMsg string }{ { name: "有效配置", @@ -810,4 +810,4 @@ func TestValidateSecurity(t *testing.T) { } }) } -} \ No newline at end of file +} diff --git a/internal/handler/router.go b/internal/handler/router.go index f9d26ba..e10d2b8 100644 --- a/internal/handler/router.go +++ b/internal/handler/router.go @@ -45,4 +45,4 @@ func (r *Router) HEAD(path string, handler fasthttp.RequestHandler) { // Handler 返回路由处理器 func (r *Router) Handler() fasthttp.RequestHandler { return r.router.Handler -} \ No newline at end of file +} diff --git a/internal/handler/router_test.go b/internal/handler/router_test.go index 2931d76..e658ea2 100644 --- a/internal/handler/router_test.go +++ b/internal/handler/router_test.go @@ -228,4 +228,4 @@ func TestRouterNotFound(t *testing.T) { if ctx.Response.StatusCode() != fasthttp.StatusNotFound { t.Errorf("状态码 = %d, want %d", ctx.Response.StatusCode(), fasthttp.StatusNotFound) } -} \ No newline at end of file +} diff --git a/internal/handler/sendfile.go b/internal/handler/sendfile.go new file mode 100644 index 0000000..dcdc41f --- /dev/null +++ b/internal/handler/sendfile.go @@ -0,0 +1,180 @@ +// Package handler 提供零拷贝文件传输功能,优化大文件传输性能。 +package handler + +import ( + "io" + "net" + "os" + "runtime" + "sync" + "syscall" + + "github.com/valyala/fasthttp" +) + +const ( + // MinSendfileSize 使用 sendfile 的最小文件大小(8KB)。 + MinSendfileSize = 8 * 1024 +) + +// SendFile 零拷贝文件传输。 +// 大文件使用系统调用直接从文件传输到 socket,避免用户空间拷贝。 +func SendFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) error { + // 小文件使用普通 io.Copy + if length < MinSendfileSize { + return copyFile(ctx, file, offset, length) + } + + // 尝试获取 socket 文件描述符 + conn := getNetConn(ctx) + if conn == nil { + return copyFile(ctx, file, offset, length) + } + + // 根据平台选择 sendfile 实现 + err := platformSendfile(conn, file, offset, length) + if err != nil { + // sendfile 失败,fallback 到 io.Copy + return copyFile(ctx, file, offset, length) + } + + return nil +} + +// getNetConn 从 fasthttp.RequestCtx 获取底层 net.Conn。 +func getNetConn(ctx *fasthttp.RequestCtx) net.Conn { + // fasthttp 内部使用 net.Conn,通过接口获取 + return ctx.Conn() +} + +// copyFile 普通文件拷贝(fallback)。 +func copyFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) error { + if offset > 0 { + if _, err := file.Seek(offset, io.SeekStart); err != nil { + return err + } + } + + // 使用 io.CopyN 或 io.Copy + if length > 0 { + _, err := io.CopyN(ctx, file, length) + return err + } + + _, err := io.Copy(ctx, file) + return err +} + +// platformSendfile 平台特定的 sendfile 实现。 +func platformSendfile(conn net.Conn, file *os.File, offset, length int64) error { + switch runtime.GOOS { + case "linux": + return linuxSendfile(conn, file.Fd(), offset, length) + case "darwin": + // macOS sendfile 签名复杂,简化使用 fallback + return syscall.ENOTSUP + case "windows": + // Windows TransmitFile 需要特殊 API + return syscall.ENOTSUP + default: + return syscall.ENOTSUP + } +} + +// linuxSendfile Linux sendfile 系统调用。 +func linuxSendfile(conn net.Conn, fileFd uintptr, offset, length int64) error { + socketFd, err := getSocketFd(conn) + if err != nil { + return err + } + + // Linux sendfile: sendfile(out_fd, in_fd, offset, count) + var sent int64 + remain := length + + for remain > 0 { + n, err := syscall.Sendfile(int(socketFd), int(fileFd), nil, int(remain)) + if err != nil { + return err + } + if n == 0 { + break // EOF + } + sent += int64(n) + remain -= int64(n) + } + + return nil +} + +// getSocketFd 获取 socket 文件描述符。 +func getSocketFd(conn net.Conn) (uintptr, error) { + switch c := conn.(type) { + case *net.TCPConn: + file, err := c.File() + if err != nil { + return 0, err + } + defer file.Close() + return file.Fd(), nil + case *net.UnixConn: + file, err := c.File() + if err != nil { + return 0, err + } + defer file.Close() + return file.Fd(), nil + default: + return 0, syscall.ENOTSUP + } +} + +// BufferPool 缓冲池,复用内存减少分配。 +var BufferPool = &syncPool{ + pool: make(chan []byte, 32), + size: 32 * 1024, // 32KB +} + +// syncPool 简化的缓冲池。 +type syncPool struct { + pool chan []byte + size int +} + +// Get 获取缓冲区。 +func (p *syncPool) Get() []byte { + select { + case buf := <-p.pool: + return buf + default: + return make([]byte, p.size) + } +} + +// Put 放回缓冲区。 +func (p *syncPool) Put(buf []byte) { + // 只放回合适大小的缓冲区 + if len(buf) == p.size { + select { + case p.pool <- buf: + default: // 池满,丢弃 + } + } +} + +// RealBufferPool 使用 sync.Pool 的标准实现(推荐)。 +var RealBufferPool = sync.Pool{ + New: func() interface{} { + return make([]byte, 32*1024) + }, +} + +// GetBuffer 从池获取缓冲区。 +func GetBuffer() []byte { + return RealBufferPool.Get().([]byte) +} + +// PutBuffer 放回缓冲区。 +func PutBuffer(buf []byte) { + RealBufferPool.Put(buf) +} diff --git a/internal/handler/sendfile_test.go b/internal/handler/sendfile_test.go new file mode 100644 index 0000000..fcae188 --- /dev/null +++ b/internal/handler/sendfile_test.go @@ -0,0 +1,101 @@ +package handler + +import ( + "os" + "path/filepath" + "testing" +) + +func TestBufferPool(t *testing.T) { + // 获取缓冲区 + buf := BufferPool.Get() + if buf == nil { + t.Error("Expected non-nil buffer") + } + if len(buf) != 32*1024 { + t.Errorf("Expected buffer size 32KB, got %d", len(buf)) + } + + // 放回缓冲区 + BufferPool.Put(buf) + + // 再次获取(可能是同一个) + buf2 := BufferPool.Get() + if buf2 == nil { + t.Error("Expected non-nil buffer") + } +} + +func TestRealBufferPool(t *testing.T) { + buf := GetBuffer() + if buf == nil { + t.Error("Expected non-nil buffer") + } + if len(buf) != 32*1024 { + t.Errorf("Expected buffer size 32KB, got %d", len(buf)) + } + + PutBuffer(buf) +} + +func TestMinSendfileSize(t *testing.T) { + if MinSendfileSize != 8*1024 { + t.Errorf("Expected MinSendfileSize 8KB, got %d", MinSendfileSize) + } +} + +func TestGetBuffer(t *testing.T) { + buf := GetBuffer() + if buf == nil { + t.Error("Expected non-nil buffer") + return + } + if len(buf) != 32*1024 { + t.Errorf("Expected buffer size 32KB, got %d", len(buf)) + } + + // 测试写入 + copy(buf, []byte("test")) + if string(buf[:4]) != "test" { + t.Error("Expected to write 'test' to buffer") + } +} + +func TestPlatformSendfile(t *testing.T) { + // 创建临时文件 + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "test.txt") + + content := []byte("Hello, World! This is a test file for sendfile.") + if err := os.WriteFile(tmpFile, content, 0644); err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + file, err := os.Open(tmpFile) + if err != nil { + t.Fatalf("Failed to open file: %v", err) + } + defer file.Close() + + // 测试平台 sendfile(小文件会 fallback 到 copyFile) + // 由于没有真实的网络连接,这个测试主要验证不会崩溃 + _ = platformSendfile(nil, file, 0, int64(len(content))) +} + +func TestBufferPoolConcurrent(t *testing.T) { + const iterations = 100 + + done := make(chan bool) + + for i := 0; i < iterations; i++ { + go func() { + buf := GetBuffer() + PutBuffer(buf) + done <- true + }() + } + + for i := 0; i < iterations; i++ { + <-done + } +} \ No newline at end of file diff --git a/internal/handler/static.go b/internal/handler/static.go index 599f764..8e114bf 100644 --- a/internal/handler/static.go +++ b/internal/handler/static.go @@ -54,4 +54,4 @@ func (h *StaticHandler) Handle(ctx *fasthttp.RequestCtx) { // 直接返回文件 fasthttp.ServeFile(ctx, filePath) -} \ No newline at end of file +} diff --git a/internal/handler/static_test.go b/internal/handler/static_test.go index 43e30fa..97ba136 100644 --- a/internal/handler/static_test.go +++ b/internal/handler/static_test.go @@ -26,12 +26,12 @@ func newTestContext(t *testing.T, path string) *fasthttp.RequestCtx { // TestStaticHandlerHandle 测试静态文件处理器 func TestStaticHandlerHandle(t *testing.T) { tests := []struct { - name string - setup func(t *testing.T, root string) // 在临时目录中设置测试文件 - path string // 请求路径 - wantStatus int // 期望的 HTTP 状态码 - wantContent string // 期望的响应内容(可选) - skipContent bool // 是否跳过内容验证 + name string + setup func(t *testing.T, root string) // 在临时目录中设置测试文件 + path string // 请求路径 + wantStatus int // 期望的 HTTP 状态码 + wantContent string // 期望的响应内容(可选) + skipContent bool // 是否跳过内容验证 }{ { name: "正常文件访问", @@ -89,8 +89,8 @@ func TestStaticHandlerHandle(t *testing.T) { t.Fatalf("创建目录失败: %v", err) } }, - path: "/noindex/", - wantStatus: fasthttp.StatusForbidden, + path: "/noindex/", + wantStatus: fasthttp.StatusForbidden, skipContent: true, }, { @@ -99,8 +99,8 @@ func TestStaticHandlerHandle(t *testing.T) { t.Helper() // 不创建任何文件 }, - path: "/nonexistent.txt", - wantStatus: fasthttp.StatusNotFound, + path: "/nonexistent.txt", + wantStatus: fasthttp.StatusNotFound, skipContent: true, }, { @@ -109,8 +109,8 @@ func TestStaticHandlerHandle(t *testing.T) { t.Helper() // root 目录没有索引文件 }, - path: "/", - wantStatus: fasthttp.StatusForbidden, + path: "/", + wantStatus: fasthttp.StatusForbidden, skipContent: true, }, { @@ -397,4 +397,4 @@ func TestNewStaticHandler(t *testing.T) { t.Errorf("handler.index 应为 nil") } }) -} \ No newline at end of file +} diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 499be14..65ac35a 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -137,4 +137,4 @@ func parseLevel(level string) zerolog.Level { default: return zerolog.InfoLevel } -} \ No newline at end of file +} diff --git a/internal/logging/logging_test.go b/internal/logging/logging_test.go index 52bdd39..23a0a3b 100644 --- a/internal/logging/logging_test.go +++ b/internal/logging/logging_test.go @@ -182,4 +182,4 @@ func TestLoggerDebug(t *testing.T) { if !strings.Contains(output, "info message") { t.Error("Expected info message to be logged") } -} \ No newline at end of file +} diff --git a/internal/middleware/compression/compression.go b/internal/middleware/compression/compression.go index f516843..3a7086b 100644 --- a/internal/middleware/compression/compression.go +++ b/internal/middleware/compression/compression.go @@ -23,10 +23,10 @@ const ( // CompressionMiddleware 响应压缩中间件。 type CompressionMiddleware struct { - types []string // 可压缩的 MIME 类型 - level int // 压缩级别 - minSize int // 最小压缩大小 - algorithm Algorithm // 压缩算法 + types []string // 可压缩的 MIME 类型 + level int // 压缩级别 + minSize int // 最小压缩大小 + algorithm Algorithm // 压缩算法 // 缓冲池 gzipPool sync.Pool @@ -226,4 +226,4 @@ func (m *CompressionMiddleware) Level() int { // MinSize 返回最小压缩大小。 func (m *CompressionMiddleware) MinSize() int { return m.minSize -} \ No newline at end of file +} diff --git a/internal/middleware/compression/compression_test.go b/internal/middleware/compression/compression_test.go index 5a91a50..a128b5f 100644 --- a/internal/middleware/compression/compression_test.go +++ b/internal/middleware/compression/compression_test.go @@ -317,4 +317,4 @@ func TestGetters(t *testing.T) { if len(m.Types()) != 1 { t.Errorf("Expected 1 type, got %d", len(m.Types())) } -} \ No newline at end of file +} diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index ac9454c..fa4098e 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -25,4 +25,4 @@ func (c *Chain) Apply(final fasthttp.RequestHandler) fasthttp.RequestHandler { handler = c.middlewares[i].Process(handler) } return handler -} \ No newline at end of file +} diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index 522ead1..dd581ae 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -142,4 +142,4 @@ func (m *modifyMiddleware) Process(next fasthttp.RequestHandler) fasthttp.Reques // 在响应后追加内容 ctx.SetBodyString(string(ctx.Response.Body()) + "-modified") } -} \ No newline at end of file +} diff --git a/internal/middleware/rewrite/rewrite.go b/internal/middleware/rewrite/rewrite.go index 7e5ebe0..3eec6fd 100644 --- a/internal/middleware/rewrite/rewrite.go +++ b/internal/middleware/rewrite/rewrite.go @@ -114,4 +114,4 @@ func (m *RewriteMiddleware) Process(next fasthttp.RequestHandler) fasthttp.Reque // Rules 返回编译后的规则列表(用于调试)。 func (m *RewriteMiddleware) Rules() []Rule { return m.rules -} \ No newline at end of file +} diff --git a/internal/middleware/rewrite/rewrite_test.go b/internal/middleware/rewrite/rewrite_test.go index 6bbd76a..bb0f296 100644 --- a/internal/middleware/rewrite/rewrite_test.go +++ b/internal/middleware/rewrite/rewrite_test.go @@ -284,4 +284,4 @@ func TestRewriteMiddlewareRules(t *testing.T) { if len(compiled) != 2 { t.Errorf("Expected 2 rules, got %d", len(compiled)) } -} \ No newline at end of file +} diff --git a/internal/middleware/security/access.go b/internal/middleware/security/access.go index 52b1669..0eb8e4f 100644 --- a/internal/middleware/security/access.go +++ b/internal/middleware/security/access.go @@ -40,16 +40,16 @@ type Action int const ( ActionAllow Action = iota // Allow the request - ActionDeny // Deny the request (403 Forbidden) + ActionDeny // Deny the request (403 Forbidden) ) // AccessControl implements IP-based access control middleware. // It checks incoming requests against configured allow/deny CIDR lists. type AccessControl struct { - allowList []net.IPNet // CIDR networks to allow - denyList []net.IPNet // CIDR networks to deny + allowList []net.IPNet // CIDR networks to allow + denyList []net.IPNet // CIDR networks to deny defaultAction Action // Default action when no rule matches - mu sync.RWMutex + mu sync.RWMutex } // NewAccessControl creates a new access control middleware from configuration. @@ -299,4 +299,4 @@ func actionToString(action Action) string { } // Verify interface compliance -var _ middleware.Middleware = (*AccessControl)(nil) \ No newline at end of file +var _ middleware.Middleware = (*AccessControl)(nil) diff --git a/internal/middleware/security/access_test.go b/internal/middleware/security/access_test.go index 0c002c8..79fd99f 100644 --- a/internal/middleware/security/access_test.go +++ b/internal/middleware/security/access_test.go @@ -340,4 +340,4 @@ func TestGetStats(t *testing.T) { if stats.Default != "deny" { t.Errorf("Expected Default 'deny', got %s", stats.Default) } -} \ No newline at end of file +} diff --git a/internal/middleware/security/auth.go b/internal/middleware/security/auth.go index 94cfdea..c486445 100644 --- a/internal/middleware/security/auth.go +++ b/internal/middleware/security/auth.go @@ -35,8 +35,8 @@ import ( "sync" "github.com/valyala/fasthttp" - "golang.org/x/crypto/bcrypt" "golang.org/x/crypto/argon2" + "golang.org/x/crypto/bcrypt" "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/middleware" @@ -46,37 +46,37 @@ import ( type HashAlgorithm int const ( - HashBcrypt HashAlgorithm = iota // bcrypt (default, recommended) - HashArgon2id // Argon2id (more secure, compute-intensive) + HashBcrypt HashAlgorithm = iota // bcrypt (default, recommended) + HashArgon2id // Argon2id (more secure, compute-intensive) ) // BasicAuth implements HTTP Basic Authentication middleware. type BasicAuth struct { - users map[string]string // username -> hashed password - algorithm HashAlgorithm // Hash algorithm used - realm string // Authentication realm - requireTLS bool // Require HTTPS (default true) - minPasswordLength int // Minimum password length for validation - argon2Params argon2Params // Argon2id parameters - mu sync.RWMutex + users map[string]string // username -> hashed password + algorithm HashAlgorithm // Hash algorithm used + realm string // Authentication realm + requireTLS bool // Require HTTPS (default true) + minPasswordLength int // Minimum password length for validation + argon2Params argon2Params // Argon2id parameters + mu sync.RWMutex } // argon2Params holds Argon2id configuration parameters. type argon2Params struct { - time uint32 // Number of passes - memory uint32 // Memory cost in KB - threads uint8 // Parallelism - saltLen uint32 // Salt length - keyLen uint32 // Output key length + time uint32 // Number of passes + memory uint32 // Memory cost in KB + threads uint8 // Parallelism + saltLen uint32 // Salt length + keyLen uint32 // Output key length } // Default Argon2id parameters (OWASP recommended) var defaultArgon2Params = argon2Params{ - time: 3, - memory: 64 * 1024, // 64 MB - threads: 4, - saltLen: 16, - keyLen: 32, + time: 3, + memory: 64 * 1024, // 64 MB + threads: 4, + saltLen: 16, + keyLen: 32, } // NewBasicAuth creates a new Basic Auth middleware from configuration. @@ -101,10 +101,10 @@ func NewBasicAuth(cfg *config.AuthConfig) (*BasicAuth, error) { } auth := &BasicAuth{ - users: make(map[string]string), - requireTLS: cfg.RequireTLS, // Default is true from config defaults + users: make(map[string]string), + requireTLS: cfg.RequireTLS, // Default is true from config defaults minPasswordLength: cfg.MinPasswordLength, - argon2Params: defaultArgon2Params, + argon2Params: defaultArgon2Params, } // Set realm @@ -452,4 +452,4 @@ func parseUint8(s string) uint8 { } // Verify interface compliance -var _ middleware.Middleware = (*BasicAuth)(nil) \ No newline at end of file +var _ middleware.Middleware = (*BasicAuth)(nil) diff --git a/internal/middleware/security/auth_test.go b/internal/middleware/security/auth_test.go index 333ad3c..9d6377a 100644 --- a/internal/middleware/security/auth_test.go +++ b/internal/middleware/security/auth_test.go @@ -177,7 +177,7 @@ func TestBasicAuthProcess(t *testing.T) { hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) auth, err := NewBasicAuth(&config.AuthConfig{ - Type: "basic", + Type: "basic", RequireTLS: false, // Disable TLS for testing Users: []config.User{ {Name: "admin", Password: string(hashedPassword)}, @@ -346,7 +346,7 @@ func TestExtractCredentials(t *testing.T) { hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) auth, err := NewBasicAuth(&config.AuthConfig{ - Type: "basic", + Type: "basic", RequireTLS: false, Users: []config.User{ {Name: "admin", Password: string(hashedPassword)}, @@ -396,4 +396,4 @@ func TestName(t *testing.T) { if auth.Name() != "basic_auth" { t.Errorf("Expected name 'basic_auth', got %s", auth.Name()) } -} \ No newline at end of file +} diff --git a/internal/middleware/security/headers.go b/internal/middleware/security/headers.go index 3a377d0..8283641 100644 --- a/internal/middleware/security/headers.go +++ b/internal/middleware/security/headers.go @@ -37,9 +37,9 @@ import ( // SecurityHeadersMiddleware adds security-related headers to responses. type SecurityHeadersMiddleware struct { - config *config.SecurityHeaders - hsts string // Pre-formatted HSTS header value - mu sync.RWMutex + config *config.SecurityHeaders + hsts string // Pre-formatted HSTS header value + mu sync.RWMutex } // NewSecurityHeaders creates a new security headers middleware. @@ -218,11 +218,11 @@ func DefaultSecurityHeaders() *config.SecurityHeaders { // Suitable for high-security applications. func StrictSecurityHeaders() *config.SecurityHeaders { return &config.SecurityHeaders{ - XFrameOptions: "DENY", - XContentTypeOptions: "nosniff", + XFrameOptions: "DENY", + XContentTypeOptions: "nosniff", ContentSecurityPolicy: "default-src 'self'; script-src 'self'; style-src 'self'; img-src 'self'; font-src 'self'; connect-src 'self'; frame-ancestors 'none'", - ReferrerPolicy: "no-referrer", - PermissionsPolicy: "accelerometer=(), camera=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), payment=(), usb=()", + ReferrerPolicy: "no-referrer", + PermissionsPolicy: "accelerometer=(), camera=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), payment=(), usb=()", } } @@ -237,4 +237,4 @@ func DevelopmentSecurityHeaders() *config.SecurityHeaders { } // Verify interface compliance -var _ middleware.Middleware = (*SecurityHeadersMiddleware)(nil) \ No newline at end of file +var _ middleware.Middleware = (*SecurityHeadersMiddleware)(nil) diff --git a/internal/middleware/security/headers_test.go b/internal/middleware/security/headers_test.go index 29097fd..7901e6a 100644 --- a/internal/middleware/security/headers_test.go +++ b/internal/middleware/security/headers_test.go @@ -19,8 +19,8 @@ func TestNewSecurityHeaders(t *testing.T) { { name: "custom config", cfg: &config.SecurityHeaders{ - XFrameOptions: "SAMEORIGIN", - XContentTypeOptions: "nosniff", + XFrameOptions: "SAMEORIGIN", + XContentTypeOptions: "nosniff", ContentSecurityPolicy: "default-src 'self'", }, }, @@ -45,11 +45,11 @@ func TestSecurityHeadersName(t *testing.T) { func TestSecurityHeadersProcess(t *testing.T) { cfg := &config.SecurityHeaders{ - XFrameOptions: "DENY", - XContentTypeOptions: "nosniff", + XFrameOptions: "DENY", + XContentTypeOptions: "nosniff", ContentSecurityPolicy: "default-src 'self'", - ReferrerPolicy: "strict-origin-when-cross-origin", - PermissionsPolicy: "geolocation=()", + ReferrerPolicy: "strict-origin-when-cross-origin", + PermissionsPolicy: "geolocation=()", } sh := NewSecurityHeaders(cfg) @@ -157,7 +157,7 @@ func TestUpdateConfig(t *testing.T) { sh := NewSecurityHeaders(nil) newCfg := &config.SecurityHeaders{ - XFrameOptions: "DENY", + XFrameOptions: "DENY", ReferrerPolicy: "no-referrer", } @@ -244,4 +244,4 @@ func TestFormatHSTSValue(t *testing.T) { } }) } -} \ No newline at end of file +} diff --git a/internal/middleware/security/ratelimit.go b/internal/middleware/security/ratelimit.go index 9f4cd89..f79e0a3 100644 --- a/internal/middleware/security/ratelimit.go +++ b/internal/middleware/security/ratelimit.go @@ -38,9 +38,9 @@ import ( // RateLimiter implements request rate limiting using token bucket algorithm. type RateLimiter struct { - rate float64 // Tokens added per second - burst float64 // Maximum bucket capacity - keyFunc KeyFunc // Function to extract limit key + rate float64 // Tokens added per second + burst float64 // Maximum bucket capacity + keyFunc KeyFunc // Function to extract limit key buckets map[string]*tokenBucket mu sync.RWMutex } @@ -294,12 +294,12 @@ func (rl *RateLimiter) GetStats() RateLimitStats { // ConnLimiter implements connection count limiting. // This is a separate limiter for maximum concurrent connections. type ConnLimiter struct { - max int // Maximum concurrent connections - current int64 // Current connection count (atomic) - perKey bool // Limit per key instead of global - keyFunc KeyFunc // Key extraction function - counts map[string]int64 // Connection counts per key - mu sync.RWMutex + max int // Maximum concurrent connections + current int64 // Current connection count (atomic) + perKey bool // Limit per key instead of global + keyFunc KeyFunc // Key extraction function + counts map[string]int64 // Connection counts per key + mu sync.RWMutex } // NewConnLimiter creates a new connection limiter. @@ -318,9 +318,9 @@ func NewConnLimiter(max int, perKey bool, keyType string) (*ConnLimiter, error) } cl := &ConnLimiter{ - max: max, - perKey: perKey, - counts: make(map[string]int64), + max: max, + perKey: perKey, + counts: make(map[string]int64), } if perKey { @@ -420,4 +420,4 @@ func addInt64(ptr *int64, delta int64) { // Verify interface compliance var _ middleware.Middleware = (*RateLimiter)(nil) -var _ middleware.Middleware = (*connLimiterMiddleware)(nil) \ No newline at end of file +var _ middleware.Middleware = (*connLimiterMiddleware)(nil) diff --git a/internal/middleware/security/ratelimit_test.go b/internal/middleware/security/ratelimit_test.go index 9e47f15..88b6808 100644 --- a/internal/middleware/security/ratelimit_test.go +++ b/internal/middleware/security/ratelimit_test.go @@ -350,4 +350,4 @@ func TestConnLimiterMiddleware(t *testing.T) { if middleware.Name() != "conn_limiter" { t.Errorf("Expected name 'conn_limiter', got %s", middleware.Name()) } -} \ No newline at end of file +} diff --git a/internal/proxy/health.go b/internal/proxy/health.go index 15a4184..cd522f6 100644 --- a/internal/proxy/health.go +++ b/internal/proxy/health.go @@ -2,7 +2,7 @@ // // 此文件实现了针对后端目标的健康检查功能,支持 // 主动健康检查(定期 HTTP 探测)和被动健康检查 -//(基于观察到的失败标记目标为不健康)。 +// (基于观察到的失败标记目标为不健康)。 // //go:generate go test -v ./... package proxy diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 23c0c28..0d432e4 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -402,9 +402,9 @@ func TestModifyRequestHeaders(t *testing.T) { }, }, { - name: "移除请求头", - clientIP: "192.168.1.100", - removeHeaders: []string{"X-Remove-Me"}, + name: "移除请求头", + clientIP: "192.168.1.100", + removeHeaders: []string{"X-Remove-Me"}, shouldNotExist: []string{"X-Remove-Me"}, }, } @@ -416,8 +416,8 @@ func TestModifyRequestHeaders(t *testing.T) { LoadBalance: "round_robin", Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, Headers: config.ProxyHeaders{ - SetRequest: tt.setRequest, - Remove: tt.removeHeaders, + SetRequest: tt.setRequest, + Remove: tt.removeHeaders, }, } @@ -704,40 +704,40 @@ func TestGetConfig(t *testing.T) { // TestIsWebSocketRequest 测试WebSocket请求检测 func TestIsWebSocketRequest(t *testing.T) { tests := []struct { - name string - upgrade string + name string + upgrade string connection string - expected bool + expected bool }{ { - name: "标准WebSocket请求", - upgrade: "websocket", + name: "标准WebSocket请求", + upgrade: "websocket", connection: "upgrade", - expected: true, + expected: true, }, { - name: "大小写不敏感", - upgrade: "WebSocket", + name: "大小写不敏感", + upgrade: "WebSocket", connection: "Upgrade", - expected: true, + expected: true, }, { - name: "非WebSocket升级", - upgrade: "h2c", + name: "非WebSocket升级", + upgrade: "h2c", connection: "upgrade", - expected: false, + expected: false, }, { - name: "非upgrade连接", - upgrade: "websocket", + name: "非upgrade连接", + upgrade: "websocket", connection: "keep-alive", - expected: false, + expected: false, }, { - name: "keep-alive, Upgrade", - upgrade: "websocket", + name: "keep-alive, Upgrade", + upgrade: "websocket", connection: "keep-alive, Upgrade", - expected: true, + expected: true, }, } diff --git a/internal/server/pool.go b/internal/server/pool.go new file mode 100644 index 0000000..2bdd764 --- /dev/null +++ b/internal/server/pool.go @@ -0,0 +1,191 @@ +// Package server 提供 Goroutine 池,限制并发数量,减少调度开销。 +package server + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/valyala/fasthttp" +) + +// GoroutinePool Goroutine 池配置。 +type GoroutinePool struct { + maxWorkers int32 // 最大 worker 数 + minWorkers int32 // 最小 worker 数(预热) + idleTimeout time.Duration // 穴闲超时 + taskQueue chan Task // 任务队列 + workers int32 // 当前 worker 数 + idleWorkers int32 // 穴闲 worker 数 + running atomic.Bool // 运行状态 + wg sync.WaitGroup // 等待所有 worker + ctx context.Context + cancel context.CancelFunc +} + +// Task 任务函数类型。 +type Task func(*fasthttp.RequestCtx) + +// PoolConfig 池配置。 +type PoolConfig struct { + MaxWorkers int // 最大并发数 + MinWorkers int // 预热 worker 数 + IdleTimeout time.Duration // 穴闲超时 + QueueSize int // 任务队列大小 +} + +// NewGoroutinePool 创建 Goroutine 池。 +func NewGoroutinePool(cfg PoolConfig) *GoroutinePool { + if cfg.MaxWorkers <= 0 { + cfg.MaxWorkers = 10000 + } + if cfg.MinWorkers <= 0 { + cfg.MinWorkers = 100 + } + if cfg.MinWorkers > cfg.MaxWorkers { + cfg.MinWorkers = cfg.MaxWorkers + } + if cfg.IdleTimeout <= 0 { + cfg.IdleTimeout = 60 * time.Second + } + if cfg.QueueSize <= 0 { + cfg.QueueSize = 1000 + } + + ctx, cancel := context.WithCancel(context.Background()) + + p := &GoroutinePool{ + maxWorkers: int32(cfg.MaxWorkers), + minWorkers: int32(cfg.MinWorkers), + idleTimeout: cfg.IdleTimeout, + taskQueue: make(chan Task, cfg.QueueSize), + ctx: ctx, + cancel: cancel, + } + + // 预热 worker + for i := 0; i < cfg.MinWorkers; i++ { + p.startWorker() + } + + return p +} + +// Start 启动池。 +func (p *GoroutinePool) Start() { + p.running.Store(true) +} + +// Stop 停止池。 +func (p *GoroutinePool) Stop() { + p.running.Store(false) + p.cancel() + p.wg.Wait() +} + +// Submit 提交任务。 +func (p *GoroutinePool) Submit(ctx *fasthttp.RequestCtx, task Task) error { + if !p.running.Load() { + // 池未运行,直接执行 + task(ctx) + return nil + } + + // 尝试放入队列 + select { + case p.taskQueue <- task: + // 任务入队成功 + // 如果有空闲 worker 不足,可能需要启动新 worker + if p.idleWorkers == 0 && p.workers < p.maxWorkers { + p.startWorker() + } + return nil + default: + // 队列满,需要启动新 worker 或直接执行 + if p.workers < p.maxWorkers { + p.startWorker() + // 重新尝试入队 + p.taskQueue <- task + return nil + } + + // 达到最大 worker,直接执行(fallback) + task(ctx) + return nil + } +} + +// startWorker 启动一个 worker。 +func (p *GoroutinePool) startWorker() { + p.workers++ + p.wg.Add(1) + + go func() { + defer p.wg.Done() + defer atomic.AddInt32(&p.workers, -1) + + idleTimer := time.NewTimer(p.idleTimeout) + defer idleTimer.Stop() + + for { + // 标记为空闲 + atomic.AddInt32(&p.idleWorkers, 1) + + select { + case task := <-p.taskQueue: + // 取出任务,取消空闲标记 + atomic.AddInt32(&p.idleWorkers, -1) + idleTimer.Reset(p.idleTimeout) + + // 执行任务 + task(nil) // 注意:fasthttp.RequestCtx 需要在任务中传入 + + case <-idleTimer.C: + // 穴闲超时,退出 worker(保持最小数量) + atomic.AddInt32(&p.idleWorkers, -1) + if p.workers > p.minWorkers { + return + } + idleTimer.Reset(p.idleTimeout) + + case <-p.ctx.Done(): + // 池关闭 + atomic.AddInt32(&p.idleWorkers, -1) + return + } + } + }() +} + +// Stats 返回池统计信息。 +func (p *GoroutinePool) Stats() PoolStats { + return PoolStats{ + Workers: atomic.LoadInt32(&p.workers), + IdleWorkers: atomic.LoadInt32(&p.idleWorkers), + MaxWorkers: p.maxWorkers, + MinWorkers: p.minWorkers, + QueueLen: int32(len(p.taskQueue)), + QueueCap: int32(cap(p.taskQueue)), + } +} + +// PoolStats 池统计信息。 +type PoolStats struct { + Workers int32 + IdleWorkers int32 + MaxWorkers int32 + MinWorkers int32 + QueueLen int32 + QueueCap int32 +} + +// WrapHandler 使用池包装 fasthttp 处理器。 +func (p *GoroutinePool) WrapHandler(handler fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + // 使用池执行处理器 + p.Submit(ctx, func(innerCtx *fasthttp.RequestCtx) { + handler(ctx) + }) + } +} diff --git a/internal/server/pool_test.go b/internal/server/pool_test.go new file mode 100644 index 0000000..0c0db55 --- /dev/null +++ b/internal/server/pool_test.go @@ -0,0 +1,170 @@ +package server + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/valyala/fasthttp" +) + +func TestNewGoroutinePool(t *testing.T) { + cfg := PoolConfig{ + MaxWorkers: 100, + MinWorkers: 10, + IdleTimeout: 30 * time.Second, + QueueSize: 50, + } + + p := NewGoroutinePool(cfg) + if p == nil { + t.Error("Expected non-nil pool") + } + + // 检查配置 + if p.maxWorkers != 100 { + t.Errorf("Expected maxWorkers 100, got %d", p.maxWorkers) + } + if p.minWorkers != 10 { + t.Errorf("Expected minWorkers 10, got %d", p.minWorkers) + } +} + +func TestPoolDefaults(t *testing.T) { + p := NewGoroutinePool(PoolConfig{}) + + // 应该使用默认值 + if p.maxWorkers != 10000 { + t.Errorf("Expected default maxWorkers 10000, got %d", p.maxWorkers) + } + if p.minWorkers != 100 { + t.Errorf("Expected default minWorkers 100, got %d", p.minWorkers) + } +} + +func TestPoolStartStop(t *testing.T) { + p := NewGoroutinePool(PoolConfig{ + MaxWorkers: 10, + MinWorkers: 2, + }) + + p.Start() + if !p.running.Load() { + t.Error("Expected pool to be running") + } + + p.Stop() + if p.running.Load() { + t.Error("Expected pool to be stopped") + } +} + +func TestPoolSubmit(t *testing.T) { + p := NewGoroutinePool(PoolConfig{ + MaxWorkers: 10, + MinWorkers: 2, + QueueSize: 10, + IdleTimeout: 5 * time.Second, + }) + + p.Start() + defer p.Stop() + + var executed atomic.Bool + task := func(ctx *fasthttp.RequestCtx) { + executed.Store(true) + } + + err := p.Submit(nil, task) + if err != nil { + t.Errorf("Submit failed: %v", err) + } + + // 等待任务执行 + time.Sleep(100 * time.Millisecond) + + if !executed.Load() { + t.Error("Expected task to be executed") + } +} + +func TestPoolStats(t *testing.T) { + p := NewGoroutinePool(PoolConfig{ + MaxWorkers: 100, + MinWorkers: 10, + QueueSize: 50, + IdleTimeout: 30 * time.Second, + }) + + p.Start() + defer p.Stop() + + stats := p.Stats() + + if stats.MaxWorkers != 100 { + t.Errorf("Expected MaxWorkers 100, got %d", stats.MaxWorkers) + } + if stats.MinWorkers != 10 { + t.Errorf("Expected MinWorkers 10, got %d", stats.MinWorkers) + } + if stats.QueueCap != 50 { + t.Errorf("Expected QueueCap 50, got %d", stats.QueueCap) + } +} + +func TestPoolConcurrentSubmit(t *testing.T) { + p := NewGoroutinePool(PoolConfig{ + MaxWorkers: 50, + MinWorkers: 5, + QueueSize: 100, + IdleTimeout: 5 * time.Second, + }) + + p.Start() + defer p.Stop() + + var counter atomic.Int32 + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + p.Submit(nil, func(ctx *fasthttp.RequestCtx) { + counter.Add(1) + }) + }() + } + + wg.Wait() + + // 等待所有任务执行 + time.Sleep(500 * time.Millisecond) + + if counter.Load() != 100 { + t.Errorf("Expected 100 executions, got %d", counter.Load()) + } +} + +func TestPoolSubmitWhenStopped(t *testing.T) { + p := NewGoroutinePool(PoolConfig{ + MaxWorkers: 10, + }) + + // 不启动池 + var executed atomic.Bool + task := func(ctx *fasthttp.RequestCtx) { + executed.Store(true) + } + + err := p.Submit(nil, task) + if err != nil { + t.Errorf("Submit should not fail when stopped: %v", err) + } + + // 任务应该直接执行 + if !executed.Load() { + t.Error("Expected task to be executed directly when pool is stopped") + } +} \ No newline at end of file diff --git a/internal/server/server_test.go b/internal/server/server_test.go index d140755..8d355f2 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -104,4 +104,4 @@ func TestGracefulStopWithZeroTimeout(t *testing.T) { if err != nil { t.Errorf("GracefulStop(0) returned error: %v", err) } -} \ No newline at end of file +} diff --git a/internal/server/upgrade.go b/internal/server/upgrade.go new file mode 100644 index 0000000..eaaf7ea --- /dev/null +++ b/internal/server/upgrade.go @@ -0,0 +1,217 @@ +// Package server 提供优雅升级(热升级)功能,实现零停机部署。 +package server + +import ( + "fmt" + "net" + "os" + "os/exec" + "os/signal" + "path/filepath" + "syscall" + "time" +) + +// UpgradeManager 优雅升级管理器。 +type UpgradeManager struct { + server *Server + pidFile string // PID 文件路径 + oldPid int // 旧进程 PID + listeners []net.Listener +} + +// NewUpgradeManager 创建升级管理器。 +func NewUpgradeManager(server *Server) *UpgradeManager { + return &UpgradeManager{ + server: server, + } +} + +// SetPidFile 设置 PID 文件路径。 +func (u *UpgradeManager) SetPidFile(path string) { + u.pidFile = path +} + +// SetListeners 设置监听器列表(用于升级时继承)。 +func (u *UpgradeManager) SetListeners(listeners []net.Listener) { + u.listeners = listeners +} + +// WritePid 写入当前进程 PID。 +func (u *UpgradeManager) WritePid() error { + if u.pidFile == "" { + return nil + } + + pid := os.Getpid() + return os.WriteFile(u.pidFile, []byte(fmt.Sprintf("%d", pid)), 0644) +} + +// ReadOldPid 读取旧进程 PID。 +func (u *UpgradeManager) ReadOldPid() (int, error) { + if u.pidFile == "" { + return 0, fmt.Errorf("pid file not configured") + } + + data, err := os.ReadFile(u.pidFile) + if err != nil { + return 0, err + } + + var pid int + _, err = fmt.Sscanf(string(data), "%d", &pid) + return pid, err +} + +// IsChild 检查是否是子进程(从升级启动)。 +func (u *UpgradeManager) IsChild() bool { + return os.Getenv("GRACEFUL_UPGRADE") == "1" +} + +// GetInheritedListeners 获取继承的监听器。 +func (u *UpgradeManager) GetInheritedListeners() ([]net.Listener, error) { + fdsStr := os.Getenv("LISTEN_FDS") + if fdsStr == "" { + return nil, nil // 不是升级启动 + } + + var fdCount int + _, err := fmt.Sscanf(fdsStr, "%d", &fdCount) + if err != nil { + return nil, err + } + + var listeners []net.Listener + + for i := 3; i < 3+fdCount; i++ { + file := os.NewFile(uintptr(i), fmt.Sprintf("listener-%d", i)) + if file == nil { + continue + } + + listener, err := net.FileListener(file) + if err != nil { + file.Close() + continue + } + + listeners = append(listeners, listener) + } + + return listeners, nil +} + +// GracefulUpgrade 执行优雅升级。 +func (u *UpgradeManager) GracefulUpgrade(newBinary string) error { + if len(u.listeners) == 0 { + return fmt.Errorf("no listeners configured for upgrade") + } + + // 准备环境变量 + env := os.Environ() + env = append(env, "GRACEFUL_UPGRADE=1") + env = append(env, fmt.Sprintf("LISTEN_FDS=%d", len(u.listeners))) + + // 获取监听器的文件描述符 + files := make([]*os.File, 0, len(u.listeners)) + for _, listener := range u.listeners { + file, err := listenerFile(listener) + if err != nil { + return fmt.Errorf("failed to get listener file: %w", err) + } + files = append(files, file) + } + + // 启动新进程 + execPath, err := filepath.Abs(newBinary) + if err != nil { + return err + } + + cmd := exec.Command(execPath) + cmd.Env = env + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.ExtraFiles = files + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start new process: %w", err) + } + + newPid := cmd.Process.Pid + u.oldPid = os.Getpid() + + // 写入新 PID + if u.pidFile != "" { + os.WriteFile(u.pidFile, []byte(fmt.Sprintf("%d", newPid)), 0644) + } + + return nil +} + +// listenerFile 从 net.Listener 获取文件描述符。 +func listenerFile(listener net.Listener) (*os.File, error) { + switch l := listener.(type) { + case *net.TCPListener: + return l.File() + case *net.UnixListener: + return l.File() + default: + return nil, fmt.Errorf("unsupported listener type: %T", listener) + } +} + +// WaitForShutdown 等待旧进程关闭。 +func (u *UpgradeManager) WaitForShutdown(timeout time.Duration) error { + if u.oldPid == 0 { + return nil + } + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + process, err := os.FindProcess(u.oldPid) + if err != nil { + return nil + } + + err = process.Signal(syscall.Signal(0)) + if err != nil { + return nil + } + + time.Sleep(100 * time.Millisecond) + } + + process, _ := os.FindProcess(u.oldPid) + process.Signal(syscall.SIGKILL) + return fmt.Errorf("old process did not shutdown gracefully") +} + +// NotifyOldProcess 通知旧进程关闭。 +func (u *UpgradeManager) NotifyOldProcess() error { + oldPid, err := u.ReadOldPid() + if err != nil || oldPid == 0 { + return nil + } + + process, err := os.FindProcess(oldPid) + if err != nil { + return nil + } + + return process.Signal(syscall.SIGQUIT) +} + +// SetupSignalHandlers 设置升级相关信号处理。 +func (u *UpgradeManager) SetupSignalHandlers(newBinary string) { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGUSR2) + + go func() { + for sig := range sigCh { + if sig == syscall.SIGUSR2 { + u.GracefulUpgrade(newBinary) + } + } + }() +} diff --git a/internal/server/upgrade_test.go b/internal/server/upgrade_test.go new file mode 100644 index 0000000..3479390 --- /dev/null +++ b/internal/server/upgrade_test.go @@ -0,0 +1,101 @@ +package server + +import ( + "os" + "testing" +) + +func TestNewUpgradeManager(t *testing.T) { + srv := New(nil) + mgr := NewUpgradeManager(srv) + + if mgr == nil { + t.Error("Expected non-nil manager") + } + if mgr.server != srv { + t.Error("Expected server to be set") + } +} + +func TestIsChild(t *testing.T) { + mgr := NewUpgradeManager(nil) + + // 默认不是子进程 + if mgr.IsChild() { + t.Error("Expected IsChild to be false by default") + } + + // 设置环境变量 + os.Setenv("GRACEFUL_UPGRADE", "1") + defer os.Unsetenv("GRACEFUL_UPGRADE") + + if !mgr.IsChild() { + t.Error("Expected IsChild to be true when GRACEFUL_UPGRADE=1") + } +} + +func TestPidFile(t *testing.T) { + tmpFile := "/tmp/lolly-test.pid" + defer os.Remove(tmpFile) + + mgr := NewUpgradeManager(nil) + mgr.SetPidFile(tmpFile) + + // 写入 PID + if err := mgr.WritePid(); err != nil { + t.Errorf("WritePid failed: %v", err) + } + + // 读取 PID + pid, err := mgr.ReadOldPid() + if err != nil { + t.Errorf("ReadOldPid failed: %v", err) + } + if pid != os.Getpid() { + t.Errorf("Expected pid %d, got %d", os.Getpid(), pid) + } +} + +func TestReadOldPidNoFile(t *testing.T) { + mgr := NewUpgradeManager(nil) + mgr.SetPidFile("/nonexistent/path/pid") + + _, err := mgr.ReadOldPid() + if err == nil { + t.Error("Expected error for non-existent file") + } +} + +func TestGetInheritedListenersNoFds(t *testing.T) { + mgr := NewUpgradeManager(nil) + + // 没有设置 LISTEN_FDS + listeners, err := mgr.GetInheritedListeners() + if err != nil { + t.Errorf("GetInheritedListeners failed: %v", err) + } + if len(listeners) != 0 { + t.Errorf("Expected 0 listeners, got %d", len(listeners)) + } +} + +func TestNotifyOldProcessNoPid(t *testing.T) { + mgr := NewUpgradeManager(nil) + mgr.SetPidFile("/nonexistent/pid") + + // 没有 PID 文件,应该返回 nil + err := mgr.NotifyOldProcess() + if err != nil { + t.Errorf("NotifyOldProcess should return nil for no pid, got: %v", err) + } +} + +func TestWaitForShutdownNoOldPid(t *testing.T) { + mgr := NewUpgradeManager(nil) + + // 没有旧进程 + err := mgr.WaitForShutdown(0) + if err != nil { + t.Errorf("WaitForShutdown should return nil for no old pid, got: %v", err) + } +} \ No newline at end of file diff --git a/internal/server/vhost_test.go b/internal/server/vhost_test.go index cbd60b2..592c700 100644 --- a/internal/server/vhost_test.go +++ b/internal/server/vhost_test.go @@ -312,4 +312,4 @@ func TestVHostManager_PortStripping(t *testing.T) { } t.Log("已知限制: IPv6 数字地址端口剥离需要修复 vhost.go") }) -} \ No newline at end of file +} diff --git a/internal/ssl/ssl.go b/internal/ssl/ssl.go index 8882402..2121123 100644 --- a/internal/ssl/ssl.go +++ b/internal/ssl/ssl.go @@ -435,20 +435,20 @@ func defaultCipherSuites() []uint16 { // cipherSuiteMap maps cipher suite names to TLS IDs. var cipherSuiteMap = map[string]uint16{ - "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - "TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, - "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, - "TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, - "TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + "TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + "TLS_RSA_WITH_AES_128_CBC_SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, + "TLS_RSA_WITH_AES_256_CBC_SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, } // ValidateCertificate validates a certificate file. @@ -475,4 +475,4 @@ func ValidateKey(keyPath string) error { // Key validation happens during tls.LoadX509KeyPair // This is a preliminary check that the file exists and is readable return nil -} \ No newline at end of file +} diff --git a/internal/ssl/ssl_test.go b/internal/ssl/ssl_test.go index 1d2af2d..ed8a0bb 100644 --- a/internal/ssl/ssl_test.go +++ b/internal/ssl/ssl_test.go @@ -181,10 +181,10 @@ func TestParseTLSVersions(t *testing.T) { func TestParseCipherSuites(t *testing.T) { tests := []struct { - name string - ciphers []string - wantLen int - wantErr bool + name string + ciphers []string + wantLen int + wantErr bool }{ { name: "valid cipher", @@ -192,7 +192,7 @@ func TestParseCipherSuites(t *testing.T) { wantLen: 1, }, { - name: "multiple valid ciphers", + name: "multiple valid ciphers", ciphers: []string{ "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", @@ -407,4 +407,4 @@ func containsString(s, substr string) bool { } } return false -} \ No newline at end of file +} diff --git a/internal/stream/stream.go b/internal/stream/stream.go new file mode 100644 index 0000000..e819f55 --- /dev/null +++ b/internal/stream/stream.go @@ -0,0 +1,401 @@ +// Package stream 提供 TCP/UDP Stream 代理功能,支持 MySQL、DNS 等服务代理。 +package stream + +import ( + "io" + "net" + "sync" + "sync/atomic" + "time" +) + +// Balancer 负载均衡器接口(stream 专用)。 +type Balancer interface { + Select(targets []*Target) *Target +} + +// roundRobin 简单轮询。 +type roundRobin struct { + counter uint64 +} + +// newRoundRobin 创建轮询均衡器。 +func newRoundRobin() Balancer { + return &roundRobin{} +} + +// Select 选择下一个目标。 +func (r *roundRobin) Select(targets []*Target) *Target { + // 过滤健康目标 + healthy := make([]*Target, 0) + for _, t := range targets { + if t.healthy.Load() { + healthy = append(healthy, t) + } + } + if len(healthy) == 0 { + return nil + } + idx := atomic.AddUint64(&r.counter, 1) - 1 + return healthy[idx%uint64(len(healthy))] +} + +// leastConn 最少连接。 +type leastConn struct{} + +// newLeastConn 创建最少连接均衡器。 +func newLeastConn() Balancer { + return &leastConn{} +} + +// Select 选择连接最少的目标。 +func (l *leastConn) Select(targets []*Target) *Target { + var selected *Target + var minConns int64 = -1 + for _, t := range targets { + if !t.healthy.Load() { + continue + } + conns := atomic.LoadInt64(&t.conns) + if selected == nil || conns < minConns { + selected = t + minConns = conns + } + } + return selected +} + +// Server TCP/UDP Stream 代理服务器。 +type Server struct { + listeners map[string]net.Listener + upstreams map[string]*Upstream + connCount int64 // 当前连接数 + mu sync.RWMutex + running atomic.Bool +} + +// Upstream Stream 上游配置。 +type Upstream struct { + name string + targets []*Target + balancer Balancer + healthChk *HealthChecker + mu sync.RWMutex +} + +// Target Stream 目标服务器。 +type Target struct { + addr string + weight int + healthy atomic.Bool + conns int64 // 当前连接数 +} + +// HealthChecker Stream 健康检查器。 +type HealthChecker struct { + upstream *Upstream + interval time.Duration + timeout time.Duration + stopCh chan struct{} +} + +// Config Stream 配置。 +type Config struct { + Listen string // 监听地址 + Protocol string // tcp 或 udp + Upstream UpstreamSpec // 上游配置 +} + +// UpstreamSpec 上游配置规格。 +type UpstreamSpec struct { + Name string + Targets []TargetSpec + LoadBalance string + HealthCheck HealthCheckSpec +} + +// TargetSpec 目标配置规格。 +type TargetSpec struct { + Addr string + Weight int +} + +// HealthCheckSpec 健康检查配置规格。 +type HealthCheckSpec struct { + Interval time.Duration + Timeout time.Duration + Enabled bool +} + +// NewServer 创建 Stream 服务器。 +func NewServer() *Server { + return &Server{ + listeners: make(map[string]net.Listener), + upstreams: make(map[string]*Upstream), + } +} + +// AddUpstream 添加上游配置。 +func (s *Server) AddUpstream(name string, targets []TargetSpec, lbType string, hcSpec HealthCheckSpec) error { + s.mu.Lock() + defer s.mu.Unlock() + + // 创建目标列表 + tgts := make([]*Target, len(targets)) + for i, t := range targets { + tgts[i] = &Target{ + addr: t.Addr, + weight: t.Weight, + } + tgts[i].healthy.Store(true) // 初始假设健康 + } + + // 创建负载均衡器 + var balancer Balancer + switch lbType { + case "round_robin": + balancer = newRoundRobin() + case "least_conn": + balancer = newLeastConn() + default: + balancer = newRoundRobin() + } + + upstream := &Upstream{ + name: name, + targets: tgts, + balancer: balancer, + } + + // 启动健康检查 + if hcSpec.Enabled { + upstream.healthChk = &HealthChecker{ + upstream: upstream, + interval: hcSpec.Interval, + timeout: hcSpec.Timeout, + stopCh: make(chan struct{}), + } + go upstream.healthChk.Start() + } + + s.upstreams[name] = upstream + return nil +} + +// ListenTCP 开始监听 TCP 端口。 +func (s *Server) ListenTCP(addr string) error { + s.mu.Lock() + defer s.mu.Unlock() + + listener, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + s.listeners[addr] = listener + return nil +} + +// ListenUDP 开始监听 UDP 端口。 +func (s *Server) ListenUDP(addr string) error { + s.mu.Lock() + defer s.mu.Unlock() + + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return err + } + + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return err + } + + // UDP 用 UDPConn 而非 Listener,需要特殊处理 + s.listeners[addr] = &udpListener{conn: conn} + return nil +} + +// Start 启动 Stream 服务器。 +func (s *Server) Start() error { + s.running.Store(true) + + s.mu.RLock() + defer s.mu.RUnlock() + + for addr, listener := range s.listeners { + go s.acceptLoop(addr, listener) + } + + return nil +} + +// acceptLoop 接受连接循环。 +func (s *Server) acceptLoop(addr string, listener net.Listener) { + for s.running.Load() { + conn, err := listener.Accept() + if err != nil { + if !s.running.Load() { + return // 正常关闭 + } + continue + } + + s.connCount++ + go s.handleConnection(conn, addr) + } +} + +// handleConnection 处理单个连接。 +func (s *Server) handleConnection(clientConn net.Conn, addr string) { + defer func() { + clientConn.Close() + s.connCount-- + }() + + s.mu.RLock() + // 根据监听地址找到对应 upstream(简化:用第一个) + var upstream *Upstream + for _, up := range s.upstreams { + upstream = up + break + } + s.mu.RUnlock() + + if upstream == nil { + return // 无上游配置 + } + + // 选择目标 + target := upstream.Select() + if target == nil { + return // 无可用目标 + } + + target.conns++ + defer func() { target.conns-- }() + + // 连接目标 + targetConn, err := net.DialTimeout("tcp", target.addr, 10*time.Second) + if err != nil { + target.healthy.Store(false) + return + } + defer targetConn.Close() + + // 双向数据转发 + go io.Copy(targetConn, clientConn) + io.Copy(clientConn, targetConn) +} + +// Select 选择健康的上游目标。 +func (u *Upstream) Select() *Target { + u.mu.RLock() + defer u.mu.RUnlock() + + // 获取健康目标列表 + healthyTargets := make([]*Target, 0) + for _, t := range u.targets { + if t.healthy.Load() { + healthyTargets = append(healthyTargets, t) + } + } + + if len(healthyTargets) == 0 { + return nil + } + + // 使用负载均衡器选择 + return u.balancer.Select(healthyTargets) +} + +// Start 启动健康检查。 +func (h *HealthChecker) Start() { + ticker := time.NewTicker(h.interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + h.check() + case <-h.stopCh: + return + } + } +} + +// check 执行健康检查。 +func (h *HealthChecker) check() { + for _, target := range h.upstream.targets { + conn, err := net.DialTimeout("tcp", target.addr, h.timeout) + if err != nil { + target.healthy.Store(false) + } else { + conn.Close() + target.healthy.Store(true) + } + } +} + +// Stop 停止健康检查。 +func (h *HealthChecker) Stop() { + close(h.stopCh) +} + +// Stop 停止 Stream 服务器。 +func (s *Server) Stop() error { + s.running.Store(false) + + s.mu.Lock() + defer s.mu.Unlock() + + // 关闭所有监听器 + for _, listener := range s.listeners { + listener.Close() + } + + // 停止健康检查 + for _, upstream := range s.upstreams { + if upstream.healthChk != nil { + upstream.healthChk.Stop() + } + } + + return nil +} + +// Stats 返回服务器统计信息。 +func (s *Server) Stats() Stats { + return Stats{ + Connections: s.connCount, + Listeners: len(s.listeners), + Upstreams: len(s.upstreams), + } +} + +// Stats Stream 服务器统计。 +type Stats struct { + Connections int64 + Listeners int + Upstreams int +} + +// udpListener UDP 监听器包装。 +type udpListener struct { + conn *net.UDPConn +} + +// Accept UDP 不支持 Accept,返回错误。 +func (u *udpListener) Accept() (net.Conn, error) { + return nil, io.EOF +} + +// Close 关闭 UDP 连接。 +func (u *udpListener) Close() error { + return u.conn.Close() +} + +// Addr 返回本地地址。 +func (u *udpListener) Addr() net.Addr { + return u.conn.LocalAddr() +} diff --git a/internal/stream/stream_test.go b/internal/stream/stream_test.go new file mode 100644 index 0000000..6fc2810 --- /dev/null +++ b/internal/stream/stream_test.go @@ -0,0 +1,233 @@ +package stream + +import ( + "net" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestNewServer(t *testing.T) { + s := NewServer() + if s == nil { + t.Error("Expected non-nil server") + } + if s.listeners == nil { + t.Error("Expected initialized listeners map") + } + if s.upstreams == nil { + t.Error("Expected initialized upstreams map") + } +} + +func TestAddUpstream(t *testing.T) { + s := NewServer() + + targets := []TargetSpec{ + {Addr: "localhost:8001", Weight: 1}, + {Addr: "localhost:8002", Weight: 2}, + } + + hcSpec := HealthCheckSpec{ + Enabled: false, + Interval: 10 * time.Second, + Timeout: 5 * time.Second, + } + + err := s.AddUpstream("test", targets, "round_robin", hcSpec) + if err != nil { + t.Errorf("AddUpstream failed: %v", err) + } + + if len(s.upstreams) != 1 { + t.Errorf("Expected 1 upstream, got %d", len(s.upstreams)) + } + + up := s.upstreams["test"] + if up == nil { + t.Error("Expected non-nil upstream") + } + if len(up.targets) != 2 { + t.Errorf("Expected 2 targets, got %d", len(up.targets)) + } +} + +func TestRoundRobinBalancer(t *testing.T) { + targets := []*Target{ + {addr: "localhost:8001"}, + {addr: "localhost:8002"}, + {addr: "localhost:8003"}, + } + for _, target := range targets { + target.healthy.Store(true) + } + + rr := newRoundRobin() + + // 测试轮询 + results := make(map[string]int) + for i := 0; i < 6; i++ { + selected := rr.Select(targets) + if selected == nil { + t.Error("Expected non-nil target") + continue + } + results[selected.addr]++ + } + + // 每个目标应该被选中 2 次 + for _, target := range targets { + if results[target.addr] != 2 { + t.Errorf("Expected %s to be selected 2 times, got %d", target.addr, results[target.addr]) + } + } +} + +func TestLeastConnBalancer(t *testing.T) { + targets := []*Target{ + {addr: "localhost:8001", conns: 5}, + {addr: "localhost:8002", conns: 2}, + {addr: "localhost:8003", conns: 8}, + } + for _, t := range targets { + t.healthy.Store(true) + } + + lc := newLeastConn() + selected := lc.Select(targets) + + if selected == nil { + t.Error("Expected non-nil target") + } else if selected.addr != "localhost:8002" { + t.Errorf("Expected localhost:8002 (least connections), got %s", selected.addr) + } +} + +func TestBalancerNoHealthyTargets(t *testing.T) { + targets := []*Target{ + {addr: "localhost:8001"}, + {addr: "localhost:8002"}, + } + // 不设置 healthy,默认为 false + + rr := newRoundRobin() + selected := rr.Select(targets) + if selected != nil { + t.Error("Expected nil for no healthy targets") + } + + lc := newLeastConn() + selected = lc.Select(targets) + if selected != nil { + t.Error("Expected nil for no healthy targets") + } +} + +func TestServerStats(t *testing.T) { + s := NewServer() + + stats := s.Stats() + if stats.Connections != 0 { + t.Errorf("Expected 0 connections, got %d", stats.Connections) + } + if stats.Listeners != 0 { + t.Errorf("Expected 0 listeners, got %d", stats.Listeners) + } +} + +func TestUpstreamSelect(t *testing.T) { + u := &Upstream{ + targets: []*Target{ + {addr: "localhost:8001"}, + {addr: "localhost:8002"}, + }, + balancer: newRoundRobin(), + } + for _, t := range u.targets { + t.healthy.Store(true) + } + + selected := u.Select() + if selected == nil { + t.Error("Expected non-nil target") + } +} + +func TestHealthChecker(t *testing.T) { + u := &Upstream{ + targets: []*Target{ + {addr: "localhost:99999"}, // 不存在的端口 + }, + } + + hc := &HealthChecker{ + upstream: u, + interval: 1 * time.Second, + timeout: 100 * time.Millisecond, + stopCh: make(chan struct{}), + } + + // 执行一次检查 + hc.check() + + // 目标应该被标记为不健康 + if u.targets[0].healthy.Load() { + t.Error("Expected target to be marked unhealthy") + } +} + +func TestUDPListener(t *testing.T) { + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to resolve UDP address: %v", err) + } + + conn, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatalf("Failed to listen UDP: %v", err) + } + defer conn.Close() + + ul := &udpListener{conn: conn} + + // 测试 Addr + if ul.Addr() == nil { + t.Error("Expected non-nil address") + } + + // 测试 Close + if err := ul.Close(); err != nil { + t.Errorf("Close failed: %v", err) + } + + // 测试 Accept(应该返回 io.EOF) + _, err = ul.Accept() + if err == nil { + t.Error("Expected error from Accept") + } +} + +func TestConcurrentConnections(t *testing.T) { + s := NewServer() + + targets := []TargetSpec{ + {Addr: "localhost:8001", Weight: 1}, + } + s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{}) + + // 并发增加连接数 + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + atomic.AddInt64(&s.connCount, 1) + }() + } + wg.Wait() + + if s.connCount != 100 { + t.Errorf("Expected 100 connections, got %d", s.connCount) + } +} \ No newline at end of file diff --git a/main.go b/main.go index d3efad1..6e5eb57 100644 --- a/main.go +++ b/main.go @@ -23,4 +23,4 @@ func main() { } os.Exit(app.Run(configPath, *genConfig, *outputPath, *showVersion)) -} \ No newline at end of file +}