refactor: 提取 Lua ngx 表 helpers 和统一验证函数

Batch 1 续:
- 新增 lua/helpers.go:GetOrCreateNgxTable/GetOrCreateNgxSubTable
- 重构 compression:提取 resettableWriteCloser 接口和 compressorPool
- 新增 validate.go:ValidateNonNegativeInt64/Duration/NoNullByte/PathTraversal
- 消除约 120 行重复代码

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-29 17:00:11 +08:00
parent 91e04222b3
commit f82e363f58
10 changed files with 245 additions and 218 deletions

View File

@ -24,6 +24,7 @@ import (
"regexp"
"slices"
"strings"
"time"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/variable"
@ -123,6 +124,38 @@ func ValidateNonNegative(value int, fieldName string) error {
return nil
}
// ValidateNonNegativeInt64 验证 int64 值为非负数
func ValidateNonNegativeInt64(value int64, fieldName string) error {
if value < 0 {
return fmt.Errorf("%s 不能为负数", fieldName)
}
return nil
}
// ValidateNonNegativeDuration 验证 time.Duration 值为非负数
func ValidateNonNegativeDuration(value time.Duration, fieldName string) error {
if value < 0 {
return fmt.Errorf("%s 不能为负数", fieldName)
}
return nil
}
// ValidateNoNullByte 验证字符串不包含 null byte
func ValidateNoNullByte(s string, fieldName string) error {
if strings.Contains(s, "\x00") {
return fmt.Errorf("%s 不能包含 null byte", fieldName)
}
return nil
}
// ValidatePathTraversal 验证路径不包含路径遍历 '..'
func ValidatePathTraversal(path string, fieldName string) error {
if strings.Contains(path, "..") {
return fmt.Errorf("%s不能包含 '..'", fieldName)
}
return nil
}
// validateServer 验证服务器配置。
//
// 检查服务器配置的各项参数是否符合要求,包括监听地址、
@ -232,13 +265,13 @@ func validateStatics(statics []StaticConfig) error {
}
// 验证根目录路径安全
if s.Root != "" && strings.Contains(s.Root, "..") {
return fmt.Errorf("static[%d]: 根目录路径不能包含 '..'", i)
if err := ValidatePathTraversal(s.Root, fmt.Sprintf("static[%d]: 根目录路径", i)); err != nil {
return err
}
// 验证 alias 路径安全
if s.Alias != "" && strings.Contains(s.Alias, "..") {
return fmt.Errorf("static[%d]: alias 路径不能包含 '..'", i)
if err := ValidatePathTraversal(s.Alias, fmt.Sprintf("static[%d]: alias 路径", i)); err != nil {
return err
}
// 验证 try_files 模式
@ -278,8 +311,8 @@ func validateTryFilesPattern(pattern string) error {
}
// 检查 null byte
if strings.Contains(pattern, "\x00") {
return errors.New("try_files 模式不能包含 null byte")
if err := ValidateNoNullByte(pattern, "try_files 模式"); err != nil {
return err
}
// 定义支持的模式类型
@ -337,8 +370,8 @@ func validateTryFilesExtension(ext string) error {
}
// 检查 null byte
if strings.Contains(ext, "\x00") {
return errors.New("扩展名不能包含 null byte")
if err := ValidateNoNullByte(ext, "扩展名"); err != nil {
return err
}
// 白名单字符检查:仅允许字母、数字、点、下划线、连字符
@ -394,8 +427,8 @@ func validateTryFilesFilename(filename string) error {
}
// 检查 null byte
if strings.Contains(filename, "\x00") {
return errors.New("文件名不能包含 null byte")
if err := ValidateNoNullByte(filename, "文件名"); err != nil {
return err
}
return nil
@ -441,8 +474,8 @@ func validateStatic(s *StaticConfig) error {
// 静态文件根目录非空时验证路径有效性
if s.Root != "" {
// 路径安全检查:不允许包含 ".."
if strings.Contains(s.Root, "..") {
return errors.New("根目录路径不能包含 '..'")
if err := ValidatePathTraversal(s.Root, "根目录路径"); err != nil {
return err
}
}
return nil
@ -711,13 +744,13 @@ func validateGeoIP(g *GeoIPConfig) error {
}
// 验证缓存大小
if g.CacheSize < 0 {
return errors.New("cache_size 不能为负数")
if err := ValidateNonNegative(g.CacheSize, "cache_size"); err != nil {
return err
}
// 验证缓存 TTL
if g.CacheTTL < 0 {
return errors.New("cache_ttl 不能为负数")
if err := ValidateNonNegativeDuration(g.CacheTTL, "cache_ttl"); err != nil {
return err
}
// 验证默认动作
@ -883,18 +916,18 @@ func validateHTTP2(h *HTTP2Config, hasSSL bool) error {
}
// 验证并发流数量
if h.MaxConcurrentStreams < 0 {
return errors.New("max_concurrent_streams 不能为负数")
if err := ValidateNonNegative(h.MaxConcurrentStreams, "max_concurrent_streams"); err != nil {
return err
}
// 验证头部大小限制
if h.MaxHeaderListSize < 0 {
return errors.New("max_header_list_size 不能为负数")
if err := ValidateNonNegative(h.MaxHeaderListSize, "max_header_list_size"); err != nil {
return err
}
// 验证空闲超时
if h.IdleTimeout < 0 {
return errors.New("idle_timeout 不能为负数")
if err := ValidateNonNegativeDuration(h.IdleTimeout, "idle_timeout"); err != nil {
return err
}
return nil
@ -1085,8 +1118,8 @@ func validateStream(s *StreamConfig) error {
// - error: 验证失败时返回错误信息,成功返回 nil
func validatePerformance(p *PerformanceConfig) error {
// 检查 Transport 配置(可能导致性能问题)
if p.Transport.MaxConnsPerHost < 0 {
return errors.New("transport.max_conns_per_host 不能为负数")
if err := ValidateNonNegative(p.Transport.MaxConnsPerHost, "transport.max_conns_per_host"); err != nil {
return err
}
return nil
@ -1112,8 +1145,8 @@ func validateNextUpstream(n *NextUpstreamConfig) error {
}
// 验证重试次数
if n.Tries < 0 {
return errors.New("tries 不能为负数")
if err := ValidateNonNegative(n.Tries, "tries"); err != nil {
return err
}
// 验证 HTTP 状态码
@ -1204,26 +1237,26 @@ func validateLua(l *LuaMiddlewareConfig) error {
}
// 超时时间验证
if script.Timeout < 0 {
return fmt.Errorf("scripts[%d].timeout 不能为负数", i)
if err := ValidateNonNegativeDuration(script.Timeout, fmt.Sprintf("scripts[%d].timeout", i)); err != nil {
return err
}
}
// 验证全局设置
if l.GlobalSettings.MaxConcurrentCoroutines < 0 {
return errors.New("global_settings.max_concurrent_coroutines 不能为负数")
if err := ValidateNonNegative(l.GlobalSettings.MaxConcurrentCoroutines, "global_settings.max_concurrent_coroutines"); err != nil {
return err
}
if l.GlobalSettings.MaxConcurrentCoroutines > 0 && l.GlobalSettings.MaxConcurrentCoroutines < 1 {
return errors.New("global_settings.max_concurrent_coroutines 至少为 1")
}
if l.GlobalSettings.CoroutineTimeout < 0 {
return errors.New("global_settings.coroutine_timeout 不能为负数")
if err := ValidateNonNegativeDuration(l.GlobalSettings.CoroutineTimeout, "global_settings.coroutine_timeout"); err != nil {
return err
}
if l.GlobalSettings.CodeCacheSize < 0 {
return errors.New("global_settings.code_cache_size 不能为负数")
if err := ValidateNonNegative(l.GlobalSettings.CodeCacheSize, "global_settings.code_cache_size"); err != nil {
return err
}
if l.GlobalSettings.MaxExecutionTime < 0 {
return errors.New("global_settings.max_execution_time 不能为负数")
if err := ValidateNonNegativeDuration(l.GlobalSettings.MaxExecutionTime, "global_settings.max_execution_time"); err != nil {
return err
}
return nil

View File

@ -114,18 +114,7 @@ func newNgxLogAPI(ctx *fasthttp.RequestCtx, luaCtx *LuaContext, logger *zerolog.
// 每次请求都会重新注册请求特定的函数log, say, print, flush, exit, redirect
func RegisterNgxLogAPI(L *glua.LState, api *ngxLogAPI) {
// 获取或创建 ngx 表
var ngx *glua.LTable
existingNgx := L.GetGlobal("ngx")
if existingNgx != nil && existingNgx.Type() == glua.LTTable {
ngxTable, ok := existingNgx.(*glua.LTable)
if ok {
ngx = ngxTable
} else {
ngx = L.NewTable()
}
} else {
ngx = L.NewTable()
}
ngx := GetOrCreateNgxTable(L)
// 检查常量是否已注册(通过 STDERR 常量判断)
// 如果已注册,跳过常量写入,避免并发写入全局表

View File

@ -105,15 +105,8 @@ func newNgxReqAPI(ctx *fasthttp.RequestCtx) *ngxReqAPI {
// RegisterNgxReqAPI 在 Lua 状态机中注册 ngx.req API
// 这是主入口函数,由 LuaEngine 在初始化时调用
func RegisterNgxReqAPI(L *glua.LState, api *ngxReqAPI, ngxTable *glua.LTable) {
// 检查 ngx.req 是否已存在,避免并发写入
var ngxReq *glua.LTable
if existingReq := ngxTable.RawGetString("req"); existingReq == glua.LNil {
// 首次创建 ngx.req 子表
ngxReq = L.NewTable()
ngxTable.RawSetString("req", ngxReq)
} else {
ngxReq = existingReq.(*glua.LTable)
}
// 获取或创建 ngx.req 子表
ngxReq := GetOrCreateNgxSubTable(ngxTable, L, "req")
// 直接映射层 APIget_method
// 特点:直接访问 fasthttp.RequestCtx零拷贝最小开销

View File

@ -46,29 +46,11 @@ func newNgxRespAPI(ctx *fasthttp.RequestCtx) *ngxRespAPI {
// RegisterNgxRespAPI 在 Lua 状态机中注册 ngx.resp API
// 这是主入口函数,由 LuaEngine 在初始化时调用
func RegisterNgxRespAPI(L *glua.LState, api *ngxRespAPI) {
// 获取已存在的 ngx 表(必须已设置全局)
ngx := L.GetGlobal("ngx")
if ngx == nil || ngx.Type() != glua.LTTable {
// 如果不存在,创建新表并设置全局
ngx = L.NewTable()
L.SetGlobal("ngx", ngx)
}
// 获取或创建 ngx 表
ngxTable := GetOrCreateNgxTable(L)
// 类型断言检查
ngxTable, ok := ngx.(*glua.LTable)
if !ok {
return
}
// 检查 ngx.resp 是否已存在,避免并发写入
var ngxResp *glua.LTable
if existingResp := ngxTable.RawGetString("resp"); existingResp == glua.LNil {
// 首次创建 ngx.resp 子表
ngxResp = L.NewTable()
ngxTable.RawSetString("resp", ngxResp)
} else {
ngxResp = existingResp.(*glua.LTable)
}
// 获取或创建 ngx.resp 子表
ngxResp := GetOrCreateNgxSubTable(ngxTable, L, "resp")
// 每次请求更新函数以绑定正确的 ctx
ngxResp.RawSetString("get_status", L.NewFunction(api.luaGetStatus))

View File

@ -537,26 +537,11 @@ const tcpSocketMT = "tcp_socket"
// RegisterTCPSocketAPI 注册 TCP socket API
func RegisterTCPSocketAPI(L *glua.LState, engine *LuaEngine) {
// 确保 ngx 表存在
ngx := L.GetGlobal("ngx")
var ngxTbl *glua.LTable
if tbl, ok := ngx.(*glua.LTable); ok {
ngxTbl = tbl
} else {
// 创建 ngx 表
ngxTbl = L.NewTable()
L.SetGlobal("ngx", ngxTbl)
}
// 获取或创建 ngx 表
ngxTbl := GetOrCreateNgxTable(L)
// 检查 ngx.socket 是否已存在,避免并发写入
var socket *glua.LTable
if existing := ngxTbl.RawGetString("socket"); existing == glua.LNil {
// 首次创建 ngx.socket 表
socket = L.NewTable()
ngxTbl.RawSetString("socket", socket)
} else {
socket = existing.(*glua.LTable)
}
// 获取或创建 ngx.socket 子表
socket := GetOrCreateNgxSubTable(ngxTbl, L, "socket")
// 每次请求更新 tcp 函数以绑定正确的 engine
socket.RawSetString("tcp", L.NewFunction(newTCPSocketFunc(engine)))

56
internal/lua/helpers.go Normal file
View File

@ -0,0 +1,56 @@
// Package lua 提供 Lua API 辅助函数。
//
// 该文件包含 ngx 表操作的共享辅助函数,用于减少代码重复。
// 所有函数保持并发安全设计(使用 RawGetString/RawSetString
//
// 作者xfy
package lua
import glua "github.com/yuin/gopher-lua"
// GetOrCreateNgxTable 获取或创建全局 ngx 表。
//
// 如果全局 ngx 表已存在,则返回现有表;否则创建新表并设置为全局变量。
// 该函数是并发安全的,使用 RawGetString/RawSetString 操作。
//
// 参数:
// - L: Lua 状态机
//
// 返回值:
// - *glua.LTable: ngx 表
func GetOrCreateNgxTable(L *glua.LState) *glua.LTable {
ngx := L.GetGlobal("ngx")
if ngx != nil && ngx.Type() == glua.LTTable {
if tbl, ok := ngx.(*glua.LTable); ok {
return tbl
}
}
// 创建新的 ngx 表
tbl := L.NewTable()
L.SetGlobal("ngx", tbl)
return tbl
}
// GetOrCreateNgxSubTable 获取或创建 ngx 子表。
//
// 如果子表已存在,则返回现有表;否则创建新子表并设置到父表。
// 该函数是并发安全的,使用 RawGetString/RawSetString 操作。
//
// 参数:
// - ngx: 父表(通常是 ngx 表)
// - L: Lua 状态机
// - name: 子表名称(如 "req", "resp", "socket"
//
// 返回值:
// - *glua.LTable: 子表
func GetOrCreateNgxSubTable(ngx *glua.LTable, L *glua.LState, name string) *glua.LTable {
existing := ngx.RawGetString(name)
if existing == glua.LNil {
// 首次创建子表
sub := L.NewTable()
ngx.RawSetString(name, sub)
return sub
}
return existing.(*glua.LTable)
}

View File

@ -20,6 +20,7 @@ package compression
import (
"bufio"
"bytes"
"io"
"strings"
"sync"
@ -29,6 +30,63 @@ import (
"rua.plus/lolly/internal/config"
)
// resettableWriteCloser 接口用于统一 gzip.Writer 和 brotli.Writer 的操作。
// 这两个 writer 都实现了 Reset(io.Writer), Write([]byte) (int, error), Close() error
type resettableWriteCloser interface {
Reset(w io.Writer)
Write(data []byte) (int, error)
Close() error
}
// compressorPool 是一个通用的压缩 writer 池。
type compressorPool struct {
pool sync.Pool
level int
factory func(level int) resettableWriteCloser
}
// newGzipPool 创建 gzip writer 池。
func newGzipPool(level int) *compressorPool {
return &compressorPool{
level: level,
factory: func(level int) resettableWriteCloser {
w, err := gzip.NewWriterLevel(nil, level)
if err != nil {
w, _ = gzip.NewWriterLevel(nil, gzip.DefaultCompression)
}
return w
},
}
}
// newBrotliPool 创建 brotli writer 池。
func newBrotliPool(level int) *compressorPool {
return &compressorPool{
level: level,
factory: func(level int) resettableWriteCloser {
return brotli.NewWriterOptions(nil, brotli.WriterOptions{
Quality: level,
})
},
}
}
// Get 从池中获取 writer。
func (p *compressorPool) Get() (resettableWriteCloser, bool) {
if p.pool.New == nil {
p.pool.New = func() any {
return p.factory(p.level)
}
}
v, ok := p.pool.Get().(resettableWriteCloser)
return v, ok
}
// Put 将 writer 放回池中。
func (p *compressorPool) Put(w resettableWriteCloser) {
p.pool.Put(w)
}
// streamingThreshold 流式压缩阈值。
// 响应体超过此大小时使用 SetBodyStreamWriter 流式压缩,
// 消除 compressed buffer 分配,降低内存峰值。
@ -49,9 +107,9 @@ const (
// Middleware 响应压缩中间件。
type Middleware struct {
// gzipPool gzip.Writer 缓冲池
gzipPool sync.Pool
gzipPool *compressorPool
// brotliPool brotli.Writer 缓冲池
brotliPool sync.Pool
brotliPool *compressorPool
// types 可压缩的 MIME 类型列表
types []string
@ -114,25 +172,8 @@ func New(cfg *config.CompressionConfig) (*Middleware, error) {
}
// 初始化缓冲池
m.gzipPool = sync.Pool{
New: func() any {
w, err := gzip.NewWriterLevel(nil, cfg.Level)
if err != nil {
// 使用默认压缩级别作为回退
w, _ = gzip.NewWriterLevel(nil, gzip.DefaultCompression)
}
return w
},
}
// 初始化 brotli 缓冲池
m.brotliPool = sync.Pool{
New: func() any {
return brotli.NewWriterOptions(nil, brotli.WriterOptions{
Quality: cfg.Level,
})
},
}
m.gzipPool = newGzipPool(cfg.Level)
m.brotliPool = newBrotliPool(cfg.Level)
return m, nil
}
@ -224,18 +265,18 @@ func (m *Middleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandl
if bodyLen > streamingThreshold {
// 大响应:流式压缩,消除 compressed buffer 分配
if useBrotli {
m.streamBrotli(ctx, encoding)
m.streamWithPool(ctx, encoding, m.brotliPool)
} else if useGzip {
m.streamGzip(ctx, encoding)
m.streamWithPool(ctx, encoding, m.gzipPool)
}
} else {
// 小响应:缓冲压缩
var compressed []byte
if useBrotli {
compressed = m.compressBrotli(body)
compressed = m.compressWithPool(body, m.brotliPool)
} else if useGzip {
compressed = m.compressGzip(body)
compressed = m.compressWithPool(body, m.gzipPool)
}
if len(compressed) > 0 && len(compressed) < bodyLen {
@ -276,19 +317,20 @@ func (m *Middleware) isCompressible(contentType []byte) bool {
return false
}
// compressGzip 使用 gzip 压缩数据。
// compressWithPool 使用缓冲池压缩数据。
//
// 参数:
// - data: 待压缩的原始数据
// - pool: 压缩 writer 缓冲池
//
// 返回值:
// - []byte: 压缩后的数据
func (m *Middleware) compressGzip(data []byte) []byte {
w, ok := m.gzipPool.Get().(*gzip.Writer)
func (m *Middleware) compressWithPool(data []byte, pool *compressorPool) []byte {
w, ok := pool.Get()
if !ok {
return data // fallback to uncompressed
}
defer m.gzipPool.Put(w)
defer pool.Put(w)
var buf bytes.Buffer
w.Reset(&buf)
@ -300,28 +342,35 @@ func (m *Middleware) compressGzip(data []byte) []byte {
return buf.Bytes()
}
// compressBrotli 使用 brotli 压缩数据。
// streamWithPool 使用流式压缩。
//
// 通过 SetBodyStreamWriter 将压缩数据直接写入响应流,
// 消除 compressed buffer 分配,降低内存峰值。
//
// 参数:
// - data: 待压缩的原始数据
//
// 返回值:
// - []byte: 压缩后的数据
func (m *Middleware) compressBrotli(data []byte) []byte {
w, ok := m.brotliPool.Get().(*brotli.Writer)
if !ok {
return data // fallback to uncompressed
}
defer m.brotliPool.Put(w)
// - ctx: fasthttp 请求上下文
// - encoding: Content-Encoding 值
// - pool: 压缩 writer 缓冲池
func (m *Middleware) streamWithPool(ctx *fasthttp.RequestCtx, encoding string, pool *compressorPool) {
ctx.Response.Header.Set("Content-Encoding", encoding)
ctx.Response.Header.Del("Content-Length") // 使用 chunked encoding
var buf bytes.Buffer
w.Reset(&buf)
if _, err := w.Write(data); err != nil { //nolint:staticcheck // intentionally empty branch
// 忽略写入错误,缓冲到 bytes.Buffer 时不太可能失败
}
_ = w.Close()
body := ctx.Response.Body()
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
writer, ok := pool.Get()
if !ok {
// pool 获取失败,直接写原始 body
_, _ = w.Write(body)
_ = w.Flush()
return
}
defer pool.Put(writer)
return buf.Bytes()
writer.Reset(w)
_, _ = writer.Write(body)
_ = writer.Close()
_ = w.Flush()
})
}
// Types 返回可压缩的 MIME 类型列表。
@ -347,63 +396,3 @@ func (m *Middleware) Level() int {
func (m *Middleware) MinSize() int {
return m.minSize
}
// streamGzip 使用 gzip 流式压缩。
//
// 通过 SetBodyStreamWriter 将压缩数据直接写入响应流,
// 消除 compressed buffer 分配,降低内存峰值。
//
// 参数:
// - ctx: fasthttp 请求上下文
// - encoding: Content-Encoding 值("gzip"
func (m *Middleware) streamGzip(ctx *fasthttp.RequestCtx, encoding string) {
ctx.Response.Header.Set("Content-Encoding", encoding)
ctx.Response.Header.Del("Content-Length") // 使用 chunked encoding
body := ctx.Response.Body()
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
writer, ok := m.gzipPool.Get().(*gzip.Writer)
if !ok {
// pool 获取失败,直接写原始 body
_, _ = w.Write(body)
_ = w.Flush()
return
}
defer m.gzipPool.Put(writer)
writer.Reset(w)
_, _ = writer.Write(body)
_ = writer.Close()
_ = w.Flush()
})
}
// streamBrotli 使用 brotli 流式压缩。
//
// 通过 SetBodyStreamWriter 将压缩数据直接写入响应流,
// 消除 compressed buffer 分配,降低内存峰值。
//
// 参数:
// - ctx: fasthttp 请求上下文
// - encoding: Content-Encoding 值("br"
func (m *Middleware) streamBrotli(ctx *fasthttp.RequestCtx, encoding string) {
ctx.Response.Header.Set("Content-Encoding", encoding)
ctx.Response.Header.Del("Content-Length") // 使用 chunked encoding
body := ctx.Response.Body()
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
writer, ok := m.brotliPool.Get().(*brotli.Writer)
if !ok {
// pool 获取失败,直接写原始 body
_, _ = w.Write(body)
_ = w.Flush()
return
}
defer m.brotliPool.Put(writer)
writer.Reset(w)
_, _ = writer.Write(body)
_ = writer.Close()
_ = w.Flush()
})
}

View File

@ -27,7 +27,7 @@ func BenchmarkGzipCompress_1KB(b *testing.B) {
b.ResetTimer()
for b.Loop() {
mw.compressGzip(data)
mw.compressWithPool(data, mw.gzipPool)
}
}
@ -44,7 +44,7 @@ func BenchmarkGzipCompress_10KB(b *testing.B) {
b.ResetTimer()
for b.Loop() {
mw.compressGzip(data)
mw.compressWithPool(data, mw.gzipPool)
}
}
@ -61,7 +61,7 @@ func BenchmarkGzipCompress_100KB(b *testing.B) {
b.ResetTimer()
for b.Loop() {
mw.compressGzip(data)
mw.compressWithPool(data, mw.gzipPool)
}
}
@ -78,7 +78,7 @@ func BenchmarkBrotliCompress_1KB(b *testing.B) {
b.ResetTimer()
for b.Loop() {
mw.compressBrotli(data)
mw.compressWithPool(data, mw.brotliPool)
}
}
@ -95,7 +95,7 @@ func BenchmarkBrotliCompress_10KB(b *testing.B) {
b.ResetTimer()
for b.Loop() {
mw.compressBrotli(data)
mw.compressWithPool(data, mw.brotliPool)
}
}
@ -114,7 +114,7 @@ func BenchmarkCompressionPool(b *testing.B) {
b.ResetTimer()
for b.Loop() {
mw.compressGzip(data)
mw.compressWithPool(data, mw.gzipPool)
}
}
@ -222,7 +222,7 @@ func BenchmarkCompressionLevelComparison(b *testing.B) {
mw, _ := New(cfg)
b.ResetTimer()
for b.Loop() {
mw.compressGzip(data)
mw.compressWithPool(data, mw.gzipPool)
}
})
@ -231,7 +231,7 @@ func BenchmarkCompressionLevelComparison(b *testing.B) {
mw, _ := New(cfg)
b.ResetTimer()
for b.Loop() {
mw.compressGzip(data)
mw.compressWithPool(data, mw.gzipPool)
}
})
@ -240,7 +240,7 @@ func BenchmarkCompressionLevelComparison(b *testing.B) {
mw, _ := New(cfg)
b.ResetTimer()
for b.Loop() {
mw.compressGzip(data)
mw.compressWithPool(data, mw.gzipPool)
}
})
}

View File

@ -134,7 +134,7 @@ func TestCompressGzip(t *testing.T) {
// 测试数据
data := []byte("Hello, World! This is a test string that should be compressed.")
compressed := m.compressGzip(data)
compressed := m.compressWithPool(data, m.gzipPool)
if len(compressed) == 0 {
t.Error("Expected compressed data")
}
@ -153,7 +153,7 @@ func TestCompressBrotli(t *testing.T) {
data := []byte("Hello, World! This is a test string that should be compressed with brotli.")
compressed := m.compressBrotli(data)
compressed := m.compressWithPool(data, m.brotliPool)
if len(compressed) == 0 {
t.Error("Expected compressed data")
}

View File

@ -83,7 +83,7 @@ func BenchmarkGzipWriter_Pool(b *testing.B) {
b.ResetTimer()
for b.Loop() {
_ = mw.compressGzip(data)
_ = mw.compressWithPool(data, mw.gzipPool)
}
}
@ -145,7 +145,7 @@ func BenchmarkGzipCompress_Sizes(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for b.Loop() {
_ = mw.compressGzip(tc.data)
_ = mw.compressWithPool(tc.data, mw.gzipPool)
}
})
}