refactor: 提取 Lua ngx 表 helpers 和统一验证函数
Batch 1 续: - 新增 lua/helpers.go:GetOrCreateNgxTable/GetOrCreateNgxSubTable - 重构 compression:提取 resettableWriteCloser 接口和 compressorPool - 新增 validate.go:ValidateNonNegativeInt64/Duration/NoNullByte/PathTraversal - 消除约 120 行重复代码 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
91e04222b3
commit
f82e363f58
@ -24,6 +24,7 @@ import (
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/variable"
|
||||
@ -123,6 +124,38 @@ func ValidateNonNegative(value int, fieldName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateNonNegativeInt64 验证 int64 值为非负数
|
||||
func ValidateNonNegativeInt64(value int64, fieldName string) error {
|
||||
if value < 0 {
|
||||
return fmt.Errorf("%s 不能为负数", fieldName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateNonNegativeDuration 验证 time.Duration 值为非负数
|
||||
func ValidateNonNegativeDuration(value time.Duration, fieldName string) error {
|
||||
if value < 0 {
|
||||
return fmt.Errorf("%s 不能为负数", fieldName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateNoNullByte 验证字符串不包含 null byte
|
||||
func ValidateNoNullByte(s string, fieldName string) error {
|
||||
if strings.Contains(s, "\x00") {
|
||||
return fmt.Errorf("%s 不能包含 null byte", fieldName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePathTraversal 验证路径不包含路径遍历 '..'
|
||||
func ValidatePathTraversal(path string, fieldName string) error {
|
||||
if strings.Contains(path, "..") {
|
||||
return fmt.Errorf("%s不能包含 '..'", fieldName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateServer 验证服务器配置。
|
||||
//
|
||||
// 检查服务器配置的各项参数是否符合要求,包括监听地址、
|
||||
@ -232,13 +265,13 @@ func validateStatics(statics []StaticConfig) error {
|
||||
}
|
||||
|
||||
// 验证根目录路径安全
|
||||
if s.Root != "" && strings.Contains(s.Root, "..") {
|
||||
return fmt.Errorf("static[%d]: 根目录路径不能包含 '..'", i)
|
||||
if err := ValidatePathTraversal(s.Root, fmt.Sprintf("static[%d]: 根目录路径", i)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证 alias 路径安全
|
||||
if s.Alias != "" && strings.Contains(s.Alias, "..") {
|
||||
return fmt.Errorf("static[%d]: alias 路径不能包含 '..'", i)
|
||||
if err := ValidatePathTraversal(s.Alias, fmt.Sprintf("static[%d]: alias 路径", i)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证 try_files 模式
|
||||
@ -278,8 +311,8 @@ func validateTryFilesPattern(pattern string) error {
|
||||
}
|
||||
|
||||
// 检查 null byte
|
||||
if strings.Contains(pattern, "\x00") {
|
||||
return errors.New("try_files 模式不能包含 null byte")
|
||||
if err := ValidateNoNullByte(pattern, "try_files 模式"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 定义支持的模式类型
|
||||
@ -337,8 +370,8 @@ func validateTryFilesExtension(ext string) error {
|
||||
}
|
||||
|
||||
// 检查 null byte
|
||||
if strings.Contains(ext, "\x00") {
|
||||
return errors.New("扩展名不能包含 null byte")
|
||||
if err := ValidateNoNullByte(ext, "扩展名"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 白名单字符检查:仅允许字母、数字、点、下划线、连字符
|
||||
@ -394,8 +427,8 @@ func validateTryFilesFilename(filename string) error {
|
||||
}
|
||||
|
||||
// 检查 null byte
|
||||
if strings.Contains(filename, "\x00") {
|
||||
return errors.New("文件名不能包含 null byte")
|
||||
if err := ValidateNoNullByte(filename, "文件名"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -441,8 +474,8 @@ func validateStatic(s *StaticConfig) error {
|
||||
// 静态文件根目录非空时验证路径有效性
|
||||
if s.Root != "" {
|
||||
// 路径安全检查:不允许包含 ".."
|
||||
if strings.Contains(s.Root, "..") {
|
||||
return errors.New("根目录路径不能包含 '..'")
|
||||
if err := ValidatePathTraversal(s.Root, "根目录路径"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@ -711,13 +744,13 @@ func validateGeoIP(g *GeoIPConfig) error {
|
||||
}
|
||||
|
||||
// 验证缓存大小
|
||||
if g.CacheSize < 0 {
|
||||
return errors.New("cache_size 不能为负数")
|
||||
if err := ValidateNonNegative(g.CacheSize, "cache_size"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证缓存 TTL
|
||||
if g.CacheTTL < 0 {
|
||||
return errors.New("cache_ttl 不能为负数")
|
||||
if err := ValidateNonNegativeDuration(g.CacheTTL, "cache_ttl"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证默认动作
|
||||
@ -883,18 +916,18 @@ func validateHTTP2(h *HTTP2Config, hasSSL bool) error {
|
||||
}
|
||||
|
||||
// 验证并发流数量
|
||||
if h.MaxConcurrentStreams < 0 {
|
||||
return errors.New("max_concurrent_streams 不能为负数")
|
||||
if err := ValidateNonNegative(h.MaxConcurrentStreams, "max_concurrent_streams"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证头部大小限制
|
||||
if h.MaxHeaderListSize < 0 {
|
||||
return errors.New("max_header_list_size 不能为负数")
|
||||
if err := ValidateNonNegative(h.MaxHeaderListSize, "max_header_list_size"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证空闲超时
|
||||
if h.IdleTimeout < 0 {
|
||||
return errors.New("idle_timeout 不能为负数")
|
||||
if err := ValidateNonNegativeDuration(h.IdleTimeout, "idle_timeout"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -1085,8 +1118,8 @@ func validateStream(s *StreamConfig) error {
|
||||
// - error: 验证失败时返回错误信息,成功返回 nil
|
||||
func validatePerformance(p *PerformanceConfig) error {
|
||||
// 检查 Transport 配置(可能导致性能问题)
|
||||
if p.Transport.MaxConnsPerHost < 0 {
|
||||
return errors.New("transport.max_conns_per_host 不能为负数")
|
||||
if err := ValidateNonNegative(p.Transport.MaxConnsPerHost, "transport.max_conns_per_host"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -1112,8 +1145,8 @@ func validateNextUpstream(n *NextUpstreamConfig) error {
|
||||
}
|
||||
|
||||
// 验证重试次数
|
||||
if n.Tries < 0 {
|
||||
return errors.New("tries 不能为负数")
|
||||
if err := ValidateNonNegative(n.Tries, "tries"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证 HTTP 状态码
|
||||
@ -1204,26 +1237,26 @@ func validateLua(l *LuaMiddlewareConfig) error {
|
||||
}
|
||||
|
||||
// 超时时间验证
|
||||
if script.Timeout < 0 {
|
||||
return fmt.Errorf("scripts[%d].timeout 不能为负数", i)
|
||||
if err := ValidateNonNegativeDuration(script.Timeout, fmt.Sprintf("scripts[%d].timeout", i)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 验证全局设置
|
||||
if l.GlobalSettings.MaxConcurrentCoroutines < 0 {
|
||||
return errors.New("global_settings.max_concurrent_coroutines 不能为负数")
|
||||
if err := ValidateNonNegative(l.GlobalSettings.MaxConcurrentCoroutines, "global_settings.max_concurrent_coroutines"); err != nil {
|
||||
return err
|
||||
}
|
||||
if l.GlobalSettings.MaxConcurrentCoroutines > 0 && l.GlobalSettings.MaxConcurrentCoroutines < 1 {
|
||||
return errors.New("global_settings.max_concurrent_coroutines 至少为 1")
|
||||
}
|
||||
if l.GlobalSettings.CoroutineTimeout < 0 {
|
||||
return errors.New("global_settings.coroutine_timeout 不能为负数")
|
||||
if err := ValidateNonNegativeDuration(l.GlobalSettings.CoroutineTimeout, "global_settings.coroutine_timeout"); err != nil {
|
||||
return err
|
||||
}
|
||||
if l.GlobalSettings.CodeCacheSize < 0 {
|
||||
return errors.New("global_settings.code_cache_size 不能为负数")
|
||||
if err := ValidateNonNegative(l.GlobalSettings.CodeCacheSize, "global_settings.code_cache_size"); err != nil {
|
||||
return err
|
||||
}
|
||||
if l.GlobalSettings.MaxExecutionTime < 0 {
|
||||
return errors.New("global_settings.max_execution_time 不能为负数")
|
||||
if err := ValidateNonNegativeDuration(l.GlobalSettings.MaxExecutionTime, "global_settings.max_execution_time"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@ -114,18 +114,7 @@ func newNgxLogAPI(ctx *fasthttp.RequestCtx, luaCtx *LuaContext, logger *zerolog.
|
||||
// 每次请求都会重新注册请求特定的函数(log, say, print, flush, exit, redirect)。
|
||||
func RegisterNgxLogAPI(L *glua.LState, api *ngxLogAPI) {
|
||||
// 获取或创建 ngx 表
|
||||
var ngx *glua.LTable
|
||||
existingNgx := L.GetGlobal("ngx")
|
||||
if existingNgx != nil && existingNgx.Type() == glua.LTTable {
|
||||
ngxTable, ok := existingNgx.(*glua.LTable)
|
||||
if ok {
|
||||
ngx = ngxTable
|
||||
} else {
|
||||
ngx = L.NewTable()
|
||||
}
|
||||
} else {
|
||||
ngx = L.NewTable()
|
||||
}
|
||||
ngx := GetOrCreateNgxTable(L)
|
||||
|
||||
// 检查常量是否已注册(通过 STDERR 常量判断)
|
||||
// 如果已注册,跳过常量写入,避免并发写入全局表
|
||||
|
||||
@ -105,15 +105,8 @@ func newNgxReqAPI(ctx *fasthttp.RequestCtx) *ngxReqAPI {
|
||||
// RegisterNgxReqAPI 在 Lua 状态机中注册 ngx.req API
|
||||
// 这是主入口函数,由 LuaEngine 在初始化时调用
|
||||
func RegisterNgxReqAPI(L *glua.LState, api *ngxReqAPI, ngxTable *glua.LTable) {
|
||||
// 检查 ngx.req 是否已存在,避免并发写入
|
||||
var ngxReq *glua.LTable
|
||||
if existingReq := ngxTable.RawGetString("req"); existingReq == glua.LNil {
|
||||
// 首次创建 ngx.req 子表
|
||||
ngxReq = L.NewTable()
|
||||
ngxTable.RawSetString("req", ngxReq)
|
||||
} else {
|
||||
ngxReq = existingReq.(*glua.LTable)
|
||||
}
|
||||
// 获取或创建 ngx.req 子表
|
||||
ngxReq := GetOrCreateNgxSubTable(ngxTable, L, "req")
|
||||
|
||||
// 直接映射层 API:get_method
|
||||
// 特点:直接访问 fasthttp.RequestCtx,零拷贝,最小开销
|
||||
|
||||
@ -46,29 +46,11 @@ func newNgxRespAPI(ctx *fasthttp.RequestCtx) *ngxRespAPI {
|
||||
// RegisterNgxRespAPI 在 Lua 状态机中注册 ngx.resp API
|
||||
// 这是主入口函数,由 LuaEngine 在初始化时调用
|
||||
func RegisterNgxRespAPI(L *glua.LState, api *ngxRespAPI) {
|
||||
// 获取已存在的 ngx 表(必须已设置全局)
|
||||
ngx := L.GetGlobal("ngx")
|
||||
if ngx == nil || ngx.Type() != glua.LTTable {
|
||||
// 如果不存在,创建新表并设置全局
|
||||
ngx = L.NewTable()
|
||||
L.SetGlobal("ngx", ngx)
|
||||
}
|
||||
// 获取或创建 ngx 表
|
||||
ngxTable := GetOrCreateNgxTable(L)
|
||||
|
||||
// 类型断言检查
|
||||
ngxTable, ok := ngx.(*glua.LTable)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 ngx.resp 是否已存在,避免并发写入
|
||||
var ngxResp *glua.LTable
|
||||
if existingResp := ngxTable.RawGetString("resp"); existingResp == glua.LNil {
|
||||
// 首次创建 ngx.resp 子表
|
||||
ngxResp = L.NewTable()
|
||||
ngxTable.RawSetString("resp", ngxResp)
|
||||
} else {
|
||||
ngxResp = existingResp.(*glua.LTable)
|
||||
}
|
||||
// 获取或创建 ngx.resp 子表
|
||||
ngxResp := GetOrCreateNgxSubTable(ngxTable, L, "resp")
|
||||
|
||||
// 每次请求更新函数以绑定正确的 ctx
|
||||
ngxResp.RawSetString("get_status", L.NewFunction(api.luaGetStatus))
|
||||
|
||||
@ -537,26 +537,11 @@ const tcpSocketMT = "tcp_socket"
|
||||
|
||||
// RegisterTCPSocketAPI 注册 TCP socket API
|
||||
func RegisterTCPSocketAPI(L *glua.LState, engine *LuaEngine) {
|
||||
// 确保 ngx 表存在
|
||||
ngx := L.GetGlobal("ngx")
|
||||
var ngxTbl *glua.LTable
|
||||
if tbl, ok := ngx.(*glua.LTable); ok {
|
||||
ngxTbl = tbl
|
||||
} else {
|
||||
// 创建 ngx 表
|
||||
ngxTbl = L.NewTable()
|
||||
L.SetGlobal("ngx", ngxTbl)
|
||||
}
|
||||
// 获取或创建 ngx 表
|
||||
ngxTbl := GetOrCreateNgxTable(L)
|
||||
|
||||
// 检查 ngx.socket 是否已存在,避免并发写入
|
||||
var socket *glua.LTable
|
||||
if existing := ngxTbl.RawGetString("socket"); existing == glua.LNil {
|
||||
// 首次创建 ngx.socket 表
|
||||
socket = L.NewTable()
|
||||
ngxTbl.RawSetString("socket", socket)
|
||||
} else {
|
||||
socket = existing.(*glua.LTable)
|
||||
}
|
||||
// 获取或创建 ngx.socket 子表
|
||||
socket := GetOrCreateNgxSubTable(ngxTbl, L, "socket")
|
||||
|
||||
// 每次请求更新 tcp 函数以绑定正确的 engine
|
||||
socket.RawSetString("tcp", L.NewFunction(newTCPSocketFunc(engine)))
|
||||
|
||||
56
internal/lua/helpers.go
Normal file
56
internal/lua/helpers.go
Normal file
@ -0,0 +1,56 @@
|
||||
// Package lua 提供 Lua API 辅助函数。
|
||||
//
|
||||
// 该文件包含 ngx 表操作的共享辅助函数,用于减少代码重复。
|
||||
// 所有函数保持并发安全设计(使用 RawGetString/RawSetString)。
|
||||
//
|
||||
// 作者:xfy
|
||||
package lua
|
||||
|
||||
import glua "github.com/yuin/gopher-lua"
|
||||
|
||||
// GetOrCreateNgxTable 获取或创建全局 ngx 表。
|
||||
//
|
||||
// 如果全局 ngx 表已存在,则返回现有表;否则创建新表并设置为全局变量。
|
||||
// 该函数是并发安全的,使用 RawGetString/RawSetString 操作。
|
||||
//
|
||||
// 参数:
|
||||
// - L: Lua 状态机
|
||||
//
|
||||
// 返回值:
|
||||
// - *glua.LTable: ngx 表
|
||||
func GetOrCreateNgxTable(L *glua.LState) *glua.LTable {
|
||||
ngx := L.GetGlobal("ngx")
|
||||
if ngx != nil && ngx.Type() == glua.LTTable {
|
||||
if tbl, ok := ngx.(*glua.LTable); ok {
|
||||
return tbl
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新的 ngx 表
|
||||
tbl := L.NewTable()
|
||||
L.SetGlobal("ngx", tbl)
|
||||
return tbl
|
||||
}
|
||||
|
||||
// GetOrCreateNgxSubTable 获取或创建 ngx 子表。
|
||||
//
|
||||
// 如果子表已存在,则返回现有表;否则创建新子表并设置到父表。
|
||||
// 该函数是并发安全的,使用 RawGetString/RawSetString 操作。
|
||||
//
|
||||
// 参数:
|
||||
// - ngx: 父表(通常是 ngx 表)
|
||||
// - L: Lua 状态机
|
||||
// - name: 子表名称(如 "req", "resp", "socket")
|
||||
//
|
||||
// 返回值:
|
||||
// - *glua.LTable: 子表
|
||||
func GetOrCreateNgxSubTable(ngx *glua.LTable, L *glua.LState, name string) *glua.LTable {
|
||||
existing := ngx.RawGetString(name)
|
||||
if existing == glua.LNil {
|
||||
// 首次创建子表
|
||||
sub := L.NewTable()
|
||||
ngx.RawSetString(name, sub)
|
||||
return sub
|
||||
}
|
||||
return existing.(*glua.LTable)
|
||||
}
|
||||
@ -20,6 +20,7 @@ package compression
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@ -29,6 +30,63 @@ import (
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// resettableWriteCloser 接口用于统一 gzip.Writer 和 brotli.Writer 的操作。
|
||||
// 这两个 writer 都实现了 Reset(io.Writer), Write([]byte) (int, error), Close() error
|
||||
type resettableWriteCloser interface {
|
||||
Reset(w io.Writer)
|
||||
Write(data []byte) (int, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// compressorPool 是一个通用的压缩 writer 池。
|
||||
type compressorPool struct {
|
||||
pool sync.Pool
|
||||
level int
|
||||
factory func(level int) resettableWriteCloser
|
||||
}
|
||||
|
||||
// newGzipPool 创建 gzip writer 池。
|
||||
func newGzipPool(level int) *compressorPool {
|
||||
return &compressorPool{
|
||||
level: level,
|
||||
factory: func(level int) resettableWriteCloser {
|
||||
w, err := gzip.NewWriterLevel(nil, level)
|
||||
if err != nil {
|
||||
w, _ = gzip.NewWriterLevel(nil, gzip.DefaultCompression)
|
||||
}
|
||||
return w
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newBrotliPool 创建 brotli writer 池。
|
||||
func newBrotliPool(level int) *compressorPool {
|
||||
return &compressorPool{
|
||||
level: level,
|
||||
factory: func(level int) resettableWriteCloser {
|
||||
return brotli.NewWriterOptions(nil, brotli.WriterOptions{
|
||||
Quality: level,
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get 从池中获取 writer。
|
||||
func (p *compressorPool) Get() (resettableWriteCloser, bool) {
|
||||
if p.pool.New == nil {
|
||||
p.pool.New = func() any {
|
||||
return p.factory(p.level)
|
||||
}
|
||||
}
|
||||
v, ok := p.pool.Get().(resettableWriteCloser)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Put 将 writer 放回池中。
|
||||
func (p *compressorPool) Put(w resettableWriteCloser) {
|
||||
p.pool.Put(w)
|
||||
}
|
||||
|
||||
// streamingThreshold 流式压缩阈值。
|
||||
// 响应体超过此大小时使用 SetBodyStreamWriter 流式压缩,
|
||||
// 消除 compressed buffer 分配,降低内存峰值。
|
||||
@ -49,9 +107,9 @@ const (
|
||||
// Middleware 响应压缩中间件。
|
||||
type Middleware struct {
|
||||
// gzipPool gzip.Writer 缓冲池
|
||||
gzipPool sync.Pool
|
||||
gzipPool *compressorPool
|
||||
// brotliPool brotli.Writer 缓冲池
|
||||
brotliPool sync.Pool
|
||||
brotliPool *compressorPool
|
||||
// types 可压缩的 MIME 类型列表
|
||||
types []string
|
||||
|
||||
@ -114,25 +172,8 @@ func New(cfg *config.CompressionConfig) (*Middleware, error) {
|
||||
}
|
||||
|
||||
// 初始化缓冲池
|
||||
m.gzipPool = sync.Pool{
|
||||
New: func() any {
|
||||
w, err := gzip.NewWriterLevel(nil, cfg.Level)
|
||||
if err != nil {
|
||||
// 使用默认压缩级别作为回退
|
||||
w, _ = gzip.NewWriterLevel(nil, gzip.DefaultCompression)
|
||||
}
|
||||
return w
|
||||
},
|
||||
}
|
||||
|
||||
// 初始化 brotli 缓冲池
|
||||
m.brotliPool = sync.Pool{
|
||||
New: func() any {
|
||||
return brotli.NewWriterOptions(nil, brotli.WriterOptions{
|
||||
Quality: cfg.Level,
|
||||
})
|
||||
},
|
||||
}
|
||||
m.gzipPool = newGzipPool(cfg.Level)
|
||||
m.brotliPool = newBrotliPool(cfg.Level)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
@ -224,18 +265,18 @@ func (m *Middleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandl
|
||||
if bodyLen > streamingThreshold {
|
||||
// 大响应:流式压缩,消除 compressed buffer 分配
|
||||
if useBrotli {
|
||||
m.streamBrotli(ctx, encoding)
|
||||
m.streamWithPool(ctx, encoding, m.brotliPool)
|
||||
} else if useGzip {
|
||||
m.streamGzip(ctx, encoding)
|
||||
m.streamWithPool(ctx, encoding, m.gzipPool)
|
||||
}
|
||||
} else {
|
||||
// 小响应:缓冲压缩
|
||||
var compressed []byte
|
||||
|
||||
if useBrotli {
|
||||
compressed = m.compressBrotli(body)
|
||||
compressed = m.compressWithPool(body, m.brotliPool)
|
||||
} else if useGzip {
|
||||
compressed = m.compressGzip(body)
|
||||
compressed = m.compressWithPool(body, m.gzipPool)
|
||||
}
|
||||
|
||||
if len(compressed) > 0 && len(compressed) < bodyLen {
|
||||
@ -276,19 +317,20 @@ func (m *Middleware) isCompressible(contentType []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// compressGzip 使用 gzip 压缩数据。
|
||||
// compressWithPool 使用缓冲池压缩数据。
|
||||
//
|
||||
// 参数:
|
||||
// - data: 待压缩的原始数据
|
||||
// - pool: 压缩 writer 缓冲池
|
||||
//
|
||||
// 返回值:
|
||||
// - []byte: 压缩后的数据
|
||||
func (m *Middleware) compressGzip(data []byte) []byte {
|
||||
w, ok := m.gzipPool.Get().(*gzip.Writer)
|
||||
func (m *Middleware) compressWithPool(data []byte, pool *compressorPool) []byte {
|
||||
w, ok := pool.Get()
|
||||
if !ok {
|
||||
return data // fallback to uncompressed
|
||||
}
|
||||
defer m.gzipPool.Put(w)
|
||||
defer pool.Put(w)
|
||||
|
||||
var buf bytes.Buffer
|
||||
w.Reset(&buf)
|
||||
@ -300,28 +342,35 @@ func (m *Middleware) compressGzip(data []byte) []byte {
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// compressBrotli 使用 brotli 压缩数据。
|
||||
// streamWithPool 使用流式压缩。
|
||||
//
|
||||
// 通过 SetBodyStreamWriter 将压缩数据直接写入响应流,
|
||||
// 消除 compressed buffer 分配,降低内存峰值。
|
||||
//
|
||||
// 参数:
|
||||
// - data: 待压缩的原始数据
|
||||
//
|
||||
// 返回值:
|
||||
// - []byte: 压缩后的数据
|
||||
func (m *Middleware) compressBrotli(data []byte) []byte {
|
||||
w, ok := m.brotliPool.Get().(*brotli.Writer)
|
||||
if !ok {
|
||||
return data // fallback to uncompressed
|
||||
}
|
||||
defer m.brotliPool.Put(w)
|
||||
// - ctx: fasthttp 请求上下文
|
||||
// - encoding: Content-Encoding 值
|
||||
// - pool: 压缩 writer 缓冲池
|
||||
func (m *Middleware) streamWithPool(ctx *fasthttp.RequestCtx, encoding string, pool *compressorPool) {
|
||||
ctx.Response.Header.Set("Content-Encoding", encoding)
|
||||
ctx.Response.Header.Del("Content-Length") // 使用 chunked encoding
|
||||
|
||||
var buf bytes.Buffer
|
||||
w.Reset(&buf)
|
||||
if _, err := w.Write(data); err != nil { //nolint:staticcheck // intentionally empty branch
|
||||
// 忽略写入错误,缓冲到 bytes.Buffer 时不太可能失败
|
||||
}
|
||||
_ = w.Close()
|
||||
body := ctx.Response.Body()
|
||||
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
|
||||
writer, ok := pool.Get()
|
||||
if !ok {
|
||||
// pool 获取失败,直接写原始 body
|
||||
_, _ = w.Write(body)
|
||||
_ = w.Flush()
|
||||
return
|
||||
}
|
||||
defer pool.Put(writer)
|
||||
|
||||
return buf.Bytes()
|
||||
writer.Reset(w)
|
||||
_, _ = writer.Write(body)
|
||||
_ = writer.Close()
|
||||
_ = w.Flush()
|
||||
})
|
||||
}
|
||||
|
||||
// Types 返回可压缩的 MIME 类型列表。
|
||||
@ -347,63 +396,3 @@ func (m *Middleware) Level() int {
|
||||
func (m *Middleware) MinSize() int {
|
||||
return m.minSize
|
||||
}
|
||||
|
||||
// streamGzip 使用 gzip 流式压缩。
|
||||
//
|
||||
// 通过 SetBodyStreamWriter 将压缩数据直接写入响应流,
|
||||
// 消除 compressed buffer 分配,降低内存峰值。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: fasthttp 请求上下文
|
||||
// - encoding: Content-Encoding 值("gzip")
|
||||
func (m *Middleware) streamGzip(ctx *fasthttp.RequestCtx, encoding string) {
|
||||
ctx.Response.Header.Set("Content-Encoding", encoding)
|
||||
ctx.Response.Header.Del("Content-Length") // 使用 chunked encoding
|
||||
|
||||
body := ctx.Response.Body()
|
||||
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
|
||||
writer, ok := m.gzipPool.Get().(*gzip.Writer)
|
||||
if !ok {
|
||||
// pool 获取失败,直接写原始 body
|
||||
_, _ = w.Write(body)
|
||||
_ = w.Flush()
|
||||
return
|
||||
}
|
||||
defer m.gzipPool.Put(writer)
|
||||
|
||||
writer.Reset(w)
|
||||
_, _ = writer.Write(body)
|
||||
_ = writer.Close()
|
||||
_ = w.Flush()
|
||||
})
|
||||
}
|
||||
|
||||
// streamBrotli 使用 brotli 流式压缩。
|
||||
//
|
||||
// 通过 SetBodyStreamWriter 将压缩数据直接写入响应流,
|
||||
// 消除 compressed buffer 分配,降低内存峰值。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: fasthttp 请求上下文
|
||||
// - encoding: Content-Encoding 值("br")
|
||||
func (m *Middleware) streamBrotli(ctx *fasthttp.RequestCtx, encoding string) {
|
||||
ctx.Response.Header.Set("Content-Encoding", encoding)
|
||||
ctx.Response.Header.Del("Content-Length") // 使用 chunked encoding
|
||||
|
||||
body := ctx.Response.Body()
|
||||
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
|
||||
writer, ok := m.brotliPool.Get().(*brotli.Writer)
|
||||
if !ok {
|
||||
// pool 获取失败,直接写原始 body
|
||||
_, _ = w.Write(body)
|
||||
_ = w.Flush()
|
||||
return
|
||||
}
|
||||
defer m.brotliPool.Put(writer)
|
||||
|
||||
writer.Reset(w)
|
||||
_, _ = writer.Write(body)
|
||||
_ = writer.Close()
|
||||
_ = w.Flush()
|
||||
})
|
||||
}
|
||||
|
||||
@ -27,7 +27,7 @@ func BenchmarkGzipCompress_1KB(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
mw.compressGzip(data)
|
||||
mw.compressWithPool(data, mw.gzipPool)
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,7 +44,7 @@ func BenchmarkGzipCompress_10KB(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
mw.compressGzip(data)
|
||||
mw.compressWithPool(data, mw.gzipPool)
|
||||
}
|
||||
}
|
||||
|
||||
@ -61,7 +61,7 @@ func BenchmarkGzipCompress_100KB(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
mw.compressGzip(data)
|
||||
mw.compressWithPool(data, mw.gzipPool)
|
||||
}
|
||||
}
|
||||
|
||||
@ -78,7 +78,7 @@ func BenchmarkBrotliCompress_1KB(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
mw.compressBrotli(data)
|
||||
mw.compressWithPool(data, mw.brotliPool)
|
||||
}
|
||||
}
|
||||
|
||||
@ -95,7 +95,7 @@ func BenchmarkBrotliCompress_10KB(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
mw.compressBrotli(data)
|
||||
mw.compressWithPool(data, mw.brotliPool)
|
||||
}
|
||||
}
|
||||
|
||||
@ -114,7 +114,7 @@ func BenchmarkCompressionPool(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
mw.compressGzip(data)
|
||||
mw.compressWithPool(data, mw.gzipPool)
|
||||
}
|
||||
}
|
||||
|
||||
@ -222,7 +222,7 @@ func BenchmarkCompressionLevelComparison(b *testing.B) {
|
||||
mw, _ := New(cfg)
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
mw.compressGzip(data)
|
||||
mw.compressWithPool(data, mw.gzipPool)
|
||||
}
|
||||
})
|
||||
|
||||
@ -231,7 +231,7 @@ func BenchmarkCompressionLevelComparison(b *testing.B) {
|
||||
mw, _ := New(cfg)
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
mw.compressGzip(data)
|
||||
mw.compressWithPool(data, mw.gzipPool)
|
||||
}
|
||||
})
|
||||
|
||||
@ -240,7 +240,7 @@ func BenchmarkCompressionLevelComparison(b *testing.B) {
|
||||
mw, _ := New(cfg)
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
mw.compressGzip(data)
|
||||
mw.compressWithPool(data, mw.gzipPool)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -134,7 +134,7 @@ func TestCompressGzip(t *testing.T) {
|
||||
// 测试数据
|
||||
data := []byte("Hello, World! This is a test string that should be compressed.")
|
||||
|
||||
compressed := m.compressGzip(data)
|
||||
compressed := m.compressWithPool(data, m.gzipPool)
|
||||
if len(compressed) == 0 {
|
||||
t.Error("Expected compressed data")
|
||||
}
|
||||
@ -153,7 +153,7 @@ func TestCompressBrotli(t *testing.T) {
|
||||
|
||||
data := []byte("Hello, World! This is a test string that should be compressed with brotli.")
|
||||
|
||||
compressed := m.compressBrotli(data)
|
||||
compressed := m.compressWithPool(data, m.brotliPool)
|
||||
if len(compressed) == 0 {
|
||||
t.Error("Expected compressed data")
|
||||
}
|
||||
|
||||
@ -83,7 +83,7 @@ func BenchmarkGzipWriter_Pool(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
|
||||
for b.Loop() {
|
||||
_ = mw.compressGzip(data)
|
||||
_ = mw.compressWithPool(data, mw.gzipPool)
|
||||
}
|
||||
}
|
||||
|
||||
@ -145,7 +145,7 @@ func BenchmarkGzipCompress_Sizes(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for b.Loop() {
|
||||
_ = mw.compressGzip(tc.data)
|
||||
_ = mw.compressWithPool(tc.data, mw.gzipPool)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user