fix(handler,http2,loadbalance,logging,resolver,ssl): fix high severity issues
- handler/static.go: add sync.RWMutex to StaticHandler; protect Handle with RLock and all setters with Lock to prevent data races - http2/server.go: delete empty connection slice keys from pool map to prevent memory leak under high client churn - loadbalance/slow_start.go: recreate stopCh in Start() to support Start-Stop-Start cycles; guard double-close in Stop() - resolver/resolver.go: recreate stopCh in Start() to support restart - logging/logging.go: save *os.File handles from getOutput so Close() actually closes log files; exclude os.Stdout/os.Stderr from closing - ssl/session_tickets.go: protect started/rotateTimer access in scheduleRotation with mu; support Start-Stop-Start cycles - ssl/ssl.go: cache parsed default certificate to avoid re-parsing on every TLS handshake for OCSP stapling
This commit is contained in:
parent
27e00b84a8
commit
f33117b940
@ -23,6 +23,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
@ -84,6 +85,7 @@ func (h *StaticHandler) statWithCache(filePath string) (os.FileInfo, bool, error
|
|||||||
// - 大文件(>= 8KB)自动启用零拷贝传输
|
// - 大文件(>= 8KB)自动启用零拷贝传输
|
||||||
// - alias 与 root 互斥,同时配置时 alias 优先
|
// - alias 与 root 互斥,同时配置时 alias 优先
|
||||||
type StaticHandler struct {
|
type StaticHandler struct {
|
||||||
|
mu sync.RWMutex // 保护以下字段的并发访问
|
||||||
// 指针类型字段(按大小排列)
|
// 指针类型字段(按大小排列)
|
||||||
fileCache *cache.FileCache
|
fileCache *cache.FileCache
|
||||||
fileInfoCache *FileInfoCache // FileInfo 缓存,减少 os.Stat 调用
|
fileInfoCache *FileInfoCache // FileInfo 缓存,减少 os.Stat 调用
|
||||||
@ -161,6 +163,8 @@ func NewStaticHandler(root, pathPrefix string, index []string, useSendfile bool)
|
|||||||
// 参数:
|
// 参数:
|
||||||
// - alias: 路径别名
|
// - alias: 路径别名
|
||||||
func (h *StaticHandler) SetAlias(alias string) {
|
func (h *StaticHandler) SetAlias(alias string) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
h.alias = alias
|
h.alias = alias
|
||||||
if alias != "" {
|
if alias != "" {
|
||||||
h.root = ""
|
h.root = ""
|
||||||
@ -201,6 +205,8 @@ func (h *StaticHandler) buildFilePath(relPath string) string {
|
|||||||
// - 仅对小于 1MB 的文件启用缓存
|
// - 仅对小于 1MB 的文件启用缓存
|
||||||
// - 缓存会自动检测文件修改并更新
|
// - 缓存会自动检测文件修改并更新
|
||||||
func (h *StaticHandler) SetFileCache(fc *cache.FileCache) {
|
func (h *StaticHandler) SetFileCache(fc *cache.FileCache) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
h.fileCache = fc
|
h.fileCache = fc
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -217,6 +223,8 @@ func (h *StaticHandler) SetFileCache(fc *cache.FileCache) {
|
|||||||
//
|
//
|
||||||
// handler.SetGzipStatic(true, nil, []string{".gz", ".br"})
|
// handler.SetGzipStatic(true, nil, []string{".gz", ".br"})
|
||||||
func (h *StaticHandler) SetGzipStatic(enabled bool, extensions, precompressedExtensions []string) {
|
func (h *StaticHandler) SetGzipStatic(enabled bool, extensions, precompressedExtensions []string) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
if enabled {
|
if enabled {
|
||||||
h.gzipStatic = compression.NewGzipStatic(true, h.root, extensions, precompressedExtensions)
|
h.gzipStatic = compression.NewGzipStatic(true, h.root, extensions, precompressedExtensions)
|
||||||
}
|
}
|
||||||
@ -236,6 +244,8 @@ func (h *StaticHandler) SetGzipStatic(enabled bool, extensions, precompressedExt
|
|||||||
//
|
//
|
||||||
// handler.SetTryFiles([]string{"$uri", "$uri/", "/index.html"}, false, nil)
|
// handler.SetTryFiles([]string{"$uri", "$uri/", "/index.html"}, false, nil)
|
||||||
func (h *StaticHandler) SetTryFiles(tryFiles []string, tryFilesPass bool, router *Router) {
|
func (h *StaticHandler) SetTryFiles(tryFiles []string, tryFilesPass bool, router *Router) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
h.tryFiles = tryFiles
|
h.tryFiles = tryFiles
|
||||||
h.tryFilesPass = tryFilesPass
|
h.tryFilesPass = tryFilesPass
|
||||||
h.router = router
|
h.router = router
|
||||||
@ -249,6 +259,8 @@ func (h *StaticHandler) SetTryFiles(tryFiles []string, tryFilesPass bool, router
|
|||||||
// 参数:
|
// 参数:
|
||||||
// - enabled: 是否启用符号链接安全检查
|
// - enabled: 是否启用符号链接安全检查
|
||||||
func (h *StaticHandler) SetSymlinkCheck(enabled bool) {
|
func (h *StaticHandler) SetSymlinkCheck(enabled bool) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
h.symlinkCheck = enabled
|
h.symlinkCheck = enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,6 +272,8 @@ func (h *StaticHandler) SetSymlinkCheck(enabled bool) {
|
|||||||
// 参数:
|
// 参数:
|
||||||
// - enabled: 是否启用内部访问限制
|
// - enabled: 是否启用内部访问限制
|
||||||
func (h *StaticHandler) SetInternal(enabled bool) {
|
func (h *StaticHandler) SetInternal(enabled bool) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
h.internal = enabled
|
h.internal = enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -271,6 +285,8 @@ func (h *StaticHandler) SetInternal(enabled bool) {
|
|||||||
// 参数:
|
// 参数:
|
||||||
// - expires: 过期时间字符串
|
// - expires: 过期时间字符串
|
||||||
func (h *StaticHandler) SetExpires(expires string) {
|
func (h *StaticHandler) SetExpires(expires string) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
h.expires = expires
|
h.expires = expires
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -284,6 +300,8 @@ func (h *StaticHandler) SetExpires(expires string) {
|
|||||||
// - localtime: 使用本地时间
|
// - localtime: 使用本地时间
|
||||||
// - exactSize: 显示精确大小
|
// - exactSize: 显示精确大小
|
||||||
func (h *StaticHandler) SetAutoIndex(enabled bool, format string, localtime, exactSize bool) {
|
func (h *StaticHandler) SetAutoIndex(enabled bool, format string, localtime, exactSize bool) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
h.autoIndex = enabled
|
h.autoIndex = enabled
|
||||||
h.autoIndexFormat = format
|
h.autoIndexFormat = format
|
||||||
h.autoIndexLocaltime = localtime
|
h.autoIndexLocaltime = localtime
|
||||||
@ -304,6 +322,8 @@ func (h *StaticHandler) SetAutoIndex(enabled bool, format string, localtime, exa
|
|||||||
//
|
//
|
||||||
// 默认 TTL 为 5 秒。
|
// 默认 TTL 为 5 秒。
|
||||||
func (h *StaticHandler) SetCacheTTL(ttl time.Duration) {
|
func (h *StaticHandler) SetCacheTTL(ttl time.Duration) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
h.cacheTTL = ttl
|
h.cacheTTL = ttl
|
||||||
if h.fileInfoCache != nil {
|
if h.fileInfoCache != nil {
|
||||||
h.fileInfoCache.SetTTL(ttl)
|
h.fileInfoCache.SetTTL(ttl)
|
||||||
@ -328,6 +348,9 @@ func (h *StaticHandler) SetCacheTTL(ttl time.Duration) {
|
|||||||
// 7. 大文件使用零拷贝传输
|
// 7. 大文件使用零拷贝传输
|
||||||
// 8. 读取文件并存入缓存
|
// 8. 读取文件并存入缓存
|
||||||
func (h *StaticHandler) Handle(ctx *fasthttp.RequestCtx) {
|
func (h *StaticHandler) Handle(ctx *fasthttp.RequestCtx) {
|
||||||
|
h.mu.RLock()
|
||||||
|
defer h.mu.RUnlock()
|
||||||
|
|
||||||
reqPath := string(ctx.Path())
|
reqPath := string(ctx.Path())
|
||||||
|
|
||||||
// 检查 internal 限制
|
// 检查 internal 限制
|
||||||
|
|||||||
@ -308,7 +308,12 @@ func (p *connectionPool) remove(key string, conn net.Conn) {
|
|||||||
conns := p.conns[key]
|
conns := p.conns[key]
|
||||||
for i, c := range conns {
|
for i, c := range conns {
|
||||||
if c == conn {
|
if c == conn {
|
||||||
p.conns[key] = append(conns[:i], conns[i+1:]...)
|
conns = append(conns[:i], conns[i+1:]...)
|
||||||
|
if len(conns) == 0 {
|
||||||
|
delete(p.conns, key)
|
||||||
|
} else {
|
||||||
|
p.conns[key] = conns
|
||||||
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -94,6 +94,13 @@ func (m *SlowStartManager) Start() {
|
|||||||
return // 已经在运行
|
return // 已经在运行
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 重建 stopCh 以支持 Start-Stop-Start 周期
|
||||||
|
select {
|
||||||
|
case <-m.stopCh:
|
||||||
|
m.stopCh = make(chan struct{})
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
go m.updateLoop()
|
go m.updateLoop()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,7 +109,12 @@ func (m *SlowStartManager) Stop() {
|
|||||||
if !m.running.Swap(false) {
|
if !m.running.Swap(false) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
select {
|
||||||
|
case <-m.stopCh:
|
||||||
|
// 已经关闭
|
||||||
|
default:
|
||||||
close(m.stopCh)
|
close(m.stopCh)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateLoop 后台更新循环。
|
// updateLoop 后台更新循环。
|
||||||
|
|||||||
@ -94,12 +94,15 @@ func New(cfg *config.LoggingConfig) *Logger {
|
|||||||
}
|
}
|
||||||
|
|
||||||
accessWriter := getOutput(cfg.Access.Path)
|
accessWriter := getOutput(cfg.Access.Path)
|
||||||
|
errorWriter := getOutput(cfg.Error.Path)
|
||||||
|
|
||||||
logger := &Logger{
|
logger := &Logger{
|
||||||
accessFormat: cfg.Access.Format,
|
accessFormat: cfg.Access.Format,
|
||||||
accessWriter: accessWriter,
|
accessWriter: accessWriter,
|
||||||
|
accessFile: writerFile(accessWriter),
|
||||||
|
errorFile: writerFile(errorWriter),
|
||||||
accessLog: zerolog.New(accessWriter).With().Timestamp().Logger(),
|
accessLog: zerolog.New(accessWriter).With().Timestamp().Logger(),
|
||||||
errorLog: zerolog.New(getOutput(cfg.Error.Path)).Level(parseLevel(cfg.Error.Level)).With().Timestamp().Logger(),
|
errorLog: zerolog.New(errorWriter).Level(parseLevel(cfg.Error.Level)).With().Timestamp().Logger(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
@ -133,6 +136,14 @@ func getOutput(path string) io.Writer {
|
|||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writerFile 从 io.Writer 中提取底层的 *os.File,如果不是或为标准流则返回 nil。
|
||||||
|
func writerFile(w io.Writer) *os.File {
|
||||||
|
if f, ok := w.(*os.File); ok && f != os.Stdout && f != os.Stderr {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// LogAccess 记录访问日志(全局实例)。
|
// LogAccess 记录访问日志(全局实例)。
|
||||||
//
|
//
|
||||||
// 使用全局 log 实例记录 HTTP 请求的基本信息,包括方法、路径、状态码、
|
// 使用全局 log 实例记录 HTTP 请求的基本信息,包括方法、路径、状态码、
|
||||||
|
|||||||
@ -357,6 +357,13 @@ func (r *DNSResolver) Start() error {
|
|||||||
|
|
||||||
r.started.Store(true)
|
r.started.Store(true)
|
||||||
|
|
||||||
|
// 重建 stopCh 以支持 Start-Stop-Start 周期
|
||||||
|
select {
|
||||||
|
case <-r.stopCh:
|
||||||
|
r.stopCh = make(chan struct{})
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
// 启动后台刷新协程
|
// 启动后台刷新协程
|
||||||
go r.refreshLoop()
|
go r.refreshLoop()
|
||||||
|
|
||||||
|
|||||||
@ -117,15 +117,21 @@ func NewSessionTicketManager(cfg config.SessionTicketsConfig) (*SessionTicketMan
|
|||||||
// 必须在调用 GetKeys 之前启动。
|
// 必须在调用 GetKeys 之前启动。
|
||||||
func (m *SessionTicketManager) Start() {
|
func (m *SessionTicketManager) Start() {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
if m.started {
|
if m.started {
|
||||||
m.mu.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m.started = true
|
m.started = true
|
||||||
m.mu.Unlock()
|
|
||||||
|
// 重建 stopCh 以支持 Start-Stop-Start 周期
|
||||||
|
select {
|
||||||
|
case <-m.stopCh:
|
||||||
|
m.stopCh = make(chan struct{})
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
// 启动轮换定时器
|
// 启动轮换定时器
|
||||||
m.scheduleRotation()
|
m.scheduleRotationLocked()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop 停止密钥轮换定时器。
|
// Stop 停止密钥轮换定时器。
|
||||||
@ -138,12 +144,17 @@ func (m *SessionTicketManager) Stop() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
m.started = false
|
m.started = false
|
||||||
m.mu.Unlock()
|
|
||||||
|
|
||||||
close(m.stopCh)
|
|
||||||
|
|
||||||
if m.rotateTimer != nil {
|
if m.rotateTimer != nil {
|
||||||
m.rotateTimer.Stop()
|
m.rotateTimer.Stop()
|
||||||
|
m.rotateTimer = nil
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-m.stopCh:
|
||||||
|
// 已经关闭
|
||||||
|
default:
|
||||||
|
close(m.stopCh)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,6 +241,12 @@ func (m *SessionTicketManager) ApplyToTLSConfig(tlsCfg *tls.Config) {
|
|||||||
//
|
//
|
||||||
// 使用定时器在指定间隔后执行密钥轮换。
|
// 使用定时器在指定间隔后执行密钥轮换。
|
||||||
func (m *SessionTicketManager) scheduleRotation() {
|
func (m *SessionTicketManager) scheduleRotation() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.scheduleRotationLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SessionTicketManager) scheduleRotationLocked() {
|
||||||
if !m.started {
|
if !m.started {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -74,6 +74,9 @@ type TLSManager struct {
|
|||||||
// certificates 解析后的证书映射,用于 OCSP
|
// certificates 解析后的证书映射,用于 OCSP
|
||||||
certificates map[string]*x509.Certificate
|
certificates map[string]*x509.Certificate
|
||||||
|
|
||||||
|
// defaultCert 默认证书的解析结果,避免每次握手重新解析
|
||||||
|
defaultCert *x509.Certificate
|
||||||
|
|
||||||
// issuers 颁发者证书映射,用于 OCSP
|
// issuers 颁发者证书映射,用于 OCSP
|
||||||
issuers map[string]*x509.Certificate
|
issuers map[string]*x509.Certificate
|
||||||
|
|
||||||
@ -163,7 +166,9 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
|
|||||||
// 解析证书用于 OCSP
|
// 解析证书用于 OCSP
|
||||||
if len(cert.Certificate) > 0 {
|
if len(cert.Certificate) > 0 {
|
||||||
parsedCert, err := x509.ParseCertificate(cert.Certificate[0])
|
parsedCert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||||
if err == nil && len(parsedCert.OCSPServer) > 0 {
|
if err == nil {
|
||||||
|
manager.defaultCert = parsedCert
|
||||||
|
if len(parsedCert.OCSPServer) > 0 {
|
||||||
// 存储证书用于 OCSP 查询
|
// 存储证书用于 OCSP 查询
|
||||||
serial := parsedCert.SerialNumber.String()
|
serial := parsedCert.SerialNumber.String()
|
||||||
manager.certificates[serial] = parsedCert
|
manager.certificates[serial] = parsedCert
|
||||||
@ -178,11 +183,12 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 设置 GetConfigForClient 回调用于 OCSP Stapling
|
// 设置 GetConfigForClient 回调用于 OCSP Stapling
|
||||||
tlsCfg.GetConfigForClient = manager.getConfigForClientWithOCSP
|
tlsCfg.GetConfigForClient = manager.getConfigForClientWithOCSP
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ocspMgr.Start()
|
ocspMgr.Start()
|
||||||
}
|
}
|
||||||
@ -262,10 +268,12 @@ func (m *TLSManager) getConfigForClientWithOCSP(hello *tls.ClientHelloInfo) (*tl
|
|||||||
// 将 OCSP 响应附加到证书
|
// 将 OCSP 响应附加到证书
|
||||||
cert := &cfgCopy.Certificates[0]
|
cert := &cfgCopy.Certificates[0]
|
||||||
if len(cert.Certificate) > 0 {
|
if len(cert.Certificate) > 0 {
|
||||||
// 解析叶子证书以获取序列号
|
// 使用已缓存的证书解析结果获取序列号
|
||||||
leafCert, err := x509.ParseCertificate(cert.Certificate[0])
|
m.mu.RLock()
|
||||||
if err == nil {
|
parsedCert := m.defaultCert
|
||||||
serial := leafCert.SerialNumber.String()
|
m.mu.RUnlock()
|
||||||
|
if parsedCert != nil {
|
||||||
|
serial := parsedCert.SerialNumber.String()
|
||||||
ocspResp := m.ocspManager.GetOCSPResponse(serial)
|
ocspResp := m.ocspManager.GetOCSPResponse(serial)
|
||||||
if ocspResp != nil {
|
if ocspResp != nil {
|
||||||
// 将 OCSP 响应附加到证书
|
// 将 OCSP 响应附加到证书
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user