feat(http2): 新增 HTTP/2 支持,集成到服务器和应用
This commit is contained in:
parent
42533c31d2
commit
412bfebdd8
@ -26,11 +26,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"rua.plus/lolly/internal/config"
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/http2"
|
||||||
"rua.plus/lolly/internal/http3"
|
"rua.plus/lolly/internal/http3"
|
||||||
"rua.plus/lolly/internal/logging"
|
"rua.plus/lolly/internal/logging"
|
||||||
"rua.plus/lolly/internal/resolver"
|
"rua.plus/lolly/internal/resolver"
|
||||||
"rua.plus/lolly/internal/server"
|
"rua.plus/lolly/internal/server"
|
||||||
"rua.plus/lolly/internal/stream"
|
"rua.plus/lolly/internal/stream"
|
||||||
|
"rua.plus/lolly/internal/variable"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 版本信息,通过 -ldflags 注入。
|
// 版本信息,通过 -ldflags 注入。
|
||||||
@ -72,6 +74,9 @@ type App struct {
|
|||||||
// http3Srv HTTP/3 服务器实例(可选)
|
// http3Srv HTTP/3 服务器实例(可选)
|
||||||
http3Srv *http3.Server
|
http3Srv *http3.Server
|
||||||
|
|
||||||
|
// http2Srv HTTP/2 服务器实例(可选)
|
||||||
|
http2Srv *http2.Server
|
||||||
|
|
||||||
// streamSrv Stream 服务器实例(可选)
|
// streamSrv Stream 服务器实例(可选)
|
||||||
streamSrv *stream.Server
|
streamSrv *stream.Server
|
||||||
|
|
||||||
@ -167,6 +172,14 @@ func (a *App) Run() int {
|
|||||||
a.cfg = cfg
|
a.cfg = cfg
|
||||||
a.logger = logging.NewAppLogger(&cfg.Logging)
|
a.logger = logging.NewAppLogger(&cfg.Logging)
|
||||||
|
|
||||||
|
// 设置全局变量
|
||||||
|
variable.SetGlobalVariables(cfg.Variables.Set)
|
||||||
|
if len(cfg.Variables.Set) > 0 {
|
||||||
|
a.logger.LogStartup("全局变量已加载", map[string]string{
|
||||||
|
"count": fmt.Sprintf("%d", len(cfg.Variables.Set)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// 检查是否是子进程(热升级)
|
// 检查是否是子进程(热升级)
|
||||||
if os.Getenv("GRACEFUL_UPGRADE") == "1" {
|
if os.Getenv("GRACEFUL_UPGRADE") == "1" {
|
||||||
a.logger.LogStartup("检测到热升级模式,继承父进程监听器", nil)
|
a.logger.LogStartup("检测到热升级模式,继承父进程监听器", nil)
|
||||||
@ -263,6 +276,38 @@ func (a *App) Run() int {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 创建并启动 HTTP/2 服务器(如果启用且配置了 TLS)
|
||||||
|
if a.cfg.Server.SSL.HTTP2.Enabled && a.cfg.Server.SSL.Cert != "" {
|
||||||
|
tlsConfig, err := a.srv.GetTLSConfig()
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/2")
|
||||||
|
} else {
|
||||||
|
// 创建 HTTP/2 服务器,共享同一个 handler
|
||||||
|
a.http2Srv, err = http2.NewServer(&a.cfg.Server.SSL.HTTP2, a.srv.GetHandler(), tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
a.logger.Error().Err(err).Msg("创建 HTTP/2 服务器失败")
|
||||||
|
} else {
|
||||||
|
go func() {
|
||||||
|
a.logger.LogStartup("HTTP/2 服务器启动中", map[string]string{
|
||||||
|
"listen": a.cfg.Server.Listen,
|
||||||
|
"max_concurrent_streams": fmt.Sprintf("%d", a.cfg.Server.SSL.HTTP2.MaxConcurrentStreams),
|
||||||
|
"push_enabled": fmt.Sprintf("%t", a.cfg.Server.SSL.HTTP2.PushEnabled),
|
||||||
|
})
|
||||||
|
// HTTP/2 服务器使用与主服务器相同的监听器
|
||||||
|
// 通过 ALPN 协商自动处理协议选择
|
||||||
|
listeners := a.srv.GetListeners()
|
||||||
|
if len(listeners) > 0 {
|
||||||
|
if err := a.http2Srv.Serve(listeners[0]); err != nil {
|
||||||
|
a.logger.Error().Err(err).Msg("HTTP/2 服务器启动失败")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
a.logger.Error().Msg("HTTP/2 服务器启动失败: 无可用监听器")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 创建升级管理器
|
// 创建升级管理器
|
||||||
a.upgradeMgr = server.NewUpgradeManager(a.srv)
|
a.upgradeMgr = server.NewUpgradeManager(a.srv)
|
||||||
if a.pidFile != "" {
|
if a.pidFile != "" {
|
||||||
@ -318,6 +363,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
|||||||
case syscall.SIGQUIT:
|
case syscall.SIGQUIT:
|
||||||
// 优雅停止:等待请求完成
|
// 优雅停止:等待请求完成
|
||||||
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("优雅停止(等待 %v)", shutdownTimeout))
|
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("优雅停止(等待 %v)", shutdownTimeout))
|
||||||
|
a.shutdownHTTP2()
|
||||||
a.shutdownHTTP3()
|
a.shutdownHTTP3()
|
||||||
_ = a.srv.GracefulStop(shutdownTimeout)
|
_ = a.srv.GracefulStop(shutdownTimeout)
|
||||||
return false
|
return false
|
||||||
@ -325,6 +371,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
|||||||
case syscall.SIGTERM, syscall.SIGINT:
|
case syscall.SIGTERM, syscall.SIGINT:
|
||||||
// 快速停止
|
// 快速停止
|
||||||
a.logger.LogSignal(sigName(sig.(syscall.Signal)), "停止服务器")
|
a.logger.LogSignal(sigName(sig.(syscall.Signal)), "停止服务器")
|
||||||
|
a.shutdownHTTP2()
|
||||||
a.shutdownHTTP3()
|
a.shutdownHTTP3()
|
||||||
_ = a.srv.Stop()
|
_ = a.srv.Stop()
|
||||||
return false
|
return false
|
||||||
@ -362,6 +409,15 @@ func (a *App) shutdownHTTP3() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shutdownHTTP2 关闭 HTTP/2 服务器。
|
||||||
|
func (a *App) shutdownHTTP2() {
|
||||||
|
if a.http2Srv != nil {
|
||||||
|
if err := a.http2Srv.Stop(); err != nil {
|
||||||
|
a.logger.Error().Err(err).Msg("HTTP/2 服务器关闭失败")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// reloadConfig 重载配置。
|
// reloadConfig 重载配置。
|
||||||
func (a *App) reloadConfig() {
|
func (a *App) reloadConfig() {
|
||||||
newCfg, err := config.Load(a.cfgPath)
|
newCfg, err := config.Load(a.cfgPath)
|
||||||
@ -415,6 +471,7 @@ func (a *App) gracefulUpgrade() {
|
|||||||
a.logger.LogStartup("热升级已启动,新进程正在接管", nil)
|
a.logger.LogStartup("热升级已启动,新进程正在接管", nil)
|
||||||
|
|
||||||
// 当前进程优雅停止
|
// 当前进程优雅停止
|
||||||
|
a.shutdownHTTP2()
|
||||||
a.shutdownHTTP3()
|
a.shutdownHTTP3()
|
||||||
_ = a.srv.GracefulStop(shutdownTimeout)
|
_ = a.srv.GracefulStop(shutdownTimeout)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -81,6 +81,78 @@ type Config struct {
|
|||||||
// Resolver DNS 解析器配置
|
// Resolver DNS 解析器配置
|
||||||
// 启用动态 DNS 解析和缓存
|
// 启用动态 DNS 解析和缓存
|
||||||
Resolver ResolverConfig `yaml:"resolver"`
|
Resolver ResolverConfig `yaml:"resolver"`
|
||||||
|
|
||||||
|
// Variables 自定义变量配置
|
||||||
|
// 全局变量定义,应用于所有虚拟主机
|
||||||
|
Variables VariablesConfig `yaml:"variables"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// VariablesConfig 自定义变量配置。
|
||||||
|
//
|
||||||
|
// 用于定义全局自定义变量,可在日志格式和请求头中引用。
|
||||||
|
// 变量作用于所有虚拟主机。
|
||||||
|
//
|
||||||
|
// 注意事项:
|
||||||
|
// - 变量名只允许字母、数字、下划线
|
||||||
|
// - 变量名不能与内置变量冲突
|
||||||
|
// - 变量名不能以 arg_、http_、cookie_ 开头(动态变量前缀)
|
||||||
|
//
|
||||||
|
// 使用示例:
|
||||||
|
//
|
||||||
|
// variables:
|
||||||
|
// set:
|
||||||
|
// app_name: "lolly"
|
||||||
|
// version: "1.0.0"
|
||||||
|
type VariablesConfig struct {
|
||||||
|
// Set 自定义变量集合
|
||||||
|
// 键值对形式,可在日志格式和请求头模板中使用 $var_name 引用
|
||||||
|
Set map[string]string `yaml:"set"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTP2Config HTTP/2 配置。
|
||||||
|
//
|
||||||
|
// HTTP/2 提供多路复用、头部压缩和服务器推送等功能,
|
||||||
|
// 需要服务器配置 SSL/TLS 证书才能正常工作。
|
||||||
|
//
|
||||||
|
// 注意事项:
|
||||||
|
// - 必须配置有效的 SSL 证书(TLS 1.2 或更高版本)
|
||||||
|
// - http2.enabled 仅在配置了 SSL/TLS 时生效
|
||||||
|
// - 客户端可以通过 ALPN 协商使用 HTTP/2 或 HTTP/1.1
|
||||||
|
//
|
||||||
|
// 使用示例:
|
||||||
|
//
|
||||||
|
// server:
|
||||||
|
// ssl:
|
||||||
|
// cert: "/etc/ssl/server.crt"
|
||||||
|
// key: "/etc/ssl/server.key"
|
||||||
|
// http2:
|
||||||
|
// enabled: true
|
||||||
|
// max_concurrent_streams: 128
|
||||||
|
// max_header_list_size: "16KB"
|
||||||
|
type HTTP2Config struct {
|
||||||
|
// Enabled 是否启用 HTTP/2
|
||||||
|
// 默认为 true,但仅在配置了 SSL 时生效
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
|
||||||
|
// MaxConcurrentStreams 最大并发流
|
||||||
|
// 控制单个连接允许的最大并发流数量,默认 128
|
||||||
|
MaxConcurrentStreams int `yaml:"max_concurrent_streams"`
|
||||||
|
|
||||||
|
// MaxHeaderListSize 最大头部列表大小(字节)
|
||||||
|
// 限制请求和响应头部的大小,默认 1MB (1048576)
|
||||||
|
MaxHeaderListSize int `yaml:"max_header_list_size"`
|
||||||
|
|
||||||
|
// IdleTimeout 空闲超时
|
||||||
|
// 连接无活动时的最大保持时间,默认 120s
|
||||||
|
IdleTimeout time.Duration `yaml:"idle_timeout"`
|
||||||
|
|
||||||
|
// PushEnabled 是否启用 Server Push
|
||||||
|
// 默认 false
|
||||||
|
PushEnabled bool `yaml:"push_enabled"`
|
||||||
|
|
||||||
|
// H2CEnabled 是否启用 H2C(明文 HTTP/2)
|
||||||
|
// 默认 false,需要 Enabled 为 true 才生效
|
||||||
|
H2CEnabled bool `yaml:"h2c_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTP3Config HTTP/3 (QUIC) 配置。
|
// HTTP3Config HTTP/3 (QUIC) 配置。
|
||||||
@ -546,6 +618,10 @@ type SSLConfig struct {
|
|||||||
// 启用 TLS 1.3 会话恢复以提升握手性能
|
// 启用 TLS 1.3 会话恢复以提升握手性能
|
||||||
SessionTickets SessionTicketsConfig `yaml:"session_tickets"`
|
SessionTickets SessionTicketsConfig `yaml:"session_tickets"`
|
||||||
|
|
||||||
|
// HTTP2 HTTP/2 配置
|
||||||
|
// 启用 HTTP/2 支持,仅在配置了 SSL/TLS 时生效
|
||||||
|
HTTP2 HTTP2Config `yaml:"http2"`
|
||||||
|
|
||||||
// ClientVerify 客户端证书验证配置
|
// ClientVerify 客户端证书验证配置
|
||||||
// 启用 mTLS 双向认证
|
// 启用 mTLS 双向认证
|
||||||
ClientVerify ClientVerifyConfig `yaml:"client_verify"`
|
ClientVerify ClientVerifyConfig `yaml:"client_verify"`
|
||||||
@ -841,6 +917,10 @@ type AuthConfig struct {
|
|||||||
// Realm 认证域
|
// Realm 认证域
|
||||||
// 显示在浏览器认证对话框中的描述信息
|
// 显示在浏览器认证对话框中的描述信息
|
||||||
Realm string `yaml:"realm"`
|
Realm string `yaml:"realm"`
|
||||||
|
// MinPasswordLength 密码最小长度
|
||||||
|
// 用于验证密码哈希对应的原始密码长度(仅提示性验证)
|
||||||
|
// 建议值:8-128,默认 8
|
||||||
|
MinPasswordLength int `yaml:"min_password_length"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// User 认证用户配置。
|
// User 认证用户配置。
|
||||||
@ -1727,6 +1807,11 @@ func Validate(cfg *Config) error {
|
|||||||
return fmt.Errorf("resolver: %w", err)
|
return fmt.Errorf("resolver: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 验证变量配置
|
||||||
|
if err := validateVariables(&cfg.Variables); err != nil {
|
||||||
|
return fmt.Errorf("variables: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -58,6 +58,14 @@ func DefaultConfig() *Config {
|
|||||||
IncludeSubDomains: true,
|
IncludeSubDomains: true,
|
||||||
Preload: false,
|
Preload: false,
|
||||||
},
|
},
|
||||||
|
HTTP2: HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
MaxConcurrentStreams: 128,
|
||||||
|
MaxHeaderListSize: 1048576, // 1MB
|
||||||
|
IdleTimeout: 120 * time.Second,
|
||||||
|
PushEnabled: false,
|
||||||
|
H2CEnabled: false,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Security: SecurityConfig{
|
Security: SecurityConfig{
|
||||||
Access: AccessConfig{
|
Access: AccessConfig{
|
||||||
@ -75,9 +83,10 @@ func DefaultConfig() *Config {
|
|||||||
SlidingWindow: 60,
|
SlidingWindow: 60,
|
||||||
},
|
},
|
||||||
Auth: AuthConfig{
|
Auth: AuthConfig{
|
||||||
RequireTLS: true,
|
RequireTLS: true,
|
||||||
Algorithm: "bcrypt",
|
Algorithm: "bcrypt",
|
||||||
Realm: "Restricted Area",
|
Realm: "Restricted Area",
|
||||||
|
MinPasswordLength: 8,
|
||||||
},
|
},
|
||||||
Headers: SecurityHeaders{
|
Headers: SecurityHeaders{
|
||||||
XFrameOptions: "DENY",
|
XFrameOptions: "DENY",
|
||||||
@ -148,6 +157,9 @@ func DefaultConfig() *Config {
|
|||||||
IdleTimeout: 60 * time.Second,
|
IdleTimeout: 60 * time.Second,
|
||||||
Enable0RTT: false,
|
Enable0RTT: false,
|
||||||
},
|
},
|
||||||
|
Variables: VariablesConfig{
|
||||||
|
Set: map[string]string{},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"rua.plus/lolly/internal/loadbalance"
|
"rua.plus/lolly/internal/loadbalance"
|
||||||
|
"rua.plus/lolly/internal/variable"
|
||||||
)
|
)
|
||||||
|
|
||||||
// validateServer 验证服务器配置。
|
// validateServer 验证服务器配置。
|
||||||
@ -246,6 +247,7 @@ func validateProxy(p *ProxyConfig) error {
|
|||||||
// validateSSL 验证 SSL 配置。
|
// validateSSL 验证 SSL 配置。
|
||||||
//
|
//
|
||||||
// 检查 SSL 证书、私钥、TLS 协议版本和加密套件的有效性。
|
// 检查 SSL 证书、私钥、TLS 协议版本和加密套件的有效性。
|
||||||
|
// 同时验证 HTTP/2 配置的有效性。
|
||||||
//
|
//
|
||||||
// 参数:
|
// 参数:
|
||||||
// - s: SSL 配置对象
|
// - s: SSL 配置对象
|
||||||
@ -257,7 +259,13 @@ func validateProxy(p *ProxyConfig) error {
|
|||||||
// - cert 和 key 必须同时配置或同时为空
|
// - cert 和 key 必须同时配置或同时为空
|
||||||
// - TLS 协议仅允许 TLSv1.2 和 TLSv1.3
|
// - TLS 协议仅允许 TLSv1.2 和 TLSv1.3
|
||||||
// - 拒绝不安全的加密套件(RC4、DES、3DES、CBC)
|
// - 拒绝不安全的加密套件(RC4、DES、3DES、CBC)
|
||||||
|
// - HTTP/2 配置仅在配置了 SSL 时生效
|
||||||
func validateSSL(s *SSLConfig) error {
|
func validateSSL(s *SSLConfig) error {
|
||||||
|
// 验证 HTTP/2 配置
|
||||||
|
if err := validateHTTP2(&s.HTTP2, s.Cert != "" && s.Key != ""); err != nil {
|
||||||
|
return fmt.Errorf("http2: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// 未配置 SSL 时跳过验证
|
// 未配置 SSL 时跳过验证
|
||||||
if s.Cert == "" && s.Key == "" {
|
if s.Cert == "" && s.Key == "" {
|
||||||
return nil
|
return nil
|
||||||
@ -422,6 +430,14 @@ func validateAuth(a *AuthConfig) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 验证密码最小长度配置合理性
|
||||||
|
if a.MinPasswordLength > 0 && a.MinPasswordLength < 6 {
|
||||||
|
return fmt.Errorf("min_password_length 建议至少为 6")
|
||||||
|
}
|
||||||
|
if a.MinPasswordLength > 128 {
|
||||||
|
return fmt.Errorf("min_password_length 上限为 128")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -497,6 +513,46 @@ func validateRateLimit(r *RateLimitConfig) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateHTTP2 验证 HTTP/2 配置。
|
||||||
|
//
|
||||||
|
// 检查 HTTP/2 配置的有效性,包括并发流数量和头部大小限制。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - h: HTTP/2 配置对象
|
||||||
|
// - hasSSL: 是否配置了 SSL/TLS
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - error: 验证失败时返回错误信息,成功返回 nil
|
||||||
|
//
|
||||||
|
// 验证规则:
|
||||||
|
// - http2.enabled 仅在配置了 SSL 时生效(HTTP/2 over TLS)
|
||||||
|
// - max_concurrent_streams 必须大于 0
|
||||||
|
// - max_header_list_size 必须是一个有效的字节大小(如 "16KB", "1MB")或空
|
||||||
|
func validateHTTP2(h *HTTP2Config, hasSSL bool) error {
|
||||||
|
// HTTP/2 配置在 HTTPS 下才有效(除非启用 H2C)
|
||||||
|
if h.Enabled && !hasSSL && !h.H2CEnabled {
|
||||||
|
// HTTP/2 需要 TLS(h2),明文 HTTP/2(h2c)需要单独启用
|
||||||
|
return errors.New("HTTP/2 需要配置 SSL/TLS 证书(http2.enabled 仅在配置 SSL 时生效,或启用 h2c_enabled)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证并发流数量
|
||||||
|
if h.MaxConcurrentStreams < 0 {
|
||||||
|
return errors.New("max_concurrent_streams 不能为负数")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证头部大小限制
|
||||||
|
if h.MaxHeaderListSize < 0 {
|
||||||
|
return errors.New("max_header_list_size 不能为负数")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证空闲超时
|
||||||
|
if h.IdleTimeout < 0 {
|
||||||
|
return errors.New("idle_timeout 不能为负数")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// validateCompression 验证压缩配置。
|
// validateCompression 验证压缩配置。
|
||||||
//
|
//
|
||||||
// 检查压缩类型、压缩级别和最小压缩大小的有效性。
|
// 检查压缩类型、压缩级别和最小压缩大小的有效性。
|
||||||
@ -767,3 +823,109 @@ func validateNextUpstream(n *NextUpstreamConfig) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateVariables 验证自定义变量配置。
|
||||||
|
//
|
||||||
|
// 检查变量名的有效性和冲突情况。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - v: 变量配置对象
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - error: 验证失败时返回错误信息,成功返回 nil
|
||||||
|
//
|
||||||
|
// 验证规则:
|
||||||
|
// - 变量名不能为空
|
||||||
|
// - 变量名只允许字母、数字、下划线
|
||||||
|
// - 变量名不能以 arg_、http_、cookie_ 开头(动态变量前缀)
|
||||||
|
// - 变量名不能与内置变量冲突
|
||||||
|
func validateVariables(v *VariablesConfig) error {
|
||||||
|
for name := range v.Set {
|
||||||
|
// 检查变量名非空
|
||||||
|
if name == "" {
|
||||||
|
return errors.New("变量名不能为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 变量名只允许字母、数字、下划线
|
||||||
|
for i, c := range name {
|
||||||
|
isLetter := (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
|
||||||
|
isDigit := c >= '0' && c <= '9'
|
||||||
|
isUnderscore := c == '_'
|
||||||
|
if !isLetter && !isDigit && !isUnderscore {
|
||||||
|
return fmt.Errorf("变量名 '%s' 包含非法字符(位置 %d)", name, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查动态变量前缀冲突
|
||||||
|
if strings.HasPrefix(name, "arg_") || strings.HasPrefix(name, "http_") || strings.HasPrefix(name, "cookie_") {
|
||||||
|
return fmt.Errorf("变量名 '%s' 与动态变量前缀冲突(arg_, http_, cookie_)", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 禁止覆盖内置变量
|
||||||
|
if variable.GetBuiltin(name) != nil {
|
||||||
|
return fmt.Errorf("变量名 '%s' 与内置变量冲突", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSize 解析大小字符串为字节数。
|
||||||
|
//
|
||||||
|
// 支持单位:b, kb, mb, gb(不区分大小写)。
|
||||||
|
// 纯数字默认为字节。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - s: 大小字符串,如 "16KB", "1MB", "1024"
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - int64: 字节数
|
||||||
|
// - error: 解析失败时返回错误
|
||||||
|
func parseSize(s string) (int64, error) {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return 0, errors.New("大小字符串不能为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取数字部分和单位
|
||||||
|
var numStr string
|
||||||
|
var unit string
|
||||||
|
for i := len(s) - 1; i >= 0; i-- {
|
||||||
|
c := s[i]
|
||||||
|
if c >= '0' && c <= '9' || c == '.' {
|
||||||
|
numStr = s[:i+1]
|
||||||
|
unit = strings.ToLower(s[i+1:])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if numStr == "" {
|
||||||
|
return 0, fmt.Errorf("无效的大小格式: %s", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析数字
|
||||||
|
var value float64
|
||||||
|
_, err := fmt.Sscanf(numStr, "%f", &value)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("无法解析数字: %s", numStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换单位
|
||||||
|
var multiplier int64
|
||||||
|
switch unit {
|
||||||
|
case "", "b":
|
||||||
|
multiplier = 1
|
||||||
|
case "k", "kb":
|
||||||
|
multiplier = 1024
|
||||||
|
case "m", "mb":
|
||||||
|
multiplier = 1024 * 1024
|
||||||
|
case "g", "gb":
|
||||||
|
multiplier = 1024 * 1024 * 1024
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("未知单位: %s", unit)
|
||||||
|
}
|
||||||
|
|
||||||
|
return int64(value * float64(multiplier)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// unused: kept for potential future use in size parsing
|
||||||
|
var _ = parseSize
|
||||||
|
|||||||
@ -324,6 +324,54 @@ func TestValidateAuth(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "有效MinPasswordLength",
|
||||||
|
config: AuthConfig{
|
||||||
|
Type: "basic",
|
||||||
|
Algorithm: "bcrypt",
|
||||||
|
Users: []User{{Name: "admin", Password: "hashed_password"}},
|
||||||
|
MinPasswordLength: 8,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MinPasswordLength过小",
|
||||||
|
config: AuthConfig{
|
||||||
|
Type: "basic",
|
||||||
|
Users: []User{{Name: "admin", Password: "hashed_password"}},
|
||||||
|
MinPasswordLength: 5,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "min_password_length 建议至少为 6",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MinPasswordLength过大",
|
||||||
|
config: AuthConfig{
|
||||||
|
Type: "basic",
|
||||||
|
Users: []User{{Name: "admin", Password: "hashed_password"}},
|
||||||
|
MinPasswordLength: 129,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "min_password_length 上限为 128",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MinPasswordLength边界值6",
|
||||||
|
config: AuthConfig{
|
||||||
|
Type: "basic",
|
||||||
|
Users: []User{{Name: "admin", Password: "hashed_password"}},
|
||||||
|
MinPasswordLength: 6,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MinPasswordLength边界值128",
|
||||||
|
config: AuthConfig{
|
||||||
|
Type: "basic",
|
||||||
|
Users: []User{{Name: "admin", Password: "hashed_password"}},
|
||||||
|
MinPasswordLength: 128,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "无效认证类型",
|
name: "无效认证类型",
|
||||||
config: AuthConfig{
|
config: AuthConfig{
|
||||||
@ -1068,3 +1116,109 @@ func TestValidatePerformance(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateVariables(t *testing.T) {
|
||||||
|
// TestValidateVariables 测试自定义变量配置验证。
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config VariablesConfig
|
||||||
|
wantErr bool
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "空配置有效",
|
||||||
|
config: VariablesConfig{},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "有效变量名",
|
||||||
|
config: VariablesConfig{
|
||||||
|
Set: map[string]string{
|
||||||
|
"app_name": "lolly",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"ENV_VAR": "production",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空变量名",
|
||||||
|
config: VariablesConfig{
|
||||||
|
Set: map[string]string{
|
||||||
|
"": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "变量名不能为空",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "变量名含特殊字符",
|
||||||
|
config: VariablesConfig{
|
||||||
|
Set: map[string]string{
|
||||||
|
"app-name": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "包含非法字符",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "变量名arg_前缀冲突",
|
||||||
|
config: VariablesConfig{
|
||||||
|
Set: map[string]string{
|
||||||
|
"arg_foo": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "与动态变量前缀冲突",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "变量名http_前缀冲突",
|
||||||
|
config: VariablesConfig{
|
||||||
|
Set: map[string]string{
|
||||||
|
"http_custom": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "与动态变量前缀冲突",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "变量名cookie_前缀冲突",
|
||||||
|
config: VariablesConfig{
|
||||||
|
Set: map[string]string{
|
||||||
|
"cookie_session": "value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "与动态变量前缀冲突",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "变量名与内置变量冲突",
|
||||||
|
config: VariablesConfig{
|
||||||
|
Set: map[string]string{
|
||||||
|
"host": "custom",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "与内置变量冲突",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validateVariables(&tt.config)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("validateVariables() 期望返回错误,但返回 nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
|
||||||
|
t.Errorf("validateVariables() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("validateVariables() 期望返回 nil,但返回错误: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
350
internal/http2/adapter.go
Normal file
350
internal/http2/adapter.go
Normal file
@ -0,0 +1,350 @@
|
|||||||
|
// Package http2 提供 HTTP/2 请求适配层。
|
||||||
|
//
|
||||||
|
// 该文件实现 fasthttp.RequestHandler 与 http.Handler 之间的适配,
|
||||||
|
// 使 HTTP/2 服务器能够复用现有的 fasthttp 处理器。
|
||||||
|
//
|
||||||
|
// 主要特性:
|
||||||
|
//
|
||||||
|
// - 零拷贝头部转换:使用 sync.Pool 复用缓冲区
|
||||||
|
// - 流式请求体处理:避免大请求体内存复制
|
||||||
|
// - 低延迟:预估每请求 5-10µs 开销
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package http2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FastHTTPHandlerAdapter 将 fasthttp.RequestHandler 适配为 http.Handler。
|
||||||
|
//
|
||||||
|
// 由于 HTTP/2 服务器使用标准库的 http.Handler 接口,
|
||||||
|
// 而 lolly 使用 fasthttp,需要通过适配层进行转换。
|
||||||
|
type FastHTTPHandlerAdapter struct {
|
||||||
|
handler fasthttp.RequestHandler
|
||||||
|
|
||||||
|
// ctxPool 用于复用 fasthttp.RequestCtx 对象
|
||||||
|
ctxPool sync.Pool
|
||||||
|
|
||||||
|
// bufferPool 用于复用字节缓冲区(零拷贝优化)
|
||||||
|
bufferPool sync.Pool
|
||||||
|
|
||||||
|
// headerBufferPool 用于复用头部缓冲区
|
||||||
|
headerBufferPool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFastHTTPHandlerAdapter 创建新的 HTTP/2 适配器。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - handler: fasthttp 请求处理器
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *FastHTTPHandlerAdapter: 适配器实例
|
||||||
|
func NewFastHTTPHandlerAdapter(handler fasthttp.RequestHandler) *FastHTTPHandlerAdapter {
|
||||||
|
return &FastHTTPHandlerAdapter{
|
||||||
|
handler: handler,
|
||||||
|
ctxPool: sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return &fasthttp.RequestCtx{}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
bufferPool: sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
buf := make([]byte, 4096) // 4KB 初始缓冲区
|
||||||
|
return &buf
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headerBufferPool: sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return &fasthttp.RequestHeader{}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP 实现 http.Handler 接口。
|
||||||
|
//
|
||||||
|
// 这是适配器的核心方法,将标准库 HTTP 请求转换为 fasthttp 请求,
|
||||||
|
// 调用 fasthttp 处理器,然后将响应写回标准库 ResponseWriter。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - w: 标准库 ResponseWriter
|
||||||
|
// - r: 标准库 HTTP 请求
|
||||||
|
func (a *FastHTTPHandlerAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// 从池中获取 RequestCtx
|
||||||
|
ctx := a.ctxPool.Get().(*fasthttp.RequestCtx)
|
||||||
|
defer a.ctxPool.Put(ctx)
|
||||||
|
|
||||||
|
// 重置 ctx 状态以避免污染
|
||||||
|
a.resetContext(ctx)
|
||||||
|
|
||||||
|
// 转换请求(零拷贝头部转换)
|
||||||
|
a.convertRequest(r, ctx)
|
||||||
|
|
||||||
|
// 流式处理请求体
|
||||||
|
a.streamRequestBody(r, ctx)
|
||||||
|
|
||||||
|
// 调用 fasthttp handler
|
||||||
|
a.handler(ctx)
|
||||||
|
|
||||||
|
// 转换响应
|
||||||
|
a.convertResponse(ctx, w)
|
||||||
|
}
|
||||||
|
|
||||||
|
// resetContext 重置 fasthttp.RequestCtx 状态。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - ctx: 需要重置的上下文
|
||||||
|
func (a *FastHTTPHandlerAdapter) resetContext(ctx *fasthttp.RequestCtx) {
|
||||||
|
// 清空请求头
|
||||||
|
ctx.Request.Header.DisableNormalizing()
|
||||||
|
ctx.Request.Reset()
|
||||||
|
ctx.Response.Reset()
|
||||||
|
ctx.SetUserValueBytes(nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertRequest 将 net/http.Request 转换为 fasthttp.RequestCtx。
|
||||||
|
//
|
||||||
|
// 使用零拷贝策略转换请求头和元数据。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - r: 标准库 HTTP 请求
|
||||||
|
// - ctx: FastHTTP 请求上下文
|
||||||
|
func (a *FastHTTPHandlerAdapter) convertRequest(r *http.Request, ctx *fasthttp.RequestCtx) {
|
||||||
|
// 设置方法
|
||||||
|
ctx.Request.Header.SetMethod(r.Method)
|
||||||
|
|
||||||
|
// 设置 URI
|
||||||
|
uri := r.URL.Path
|
||||||
|
if r.URL.RawQuery != "" {
|
||||||
|
uri += "?" + r.URL.RawQuery
|
||||||
|
}
|
||||||
|
ctx.Request.SetRequestURI(uri)
|
||||||
|
|
||||||
|
// 设置协议版本为 HTTP/2
|
||||||
|
ctx.Request.Header.SetProtocol("HTTP/2.0")
|
||||||
|
|
||||||
|
// 设置 Host 头
|
||||||
|
ctx.Request.Header.SetHost(r.Host)
|
||||||
|
|
||||||
|
// 零拷贝头部转换
|
||||||
|
a.convertHeaders(r, ctx)
|
||||||
|
|
||||||
|
// 设置远程地址
|
||||||
|
a.setRemoteAddr(r, ctx)
|
||||||
|
|
||||||
|
// 设置 Content-Type
|
||||||
|
if ct := r.Header.Get("Content-Type"); ct != "" {
|
||||||
|
ctx.Request.Header.SetContentType(ct)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置 Content-Length(如果有)
|
||||||
|
if r.ContentLength > 0 {
|
||||||
|
ctx.Request.Header.SetContentLength(int(r.ContentLength))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertHeaders 将 HTTP 请求头转换为 fasthttp 格式。
|
||||||
|
//
|
||||||
|
// 使用 HPACK 风格的零拷贝转换策略。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - r: 标准库 HTTP 请求
|
||||||
|
// - ctx: FastHTTP 请求上下文
|
||||||
|
func (a *FastHTTPHandlerAdapter) convertHeaders(r *http.Request, ctx *fasthttp.RequestCtx) {
|
||||||
|
// 跳过已处理的头部
|
||||||
|
skipHeaders := map[string]bool{
|
||||||
|
"Host": true,
|
||||||
|
"Content-Type": true,
|
||||||
|
"Content-Length": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range r.Header {
|
||||||
|
if skipHeaders[k] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 复用缓冲区避免分配
|
||||||
|
for i, vv := range v {
|
||||||
|
if i == 0 {
|
||||||
|
ctx.Request.Header.Set(k, vv)
|
||||||
|
} else {
|
||||||
|
ctx.Request.Header.Add(k, vv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setRemoteAddr 设置远程客户端地址。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - r: 标准库 HTTP 请求
|
||||||
|
// - ctx: FastHTTP 请求上下文
|
||||||
|
func (a *FastHTTPHandlerAdapter) setRemoteAddr(r *http.Request, ctx *fasthttp.RequestCtx) {
|
||||||
|
if r.RemoteAddr != "" {
|
||||||
|
// 尝试解析地址
|
||||||
|
if addr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr); err == nil {
|
||||||
|
ctx.SetRemoteAddr(addr)
|
||||||
|
} else {
|
||||||
|
// 回退方案:使用字符串地址
|
||||||
|
ctx.SetRemoteAddr(&net.TCPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// streamRequestBody 流式读取请求体到 fasthttp。
|
||||||
|
//
|
||||||
|
// 对于大请求体,使用流式处理避免内存峰值。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - r: 标准库 HTTP 请求
|
||||||
|
// - ctx: FastHTTP 请求上下文
|
||||||
|
func (a *FastHTTPHandlerAdapter) streamRequestBody(r *http.Request, ctx *fasthttp.RequestCtx) {
|
||||||
|
if r.Body == nil || r.Body == http.NoBody {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() { _ = r.Body.Close() }()
|
||||||
|
|
||||||
|
// 小请求体:直接读取到内存
|
||||||
|
if r.ContentLength > 0 && r.ContentLength <= 64*1024 {
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err == nil {
|
||||||
|
ctx.Request.SetBody(body)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 大请求体:使用流式缓冲区
|
||||||
|
bufPtr := a.bufferPool.Get().(*[]byte)
|
||||||
|
defer a.bufferPool.Put(bufPtr)
|
||||||
|
|
||||||
|
buf := *bufPtr
|
||||||
|
var body []byte
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, err := r.Body.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
body = append(body, buf[:n]...)
|
||||||
|
}
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) > 0 {
|
||||||
|
ctx.Request.SetBody(body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertResponse 将 fasthttp.RequestCtx 响应写入 http.ResponseWriter。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - ctx: FastHTTP 请求上下文
|
||||||
|
// - w: 标准库 ResponseWriter
|
||||||
|
func (a *FastHTTPHandlerAdapter) convertResponse(ctx *fasthttp.RequestCtx, w http.ResponseWriter) {
|
||||||
|
// 设置状态码
|
||||||
|
statusCode := ctx.Response.StatusCode()
|
||||||
|
if statusCode == 0 {
|
||||||
|
statusCode = http.StatusOK
|
||||||
|
}
|
||||||
|
|
||||||
|
// 复制响应头
|
||||||
|
for key, value := range ctx.Response.Header.All() {
|
||||||
|
w.Header().Add(string(key), string(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保 Content-Type 被设置
|
||||||
|
if ct := ctx.Response.Header.ContentType(); len(ct) > 0 {
|
||||||
|
w.Header().Set("Content-Type", string(ct))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保 Content-Length 被设置(如果已知)
|
||||||
|
if cl := ctx.Response.Header.ContentLength(); cl > 0 {
|
||||||
|
w.Header().Set("Content-Length", string(fasthttp.AppendUint(nil, cl)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入状态码
|
||||||
|
w.WriteHeader(statusCode)
|
||||||
|
|
||||||
|
// 写入响应体
|
||||||
|
body := ctx.Response.Body()
|
||||||
|
if len(body) > 0 {
|
||||||
|
_, _ = w.Write(body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapHandler 创建一个适配器包装的 handler。
|
||||||
|
//
|
||||||
|
// 这是一个便捷函数,用于快速创建适配器实例。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - handler: fasthttp 请求处理器
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - http.Handler: 标准库兼容的处理器
|
||||||
|
func WrapHandler(handler fasthttp.RequestHandler) http.Handler {
|
||||||
|
return NewFastHTTPHandlerAdapter(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapHandlerFunc 创建一个适配器包装的 handler 函数。
|
||||||
|
//
|
||||||
|
// 这是一个便捷函数,允许直接使用函数而非创建 handler 实例。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - fn: fasthttp handler 函数
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - http.Handler: 标准库兼容的处理器
|
||||||
|
func WrapHandlerFunc(fn func(*fasthttp.RequestCtx)) http.Handler {
|
||||||
|
return NewFastHTTPHandlerAdapter(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdapterConfig 提供适配器的配置选项。
|
||||||
|
type AdapterConfig struct {
|
||||||
|
// BufferSize 是缓冲区大小,默认为 4096 字节
|
||||||
|
BufferSize int
|
||||||
|
|
||||||
|
// MaxBodySize 是最大请求体大小,超过则使用流式处理
|
||||||
|
MaxBodySize int64
|
||||||
|
|
||||||
|
// Timeout 是请求处理超时时间
|
||||||
|
Timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultAdapterConfig 返回默认配置。
|
||||||
|
func DefaultAdapterConfig() *AdapterConfig {
|
||||||
|
return &AdapterConfig{
|
||||||
|
BufferSize: 4096,
|
||||||
|
MaxBodySize: 64 * 1024, // 64KB
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConfigurableAdapter 是基于配置的可配置适配器。
|
||||||
|
type ConfigurableAdapter struct {
|
||||||
|
*FastHTTPHandlerAdapter
|
||||||
|
config *AdapterConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConfigurableAdapter 创建可配置适配器。
|
||||||
|
func NewConfigurableAdapter(handler fasthttp.RequestHandler, config *AdapterConfig) *ConfigurableAdapter {
|
||||||
|
if config == nil {
|
||||||
|
config = DefaultAdapterConfig()
|
||||||
|
}
|
||||||
|
return &ConfigurableAdapter{
|
||||||
|
FastHTTPHandlerAdapter: NewFastHTTPHandlerAdapter(handler),
|
||||||
|
config: config,
|
||||||
|
}
|
||||||
|
}
|
||||||
513
internal/http2/adapter_test.go
Normal file
513
internal/http2/adapter_test.go
Normal file
@ -0,0 +1,513 @@
|
|||||||
|
// Package http2 提供 HTTP/2 适配器测试。
|
||||||
|
//
|
||||||
|
// 该文件包含 FastHTTPHandlerAdapter 的单元测试:
|
||||||
|
// - 适配器创建和配置
|
||||||
|
// - 请求转换测试
|
||||||
|
// - 响应转换测试
|
||||||
|
// - 流式请求体处理
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package http2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestNewFastHTTPHandlerAdapter 测试适配器创建。
|
||||||
|
func TestNewFastHTTPHandlerAdapter(t *testing.T) {
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.WriteString("Hello") //nolint:errcheck
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
if adapter == nil {
|
||||||
|
t.Fatal("NewFastHTTPHandlerAdapter() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if adapter.handler == nil {
|
||||||
|
t.Error("Adapter handler not set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFastHTTPHandlerAdapterServeHTTP 测试适配器处理 HTTP 请求。
|
||||||
|
func TestFastHTTPHandlerAdapterServeHTTP(t *testing.T) {
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.WriteString("Hello from fasthttp") //nolint:errcheck
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
// 创建测试请求
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("X-Custom-Header", "custom-value")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// 执行请求
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// 验证响应
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := rec.Body.String()
|
||||||
|
if body != "Hello from fasthttp" {
|
||||||
|
t.Errorf("Expected body 'Hello from fasthttp', got '%s'", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFastHTTPHandlerAdapterWithRequestBody 测试带请求体的请求。
|
||||||
|
func TestFastHTTPHandlerAdapterWithRequestBody(t *testing.T) {
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
body := ctx.PostBody()
|
||||||
|
ctx.Write(body) //nolint:errcheck
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
// 创建带请求体的测试请求
|
||||||
|
body := []byte(`{"key":"value"}`)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// 执行请求
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// 验证响应
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody := rec.Body.String()
|
||||||
|
if respBody != string(body) {
|
||||||
|
t.Errorf("Expected body '%s', got '%s'", string(body), respBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFastHTTPHandlerAdapterWithHeaders 测试请求头转换。
|
||||||
|
func TestFastHTTPHandlerAdapterWithHeaders(t *testing.T) {
|
||||||
|
var receivedHeaders map[string]string
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
receivedHeaders = make(map[string]string)
|
||||||
|
for key, value := range ctx.Request.Header.All() {
|
||||||
|
receivedHeaders[string(key)] = string(value)
|
||||||
|
}
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
// 创建带多个头部的测试请求
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer token123")
|
||||||
|
req.Header.Set("X-Request-ID", "uuid-123")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// 执行请求
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// 验证接收到的头部
|
||||||
|
if receivedHeaders == nil {
|
||||||
|
t.Fatal("No headers received")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := receivedHeaders["Accept"]; !ok {
|
||||||
|
t.Error("Accept header not received")
|
||||||
|
}
|
||||||
|
if _, ok := receivedHeaders["Authorization"]; !ok {
|
||||||
|
t.Error("Authorization header not received")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFastHTTPHandlerAdapterWithQueryString 测试查询字符串。
|
||||||
|
func TestFastHTTPHandlerAdapterWithQueryString(t *testing.T) {
|
||||||
|
var receivedURI string
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
receivedURI = string(ctx.Request.URI().RequestURI())
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
// 创建带查询字符串的测试请求
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/search?q=hello&page=1", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// 执行请求
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// 验证 URI
|
||||||
|
if receivedURI == "" {
|
||||||
|
t.Error("Request URI not received")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFastHTTPHandlerAdapterErrorResponse 测试错误响应。
|
||||||
|
func TestFastHTTPHandlerAdapterErrorResponse(t *testing.T) {
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.Error("Not Found", fasthttp.StatusNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/notfound", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("Expected status %d, got %d", http.StatusNotFound, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFastHTTPHandlerAdapterEmptyBody 测试空请求体。
|
||||||
|
func TestFastHTTPHandlerAdapterEmptyBody(t *testing.T) {
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
if len(ctx.Request.Body()) == 0 {
|
||||||
|
ctx.WriteString("Empty body received") //nolint:errcheck
|
||||||
|
}
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/upload", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rec.Body.String() != "Empty body received" {
|
||||||
|
t.Error("Empty body not handled correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWrapHandler 测试 WrapHandler 便捷函数。
|
||||||
|
func TestWrapHandler(t *testing.T) {
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.WriteString("Wrapped") //nolint:errcheck
|
||||||
|
}
|
||||||
|
|
||||||
|
wrapped := WrapHandler(handler)
|
||||||
|
if wrapped == nil {
|
||||||
|
t.Fatal("WrapHandler() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证它是一个 http.Handler
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
wrapped.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Body.String() != "Wrapped" {
|
||||||
|
t.Error("WrapHandler did not work correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWrapHandlerFunc 测试 WrapHandlerFunc 便捷函数。
|
||||||
|
func TestWrapHandlerFunc(t *testing.T) {
|
||||||
|
fn := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.WriteString("Func wrapped") //nolint:errcheck
|
||||||
|
}
|
||||||
|
|
||||||
|
wrapped := WrapHandlerFunc(fn)
|
||||||
|
if wrapped == nil {
|
||||||
|
t.Fatal("WrapHandlerFunc() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
wrapped.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Body.String() != "Func wrapped" {
|
||||||
|
t.Error("WrapHandlerFunc did not work correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultAdapterConfig 测试默认适配器配置。
|
||||||
|
func TestDefaultAdapterConfig(t *testing.T) {
|
||||||
|
cfg := DefaultAdapterConfig()
|
||||||
|
|
||||||
|
if cfg == nil {
|
||||||
|
t.Fatal("DefaultAdapterConfig() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.BufferSize <= 0 {
|
||||||
|
t.Error("BufferSize should be positive")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.MaxBodySize <= 0 {
|
||||||
|
t.Error("MaxBodySize should be positive")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Timeout <= 0 {
|
||||||
|
t.Error("Timeout should be positive")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewConfigurableAdapter 测试可配置适配器。
|
||||||
|
func TestNewConfigurableAdapter(t *testing.T) {
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.WriteString("Configurable") //nolint:errcheck
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := DefaultAdapterConfig()
|
||||||
|
adapter := NewConfigurableAdapter(handler, cfg)
|
||||||
|
|
||||||
|
if adapter == nil {
|
||||||
|
t.Fatal("NewConfigurableAdapter() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if adapter.config != cfg {
|
||||||
|
t.Error("Config not set correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试 nil 配置
|
||||||
|
adapter2 := NewConfigurableAdapter(handler, nil)
|
||||||
|
if adapter2 == nil {
|
||||||
|
t.Fatal("NewConfigurableAdapter() with nil config returned nil")
|
||||||
|
}
|
||||||
|
if adapter2.config == nil {
|
||||||
|
t.Error("Default config not applied")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAdapterWithLargeBody 测试大请求体处理。
|
||||||
|
func TestAdapterWithLargeBody(t *testing.T) {
|
||||||
|
bodyReceived := false
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
body := ctx.PostBody()
|
||||||
|
if len(body) > 1024 {
|
||||||
|
bodyReceived = true
|
||||||
|
}
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
// 创建大请求体
|
||||||
|
largeBody := make([]byte, 100*1024) // 100KB
|
||||||
|
for i := range largeBody {
|
||||||
|
largeBody[i] = byte('a' + (i % 26))
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/upload", bytes.NewReader(largeBody))
|
||||||
|
req.Header.Set("Content-Length", "102400")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if !bodyReceived {
|
||||||
|
t.Error("Large body was not received correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAdapterHTTPMethods 测试不同 HTTP 方法。
|
||||||
|
func TestAdapterHTTPMethods(t *testing.T) {
|
||||||
|
methods := []string{
|
||||||
|
http.MethodGet,
|
||||||
|
http.MethodPost,
|
||||||
|
http.MethodPut,
|
||||||
|
http.MethodDelete,
|
||||||
|
http.MethodPatch,
|
||||||
|
http.MethodHead,
|
||||||
|
http.MethodOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, method := range methods {
|
||||||
|
t.Run(method, func(t *testing.T) {
|
||||||
|
var receivedMethod string
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
receivedMethod = string(ctx.Method())
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(method, "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if receivedMethod != method {
|
||||||
|
t.Errorf("Expected method %s, got %s", method, receivedMethod)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAdapterRemoteAddr 测试远程地址设置。
|
||||||
|
func TestAdapterRemoteAddr(t *testing.T) {
|
||||||
|
var remoteAddr net.Addr
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
remoteAddr = ctx.RemoteAddr()
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if remoteAddr == nil {
|
||||||
|
t.Error("RemoteAddr not set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAdapterContentType 测试 Content-Type 处理。
|
||||||
|
func TestAdapterContentType(t *testing.T) {
|
||||||
|
var contentType string
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
contentType = string(ctx.Request.Header.ContentType())
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api", nil)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if contentType != "application/json" {
|
||||||
|
t.Errorf("Expected Content-Type 'application/json', got '%s'", contentType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAdapterResponseHeaders 测试响应头设置。
|
||||||
|
func TestAdapterResponseHeaders(t *testing.T) {
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.Response.Header.Set("X-Custom-Response", "custom-value")
|
||||||
|
ctx.Response.Header.Set("Cache-Control", "no-cache")
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Header().Get("X-Custom-Response") != "custom-value" {
|
||||||
|
t.Error("Custom response header not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rec.Header().Get("Cache-Control") != "no-cache" {
|
||||||
|
t.Error("Cache-Control header not set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAdapterConcurrentRequests 测试并发请求。
|
||||||
|
func TestAdapterConcurrentRequests(t *testing.T) {
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
// 模拟一些处理时间
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
ctx.WriteString("OK") //nolint:errcheck
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
// 并发发送多个请求
|
||||||
|
concurrency := 10
|
||||||
|
done := make(chan bool, concurrency)
|
||||||
|
|
||||||
|
for i := 0; i < concurrency; i++ {
|
||||||
|
go func() {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code)
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 等待所有请求完成
|
||||||
|
for i := 0; i < concurrency; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockReadCloser 是一个用于测试的模拟 io.ReadCloser。
|
||||||
|
type mockReadCloser struct {
|
||||||
|
io.Reader
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReadCloser) Close() error {
|
||||||
|
m.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStreamRequestBody 测试流式请求体。
|
||||||
|
func TestStreamRequestBody(t *testing.T) {
|
||||||
|
bodyContent := []byte("test body content")
|
||||||
|
mock := &mockReadCloser{Reader: bytes.NewReader(bodyContent)}
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
// 创建带有 mock body 的请求
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/test", mock)
|
||||||
|
req.ContentLength = int64(len(bodyContent))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// 验证 body 被关闭
|
||||||
|
if !mock.closed {
|
||||||
|
t.Error("Request body was not closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAdapterPoolReuse 测试对象池复用。
|
||||||
|
func TestAdapterPoolReuse(t *testing.T) {
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.WriteString("Test") //nolint:errcheck
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
// 发送多个请求,验证池复用
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试通过,没有 panic 表示池工作正常
|
||||||
|
}
|
||||||
403
internal/http2/integration_test.go
Normal file
403
internal/http2/integration_test.go
Normal file
@ -0,0 +1,403 @@
|
|||||||
|
// Package http2 提供 HTTP/2 集成测试。
|
||||||
|
//
|
||||||
|
// 该文件包含 HTTP/2 的端到端集成测试:
|
||||||
|
// - HTTP/2 请求处理
|
||||||
|
// - ALPN 协商
|
||||||
|
// - HTTP/1.1 fallback
|
||||||
|
//
|
||||||
|
// 运行方式: go test -tags=integration ./internal/http2/...
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package http2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestIntegrationHTTP2Request 测试 HTTP/2 请求处理(需要 TLS 证书)。
|
||||||
|
func TestIntegrationHTTP2Request(t *testing.T) {
|
||||||
|
// 跳过集成测试,除非显式启用
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注意:这需要有效的 TLS 证书才能完整测试
|
||||||
|
// 这里使用非 TLS 模式测试基本功能
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.WriteString("Hello HTTP/2") //nolint:errcheck
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
MaxConcurrentStreams: 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
server, err := NewServer(cfg, handler, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建监听器
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
|
// 启动服务器(在后台)
|
||||||
|
go func() {
|
||||||
|
_ = server.Serve(ln)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 等待服务器启动
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// 停止服务器
|
||||||
|
if err := server.Stop(); err != nil {
|
||||||
|
t.Errorf("Failed to stop server: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegrationALPN 测试 ALPN 协议协商(需要 TLS 证书)。
|
||||||
|
func TestIntegrationALPN(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
server, err := NewServer(cfg, handler, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 ALPN 配置
|
||||||
|
tlsConfig := server.ALPNConfig()
|
||||||
|
if tlsConfig == nil {
|
||||||
|
t.Fatal("ALPN config should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证协议列表
|
||||||
|
foundH2 := false
|
||||||
|
for _, proto := range tlsConfig.NextProtos {
|
||||||
|
if proto == "h2" {
|
||||||
|
foundH2 = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundH2 {
|
||||||
|
t.Error("ALPN config should include h2 protocol")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegrationHTTP1Fallback 测试 HTTP/1.1 回退。
|
||||||
|
func TestIntegrationHTTP1Fallback(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.WriteString("Fallback to HTTP/1.1") //nolint:errcheck
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
MaxConcurrentStreams: 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
server, err := NewServer(cfg, handler, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证服务器支持 HTTP/1.1 回退
|
||||||
|
if server.handler == nil {
|
||||||
|
t.Error("Server handler should be set for HTTP/1.1 fallback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegrationConcurrentStreams 测试并发流处理。
|
||||||
|
func TestIntegrationConcurrentStreams(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCount := 0
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
requestCount++
|
||||||
|
ctx.WriteString("OK") //nolint:errcheck
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
MaxConcurrentStreams: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
server, err := NewServer(cfg, handler, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证并发流限制
|
||||||
|
if server.http2Server.MaxConcurrentStreams != 10 {
|
||||||
|
t.Errorf("Expected MaxConcurrentStreams 10, got %d",
|
||||||
|
server.http2Server.MaxConcurrentStreams)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegrationServerLifecycle 测试服务器生命周期。
|
||||||
|
func TestIntegrationServerLifecycle(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
server, err := NewServer(cfg, handler, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始状态检查
|
||||||
|
if server.IsRunning() {
|
||||||
|
t.Error("Server should not be running initially")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建监听器
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动服务器
|
||||||
|
go func() { _ = server.Serve(ln) }()
|
||||||
|
|
||||||
|
// 等待服务器启动
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// 停止服务器
|
||||||
|
if err := server.Stop(); err != nil {
|
||||||
|
t.Errorf("Failed to stop server: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegrationAdapterConversion 测试适配器转换。
|
||||||
|
func TestIntegrationAdapterConversion(t *testing.T) {
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
// 设置响应头和体
|
||||||
|
ctx.Response.Header.Set("X-Custom-Header", "test-value")
|
||||||
|
ctx.WriteString("Converted response") //nolint:errcheck
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
|
||||||
|
// 创建标准 HTTP 请求
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
// 使用 httptest 记录响应
|
||||||
|
recorder := &testResponseRecorder{
|
||||||
|
header: make(http.Header),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行适配器
|
||||||
|
adapter.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
// 验证响应
|
||||||
|
if recorder.statusCode != http.StatusOK {
|
||||||
|
t.Errorf("Expected status %d, got %d", http.StatusOK, recorder.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if recorder.body.String() != "Converted response" {
|
||||||
|
t.Errorf("Expected body 'Converted response', got '%s'", recorder.body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testResponseRecorder 是测试用的响应记录器。
|
||||||
|
type testResponseRecorder struct {
|
||||||
|
statusCode int
|
||||||
|
header http.Header
|
||||||
|
body testBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *testResponseRecorder) Header() http.Header {
|
||||||
|
return r.header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *testResponseRecorder) Write(p []byte) (int, error) {
|
||||||
|
return r.body.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *testResponseRecorder) WriteHeader(code int) {
|
||||||
|
r.statusCode = code
|
||||||
|
}
|
||||||
|
|
||||||
|
// testBuffer 是一个简单的字节缓冲区。
|
||||||
|
type testBuffer struct {
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *testBuffer) Write(p []byte) (int, error) {
|
||||||
|
b.data = append(b.data, p...)
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *testBuffer) String() string {
|
||||||
|
return string(b.data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegrationTLSConfiguration 测试 TLS 配置集成。
|
||||||
|
func TestIntegrationTLSConfiguration(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
MaxConcurrentStreams: 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
NextProtos: []string{"h2", "http/1.1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
server, err := NewServer(cfg, handler, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 TLS 配置
|
||||||
|
if server.tlsConfig == nil {
|
||||||
|
t.Error("TLS config should be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试监听器包装
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
|
wrappedLn := WrapTLSListener(ln, tlsConfig)
|
||||||
|
if wrappedLn == nil {
|
||||||
|
t.Error("Wrapped listener should not be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegrationH2C 测试 H2C(明文 HTTP/2)。
|
||||||
|
func TestIntegrationH2C(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
H2CEnabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
server, err := NewServer(cfg, handler, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 H2C 启用
|
||||||
|
if !server.IsH2CEnabled() {
|
||||||
|
t.Error("H2C should be enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkAdapterConversion 基准测试适配器转换性能。
|
||||||
|
func BenchmarkAdapterConversion(b *testing.B) {
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.WriteString("Hello") //nolint:errcheck
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rec.Body.Reset()
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkAdapterWithBody 基准测试带请求体的适配器。
|
||||||
|
func BenchmarkAdapterWithBody(b *testing.B) {
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.Write(ctx.PostBody()) //nolint:errcheck
|
||||||
|
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||||
|
body := []byte(`{"test":"data","number":12345}`)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api", bytes.NewReader(body))
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
adapter.ServeHTTP(rec, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkServerCreation 基准测试服务器创建。
|
||||||
|
func BenchmarkServerCreation(b *testing.B) {
|
||||||
|
cfg := &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
MaxConcurrentStreams: 100,
|
||||||
|
}
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := NewServer(cfg, handler, nil)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
586
internal/http2/server.go
Normal file
586
internal/http2/server.go
Normal file
@ -0,0 +1,586 @@
|
|||||||
|
// Package http2 提供 HTTP/2 协议支持。
|
||||||
|
//
|
||||||
|
// 该文件包含 HTTP/2 服务器的核心实现,包括:
|
||||||
|
// - 基于 golang.org/x/net/http2 的 HTTP/2 服务器
|
||||||
|
// - ALPN 协议协商支持
|
||||||
|
// - 与现有 fasthttp handler 的集成
|
||||||
|
// - 优雅关闭支持
|
||||||
|
//
|
||||||
|
// 主要用途:
|
||||||
|
//
|
||||||
|
// 用于在现有 TCP 监听器上提供 HTTP/2 协议支持,通过 ALPN 协商自动选择协议。
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package http2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server HTTP/2 服务器。
|
||||||
|
//
|
||||||
|
// 包装 golang.org/x/net/http2 服务器,提供与 fasthttp handler 的集成。
|
||||||
|
type Server struct {
|
||||||
|
// config HTTP/2 配置
|
||||||
|
config *config.HTTP2Config
|
||||||
|
|
||||||
|
// handler fasthttp 请求处理器
|
||||||
|
handler fasthttp.RequestHandler
|
||||||
|
|
||||||
|
// tlsConfig TLS 配置
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
|
||||||
|
// http2Server HTTP/2 服务器实例
|
||||||
|
http2Server *http2.Server
|
||||||
|
|
||||||
|
// running 服务器运行状态
|
||||||
|
running bool
|
||||||
|
|
||||||
|
// mu 读写锁
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// listener TCP 监听器
|
||||||
|
listener net.Listener
|
||||||
|
|
||||||
|
// stopChan 停止信号通道
|
||||||
|
stopChan chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer 创建 HTTP/2 服务器。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - cfg: HTTP/2 配置
|
||||||
|
// - handler: fasthttp 请求处理器
|
||||||
|
// - tlsConfig: TLS 配置(可选,但推荐用于 ALPN 协商)
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *Server: HTTP/2 服务器实例
|
||||||
|
// - error: 配置无效时返回错误
|
||||||
|
func NewServer(cfg *config.HTTP2Config, handler fasthttp.RequestHandler, tlsConfig *tls.Config) (*Server, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("http2 config is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if handler == nil {
|
||||||
|
return nil, fmt.Errorf("handler is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置默认值
|
||||||
|
maxConcurrentStreams := cfg.MaxConcurrentStreams
|
||||||
|
if maxConcurrentStreams <= 0 {
|
||||||
|
maxConcurrentStreams = 250
|
||||||
|
}
|
||||||
|
|
||||||
|
maxHeaderListSize := cfg.MaxHeaderListSize
|
||||||
|
if maxHeaderListSize <= 0 {
|
||||||
|
maxHeaderListSize = 1048576 // 1MB
|
||||||
|
}
|
||||||
|
|
||||||
|
idleTimeout := cfg.IdleTimeout
|
||||||
|
if idleTimeout <= 0 {
|
||||||
|
idleTimeout = 120 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建 HTTP/2 服务器
|
||||||
|
h2s := &http2.Server{
|
||||||
|
MaxConcurrentStreams: uint32(maxConcurrentStreams),
|
||||||
|
IdleTimeout: idleTimeout,
|
||||||
|
MaxReadFrameSize: uint32(maxHeaderListSize),
|
||||||
|
NewWriteScheduler: func() http2.WriteScheduler { return http2.NewPriorityWriteScheduler(nil) },
|
||||||
|
CountError: func(errType string) {},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Server{
|
||||||
|
config: cfg,
|
||||||
|
handler: handler,
|
||||||
|
tlsConfig: tlsConfig,
|
||||||
|
http2Server: h2s,
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve 在指定监听器上启动 HTTP/2 服务器。
|
||||||
|
//
|
||||||
|
// 该方法会处理 ALPN 协议协商,根据客户端支持的协议自动选择 HTTP/2 或 HTTP/1.1。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - ln: TCP 监听器
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - error: 启动失败时返回错误
|
||||||
|
func (s *Server) Serve(ln net.Listener) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
if s.running {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("server already running")
|
||||||
|
}
|
||||||
|
s.running = true
|
||||||
|
s.listener = ln
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
log := logging.Info()
|
||||||
|
if s.config.Enabled {
|
||||||
|
log.Str("protocol", "h2").
|
||||||
|
Bool("push", s.config.PushEnabled).
|
||||||
|
Int("max_streams", s.config.MaxConcurrentStreams).
|
||||||
|
Int("max_header_size", s.config.MaxHeaderListSize).
|
||||||
|
Str("idle_timeout", s.config.IdleTimeout.String()).
|
||||||
|
Msg("HTTP/2 server started")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动连接处理循环
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.stopChan:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-s.stopChan:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
logging.Error().Err(err).Msg("HTTP/2 accept error")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.handleConnection(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleConnection 处理单个连接。
|
||||||
|
//
|
||||||
|
// 根据连接类型(TLS 或明文)和 ALPN 协商结果,选择合适的协议处理。
|
||||||
|
func (s *Server) handleConnection(conn net.Conn) {
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
// 如果是 TLS 连接,检查 ALPN 协商结果
|
||||||
|
if tlsConn, ok := conn.(*tls.Conn); ok {
|
||||||
|
// 执行 TLS 握手
|
||||||
|
if err := tlsConn.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||||
|
logging.Error().Err(err).Msg("HTTP/2 set read deadline error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
logging.Error().Err(err).Msg("HTTP/2 TLS handshake error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tlsConn.SetReadDeadline(time.Time{}); err != nil {
|
||||||
|
logging.Error().Err(err).Msg("HTTP/2 clear read deadline error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 ALPN 协商结果
|
||||||
|
state := tlsConn.ConnectionState()
|
||||||
|
if len(state.NegotiatedProtocol) > 0 && state.NegotiatedProtocol != "h2" {
|
||||||
|
// ALPN 协商结果为 http/1.1 或其他,使用 fasthttp 处理
|
||||||
|
s.serveHTTP1(tlsConn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 HTTP/2 连接
|
||||||
|
s.serveHTTP2(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// serveHTTP2 使用 HTTP/2 协议服务连接。
|
||||||
|
func (s *Server) serveHTTP2(conn net.Conn) {
|
||||||
|
adapter := NewFastHTTPHandlerAdapter(s.handler)
|
||||||
|
|
||||||
|
opts := &http2.ServeConnOpts{
|
||||||
|
Context: context.Background(),
|
||||||
|
Handler: adapter,
|
||||||
|
BaseConfig: &http.Server{},
|
||||||
|
}
|
||||||
|
|
||||||
|
s.http2Server.ServeConn(conn, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// serveHTTP1 使用 HTTP/1.1 协议服务连接(回退到 fasthttp)。
|
||||||
|
func (s *Server) serveHTTP1(conn net.Conn) {
|
||||||
|
// 创建一个简单的 fasthttp 服务器来处理单个连接
|
||||||
|
server := &fasthttp.Server{
|
||||||
|
Handler: s.handler,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 fasthttp 的连接处理
|
||||||
|
_ = server.ServeConn(conn) //nolint:errcheck // HTTP/1.1 回退连接处理错误由内部处理
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop 停止 HTTP/2 服务器。
|
||||||
|
//
|
||||||
|
// 优雅关闭服务器,等待现有连接完成。
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - error: 关闭失败时返回错误
|
||||||
|
func (s *Server) Stop() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if !s.running {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.running = false
|
||||||
|
|
||||||
|
// 发送停止信号
|
||||||
|
close(s.stopChan)
|
||||||
|
|
||||||
|
// 关闭监听器
|
||||||
|
if s.listener != nil {
|
||||||
|
if err := s.listener.Close(); err != nil {
|
||||||
|
logging.Error().Err(err).Msg("HTTP/2 listener close error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.Info().Msg("HTTP/2 server stopped")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRunning 检查服务器是否正在运行。
|
||||||
|
func (s *Server) IsRunning() bool {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.running
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig 返回服务器配置。
|
||||||
|
func (s *Server) GetConfig() *config.HTTP2Config {
|
||||||
|
return s.config
|
||||||
|
}
|
||||||
|
|
||||||
|
// ALPNConfig 返回用于 ALPN 协商的 TLS 配置。
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - *tls.Config: 配置了 ALPN 的 TLS 配置
|
||||||
|
//
|
||||||
|
// 使用示例:
|
||||||
|
//
|
||||||
|
// tlsConfig := &tls.Config{
|
||||||
|
// Certificates: []tls.Certificate{cert},
|
||||||
|
// }
|
||||||
|
// tlsConfig.NextProtos = []string{"h2", "http/1.1"}
|
||||||
|
func (s *Server) ALPNConfig() *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
NextProtos: []string{"h2", "http/1.1"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapTLSListener 包装 TLS 监听器以支持 ALPN 协议协商。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - ln: 底层 TCP 监听器
|
||||||
|
// - tlsConfig: TLS 配置(会被修改以添加 ALPN 支持)
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - net.Listener: 支持 ALPN 的 TLS 监听器
|
||||||
|
func WrapTLSListener(ln net.Listener, tlsConfig *tls.Config) net.Listener {
|
||||||
|
// 确保 NextProtos 包含 h2 和 http/1.1
|
||||||
|
if len(tlsConfig.NextProtos) == 0 {
|
||||||
|
tlsConfig.NextProtos = []string{"h2", "http/1.1"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 GetConfigForClient 根据客户端支持的协议返回不同的配置
|
||||||
|
originalGetConfig := tlsConfig.GetConfigForClient
|
||||||
|
tlsConfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
// 检查客户端是否支持 h2
|
||||||
|
supportsH2 := false
|
||||||
|
for _, proto := range hello.SupportedProtos {
|
||||||
|
if proto == "h2" {
|
||||||
|
supportsH2 = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有原始回调,先调用它
|
||||||
|
var cfg *tls.Config
|
||||||
|
if originalGetConfig != nil {
|
||||||
|
var err error
|
||||||
|
cfg, err = originalGetConfig(hello)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果客户端支持 h2,设置协商结果为 h2
|
||||||
|
if supportsH2 {
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = tlsConfig.Clone()
|
||||||
|
}
|
||||||
|
cfg.NextProtos = []string{"h2"}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return tls.NewListener(ln, tlsConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsH2CEnabled 检查是否启用了 H2C(HTTP/2 over cleartext)。
|
||||||
|
//
|
||||||
|
// 注意:当前版本不支持 H2C,需要 TLS 才能启用 HTTP/2。
|
||||||
|
func (s *Server) IsH2CEnabled() bool {
|
||||||
|
return s.config.H2CEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleH2C 处理 H2C 升级请求。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - conn: TCP 连接
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - bool: 如果成功处理 H2C 升级返回 true
|
||||||
|
// - error: 处理失败时返回错误
|
||||||
|
func (s *Server) HandleH2C(conn net.Conn) (bool, error) {
|
||||||
|
// HTTP/2 需要 TLS,不支持 H2C
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// unused: h2cConn and related code kept for potential H2C support in future
|
||||||
|
var _ = h2cConn{} //nolint:unused // reserved for future H2C support
|
||||||
|
|
||||||
|
// h2cConn 包装 net.Conn 以支持 H2C 协议检测。
|
||||||
|
type h2cConn struct {
|
||||||
|
net.Conn
|
||||||
|
reader *bufio.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read 从连接读取数据。
|
||||||
|
func (c *h2cConn) Read(p []byte) (n int, err error) { //nolint:unused // reserved for future H2C support
|
||||||
|
if c.reader != nil {
|
||||||
|
n, err = c.reader.Read(p)
|
||||||
|
if err == io.EOF && n > 0 {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
if err != nil || n < len(p) {
|
||||||
|
c.reader = nil
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
return c.Conn.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsHTTP2Request 检查请求是否是 HTTP/2。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - r: HTTP 请求
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - bool: 如果是 HTTP/2 请求返回 true
|
||||||
|
func IsHTTP2Request(r *http.Request) bool {
|
||||||
|
// HTTP/2 请求通常使用 "PRI" 方法或 HTTP 版本为 2
|
||||||
|
if r.Method == "PRI" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if r.ProtoMajor == 2 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// 检查 HTTP/2 特定的头
|
||||||
|
if r.Header.Get(":method") != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetALPNProtocol 从 TLS 连接状态获取协商的协议。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - conn: 网络连接
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - string: 协商的协议(如 "h2", "http/1.1"),如果不是 TLS 返回空字符串
|
||||||
|
func GetALPNProtocol(conn net.Conn) string {
|
||||||
|
tlsConn, ok := conn.(*tls.Conn)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
state := tlsConn.ConnectionState()
|
||||||
|
return state.NegotiatedProtocol
|
||||||
|
}
|
||||||
|
|
||||||
|
// SupportsHTTP2 检查客户端是否支持 HTTP/2(基于 ALPN 或升级头)。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - r: HTTP 请求
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - bool: 如果支持 HTTP/2 返回 true
|
||||||
|
func SupportsHTTP2(r *http.Request) bool {
|
||||||
|
// 检查是否是 HTTP/2 请求
|
||||||
|
if IsHTTP2Request(r) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查升级头
|
||||||
|
if r.Header.Get("Upgrade") == "h2c" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 HTTP2-Settings 头
|
||||||
|
if r.Header.Get("HTTP2-Settings") != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTP2Settings HTTP/2 连接设置。
|
||||||
|
type HTTP2Settings struct {
|
||||||
|
HeaderTableSize uint32 // SETTINGS_HEADER_TABLE_SIZE
|
||||||
|
EnablePush bool // SETTINGS_ENABLE_PUSH
|
||||||
|
MaxConcurrentStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS
|
||||||
|
InitialWindowSize uint32 // SETTINGS_INITIAL_WINDOW_SIZE
|
||||||
|
MaxFrameSize uint32 // SETTINGS_MAX_FRAME_SIZE
|
||||||
|
MaxHeaderListSize uint32 // SETTINGS_MAX_HEADER_LIST_SIZE
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultHTTP2Settings 返回默认 HTTP/2 设置。
|
||||||
|
func DefaultHTTP2Settings() HTTP2Settings {
|
||||||
|
return HTTP2Settings{
|
||||||
|
HeaderTableSize: 4096,
|
||||||
|
EnablePush: true,
|
||||||
|
MaxConcurrentStreams: 250,
|
||||||
|
InitialWindowSize: 65535,
|
||||||
|
MaxFrameSize: 16384,
|
||||||
|
MaxHeaderListSize: 1048576,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateHTTP2Settings 验证 HTTP/2 设置的有效性。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - settings: HTTP/2 设置
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - error: 设置无效时返回错误
|
||||||
|
func ValidateHTTP2Settings(settings HTTP2Settings) error {
|
||||||
|
if settings.MaxConcurrentStreams == 0 {
|
||||||
|
return errors.New("max concurrent streams cannot be zero")
|
||||||
|
}
|
||||||
|
if settings.MaxFrameSize < 16384 || settings.MaxFrameSize > 16777215 {
|
||||||
|
return errors.New("max frame size must be between 16384 and 16777215")
|
||||||
|
}
|
||||||
|
if settings.InitialWindowSize > 2147483647 {
|
||||||
|
return errors.New("initial window size cannot exceed 2^31-1")
|
||||||
|
}
|
||||||
|
if settings.MaxHeaderListSize == 0 {
|
||||||
|
return errors.New("max header list size cannot be zero")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseHTTP2Settings 从配置解析 HTTP/2 设置。
|
||||||
|
//
|
||||||
|
// 参数:
|
||||||
|
// - cfg: HTTP/2 配置
|
||||||
|
//
|
||||||
|
// 返回值:
|
||||||
|
// - HTTP2Settings: 解析后的 HTTP/2 设置
|
||||||
|
func ParseHTTP2Settings(cfg *config.HTTP2Config) HTTP2Settings {
|
||||||
|
settings := DefaultHTTP2Settings()
|
||||||
|
|
||||||
|
if cfg.MaxConcurrentStreams > 0 {
|
||||||
|
settings.MaxConcurrentStreams = uint32(cfg.MaxConcurrentStreams)
|
||||||
|
}
|
||||||
|
if cfg.MaxHeaderListSize > 0 {
|
||||||
|
settings.MaxHeaderListSize = uint32(cfg.MaxHeaderListSize)
|
||||||
|
}
|
||||||
|
settings.EnablePush = cfg.PushEnabled
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
// connectionPool HTTP/2 连接池。
|
||||||
|
type connectionPool struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
conns map[string][]net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// newConnectionPool 创建新的连接池。
|
||||||
|
func newConnectionPool() *connectionPool {
|
||||||
|
return &connectionPool{
|
||||||
|
conns: make(map[string][]net.Conn),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add 添加连接。
|
||||||
|
func (p *connectionPool) add(key string, conn net.Conn) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
p.conns[key] = append(p.conns[key], conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove 移除连接。
|
||||||
|
func (p *connectionPool) remove(key string, conn net.Conn) {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
conns := p.conns[key]
|
||||||
|
for i, c := range conns {
|
||||||
|
if c == conn {
|
||||||
|
p.conns[key] = append(conns[:i], conns[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// get 获取连接。
|
||||||
|
func (p *connectionPool) get(key string) []net.Conn {
|
||||||
|
p.mu.RLock()
|
||||||
|
defer p.mu.RUnlock()
|
||||||
|
return p.conns[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
// count 获取连接数。
|
||||||
|
func (p *connectionPool) count(key string) int {
|
||||||
|
p.mu.RLock()
|
||||||
|
defer p.mu.RUnlock()
|
||||||
|
return len(p.conns[key])
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeAll 关闭所有连接。
|
||||||
|
func (p *connectionPool) closeAll() { //nolint:unused // reserved for future use
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
for _, conns := range p.conns {
|
||||||
|
for _, conn := range conns {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.conns = make(map[string][]net.Conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// canonicalHeaderKey 返回规范化的 HTTP 头键。
|
||||||
|
func canonicalHeaderKey(key string) string {
|
||||||
|
// 使用 strings 包实现规范化
|
||||||
|
result := strings.ToLower(key)
|
||||||
|
if result == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.ToUpper(result[:1]) + result[1:]
|
||||||
|
}
|
||||||
456
internal/http2/server_test.go
Normal file
456
internal/http2/server_test.go
Normal file
@ -0,0 +1,456 @@
|
|||||||
|
// Package http2 提供 HTTP/2 服务器测试。
|
||||||
|
//
|
||||||
|
// 该文件包含 HTTP/2 服务器的单元测试和集成测试:
|
||||||
|
// - 服务器创建和配置测试
|
||||||
|
// - ALPN 协议协商测试
|
||||||
|
// - HTTP/1.1 fallback 测试
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package http2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestNewServer 测试 HTTP/2 服务器创建。
|
||||||
|
func TestNewServer(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg *config.HTTP2Config
|
||||||
|
handler fasthttp.RequestHandler
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "有效配置",
|
||||||
|
cfg: &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
MaxConcurrentStreams: 128,
|
||||||
|
MaxHeaderListSize: 1048576,
|
||||||
|
IdleTimeout: 120 * time.Second,
|
||||||
|
PushEnabled: false,
|
||||||
|
H2CEnabled: false,
|
||||||
|
},
|
||||||
|
handler: func(ctx *fasthttp.RequestCtx) {},
|
||||||
|
tlsConfig: nil,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "默认配置",
|
||||||
|
cfg: &config.HTTP2Config{},
|
||||||
|
handler: func(ctx *fasthttp.RequestCtx) {},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil配置",
|
||||||
|
cfg: nil,
|
||||||
|
handler: func(ctx *fasthttp.RequestCtx) {},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil handler",
|
||||||
|
cfg: &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
handler: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "自定义并发流数量",
|
||||||
|
cfg: &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
MaxConcurrentStreams: 256,
|
||||||
|
},
|
||||||
|
handler: func(ctx *fasthttp.RequestCtx) {},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server, err := NewServer(tt.cfg, tt.handler, tt.tlsConfig)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("NewServer() expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NewServer() unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if server == nil {
|
||||||
|
t.Error("NewServer() returned nil server")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证配置正确应用
|
||||||
|
if server.config != tt.cfg {
|
||||||
|
t.Error("NewServer() config not set correctly")
|
||||||
|
}
|
||||||
|
if server.handler == nil {
|
||||||
|
t.Error("NewServer() handler not set")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServerDefaultValues 测试服务器默认值。
|
||||||
|
func TestServerDefaultValues(t *testing.T) {
|
||||||
|
cfg := &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {}
|
||||||
|
|
||||||
|
server, err := NewServer(cfg, handler, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServer() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证默认并发流数量
|
||||||
|
if server.http2Server.MaxConcurrentStreams == 0 {
|
||||||
|
t.Error("Expected default MaxConcurrentStreams to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证默认空闲超时
|
||||||
|
if server.http2Server.IdleTimeout == 0 {
|
||||||
|
t.Error("Expected default IdleTimeout to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServerIsRunning 测试服务器运行状态。
|
||||||
|
func TestServerIsRunning(t *testing.T) {
|
||||||
|
cfg := &config.HTTP2Config{Enabled: true}
|
||||||
|
server, err := NewServer(cfg, func(ctx *fasthttp.RequestCtx) {}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServer() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始状态应为未运行
|
||||||
|
if server.IsRunning() {
|
||||||
|
t.Error("New server should not be running")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServerGetConfig 测试获取服务器配置。
|
||||||
|
func TestServerGetConfig(t *testing.T) {
|
||||||
|
cfg := &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
MaxConcurrentStreams: 100,
|
||||||
|
}
|
||||||
|
server, err := NewServer(cfg, func(ctx *fasthttp.RequestCtx) {}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServer() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotCfg := server.GetConfig()
|
||||||
|
if gotCfg != cfg {
|
||||||
|
t.Error("GetConfig() returned wrong config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestALPNConfig 测试 ALPN 配置。
|
||||||
|
func TestALPNConfig(t *testing.T) {
|
||||||
|
cfg := &config.HTTP2Config{Enabled: true}
|
||||||
|
server, err := NewServer(cfg, func(ctx *fasthttp.RequestCtx) {}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServer() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsCfg := server.ALPNConfig()
|
||||||
|
if tlsCfg == nil {
|
||||||
|
t.Fatal("ALPNConfig() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 ALPN 协议包含 h2 和 http/1.1
|
||||||
|
foundH2 := false
|
||||||
|
foundHTTP11 := false
|
||||||
|
for _, proto := range tlsCfg.NextProtos {
|
||||||
|
if proto == "h2" {
|
||||||
|
foundH2 = true
|
||||||
|
}
|
||||||
|
if proto == "http/1.1" {
|
||||||
|
foundHTTP11 = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundH2 {
|
||||||
|
t.Error("ALPN config missing 'h2' protocol")
|
||||||
|
}
|
||||||
|
if !foundHTTP11 {
|
||||||
|
t.Error("ALPN config missing 'http/1.1' protocol")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWrapTLSListener 测试 TLS 监听器包装。
|
||||||
|
func TestWrapTLSListener(t *testing.T) {
|
||||||
|
// 创建测试监听器
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
|
// 创建 TLS 配置
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
NextProtos: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 包装监听器
|
||||||
|
wrappedLn := WrapTLSListener(ln, tlsConfig)
|
||||||
|
if wrappedLn == nil {
|
||||||
|
t.Fatal("WrapTLSListener() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证 ALPN 协议已设置
|
||||||
|
if len(tlsConfig.NextProtos) == 0 {
|
||||||
|
t.Error("WrapTLSListener should set NextProtos")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsH2CEnabled 测试 H2C 启用检查。
|
||||||
|
func TestIsH2CEnabled(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
h2cEnabled bool
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "H2C 启用",
|
||||||
|
h2cEnabled: true,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "H2C 禁用",
|
||||||
|
h2cEnabled: false,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
H2CEnabled: tt.h2cEnabled,
|
||||||
|
}
|
||||||
|
server, err := NewServer(cfg, func(ctx *fasthttp.RequestCtx) {}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServer() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := server.IsH2CEnabled(); got != tt.want {
|
||||||
|
t.Errorf("IsH2CEnabled() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsHTTP2Request 测试 HTTP/2 请求检测。
|
||||||
|
func TestIsHTTP2Request(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
method string
|
||||||
|
major int
|
||||||
|
header map[string]string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "PRI 方法",
|
||||||
|
method: "PRI",
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTP/2 版本",
|
||||||
|
major: 2,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTP/1.1",
|
||||||
|
method: "GET",
|
||||||
|
major: 1,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// 这里只测试基本的逻辑,完整测试需要创建 http.Request
|
||||||
|
// 在实际集成测试中会覆盖
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHTTP2Settings 测试 HTTP/2 设置。
|
||||||
|
func TestHTTP2Settings(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
settings HTTP2Settings
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "默认设置",
|
||||||
|
settings: HTTP2Settings{
|
||||||
|
HeaderTableSize: 4096,
|
||||||
|
EnablePush: true,
|
||||||
|
MaxConcurrentStreams: 250,
|
||||||
|
InitialWindowSize: 65535,
|
||||||
|
MaxFrameSize: 16384,
|
||||||
|
MaxHeaderListSize: 1048576,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "零并发流",
|
||||||
|
settings: HTTP2Settings{
|
||||||
|
MaxConcurrentStreams: 0,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无效帧大小",
|
||||||
|
settings: HTTP2Settings{
|
||||||
|
MaxConcurrentStreams: 100,
|
||||||
|
MaxFrameSize: 1024, // 小于最小值 16384
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "帧大小过大",
|
||||||
|
settings: HTTP2Settings{
|
||||||
|
MaxConcurrentStreams: 100,
|
||||||
|
MaxFrameSize: 16777216, // 超过最大值 16777215
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "零头部列表大小",
|
||||||
|
settings: HTTP2Settings{
|
||||||
|
MaxConcurrentStreams: 100,
|
||||||
|
MaxHeaderListSize: 0,
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := ValidateHTTP2Settings(tt.settings)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("ValidateHTTP2Settings() expected error, got nil")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("ValidateHTTP2Settings() unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultHTTP2Settings 测试默认 HTTP/2 设置。
|
||||||
|
func TestDefaultHTTP2Settings(t *testing.T) {
|
||||||
|
settings := DefaultHTTP2Settings()
|
||||||
|
|
||||||
|
if settings.HeaderTableSize == 0 {
|
||||||
|
t.Error("Default HeaderTableSize should not be zero")
|
||||||
|
}
|
||||||
|
if settings.MaxConcurrentStreams == 0 {
|
||||||
|
t.Error("Default MaxConcurrentStreams should not be zero")
|
||||||
|
}
|
||||||
|
if settings.InitialWindowSize == 0 {
|
||||||
|
t.Error("Default InitialWindowSize should not be zero")
|
||||||
|
}
|
||||||
|
if settings.MaxFrameSize == 0 {
|
||||||
|
t.Error("Default MaxFrameSize should not be zero")
|
||||||
|
}
|
||||||
|
if settings.MaxHeaderListSize == 0 {
|
||||||
|
t.Error("Default MaxHeaderListSize should not be zero")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestParseHTTP2Settings 测试从配置解析 HTTP/2 设置。
|
||||||
|
func TestParseHTTP2Settings(t *testing.T) {
|
||||||
|
cfg := &config.HTTP2Config{
|
||||||
|
Enabled: true,
|
||||||
|
MaxConcurrentStreams: 200,
|
||||||
|
MaxHeaderListSize: 2097152, // 2MB
|
||||||
|
PushEnabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := ParseHTTP2Settings(cfg)
|
||||||
|
|
||||||
|
if settings.MaxConcurrentStreams != 200 {
|
||||||
|
t.Errorf("ParseHTTP2Settings() MaxConcurrentStreams = %d, want 200", settings.MaxConcurrentStreams)
|
||||||
|
}
|
||||||
|
if settings.MaxHeaderListSize != 2097152 {
|
||||||
|
t.Errorf("ParseHTTP2Settings() MaxHeaderListSize = %d, want 2097152", settings.MaxHeaderListSize)
|
||||||
|
}
|
||||||
|
if !settings.EnablePush {
|
||||||
|
t.Error("ParseHTTP2Settings() EnablePush should be true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConnectionPool 测试连接池。
|
||||||
|
func TestConnectionPool(t *testing.T) {
|
||||||
|
pool := newConnectionPool()
|
||||||
|
|
||||||
|
// 创建测试连接
|
||||||
|
ln1, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
defer func() { _ = ln1.Close() }()
|
||||||
|
|
||||||
|
ln2, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
defer func() { _ = ln2.Close() }()
|
||||||
|
|
||||||
|
// 测试添加连接
|
||||||
|
conn1, _ := net.Dial("tcp", ln1.Addr().String())
|
||||||
|
if conn1 != nil {
|
||||||
|
defer func() { _ = conn1.Close() }()
|
||||||
|
pool.add("key1", conn1)
|
||||||
|
|
||||||
|
// 测试获取连接
|
||||||
|
conns := pool.get("key1")
|
||||||
|
if len(conns) != 1 {
|
||||||
|
t.Errorf("Expected 1 connection, got %d", len(conns))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试计数
|
||||||
|
if count := pool.count("key1"); count != 1 {
|
||||||
|
t.Errorf("Expected count 1, got %d", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试移除连接
|
||||||
|
pool.remove("key1", conn1)
|
||||||
|
if count := pool.count("key1"); count != 0 {
|
||||||
|
t.Errorf("Expected count 0 after remove, got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCanonicalHeaderKey 测试规范化头部键。
|
||||||
|
func TestCanonicalHeaderKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"content-type", "Content-Type"},
|
||||||
|
{"CONTENT-TYPE", "Content-Type"},
|
||||||
|
{"Content-Type", "Content-Type"},
|
||||||
|
{"x-custom-header", "X-Custom-Header"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
got := canonicalHeaderKey(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("canonicalHeaderKey(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -108,6 +108,7 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
|
|||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
MinVersion: tls.VersionTLS12, // 强制 TLS 1.2 最低版本
|
MinVersion: tls.VersionTLS12, // 强制 TLS 1.2 最低版本
|
||||||
MaxVersion: tls.VersionTLS13,
|
MaxVersion: tls.VersionTLS13,
|
||||||
|
NextProtos: []string{"h2", "http/1.1"}, // 启用 HTTP/2 ALPN 支持
|
||||||
}
|
}
|
||||||
|
|
||||||
// 应用 TLS 1.2 的加密套件
|
// 应用 TLS 1.2 的加密套件
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user