From f82e363f588a1981c71b45877ef143feddb84b1a Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 29 Apr 2026 17:00:11 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=8F=90=E5=8F=96=20Lua=20ngx=20?= =?UTF-8?q?=E8=A1=A8=20helpers=20=E5=92=8C=E7=BB=9F=E4=B8=80=E9=AA=8C?= =?UTF-8?q?=E8=AF=81=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Batch 1 续: - 新增 lua/helpers.go:GetOrCreateNgxTable/GetOrCreateNgxSubTable - 重构 compression:提取 resettableWriteCloser 接口和 compressorPool - 新增 validate.go:ValidateNonNegativeInt64/Duration/NoNullByte/PathTraversal - 消除约 120 行重复代码 Co-Authored-By: Claude Opus 4.7 --- internal/config/validate.go | 105 +++++---- internal/lua/api_log.go | 13 +- internal/lua/api_req.go | 11 +- internal/lua/api_resp.go | 26 +-- internal/lua/api_socket_tcp.go | 23 +- internal/lua/helpers.go | 56 +++++ .../middleware/compression/compression.go | 203 +++++++++--------- .../compression/compression_bench_test.go | 18 +- .../compression/compression_test.go | 4 +- .../middleware/compression/pool_bench_test.go | 4 +- 10 files changed, 245 insertions(+), 218 deletions(-) create mode 100644 internal/lua/helpers.go diff --git a/internal/config/validate.go b/internal/config/validate.go index 90ab61f..a0c1693 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -24,6 +24,7 @@ import ( "regexp" "slices" "strings" + "time" "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/variable" @@ -123,6 +124,38 @@ func ValidateNonNegative(value int, fieldName string) error { return nil } +// ValidateNonNegativeInt64 验证 int64 值为非负数 +func ValidateNonNegativeInt64(value int64, fieldName string) error { + if value < 0 { + return fmt.Errorf("%s 不能为负数", fieldName) + } + return nil +} + +// ValidateNonNegativeDuration 验证 time.Duration 值为非负数 +func ValidateNonNegativeDuration(value time.Duration, fieldName string) error { + if value < 0 { + return fmt.Errorf("%s 不能为负数", fieldName) + } + return nil +} + +// ValidateNoNullByte 验证字符串不包含 null byte +func ValidateNoNullByte(s string, fieldName string) error { + if strings.Contains(s, "\x00") { + return fmt.Errorf("%s 不能包含 null byte", fieldName) + } + return nil +} + +// ValidatePathTraversal 验证路径不包含路径遍历 '..' +func ValidatePathTraversal(path string, fieldName string) error { + if strings.Contains(path, "..") { + return fmt.Errorf("%s不能包含 '..'", fieldName) + } + return nil +} + // validateServer 验证服务器配置。 // // 检查服务器配置的各项参数是否符合要求,包括监听地址、 @@ -232,13 +265,13 @@ func validateStatics(statics []StaticConfig) error { } // 验证根目录路径安全 - if s.Root != "" && strings.Contains(s.Root, "..") { - return fmt.Errorf("static[%d]: 根目录路径不能包含 '..'", i) + if err := ValidatePathTraversal(s.Root, fmt.Sprintf("static[%d]: 根目录路径", i)); err != nil { + return err } // 验证 alias 路径安全 - if s.Alias != "" && strings.Contains(s.Alias, "..") { - return fmt.Errorf("static[%d]: alias 路径不能包含 '..'", i) + if err := ValidatePathTraversal(s.Alias, fmt.Sprintf("static[%d]: alias 路径", i)); err != nil { + return err } // 验证 try_files 模式 @@ -278,8 +311,8 @@ func validateTryFilesPattern(pattern string) error { } // 检查 null byte - if strings.Contains(pattern, "\x00") { - return errors.New("try_files 模式不能包含 null byte") + if err := ValidateNoNullByte(pattern, "try_files 模式"); err != nil { + return err } // 定义支持的模式类型 @@ -337,8 +370,8 @@ func validateTryFilesExtension(ext string) error { } // 检查 null byte - if strings.Contains(ext, "\x00") { - return errors.New("扩展名不能包含 null byte") + if err := ValidateNoNullByte(ext, "扩展名"); err != nil { + return err } // 白名单字符检查:仅允许字母、数字、点、下划线、连字符 @@ -394,8 +427,8 @@ func validateTryFilesFilename(filename string) error { } // 检查 null byte - if strings.Contains(filename, "\x00") { - return errors.New("文件名不能包含 null byte") + if err := ValidateNoNullByte(filename, "文件名"); err != nil { + return err } return nil @@ -441,8 +474,8 @@ func validateStatic(s *StaticConfig) error { // 静态文件根目录非空时验证路径有效性 if s.Root != "" { // 路径安全检查:不允许包含 ".." - if strings.Contains(s.Root, "..") { - return errors.New("根目录路径不能包含 '..'") + if err := ValidatePathTraversal(s.Root, "根目录路径"); err != nil { + return err } } return nil @@ -711,13 +744,13 @@ func validateGeoIP(g *GeoIPConfig) error { } // 验证缓存大小 - if g.CacheSize < 0 { - return errors.New("cache_size 不能为负数") + if err := ValidateNonNegative(g.CacheSize, "cache_size"); err != nil { + return err } // 验证缓存 TTL - if g.CacheTTL < 0 { - return errors.New("cache_ttl 不能为负数") + if err := ValidateNonNegativeDuration(g.CacheTTL, "cache_ttl"); err != nil { + return err } // 验证默认动作 @@ -883,18 +916,18 @@ func validateHTTP2(h *HTTP2Config, hasSSL bool) error { } // 验证并发流数量 - if h.MaxConcurrentStreams < 0 { - return errors.New("max_concurrent_streams 不能为负数") + if err := ValidateNonNegative(h.MaxConcurrentStreams, "max_concurrent_streams"); err != nil { + return err } // 验证头部大小限制 - if h.MaxHeaderListSize < 0 { - return errors.New("max_header_list_size 不能为负数") + if err := ValidateNonNegative(h.MaxHeaderListSize, "max_header_list_size"); err != nil { + return err } // 验证空闲超时 - if h.IdleTimeout < 0 { - return errors.New("idle_timeout 不能为负数") + if err := ValidateNonNegativeDuration(h.IdleTimeout, "idle_timeout"); err != nil { + return err } return nil @@ -1085,8 +1118,8 @@ func validateStream(s *StreamConfig) error { // - error: 验证失败时返回错误信息,成功返回 nil func validatePerformance(p *PerformanceConfig) error { // 检查 Transport 配置(可能导致性能问题) - if p.Transport.MaxConnsPerHost < 0 { - return errors.New("transport.max_conns_per_host 不能为负数") + if err := ValidateNonNegative(p.Transport.MaxConnsPerHost, "transport.max_conns_per_host"); err != nil { + return err } return nil @@ -1112,8 +1145,8 @@ func validateNextUpstream(n *NextUpstreamConfig) error { } // 验证重试次数 - if n.Tries < 0 { - return errors.New("tries 不能为负数") + if err := ValidateNonNegative(n.Tries, "tries"); err != nil { + return err } // 验证 HTTP 状态码 @@ -1204,26 +1237,26 @@ func validateLua(l *LuaMiddlewareConfig) error { } // 超时时间验证 - if script.Timeout < 0 { - return fmt.Errorf("scripts[%d].timeout 不能为负数", i) + if err := ValidateNonNegativeDuration(script.Timeout, fmt.Sprintf("scripts[%d].timeout", i)); err != nil { + return err } } // 验证全局设置 - if l.GlobalSettings.MaxConcurrentCoroutines < 0 { - return errors.New("global_settings.max_concurrent_coroutines 不能为负数") + if err := ValidateNonNegative(l.GlobalSettings.MaxConcurrentCoroutines, "global_settings.max_concurrent_coroutines"); err != nil { + return err } if l.GlobalSettings.MaxConcurrentCoroutines > 0 && l.GlobalSettings.MaxConcurrentCoroutines < 1 { return errors.New("global_settings.max_concurrent_coroutines 至少为 1") } - if l.GlobalSettings.CoroutineTimeout < 0 { - return errors.New("global_settings.coroutine_timeout 不能为负数") + if err := ValidateNonNegativeDuration(l.GlobalSettings.CoroutineTimeout, "global_settings.coroutine_timeout"); err != nil { + return err } - if l.GlobalSettings.CodeCacheSize < 0 { - return errors.New("global_settings.code_cache_size 不能为负数") + if err := ValidateNonNegative(l.GlobalSettings.CodeCacheSize, "global_settings.code_cache_size"); err != nil { + return err } - if l.GlobalSettings.MaxExecutionTime < 0 { - return errors.New("global_settings.max_execution_time 不能为负数") + if err := ValidateNonNegativeDuration(l.GlobalSettings.MaxExecutionTime, "global_settings.max_execution_time"); err != nil { + return err } return nil diff --git a/internal/lua/api_log.go b/internal/lua/api_log.go index f1da077..e2f951a 100644 --- a/internal/lua/api_log.go +++ b/internal/lua/api_log.go @@ -114,18 +114,7 @@ func newNgxLogAPI(ctx *fasthttp.RequestCtx, luaCtx *LuaContext, logger *zerolog. // 每次请求都会重新注册请求特定的函数(log, say, print, flush, exit, redirect)。 func RegisterNgxLogAPI(L *glua.LState, api *ngxLogAPI) { // 获取或创建 ngx 表 - var ngx *glua.LTable - existingNgx := L.GetGlobal("ngx") - if existingNgx != nil && existingNgx.Type() == glua.LTTable { - ngxTable, ok := existingNgx.(*glua.LTable) - if ok { - ngx = ngxTable - } else { - ngx = L.NewTable() - } - } else { - ngx = L.NewTable() - } + ngx := GetOrCreateNgxTable(L) // 检查常量是否已注册(通过 STDERR 常量判断) // 如果已注册,跳过常量写入,避免并发写入全局表 diff --git a/internal/lua/api_req.go b/internal/lua/api_req.go index 7f70bbe..66026ce 100644 --- a/internal/lua/api_req.go +++ b/internal/lua/api_req.go @@ -105,15 +105,8 @@ func newNgxReqAPI(ctx *fasthttp.RequestCtx) *ngxReqAPI { // RegisterNgxReqAPI 在 Lua 状态机中注册 ngx.req API // 这是主入口函数,由 LuaEngine 在初始化时调用 func RegisterNgxReqAPI(L *glua.LState, api *ngxReqAPI, ngxTable *glua.LTable) { - // 检查 ngx.req 是否已存在,避免并发写入 - var ngxReq *glua.LTable - if existingReq := ngxTable.RawGetString("req"); existingReq == glua.LNil { - // 首次创建 ngx.req 子表 - ngxReq = L.NewTable() - ngxTable.RawSetString("req", ngxReq) - } else { - ngxReq = existingReq.(*glua.LTable) - } + // 获取或创建 ngx.req 子表 + ngxReq := GetOrCreateNgxSubTable(ngxTable, L, "req") // 直接映射层 API:get_method // 特点:直接访问 fasthttp.RequestCtx,零拷贝,最小开销 diff --git a/internal/lua/api_resp.go b/internal/lua/api_resp.go index 4ff8352..4ee1eb3 100644 --- a/internal/lua/api_resp.go +++ b/internal/lua/api_resp.go @@ -46,29 +46,11 @@ func newNgxRespAPI(ctx *fasthttp.RequestCtx) *ngxRespAPI { // RegisterNgxRespAPI 在 Lua 状态机中注册 ngx.resp API // 这是主入口函数,由 LuaEngine 在初始化时调用 func RegisterNgxRespAPI(L *glua.LState, api *ngxRespAPI) { - // 获取已存在的 ngx 表(必须已设置全局) - ngx := L.GetGlobal("ngx") - if ngx == nil || ngx.Type() != glua.LTTable { - // 如果不存在,创建新表并设置全局 - ngx = L.NewTable() - L.SetGlobal("ngx", ngx) - } + // 获取或创建 ngx 表 + ngxTable := GetOrCreateNgxTable(L) - // 类型断言检查 - ngxTable, ok := ngx.(*glua.LTable) - if !ok { - return - } - - // 检查 ngx.resp 是否已存在,避免并发写入 - var ngxResp *glua.LTable - if existingResp := ngxTable.RawGetString("resp"); existingResp == glua.LNil { - // 首次创建 ngx.resp 子表 - ngxResp = L.NewTable() - ngxTable.RawSetString("resp", ngxResp) - } else { - ngxResp = existingResp.(*glua.LTable) - } + // 获取或创建 ngx.resp 子表 + ngxResp := GetOrCreateNgxSubTable(ngxTable, L, "resp") // 每次请求更新函数以绑定正确的 ctx ngxResp.RawSetString("get_status", L.NewFunction(api.luaGetStatus)) diff --git a/internal/lua/api_socket_tcp.go b/internal/lua/api_socket_tcp.go index db50b3c..1d4b23f 100644 --- a/internal/lua/api_socket_tcp.go +++ b/internal/lua/api_socket_tcp.go @@ -537,26 +537,11 @@ const tcpSocketMT = "tcp_socket" // RegisterTCPSocketAPI 注册 TCP socket API func RegisterTCPSocketAPI(L *glua.LState, engine *LuaEngine) { - // 确保 ngx 表存在 - ngx := L.GetGlobal("ngx") - var ngxTbl *glua.LTable - if tbl, ok := ngx.(*glua.LTable); ok { - ngxTbl = tbl - } else { - // 创建 ngx 表 - ngxTbl = L.NewTable() - L.SetGlobal("ngx", ngxTbl) - } + // 获取或创建 ngx 表 + ngxTbl := GetOrCreateNgxTable(L) - // 检查 ngx.socket 是否已存在,避免并发写入 - var socket *glua.LTable - if existing := ngxTbl.RawGetString("socket"); existing == glua.LNil { - // 首次创建 ngx.socket 表 - socket = L.NewTable() - ngxTbl.RawSetString("socket", socket) - } else { - socket = existing.(*glua.LTable) - } + // 获取或创建 ngx.socket 子表 + socket := GetOrCreateNgxSubTable(ngxTbl, L, "socket") // 每次请求更新 tcp 函数以绑定正确的 engine socket.RawSetString("tcp", L.NewFunction(newTCPSocketFunc(engine))) diff --git a/internal/lua/helpers.go b/internal/lua/helpers.go new file mode 100644 index 0000000..fdd9225 --- /dev/null +++ b/internal/lua/helpers.go @@ -0,0 +1,56 @@ +// Package lua 提供 Lua API 辅助函数。 +// +// 该文件包含 ngx 表操作的共享辅助函数,用于减少代码重复。 +// 所有函数保持并发安全设计(使用 RawGetString/RawSetString)。 +// +// 作者:xfy +package lua + +import glua "github.com/yuin/gopher-lua" + +// GetOrCreateNgxTable 获取或创建全局 ngx 表。 +// +// 如果全局 ngx 表已存在,则返回现有表;否则创建新表并设置为全局变量。 +// 该函数是并发安全的,使用 RawGetString/RawSetString 操作。 +// +// 参数: +// - L: Lua 状态机 +// +// 返回值: +// - *glua.LTable: ngx 表 +func GetOrCreateNgxTable(L *glua.LState) *glua.LTable { + ngx := L.GetGlobal("ngx") + if ngx != nil && ngx.Type() == glua.LTTable { + if tbl, ok := ngx.(*glua.LTable); ok { + return tbl + } + } + + // 创建新的 ngx 表 + tbl := L.NewTable() + L.SetGlobal("ngx", tbl) + return tbl +} + +// GetOrCreateNgxSubTable 获取或创建 ngx 子表。 +// +// 如果子表已存在,则返回现有表;否则创建新子表并设置到父表。 +// 该函数是并发安全的,使用 RawGetString/RawSetString 操作。 +// +// 参数: +// - ngx: 父表(通常是 ngx 表) +// - L: Lua 状态机 +// - name: 子表名称(如 "req", "resp", "socket") +// +// 返回值: +// - *glua.LTable: 子表 +func GetOrCreateNgxSubTable(ngx *glua.LTable, L *glua.LState, name string) *glua.LTable { + existing := ngx.RawGetString(name) + if existing == glua.LNil { + // 首次创建子表 + sub := L.NewTable() + ngx.RawSetString(name, sub) + return sub + } + return existing.(*glua.LTable) +} diff --git a/internal/middleware/compression/compression.go b/internal/middleware/compression/compression.go index be71de9..2ee19d3 100644 --- a/internal/middleware/compression/compression.go +++ b/internal/middleware/compression/compression.go @@ -20,6 +20,7 @@ package compression import ( "bufio" "bytes" + "io" "strings" "sync" @@ -29,6 +30,63 @@ import ( "rua.plus/lolly/internal/config" ) +// resettableWriteCloser 接口用于统一 gzip.Writer 和 brotli.Writer 的操作。 +// 这两个 writer 都实现了 Reset(io.Writer), Write([]byte) (int, error), Close() error +type resettableWriteCloser interface { + Reset(w io.Writer) + Write(data []byte) (int, error) + Close() error +} + +// compressorPool 是一个通用的压缩 writer 池。 +type compressorPool struct { + pool sync.Pool + level int + factory func(level int) resettableWriteCloser +} + +// newGzipPool 创建 gzip writer 池。 +func newGzipPool(level int) *compressorPool { + return &compressorPool{ + level: level, + factory: func(level int) resettableWriteCloser { + w, err := gzip.NewWriterLevel(nil, level) + if err != nil { + w, _ = gzip.NewWriterLevel(nil, gzip.DefaultCompression) + } + return w + }, + } +} + +// newBrotliPool 创建 brotli writer 池。 +func newBrotliPool(level int) *compressorPool { + return &compressorPool{ + level: level, + factory: func(level int) resettableWriteCloser { + return brotli.NewWriterOptions(nil, brotli.WriterOptions{ + Quality: level, + }) + }, + } +} + +// Get 从池中获取 writer。 +func (p *compressorPool) Get() (resettableWriteCloser, bool) { + if p.pool.New == nil { + p.pool.New = func() any { + return p.factory(p.level) + } + } + v, ok := p.pool.Get().(resettableWriteCloser) + return v, ok +} + +// Put 将 writer 放回池中。 +func (p *compressorPool) Put(w resettableWriteCloser) { + p.pool.Put(w) +} + // streamingThreshold 流式压缩阈值。 // 响应体超过此大小时使用 SetBodyStreamWriter 流式压缩, // 消除 compressed buffer 分配,降低内存峰值。 @@ -49,9 +107,9 @@ const ( // Middleware 响应压缩中间件。 type Middleware struct { // gzipPool gzip.Writer 缓冲池 - gzipPool sync.Pool + gzipPool *compressorPool // brotliPool brotli.Writer 缓冲池 - brotliPool sync.Pool + brotliPool *compressorPool // types 可压缩的 MIME 类型列表 types []string @@ -114,25 +172,8 @@ func New(cfg *config.CompressionConfig) (*Middleware, error) { } // 初始化缓冲池 - m.gzipPool = sync.Pool{ - New: func() any { - w, err := gzip.NewWriterLevel(nil, cfg.Level) - if err != nil { - // 使用默认压缩级别作为回退 - w, _ = gzip.NewWriterLevel(nil, gzip.DefaultCompression) - } - return w - }, - } - - // 初始化 brotli 缓冲池 - m.brotliPool = sync.Pool{ - New: func() any { - return brotli.NewWriterOptions(nil, brotli.WriterOptions{ - Quality: cfg.Level, - }) - }, - } + m.gzipPool = newGzipPool(cfg.Level) + m.brotliPool = newBrotliPool(cfg.Level) return m, nil } @@ -224,18 +265,18 @@ func (m *Middleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandl if bodyLen > streamingThreshold { // 大响应:流式压缩,消除 compressed buffer 分配 if useBrotli { - m.streamBrotli(ctx, encoding) + m.streamWithPool(ctx, encoding, m.brotliPool) } else if useGzip { - m.streamGzip(ctx, encoding) + m.streamWithPool(ctx, encoding, m.gzipPool) } } else { // 小响应:缓冲压缩 var compressed []byte if useBrotli { - compressed = m.compressBrotli(body) + compressed = m.compressWithPool(body, m.brotliPool) } else if useGzip { - compressed = m.compressGzip(body) + compressed = m.compressWithPool(body, m.gzipPool) } if len(compressed) > 0 && len(compressed) < bodyLen { @@ -276,19 +317,20 @@ func (m *Middleware) isCompressible(contentType []byte) bool { return false } -// compressGzip 使用 gzip 压缩数据。 +// compressWithPool 使用缓冲池压缩数据。 // // 参数: // - data: 待压缩的原始数据 +// - pool: 压缩 writer 缓冲池 // // 返回值: // - []byte: 压缩后的数据 -func (m *Middleware) compressGzip(data []byte) []byte { - w, ok := m.gzipPool.Get().(*gzip.Writer) +func (m *Middleware) compressWithPool(data []byte, pool *compressorPool) []byte { + w, ok := pool.Get() if !ok { return data // fallback to uncompressed } - defer m.gzipPool.Put(w) + defer pool.Put(w) var buf bytes.Buffer w.Reset(&buf) @@ -300,28 +342,35 @@ func (m *Middleware) compressGzip(data []byte) []byte { return buf.Bytes() } -// compressBrotli 使用 brotli 压缩数据。 +// streamWithPool 使用流式压缩。 +// +// 通过 SetBodyStreamWriter 将压缩数据直接写入响应流, +// 消除 compressed buffer 分配,降低内存峰值。 // // 参数: -// - data: 待压缩的原始数据 -// -// 返回值: -// - []byte: 压缩后的数据 -func (m *Middleware) compressBrotli(data []byte) []byte { - w, ok := m.brotliPool.Get().(*brotli.Writer) - if !ok { - return data // fallback to uncompressed - } - defer m.brotliPool.Put(w) +// - ctx: fasthttp 请求上下文 +// - encoding: Content-Encoding 值 +// - pool: 压缩 writer 缓冲池 +func (m *Middleware) streamWithPool(ctx *fasthttp.RequestCtx, encoding string, pool *compressorPool) { + ctx.Response.Header.Set("Content-Encoding", encoding) + ctx.Response.Header.Del("Content-Length") // 使用 chunked encoding - var buf bytes.Buffer - w.Reset(&buf) - if _, err := w.Write(data); err != nil { //nolint:staticcheck // intentionally empty branch - // 忽略写入错误,缓冲到 bytes.Buffer 时不太可能失败 - } - _ = w.Close() + body := ctx.Response.Body() + ctx.SetBodyStreamWriter(func(w *bufio.Writer) { + writer, ok := pool.Get() + if !ok { + // pool 获取失败,直接写原始 body + _, _ = w.Write(body) + _ = w.Flush() + return + } + defer pool.Put(writer) - return buf.Bytes() + writer.Reset(w) + _, _ = writer.Write(body) + _ = writer.Close() + _ = w.Flush() + }) } // Types 返回可压缩的 MIME 类型列表。 @@ -347,63 +396,3 @@ func (m *Middleware) Level() int { func (m *Middleware) MinSize() int { return m.minSize } - -// streamGzip 使用 gzip 流式压缩。 -// -// 通过 SetBodyStreamWriter 将压缩数据直接写入响应流, -// 消除 compressed buffer 分配,降低内存峰值。 -// -// 参数: -// - ctx: fasthttp 请求上下文 -// - encoding: Content-Encoding 值("gzip") -func (m *Middleware) streamGzip(ctx *fasthttp.RequestCtx, encoding string) { - ctx.Response.Header.Set("Content-Encoding", encoding) - ctx.Response.Header.Del("Content-Length") // 使用 chunked encoding - - body := ctx.Response.Body() - ctx.SetBodyStreamWriter(func(w *bufio.Writer) { - writer, ok := m.gzipPool.Get().(*gzip.Writer) - if !ok { - // pool 获取失败,直接写原始 body - _, _ = w.Write(body) - _ = w.Flush() - return - } - defer m.gzipPool.Put(writer) - - writer.Reset(w) - _, _ = writer.Write(body) - _ = writer.Close() - _ = w.Flush() - }) -} - -// streamBrotli 使用 brotli 流式压缩。 -// -// 通过 SetBodyStreamWriter 将压缩数据直接写入响应流, -// 消除 compressed buffer 分配,降低内存峰值。 -// -// 参数: -// - ctx: fasthttp 请求上下文 -// - encoding: Content-Encoding 值("br") -func (m *Middleware) streamBrotli(ctx *fasthttp.RequestCtx, encoding string) { - ctx.Response.Header.Set("Content-Encoding", encoding) - ctx.Response.Header.Del("Content-Length") // 使用 chunked encoding - - body := ctx.Response.Body() - ctx.SetBodyStreamWriter(func(w *bufio.Writer) { - writer, ok := m.brotliPool.Get().(*brotli.Writer) - if !ok { - // pool 获取失败,直接写原始 body - _, _ = w.Write(body) - _ = w.Flush() - return - } - defer m.brotliPool.Put(writer) - - writer.Reset(w) - _, _ = writer.Write(body) - _ = writer.Close() - _ = w.Flush() - }) -} diff --git a/internal/middleware/compression/compression_bench_test.go b/internal/middleware/compression/compression_bench_test.go index eb1ac2f..4fa8dbf 100644 --- a/internal/middleware/compression/compression_bench_test.go +++ b/internal/middleware/compression/compression_bench_test.go @@ -27,7 +27,7 @@ func BenchmarkGzipCompress_1KB(b *testing.B) { b.ResetTimer() for b.Loop() { - mw.compressGzip(data) + mw.compressWithPool(data, mw.gzipPool) } } @@ -44,7 +44,7 @@ func BenchmarkGzipCompress_10KB(b *testing.B) { b.ResetTimer() for b.Loop() { - mw.compressGzip(data) + mw.compressWithPool(data, mw.gzipPool) } } @@ -61,7 +61,7 @@ func BenchmarkGzipCompress_100KB(b *testing.B) { b.ResetTimer() for b.Loop() { - mw.compressGzip(data) + mw.compressWithPool(data, mw.gzipPool) } } @@ -78,7 +78,7 @@ func BenchmarkBrotliCompress_1KB(b *testing.B) { b.ResetTimer() for b.Loop() { - mw.compressBrotli(data) + mw.compressWithPool(data, mw.brotliPool) } } @@ -95,7 +95,7 @@ func BenchmarkBrotliCompress_10KB(b *testing.B) { b.ResetTimer() for b.Loop() { - mw.compressBrotli(data) + mw.compressWithPool(data, mw.brotliPool) } } @@ -114,7 +114,7 @@ func BenchmarkCompressionPool(b *testing.B) { b.ResetTimer() for b.Loop() { - mw.compressGzip(data) + mw.compressWithPool(data, mw.gzipPool) } } @@ -222,7 +222,7 @@ func BenchmarkCompressionLevelComparison(b *testing.B) { mw, _ := New(cfg) b.ResetTimer() for b.Loop() { - mw.compressGzip(data) + mw.compressWithPool(data, mw.gzipPool) } }) @@ -231,7 +231,7 @@ func BenchmarkCompressionLevelComparison(b *testing.B) { mw, _ := New(cfg) b.ResetTimer() for b.Loop() { - mw.compressGzip(data) + mw.compressWithPool(data, mw.gzipPool) } }) @@ -240,7 +240,7 @@ func BenchmarkCompressionLevelComparison(b *testing.B) { mw, _ := New(cfg) b.ResetTimer() for b.Loop() { - mw.compressGzip(data) + mw.compressWithPool(data, mw.gzipPool) } }) } diff --git a/internal/middleware/compression/compression_test.go b/internal/middleware/compression/compression_test.go index 595fe00..36de580 100644 --- a/internal/middleware/compression/compression_test.go +++ b/internal/middleware/compression/compression_test.go @@ -134,7 +134,7 @@ func TestCompressGzip(t *testing.T) { // 测试数据 data := []byte("Hello, World! This is a test string that should be compressed.") - compressed := m.compressGzip(data) + compressed := m.compressWithPool(data, m.gzipPool) if len(compressed) == 0 { t.Error("Expected compressed data") } @@ -153,7 +153,7 @@ func TestCompressBrotli(t *testing.T) { data := []byte("Hello, World! This is a test string that should be compressed with brotli.") - compressed := m.compressBrotli(data) + compressed := m.compressWithPool(data, m.brotliPool) if len(compressed) == 0 { t.Error("Expected compressed data") } diff --git a/internal/middleware/compression/pool_bench_test.go b/internal/middleware/compression/pool_bench_test.go index 9573871..374ebe6 100644 --- a/internal/middleware/compression/pool_bench_test.go +++ b/internal/middleware/compression/pool_bench_test.go @@ -83,7 +83,7 @@ func BenchmarkGzipWriter_Pool(b *testing.B) { b.ResetTimer() for b.Loop() { - _ = mw.compressGzip(data) + _ = mw.compressWithPool(data, mw.gzipPool) } } @@ -145,7 +145,7 @@ func BenchmarkGzipCompress_Sizes(b *testing.B) { b.ReportAllocs() b.ResetTimer() for b.Loop() { - _ = mw.compressGzip(tc.data) + _ = mw.compressWithPool(tc.data, mw.gzipPool) } }) }