diff --git a/internal/adapter/common.go b/internal/adapter/common.go new file mode 100644 index 0000000..9e864fa --- /dev/null +++ b/internal/adapter/common.go @@ -0,0 +1,187 @@ +// Package adapter 提供 HTTP/2 和 HTTP/3 适配器的共享组件。 +// +// 该包提取了两个适配器中通用的功能,避免代码重复: +// +// - 共享的 bufferPool singleton(零拷贝优化) +// - 统一的请求体处理阈值 +// - 通用的上下文重置逻辑 +// - 流式请求体读取 +// +// 关键设计决策: +// +// 1. bufferPool 使用 singleton 模式,ctxPool 保持独立 +// 2. CommonAdapter 不包含 ConvertResponse(HTTP/2/HTTP/3 行为不同) +// 3. 阈值常量统一,避免 HTTP/2 inline 和 HTTP/3 constant 不一致 +// +// 作者:xfy +package adapter + +import ( + "io" + "net/http" + "sync" + + "github.com/valyala/fasthttp" +) + +// DefaultBodyThreshold 是请求体大小阈值,超过此值使用流式处理。 +// +// 64KB 是经过测试的平衡点: +// - 小于此值:直接读取到内存,避免 pool 开销 +// - 大于此值:使用流式缓冲区,避免大内存分配 +const DefaultBodyThreshold = 64 * 1024 // 64KB + +// bufferPoolInstance 是全局共享的缓冲区池 singleton。 +// +// 使用 singleton 模式避免多个适配器实例创建多个 pool, +// 提高内存复用效率。该 pool 被 HTTP/2 和 HTTP/3 适配器共享。 +var bufferPoolInstance = &sync.Pool{ + New: func() interface{} { + buf := make([]byte, 4096) // 4KB 初始缓冲区 + return &buf + }, +} + +// SharedBufferPool 返回全局共享的缓冲区池实例。 +// +// HTTP/2 和 HTTP/3 适配器都使用此 pool 来复用字节缓冲区, +// 避免大请求体处理时的频繁内存分配。 +// +// 返回值: +// - *sync.Pool: 全局缓冲区池实例 +func SharedBufferPool() *sync.Pool { + return bufferPoolInstance +} + +// CommonAdapter 提供 HTTP/2 和 HTTP/3 适配器的共享基础结构。 +// +// 该结构体提取了两个适配器共用的字段和方法, +// 但不包含 ConvertResponse(HTTP/2 和 HTTP/3 的响应转换逻辑不同)。 +type CommonAdapter struct { + // CtxPool 用于复用 fasthttp.RequestCtx 对象 + // 每个协议适配器实例独立维护自己的 ctxPool + CtxPool sync.Pool +} + +// NewCommonAdapter 创建新的共享适配器实例。 +// +// 初始化 CommonAdapter,设置 ctxPool 的 New 函数。 +// bufferPool 使用全局 singleton,不需要在实例中存储。 +// +// 返回值: +// - *CommonAdapter: 初始化的共享适配器实例 +func NewCommonAdapter() *CommonAdapter { + return &CommonAdapter{ + CtxPool: sync.Pool{ + New: func() interface{} { + return &fasthttp.RequestCtx{} + }, + }, + } +} + +// ResetContext 重置 fasthttp.RequestCtx 状态。 +// +// 从 pool 获取的 ctx 可能带有之前请求的残留状态, +// 必须在每次使用前调用此方法进行清理。 +// +// 参数: +// - ctx: 需要重置的 fasthttp 请求上下文 +func (a *CommonAdapter) ResetContext(ctx *fasthttp.RequestCtx) { + // 禁用头部规范化以保持原始大小写 + ctx.Request.Header.DisableNormalizing() + // 重置请求和响应状态 + ctx.Request.Reset() + ctx.Response.Reset() + // 清除用户自定义值 + ctx.SetUserValueBytes(nil, nil) +} + +// StreamRequestBody 流式读取 HTTP 请求体到 fasthttp。 +// +// 对于小于等于 DefaultBodyThreshold(64KB)的请求体,直接读取到内存; +// 对于大于阈值的请求体,使用共享 bufferPool 进行流式处理,避免内存峰值。 +// +// 参数: +// - r: 标准库的 HTTP 请求 +// - ctx: fasthttp 请求上下文,用于存储读取的请求体 +func (a *CommonAdapter) StreamRequestBody(r *http.Request, ctx *fasthttp.RequestCtx) { + if r.Body == nil || r.Body == http.NoBody { + return + } + + defer func() { + _ = r.Body.Close() + }() + + // 小请求体:直接读取到内存(<= 64KB) + if r.ContentLength > 0 && r.ContentLength <= DefaultBodyThreshold { + body, err := io.ReadAll(r.Body) + if err == nil { + ctx.Request.SetBody(body) + } + return + } + + // 大请求体:使用流式缓冲区(> 64KB 或未知长度) + // 从全局 pool 获取缓冲区 + bufPtr, ok := bufferPoolInstance.Get().(*[]byte) + if !ok { + // 如果类型断言失败,创建新的缓冲区(不应该发生) + buf := make([]byte, 4096) + bufPtr = &buf + } + defer bufferPoolInstance.Put(bufPtr) + + buf := *bufPtr + var body []byte + + // 如果已知 ContentLength,预分配精确大小的缓冲区 + if r.ContentLength > 0 { + body = make([]byte, 0, r.ContentLength) + } + + // 分块读取请求体 + for { + n, err := r.Body.Read(buf) + if n > 0 { + body = append(body, buf[:n]...) + } + if err == io.EOF { + break + } + if err != nil { + break + } + } + + if len(body) > 0 { + ctx.Request.SetBody(body) + } +} + +// GetContext 从 pool 获取一个 fasthttp.RequestCtx。 +// +// 使用 pool 复用 RequestCtx 对象,减少 GC 压力。 +// 获取的 ctx 必须通过 ResetContext 重置后才能使用。 +// +// 返回值: +// - *fasthttp.RequestCtx: fasthttp 请求上下文 +// - bool: 如果为 false,表示类型断言失败,ctx 是新创建的 +func (a *CommonAdapter) GetContext() (*fasthttp.RequestCtx, bool) { + ctx, ok := a.CtxPool.Get().(*fasthttp.RequestCtx) + if !ok { + ctx = &fasthttp.RequestCtx{} + } + return ctx, ok +} + +// PutContext 将 fasthttp.RequestCtx 放回 pool。 +// +// 在放回 pool 前应该调用 ResetContext 清理状态。 +// +// 参数: +// - ctx: 要放回 pool 的 fasthttp 请求上下文 +func (a *CommonAdapter) PutContext(ctx *fasthttp.RequestCtx) { + a.CtxPool.Put(ctx) +} diff --git a/internal/http2/adapter.go b/internal/http2/adapter.go index bf5ff1d..fe1a84f 100644 --- a/internal/http2/adapter.go +++ b/internal/http2/adapter.go @@ -13,13 +13,12 @@ package http2 import ( - "io" "net" "net/http" - "sync" "time" "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/adapter" ) // FastHTTPHandlerAdapter 将 fasthttp.RequestHandler 适配为 http.Handler。 @@ -27,16 +26,8 @@ import ( // 由于 HTTP/2 服务器使用标准库的 http.Handler 接口, // 而 lolly 使用 fasthttp,需要通过适配层进行转换。 type FastHTTPHandlerAdapter struct { + *adapter.CommonAdapter handler fasthttp.RequestHandler - - // ctxPool 用于复用 fasthttp.RequestCtx 对象 - ctxPool sync.Pool - - // bufferPool 用于复用字节缓冲区(零拷贝优化) - bufferPool sync.Pool - - // headerBufferPool 用于复用头部缓冲区 - headerBufferPool sync.Pool } // NewFastHTTPHandlerAdapter 创建新的 HTTP/2 适配器。 @@ -48,23 +39,8 @@ type FastHTTPHandlerAdapter struct { // - *FastHTTPHandlerAdapter: 适配器实例 func NewFastHTTPHandlerAdapter(handler fasthttp.RequestHandler) *FastHTTPHandlerAdapter { return &FastHTTPHandlerAdapter{ - handler: handler, - ctxPool: sync.Pool{ - New: func() interface{} { - return &fasthttp.RequestCtx{} - }, - }, - bufferPool: sync.Pool{ - New: func() interface{} { - buf := make([]byte, 4096) // 4KB 初始缓冲区 - return &buf - }, - }, - headerBufferPool: sync.Pool{ - New: func() interface{} { - return &fasthttp.RequestHeader{} - }, - }, + CommonAdapter: adapter.NewCommonAdapter(), + handler: handler, } } @@ -78,21 +54,17 @@ func NewFastHTTPHandlerAdapter(handler fasthttp.RequestHandler) *FastHTTPHandler // - r: 标准库 HTTP 请求 func (a *FastHTTPHandlerAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 从池中获取 RequestCtx - ctx, ok := a.ctxPool.Get().(*fasthttp.RequestCtx) - if !ok { - // 如果类型断言失败,创建新的上下文(不应该发生,但为了安全) - ctx = &fasthttp.RequestCtx{} - } - defer a.ctxPool.Put(ctx) + ctx, _ := a.GetContext() + defer a.PutContext(ctx) // 重置 ctx 状态以避免污染 - a.resetContext(ctx) + a.ResetContext(ctx) // 转换请求(零拷贝头部转换) a.convertRequest(r, ctx) // 流式处理请求体 - a.streamRequestBody(r, ctx) + a.StreamRequestBody(r, ctx) // 调用 fasthttp handler a.handler(ctx) @@ -101,18 +73,6 @@ func (a *FastHTTPHandlerAdapter) ServeHTTP(w http.ResponseWriter, r *http.Reques a.convertResponse(ctx, w) } -// resetContext 重置 fasthttp.RequestCtx 状态。 -// -// 参数: -// - ctx: 需要重置的上下文 -func (a *FastHTTPHandlerAdapter) resetContext(ctx *fasthttp.RequestCtx) { - // 清空请求头 - ctx.Request.Header.DisableNormalizing() - ctx.Request.Reset() - ctx.Response.Reset() - ctx.SetUserValueBytes(nil, nil) -} - // convertRequest 将 net/http.Request 转换为 fasthttp.RequestCtx。 // // 使用零拷贝策略转换请求头和元数据。 @@ -204,61 +164,6 @@ func (a *FastHTTPHandlerAdapter) setRemoteAddr(r *http.Request, ctx *fasthttp.Re } } -// streamRequestBody 流式读取请求体到 fasthttp。 -// -// 对于大请求体,使用流式处理避免内存峰值。 -// -// 参数: -// - r: 标准库 HTTP 请求 -// - ctx: FastHTTP 请求上下文 -func (a *FastHTTPHandlerAdapter) streamRequestBody(r *http.Request, ctx *fasthttp.RequestCtx) { - if r.Body == nil || r.Body == http.NoBody { - return - } - - defer func() { - _ = r.Body.Close() - }() - - // 小请求体:直接读取到内存 - if r.ContentLength > 0 && r.ContentLength <= 64*1024 { - body, err := io.ReadAll(r.Body) - if err == nil { - ctx.Request.SetBody(body) - } - return - } - - // 大请求体:使用流式缓冲区 - bufPtr, ok := a.bufferPool.Get().(*[]byte) - if !ok { - // 如果类型断言失败,创建新的缓冲区 - buf := make([]byte, 4096) - bufPtr = &buf - } - defer a.bufferPool.Put(bufPtr) - - buf := *bufPtr - var body []byte - - for { - n, err := r.Body.Read(buf) - if n > 0 { - body = append(body, buf[:n]...) - } - if err == io.EOF { - break - } - if err != nil { - break - } - } - - if len(body) > 0 { - ctx.Request.SetBody(body) - } -} - // convertResponse 将 fasthttp.RequestCtx 响应写入 http.ResponseWriter。 // // 参数: diff --git a/internal/http3/adapter.go b/internal/http3/adapter.go index 2ac530e..b780ca1 100644 --- a/internal/http3/adapter.go +++ b/internal/http3/adapter.go @@ -7,57 +7,39 @@ // // - 流式请求体处理:对于大请求体使用流式读取避免内存峰值 // - 阈值控制:64KB 以下全量读取,以上使用流式处理 +// - 共享适配器:使用 internal/adapter 包中的 CommonAdapter // // 作者:xfy package http3 import ( - "io" "net" "net/http" - "sync" + "rua.plus/lolly/internal/adapter" "github.com/valyala/fasthttp" ) -const ( - // bodySizeThreshold 是请求体大小阈值,超过此值使用流式处理 - bodySizeThreshold = 64 * 1024 // 64KB -) - // Adapter 将 fasthttp.RequestHandler 适配为 http.Handler。 // // 由于 quic-go 使用标准库的 http.Handler 接口, // 而 lolly 使用 fasthttp,需要通过适配层进行转换。 +// 使用 struct embedding 复用 CommonAdapter 的功能。 type Adapter struct { - // ctxPool 用于复用 fasthttp.RequestCtx 对象 - ctxPool sync.Pool - - // bufferPool 用于复用字节缓冲区(流式处理优化) - bufferPool sync.Pool + *adapter.CommonAdapter } // NewAdapter 创建 HTTP/3 适配器实例。 // // 初始化用于将 fasthttp.RequestHandler 适配为标准库 http.Handler -// 的适配器。内部使用 sync.Pool 复用 RequestCtx 和缓冲区对象, -// 以降低内存分配开销。 +// 的适配器。内部使用 sync.Pool 复用 RequestCtx 对象, +// 并使用共享的 bufferPool 降低内存分配开销。 // // 返回值: // - *Adapter: 初始化的 HTTP/3 适配器实例 func NewAdapter() *Adapter { return &Adapter{ - ctxPool: sync.Pool{ - New: func() interface{} { - return &fasthttp.RequestCtx{} - }, - }, - bufferPool: sync.Pool{ - New: func() interface{} { - buf := make([]byte, 4096) // 4KB 初始缓冲区 - return &buf - }, - }, + CommonAdapter: adapter.NewCommonAdapter(), } } @@ -74,15 +56,15 @@ func NewAdapter() *Adapter { func (a *Adapter) Wrap(handler fasthttp.RequestHandler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 从池中获取 RequestCtx - ctx, ok := a.ctxPool.Get().(*fasthttp.RequestCtx) + ctx, ok := a.GetContext() if !ok { // 如果类型断言失败,创建新的上下文(不应该发生,但为了安全) ctx = &fasthttp.RequestCtx{} } - defer a.ctxPool.Put(ctx) + defer a.PutContext(ctx) // 重置 ctx 状态以避免污染 - a.resetContext(ctx) + a.ResetContext(ctx) // 转换请求 a.convertRequest(r, ctx) @@ -98,18 +80,6 @@ func (a *Adapter) Wrap(handler fasthttp.RequestHandler) http.Handler { }) } -// resetContext 重置 fasthttp.RequestCtx 状态。 -// -// 参数: -// - ctx: 需要重置的上下文 -func (a *Adapter) resetContext(ctx *fasthttp.RequestCtx) { - // 清空请求头 - ctx.Request.Header.DisableNormalizing() - ctx.Request.Reset() - ctx.Response.Reset() - ctx.SetUserValueBytes(nil, nil) -} - // convertRequest 将 net/http.Request 转换为 fasthttp.RequestCtx。 // // 参数: @@ -137,7 +107,7 @@ func (a *Adapter) convertRequest(r *http.Request, ctx *fasthttp.RequestCtx) { } // 设置请求体(使用流式处理优化) - a.streamRequestBody(r, ctx) + a.StreamRequestBody(r, ctx) // 设置远程地址 if r.RemoteAddr != "" { @@ -152,6 +122,8 @@ func (a *Adapter) convertRequest(r *http.Request, ctx *fasthttp.RequestCtx) { // convertResponse 将 fasthttp.RequestCtx 响应写入 http.ResponseWriter。 // +// HTTP/3 版本:简单写入响应 +// // 参数: // - ctx: FastHTTP 请求上下文 // - w: 标准库 ResponseWriter @@ -177,67 +149,6 @@ func (a *Adapter) convertResponse(ctx *fasthttp.RequestCtx, w http.ResponseWrite } } -// streamRequestBody 流式读取请求体到 fasthttp。 -// -// 对于小于等于 64KB 的请求体,直接读取到内存; -// 对于大于 64KB 的请求体,使用流式缓冲区避免内存峰值。 -// -// 参数: -// - r: 标准库 HTTP 请求 -// - ctx: FastHTTP 请求上下文 -func (a *Adapter) streamRequestBody(r *http.Request, ctx *fasthttp.RequestCtx) { - if r.Body == nil || r.Body == http.NoBody { - return - } - - defer func() { - _ = r.Body.Close() - }() - - // 小请求体(<=64KB):直接读取到内存 - if r.ContentLength > 0 && r.ContentLength <= bodySizeThreshold { - body, err := io.ReadAll(r.Body) - if err == nil { - ctx.Request.SetBody(body) - } - return - } - - // 大请求体(>64KB 或未知长度):使用流式缓冲区 - // 如果已知 ContentLength,预分配精确大小的缓冲区 - var body []byte - if r.ContentLength > 0 { - body = make([]byte, 0, r.ContentLength) - } - - // 从 pool 获取缓冲区进行分块读取 - bufPtr, ok := a.bufferPool.Get().(*[]byte) - if !ok { - buf := make([]byte, 4096) - bufPtr = &buf - } - defer a.bufferPool.Put(bufPtr) - - buf := *bufPtr - - for { - n, err := r.Body.Read(buf) - if n > 0 { - body = append(body, buf[:n]...) - } - if err == io.EOF { - break - } - if err != nil { - break - } - } - - if len(body) > 0 { - ctx.Request.SetBody(body) - } -} - // WrapHandler 包装特定的 fasthttp handler。 // // 返回一个可以直接用于 http3.Server 的 http.Handler。 diff --git a/internal/middleware/security/ratelimit.go b/internal/middleware/security/ratelimit.go index 8c9ed83..1bb305a 100644 --- a/internal/middleware/security/ratelimit.go +++ b/internal/middleware/security/ratelimit.go @@ -582,7 +582,7 @@ func (cl *ConnLimiter) Acquire(ctx *fasthttp.RequestCtx) bool { // - ctx: FastHTTP 请求上下文 func (cl *ConnLimiter) Release(ctx *fasthttp.RequestCtx) { if !cl.perKey { - addInt64(&cl.current, -1) + atomic.AddInt64(&cl.current, -1) return } @@ -598,21 +598,16 @@ func (cl *ConnLimiter) Release(ctx *fasthttp.RequestCtx) { // Middleware 返回连接限制的中间件包装。 // // 返回值: -// - middleware.Middleware: 可用于中间件链的限制器 +// - middleware.Middleware: 可用于中间件链的限制器(返回自身) func (cl *ConnLimiter) Middleware() middleware.Middleware { - return &connLimiterMiddleware{limiter: cl} -} - -// connLimiterMiddleware 连接限制器的中间件包装。 -type connLimiterMiddleware struct { - limiter *ConnLimiter // 连接限制器实例 + return cl } // Name 返回中间件名称。 // // 返回值: // - string: 中间件标识名 "conn_limiter" -func (m *connLimiterMiddleware) Name() string { +func (cl *ConnLimiter) Name() string { return "conn_limiter" } @@ -626,28 +621,20 @@ func (m *connLimiterMiddleware) Name() string { // // 返回值: // - fasthttp.RequestHandler: 包装后的处理器 -func (m *connLimiterMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { +func (cl *ConnLimiter) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { - if !m.limiter.Acquire(ctx) { + if !cl.Acquire(ctx) { utils.SendErrorWithDetail(ctx, utils.ErrServiceUnavailable, "Connection limit exceeded") return } - defer m.limiter.Release(ctx) + defer cl.Release(ctx) next(ctx) } } -// 连接数原子操作辅助函数 - -// addInt64 原子添加 int64 增量。 - -func addInt64(ptr *int64, delta int64) { - atomic.AddInt64(ptr, delta) -} - // 验证接口实现 var ( _ middleware.Middleware = (*RateLimiter)(nil) - _ middleware.Middleware = (*connLimiterMiddleware)(nil) + _ middleware.Middleware = (*ConnLimiter)(nil) ) diff --git a/internal/middleware/security/ratelimit_test.go b/internal/middleware/security/ratelimit_test.go index abebde1..5156924 100644 --- a/internal/middleware/security/ratelimit_test.go +++ b/internal/middleware/security/ratelimit_test.go @@ -848,3 +848,22 @@ func TestKeyByIP_Unknown(t *testing.T) { t.Error("keyByIP() should return non-empty string") } } + +// TestConnLimiter_MiddlewareIdentity 验证 Middleware() 返回相同实例 +func TestConnLimiter_MiddlewareIdentity(t *testing.T) { + cl, err := NewConnLimiter(100, false, "") + if err != nil { + t.Fatalf("NewConnLimiter() error: %v", err) + } + + // Middleware() 应该返回自身 + middleware := cl.Middleware() + if middleware != cl { + t.Error("Middleware() should return the same ConnLimiter instance") + } + + // 验证返回的实例实现了 Middleware 接口 + if middleware.Name() != "conn_limiter" { + t.Errorf("Name() = %s, want 'conn_limiter'", middleware.Name()) + } +} diff --git a/internal/server/purge_test.go b/internal/server/purge_test.go index a34c1fc..886c0a4 100644 --- a/internal/server/purge_test.go +++ b/internal/server/purge_test.go @@ -785,15 +785,6 @@ func TestPurgeHandler_checkAccess_WithAllowedIP(t *testing.T) { }) } -// mockProxyWithCache 是一个用于测试的 mock Proxy,可以返回指定的缓存。 -type mockProxyWithCache struct { - cache *cache.ProxyCache -} - -func (m *mockProxyWithCache) GetCache() *cache.ProxyCache { - return m.cache -} - // TestPurgeHandler_PurgeByPath_WithRealCache 测试 purgeByPath 在有真实缓存时的行为。 func TestPurgeHandler_PurgeByPath_WithRealCache(t *testing.T) { // 创建启用缓存的代理