fix(handler,http2,loadbalance,logging,resolver,ssl): fix high severity issues

- handler/static.go: add sync.RWMutex to StaticHandler; protect Handle
  with RLock and all setters with Lock to prevent data races
- http2/server.go: delete empty connection slice keys from pool map to
  prevent memory leak under high client churn
- loadbalance/slow_start.go: recreate stopCh in Start() to support
  Start-Stop-Start cycles; guard double-close in Stop()
- resolver/resolver.go: recreate stopCh in Start() to support restart
- logging/logging.go: save *os.File handles from getOutput so Close()
  actually closes log files; exclude os.Stdout/os.Stderr from closing
- ssl/session_tickets.go: protect started/rotateTimer access in
  scheduleRotation with mu; support Start-Stop-Start cycles
- ssl/ssl.go: cache parsed default certificate to avoid re-parsing on
  every TLS handshake for OCSP stapling
This commit is contained in:
xfy 2026-06-11 17:03:17 +08:00
parent 27e00b84a8
commit f33117b940
7 changed files with 111 additions and 28 deletions

View File

@ -23,6 +23,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/valyala/fasthttp"
@ -84,6 +85,7 @@ func (h *StaticHandler) statWithCache(filePath string) (os.FileInfo, bool, error
// - 大文件(>= 8KB自动启用零拷贝传输
// - alias 与 root 互斥,同时配置时 alias 优先
type StaticHandler struct {
mu sync.RWMutex // 保护以下字段的并发访问
// 指针类型字段(按大小排列)
fileCache *cache.FileCache
fileInfoCache *FileInfoCache // FileInfo 缓存,减少 os.Stat 调用
@ -161,6 +163,8 @@ func NewStaticHandler(root, pathPrefix string, index []string, useSendfile bool)
// 参数:
// - alias: 路径别名
func (h *StaticHandler) SetAlias(alias string) {
h.mu.Lock()
defer h.mu.Unlock()
h.alias = alias
if alias != "" {
h.root = ""
@ -201,6 +205,8 @@ func (h *StaticHandler) buildFilePath(relPath string) string {
// - 仅对小于 1MB 的文件启用缓存
// - 缓存会自动检测文件修改并更新
func (h *StaticHandler) SetFileCache(fc *cache.FileCache) {
h.mu.Lock()
defer h.mu.Unlock()
h.fileCache = fc
}
@ -217,6 +223,8 @@ func (h *StaticHandler) SetFileCache(fc *cache.FileCache) {
//
// handler.SetGzipStatic(true, nil, []string{".gz", ".br"})
func (h *StaticHandler) SetGzipStatic(enabled bool, extensions, precompressedExtensions []string) {
h.mu.Lock()
defer h.mu.Unlock()
if enabled {
h.gzipStatic = compression.NewGzipStatic(true, h.root, extensions, precompressedExtensions)
}
@ -236,6 +244,8 @@ func (h *StaticHandler) SetGzipStatic(enabled bool, extensions, precompressedExt
//
// handler.SetTryFiles([]string{"$uri", "$uri/", "/index.html"}, false, nil)
func (h *StaticHandler) SetTryFiles(tryFiles []string, tryFilesPass bool, router *Router) {
h.mu.Lock()
defer h.mu.Unlock()
h.tryFiles = tryFiles
h.tryFilesPass = tryFilesPass
h.router = router
@ -249,6 +259,8 @@ func (h *StaticHandler) SetTryFiles(tryFiles []string, tryFilesPass bool, router
// 参数:
// - enabled: 是否启用符号链接安全检查
func (h *StaticHandler) SetSymlinkCheck(enabled bool) {
h.mu.Lock()
defer h.mu.Unlock()
h.symlinkCheck = enabled
}
@ -260,6 +272,8 @@ func (h *StaticHandler) SetSymlinkCheck(enabled bool) {
// 参数:
// - enabled: 是否启用内部访问限制
func (h *StaticHandler) SetInternal(enabled bool) {
h.mu.Lock()
defer h.mu.Unlock()
h.internal = enabled
}
@ -271,6 +285,8 @@ func (h *StaticHandler) SetInternal(enabled bool) {
// 参数:
// - expires: 过期时间字符串
func (h *StaticHandler) SetExpires(expires string) {
h.mu.Lock()
defer h.mu.Unlock()
h.expires = expires
}
@ -284,6 +300,8 @@ func (h *StaticHandler) SetExpires(expires string) {
// - localtime: 使用本地时间
// - exactSize: 显示精确大小
func (h *StaticHandler) SetAutoIndex(enabled bool, format string, localtime, exactSize bool) {
h.mu.Lock()
defer h.mu.Unlock()
h.autoIndex = enabled
h.autoIndexFormat = format
h.autoIndexLocaltime = localtime
@ -304,6 +322,8 @@ func (h *StaticHandler) SetAutoIndex(enabled bool, format string, localtime, exa
//
// 默认 TTL 为 5 秒。
func (h *StaticHandler) SetCacheTTL(ttl time.Duration) {
h.mu.Lock()
defer h.mu.Unlock()
h.cacheTTL = ttl
if h.fileInfoCache != nil {
h.fileInfoCache.SetTTL(ttl)
@ -328,6 +348,9 @@ func (h *StaticHandler) SetCacheTTL(ttl time.Duration) {
// 7. 大文件使用零拷贝传输
// 8. 读取文件并存入缓存
func (h *StaticHandler) Handle(ctx *fasthttp.RequestCtx) {
h.mu.RLock()
defer h.mu.RUnlock()
reqPath := string(ctx.Path())
// 检查 internal 限制

View File

@ -308,7 +308,12 @@ func (p *connectionPool) remove(key string, conn net.Conn) {
conns := p.conns[key]
for i, c := range conns {
if c == conn {
p.conns[key] = append(conns[:i], conns[i+1:]...)
conns = append(conns[:i], conns[i+1:]...)
if len(conns) == 0 {
delete(p.conns, key)
} else {
p.conns[key] = conns
}
break
}
}

View File

@ -94,6 +94,13 @@ func (m *SlowStartManager) Start() {
return // 已经在运行
}
// 重建 stopCh 以支持 Start-Stop-Start 周期
select {
case <-m.stopCh:
m.stopCh = make(chan struct{})
default:
}
go m.updateLoop()
}
@ -102,7 +109,12 @@ func (m *SlowStartManager) Stop() {
if !m.running.Swap(false) {
return
}
close(m.stopCh)
select {
case <-m.stopCh:
// 已经关闭
default:
close(m.stopCh)
}
}
// updateLoop 后台更新循环。

View File

@ -94,12 +94,15 @@ func New(cfg *config.LoggingConfig) *Logger {
}
accessWriter := getOutput(cfg.Access.Path)
errorWriter := getOutput(cfg.Error.Path)
logger := &Logger{
accessFormat: cfg.Access.Format,
accessWriter: accessWriter,
accessFile: writerFile(accessWriter),
errorFile: writerFile(errorWriter),
accessLog: zerolog.New(accessWriter).With().Timestamp().Logger(),
errorLog: zerolog.New(getOutput(cfg.Error.Path)).Level(parseLevel(cfg.Error.Level)).With().Timestamp().Logger(),
errorLog: zerolog.New(errorWriter).Level(parseLevel(cfg.Error.Level)).With().Timestamp().Logger(),
}
return logger
@ -133,6 +136,14 @@ func getOutput(path string) io.Writer {
return f
}
// writerFile 从 io.Writer 中提取底层的 *os.File如果不是或为标准流则返回 nil。
func writerFile(w io.Writer) *os.File {
if f, ok := w.(*os.File); ok && f != os.Stdout && f != os.Stderr {
return f
}
return nil
}
// LogAccess 记录访问日志(全局实例)。
//
// 使用全局 log 实例记录 HTTP 请求的基本信息,包括方法、路径、状态码、

View File

@ -357,6 +357,13 @@ func (r *DNSResolver) Start() error {
r.started.Store(true)
// 重建 stopCh 以支持 Start-Stop-Start 周期
select {
case <-r.stopCh:
r.stopCh = make(chan struct{})
default:
}
// 启动后台刷新协程
go r.refreshLoop()

View File

@ -117,15 +117,21 @@ func NewSessionTicketManager(cfg config.SessionTicketsConfig) (*SessionTicketMan
// 必须在调用 GetKeys 之前启动。
func (m *SessionTicketManager) Start() {
m.mu.Lock()
defer m.mu.Unlock()
if m.started {
m.mu.Unlock()
return
}
m.started = true
m.mu.Unlock()
// 重建 stopCh 以支持 Start-Stop-Start 周期
select {
case <-m.stopCh:
m.stopCh = make(chan struct{})
default:
}
// 启动轮换定时器
m.scheduleRotation()
m.scheduleRotationLocked()
}
// Stop 停止密钥轮换定时器。
@ -138,12 +144,17 @@ func (m *SessionTicketManager) Stop() {
return
}
m.started = false
m.mu.Unlock()
close(m.stopCh)
if m.rotateTimer != nil {
m.rotateTimer.Stop()
m.rotateTimer = nil
}
m.mu.Unlock()
select {
case <-m.stopCh:
// 已经关闭
default:
close(m.stopCh)
}
}
@ -230,6 +241,12 @@ func (m *SessionTicketManager) ApplyToTLSConfig(tlsCfg *tls.Config) {
//
// 使用定时器在指定间隔后执行密钥轮换。
func (m *SessionTicketManager) scheduleRotation() {
m.mu.Lock()
defer m.mu.Unlock()
m.scheduleRotationLocked()
}
func (m *SessionTicketManager) scheduleRotationLocked() {
if !m.started {
return
}

View File

@ -74,6 +74,9 @@ type TLSManager struct {
// certificates 解析后的证书映射,用于 OCSP
certificates map[string]*x509.Certificate
// defaultCert 默认证书的解析结果,避免每次握手重新解析
defaultCert *x509.Certificate
// issuers 颁发者证书映射,用于 OCSP
issuers map[string]*x509.Certificate
@ -163,27 +166,30 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
// 解析证书用于 OCSP
if len(cert.Certificate) > 0 {
parsedCert, err := x509.ParseCertificate(cert.Certificate[0])
if err == nil && len(parsedCert.OCSPServer) > 0 {
// 存储证书用于 OCSP 查询
serial := parsedCert.SerialNumber.String()
manager.certificates[serial] = parsedCert
if err == nil {
manager.defaultCert = parsedCert
if len(parsedCert.OCSPServer) > 0 {
// 存储证书用于 OCSP 查询
serial := parsedCert.SerialNumber.String()
manager.certificates[serial] = parsedCert
// 尝试从证书链解析颁发者证书
if len(cert.Certificate) > 1 {
issuerCert, err := x509.ParseCertificate(cert.Certificate[1])
if err == nil {
manager.issuers[serial] = issuerCert
if err := ocspMgr.RegisterCertificate(parsedCert, issuerCert); err != nil {
logging.Warn().Err(err).Msg("OCSP Stapling 注册失败")
// 尝试从证书链解析颁发者证书
if len(cert.Certificate) > 1 {
issuerCert, err := x509.ParseCertificate(cert.Certificate[1])
if err == nil {
manager.issuers[serial] = issuerCert
if err := ocspMgr.RegisterCertificate(parsedCert, issuerCert); err != nil {
logging.Warn().Err(err).Msg("OCSP Stapling 注册失败")
}
}
}
}
// 设置 GetConfigForClient 回调用于 OCSP Stapling
tlsCfg.GetConfigForClient = manager.getConfigForClientWithOCSP
}
}
// 设置 GetConfigForClient 回调用于 OCSP Stapling
tlsCfg.GetConfigForClient = manager.getConfigForClientWithOCSP
ocspMgr.Start()
}
@ -262,10 +268,12 @@ func (m *TLSManager) getConfigForClientWithOCSP(hello *tls.ClientHelloInfo) (*tl
// 将 OCSP 响应附加到证书
cert := &cfgCopy.Certificates[0]
if len(cert.Certificate) > 0 {
// 解析叶子证书以获取序列号
leafCert, err := x509.ParseCertificate(cert.Certificate[0])
if err == nil {
serial := leafCert.SerialNumber.String()
// 使用已缓存的证书解析结果获取序列号
m.mu.RLock()
parsedCert := m.defaultCert
m.mu.RUnlock()
if parsedCert != nil {
serial := parsedCert.SerialNumber.String()
ocspResp := m.ocspManager.GetOCSPResponse(serial)
if ocspResp != nil {
// 将 OCSP 响应附加到证书