refactor: 提取 Lua ngx 表 helpers 和统一验证函数
Batch 1 续: - 新增 lua/helpers.go:GetOrCreateNgxTable/GetOrCreateNgxSubTable - 重构 compression:提取 resettableWriteCloser 接口和 compressorPool - 新增 validate.go:ValidateNonNegativeInt64/Duration/NoNullByte/PathTraversal - 消除约 120 行重复代码 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
91e04222b3
commit
f82e363f58
@ -24,6 +24,7 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"rua.plus/lolly/internal/loadbalance"
|
"rua.plus/lolly/internal/loadbalance"
|
||||||
"rua.plus/lolly/internal/variable"
|
"rua.plus/lolly/internal/variable"
|
||||||
@ -123,6 +124,38 @@ func ValidateNonNegative(value int, fieldName string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateNonNegativeInt64 验证 int64 值为非负数
|
||||||
|
func ValidateNonNegativeInt64(value int64, fieldName string) error {
|
||||||
|
if value < 0 {
|
||||||
|
return fmt.Errorf("%s 不能为负数", fieldName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateNonNegativeDuration 验证 time.Duration 值为非负数
|
||||||
|
func ValidateNonNegativeDuration(value time.Duration, fieldName string) error {
|
||||||
|
if value < 0 {
|
||||||
|
return fmt.Errorf("%s 不能为负数", fieldName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateNoNullByte 验证字符串不包含 null byte
|
||||||
|
func ValidateNoNullByte(s string, fieldName string) error {
|
||||||
|
if strings.Contains(s, "\x00") {
|
||||||
|
return fmt.Errorf("%s 不能包含 null byte", fieldName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidatePathTraversal 验证路径不包含路径遍历 '..'
|
||||||
|
func ValidatePathTraversal(path string, fieldName string) error {
|
||||||
|
if strings.Contains(path, "..") {
|
||||||
|
return fmt.Errorf("%s不能包含 '..'", fieldName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// validateServer 验证服务器配置。
|
// validateServer 验证服务器配置。
|
||||||
//
|
//
|
||||||
// 检查服务器配置的各项参数是否符合要求,包括监听地址、
|
// 检查服务器配置的各项参数是否符合要求,包括监听地址、
|
||||||
@ -232,13 +265,13 @@ func validateStatics(statics []StaticConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证根目录路径安全
|
// 验证根目录路径安全
|
||||||
if s.Root != "" && strings.Contains(s.Root, "..") {
|
if err := ValidatePathTraversal(s.Root, fmt.Sprintf("static[%d]: 根目录路径", i)); err != nil {
|
||||||
return fmt.Errorf("static[%d]: 根目录路径不能包含 '..'", i)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证 alias 路径安全
|
// 验证 alias 路径安全
|
||||||
if s.Alias != "" && strings.Contains(s.Alias, "..") {
|
if err := ValidatePathTraversal(s.Alias, fmt.Sprintf("static[%d]: alias 路径", i)); err != nil {
|
||||||
return fmt.Errorf("static[%d]: alias 路径不能包含 '..'", i)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证 try_files 模式
|
// 验证 try_files 模式
|
||||||
@ -278,8 +311,8 @@ func validateTryFilesPattern(pattern string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 检查 null byte
|
// 检查 null byte
|
||||||
if strings.Contains(pattern, "\x00") {
|
if err := ValidateNoNullByte(pattern, "try_files 模式"); err != nil {
|
||||||
return errors.New("try_files 模式不能包含 null byte")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 定义支持的模式类型
|
// 定义支持的模式类型
|
||||||
@ -337,8 +370,8 @@ func validateTryFilesExtension(ext string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 检查 null byte
|
// 检查 null byte
|
||||||
if strings.Contains(ext, "\x00") {
|
if err := ValidateNoNullByte(ext, "扩展名"); err != nil {
|
||||||
return errors.New("扩展名不能包含 null byte")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 白名单字符检查:仅允许字母、数字、点、下划线、连字符
|
// 白名单字符检查:仅允许字母、数字、点、下划线、连字符
|
||||||
@ -394,8 +427,8 @@ func validateTryFilesFilename(filename string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 检查 null byte
|
// 检查 null byte
|
||||||
if strings.Contains(filename, "\x00") {
|
if err := ValidateNoNullByte(filename, "文件名"); err != nil {
|
||||||
return errors.New("文件名不能包含 null byte")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -441,8 +474,8 @@ func validateStatic(s *StaticConfig) error {
|
|||||||
// 静态文件根目录非空时验证路径有效性
|
// 静态文件根目录非空时验证路径有效性
|
||||||
if s.Root != "" {
|
if s.Root != "" {
|
||||||
// 路径安全检查:不允许包含 ".."
|
// 路径安全检查:不允许包含 ".."
|
||||||
if strings.Contains(s.Root, "..") {
|
if err := ValidatePathTraversal(s.Root, "根目录路径"); err != nil {
|
||||||
return errors.New("根目录路径不能包含 '..'")
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -711,13 +744,13 @@ func validateGeoIP(g *GeoIPConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证缓存大小
|
// 验证缓存大小
|
||||||
if g.CacheSize < 0 {
|
if err := ValidateNonNegative(g.CacheSize, "cache_size"); err != nil {
|
||||||
return errors.New("cache_size 不能为负数")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证缓存 TTL
|
// 验证缓存 TTL
|
||||||
if g.CacheTTL < 0 {
|
if err := ValidateNonNegativeDuration(g.CacheTTL, "cache_ttl"); err != nil {
|
||||||
return errors.New("cache_ttl 不能为负数")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证默认动作
|
// 验证默认动作
|
||||||
@ -883,18 +916,18 @@ func validateHTTP2(h *HTTP2Config, hasSSL bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证并发流数量
|
// 验证并发流数量
|
||||||
if h.MaxConcurrentStreams < 0 {
|
if err := ValidateNonNegative(h.MaxConcurrentStreams, "max_concurrent_streams"); err != nil {
|
||||||
return errors.New("max_concurrent_streams 不能为负数")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证头部大小限制
|
// 验证头部大小限制
|
||||||
if h.MaxHeaderListSize < 0 {
|
if err := ValidateNonNegative(h.MaxHeaderListSize, "max_header_list_size"); err != nil {
|
||||||
return errors.New("max_header_list_size 不能为负数")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证空闲超时
|
// 验证空闲超时
|
||||||
if h.IdleTimeout < 0 {
|
if err := ValidateNonNegativeDuration(h.IdleTimeout, "idle_timeout"); err != nil {
|
||||||
return errors.New("idle_timeout 不能为负数")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -1085,8 +1118,8 @@ func validateStream(s *StreamConfig) error {
|
|||||||
// - error: 验证失败时返回错误信息,成功返回 nil
|
// - error: 验证失败时返回错误信息,成功返回 nil
|
||||||
func validatePerformance(p *PerformanceConfig) error {
|
func validatePerformance(p *PerformanceConfig) error {
|
||||||
// 检查 Transport 配置(可能导致性能问题)
|
// 检查 Transport 配置(可能导致性能问题)
|
||||||
if p.Transport.MaxConnsPerHost < 0 {
|
if err := ValidateNonNegative(p.Transport.MaxConnsPerHost, "transport.max_conns_per_host"); err != nil {
|
||||||
return errors.New("transport.max_conns_per_host 不能为负数")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -1112,8 +1145,8 @@ func validateNextUpstream(n *NextUpstreamConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证重试次数
|
// 验证重试次数
|
||||||
if n.Tries < 0 {
|
if err := ValidateNonNegative(n.Tries, "tries"); err != nil {
|
||||||
return errors.New("tries 不能为负数")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证 HTTP 状态码
|
// 验证 HTTP 状态码
|
||||||
@ -1204,26 +1237,26 @@ func validateLua(l *LuaMiddlewareConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 超时时间验证
|
// 超时时间验证
|
||||||
if script.Timeout < 0 {
|
if err := ValidateNonNegativeDuration(script.Timeout, fmt.Sprintf("scripts[%d].timeout", i)); err != nil {
|
||||||
return fmt.Errorf("scripts[%d].timeout 不能为负数", i)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证全局设置
|
// 验证全局设置
|
||||||
if l.GlobalSettings.MaxConcurrentCoroutines < 0 {
|
if err := ValidateNonNegative(l.GlobalSettings.MaxConcurrentCoroutines, "global_settings.max_concurrent_coroutines"); err != nil {
|
||||||
return errors.New("global_settings.max_concurrent_coroutines 不能为负数")
|
return err
|
||||||
}
|
}
|
||||||
if l.GlobalSettings.MaxConcurrentCoroutines > 0 && l.GlobalSettings.MaxConcurrentCoroutines < 1 {
|
if l.GlobalSettings.MaxConcurrentCoroutines > 0 && l.GlobalSettings.MaxConcurrentCoroutines < 1 {
|
||||||
return errors.New("global_settings.max_concurrent_coroutines 至少为 1")
|
return errors.New("global_settings.max_concurrent_coroutines 至少为 1")
|
||||||
}
|
}
|
||||||
if l.GlobalSettings.CoroutineTimeout < 0 {
|
if err := ValidateNonNegativeDuration(l.GlobalSettings.CoroutineTimeout, "global_settings.coroutine_timeout"); err != nil {
|
||||||
return errors.New("global_settings.coroutine_timeout 不能为负数")
|
return err
|
||||||
}
|
}
|
||||||
if l.GlobalSettings.CodeCacheSize < 0 {
|
if err := ValidateNonNegative(l.GlobalSettings.CodeCacheSize, "global_settings.code_cache_size"); err != nil {
|
||||||
return errors.New("global_settings.code_cache_size 不能为负数")
|
return err
|
||||||
}
|
}
|
||||||
if l.GlobalSettings.MaxExecutionTime < 0 {
|
if err := ValidateNonNegativeDuration(l.GlobalSettings.MaxExecutionTime, "global_settings.max_execution_time"); err != nil {
|
||||||
return errors.New("global_settings.max_execution_time 不能为负数")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -114,18 +114,7 @@ func newNgxLogAPI(ctx *fasthttp.RequestCtx, luaCtx *LuaContext, logger *zerolog.
|
|||||||
// 每次请求都会重新注册请求特定的函数(log, say, print, flush, exit, redirect)。
|
// 每次请求都会重新注册请求特定的函数(log, say, print, flush, exit, redirect)。
|
||||||
func RegisterNgxLogAPI(L *glua.LState, api *ngxLogAPI) {
|
func RegisterNgxLogAPI(L *glua.LState, api *ngxLogAPI) {
|
||||||
// 获取或创建 ngx 表
|
// 获取或创建 ngx 表
|
||||||
var ngx *glua.LTable
|
ngx := GetOrCreateNgxTable(L)
|
||||||
existingNgx := L.GetGlobal("ngx")
|
|
||||||
if existingNgx != nil && existingNgx.Type() == glua.LTTable {
|
|
||||||
ngxTable, ok := existingNgx.(*glua.LTable)
|
|
||||||
if ok {
|
|
||||||
ngx = ngxTable
|
|
||||||
} else {
|
|
||||||
ngx = L.NewTable()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ngx = L.NewTable()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查常量是否已注册(通过 STDERR 常量判断)
|
// 检查常量是否已注册(通过 STDERR 常量判断)
|
||||||
// 如果已注册,跳过常量写入,避免并发写入全局表
|
// 如果已注册,跳过常量写入,避免并发写入全局表
|
||||||
|
|||||||
@ -105,15 +105,8 @@ func newNgxReqAPI(ctx *fasthttp.RequestCtx) *ngxReqAPI {
|
|||||||
// RegisterNgxReqAPI 在 Lua 状态机中注册 ngx.req API
|
// RegisterNgxReqAPI 在 Lua 状态机中注册 ngx.req API
|
||||||
// 这是主入口函数,由 LuaEngine 在初始化时调用
|
// 这是主入口函数,由 LuaEngine 在初始化时调用
|
||||||
func RegisterNgxReqAPI(L *glua.LState, api *ngxReqAPI, ngxTable *glua.LTable) {
|
func RegisterNgxReqAPI(L *glua.LState, api *ngxReqAPI, ngxTable *glua.LTable) {
|
||||||
// 检查 ngx.req 是否已存在,避免并发写入
|
// 获取或创建 ngx.req 子表
|
||||||
var ngxReq *glua.LTable
|
ngxReq := GetOrCreateNgxSubTable(ngxTable, L, "req")
|
||||||
if existingReq := ngxTable.RawGetString("req"); existingReq == glua.LNil {
|
|
||||||
// 首次创建 ngx.req 子表
|
|
||||||
ngxReq = L.NewTable()
|
|
||||||
ngxTable.RawSetString("req", ngxReq)
|
|
||||||
} else {
|
|
||||||
ngxReq = existingReq.(*glua.LTable)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 直接映射层 API:get_method
|
// 直接映射层 API:get_method
|
||||||
// 特点:直接访问 fasthttp.RequestCtx,零拷贝,最小开销
|
// 特点:直接访问 fasthttp.RequestCtx,零拷贝,最小开销
|
||||||
|
|||||||
@ -46,29 +46,11 @@ func newNgxRespAPI(ctx *fasthttp.RequestCtx) *ngxRespAPI {
|
|||||||
// RegisterNgxRespAPI 在 Lua 状态机中注册 ngx.resp API
|
// RegisterNgxRespAPI 在 Lua 状态机中注册 ngx.resp API
|
||||||
// 这是主入口函数,由 LuaEngine 在初始化时调用
|
// 这是主入口函数,由 LuaEngine 在初始化时调用
|
||||||
func RegisterNgxRespAPI(L *glua.LState, api *ngxRespAPI) {
|
func RegisterNgxRespAPI(L *glua.LState, api *ngxRespAPI) {
|
||||||
// 获取已存在的 ngx 表(必须已设置全局)
|
// 获取或创建 ngx 表
|
||||||
ngx := L.GetGlobal("ngx")
|
ngxTable := GetOrCreateNgxTable(L)
|
||||||
if ngx == nil || ngx.Type() != glua.LTTable {
|
|
||||||
// 如果不存在,创建新表并设置全局
|
|
||||||
ngx = L.NewTable()
|
|
||||||
L.SetGlobal("ngx", ngx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 类型断言检查
|
// 获取或创建 ngx.resp 子表
|
||||||
ngxTable, ok := ngx.(*glua.LTable)
|
ngxResp := GetOrCreateNgxSubTable(ngxTable, L, "resp")
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查 ngx.resp 是否已存在,避免并发写入
|
|
||||||
var ngxResp *glua.LTable
|
|
||||||
if existingResp := ngxTable.RawGetString("resp"); existingResp == glua.LNil {
|
|
||||||
// 首次创建 ngx.resp 子表
|
|
||||||
ngxResp = L.NewTable()
|
|
||||||
ngxTable.RawSetString("resp", ngxResp)
|
|
||||||
} else {
|
|
||||||
ngxResp = existingResp.(*glua.LTable)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 每次请求更新函数以绑定正确的 ctx
|
// 每次请求更新函数以绑定正确的 ctx
|
||||||
ngxResp.RawSetString("get_status", L.NewFunction(api.luaGetStatus))
|
ngxResp.RawSetString("get_status", L.NewFunction(api.luaGetStatus))
|
||||||
|
|||||||
@ -537,26 +537,11 @@ const tcpSocketMT = "tcp_socket"
|
|||||||
|
|
||||||
// RegisterTCPSocketAPI 注册 TCP socket API
|
// RegisterTCPSocketAPI 注册 TCP socket API
|
||||||
func RegisterTCPSocketAPI(L *glua.LState, engine *LuaEngine) {
|
func RegisterTCPSocketAPI(L *glua.LState, engine *LuaEngine) {
|
||||||
// 确保 ngx 表存在
|
// 获取或创建 ngx 表
|
||||||
ngx := L.GetGlobal("ngx")
|
ngxTbl := GetOrCreateNgxTable(L)
|
||||||
var ngxTbl *glua.LTable
|
|
||||||
if tbl, ok := ngx.(*glua.LTable); ok {
|
|
||||||
ngxTbl = tbl
|
|
||||||
} else {
|
|
||||||
// 创建 ngx 表
|
|
||||||
ngxTbl = L.NewTable()
|
|
||||||
L.SetGlobal("ngx", ngxTbl)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查 ngx.socket 是否已存在,避免并发写入
|
// 获取或创建 ngx.socket 子表
|
||||||
var socket *glua.LTable
|
socket := GetOrCreateNgxSubTable(ngxTbl, L, "socket")
|
||||||
if existing := ngxTbl.RawGetString("socket"); existing == glua.LNil {
|
|
||||||
// 首次创建 ngx.socket 表
|
|
||||||
socket = L.NewTable()
|
|
||||||
ngxTbl.RawSetString("socket", socket)
|
|
||||||
} else {
|
|
||||||
socket = existing.(*glua.LTable)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 每次请求更新 tcp 函数以绑定正确的 engine
|
// 每次请求更新 tcp 函数以绑定正确的 engine
|
||||||
socket.RawSetString("tcp", L.NewFunction(newTCPSocketFunc(engine)))
|
socket.RawSetString("tcp", L.NewFunction(newTCPSocketFunc(engine)))
|
||||||
|
|||||||
56
internal/lua/helpers.go
Normal file
56
internal/lua/helpers.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
// Package lua 提供 Lua API 辅助函数。
|
||||||
|
//
|
||||||
|
// 该文件包含 ngx 表操作的共享辅助函数,用于减少代码重复。
|
||||||
|
// 所有函数保持并发安全设计(使用 RawGetString/RawSetString)。
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package lua
|
||||||
|
|
||||||
|
import glua "github.com/yuin/gopher-lua"
|
||||||
|
|
||||||
|
// GetOrCreateNgxTable 获取或创建全局 ngx 表。
|
||||||
|
//
|
||||||
|
// 如果全局 ngx 表已存在,则返回现有表;否则创建新表并设置为全局变量。
|
||||||
|
// 该函数是并发安全的,使用 RawGetString/RawSetString 操作。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - L: Lua 状态机
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *glua.LTable: ngx 表
|
||||||
|
func GetOrCreateNgxTable(L *glua.LState) *glua.LTable {
|
||||||
|
ngx := L.GetGlobal("ngx")
|
||||||
|
if ngx != nil && ngx.Type() == glua.LTTable {
|
||||||
|
if tbl, ok := ngx.(*glua.LTable); ok {
|
||||||
|
return tbl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建新的 ngx 表
|
||||||
|
tbl := L.NewTable()
|
||||||
|
L.SetGlobal("ngx", tbl)
|
||||||
|
return tbl
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrCreateNgxSubTable 获取或创建 ngx 子表。
|
||||||
|
//
|
||||||
|
// 如果子表已存在,则返回现有表;否则创建新子表并设置到父表。
|
||||||
|
// 该函数是并发安全的,使用 RawGetString/RawSetString 操作。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - ngx: 父表(通常是 ngx 表)
|
||||||
|
// - L: Lua 状态机
|
||||||
|
// - name: 子表名称(如 "req", "resp", "socket")
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *glua.LTable: 子表
|
||||||
|
func GetOrCreateNgxSubTable(ngx *glua.LTable, L *glua.LState, name string) *glua.LTable {
|
||||||
|
existing := ngx.RawGetString(name)
|
||||||
|
if existing == glua.LNil {
|
||||||
|
// 首次创建子表
|
||||||
|
sub := L.NewTable()
|
||||||
|
ngx.RawSetString(name, sub)
|
||||||
|
return sub
|
||||||
|
}
|
||||||
|
return existing.(*glua.LTable)
|
||||||
|
}
|
||||||
@ -20,6 +20,7 @@ package compression
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@ -29,6 +30,63 @@ import (
|
|||||||
"rua.plus/lolly/internal/config"
|
"rua.plus/lolly/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// resettableWriteCloser 接口用于统一 gzip.Writer 和 brotli.Writer 的操作。
|
||||||
|
// 这两个 writer 都实现了 Reset(io.Writer), Write([]byte) (int, error), Close() error
|
||||||
|
type resettableWriteCloser interface {
|
||||||
|
Reset(w io.Writer)
|
||||||
|
Write(data []byte) (int, error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// compressorPool 是一个通用的压缩 writer 池。
|
||||||
|
type compressorPool struct {
|
||||||
|
pool sync.Pool
|
||||||
|
level int
|
||||||
|
factory func(level int) resettableWriteCloser
|
||||||
|
}
|
||||||
|
|
||||||
|
// newGzipPool 创建 gzip writer 池。
|
||||||
|
func newGzipPool(level int) *compressorPool {
|
||||||
|
return &compressorPool{
|
||||||
|
level: level,
|
||||||
|
factory: func(level int) resettableWriteCloser {
|
||||||
|
w, err := gzip.NewWriterLevel(nil, level)
|
||||||
|
if err != nil {
|
||||||
|
w, _ = gzip.NewWriterLevel(nil, gzip.DefaultCompression)
|
||||||
|
}
|
||||||
|
return w
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newBrotliPool 创建 brotli writer 池。
|
||||||
|
func newBrotliPool(level int) *compressorPool {
|
||||||
|
return &compressorPool{
|
||||||
|
level: level,
|
||||||
|
factory: func(level int) resettableWriteCloser {
|
||||||
|
return brotli.NewWriterOptions(nil, brotli.WriterOptions{
|
||||||
|
Quality: level,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 从池中获取 writer。
|
||||||
|
func (p *compressorPool) Get() (resettableWriteCloser, bool) {
|
||||||
|
if p.pool.New == nil {
|
||||||
|
p.pool.New = func() any {
|
||||||
|
return p.factory(p.level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
v, ok := p.pool.Get().(resettableWriteCloser)
|
||||||
|
return v, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put 将 writer 放回池中。
|
||||||
|
func (p *compressorPool) Put(w resettableWriteCloser) {
|
||||||
|
p.pool.Put(w)
|
||||||
|
}
|
||||||
|
|
||||||
// streamingThreshold 流式压缩阈值。
|
// streamingThreshold 流式压缩阈值。
|
||||||
// 响应体超过此大小时使用 SetBodyStreamWriter 流式压缩,
|
// 响应体超过此大小时使用 SetBodyStreamWriter 流式压缩,
|
||||||
// 消除 compressed buffer 分配,降低内存峰值。
|
// 消除 compressed buffer 分配,降低内存峰值。
|
||||||
@ -49,9 +107,9 @@ const (
|
|||||||
// Middleware 响应压缩中间件。
|
// Middleware 响应压缩中间件。
|
||||||
type Middleware struct {
|
type Middleware struct {
|
||||||
// gzipPool gzip.Writer 缓冲池
|
// gzipPool gzip.Writer 缓冲池
|
||||||
gzipPool sync.Pool
|
gzipPool *compressorPool
|
||||||
// brotliPool brotli.Writer 缓冲池
|
// brotliPool brotli.Writer 缓冲池
|
||||||
brotliPool sync.Pool
|
brotliPool *compressorPool
|
||||||
// types 可压缩的 MIME 类型列表
|
// types 可压缩的 MIME 类型列表
|
||||||
types []string
|
types []string
|
||||||
|
|
||||||
@ -114,25 +172,8 @@ func New(cfg *config.CompressionConfig) (*Middleware, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 初始化缓冲池
|
// 初始化缓冲池
|
||||||
m.gzipPool = sync.Pool{
|
m.gzipPool = newGzipPool(cfg.Level)
|
||||||
New: func() any {
|
m.brotliPool = newBrotliPool(cfg.Level)
|
||||||
w, err := gzip.NewWriterLevel(nil, cfg.Level)
|
|
||||||
if err != nil {
|
|
||||||
// 使用默认压缩级别作为回退
|
|
||||||
w, _ = gzip.NewWriterLevel(nil, gzip.DefaultCompression)
|
|
||||||
}
|
|
||||||
return w
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// 初始化 brotli 缓冲池
|
|
||||||
m.brotliPool = sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
return brotli.NewWriterOptions(nil, brotli.WriterOptions{
|
|
||||||
Quality: cfg.Level,
|
|
||||||
})
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
@ -224,18 +265,18 @@ func (m *Middleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandl
|
|||||||
if bodyLen > streamingThreshold {
|
if bodyLen > streamingThreshold {
|
||||||
// 大响应:流式压缩,消除 compressed buffer 分配
|
// 大响应:流式压缩,消除 compressed buffer 分配
|
||||||
if useBrotli {
|
if useBrotli {
|
||||||
m.streamBrotli(ctx, encoding)
|
m.streamWithPool(ctx, encoding, m.brotliPool)
|
||||||
} else if useGzip {
|
} else if useGzip {
|
||||||
m.streamGzip(ctx, encoding)
|
m.streamWithPool(ctx, encoding, m.gzipPool)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 小响应:缓冲压缩
|
// 小响应:缓冲压缩
|
||||||
var compressed []byte
|
var compressed []byte
|
||||||
|
|
||||||
if useBrotli {
|
if useBrotli {
|
||||||
compressed = m.compressBrotli(body)
|
compressed = m.compressWithPool(body, m.brotliPool)
|
||||||
} else if useGzip {
|
} else if useGzip {
|
||||||
compressed = m.compressGzip(body)
|
compressed = m.compressWithPool(body, m.gzipPool)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(compressed) > 0 && len(compressed) < bodyLen {
|
if len(compressed) > 0 && len(compressed) < bodyLen {
|
||||||
@ -276,19 +317,20 @@ func (m *Middleware) isCompressible(contentType []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// compressGzip 使用 gzip 压缩数据。
|
// compressWithPool 使用缓冲池压缩数据。
|
||||||
//
|
//
|
||||||
// 参数:
|
// 参数:
|
||||||
// - data: 待压缩的原始数据
|
// - data: 待压缩的原始数据
|
||||||
|
// - pool: 压缩 writer 缓冲池
|
||||||
//
|
//
|
||||||
// 返回值:
|
// 返回值:
|
||||||
// - []byte: 压缩后的数据
|
// - []byte: 压缩后的数据
|
||||||
func (m *Middleware) compressGzip(data []byte) []byte {
|
func (m *Middleware) compressWithPool(data []byte, pool *compressorPool) []byte {
|
||||||
w, ok := m.gzipPool.Get().(*gzip.Writer)
|
w, ok := pool.Get()
|
||||||
if !ok {
|
if !ok {
|
||||||
return data // fallback to uncompressed
|
return data // fallback to uncompressed
|
||||||
}
|
}
|
||||||
defer m.gzipPool.Put(w)
|
defer pool.Put(w)
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
w.Reset(&buf)
|
w.Reset(&buf)
|
||||||
@ -300,28 +342,35 @@ func (m *Middleware) compressGzip(data []byte) []byte {
|
|||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
// compressBrotli 使用 brotli 压缩数据。
|
// streamWithPool 使用流式压缩。
|
||||||
|
//
|
||||||
|
// 通过 SetBodyStreamWriter 将压缩数据直接写入响应流,
|
||||||
|
// 消除 compressed buffer 分配,降低内存峰值。
|
||||||
//
|
//
|
||||||
// 参数:
|
// 参数:
|
||||||
// - data: 待压缩的原始数据
|
// - ctx: fasthttp 请求上下文
|
||||||
//
|
// - encoding: Content-Encoding 值
|
||||||
// 返回值:
|
// - pool: 压缩 writer 缓冲池
|
||||||
// - []byte: 压缩后的数据
|
func (m *Middleware) streamWithPool(ctx *fasthttp.RequestCtx, encoding string, pool *compressorPool) {
|
||||||
func (m *Middleware) compressBrotli(data []byte) []byte {
|
ctx.Response.Header.Set("Content-Encoding", encoding)
|
||||||
w, ok := m.brotliPool.Get().(*brotli.Writer)
|
ctx.Response.Header.Del("Content-Length") // 使用 chunked encoding
|
||||||
|
|
||||||
|
body := ctx.Response.Body()
|
||||||
|
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
|
||||||
|
writer, ok := pool.Get()
|
||||||
if !ok {
|
if !ok {
|
||||||
return data // fallback to uncompressed
|
// pool 获取失败,直接写原始 body
|
||||||
|
_, _ = w.Write(body)
|
||||||
|
_ = w.Flush()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
defer m.brotliPool.Put(w)
|
defer pool.Put(writer)
|
||||||
|
|
||||||
var buf bytes.Buffer
|
writer.Reset(w)
|
||||||
w.Reset(&buf)
|
_, _ = writer.Write(body)
|
||||||
if _, err := w.Write(data); err != nil { //nolint:staticcheck // intentionally empty branch
|
_ = writer.Close()
|
||||||
// 忽略写入错误,缓冲到 bytes.Buffer 时不太可能失败
|
_ = w.Flush()
|
||||||
}
|
})
|
||||||
_ = w.Close()
|
|
||||||
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Types 返回可压缩的 MIME 类型列表。
|
// Types 返回可压缩的 MIME 类型列表。
|
||||||
@ -347,63 +396,3 @@ func (m *Middleware) Level() int {
|
|||||||
func (m *Middleware) MinSize() int {
|
func (m *Middleware) MinSize() int {
|
||||||
return m.minSize
|
return m.minSize
|
||||||
}
|
}
|
||||||
|
|
||||||
// streamGzip 使用 gzip 流式压缩。
|
|
||||||
//
|
|
||||||
// 通过 SetBodyStreamWriter 将压缩数据直接写入响应流,
|
|
||||||
// 消除 compressed buffer 分配,降低内存峰值。
|
|
||||||
//
|
|
||||||
// 参数:
|
|
||||||
// - ctx: fasthttp 请求上下文
|
|
||||||
// - encoding: Content-Encoding 值("gzip")
|
|
||||||
func (m *Middleware) streamGzip(ctx *fasthttp.RequestCtx, encoding string) {
|
|
||||||
ctx.Response.Header.Set("Content-Encoding", encoding)
|
|
||||||
ctx.Response.Header.Del("Content-Length") // 使用 chunked encoding
|
|
||||||
|
|
||||||
body := ctx.Response.Body()
|
|
||||||
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
|
|
||||||
writer, ok := m.gzipPool.Get().(*gzip.Writer)
|
|
||||||
if !ok {
|
|
||||||
// pool 获取失败,直接写原始 body
|
|
||||||
_, _ = w.Write(body)
|
|
||||||
_ = w.Flush()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer m.gzipPool.Put(writer)
|
|
||||||
|
|
||||||
writer.Reset(w)
|
|
||||||
_, _ = writer.Write(body)
|
|
||||||
_ = writer.Close()
|
|
||||||
_ = w.Flush()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// streamBrotli 使用 brotli 流式压缩。
|
|
||||||
//
|
|
||||||
// 通过 SetBodyStreamWriter 将压缩数据直接写入响应流,
|
|
||||||
// 消除 compressed buffer 分配,降低内存峰值。
|
|
||||||
//
|
|
||||||
// 参数:
|
|
||||||
// - ctx: fasthttp 请求上下文
|
|
||||||
// - encoding: Content-Encoding 值("br")
|
|
||||||
func (m *Middleware) streamBrotli(ctx *fasthttp.RequestCtx, encoding string) {
|
|
||||||
ctx.Response.Header.Set("Content-Encoding", encoding)
|
|
||||||
ctx.Response.Header.Del("Content-Length") // 使用 chunked encoding
|
|
||||||
|
|
||||||
body := ctx.Response.Body()
|
|
||||||
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
|
|
||||||
writer, ok := m.brotliPool.Get().(*brotli.Writer)
|
|
||||||
if !ok {
|
|
||||||
// pool 获取失败,直接写原始 body
|
|
||||||
_, _ = w.Write(body)
|
|
||||||
_ = w.Flush()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer m.brotliPool.Put(writer)
|
|
||||||
|
|
||||||
writer.Reset(w)
|
|
||||||
_, _ = writer.Write(body)
|
|
||||||
_ = writer.Close()
|
|
||||||
_ = w.Flush()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ func BenchmarkGzipCompress_1KB(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
mw.compressGzip(data)
|
mw.compressWithPool(data, mw.gzipPool)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ func BenchmarkGzipCompress_10KB(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
mw.compressGzip(data)
|
mw.compressWithPool(data, mw.gzipPool)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ func BenchmarkGzipCompress_100KB(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
mw.compressGzip(data)
|
mw.compressWithPool(data, mw.gzipPool)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ func BenchmarkBrotliCompress_1KB(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
mw.compressBrotli(data)
|
mw.compressWithPool(data, mw.brotliPool)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,7 +95,7 @@ func BenchmarkBrotliCompress_10KB(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
mw.compressBrotli(data)
|
mw.compressWithPool(data, mw.brotliPool)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ func BenchmarkCompressionPool(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
mw.compressGzip(data)
|
mw.compressWithPool(data, mw.gzipPool)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -222,7 +222,7 @@ func BenchmarkCompressionLevelComparison(b *testing.B) {
|
|||||||
mw, _ := New(cfg)
|
mw, _ := New(cfg)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
mw.compressGzip(data)
|
mw.compressWithPool(data, mw.gzipPool)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -231,7 +231,7 @@ func BenchmarkCompressionLevelComparison(b *testing.B) {
|
|||||||
mw, _ := New(cfg)
|
mw, _ := New(cfg)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
mw.compressGzip(data)
|
mw.compressWithPool(data, mw.gzipPool)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -240,7 +240,7 @@ func BenchmarkCompressionLevelComparison(b *testing.B) {
|
|||||||
mw, _ := New(cfg)
|
mw, _ := New(cfg)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
mw.compressGzip(data)
|
mw.compressWithPool(data, mw.gzipPool)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -134,7 +134,7 @@ func TestCompressGzip(t *testing.T) {
|
|||||||
// 测试数据
|
// 测试数据
|
||||||
data := []byte("Hello, World! This is a test string that should be compressed.")
|
data := []byte("Hello, World! This is a test string that should be compressed.")
|
||||||
|
|
||||||
compressed := m.compressGzip(data)
|
compressed := m.compressWithPool(data, m.gzipPool)
|
||||||
if len(compressed) == 0 {
|
if len(compressed) == 0 {
|
||||||
t.Error("Expected compressed data")
|
t.Error("Expected compressed data")
|
||||||
}
|
}
|
||||||
@ -153,7 +153,7 @@ func TestCompressBrotli(t *testing.T) {
|
|||||||
|
|
||||||
data := []byte("Hello, World! This is a test string that should be compressed with brotli.")
|
data := []byte("Hello, World! This is a test string that should be compressed with brotli.")
|
||||||
|
|
||||||
compressed := m.compressBrotli(data)
|
compressed := m.compressWithPool(data, m.brotliPool)
|
||||||
if len(compressed) == 0 {
|
if len(compressed) == 0 {
|
||||||
t.Error("Expected compressed data")
|
t.Error("Expected compressed data")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -83,7 +83,7 @@ func BenchmarkGzipWriter_Pool(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
_ = mw.compressGzip(data)
|
_ = mw.compressWithPool(data, mw.gzipPool)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,7 +145,7 @@ func BenchmarkGzipCompress_Sizes(b *testing.B) {
|
|||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for b.Loop() {
|
||||||
_ = mw.compressGzip(tc.data)
|
_ = mw.compressWithPool(tc.data, mw.gzipPool)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user