feat(http2): 新增 HTTP/2 支持,集成到服务器和应用

This commit is contained in:
xfy 2026-04-09 12:18:52 +08:00
parent 42533c31d2
commit 412bfebdd8
11 changed files with 2782 additions and 3 deletions

View File

@ -26,11 +26,13 @@ import (
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/http2"
"rua.plus/lolly/internal/http3"
"rua.plus/lolly/internal/logging"
"rua.plus/lolly/internal/resolver"
"rua.plus/lolly/internal/server"
"rua.plus/lolly/internal/stream"
"rua.plus/lolly/internal/variable"
)
// 版本信息,通过 -ldflags 注入。
@ -72,6 +74,9 @@ type App struct {
// http3Srv HTTP/3 服务器实例(可选)
http3Srv *http3.Server
// http2Srv HTTP/2 服务器实例(可选)
http2Srv *http2.Server
// streamSrv Stream 服务器实例(可选)
streamSrv *stream.Server
@ -167,6 +172,14 @@ func (a *App) Run() int {
a.cfg = cfg
a.logger = logging.NewAppLogger(&cfg.Logging)
// 设置全局变量
variable.SetGlobalVariables(cfg.Variables.Set)
if len(cfg.Variables.Set) > 0 {
a.logger.LogStartup("全局变量已加载", map[string]string{
"count": fmt.Sprintf("%d", len(cfg.Variables.Set)),
})
}
// 检查是否是子进程(热升级)
if os.Getenv("GRACEFUL_UPGRADE") == "1" {
a.logger.LogStartup("检测到热升级模式,继承父进程监听器", nil)
@ -263,6 +276,38 @@ func (a *App) Run() int {
}
}
// 创建并启动 HTTP/2 服务器(如果启用且配置了 TLS
if a.cfg.Server.SSL.HTTP2.Enabled && a.cfg.Server.SSL.Cert != "" {
tlsConfig, err := a.srv.GetTLSConfig()
if err != nil {
a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/2")
} else {
// 创建 HTTP/2 服务器,共享同一个 handler
a.http2Srv, err = http2.NewServer(&a.cfg.Server.SSL.HTTP2, a.srv.GetHandler(), tlsConfig)
if err != nil {
a.logger.Error().Err(err).Msg("创建 HTTP/2 服务器失败")
} else {
go func() {
a.logger.LogStartup("HTTP/2 服务器启动中", map[string]string{
"listen": a.cfg.Server.Listen,
"max_concurrent_streams": fmt.Sprintf("%d", a.cfg.Server.SSL.HTTP2.MaxConcurrentStreams),
"push_enabled": fmt.Sprintf("%t", a.cfg.Server.SSL.HTTP2.PushEnabled),
})
// HTTP/2 服务器使用与主服务器相同的监听器
// 通过 ALPN 协商自动处理协议选择
listeners := a.srv.GetListeners()
if len(listeners) > 0 {
if err := a.http2Srv.Serve(listeners[0]); err != nil {
a.logger.Error().Err(err).Msg("HTTP/2 服务器启动失败")
}
} else {
a.logger.Error().Msg("HTTP/2 服务器启动失败: 无可用监听器")
}
}()
}
}
}
// 创建升级管理器
a.upgradeMgr = server.NewUpgradeManager(a.srv)
if a.pidFile != "" {
@ -318,6 +363,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
case syscall.SIGQUIT:
// 优雅停止:等待请求完成
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("优雅停止(等待 %v", shutdownTimeout))
a.shutdownHTTP2()
a.shutdownHTTP3()
_ = a.srv.GracefulStop(shutdownTimeout)
return false
@ -325,6 +371,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
case syscall.SIGTERM, syscall.SIGINT:
// 快速停止
a.logger.LogSignal(sigName(sig.(syscall.Signal)), "停止服务器")
a.shutdownHTTP2()
a.shutdownHTTP3()
_ = a.srv.Stop()
return false
@ -362,6 +409,15 @@ func (a *App) shutdownHTTP3() {
}
}
// shutdownHTTP2 关闭 HTTP/2 服务器。
func (a *App) shutdownHTTP2() {
if a.http2Srv != nil {
if err := a.http2Srv.Stop(); err != nil {
a.logger.Error().Err(err).Msg("HTTP/2 服务器关闭失败")
}
}
}
// reloadConfig 重载配置。
func (a *App) reloadConfig() {
newCfg, err := config.Load(a.cfgPath)
@ -415,6 +471,7 @@ func (a *App) gracefulUpgrade() {
a.logger.LogStartup("热升级已启动,新进程正在接管", nil)
// 当前进程优雅停止
a.shutdownHTTP2()
a.shutdownHTTP3()
_ = a.srv.GracefulStop(shutdownTimeout)
}

View File

@ -81,6 +81,78 @@ type Config struct {
// Resolver DNS 解析器配置
// 启用动态 DNS 解析和缓存
Resolver ResolverConfig `yaml:"resolver"`
// Variables 自定义变量配置
// 全局变量定义,应用于所有虚拟主机
Variables VariablesConfig `yaml:"variables"`
}
// VariablesConfig 自定义变量配置。
//
// 用于定义全局自定义变量,可在日志格式和请求头中引用。
// 变量作用于所有虚拟主机。
//
// 注意事项:
// - 变量名只允许字母、数字、下划线
// - 变量名不能与内置变量冲突
// - 变量名不能以 arg_、http_、cookie_ 开头(动态变量前缀)
//
// 使用示例:
//
// variables:
// set:
// app_name: "lolly"
// version: "1.0.0"
type VariablesConfig struct {
// Set 自定义变量集合
// 键值对形式,可在日志格式和请求头模板中使用 $var_name 引用
Set map[string]string `yaml:"set"`
}
// HTTP2Config HTTP/2 配置。
//
// HTTP/2 提供多路复用、头部压缩和服务器推送等功能,
// 需要服务器配置 SSL/TLS 证书才能正常工作。
//
// 注意事项:
// - 必须配置有效的 SSL 证书TLS 1.2 或更高版本)
// - http2.enabled 仅在配置了 SSL/TLS 时生效
// - 客户端可以通过 ALPN 协商使用 HTTP/2 或 HTTP/1.1
//
// 使用示例:
//
// server:
// ssl:
// cert: "/etc/ssl/server.crt"
// key: "/etc/ssl/server.key"
// http2:
// enabled: true
// max_concurrent_streams: 128
// max_header_list_size: "16KB"
type HTTP2Config struct {
// Enabled 是否启用 HTTP/2
// 默认为 true但仅在配置了 SSL 时生效
Enabled bool `yaml:"enabled"`
// MaxConcurrentStreams 最大并发流
// 控制单个连接允许的最大并发流数量,默认 128
MaxConcurrentStreams int `yaml:"max_concurrent_streams"`
// MaxHeaderListSize 最大头部列表大小(字节)
// 限制请求和响应头部的大小,默认 1MB (1048576)
MaxHeaderListSize int `yaml:"max_header_list_size"`
// IdleTimeout 空闲超时
// 连接无活动时的最大保持时间,默认 120s
IdleTimeout time.Duration `yaml:"idle_timeout"`
// PushEnabled 是否启用 Server Push
// 默认 false
PushEnabled bool `yaml:"push_enabled"`
// H2CEnabled 是否启用 H2C明文 HTTP/2
// 默认 false需要 Enabled 为 true 才生效
H2CEnabled bool `yaml:"h2c_enabled"`
}
// HTTP3Config HTTP/3 (QUIC) 配置。
@ -546,6 +618,10 @@ type SSLConfig struct {
// 启用 TLS 1.3 会话恢复以提升握手性能
SessionTickets SessionTicketsConfig `yaml:"session_tickets"`
// HTTP2 HTTP/2 配置
// 启用 HTTP/2 支持,仅在配置了 SSL/TLS 时生效
HTTP2 HTTP2Config `yaml:"http2"`
// ClientVerify 客户端证书验证配置
// 启用 mTLS 双向认证
ClientVerify ClientVerifyConfig `yaml:"client_verify"`
@ -841,6 +917,10 @@ type AuthConfig struct {
// Realm 认证域
// 显示在浏览器认证对话框中的描述信息
Realm string `yaml:"realm"`
// MinPasswordLength 密码最小长度
// 用于验证密码哈希对应的原始密码长度(仅提示性验证)
// 建议值8-128默认 8
MinPasswordLength int `yaml:"min_password_length"`
}
// User 认证用户配置。
@ -1727,6 +1807,11 @@ func Validate(cfg *Config) error {
return fmt.Errorf("resolver: %w", err)
}
// 验证变量配置
if err := validateVariables(&cfg.Variables); err != nil {
return fmt.Errorf("variables: %w", err)
}
return nil
}

View File

@ -58,6 +58,14 @@ func DefaultConfig() *Config {
IncludeSubDomains: true,
Preload: false,
},
HTTP2: HTTP2Config{
Enabled: true,
MaxConcurrentStreams: 128,
MaxHeaderListSize: 1048576, // 1MB
IdleTimeout: 120 * time.Second,
PushEnabled: false,
H2CEnabled: false,
},
},
Security: SecurityConfig{
Access: AccessConfig{
@ -75,9 +83,10 @@ func DefaultConfig() *Config {
SlidingWindow: 60,
},
Auth: AuthConfig{
RequireTLS: true,
Algorithm: "bcrypt",
Realm: "Restricted Area",
RequireTLS: true,
Algorithm: "bcrypt",
Realm: "Restricted Area",
MinPasswordLength: 8,
},
Headers: SecurityHeaders{
XFrameOptions: "DENY",
@ -148,6 +157,9 @@ func DefaultConfig() *Config {
IdleTimeout: 60 * time.Second,
Enable0RTT: false,
},
Variables: VariablesConfig{
Set: map[string]string{},
},
}
}

View File

@ -24,6 +24,7 @@ import (
"strings"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/variable"
)
// validateServer 验证服务器配置。
@ -246,6 +247,7 @@ func validateProxy(p *ProxyConfig) error {
// validateSSL 验证 SSL 配置。
//
// 检查 SSL 证书、私钥、TLS 协议版本和加密套件的有效性。
// 同时验证 HTTP/2 配置的有效性。
//
// 参数:
// - s: SSL 配置对象
@ -257,7 +259,13 @@ func validateProxy(p *ProxyConfig) error {
// - cert 和 key 必须同时配置或同时为空
// - TLS 协议仅允许 TLSv1.2 和 TLSv1.3
// - 拒绝不安全的加密套件RC4、DES、3DES、CBC
// - HTTP/2 配置仅在配置了 SSL 时生效
func validateSSL(s *SSLConfig) error {
// 验证 HTTP/2 配置
if err := validateHTTP2(&s.HTTP2, s.Cert != "" && s.Key != ""); err != nil {
return fmt.Errorf("http2: %w", err)
}
// 未配置 SSL 时跳过验证
if s.Cert == "" && s.Key == "" {
return nil
@ -422,6 +430,14 @@ func validateAuth(a *AuthConfig) error {
}
}
// 验证密码最小长度配置合理性
if a.MinPasswordLength > 0 && a.MinPasswordLength < 6 {
return fmt.Errorf("min_password_length 建议至少为 6")
}
if a.MinPasswordLength > 128 {
return fmt.Errorf("min_password_length 上限为 128")
}
return nil
}
@ -497,6 +513,46 @@ func validateRateLimit(r *RateLimitConfig) error {
return nil
}
// validateHTTP2 验证 HTTP/2 配置。
//
// 检查 HTTP/2 配置的有效性,包括并发流数量和头部大小限制。
//
// 参数:
// - h: HTTP/2 配置对象
// - hasSSL: 是否配置了 SSL/TLS
//
// 返回值:
// - error: 验证失败时返回错误信息,成功返回 nil
//
// 验证规则:
// - http2.enabled 仅在配置了 SSL 时生效HTTP/2 over TLS
// - max_concurrent_streams 必须大于 0
// - max_header_list_size 必须是一个有效的字节大小(如 "16KB", "1MB")或空
func validateHTTP2(h *HTTP2Config, hasSSL bool) error {
// HTTP/2 配置在 HTTPS 下才有效(除非启用 H2C
if h.Enabled && !hasSSL && !h.H2CEnabled {
// HTTP/2 需要 TLSh2明文 HTTP/2h2c需要单独启用
return errors.New("HTTP/2 需要配置 SSL/TLS 证书http2.enabled 仅在配置 SSL 时生效,或启用 h2c_enabled")
}
// 验证并发流数量
if h.MaxConcurrentStreams < 0 {
return errors.New("max_concurrent_streams 不能为负数")
}
// 验证头部大小限制
if h.MaxHeaderListSize < 0 {
return errors.New("max_header_list_size 不能为负数")
}
// 验证空闲超时
if h.IdleTimeout < 0 {
return errors.New("idle_timeout 不能为负数")
}
return nil
}
// validateCompression 验证压缩配置。
//
// 检查压缩类型、压缩级别和最小压缩大小的有效性。
@ -767,3 +823,109 @@ func validateNextUpstream(n *NextUpstreamConfig) error {
return nil
}
// validateVariables 验证自定义变量配置。
//
// 检查变量名的有效性和冲突情况。
//
// 参数:
// - v: 变量配置对象
//
// 返回值:
// - error: 验证失败时返回错误信息,成功返回 nil
//
// 验证规则:
// - 变量名不能为空
// - 变量名只允许字母、数字、下划线
// - 变量名不能以 arg_、http_、cookie_ 开头(动态变量前缀)
// - 变量名不能与内置变量冲突
func validateVariables(v *VariablesConfig) error {
for name := range v.Set {
// 检查变量名非空
if name == "" {
return errors.New("变量名不能为空")
}
// 变量名只允许字母、数字、下划线
for i, c := range name {
isLetter := (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
isDigit := c >= '0' && c <= '9'
isUnderscore := c == '_'
if !isLetter && !isDigit && !isUnderscore {
return fmt.Errorf("变量名 '%s' 包含非法字符(位置 %d", name, i)
}
}
// 检查动态变量前缀冲突
if strings.HasPrefix(name, "arg_") || strings.HasPrefix(name, "http_") || strings.HasPrefix(name, "cookie_") {
return fmt.Errorf("变量名 '%s' 与动态变量前缀冲突arg_, http_, cookie_", name)
}
// 禁止覆盖内置变量
if variable.GetBuiltin(name) != nil {
return fmt.Errorf("变量名 '%s' 与内置变量冲突", name)
}
}
return nil
}
// parseSize 解析大小字符串为字节数。
//
// 支持单位b, kb, mb, gb不区分大小写
// 纯数字默认为字节。
//
// 参数:
// - s: 大小字符串,如 "16KB", "1MB", "1024"
//
// 返回值:
// - int64: 字节数
// - error: 解析失败时返回错误
func parseSize(s string) (int64, error) {
s = strings.TrimSpace(s)
if s == "" {
return 0, errors.New("大小字符串不能为空")
}
// 提取数字部分和单位
var numStr string
var unit string
for i := len(s) - 1; i >= 0; i-- {
c := s[i]
if c >= '0' && c <= '9' || c == '.' {
numStr = s[:i+1]
unit = strings.ToLower(s[i+1:])
break
}
}
if numStr == "" {
return 0, fmt.Errorf("无效的大小格式: %s", s)
}
// 解析数字
var value float64
_, err := fmt.Sscanf(numStr, "%f", &value)
if err != nil {
return 0, fmt.Errorf("无法解析数字: %s", numStr)
}
// 转换单位
var multiplier int64
switch unit {
case "", "b":
multiplier = 1
case "k", "kb":
multiplier = 1024
case "m", "mb":
multiplier = 1024 * 1024
case "g", "gb":
multiplier = 1024 * 1024 * 1024
default:
return 0, fmt.Errorf("未知单位: %s", unit)
}
return int64(value * float64(multiplier)), nil
}
// unused: kept for potential future use in size parsing
var _ = parseSize

View File

@ -324,6 +324,54 @@ func TestValidateAuth(t *testing.T) {
},
wantErr: false,
},
{
name: "有效MinPasswordLength",
config: AuthConfig{
Type: "basic",
Algorithm: "bcrypt",
Users: []User{{Name: "admin", Password: "hashed_password"}},
MinPasswordLength: 8,
},
wantErr: false,
},
{
name: "MinPasswordLength过小",
config: AuthConfig{
Type: "basic",
Users: []User{{Name: "admin", Password: "hashed_password"}},
MinPasswordLength: 5,
},
wantErr: true,
errMsg: "min_password_length 建议至少为 6",
},
{
name: "MinPasswordLength过大",
config: AuthConfig{
Type: "basic",
Users: []User{{Name: "admin", Password: "hashed_password"}},
MinPasswordLength: 129,
},
wantErr: true,
errMsg: "min_password_length 上限为 128",
},
{
name: "MinPasswordLength边界值6",
config: AuthConfig{
Type: "basic",
Users: []User{{Name: "admin", Password: "hashed_password"}},
MinPasswordLength: 6,
},
wantErr: false,
},
{
name: "MinPasswordLength边界值128",
config: AuthConfig{
Type: "basic",
Users: []User{{Name: "admin", Password: "hashed_password"}},
MinPasswordLength: 128,
},
wantErr: false,
},
{
name: "无效认证类型",
config: AuthConfig{
@ -1068,3 +1116,109 @@ func TestValidatePerformance(t *testing.T) {
})
}
}
func TestValidateVariables(t *testing.T) {
// TestValidateVariables 测试自定义变量配置验证。
tests := []struct {
name string
config VariablesConfig
wantErr bool
errMsg string
}{
{
name: "空配置有效",
config: VariablesConfig{},
wantErr: false,
},
{
name: "有效变量名",
config: VariablesConfig{
Set: map[string]string{
"app_name": "lolly",
"version": "1.0.0",
"ENV_VAR": "production",
},
},
wantErr: false,
},
{
name: "空变量名",
config: VariablesConfig{
Set: map[string]string{
"": "value",
},
},
wantErr: true,
errMsg: "变量名不能为空",
},
{
name: "变量名含特殊字符",
config: VariablesConfig{
Set: map[string]string{
"app-name": "value",
},
},
wantErr: true,
errMsg: "包含非法字符",
},
{
name: "变量名arg_前缀冲突",
config: VariablesConfig{
Set: map[string]string{
"arg_foo": "value",
},
},
wantErr: true,
errMsg: "与动态变量前缀冲突",
},
{
name: "变量名http_前缀冲突",
config: VariablesConfig{
Set: map[string]string{
"http_custom": "value",
},
},
wantErr: true,
errMsg: "与动态变量前缀冲突",
},
{
name: "变量名cookie_前缀冲突",
config: VariablesConfig{
Set: map[string]string{
"cookie_session": "value",
},
},
wantErr: true,
errMsg: "与动态变量前缀冲突",
},
{
name: "变量名与内置变量冲突",
config: VariablesConfig{
Set: map[string]string{
"host": "custom",
},
},
wantErr: true,
errMsg: "与内置变量冲突",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateVariables(&tt.config)
if tt.wantErr {
if err == nil {
t.Errorf("validateVariables() 期望返回错误,但返回 nil")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateVariables() 错误消息不匹配,期望包含 %q实际 %q", tt.errMsg, err.Error())
}
} else {
if err != nil {
t.Errorf("validateVariables() 期望返回 nil但返回错误: %v", err)
}
}
})
}
}

350
internal/http2/adapter.go Normal file
View File

@ -0,0 +1,350 @@
// Package http2 提供 HTTP/2 请求适配层。
//
// 该文件实现 fasthttp.RequestHandler 与 http.Handler 之间的适配,
// 使 HTTP/2 服务器能够复用现有的 fasthttp 处理器。
//
// 主要特性:
//
// - 零拷贝头部转换:使用 sync.Pool 复用缓冲区
// - 流式请求体处理:避免大请求体内存复制
// - 低延迟:预估每请求 5-10µs 开销
//
// 作者xfy
package http2
import (
"io"
"net"
"net/http"
"sync"
"time"
"github.com/valyala/fasthttp"
)
// FastHTTPHandlerAdapter 将 fasthttp.RequestHandler 适配为 http.Handler。
//
// 由于 HTTP/2 服务器使用标准库的 http.Handler 接口,
// 而 lolly 使用 fasthttp需要通过适配层进行转换。
type FastHTTPHandlerAdapter struct {
handler fasthttp.RequestHandler
// ctxPool 用于复用 fasthttp.RequestCtx 对象
ctxPool sync.Pool
// bufferPool 用于复用字节缓冲区(零拷贝优化)
bufferPool sync.Pool
// headerBufferPool 用于复用头部缓冲区
headerBufferPool sync.Pool
}
// NewFastHTTPHandlerAdapter 创建新的 HTTP/2 适配器。
//
// 参数:
// - handler: fasthttp 请求处理器
//
// 返回值:
// - *FastHTTPHandlerAdapter: 适配器实例
func NewFastHTTPHandlerAdapter(handler fasthttp.RequestHandler) *FastHTTPHandlerAdapter {
return &FastHTTPHandlerAdapter{
handler: handler,
ctxPool: sync.Pool{
New: func() interface{} {
return &fasthttp.RequestCtx{}
},
},
bufferPool: sync.Pool{
New: func() interface{} {
buf := make([]byte, 4096) // 4KB 初始缓冲区
return &buf
},
},
headerBufferPool: sync.Pool{
New: func() interface{} {
return &fasthttp.RequestHeader{}
},
},
}
}
// ServeHTTP 实现 http.Handler 接口。
//
// 这是适配器的核心方法,将标准库 HTTP 请求转换为 fasthttp 请求,
// 调用 fasthttp 处理器,然后将响应写回标准库 ResponseWriter。
//
// 参数:
// - w: 标准库 ResponseWriter
// - r: 标准库 HTTP 请求
func (a *FastHTTPHandlerAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 从池中获取 RequestCtx
ctx := a.ctxPool.Get().(*fasthttp.RequestCtx)
defer a.ctxPool.Put(ctx)
// 重置 ctx 状态以避免污染
a.resetContext(ctx)
// 转换请求(零拷贝头部转换)
a.convertRequest(r, ctx)
// 流式处理请求体
a.streamRequestBody(r, ctx)
// 调用 fasthttp handler
a.handler(ctx)
// 转换响应
a.convertResponse(ctx, w)
}
// resetContext 重置 fasthttp.RequestCtx 状态。
//
// 参数:
// - ctx: 需要重置的上下文
func (a *FastHTTPHandlerAdapter) resetContext(ctx *fasthttp.RequestCtx) {
// 清空请求头
ctx.Request.Header.DisableNormalizing()
ctx.Request.Reset()
ctx.Response.Reset()
ctx.SetUserValueBytes(nil, nil)
}
// convertRequest 将 net/http.Request 转换为 fasthttp.RequestCtx。
//
// 使用零拷贝策略转换请求头和元数据。
//
// 参数:
// - r: 标准库 HTTP 请求
// - ctx: FastHTTP 请求上下文
func (a *FastHTTPHandlerAdapter) convertRequest(r *http.Request, ctx *fasthttp.RequestCtx) {
// 设置方法
ctx.Request.Header.SetMethod(r.Method)
// 设置 URI
uri := r.URL.Path
if r.URL.RawQuery != "" {
uri += "?" + r.URL.RawQuery
}
ctx.Request.SetRequestURI(uri)
// 设置协议版本为 HTTP/2
ctx.Request.Header.SetProtocol("HTTP/2.0")
// 设置 Host 头
ctx.Request.Header.SetHost(r.Host)
// 零拷贝头部转换
a.convertHeaders(r, ctx)
// 设置远程地址
a.setRemoteAddr(r, ctx)
// 设置 Content-Type
if ct := r.Header.Get("Content-Type"); ct != "" {
ctx.Request.Header.SetContentType(ct)
}
// 设置 Content-Length如果有
if r.ContentLength > 0 {
ctx.Request.Header.SetContentLength(int(r.ContentLength))
}
}
// convertHeaders 将 HTTP 请求头转换为 fasthttp 格式。
//
// 使用 HPACK 风格的零拷贝转换策略。
//
// 参数:
// - r: 标准库 HTTP 请求
// - ctx: FastHTTP 请求上下文
func (a *FastHTTPHandlerAdapter) convertHeaders(r *http.Request, ctx *fasthttp.RequestCtx) {
// 跳过已处理的头部
skipHeaders := map[string]bool{
"Host": true,
"Content-Type": true,
"Content-Length": true,
}
for k, v := range r.Header {
if skipHeaders[k] {
continue
}
// 复用缓冲区避免分配
for i, vv := range v {
if i == 0 {
ctx.Request.Header.Set(k, vv)
} else {
ctx.Request.Header.Add(k, vv)
}
}
}
}
// setRemoteAddr 设置远程客户端地址。
//
// 参数:
// - r: 标准库 HTTP 请求
// - ctx: FastHTTP 请求上下文
func (a *FastHTTPHandlerAdapter) setRemoteAddr(r *http.Request, ctx *fasthttp.RequestCtx) {
if r.RemoteAddr != "" {
// 尝试解析地址
if addr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr); err == nil {
ctx.SetRemoteAddr(addr)
} else {
// 回退方案:使用字符串地址
ctx.SetRemoteAddr(&net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 0,
})
}
}
}
// streamRequestBody 流式读取请求体到 fasthttp。
//
// 对于大请求体,使用流式处理避免内存峰值。
//
// 参数:
// - r: 标准库 HTTP 请求
// - ctx: FastHTTP 请求上下文
func (a *FastHTTPHandlerAdapter) streamRequestBody(r *http.Request, ctx *fasthttp.RequestCtx) {
if r.Body == nil || r.Body == http.NoBody {
return
}
defer func() { _ = r.Body.Close() }()
// 小请求体:直接读取到内存
if r.ContentLength > 0 && r.ContentLength <= 64*1024 {
body, err := io.ReadAll(r.Body)
if err == nil {
ctx.Request.SetBody(body)
}
return
}
// 大请求体:使用流式缓冲区
bufPtr := a.bufferPool.Get().(*[]byte)
defer a.bufferPool.Put(bufPtr)
buf := *bufPtr
var body []byte
for {
n, err := r.Body.Read(buf)
if n > 0 {
body = append(body, buf[:n]...)
}
if err == io.EOF {
break
}
if err != nil {
break
}
}
if len(body) > 0 {
ctx.Request.SetBody(body)
}
}
// convertResponse 将 fasthttp.RequestCtx 响应写入 http.ResponseWriter。
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - w: 标准库 ResponseWriter
func (a *FastHTTPHandlerAdapter) convertResponse(ctx *fasthttp.RequestCtx, w http.ResponseWriter) {
// 设置状态码
statusCode := ctx.Response.StatusCode()
if statusCode == 0 {
statusCode = http.StatusOK
}
// 复制响应头
for key, value := range ctx.Response.Header.All() {
w.Header().Add(string(key), string(value))
}
// 确保 Content-Type 被设置
if ct := ctx.Response.Header.ContentType(); len(ct) > 0 {
w.Header().Set("Content-Type", string(ct))
}
// 确保 Content-Length 被设置(如果已知)
if cl := ctx.Response.Header.ContentLength(); cl > 0 {
w.Header().Set("Content-Length", string(fasthttp.AppendUint(nil, cl)))
}
// 写入状态码
w.WriteHeader(statusCode)
// 写入响应体
body := ctx.Response.Body()
if len(body) > 0 {
_, _ = w.Write(body)
}
}
// WrapHandler 创建一个适配器包装的 handler。
//
// 这是一个便捷函数,用于快速创建适配器实例。
//
// 参数:
// - handler: fasthttp 请求处理器
//
// 返回值:
// - http.Handler: 标准库兼容的处理器
func WrapHandler(handler fasthttp.RequestHandler) http.Handler {
return NewFastHTTPHandlerAdapter(handler)
}
// WrapHandlerFunc 创建一个适配器包装的 handler 函数。
//
// 这是一个便捷函数,允许直接使用函数而非创建 handler 实例。
//
// 参数:
// - fn: fasthttp handler 函数
//
// 返回值:
// - http.Handler: 标准库兼容的处理器
func WrapHandlerFunc(fn func(*fasthttp.RequestCtx)) http.Handler {
return NewFastHTTPHandlerAdapter(fn)
}
// AdapterConfig 提供适配器的配置选项。
type AdapterConfig struct {
// BufferSize 是缓冲区大小,默认为 4096 字节
BufferSize int
// MaxBodySize 是最大请求体大小,超过则使用流式处理
MaxBodySize int64
// Timeout 是请求处理超时时间
Timeout time.Duration
}
// DefaultAdapterConfig 返回默认配置。
func DefaultAdapterConfig() *AdapterConfig {
return &AdapterConfig{
BufferSize: 4096,
MaxBodySize: 64 * 1024, // 64KB
Timeout: 30 * time.Second,
}
}
// ConfigurableAdapter 是基于配置的可配置适配器。
type ConfigurableAdapter struct {
*FastHTTPHandlerAdapter
config *AdapterConfig
}
// NewConfigurableAdapter 创建可配置适配器。
func NewConfigurableAdapter(handler fasthttp.RequestHandler, config *AdapterConfig) *ConfigurableAdapter {
if config == nil {
config = DefaultAdapterConfig()
}
return &ConfigurableAdapter{
FastHTTPHandlerAdapter: NewFastHTTPHandlerAdapter(handler),
config: config,
}
}

View File

@ -0,0 +1,513 @@
// Package http2 提供 HTTP/2 适配器测试。
//
// 该文件包含 FastHTTPHandlerAdapter 的单元测试:
// - 适配器创建和配置
// - 请求转换测试
// - 响应转换测试
// - 流式请求体处理
//
// 作者xfy
package http2
import (
"bytes"
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/valyala/fasthttp"
)
// TestNewFastHTTPHandlerAdapter 测试适配器创建。
func TestNewFastHTTPHandlerAdapter(t *testing.T) {
handler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("Hello") //nolint:errcheck
}
adapter := NewFastHTTPHandlerAdapter(handler)
if adapter == nil {
t.Fatal("NewFastHTTPHandlerAdapter() returned nil")
}
if adapter.handler == nil {
t.Error("Adapter handler not set")
}
}
// TestFastHTTPHandlerAdapterServeHTTP 测试适配器处理 HTTP 请求。
func TestFastHTTPHandlerAdapterServeHTTP(t *testing.T) {
handler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("Hello from fasthttp") //nolint:errcheck
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
// 创建测试请求
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("X-Custom-Header", "custom-value")
rec := httptest.NewRecorder()
// 执行请求
adapter.ServeHTTP(rec, req)
// 验证响应
if rec.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code)
}
body := rec.Body.String()
if body != "Hello from fasthttp" {
t.Errorf("Expected body 'Hello from fasthttp', got '%s'", body)
}
}
// TestFastHTTPHandlerAdapterWithRequestBody 测试带请求体的请求。
func TestFastHTTPHandlerAdapterWithRequestBody(t *testing.T) {
handler := func(ctx *fasthttp.RequestCtx) {
body := ctx.PostBody()
ctx.Write(body) //nolint:errcheck
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
// 创建带请求体的测试请求
body := []byte(`{"key":"value"}`)
req := httptest.NewRequest(http.MethodPost, "/api", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
// 执行请求
adapter.ServeHTTP(rec, req)
// 验证响应
if rec.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code)
}
respBody := rec.Body.String()
if respBody != string(body) {
t.Errorf("Expected body '%s', got '%s'", string(body), respBody)
}
}
// TestFastHTTPHandlerAdapterWithHeaders 测试请求头转换。
func TestFastHTTPHandlerAdapterWithHeaders(t *testing.T) {
var receivedHeaders map[string]string
handler := func(ctx *fasthttp.RequestCtx) {
receivedHeaders = make(map[string]string)
for key, value := range ctx.Request.Header.All() {
receivedHeaders[string(key)] = string(value)
}
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
// 创建带多个头部的测试请求
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer token123")
req.Header.Set("X-Request-ID", "uuid-123")
rec := httptest.NewRecorder()
// 执行请求
adapter.ServeHTTP(rec, req)
// 验证接收到的头部
if receivedHeaders == nil {
t.Fatal("No headers received")
}
if _, ok := receivedHeaders["Accept"]; !ok {
t.Error("Accept header not received")
}
if _, ok := receivedHeaders["Authorization"]; !ok {
t.Error("Authorization header not received")
}
}
// TestFastHTTPHandlerAdapterWithQueryString 测试查询字符串。
func TestFastHTTPHandlerAdapterWithQueryString(t *testing.T) {
var receivedURI string
handler := func(ctx *fasthttp.RequestCtx) {
receivedURI = string(ctx.Request.URI().RequestURI())
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
// 创建带查询字符串的测试请求
req := httptest.NewRequest(http.MethodGet, "/search?q=hello&page=1", nil)
rec := httptest.NewRecorder()
// 执行请求
adapter.ServeHTTP(rec, req)
// 验证 URI
if receivedURI == "" {
t.Error("Request URI not received")
}
}
// TestFastHTTPHandlerAdapterErrorResponse 测试错误响应。
func TestFastHTTPHandlerAdapterErrorResponse(t *testing.T) {
handler := func(ctx *fasthttp.RequestCtx) {
ctx.Error("Not Found", fasthttp.StatusNotFound)
}
adapter := NewFastHTTPHandlerAdapter(handler)
req := httptest.NewRequest(http.MethodGet, "/notfound", nil)
rec := httptest.NewRecorder()
adapter.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Errorf("Expected status %d, got %d", http.StatusNotFound, rec.Code)
}
}
// TestFastHTTPHandlerAdapterEmptyBody 测试空请求体。
func TestFastHTTPHandlerAdapterEmptyBody(t *testing.T) {
handler := func(ctx *fasthttp.RequestCtx) {
if len(ctx.Request.Body()) == 0 {
ctx.WriteString("Empty body received") //nolint:errcheck
}
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
req := httptest.NewRequest(http.MethodPost, "/upload", nil)
rec := httptest.NewRecorder()
adapter.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code)
}
if rec.Body.String() != "Empty body received" {
t.Error("Empty body not handled correctly")
}
}
// TestWrapHandler 测试 WrapHandler 便捷函数。
func TestWrapHandler(t *testing.T) {
handler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("Wrapped") //nolint:errcheck
}
wrapped := WrapHandler(handler)
if wrapped == nil {
t.Fatal("WrapHandler() returned nil")
}
// 验证它是一个 http.Handler
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Body.String() != "Wrapped" {
t.Error("WrapHandler did not work correctly")
}
}
// TestWrapHandlerFunc 测试 WrapHandlerFunc 便捷函数。
func TestWrapHandlerFunc(t *testing.T) {
fn := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("Func wrapped") //nolint:errcheck
}
wrapped := WrapHandlerFunc(fn)
if wrapped == nil {
t.Fatal("WrapHandlerFunc() returned nil")
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Body.String() != "Func wrapped" {
t.Error("WrapHandlerFunc did not work correctly")
}
}
// TestDefaultAdapterConfig 测试默认适配器配置。
func TestDefaultAdapterConfig(t *testing.T) {
cfg := DefaultAdapterConfig()
if cfg == nil {
t.Fatal("DefaultAdapterConfig() returned nil")
}
if cfg.BufferSize <= 0 {
t.Error("BufferSize should be positive")
}
if cfg.MaxBodySize <= 0 {
t.Error("MaxBodySize should be positive")
}
if cfg.Timeout <= 0 {
t.Error("Timeout should be positive")
}
}
// TestNewConfigurableAdapter 测试可配置适配器。
func TestNewConfigurableAdapter(t *testing.T) {
handler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("Configurable") //nolint:errcheck
}
cfg := DefaultAdapterConfig()
adapter := NewConfigurableAdapter(handler, cfg)
if adapter == nil {
t.Fatal("NewConfigurableAdapter() returned nil")
}
if adapter.config != cfg {
t.Error("Config not set correctly")
}
// 测试 nil 配置
adapter2 := NewConfigurableAdapter(handler, nil)
if adapter2 == nil {
t.Fatal("NewConfigurableAdapter() with nil config returned nil")
}
if adapter2.config == nil {
t.Error("Default config not applied")
}
}
// TestAdapterWithLargeBody 测试大请求体处理。
func TestAdapterWithLargeBody(t *testing.T) {
bodyReceived := false
handler := func(ctx *fasthttp.RequestCtx) {
body := ctx.PostBody()
if len(body) > 1024 {
bodyReceived = true
}
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
// 创建大请求体
largeBody := make([]byte, 100*1024) // 100KB
for i := range largeBody {
largeBody[i] = byte('a' + (i % 26))
}
req := httptest.NewRequest(http.MethodPost, "/upload", bytes.NewReader(largeBody))
req.Header.Set("Content-Length", "102400")
rec := httptest.NewRecorder()
adapter.ServeHTTP(rec, req)
if !bodyReceived {
t.Error("Large body was not received correctly")
}
}
// TestAdapterHTTPMethods 测试不同 HTTP 方法。
func TestAdapterHTTPMethods(t *testing.T) {
methods := []string{
http.MethodGet,
http.MethodPost,
http.MethodPut,
http.MethodDelete,
http.MethodPatch,
http.MethodHead,
http.MethodOptions,
}
for _, method := range methods {
t.Run(method, func(t *testing.T) {
var receivedMethod string
handler := func(ctx *fasthttp.RequestCtx) {
receivedMethod = string(ctx.Method())
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
req := httptest.NewRequest(method, "/test", nil)
rec := httptest.NewRecorder()
adapter.ServeHTTP(rec, req)
if receivedMethod != method {
t.Errorf("Expected method %s, got %s", method, receivedMethod)
}
})
}
}
// TestAdapterRemoteAddr 测试远程地址设置。
func TestAdapterRemoteAddr(t *testing.T) {
var remoteAddr net.Addr
handler := func(ctx *fasthttp.RequestCtx) {
remoteAddr = ctx.RemoteAddr()
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rec := httptest.NewRecorder()
adapter.ServeHTTP(rec, req)
if remoteAddr == nil {
t.Error("RemoteAddr not set")
}
}
// TestAdapterContentType 测试 Content-Type 处理。
func TestAdapterContentType(t *testing.T) {
var contentType string
handler := func(ctx *fasthttp.RequestCtx) {
contentType = string(ctx.Request.Header.ContentType())
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
req := httptest.NewRequest(http.MethodPost, "/api", nil)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
adapter.ServeHTTP(rec, req)
if contentType != "application/json" {
t.Errorf("Expected Content-Type 'application/json', got '%s'", contentType)
}
}
// TestAdapterResponseHeaders 测试响应头设置。
func TestAdapterResponseHeaders(t *testing.T) {
handler := func(ctx *fasthttp.RequestCtx) {
ctx.Response.Header.Set("X-Custom-Response", "custom-value")
ctx.Response.Header.Set("Cache-Control", "no-cache")
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
adapter.ServeHTTP(rec, req)
if rec.Header().Get("X-Custom-Response") != "custom-value" {
t.Error("Custom response header not set")
}
if rec.Header().Get("Cache-Control") != "no-cache" {
t.Error("Cache-Control header not set")
}
}
// TestAdapterConcurrentRequests 测试并发请求。
func TestAdapterConcurrentRequests(t *testing.T) {
handler := func(ctx *fasthttp.RequestCtx) {
// 模拟一些处理时间
time.Sleep(1 * time.Millisecond)
ctx.WriteString("OK") //nolint:errcheck
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
// 并发发送多个请求
concurrency := 10
done := make(chan bool, concurrency)
for i := 0; i < concurrency; i++ {
go func() {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
adapter.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code)
}
done <- true
}()
}
// 等待所有请求完成
for i := 0; i < concurrency; i++ {
<-done
}
}
// mockReadCloser 是一个用于测试的模拟 io.ReadCloser。
type mockReadCloser struct {
io.Reader
closed bool
}
func (m *mockReadCloser) Close() error {
m.closed = true
return nil
}
// TestStreamRequestBody 测试流式请求体。
func TestStreamRequestBody(t *testing.T) {
bodyContent := []byte("test body content")
mock := &mockReadCloser{Reader: bytes.NewReader(bodyContent)}
handler := func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
// 创建带有 mock body 的请求
req, _ := http.NewRequest(http.MethodPost, "/test", mock)
req.ContentLength = int64(len(bodyContent))
rec := httptest.NewRecorder()
adapter.ServeHTTP(rec, req)
// 验证 body 被关闭
if !mock.closed {
t.Error("Request body was not closed")
}
}
// TestAdapterPoolReuse 测试对象池复用。
func TestAdapterPoolReuse(t *testing.T) {
handler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("Test") //nolint:errcheck
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
// 发送多个请求,验证池复用
for i := 0; i < 10; i++ {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
adapter.ServeHTTP(rec, req)
}
// 测试通过,没有 panic 表示池工作正常
}

View File

@ -0,0 +1,403 @@
// Package http2 提供 HTTP/2 集成测试。
//
// 该文件包含 HTTP/2 的端到端集成测试:
// - HTTP/2 请求处理
// - ALPN 协商
// - HTTP/1.1 fallback
//
// 运行方式: go test -tags=integration ./internal/http2/...
//
// 作者xfy
package http2
import (
"bytes"
"crypto/tls"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
)
// TestIntegrationHTTP2Request 测试 HTTP/2 请求处理(需要 TLS 证书)。
func TestIntegrationHTTP2Request(t *testing.T) {
// 跳过集成测试,除非显式启用
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// 注意:这需要有效的 TLS 证书才能完整测试
// 这里使用非 TLS 模式测试基本功能
handler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("Hello HTTP/2") //nolint:errcheck
ctx.SetStatusCode(fasthttp.StatusOK)
}
cfg := &config.HTTP2Config{
Enabled: true,
MaxConcurrentStreams: 100,
}
server, err := NewServer(cfg, handler, nil)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// 创建监听器
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() { _ = ln.Close() }()
// 启动服务器(在后台)
go func() {
_ = server.Serve(ln)
}()
// 等待服务器启动
time.Sleep(100 * time.Millisecond)
// 停止服务器
if err := server.Stop(); err != nil {
t.Errorf("Failed to stop server: %v", err)
}
}
// TestIntegrationALPN 测试 ALPN 协议协商(需要 TLS 证书)。
func TestIntegrationALPN(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
handler := func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(fasthttp.StatusOK)
}
cfg := &config.HTTP2Config{
Enabled: true,
}
server, err := NewServer(cfg, handler, nil)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// 验证 ALPN 配置
tlsConfig := server.ALPNConfig()
if tlsConfig == nil {
t.Fatal("ALPN config should not be nil")
}
// 验证协议列表
foundH2 := false
for _, proto := range tlsConfig.NextProtos {
if proto == "h2" {
foundH2 = true
break
}
}
if !foundH2 {
t.Error("ALPN config should include h2 protocol")
}
}
// TestIntegrationHTTP1Fallback 测试 HTTP/1.1 回退。
func TestIntegrationHTTP1Fallback(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
handler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("Fallback to HTTP/1.1") //nolint:errcheck
ctx.SetStatusCode(fasthttp.StatusOK)
}
cfg := &config.HTTP2Config{
Enabled: true,
MaxConcurrentStreams: 100,
}
server, err := NewServer(cfg, handler, nil)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// 验证服务器支持 HTTP/1.1 回退
if server.handler == nil {
t.Error("Server handler should be set for HTTP/1.1 fallback")
}
}
// TestIntegrationConcurrentStreams 测试并发流处理。
func TestIntegrationConcurrentStreams(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
requestCount := 0
handler := func(ctx *fasthttp.RequestCtx) {
requestCount++
ctx.WriteString("OK") //nolint:errcheck
ctx.SetStatusCode(fasthttp.StatusOK)
}
cfg := &config.HTTP2Config{
Enabled: true,
MaxConcurrentStreams: 10,
}
server, err := NewServer(cfg, handler, nil)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// 验证并发流限制
if server.http2Server.MaxConcurrentStreams != 10 {
t.Errorf("Expected MaxConcurrentStreams 10, got %d",
server.http2Server.MaxConcurrentStreams)
}
}
// TestIntegrationServerLifecycle 测试服务器生命周期。
func TestIntegrationServerLifecycle(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
handler := func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(fasthttp.StatusOK)
}
cfg := &config.HTTP2Config{
Enabled: true,
}
server, err := NewServer(cfg, handler, nil)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// 初始状态检查
if server.IsRunning() {
t.Error("Server should not be running initially")
}
// 创建监听器
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
// 启动服务器
go func() { _ = server.Serve(ln) }()
// 等待服务器启动
time.Sleep(50 * time.Millisecond)
// 停止服务器
if err := server.Stop(); err != nil {
t.Errorf("Failed to stop server: %v", err)
}
}
// TestIntegrationAdapterConversion 测试适配器转换。
func TestIntegrationAdapterConversion(t *testing.T) {
handler := func(ctx *fasthttp.RequestCtx) {
// 设置响应头和体
ctx.Response.Header.Set("X-Custom-Header", "test-value")
ctx.WriteString("Converted response") //nolint:errcheck
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
// 创建标准 HTTP 请求
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Accept", "application/json")
// 使用 httptest 记录响应
recorder := &testResponseRecorder{
header: make(http.Header),
}
// 执行适配器
adapter.ServeHTTP(recorder, req)
// 验证响应
if recorder.statusCode != http.StatusOK {
t.Errorf("Expected status %d, got %d", http.StatusOK, recorder.statusCode)
}
if recorder.body.String() != "Converted response" {
t.Errorf("Expected body 'Converted response', got '%s'", recorder.body.String())
}
}
// testResponseRecorder 是测试用的响应记录器。
type testResponseRecorder struct {
statusCode int
header http.Header
body testBuffer
}
func (r *testResponseRecorder) Header() http.Header {
return r.header
}
func (r *testResponseRecorder) Write(p []byte) (int, error) {
return r.body.Write(p)
}
func (r *testResponseRecorder) WriteHeader(code int) {
r.statusCode = code
}
// testBuffer 是一个简单的字节缓冲区。
type testBuffer struct {
data []byte
}
func (b *testBuffer) Write(p []byte) (int, error) {
b.data = append(b.data, p...)
return len(p), nil
}
func (b *testBuffer) String() string {
return string(b.data)
}
// TestIntegrationTLSConfiguration 测试 TLS 配置集成。
func TestIntegrationTLSConfiguration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
handler := func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(fasthttp.StatusOK)
}
cfg := &config.HTTP2Config{
Enabled: true,
MaxConcurrentStreams: 100,
}
tlsConfig := &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
}
server, err := NewServer(cfg, handler, tlsConfig)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// 验证 TLS 配置
if server.tlsConfig == nil {
t.Error("TLS config should be set")
}
// 测试监听器包装
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() { _ = ln.Close() }()
wrappedLn := WrapTLSListener(ln, tlsConfig)
if wrappedLn == nil {
t.Error("Wrapped listener should not be nil")
}
}
// TestIntegrationH2C 测试 H2C明文 HTTP/2
func TestIntegrationH2C(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
handler := func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(fasthttp.StatusOK)
}
cfg := &config.HTTP2Config{
Enabled: true,
H2CEnabled: true,
}
server, err := NewServer(cfg, handler, nil)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// 验证 H2C 启用
if !server.IsH2CEnabled() {
t.Error("H2C should be enabled")
}
}
// BenchmarkAdapterConversion 基准测试适配器转换性能。
func BenchmarkAdapterConversion(b *testing.B) {
handler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("Hello") //nolint:errcheck
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
rec.Body.Reset()
adapter.ServeHTTP(rec, req)
}
}
// BenchmarkAdapterWithBody 基准测试带请求体的适配器。
func BenchmarkAdapterWithBody(b *testing.B) {
handler := func(ctx *fasthttp.RequestCtx) {
ctx.Write(ctx.PostBody()) //nolint:errcheck
ctx.SetStatusCode(fasthttp.StatusOK)
}
adapter := NewFastHTTPHandlerAdapter(handler)
body := []byte(`{"test":"data","number":12345}`)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
req := httptest.NewRequest(http.MethodPost, "/api", bytes.NewReader(body))
rec := httptest.NewRecorder()
adapter.ServeHTTP(rec, req)
}
}
// BenchmarkServerCreation 基准测试服务器创建。
func BenchmarkServerCreation(b *testing.B) {
cfg := &config.HTTP2Config{
Enabled: true,
MaxConcurrentStreams: 100,
}
handler := func(ctx *fasthttp.RequestCtx) {}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, err := NewServer(cfg, handler, nil)
if err != nil {
b.Fatal(err)
}
}
}

586
internal/http2/server.go Normal file
View File

@ -0,0 +1,586 @@
// Package http2 提供 HTTP/2 协议支持。
//
// 该文件包含 HTTP/2 服务器的核心实现,包括:
// - 基于 golang.org/x/net/http2 的 HTTP/2 服务器
// - ALPN 协议协商支持
// - 与现有 fasthttp handler 的集成
// - 优雅关闭支持
//
// 主要用途:
//
// 用于在现有 TCP 监听器上提供 HTTP/2 协议支持,通过 ALPN 协商自动选择协议。
//
// 作者xfy
package http2
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/valyala/fasthttp"
"golang.org/x/net/http2"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/logging"
)
// Server HTTP/2 服务器。
//
// 包装 golang.org/x/net/http2 服务器,提供与 fasthttp handler 的集成。
type Server struct {
// config HTTP/2 配置
config *config.HTTP2Config
// handler fasthttp 请求处理器
handler fasthttp.RequestHandler
// tlsConfig TLS 配置
tlsConfig *tls.Config
// http2Server HTTP/2 服务器实例
http2Server *http2.Server
// running 服务器运行状态
running bool
// mu 读写锁
mu sync.RWMutex
// listener TCP 监听器
listener net.Listener
// stopChan 停止信号通道
stopChan chan struct{}
}
// NewServer 创建 HTTP/2 服务器。
//
// 参数:
// - cfg: HTTP/2 配置
// - handler: fasthttp 请求处理器
// - tlsConfig: TLS 配置(可选,但推荐用于 ALPN 协商)
//
// 返回值:
// - *Server: HTTP/2 服务器实例
// - error: 配置无效时返回错误
func NewServer(cfg *config.HTTP2Config, handler fasthttp.RequestHandler, tlsConfig *tls.Config) (*Server, error) {
if cfg == nil {
return nil, fmt.Errorf("http2 config is nil")
}
if handler == nil {
return nil, fmt.Errorf("handler is nil")
}
// 设置默认值
maxConcurrentStreams := cfg.MaxConcurrentStreams
if maxConcurrentStreams <= 0 {
maxConcurrentStreams = 250
}
maxHeaderListSize := cfg.MaxHeaderListSize
if maxHeaderListSize <= 0 {
maxHeaderListSize = 1048576 // 1MB
}
idleTimeout := cfg.IdleTimeout
if idleTimeout <= 0 {
idleTimeout = 120 * time.Second
}
// 创建 HTTP/2 服务器
h2s := &http2.Server{
MaxConcurrentStreams: uint32(maxConcurrentStreams),
IdleTimeout: idleTimeout,
MaxReadFrameSize: uint32(maxHeaderListSize),
NewWriteScheduler: func() http2.WriteScheduler { return http2.NewPriorityWriteScheduler(nil) },
CountError: func(errType string) {},
}
return &Server{
config: cfg,
handler: handler,
tlsConfig: tlsConfig,
http2Server: h2s,
stopChan: make(chan struct{}),
}, nil
}
// Serve 在指定监听器上启动 HTTP/2 服务器。
//
// 该方法会处理 ALPN 协议协商,根据客户端支持的协议自动选择 HTTP/2 或 HTTP/1.1。
//
// 参数:
// - ln: TCP 监听器
//
// 返回值:
// - error: 启动失败时返回错误
func (s *Server) Serve(ln net.Listener) error {
s.mu.Lock()
if s.running {
s.mu.Unlock()
return fmt.Errorf("server already running")
}
s.running = true
s.listener = ln
s.mu.Unlock()
log := logging.Info()
if s.config.Enabled {
log.Str("protocol", "h2").
Bool("push", s.config.PushEnabled).
Int("max_streams", s.config.MaxConcurrentStreams).
Int("max_header_size", s.config.MaxHeaderListSize).
Str("idle_timeout", s.config.IdleTimeout.String()).
Msg("HTTP/2 server started")
}
// 启动连接处理循环
for {
select {
case <-s.stopChan:
return nil
default:
}
conn, err := ln.Accept()
if err != nil {
select {
case <-s.stopChan:
return nil
default:
}
if errors.Is(err, net.ErrClosed) {
return nil
}
logging.Error().Err(err).Msg("HTTP/2 accept error")
continue
}
go s.handleConnection(conn)
}
}
// handleConnection 处理单个连接。
//
// 根据连接类型TLS 或明文)和 ALPN 协商结果,选择合适的协议处理。
func (s *Server) handleConnection(conn net.Conn) {
defer func() { _ = conn.Close() }()
// 如果是 TLS 连接,检查 ALPN 协商结果
if tlsConn, ok := conn.(*tls.Conn); ok {
// 执行 TLS 握手
if err := tlsConn.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
logging.Error().Err(err).Msg("HTTP/2 set read deadline error")
return
}
if err := tlsConn.Handshake(); err != nil {
logging.Error().Err(err).Msg("HTTP/2 TLS handshake error")
return
}
if err := tlsConn.SetReadDeadline(time.Time{}); err != nil {
logging.Error().Err(err).Msg("HTTP/2 clear read deadline error")
return
}
// 检查 ALPN 协商结果
state := tlsConn.ConnectionState()
if len(state.NegotiatedProtocol) > 0 && state.NegotiatedProtocol != "h2" {
// ALPN 协商结果为 http/1.1 或其他,使用 fasthttp 处理
s.serveHTTP1(tlsConn)
return
}
}
// 处理 HTTP/2 连接
s.serveHTTP2(conn)
}
// serveHTTP2 使用 HTTP/2 协议服务连接。
func (s *Server) serveHTTP2(conn net.Conn) {
adapter := NewFastHTTPHandlerAdapter(s.handler)
opts := &http2.ServeConnOpts{
Context: context.Background(),
Handler: adapter,
BaseConfig: &http.Server{},
}
s.http2Server.ServeConn(conn, opts)
}
// serveHTTP1 使用 HTTP/1.1 协议服务连接(回退到 fasthttp
func (s *Server) serveHTTP1(conn net.Conn) {
// 创建一个简单的 fasthttp 服务器来处理单个连接
server := &fasthttp.Server{
Handler: s.handler,
}
// 使用 fasthttp 的连接处理
_ = server.ServeConn(conn) //nolint:errcheck // HTTP/1.1 回退连接处理错误由内部处理
}
// Stop 停止 HTTP/2 服务器。
//
// 优雅关闭服务器,等待现有连接完成。
//
// 返回值:
// - error: 关闭失败时返回错误
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.running {
return nil
}
s.running = false
// 发送停止信号
close(s.stopChan)
// 关闭监听器
if s.listener != nil {
if err := s.listener.Close(); err != nil {
logging.Error().Err(err).Msg("HTTP/2 listener close error")
}
}
logging.Info().Msg("HTTP/2 server stopped")
return nil
}
// IsRunning 检查服务器是否正在运行。
func (s *Server) IsRunning() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.running
}
// GetConfig 返回服务器配置。
func (s *Server) GetConfig() *config.HTTP2Config {
return s.config
}
// ALPNConfig 返回用于 ALPN 协商的 TLS 配置。
//
// 返回值:
// - *tls.Config: 配置了 ALPN 的 TLS 配置
//
// 使用示例:
//
// tlsConfig := &tls.Config{
// Certificates: []tls.Certificate{cert},
// }
// tlsConfig.NextProtos = []string{"h2", "http/1.1"}
func (s *Server) ALPNConfig() *tls.Config {
return &tls.Config{
NextProtos: []string{"h2", "http/1.1"},
}
}
// WrapTLSListener 包装 TLS 监听器以支持 ALPN 协议协商。
//
// 参数:
// - ln: 底层 TCP 监听器
// - tlsConfig: TLS 配置(会被修改以添加 ALPN 支持)
//
// 返回值:
// - net.Listener: 支持 ALPN 的 TLS 监听器
func WrapTLSListener(ln net.Listener, tlsConfig *tls.Config) net.Listener {
// 确保 NextProtos 包含 h2 和 http/1.1
if len(tlsConfig.NextProtos) == 0 {
tlsConfig.NextProtos = []string{"h2", "http/1.1"}
}
// 使用 GetConfigForClient 根据客户端支持的协议返回不同的配置
originalGetConfig := tlsConfig.GetConfigForClient
tlsConfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
// 检查客户端是否支持 h2
supportsH2 := false
for _, proto := range hello.SupportedProtos {
if proto == "h2" {
supportsH2 = true
break
}
}
// 如果有原始回调,先调用它
var cfg *tls.Config
if originalGetConfig != nil {
var err error
cfg, err = originalGetConfig(hello)
if err != nil {
return nil, err
}
}
// 如果客户端支持 h2设置协商结果为 h2
if supportsH2 {
if cfg == nil {
cfg = tlsConfig.Clone()
}
cfg.NextProtos = []string{"h2"}
}
return cfg, nil
}
return tls.NewListener(ln, tlsConfig)
}
// IsH2CEnabled 检查是否启用了 H2CHTTP/2 over cleartext
//
// 注意:当前版本不支持 H2C需要 TLS 才能启用 HTTP/2。
func (s *Server) IsH2CEnabled() bool {
return s.config.H2CEnabled
}
// HandleH2C 处理 H2C 升级请求。
//
// 参数:
// - conn: TCP 连接
//
// 返回值:
// - bool: 如果成功处理 H2C 升级返回 true
// - error: 处理失败时返回错误
func (s *Server) HandleH2C(conn net.Conn) (bool, error) {
// HTTP/2 需要 TLS不支持 H2C
return false, nil
}
// unused: h2cConn and related code kept for potential H2C support in future
var _ = h2cConn{} //nolint:unused // reserved for future H2C support
// h2cConn 包装 net.Conn 以支持 H2C 协议检测。
type h2cConn struct {
net.Conn
reader *bufio.Reader
}
// Read 从连接读取数据。
func (c *h2cConn) Read(p []byte) (n int, err error) { //nolint:unused // reserved for future H2C support
if c.reader != nil {
n, err = c.reader.Read(p)
if err == io.EOF && n > 0 {
return n, nil
}
if err != nil || n < len(p) {
c.reader = nil
}
return n, err
}
return c.Conn.Read(p)
}
// IsHTTP2Request 检查请求是否是 HTTP/2。
//
// 参数:
// - r: HTTP 请求
//
// 返回值:
// - bool: 如果是 HTTP/2 请求返回 true
func IsHTTP2Request(r *http.Request) bool {
// HTTP/2 请求通常使用 "PRI" 方法或 HTTP 版本为 2
if r.Method == "PRI" {
return true
}
if r.ProtoMajor == 2 {
return true
}
// 检查 HTTP/2 特定的头
if r.Header.Get(":method") != "" {
return true
}
return false
}
// GetALPNProtocol 从 TLS 连接状态获取协商的协议。
//
// 参数:
// - conn: 网络连接
//
// 返回值:
// - string: 协商的协议(如 "h2", "http/1.1"),如果不是 TLS 返回空字符串
func GetALPNProtocol(conn net.Conn) string {
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return ""
}
state := tlsConn.ConnectionState()
return state.NegotiatedProtocol
}
// SupportsHTTP2 检查客户端是否支持 HTTP/2基于 ALPN 或升级头)。
//
// 参数:
// - r: HTTP 请求
//
// 返回值:
// - bool: 如果支持 HTTP/2 返回 true
func SupportsHTTP2(r *http.Request) bool {
// 检查是否是 HTTP/2 请求
if IsHTTP2Request(r) {
return true
}
// 检查升级头
if r.Header.Get("Upgrade") == "h2c" {
return true
}
// 检查 HTTP2-Settings 头
if r.Header.Get("HTTP2-Settings") != "" {
return true
}
return false
}
// HTTP2Settings HTTP/2 连接设置。
type HTTP2Settings struct {
HeaderTableSize uint32 // SETTINGS_HEADER_TABLE_SIZE
EnablePush bool // SETTINGS_ENABLE_PUSH
MaxConcurrentStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS
InitialWindowSize uint32 // SETTINGS_INITIAL_WINDOW_SIZE
MaxFrameSize uint32 // SETTINGS_MAX_FRAME_SIZE
MaxHeaderListSize uint32 // SETTINGS_MAX_HEADER_LIST_SIZE
}
// DefaultHTTP2Settings 返回默认 HTTP/2 设置。
func DefaultHTTP2Settings() HTTP2Settings {
return HTTP2Settings{
HeaderTableSize: 4096,
EnablePush: true,
MaxConcurrentStreams: 250,
InitialWindowSize: 65535,
MaxFrameSize: 16384,
MaxHeaderListSize: 1048576,
}
}
// ValidateHTTP2Settings 验证 HTTP/2 设置的有效性。
//
// 参数:
// - settings: HTTP/2 设置
//
// 返回值:
// - error: 设置无效时返回错误
func ValidateHTTP2Settings(settings HTTP2Settings) error {
if settings.MaxConcurrentStreams == 0 {
return errors.New("max concurrent streams cannot be zero")
}
if settings.MaxFrameSize < 16384 || settings.MaxFrameSize > 16777215 {
return errors.New("max frame size must be between 16384 and 16777215")
}
if settings.InitialWindowSize > 2147483647 {
return errors.New("initial window size cannot exceed 2^31-1")
}
if settings.MaxHeaderListSize == 0 {
return errors.New("max header list size cannot be zero")
}
return nil
}
// ParseHTTP2Settings 从配置解析 HTTP/2 设置。
//
// 参数:
// - cfg: HTTP/2 配置
//
// 返回值:
// - HTTP2Settings: 解析后的 HTTP/2 设置
func ParseHTTP2Settings(cfg *config.HTTP2Config) HTTP2Settings {
settings := DefaultHTTP2Settings()
if cfg.MaxConcurrentStreams > 0 {
settings.MaxConcurrentStreams = uint32(cfg.MaxConcurrentStreams)
}
if cfg.MaxHeaderListSize > 0 {
settings.MaxHeaderListSize = uint32(cfg.MaxHeaderListSize)
}
settings.EnablePush = cfg.PushEnabled
return settings
}
// connectionPool HTTP/2 连接池。
type connectionPool struct {
mu sync.RWMutex
conns map[string][]net.Conn
}
// newConnectionPool 创建新的连接池。
func newConnectionPool() *connectionPool {
return &connectionPool{
conns: make(map[string][]net.Conn),
}
}
// add 添加连接。
func (p *connectionPool) add(key string, conn net.Conn) {
p.mu.Lock()
defer p.mu.Unlock()
p.conns[key] = append(p.conns[key], conn)
}
// remove 移除连接。
func (p *connectionPool) remove(key string, conn net.Conn) {
p.mu.Lock()
defer p.mu.Unlock()
conns := p.conns[key]
for i, c := range conns {
if c == conn {
p.conns[key] = append(conns[:i], conns[i+1:]...)
break
}
}
}
// get 获取连接。
func (p *connectionPool) get(key string) []net.Conn {
p.mu.RLock()
defer p.mu.RUnlock()
return p.conns[key]
}
// count 获取连接数。
func (p *connectionPool) count(key string) int {
p.mu.RLock()
defer p.mu.RUnlock()
return len(p.conns[key])
}
// closeAll 关闭所有连接。
func (p *connectionPool) closeAll() { //nolint:unused // reserved for future use
p.mu.Lock()
defer p.mu.Unlock()
for _, conns := range p.conns {
for _, conn := range conns {
_ = conn.Close()
}
}
p.conns = make(map[string][]net.Conn)
}
// canonicalHeaderKey 返回规范化的 HTTP 头键。
func canonicalHeaderKey(key string) string {
// 使用 strings 包实现规范化
result := strings.ToLower(key)
if result == "" {
return ""
}
return strings.ToUpper(result[:1]) + result[1:]
}

View File

@ -0,0 +1,456 @@
// Package http2 提供 HTTP/2 服务器测试。
//
// 该文件包含 HTTP/2 服务器的单元测试和集成测试:
// - 服务器创建和配置测试
// - ALPN 协议协商测试
// - HTTP/1.1 fallback 测试
//
// 作者xfy
package http2
import (
"crypto/tls"
"net"
"testing"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
)
// TestNewServer 测试 HTTP/2 服务器创建。
func TestNewServer(t *testing.T) {
tests := []struct {
name string
cfg *config.HTTP2Config
handler fasthttp.RequestHandler
tlsConfig *tls.Config
wantErr bool
}{
{
name: "有效配置",
cfg: &config.HTTP2Config{
Enabled: true,
MaxConcurrentStreams: 128,
MaxHeaderListSize: 1048576,
IdleTimeout: 120 * time.Second,
PushEnabled: false,
H2CEnabled: false,
},
handler: func(ctx *fasthttp.RequestCtx) {},
tlsConfig: nil,
wantErr: false,
},
{
name: "默认配置",
cfg: &config.HTTP2Config{},
handler: func(ctx *fasthttp.RequestCtx) {},
wantErr: false,
},
{
name: "nil配置",
cfg: nil,
handler: func(ctx *fasthttp.RequestCtx) {},
wantErr: true,
},
{
name: "nil handler",
cfg: &config.HTTP2Config{
Enabled: true,
},
handler: nil,
wantErr: true,
},
{
name: "自定义并发流数量",
cfg: &config.HTTP2Config{
Enabled: true,
MaxConcurrentStreams: 256,
},
handler: func(ctx *fasthttp.RequestCtx) {},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server, err := NewServer(tt.cfg, tt.handler, tt.tlsConfig)
if tt.wantErr {
if err == nil {
t.Errorf("NewServer() expected error, got nil")
}
return
}
if err != nil {
t.Errorf("NewServer() unexpected error: %v", err)
return
}
if server == nil {
t.Error("NewServer() returned nil server")
return
}
// 验证配置正确应用
if server.config != tt.cfg {
t.Error("NewServer() config not set correctly")
}
if server.handler == nil {
t.Error("NewServer() handler not set")
}
})
}
}
// TestServerDefaultValues 测试服务器默认值。
func TestServerDefaultValues(t *testing.T) {
cfg := &config.HTTP2Config{
Enabled: true,
}
handler := func(ctx *fasthttp.RequestCtx) {}
server, err := NewServer(cfg, handler, nil)
if err != nil {
t.Fatalf("NewServer() error: %v", err)
}
// 验证默认并发流数量
if server.http2Server.MaxConcurrentStreams == 0 {
t.Error("Expected default MaxConcurrentStreams to be set")
}
// 验证默认空闲超时
if server.http2Server.IdleTimeout == 0 {
t.Error("Expected default IdleTimeout to be set")
}
}
// TestServerIsRunning 测试服务器运行状态。
func TestServerIsRunning(t *testing.T) {
cfg := &config.HTTP2Config{Enabled: true}
server, err := NewServer(cfg, func(ctx *fasthttp.RequestCtx) {}, nil)
if err != nil {
t.Fatalf("NewServer() error: %v", err)
}
// 初始状态应为未运行
if server.IsRunning() {
t.Error("New server should not be running")
}
}
// TestServerGetConfig 测试获取服务器配置。
func TestServerGetConfig(t *testing.T) {
cfg := &config.HTTP2Config{
Enabled: true,
MaxConcurrentStreams: 100,
}
server, err := NewServer(cfg, func(ctx *fasthttp.RequestCtx) {}, nil)
if err != nil {
t.Fatalf("NewServer() error: %v", err)
}
gotCfg := server.GetConfig()
if gotCfg != cfg {
t.Error("GetConfig() returned wrong config")
}
}
// TestALPNConfig 测试 ALPN 配置。
func TestALPNConfig(t *testing.T) {
cfg := &config.HTTP2Config{Enabled: true}
server, err := NewServer(cfg, func(ctx *fasthttp.RequestCtx) {}, nil)
if err != nil {
t.Fatalf("NewServer() error: %v", err)
}
tlsCfg := server.ALPNConfig()
if tlsCfg == nil {
t.Fatal("ALPNConfig() returned nil")
}
// 验证 ALPN 协议包含 h2 和 http/1.1
foundH2 := false
foundHTTP11 := false
for _, proto := range tlsCfg.NextProtos {
if proto == "h2" {
foundH2 = true
}
if proto == "http/1.1" {
foundHTTP11 = true
}
}
if !foundH2 {
t.Error("ALPN config missing 'h2' protocol")
}
if !foundHTTP11 {
t.Error("ALPN config missing 'http/1.1' protocol")
}
}
// TestWrapTLSListener 测试 TLS 监听器包装。
func TestWrapTLSListener(t *testing.T) {
// 创建测试监听器
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() { _ = ln.Close() }()
// 创建 TLS 配置
tlsConfig := &tls.Config{
NextProtos: []string{},
}
// 包装监听器
wrappedLn := WrapTLSListener(ln, tlsConfig)
if wrappedLn == nil {
t.Fatal("WrapTLSListener() returned nil")
}
// 验证 ALPN 协议已设置
if len(tlsConfig.NextProtos) == 0 {
t.Error("WrapTLSListener should set NextProtos")
}
}
// TestIsH2CEnabled 测试 H2C 启用检查。
func TestIsH2CEnabled(t *testing.T) {
tests := []struct {
name string
h2cEnabled bool
want bool
}{
{
name: "H2C 启用",
h2cEnabled: true,
want: true,
},
{
name: "H2C 禁用",
h2cEnabled: false,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.HTTP2Config{
Enabled: true,
H2CEnabled: tt.h2cEnabled,
}
server, err := NewServer(cfg, func(ctx *fasthttp.RequestCtx) {}, nil)
if err != nil {
t.Fatalf("NewServer() error: %v", err)
}
if got := server.IsH2CEnabled(); got != tt.want {
t.Errorf("IsH2CEnabled() = %v, want %v", got, tt.want)
}
})
}
}
// TestIsHTTP2Request 测试 HTTP/2 请求检测。
func TestIsHTTP2Request(t *testing.T) {
tests := []struct {
name string
method string
major int
header map[string]string
want bool
}{
{
name: "PRI 方法",
method: "PRI",
want: true,
},
{
name: "HTTP/2 版本",
major: 2,
want: true,
},
{
name: "HTTP/1.1",
method: "GET",
major: 1,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 这里只测试基本的逻辑,完整测试需要创建 http.Request
// 在实际集成测试中会覆盖
})
}
}
// TestHTTP2Settings 测试 HTTP/2 设置。
func TestHTTP2Settings(t *testing.T) {
tests := []struct {
name string
settings HTTP2Settings
wantErr bool
}{
{
name: "默认设置",
settings: HTTP2Settings{
HeaderTableSize: 4096,
EnablePush: true,
MaxConcurrentStreams: 250,
InitialWindowSize: 65535,
MaxFrameSize: 16384,
MaxHeaderListSize: 1048576,
},
wantErr: false,
},
{
name: "零并发流",
settings: HTTP2Settings{
MaxConcurrentStreams: 0,
},
wantErr: true,
},
{
name: "无效帧大小",
settings: HTTP2Settings{
MaxConcurrentStreams: 100,
MaxFrameSize: 1024, // 小于最小值 16384
},
wantErr: true,
},
{
name: "帧大小过大",
settings: HTTP2Settings{
MaxConcurrentStreams: 100,
MaxFrameSize: 16777216, // 超过最大值 16777215
},
wantErr: true,
},
{
name: "零头部列表大小",
settings: HTTP2Settings{
MaxConcurrentStreams: 100,
MaxHeaderListSize: 0,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateHTTP2Settings(tt.settings)
if tt.wantErr {
if err == nil {
t.Errorf("ValidateHTTP2Settings() expected error, got nil")
}
return
}
if err != nil {
t.Errorf("ValidateHTTP2Settings() unexpected error: %v", err)
}
})
}
}
// TestDefaultHTTP2Settings 测试默认 HTTP/2 设置。
func TestDefaultHTTP2Settings(t *testing.T) {
settings := DefaultHTTP2Settings()
if settings.HeaderTableSize == 0 {
t.Error("Default HeaderTableSize should not be zero")
}
if settings.MaxConcurrentStreams == 0 {
t.Error("Default MaxConcurrentStreams should not be zero")
}
if settings.InitialWindowSize == 0 {
t.Error("Default InitialWindowSize should not be zero")
}
if settings.MaxFrameSize == 0 {
t.Error("Default MaxFrameSize should not be zero")
}
if settings.MaxHeaderListSize == 0 {
t.Error("Default MaxHeaderListSize should not be zero")
}
}
// TestParseHTTP2Settings 测试从配置解析 HTTP/2 设置。
func TestParseHTTP2Settings(t *testing.T) {
cfg := &config.HTTP2Config{
Enabled: true,
MaxConcurrentStreams: 200,
MaxHeaderListSize: 2097152, // 2MB
PushEnabled: true,
}
settings := ParseHTTP2Settings(cfg)
if settings.MaxConcurrentStreams != 200 {
t.Errorf("ParseHTTP2Settings() MaxConcurrentStreams = %d, want 200", settings.MaxConcurrentStreams)
}
if settings.MaxHeaderListSize != 2097152 {
t.Errorf("ParseHTTP2Settings() MaxHeaderListSize = %d, want 2097152", settings.MaxHeaderListSize)
}
if !settings.EnablePush {
t.Error("ParseHTTP2Settings() EnablePush should be true")
}
}
// TestConnectionPool 测试连接池。
func TestConnectionPool(t *testing.T) {
pool := newConnectionPool()
// 创建测试连接
ln1, _ := net.Listen("tcp", "127.0.0.1:0")
defer func() { _ = ln1.Close() }()
ln2, _ := net.Listen("tcp", "127.0.0.1:0")
defer func() { _ = ln2.Close() }()
// 测试添加连接
conn1, _ := net.Dial("tcp", ln1.Addr().String())
if conn1 != nil {
defer func() { _ = conn1.Close() }()
pool.add("key1", conn1)
// 测试获取连接
conns := pool.get("key1")
if len(conns) != 1 {
t.Errorf("Expected 1 connection, got %d", len(conns))
}
// 测试计数
if count := pool.count("key1"); count != 1 {
t.Errorf("Expected count 1, got %d", count)
}
// 测试移除连接
pool.remove("key1", conn1)
if count := pool.count("key1"); count != 0 {
t.Errorf("Expected count 0 after remove, got %d", count)
}
}
}
// TestCanonicalHeaderKey 测试规范化头部键。
func TestCanonicalHeaderKey(t *testing.T) {
tests := []struct {
input string
want string
}{
{"content-type", "Content-Type"},
{"CONTENT-TYPE", "Content-Type"},
{"Content-Type", "Content-Type"},
{"x-custom-header", "X-Custom-Header"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := canonicalHeaderKey(tt.input)
if got != tt.want {
t.Errorf("canonicalHeaderKey(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}

View File

@ -108,6 +108,7 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12, // 强制 TLS 1.2 最低版本
MaxVersion: tls.VersionTLS13,
NextProtos: []string{"h2", "http/1.1"}, // 启用 HTTP/2 ALPN 支持
}
// 应用 TLS 1.2 的加密套件