From f33117b940e7346688147eb22fc764b953b28ee8 Mon Sep 17 00:00:00 2001 From: xfy Date: Thu, 11 Jun 2026 17:03:17 +0800 Subject: [PATCH] fix(handler,http2,loadbalance,logging,resolver,ssl): fix high severity issues - handler/static.go: add sync.RWMutex to StaticHandler; protect Handle with RLock and all setters with Lock to prevent data races - http2/server.go: delete empty connection slice keys from pool map to prevent memory leak under high client churn - loadbalance/slow_start.go: recreate stopCh in Start() to support Start-Stop-Start cycles; guard double-close in Stop() - resolver/resolver.go: recreate stopCh in Start() to support restart - logging/logging.go: save *os.File handles from getOutput so Close() actually closes log files; exclude os.Stdout/os.Stderr from closing - ssl/session_tickets.go: protect started/rotateTimer access in scheduleRotation with mu; support Start-Stop-Start cycles - ssl/ssl.go: cache parsed default certificate to avoid re-parsing on every TLS handshake for OCSP stapling --- internal/handler/static.go | 23 ++++++++++++++++ internal/http2/server.go | 7 ++++- internal/loadbalance/slow_start.go | 14 +++++++++- internal/logging/logging.go | 13 ++++++++- internal/resolver/resolver.go | 7 +++++ internal/ssl/session_tickets.go | 31 ++++++++++++++++----- internal/ssl/ssl.go | 44 ++++++++++++++++++------------ 7 files changed, 111 insertions(+), 28 deletions(-) diff --git a/internal/handler/static.go b/internal/handler/static.go index 272c493..a926d2c 100644 --- a/internal/handler/static.go +++ b/internal/handler/static.go @@ -23,6 +23,7 @@ import ( "os" "path/filepath" "strings" + "sync" "time" "github.com/valyala/fasthttp" @@ -84,6 +85,7 @@ func (h *StaticHandler) statWithCache(filePath string) (os.FileInfo, bool, error // - 大文件(>= 8KB)自动启用零拷贝传输 // - alias 与 root 互斥,同时配置时 alias 优先 type StaticHandler struct { + mu sync.RWMutex // 保护以下字段的并发访问 // 指针类型字段(按大小排列) fileCache *cache.FileCache fileInfoCache *FileInfoCache // FileInfo 缓存,减少 os.Stat 调用 @@ -161,6 +163,8 @@ func NewStaticHandler(root, pathPrefix string, index []string, useSendfile bool) // 参数: // - alias: 路径别名 func (h *StaticHandler) SetAlias(alias string) { + h.mu.Lock() + defer h.mu.Unlock() h.alias = alias if alias != "" { h.root = "" @@ -201,6 +205,8 @@ func (h *StaticHandler) buildFilePath(relPath string) string { // - 仅对小于 1MB 的文件启用缓存 // - 缓存会自动检测文件修改并更新 func (h *StaticHandler) SetFileCache(fc *cache.FileCache) { + h.mu.Lock() + defer h.mu.Unlock() h.fileCache = fc } @@ -217,6 +223,8 @@ func (h *StaticHandler) SetFileCache(fc *cache.FileCache) { // // handler.SetGzipStatic(true, nil, []string{".gz", ".br"}) func (h *StaticHandler) SetGzipStatic(enabled bool, extensions, precompressedExtensions []string) { + h.mu.Lock() + defer h.mu.Unlock() if enabled { h.gzipStatic = compression.NewGzipStatic(true, h.root, extensions, precompressedExtensions) } @@ -236,6 +244,8 @@ func (h *StaticHandler) SetGzipStatic(enabled bool, extensions, precompressedExt // // handler.SetTryFiles([]string{"$uri", "$uri/", "/index.html"}, false, nil) func (h *StaticHandler) SetTryFiles(tryFiles []string, tryFilesPass bool, router *Router) { + h.mu.Lock() + defer h.mu.Unlock() h.tryFiles = tryFiles h.tryFilesPass = tryFilesPass h.router = router @@ -249,6 +259,8 @@ func (h *StaticHandler) SetTryFiles(tryFiles []string, tryFilesPass bool, router // 参数: // - enabled: 是否启用符号链接安全检查 func (h *StaticHandler) SetSymlinkCheck(enabled bool) { + h.mu.Lock() + defer h.mu.Unlock() h.symlinkCheck = enabled } @@ -260,6 +272,8 @@ func (h *StaticHandler) SetSymlinkCheck(enabled bool) { // 参数: // - enabled: 是否启用内部访问限制 func (h *StaticHandler) SetInternal(enabled bool) { + h.mu.Lock() + defer h.mu.Unlock() h.internal = enabled } @@ -271,6 +285,8 @@ func (h *StaticHandler) SetInternal(enabled bool) { // 参数: // - expires: 过期时间字符串 func (h *StaticHandler) SetExpires(expires string) { + h.mu.Lock() + defer h.mu.Unlock() h.expires = expires } @@ -284,6 +300,8 @@ func (h *StaticHandler) SetExpires(expires string) { // - localtime: 使用本地时间 // - exactSize: 显示精确大小 func (h *StaticHandler) SetAutoIndex(enabled bool, format string, localtime, exactSize bool) { + h.mu.Lock() + defer h.mu.Unlock() h.autoIndex = enabled h.autoIndexFormat = format h.autoIndexLocaltime = localtime @@ -304,6 +322,8 @@ func (h *StaticHandler) SetAutoIndex(enabled bool, format string, localtime, exa // // 默认 TTL 为 5 秒。 func (h *StaticHandler) SetCacheTTL(ttl time.Duration) { + h.mu.Lock() + defer h.mu.Unlock() h.cacheTTL = ttl if h.fileInfoCache != nil { h.fileInfoCache.SetTTL(ttl) @@ -328,6 +348,9 @@ func (h *StaticHandler) SetCacheTTL(ttl time.Duration) { // 7. 大文件使用零拷贝传输 // 8. 读取文件并存入缓存 func (h *StaticHandler) Handle(ctx *fasthttp.RequestCtx) { + h.mu.RLock() + defer h.mu.RUnlock() + reqPath := string(ctx.Path()) // 检查 internal 限制 diff --git a/internal/http2/server.go b/internal/http2/server.go index 63b0773..9dba7be 100644 --- a/internal/http2/server.go +++ b/internal/http2/server.go @@ -308,7 +308,12 @@ func (p *connectionPool) remove(key string, conn net.Conn) { conns := p.conns[key] for i, c := range conns { if c == conn { - p.conns[key] = append(conns[:i], conns[i+1:]...) + conns = append(conns[:i], conns[i+1:]...) + if len(conns) == 0 { + delete(p.conns, key) + } else { + p.conns[key] = conns + } break } } diff --git a/internal/loadbalance/slow_start.go b/internal/loadbalance/slow_start.go index 1aebf21..52236c7 100644 --- a/internal/loadbalance/slow_start.go +++ b/internal/loadbalance/slow_start.go @@ -94,6 +94,13 @@ func (m *SlowStartManager) Start() { return // 已经在运行 } + // 重建 stopCh 以支持 Start-Stop-Start 周期 + select { + case <-m.stopCh: + m.stopCh = make(chan struct{}) + default: + } + go m.updateLoop() } @@ -102,7 +109,12 @@ func (m *SlowStartManager) Stop() { if !m.running.Swap(false) { return } - close(m.stopCh) + select { + case <-m.stopCh: + // 已经关闭 + default: + close(m.stopCh) + } } // updateLoop 后台更新循环。 diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 58c691f..82786ae 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -94,12 +94,15 @@ func New(cfg *config.LoggingConfig) *Logger { } accessWriter := getOutput(cfg.Access.Path) + errorWriter := getOutput(cfg.Error.Path) logger := &Logger{ accessFormat: cfg.Access.Format, accessWriter: accessWriter, + accessFile: writerFile(accessWriter), + errorFile: writerFile(errorWriter), accessLog: zerolog.New(accessWriter).With().Timestamp().Logger(), - errorLog: zerolog.New(getOutput(cfg.Error.Path)).Level(parseLevel(cfg.Error.Level)).With().Timestamp().Logger(), + errorLog: zerolog.New(errorWriter).Level(parseLevel(cfg.Error.Level)).With().Timestamp().Logger(), } return logger @@ -133,6 +136,14 @@ func getOutput(path string) io.Writer { return f } +// writerFile 从 io.Writer 中提取底层的 *os.File,如果不是或为标准流则返回 nil。 +func writerFile(w io.Writer) *os.File { + if f, ok := w.(*os.File); ok && f != os.Stdout && f != os.Stderr { + return f + } + return nil +} + // LogAccess 记录访问日志(全局实例)。 // // 使用全局 log 实例记录 HTTP 请求的基本信息,包括方法、路径、状态码、 diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index 6189ad1..0d0cd33 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -357,6 +357,13 @@ func (r *DNSResolver) Start() error { r.started.Store(true) + // 重建 stopCh 以支持 Start-Stop-Start 周期 + select { + case <-r.stopCh: + r.stopCh = make(chan struct{}) + default: + } + // 启动后台刷新协程 go r.refreshLoop() diff --git a/internal/ssl/session_tickets.go b/internal/ssl/session_tickets.go index 47742ca..8f682ea 100644 --- a/internal/ssl/session_tickets.go +++ b/internal/ssl/session_tickets.go @@ -117,15 +117,21 @@ func NewSessionTicketManager(cfg config.SessionTicketsConfig) (*SessionTicketMan // 必须在调用 GetKeys 之前启动。 func (m *SessionTicketManager) Start() { m.mu.Lock() + defer m.mu.Unlock() if m.started { - m.mu.Unlock() return } m.started = true - m.mu.Unlock() + + // 重建 stopCh 以支持 Start-Stop-Start 周期 + select { + case <-m.stopCh: + m.stopCh = make(chan struct{}) + default: + } // 启动轮换定时器 - m.scheduleRotation() + m.scheduleRotationLocked() } // Stop 停止密钥轮换定时器。 @@ -138,12 +144,17 @@ func (m *SessionTicketManager) Stop() { return } m.started = false - m.mu.Unlock() - - close(m.stopCh) - if m.rotateTimer != nil { m.rotateTimer.Stop() + m.rotateTimer = nil + } + m.mu.Unlock() + + select { + case <-m.stopCh: + // 已经关闭 + default: + close(m.stopCh) } } @@ -230,6 +241,12 @@ func (m *SessionTicketManager) ApplyToTLSConfig(tlsCfg *tls.Config) { // // 使用定时器在指定间隔后执行密钥轮换。 func (m *SessionTicketManager) scheduleRotation() { + m.mu.Lock() + defer m.mu.Unlock() + m.scheduleRotationLocked() +} + +func (m *SessionTicketManager) scheduleRotationLocked() { if !m.started { return } diff --git a/internal/ssl/ssl.go b/internal/ssl/ssl.go index 4a4cfec..5ea9381 100644 --- a/internal/ssl/ssl.go +++ b/internal/ssl/ssl.go @@ -74,6 +74,9 @@ type TLSManager struct { // certificates 解析后的证书映射,用于 OCSP certificates map[string]*x509.Certificate + // defaultCert 默认证书的解析结果,避免每次握手重新解析 + defaultCert *x509.Certificate + // issuers 颁发者证书映射,用于 OCSP issuers map[string]*x509.Certificate @@ -163,27 +166,30 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) { // 解析证书用于 OCSP if len(cert.Certificate) > 0 { parsedCert, err := x509.ParseCertificate(cert.Certificate[0]) - if err == nil && len(parsedCert.OCSPServer) > 0 { - // 存储证书用于 OCSP 查询 - serial := parsedCert.SerialNumber.String() - manager.certificates[serial] = parsedCert + if err == nil { + manager.defaultCert = parsedCert + if len(parsedCert.OCSPServer) > 0 { + // 存储证书用于 OCSP 查询 + serial := parsedCert.SerialNumber.String() + manager.certificates[serial] = parsedCert - // 尝试从证书链解析颁发者证书 - if len(cert.Certificate) > 1 { - issuerCert, err := x509.ParseCertificate(cert.Certificate[1]) - if err == nil { - manager.issuers[serial] = issuerCert - if err := ocspMgr.RegisterCertificate(parsedCert, issuerCert); err != nil { - logging.Warn().Err(err).Msg("OCSP Stapling 注册失败") + // 尝试从证书链解析颁发者证书 + if len(cert.Certificate) > 1 { + issuerCert, err := x509.ParseCertificate(cert.Certificate[1]) + if err == nil { + manager.issuers[serial] = issuerCert + if err := ocspMgr.RegisterCertificate(parsedCert, issuerCert); err != nil { + logging.Warn().Err(err).Msg("OCSP Stapling 注册失败") + } } } } - - // 设置 GetConfigForClient 回调用于 OCSP Stapling - tlsCfg.GetConfigForClient = manager.getConfigForClientWithOCSP } } + // 设置 GetConfigForClient 回调用于 OCSP Stapling + tlsCfg.GetConfigForClient = manager.getConfigForClientWithOCSP + ocspMgr.Start() } @@ -262,10 +268,12 @@ func (m *TLSManager) getConfigForClientWithOCSP(hello *tls.ClientHelloInfo) (*tl // 将 OCSP 响应附加到证书 cert := &cfgCopy.Certificates[0] if len(cert.Certificate) > 0 { - // 解析叶子证书以获取序列号 - leafCert, err := x509.ParseCertificate(cert.Certificate[0]) - if err == nil { - serial := leafCert.SerialNumber.String() + // 使用已缓存的证书解析结果获取序列号 + m.mu.RLock() + parsedCert := m.defaultCert + m.mu.RUnlock() + if parsedCert != nil { + serial := parsedCert.SerialNumber.String() ocspResp := m.ocspManager.GetOCSPResponse(serial) if ocspResp != nil { // 将 OCSP 响应附加到证书