fix(handler,http2,loadbalance,logging,resolver,ssl): fix high severity issues

- handler/static.go: add sync.RWMutex to StaticHandler; protect Handle
  with RLock and all setters with Lock to prevent data races
- http2/server.go: delete empty connection slice keys from pool map to
  prevent memory leak under high client churn
- loadbalance/slow_start.go: recreate stopCh in Start() to support
  Start-Stop-Start cycles; guard double-close in Stop()
- resolver/resolver.go: recreate stopCh in Start() to support restart
- logging/logging.go: save *os.File handles from getOutput so Close()
  actually closes log files; exclude os.Stdout/os.Stderr from closing
- ssl/session_tickets.go: protect started/rotateTimer access in
  scheduleRotation with mu; support Start-Stop-Start cycles
- ssl/ssl.go: cache parsed default certificate to avoid re-parsing on
  every TLS handshake for OCSP stapling
This commit is contained in:
xfy 2026-06-11 17:03:17 +08:00
parent 27e00b84a8
commit f33117b940
7 changed files with 111 additions and 28 deletions

View File

@ -23,6 +23,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"time" "time"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
@ -84,6 +85,7 @@ func (h *StaticHandler) statWithCache(filePath string) (os.FileInfo, bool, error
// - 大文件(>= 8KB自动启用零拷贝传输 // - 大文件(>= 8KB自动启用零拷贝传输
// - alias 与 root 互斥,同时配置时 alias 优先 // - alias 与 root 互斥,同时配置时 alias 优先
type StaticHandler struct { type StaticHandler struct {
mu sync.RWMutex // 保护以下字段的并发访问
// 指针类型字段(按大小排列) // 指针类型字段(按大小排列)
fileCache *cache.FileCache fileCache *cache.FileCache
fileInfoCache *FileInfoCache // FileInfo 缓存,减少 os.Stat 调用 fileInfoCache *FileInfoCache // FileInfo 缓存,减少 os.Stat 调用
@ -161,6 +163,8 @@ func NewStaticHandler(root, pathPrefix string, index []string, useSendfile bool)
// 参数: // 参数:
// - alias: 路径别名 // - alias: 路径别名
func (h *StaticHandler) SetAlias(alias string) { func (h *StaticHandler) SetAlias(alias string) {
h.mu.Lock()
defer h.mu.Unlock()
h.alias = alias h.alias = alias
if alias != "" { if alias != "" {
h.root = "" h.root = ""
@ -201,6 +205,8 @@ func (h *StaticHandler) buildFilePath(relPath string) string {
// - 仅对小于 1MB 的文件启用缓存 // - 仅对小于 1MB 的文件启用缓存
// - 缓存会自动检测文件修改并更新 // - 缓存会自动检测文件修改并更新
func (h *StaticHandler) SetFileCache(fc *cache.FileCache) { func (h *StaticHandler) SetFileCache(fc *cache.FileCache) {
h.mu.Lock()
defer h.mu.Unlock()
h.fileCache = fc h.fileCache = fc
} }
@ -217,6 +223,8 @@ func (h *StaticHandler) SetFileCache(fc *cache.FileCache) {
// //
// handler.SetGzipStatic(true, nil, []string{".gz", ".br"}) // handler.SetGzipStatic(true, nil, []string{".gz", ".br"})
func (h *StaticHandler) SetGzipStatic(enabled bool, extensions, precompressedExtensions []string) { func (h *StaticHandler) SetGzipStatic(enabled bool, extensions, precompressedExtensions []string) {
h.mu.Lock()
defer h.mu.Unlock()
if enabled { if enabled {
h.gzipStatic = compression.NewGzipStatic(true, h.root, extensions, precompressedExtensions) h.gzipStatic = compression.NewGzipStatic(true, h.root, extensions, precompressedExtensions)
} }
@ -236,6 +244,8 @@ func (h *StaticHandler) SetGzipStatic(enabled bool, extensions, precompressedExt
// //
// handler.SetTryFiles([]string{"$uri", "$uri/", "/index.html"}, false, nil) // handler.SetTryFiles([]string{"$uri", "$uri/", "/index.html"}, false, nil)
func (h *StaticHandler) SetTryFiles(tryFiles []string, tryFilesPass bool, router *Router) { func (h *StaticHandler) SetTryFiles(tryFiles []string, tryFilesPass bool, router *Router) {
h.mu.Lock()
defer h.mu.Unlock()
h.tryFiles = tryFiles h.tryFiles = tryFiles
h.tryFilesPass = tryFilesPass h.tryFilesPass = tryFilesPass
h.router = router h.router = router
@ -249,6 +259,8 @@ func (h *StaticHandler) SetTryFiles(tryFiles []string, tryFilesPass bool, router
// 参数: // 参数:
// - enabled: 是否启用符号链接安全检查 // - enabled: 是否启用符号链接安全检查
func (h *StaticHandler) SetSymlinkCheck(enabled bool) { func (h *StaticHandler) SetSymlinkCheck(enabled bool) {
h.mu.Lock()
defer h.mu.Unlock()
h.symlinkCheck = enabled h.symlinkCheck = enabled
} }
@ -260,6 +272,8 @@ func (h *StaticHandler) SetSymlinkCheck(enabled bool) {
// 参数: // 参数:
// - enabled: 是否启用内部访问限制 // - enabled: 是否启用内部访问限制
func (h *StaticHandler) SetInternal(enabled bool) { func (h *StaticHandler) SetInternal(enabled bool) {
h.mu.Lock()
defer h.mu.Unlock()
h.internal = enabled h.internal = enabled
} }
@ -271,6 +285,8 @@ func (h *StaticHandler) SetInternal(enabled bool) {
// 参数: // 参数:
// - expires: 过期时间字符串 // - expires: 过期时间字符串
func (h *StaticHandler) SetExpires(expires string) { func (h *StaticHandler) SetExpires(expires string) {
h.mu.Lock()
defer h.mu.Unlock()
h.expires = expires h.expires = expires
} }
@ -284,6 +300,8 @@ func (h *StaticHandler) SetExpires(expires string) {
// - localtime: 使用本地时间 // - localtime: 使用本地时间
// - exactSize: 显示精确大小 // - exactSize: 显示精确大小
func (h *StaticHandler) SetAutoIndex(enabled bool, format string, localtime, exactSize bool) { func (h *StaticHandler) SetAutoIndex(enabled bool, format string, localtime, exactSize bool) {
h.mu.Lock()
defer h.mu.Unlock()
h.autoIndex = enabled h.autoIndex = enabled
h.autoIndexFormat = format h.autoIndexFormat = format
h.autoIndexLocaltime = localtime h.autoIndexLocaltime = localtime
@ -304,6 +322,8 @@ func (h *StaticHandler) SetAutoIndex(enabled bool, format string, localtime, exa
// //
// 默认 TTL 为 5 秒。 // 默认 TTL 为 5 秒。
func (h *StaticHandler) SetCacheTTL(ttl time.Duration) { func (h *StaticHandler) SetCacheTTL(ttl time.Duration) {
h.mu.Lock()
defer h.mu.Unlock()
h.cacheTTL = ttl h.cacheTTL = ttl
if h.fileInfoCache != nil { if h.fileInfoCache != nil {
h.fileInfoCache.SetTTL(ttl) h.fileInfoCache.SetTTL(ttl)
@ -328,6 +348,9 @@ func (h *StaticHandler) SetCacheTTL(ttl time.Duration) {
// 7. 大文件使用零拷贝传输 // 7. 大文件使用零拷贝传输
// 8. 读取文件并存入缓存 // 8. 读取文件并存入缓存
func (h *StaticHandler) Handle(ctx *fasthttp.RequestCtx) { func (h *StaticHandler) Handle(ctx *fasthttp.RequestCtx) {
h.mu.RLock()
defer h.mu.RUnlock()
reqPath := string(ctx.Path()) reqPath := string(ctx.Path())
// 检查 internal 限制 // 检查 internal 限制

View File

@ -308,7 +308,12 @@ func (p *connectionPool) remove(key string, conn net.Conn) {
conns := p.conns[key] conns := p.conns[key]
for i, c := range conns { for i, c := range conns {
if c == conn { if c == conn {
p.conns[key] = append(conns[:i], conns[i+1:]...) conns = append(conns[:i], conns[i+1:]...)
if len(conns) == 0 {
delete(p.conns, key)
} else {
p.conns[key] = conns
}
break break
} }
} }

View File

@ -94,6 +94,13 @@ func (m *SlowStartManager) Start() {
return // 已经在运行 return // 已经在运行
} }
// 重建 stopCh 以支持 Start-Stop-Start 周期
select {
case <-m.stopCh:
m.stopCh = make(chan struct{})
default:
}
go m.updateLoop() go m.updateLoop()
} }
@ -102,8 +109,13 @@ func (m *SlowStartManager) Stop() {
if !m.running.Swap(false) { if !m.running.Swap(false) {
return return
} }
select {
case <-m.stopCh:
// 已经关闭
default:
close(m.stopCh) close(m.stopCh)
} }
}
// updateLoop 后台更新循环。 // updateLoop 后台更新循环。
func (m *SlowStartManager) updateLoop() { func (m *SlowStartManager) updateLoop() {

View File

@ -94,12 +94,15 @@ func New(cfg *config.LoggingConfig) *Logger {
} }
accessWriter := getOutput(cfg.Access.Path) accessWriter := getOutput(cfg.Access.Path)
errorWriter := getOutput(cfg.Error.Path)
logger := &Logger{ logger := &Logger{
accessFormat: cfg.Access.Format, accessFormat: cfg.Access.Format,
accessWriter: accessWriter, accessWriter: accessWriter,
accessFile: writerFile(accessWriter),
errorFile: writerFile(errorWriter),
accessLog: zerolog.New(accessWriter).With().Timestamp().Logger(), accessLog: zerolog.New(accessWriter).With().Timestamp().Logger(),
errorLog: zerolog.New(getOutput(cfg.Error.Path)).Level(parseLevel(cfg.Error.Level)).With().Timestamp().Logger(), errorLog: zerolog.New(errorWriter).Level(parseLevel(cfg.Error.Level)).With().Timestamp().Logger(),
} }
return logger return logger
@ -133,6 +136,14 @@ func getOutput(path string) io.Writer {
return f return f
} }
// writerFile 从 io.Writer 中提取底层的 *os.File如果不是或为标准流则返回 nil。
func writerFile(w io.Writer) *os.File {
if f, ok := w.(*os.File); ok && f != os.Stdout && f != os.Stderr {
return f
}
return nil
}
// LogAccess 记录访问日志(全局实例)。 // LogAccess 记录访问日志(全局实例)。
// //
// 使用全局 log 实例记录 HTTP 请求的基本信息,包括方法、路径、状态码、 // 使用全局 log 实例记录 HTTP 请求的基本信息,包括方法、路径、状态码、

View File

@ -357,6 +357,13 @@ func (r *DNSResolver) Start() error {
r.started.Store(true) r.started.Store(true)
// 重建 stopCh 以支持 Start-Stop-Start 周期
select {
case <-r.stopCh:
r.stopCh = make(chan struct{})
default:
}
// 启动后台刷新协程 // 启动后台刷新协程
go r.refreshLoop() go r.refreshLoop()

View File

@ -117,15 +117,21 @@ func NewSessionTicketManager(cfg config.SessionTicketsConfig) (*SessionTicketMan
// 必须在调用 GetKeys 之前启动。 // 必须在调用 GetKeys 之前启动。
func (m *SessionTicketManager) Start() { func (m *SessionTicketManager) Start() {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock()
if m.started { if m.started {
m.mu.Unlock()
return return
} }
m.started = true m.started = true
m.mu.Unlock()
// 重建 stopCh 以支持 Start-Stop-Start 周期
select {
case <-m.stopCh:
m.stopCh = make(chan struct{})
default:
}
// 启动轮换定时器 // 启动轮换定时器
m.scheduleRotation() m.scheduleRotationLocked()
} }
// Stop 停止密钥轮换定时器。 // Stop 停止密钥轮换定时器。
@ -138,12 +144,17 @@ func (m *SessionTicketManager) Stop() {
return return
} }
m.started = false m.started = false
m.mu.Unlock()
close(m.stopCh)
if m.rotateTimer != nil { if m.rotateTimer != nil {
m.rotateTimer.Stop() m.rotateTimer.Stop()
m.rotateTimer = nil
}
m.mu.Unlock()
select {
case <-m.stopCh:
// 已经关闭
default:
close(m.stopCh)
} }
} }
@ -230,6 +241,12 @@ func (m *SessionTicketManager) ApplyToTLSConfig(tlsCfg *tls.Config) {
// //
// 使用定时器在指定间隔后执行密钥轮换。 // 使用定时器在指定间隔后执行密钥轮换。
func (m *SessionTicketManager) scheduleRotation() { func (m *SessionTicketManager) scheduleRotation() {
m.mu.Lock()
defer m.mu.Unlock()
m.scheduleRotationLocked()
}
func (m *SessionTicketManager) scheduleRotationLocked() {
if !m.started { if !m.started {
return return
} }

View File

@ -74,6 +74,9 @@ type TLSManager struct {
// certificates 解析后的证书映射,用于 OCSP // certificates 解析后的证书映射,用于 OCSP
certificates map[string]*x509.Certificate certificates map[string]*x509.Certificate
// defaultCert 默认证书的解析结果,避免每次握手重新解析
defaultCert *x509.Certificate
// issuers 颁发者证书映射,用于 OCSP // issuers 颁发者证书映射,用于 OCSP
issuers map[string]*x509.Certificate issuers map[string]*x509.Certificate
@ -163,7 +166,9 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
// 解析证书用于 OCSP // 解析证书用于 OCSP
if len(cert.Certificate) > 0 { if len(cert.Certificate) > 0 {
parsedCert, err := x509.ParseCertificate(cert.Certificate[0]) parsedCert, err := x509.ParseCertificate(cert.Certificate[0])
if err == nil && len(parsedCert.OCSPServer) > 0 { if err == nil {
manager.defaultCert = parsedCert
if len(parsedCert.OCSPServer) > 0 {
// 存储证书用于 OCSP 查询 // 存储证书用于 OCSP 查询
serial := parsedCert.SerialNumber.String() serial := parsedCert.SerialNumber.String()
manager.certificates[serial] = parsedCert manager.certificates[serial] = parsedCert
@ -178,11 +183,12 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
} }
} }
} }
}
}
}
// 设置 GetConfigForClient 回调用于 OCSP Stapling // 设置 GetConfigForClient 回调用于 OCSP Stapling
tlsCfg.GetConfigForClient = manager.getConfigForClientWithOCSP tlsCfg.GetConfigForClient = manager.getConfigForClientWithOCSP
}
}
ocspMgr.Start() ocspMgr.Start()
} }
@ -262,10 +268,12 @@ func (m *TLSManager) getConfigForClientWithOCSP(hello *tls.ClientHelloInfo) (*tl
// 将 OCSP 响应附加到证书 // 将 OCSP 响应附加到证书
cert := &cfgCopy.Certificates[0] cert := &cfgCopy.Certificates[0]
if len(cert.Certificate) > 0 { if len(cert.Certificate) > 0 {
// 解析叶子证书以获取序列号 // 使用已缓存的证书解析结果获取序列号
leafCert, err := x509.ParseCertificate(cert.Certificate[0]) m.mu.RLock()
if err == nil { parsedCert := m.defaultCert
serial := leafCert.SerialNumber.String() m.mu.RUnlock()
if parsedCert != nil {
serial := parsedCert.SerialNumber.String()
ocspResp := m.ocspManager.GetOCSPResponse(serial) ocspResp := m.ocspManager.GetOCSPResponse(serial)
if ocspResp != nil { if ocspResp != nil {
// 将 OCSP 响应附加到证书 // 将 OCSP 响应附加到证书