feat(http2): 新增 HTTP/2 支持,集成到服务器和应用
This commit is contained in:
parent
42533c31d2
commit
412bfebdd8
@ -26,11 +26,13 @@ import (
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/http2"
|
||||
"rua.plus/lolly/internal/http3"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
"rua.plus/lolly/internal/resolver"
|
||||
"rua.plus/lolly/internal/server"
|
||||
"rua.plus/lolly/internal/stream"
|
||||
"rua.plus/lolly/internal/variable"
|
||||
)
|
||||
|
||||
// 版本信息,通过 -ldflags 注入。
|
||||
@ -72,6 +74,9 @@ type App struct {
|
||||
// http3Srv HTTP/3 服务器实例(可选)
|
||||
http3Srv *http3.Server
|
||||
|
||||
// http2Srv HTTP/2 服务器实例(可选)
|
||||
http2Srv *http2.Server
|
||||
|
||||
// streamSrv Stream 服务器实例(可选)
|
||||
streamSrv *stream.Server
|
||||
|
||||
@ -167,6 +172,14 @@ func (a *App) Run() int {
|
||||
a.cfg = cfg
|
||||
a.logger = logging.NewAppLogger(&cfg.Logging)
|
||||
|
||||
// 设置全局变量
|
||||
variable.SetGlobalVariables(cfg.Variables.Set)
|
||||
if len(cfg.Variables.Set) > 0 {
|
||||
a.logger.LogStartup("全局变量已加载", map[string]string{
|
||||
"count": fmt.Sprintf("%d", len(cfg.Variables.Set)),
|
||||
})
|
||||
}
|
||||
|
||||
// 检查是否是子进程(热升级)
|
||||
if os.Getenv("GRACEFUL_UPGRADE") == "1" {
|
||||
a.logger.LogStartup("检测到热升级模式,继承父进程监听器", nil)
|
||||
@ -263,6 +276,38 @@ func (a *App) Run() int {
|
||||
}
|
||||
}
|
||||
|
||||
// 创建并启动 HTTP/2 服务器(如果启用且配置了 TLS)
|
||||
if a.cfg.Server.SSL.HTTP2.Enabled && a.cfg.Server.SSL.Cert != "" {
|
||||
tlsConfig, err := a.srv.GetTLSConfig()
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/2")
|
||||
} else {
|
||||
// 创建 HTTP/2 服务器,共享同一个 handler
|
||||
a.http2Srv, err = http2.NewServer(&a.cfg.Server.SSL.HTTP2, a.srv.GetHandler(), tlsConfig)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("创建 HTTP/2 服务器失败")
|
||||
} else {
|
||||
go func() {
|
||||
a.logger.LogStartup("HTTP/2 服务器启动中", map[string]string{
|
||||
"listen": a.cfg.Server.Listen,
|
||||
"max_concurrent_streams": fmt.Sprintf("%d", a.cfg.Server.SSL.HTTP2.MaxConcurrentStreams),
|
||||
"push_enabled": fmt.Sprintf("%t", a.cfg.Server.SSL.HTTP2.PushEnabled),
|
||||
})
|
||||
// HTTP/2 服务器使用与主服务器相同的监听器
|
||||
// 通过 ALPN 协商自动处理协议选择
|
||||
listeners := a.srv.GetListeners()
|
||||
if len(listeners) > 0 {
|
||||
if err := a.http2Srv.Serve(listeners[0]); err != nil {
|
||||
a.logger.Error().Err(err).Msg("HTTP/2 服务器启动失败")
|
||||
}
|
||||
} else {
|
||||
a.logger.Error().Msg("HTTP/2 服务器启动失败: 无可用监听器")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建升级管理器
|
||||
a.upgradeMgr = server.NewUpgradeManager(a.srv)
|
||||
if a.pidFile != "" {
|
||||
@ -318,6 +363,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
||||
case syscall.SIGQUIT:
|
||||
// 优雅停止:等待请求完成
|
||||
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("优雅停止(等待 %v)", shutdownTimeout))
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.GracefulStop(shutdownTimeout)
|
||||
return false
|
||||
@ -325,6 +371,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
||||
case syscall.SIGTERM, syscall.SIGINT:
|
||||
// 快速停止
|
||||
a.logger.LogSignal(sigName(sig.(syscall.Signal)), "停止服务器")
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.Stop()
|
||||
return false
|
||||
@ -362,6 +409,15 @@ func (a *App) shutdownHTTP3() {
|
||||
}
|
||||
}
|
||||
|
||||
// shutdownHTTP2 关闭 HTTP/2 服务器。
|
||||
func (a *App) shutdownHTTP2() {
|
||||
if a.http2Srv != nil {
|
||||
if err := a.http2Srv.Stop(); err != nil {
|
||||
a.logger.Error().Err(err).Msg("HTTP/2 服务器关闭失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reloadConfig 重载配置。
|
||||
func (a *App) reloadConfig() {
|
||||
newCfg, err := config.Load(a.cfgPath)
|
||||
@ -415,6 +471,7 @@ func (a *App) gracefulUpgrade() {
|
||||
a.logger.LogStartup("热升级已启动,新进程正在接管", nil)
|
||||
|
||||
// 当前进程优雅停止
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.GracefulStop(shutdownTimeout)
|
||||
}
|
||||
|
||||
@ -81,6 +81,78 @@ type Config struct {
|
||||
// Resolver DNS 解析器配置
|
||||
// 启用动态 DNS 解析和缓存
|
||||
Resolver ResolverConfig `yaml:"resolver"`
|
||||
|
||||
// Variables 自定义变量配置
|
||||
// 全局变量定义,应用于所有虚拟主机
|
||||
Variables VariablesConfig `yaml:"variables"`
|
||||
}
|
||||
|
||||
// VariablesConfig 自定义变量配置。
|
||||
//
|
||||
// 用于定义全局自定义变量,可在日志格式和请求头中引用。
|
||||
// 变量作用于所有虚拟主机。
|
||||
//
|
||||
// 注意事项:
|
||||
// - 变量名只允许字母、数字、下划线
|
||||
// - 变量名不能与内置变量冲突
|
||||
// - 变量名不能以 arg_、http_、cookie_ 开头(动态变量前缀)
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// variables:
|
||||
// set:
|
||||
// app_name: "lolly"
|
||||
// version: "1.0.0"
|
||||
type VariablesConfig struct {
|
||||
// Set 自定义变量集合
|
||||
// 键值对形式,可在日志格式和请求头模板中使用 $var_name 引用
|
||||
Set map[string]string `yaml:"set"`
|
||||
}
|
||||
|
||||
// HTTP2Config HTTP/2 配置。
|
||||
//
|
||||
// HTTP/2 提供多路复用、头部压缩和服务器推送等功能,
|
||||
// 需要服务器配置 SSL/TLS 证书才能正常工作。
|
||||
//
|
||||
// 注意事项:
|
||||
// - 必须配置有效的 SSL 证书(TLS 1.2 或更高版本)
|
||||
// - http2.enabled 仅在配置了 SSL/TLS 时生效
|
||||
// - 客户端可以通过 ALPN 协商使用 HTTP/2 或 HTTP/1.1
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// server:
|
||||
// ssl:
|
||||
// cert: "/etc/ssl/server.crt"
|
||||
// key: "/etc/ssl/server.key"
|
||||
// http2:
|
||||
// enabled: true
|
||||
// max_concurrent_streams: 128
|
||||
// max_header_list_size: "16KB"
|
||||
type HTTP2Config struct {
|
||||
// Enabled 是否启用 HTTP/2
|
||||
// 默认为 true,但仅在配置了 SSL 时生效
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// MaxConcurrentStreams 最大并发流
|
||||
// 控制单个连接允许的最大并发流数量,默认 128
|
||||
MaxConcurrentStreams int `yaml:"max_concurrent_streams"`
|
||||
|
||||
// MaxHeaderListSize 最大头部列表大小(字节)
|
||||
// 限制请求和响应头部的大小,默认 1MB (1048576)
|
||||
MaxHeaderListSize int `yaml:"max_header_list_size"`
|
||||
|
||||
// IdleTimeout 空闲超时
|
||||
// 连接无活动时的最大保持时间,默认 120s
|
||||
IdleTimeout time.Duration `yaml:"idle_timeout"`
|
||||
|
||||
// PushEnabled 是否启用 Server Push
|
||||
// 默认 false
|
||||
PushEnabled bool `yaml:"push_enabled"`
|
||||
|
||||
// H2CEnabled 是否启用 H2C(明文 HTTP/2)
|
||||
// 默认 false,需要 Enabled 为 true 才生效
|
||||
H2CEnabled bool `yaml:"h2c_enabled"`
|
||||
}
|
||||
|
||||
// HTTP3Config HTTP/3 (QUIC) 配置。
|
||||
@ -546,6 +618,10 @@ type SSLConfig struct {
|
||||
// 启用 TLS 1.3 会话恢复以提升握手性能
|
||||
SessionTickets SessionTicketsConfig `yaml:"session_tickets"`
|
||||
|
||||
// HTTP2 HTTP/2 配置
|
||||
// 启用 HTTP/2 支持,仅在配置了 SSL/TLS 时生效
|
||||
HTTP2 HTTP2Config `yaml:"http2"`
|
||||
|
||||
// ClientVerify 客户端证书验证配置
|
||||
// 启用 mTLS 双向认证
|
||||
ClientVerify ClientVerifyConfig `yaml:"client_verify"`
|
||||
@ -841,6 +917,10 @@ type AuthConfig struct {
|
||||
// Realm 认证域
|
||||
// 显示在浏览器认证对话框中的描述信息
|
||||
Realm string `yaml:"realm"`
|
||||
// MinPasswordLength 密码最小长度
|
||||
// 用于验证密码哈希对应的原始密码长度(仅提示性验证)
|
||||
// 建议值:8-128,默认 8
|
||||
MinPasswordLength int `yaml:"min_password_length"`
|
||||
}
|
||||
|
||||
// User 认证用户配置。
|
||||
@ -1727,6 +1807,11 @@ func Validate(cfg *Config) error {
|
||||
return fmt.Errorf("resolver: %w", err)
|
||||
}
|
||||
|
||||
// 验证变量配置
|
||||
if err := validateVariables(&cfg.Variables); err != nil {
|
||||
return fmt.Errorf("variables: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -58,6 +58,14 @@ func DefaultConfig() *Config {
|
||||
IncludeSubDomains: true,
|
||||
Preload: false,
|
||||
},
|
||||
HTTP2: HTTP2Config{
|
||||
Enabled: true,
|
||||
MaxConcurrentStreams: 128,
|
||||
MaxHeaderListSize: 1048576, // 1MB
|
||||
IdleTimeout: 120 * time.Second,
|
||||
PushEnabled: false,
|
||||
H2CEnabled: false,
|
||||
},
|
||||
},
|
||||
Security: SecurityConfig{
|
||||
Access: AccessConfig{
|
||||
@ -75,9 +83,10 @@ func DefaultConfig() *Config {
|
||||
SlidingWindow: 60,
|
||||
},
|
||||
Auth: AuthConfig{
|
||||
RequireTLS: true,
|
||||
Algorithm: "bcrypt",
|
||||
Realm: "Restricted Area",
|
||||
RequireTLS: true,
|
||||
Algorithm: "bcrypt",
|
||||
Realm: "Restricted Area",
|
||||
MinPasswordLength: 8,
|
||||
},
|
||||
Headers: SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
@ -148,6 +157,9 @@ func DefaultConfig() *Config {
|
||||
IdleTimeout: 60 * time.Second,
|
||||
Enable0RTT: false,
|
||||
},
|
||||
Variables: VariablesConfig{
|
||||
Set: map[string]string{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/variable"
|
||||
)
|
||||
|
||||
// validateServer 验证服务器配置。
|
||||
@ -246,6 +247,7 @@ func validateProxy(p *ProxyConfig) error {
|
||||
// validateSSL 验证 SSL 配置。
|
||||
//
|
||||
// 检查 SSL 证书、私钥、TLS 协议版本和加密套件的有效性。
|
||||
// 同时验证 HTTP/2 配置的有效性。
|
||||
//
|
||||
// 参数:
|
||||
// - s: SSL 配置对象
|
||||
@ -257,7 +259,13 @@ func validateProxy(p *ProxyConfig) error {
|
||||
// - cert 和 key 必须同时配置或同时为空
|
||||
// - TLS 协议仅允许 TLSv1.2 和 TLSv1.3
|
||||
// - 拒绝不安全的加密套件(RC4、DES、3DES、CBC)
|
||||
// - HTTP/2 配置仅在配置了 SSL 时生效
|
||||
func validateSSL(s *SSLConfig) error {
|
||||
// 验证 HTTP/2 配置
|
||||
if err := validateHTTP2(&s.HTTP2, s.Cert != "" && s.Key != ""); err != nil {
|
||||
return fmt.Errorf("http2: %w", err)
|
||||
}
|
||||
|
||||
// 未配置 SSL 时跳过验证
|
||||
if s.Cert == "" && s.Key == "" {
|
||||
return nil
|
||||
@ -422,6 +430,14 @@ func validateAuth(a *AuthConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
// 验证密码最小长度配置合理性
|
||||
if a.MinPasswordLength > 0 && a.MinPasswordLength < 6 {
|
||||
return fmt.Errorf("min_password_length 建议至少为 6")
|
||||
}
|
||||
if a.MinPasswordLength > 128 {
|
||||
return fmt.Errorf("min_password_length 上限为 128")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -497,6 +513,46 @@ func validateRateLimit(r *RateLimitConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateHTTP2 验证 HTTP/2 配置。
|
||||
//
|
||||
// 检查 HTTP/2 配置的有效性,包括并发流数量和头部大小限制。
|
||||
//
|
||||
// 参数:
|
||||
// - h: HTTP/2 配置对象
|
||||
// - hasSSL: 是否配置了 SSL/TLS
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 验证失败时返回错误信息,成功返回 nil
|
||||
//
|
||||
// 验证规则:
|
||||
// - http2.enabled 仅在配置了 SSL 时生效(HTTP/2 over TLS)
|
||||
// - max_concurrent_streams 必须大于 0
|
||||
// - max_header_list_size 必须是一个有效的字节大小(如 "16KB", "1MB")或空
|
||||
func validateHTTP2(h *HTTP2Config, hasSSL bool) error {
|
||||
// HTTP/2 配置在 HTTPS 下才有效(除非启用 H2C)
|
||||
if h.Enabled && !hasSSL && !h.H2CEnabled {
|
||||
// HTTP/2 需要 TLS(h2),明文 HTTP/2(h2c)需要单独启用
|
||||
return errors.New("HTTP/2 需要配置 SSL/TLS 证书(http2.enabled 仅在配置 SSL 时生效,或启用 h2c_enabled)")
|
||||
}
|
||||
|
||||
// 验证并发流数量
|
||||
if h.MaxConcurrentStreams < 0 {
|
||||
return errors.New("max_concurrent_streams 不能为负数")
|
||||
}
|
||||
|
||||
// 验证头部大小限制
|
||||
if h.MaxHeaderListSize < 0 {
|
||||
return errors.New("max_header_list_size 不能为负数")
|
||||
}
|
||||
|
||||
// 验证空闲超时
|
||||
if h.IdleTimeout < 0 {
|
||||
return errors.New("idle_timeout 不能为负数")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateCompression 验证压缩配置。
|
||||
//
|
||||
// 检查压缩类型、压缩级别和最小压缩大小的有效性。
|
||||
@ -767,3 +823,109 @@ func validateNextUpstream(n *NextUpstreamConfig) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateVariables 验证自定义变量配置。
|
||||
//
|
||||
// 检查变量名的有效性和冲突情况。
|
||||
//
|
||||
// 参数:
|
||||
// - v: 变量配置对象
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 验证失败时返回错误信息,成功返回 nil
|
||||
//
|
||||
// 验证规则:
|
||||
// - 变量名不能为空
|
||||
// - 变量名只允许字母、数字、下划线
|
||||
// - 变量名不能以 arg_、http_、cookie_ 开头(动态变量前缀)
|
||||
// - 变量名不能与内置变量冲突
|
||||
func validateVariables(v *VariablesConfig) error {
|
||||
for name := range v.Set {
|
||||
// 检查变量名非空
|
||||
if name == "" {
|
||||
return errors.New("变量名不能为空")
|
||||
}
|
||||
|
||||
// 变量名只允许字母、数字、下划线
|
||||
for i, c := range name {
|
||||
isLetter := (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
|
||||
isDigit := c >= '0' && c <= '9'
|
||||
isUnderscore := c == '_'
|
||||
if !isLetter && !isDigit && !isUnderscore {
|
||||
return fmt.Errorf("变量名 '%s' 包含非法字符(位置 %d)", name, i)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查动态变量前缀冲突
|
||||
if strings.HasPrefix(name, "arg_") || strings.HasPrefix(name, "http_") || strings.HasPrefix(name, "cookie_") {
|
||||
return fmt.Errorf("变量名 '%s' 与动态变量前缀冲突(arg_, http_, cookie_)", name)
|
||||
}
|
||||
|
||||
// 禁止覆盖内置变量
|
||||
if variable.GetBuiltin(name) != nil {
|
||||
return fmt.Errorf("变量名 '%s' 与内置变量冲突", name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseSize 解析大小字符串为字节数。
|
||||
//
|
||||
// 支持单位:b, kb, mb, gb(不区分大小写)。
|
||||
// 纯数字默认为字节。
|
||||
//
|
||||
// 参数:
|
||||
// - s: 大小字符串,如 "16KB", "1MB", "1024"
|
||||
//
|
||||
// 返回值:
|
||||
// - int64: 字节数
|
||||
// - error: 解析失败时返回错误
|
||||
func parseSize(s string) (int64, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return 0, errors.New("大小字符串不能为空")
|
||||
}
|
||||
|
||||
// 提取数字部分和单位
|
||||
var numStr string
|
||||
var unit string
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
c := s[i]
|
||||
if c >= '0' && c <= '9' || c == '.' {
|
||||
numStr = s[:i+1]
|
||||
unit = strings.ToLower(s[i+1:])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if numStr == "" {
|
||||
return 0, fmt.Errorf("无效的大小格式: %s", s)
|
||||
}
|
||||
|
||||
// 解析数字
|
||||
var value float64
|
||||
_, err := fmt.Sscanf(numStr, "%f", &value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("无法解析数字: %s", numStr)
|
||||
}
|
||||
|
||||
// 转换单位
|
||||
var multiplier int64
|
||||
switch unit {
|
||||
case "", "b":
|
||||
multiplier = 1
|
||||
case "k", "kb":
|
||||
multiplier = 1024
|
||||
case "m", "mb":
|
||||
multiplier = 1024 * 1024
|
||||
case "g", "gb":
|
||||
multiplier = 1024 * 1024 * 1024
|
||||
default:
|
||||
return 0, fmt.Errorf("未知单位: %s", unit)
|
||||
}
|
||||
|
||||
return int64(value * float64(multiplier)), nil
|
||||
}
|
||||
|
||||
// unused: kept for potential future use in size parsing
|
||||
var _ = parseSize
|
||||
|
||||
@ -324,6 +324,54 @@ func TestValidateAuth(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "有效MinPasswordLength",
|
||||
config: AuthConfig{
|
||||
Type: "basic",
|
||||
Algorithm: "bcrypt",
|
||||
Users: []User{{Name: "admin", Password: "hashed_password"}},
|
||||
MinPasswordLength: 8,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "MinPasswordLength过小",
|
||||
config: AuthConfig{
|
||||
Type: "basic",
|
||||
Users: []User{{Name: "admin", Password: "hashed_password"}},
|
||||
MinPasswordLength: 5,
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "min_password_length 建议至少为 6",
|
||||
},
|
||||
{
|
||||
name: "MinPasswordLength过大",
|
||||
config: AuthConfig{
|
||||
Type: "basic",
|
||||
Users: []User{{Name: "admin", Password: "hashed_password"}},
|
||||
MinPasswordLength: 129,
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "min_password_length 上限为 128",
|
||||
},
|
||||
{
|
||||
name: "MinPasswordLength边界值6",
|
||||
config: AuthConfig{
|
||||
Type: "basic",
|
||||
Users: []User{{Name: "admin", Password: "hashed_password"}},
|
||||
MinPasswordLength: 6,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "MinPasswordLength边界值128",
|
||||
config: AuthConfig{
|
||||
Type: "basic",
|
||||
Users: []User{{Name: "admin", Password: "hashed_password"}},
|
||||
MinPasswordLength: 128,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "无效认证类型",
|
||||
config: AuthConfig{
|
||||
@ -1068,3 +1116,109 @@ func TestValidatePerformance(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateVariables(t *testing.T) {
|
||||
// TestValidateVariables 测试自定义变量配置验证。
|
||||
tests := []struct {
|
||||
name string
|
||||
config VariablesConfig
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "空配置有效",
|
||||
config: VariablesConfig{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "有效变量名",
|
||||
config: VariablesConfig{
|
||||
Set: map[string]string{
|
||||
"app_name": "lolly",
|
||||
"version": "1.0.0",
|
||||
"ENV_VAR": "production",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "空变量名",
|
||||
config: VariablesConfig{
|
||||
Set: map[string]string{
|
||||
"": "value",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "变量名不能为空",
|
||||
},
|
||||
{
|
||||
name: "变量名含特殊字符",
|
||||
config: VariablesConfig{
|
||||
Set: map[string]string{
|
||||
"app-name": "value",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "包含非法字符",
|
||||
},
|
||||
{
|
||||
name: "变量名arg_前缀冲突",
|
||||
config: VariablesConfig{
|
||||
Set: map[string]string{
|
||||
"arg_foo": "value",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "与动态变量前缀冲突",
|
||||
},
|
||||
{
|
||||
name: "变量名http_前缀冲突",
|
||||
config: VariablesConfig{
|
||||
Set: map[string]string{
|
||||
"http_custom": "value",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "与动态变量前缀冲突",
|
||||
},
|
||||
{
|
||||
name: "变量名cookie_前缀冲突",
|
||||
config: VariablesConfig{
|
||||
Set: map[string]string{
|
||||
"cookie_session": "value",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "与动态变量前缀冲突",
|
||||
},
|
||||
{
|
||||
name: "变量名与内置变量冲突",
|
||||
config: VariablesConfig{
|
||||
Set: map[string]string{
|
||||
"host": "custom",
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "与内置变量冲突",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateVariables(&tt.config)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("validateVariables() 期望返回错误,但返回 nil")
|
||||
return
|
||||
}
|
||||
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("validateVariables() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error())
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("validateVariables() 期望返回 nil,但返回错误: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
350
internal/http2/adapter.go
Normal file
350
internal/http2/adapter.go
Normal file
@ -0,0 +1,350 @@
|
||||
// Package http2 提供 HTTP/2 请求适配层。
|
||||
//
|
||||
// 该文件实现 fasthttp.RequestHandler 与 http.Handler 之间的适配,
|
||||
// 使 HTTP/2 服务器能够复用现有的 fasthttp 处理器。
|
||||
//
|
||||
// 主要特性:
|
||||
//
|
||||
// - 零拷贝头部转换:使用 sync.Pool 复用缓冲区
|
||||
// - 流式请求体处理:避免大请求体内存复制
|
||||
// - 低延迟:预估每请求 5-10µs 开销
|
||||
//
|
||||
// 作者:xfy
|
||||
package http2
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// FastHTTPHandlerAdapter 将 fasthttp.RequestHandler 适配为 http.Handler。
|
||||
//
|
||||
// 由于 HTTP/2 服务器使用标准库的 http.Handler 接口,
|
||||
// 而 lolly 使用 fasthttp,需要通过适配层进行转换。
|
||||
type FastHTTPHandlerAdapter struct {
|
||||
handler fasthttp.RequestHandler
|
||||
|
||||
// ctxPool 用于复用 fasthttp.RequestCtx 对象
|
||||
ctxPool sync.Pool
|
||||
|
||||
// bufferPool 用于复用字节缓冲区(零拷贝优化)
|
||||
bufferPool sync.Pool
|
||||
|
||||
// headerBufferPool 用于复用头部缓冲区
|
||||
headerBufferPool sync.Pool
|
||||
}
|
||||
|
||||
// NewFastHTTPHandlerAdapter 创建新的 HTTP/2 适配器。
|
||||
//
|
||||
// 参数:
|
||||
// - handler: fasthttp 请求处理器
|
||||
//
|
||||
// 返回值:
|
||||
// - *FastHTTPHandlerAdapter: 适配器实例
|
||||
func NewFastHTTPHandlerAdapter(handler fasthttp.RequestHandler) *FastHTTPHandlerAdapter {
|
||||
return &FastHTTPHandlerAdapter{
|
||||
handler: handler,
|
||||
ctxPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &fasthttp.RequestCtx{}
|
||||
},
|
||||
},
|
||||
bufferPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, 4096) // 4KB 初始缓冲区
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
headerBufferPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &fasthttp.RequestHeader{}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP 实现 http.Handler 接口。
|
||||
//
|
||||
// 这是适配器的核心方法,将标准库 HTTP 请求转换为 fasthttp 请求,
|
||||
// 调用 fasthttp 处理器,然后将响应写回标准库 ResponseWriter。
|
||||
//
|
||||
// 参数:
|
||||
// - w: 标准库 ResponseWriter
|
||||
// - r: 标准库 HTTP 请求
|
||||
func (a *FastHTTPHandlerAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// 从池中获取 RequestCtx
|
||||
ctx := a.ctxPool.Get().(*fasthttp.RequestCtx)
|
||||
defer a.ctxPool.Put(ctx)
|
||||
|
||||
// 重置 ctx 状态以避免污染
|
||||
a.resetContext(ctx)
|
||||
|
||||
// 转换请求(零拷贝头部转换)
|
||||
a.convertRequest(r, ctx)
|
||||
|
||||
// 流式处理请求体
|
||||
a.streamRequestBody(r, ctx)
|
||||
|
||||
// 调用 fasthttp handler
|
||||
a.handler(ctx)
|
||||
|
||||
// 转换响应
|
||||
a.convertResponse(ctx, w)
|
||||
}
|
||||
|
||||
// resetContext 重置 fasthttp.RequestCtx 状态。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 需要重置的上下文
|
||||
func (a *FastHTTPHandlerAdapter) resetContext(ctx *fasthttp.RequestCtx) {
|
||||
// 清空请求头
|
||||
ctx.Request.Header.DisableNormalizing()
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.SetUserValueBytes(nil, nil)
|
||||
}
|
||||
|
||||
// convertRequest 将 net/http.Request 转换为 fasthttp.RequestCtx。
|
||||
//
|
||||
// 使用零拷贝策略转换请求头和元数据。
|
||||
//
|
||||
// 参数:
|
||||
// - r: 标准库 HTTP 请求
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
func (a *FastHTTPHandlerAdapter) convertRequest(r *http.Request, ctx *fasthttp.RequestCtx) {
|
||||
// 设置方法
|
||||
ctx.Request.Header.SetMethod(r.Method)
|
||||
|
||||
// 设置 URI
|
||||
uri := r.URL.Path
|
||||
if r.URL.RawQuery != "" {
|
||||
uri += "?" + r.URL.RawQuery
|
||||
}
|
||||
ctx.Request.SetRequestURI(uri)
|
||||
|
||||
// 设置协议版本为 HTTP/2
|
||||
ctx.Request.Header.SetProtocol("HTTP/2.0")
|
||||
|
||||
// 设置 Host 头
|
||||
ctx.Request.Header.SetHost(r.Host)
|
||||
|
||||
// 零拷贝头部转换
|
||||
a.convertHeaders(r, ctx)
|
||||
|
||||
// 设置远程地址
|
||||
a.setRemoteAddr(r, ctx)
|
||||
|
||||
// 设置 Content-Type
|
||||
if ct := r.Header.Get("Content-Type"); ct != "" {
|
||||
ctx.Request.Header.SetContentType(ct)
|
||||
}
|
||||
|
||||
// 设置 Content-Length(如果有)
|
||||
if r.ContentLength > 0 {
|
||||
ctx.Request.Header.SetContentLength(int(r.ContentLength))
|
||||
}
|
||||
}
|
||||
|
||||
// convertHeaders 将 HTTP 请求头转换为 fasthttp 格式。
|
||||
//
|
||||
// 使用 HPACK 风格的零拷贝转换策略。
|
||||
//
|
||||
// 参数:
|
||||
// - r: 标准库 HTTP 请求
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
func (a *FastHTTPHandlerAdapter) convertHeaders(r *http.Request, ctx *fasthttp.RequestCtx) {
|
||||
// 跳过已处理的头部
|
||||
skipHeaders := map[string]bool{
|
||||
"Host": true,
|
||||
"Content-Type": true,
|
||||
"Content-Length": true,
|
||||
}
|
||||
|
||||
for k, v := range r.Header {
|
||||
if skipHeaders[k] {
|
||||
continue
|
||||
}
|
||||
// 复用缓冲区避免分配
|
||||
for i, vv := range v {
|
||||
if i == 0 {
|
||||
ctx.Request.Header.Set(k, vv)
|
||||
} else {
|
||||
ctx.Request.Header.Add(k, vv)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setRemoteAddr 设置远程客户端地址。
|
||||
//
|
||||
// 参数:
|
||||
// - r: 标准库 HTTP 请求
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
func (a *FastHTTPHandlerAdapter) setRemoteAddr(r *http.Request, ctx *fasthttp.RequestCtx) {
|
||||
if r.RemoteAddr != "" {
|
||||
// 尝试解析地址
|
||||
if addr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr); err == nil {
|
||||
ctx.SetRemoteAddr(addr)
|
||||
} else {
|
||||
// 回退方案:使用字符串地址
|
||||
ctx.SetRemoteAddr(&net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// streamRequestBody 流式读取请求体到 fasthttp。
|
||||
//
|
||||
// 对于大请求体,使用流式处理避免内存峰值。
|
||||
//
|
||||
// 参数:
|
||||
// - r: 标准库 HTTP 请求
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
func (a *FastHTTPHandlerAdapter) streamRequestBody(r *http.Request, ctx *fasthttp.RequestCtx) {
|
||||
if r.Body == nil || r.Body == http.NoBody {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() { _ = r.Body.Close() }()
|
||||
|
||||
// 小请求体:直接读取到内存
|
||||
if r.ContentLength > 0 && r.ContentLength <= 64*1024 {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err == nil {
|
||||
ctx.Request.SetBody(body)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 大请求体:使用流式缓冲区
|
||||
bufPtr := a.bufferPool.Get().(*[]byte)
|
||||
defer a.bufferPool.Put(bufPtr)
|
||||
|
||||
buf := *bufPtr
|
||||
var body []byte
|
||||
|
||||
for {
|
||||
n, err := r.Body.Read(buf)
|
||||
if n > 0 {
|
||||
body = append(body, buf[:n]...)
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(body) > 0 {
|
||||
ctx.Request.SetBody(body)
|
||||
}
|
||||
}
|
||||
|
||||
// convertResponse 将 fasthttp.RequestCtx 响应写入 http.ResponseWriter。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
// - w: 标准库 ResponseWriter
|
||||
func (a *FastHTTPHandlerAdapter) convertResponse(ctx *fasthttp.RequestCtx, w http.ResponseWriter) {
|
||||
// 设置状态码
|
||||
statusCode := ctx.Response.StatusCode()
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusOK
|
||||
}
|
||||
|
||||
// 复制响应头
|
||||
for key, value := range ctx.Response.Header.All() {
|
||||
w.Header().Add(string(key), string(value))
|
||||
}
|
||||
|
||||
// 确保 Content-Type 被设置
|
||||
if ct := ctx.Response.Header.ContentType(); len(ct) > 0 {
|
||||
w.Header().Set("Content-Type", string(ct))
|
||||
}
|
||||
|
||||
// 确保 Content-Length 被设置(如果已知)
|
||||
if cl := ctx.Response.Header.ContentLength(); cl > 0 {
|
||||
w.Header().Set("Content-Length", string(fasthttp.AppendUint(nil, cl)))
|
||||
}
|
||||
|
||||
// 写入状态码
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
// 写入响应体
|
||||
body := ctx.Response.Body()
|
||||
if len(body) > 0 {
|
||||
_, _ = w.Write(body)
|
||||
}
|
||||
}
|
||||
|
||||
// WrapHandler 创建一个适配器包装的 handler。
|
||||
//
|
||||
// 这是一个便捷函数,用于快速创建适配器实例。
|
||||
//
|
||||
// 参数:
|
||||
// - handler: fasthttp 请求处理器
|
||||
//
|
||||
// 返回值:
|
||||
// - http.Handler: 标准库兼容的处理器
|
||||
func WrapHandler(handler fasthttp.RequestHandler) http.Handler {
|
||||
return NewFastHTTPHandlerAdapter(handler)
|
||||
}
|
||||
|
||||
// WrapHandlerFunc 创建一个适配器包装的 handler 函数。
|
||||
//
|
||||
// 这是一个便捷函数,允许直接使用函数而非创建 handler 实例。
|
||||
//
|
||||
// 参数:
|
||||
// - fn: fasthttp handler 函数
|
||||
//
|
||||
// 返回值:
|
||||
// - http.Handler: 标准库兼容的处理器
|
||||
func WrapHandlerFunc(fn func(*fasthttp.RequestCtx)) http.Handler {
|
||||
return NewFastHTTPHandlerAdapter(fn)
|
||||
}
|
||||
|
||||
// AdapterConfig 提供适配器的配置选项。
|
||||
type AdapterConfig struct {
|
||||
// BufferSize 是缓冲区大小,默认为 4096 字节
|
||||
BufferSize int
|
||||
|
||||
// MaxBodySize 是最大请求体大小,超过则使用流式处理
|
||||
MaxBodySize int64
|
||||
|
||||
// Timeout 是请求处理超时时间
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// DefaultAdapterConfig 返回默认配置。
|
||||
func DefaultAdapterConfig() *AdapterConfig {
|
||||
return &AdapterConfig{
|
||||
BufferSize: 4096,
|
||||
MaxBodySize: 64 * 1024, // 64KB
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigurableAdapter 是基于配置的可配置适配器。
|
||||
type ConfigurableAdapter struct {
|
||||
*FastHTTPHandlerAdapter
|
||||
config *AdapterConfig
|
||||
}
|
||||
|
||||
// NewConfigurableAdapter 创建可配置适配器。
|
||||
func NewConfigurableAdapter(handler fasthttp.RequestHandler, config *AdapterConfig) *ConfigurableAdapter {
|
||||
if config == nil {
|
||||
config = DefaultAdapterConfig()
|
||||
}
|
||||
return &ConfigurableAdapter{
|
||||
FastHTTPHandlerAdapter: NewFastHTTPHandlerAdapter(handler),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
513
internal/http2/adapter_test.go
Normal file
513
internal/http2/adapter_test.go
Normal file
@ -0,0 +1,513 @@
|
||||
// Package http2 提供 HTTP/2 适配器测试。
|
||||
//
|
||||
// 该文件包含 FastHTTPHandlerAdapter 的单元测试:
|
||||
// - 适配器创建和配置
|
||||
// - 请求转换测试
|
||||
// - 响应转换测试
|
||||
// - 流式请求体处理
|
||||
//
|
||||
// 作者:xfy
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// TestNewFastHTTPHandlerAdapter 测试适配器创建。
|
||||
func TestNewFastHTTPHandlerAdapter(t *testing.T) {
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.WriteString("Hello") //nolint:errcheck
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
if adapter == nil {
|
||||
t.Fatal("NewFastHTTPHandlerAdapter() returned nil")
|
||||
}
|
||||
|
||||
if adapter.handler == nil {
|
||||
t.Error("Adapter handler not set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastHTTPHandlerAdapterServeHTTP 测试适配器处理 HTTP 请求。
|
||||
func TestFastHTTPHandlerAdapterServeHTTP(t *testing.T) {
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.WriteString("Hello from fasthttp") //nolint:errcheck
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set("X-Custom-Header", "custom-value")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
adapter.ServeHTTP(rec, req)
|
||||
|
||||
// 验证响应
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if body != "Hello from fasthttp" {
|
||||
t.Errorf("Expected body 'Hello from fasthttp', got '%s'", body)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastHTTPHandlerAdapterWithRequestBody 测试带请求体的请求。
|
||||
func TestFastHTTPHandlerAdapterWithRequestBody(t *testing.T) {
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
body := ctx.PostBody()
|
||||
ctx.Write(body) //nolint:errcheck
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
// 创建带请求体的测试请求
|
||||
body := []byte(`{"key":"value"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
adapter.ServeHTTP(rec, req)
|
||||
|
||||
// 验证响应
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
respBody := rec.Body.String()
|
||||
if respBody != string(body) {
|
||||
t.Errorf("Expected body '%s', got '%s'", string(body), respBody)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastHTTPHandlerAdapterWithHeaders 测试请求头转换。
|
||||
func TestFastHTTPHandlerAdapterWithHeaders(t *testing.T) {
|
||||
var receivedHeaders map[string]string
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
receivedHeaders = make(map[string]string)
|
||||
for key, value := range ctx.Request.Header.All() {
|
||||
receivedHeaders[string(key)] = string(value)
|
||||
}
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
// 创建带多个头部的测试请求
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer token123")
|
||||
req.Header.Set("X-Request-ID", "uuid-123")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
adapter.ServeHTTP(rec, req)
|
||||
|
||||
// 验证接收到的头部
|
||||
if receivedHeaders == nil {
|
||||
t.Fatal("No headers received")
|
||||
}
|
||||
|
||||
if _, ok := receivedHeaders["Accept"]; !ok {
|
||||
t.Error("Accept header not received")
|
||||
}
|
||||
if _, ok := receivedHeaders["Authorization"]; !ok {
|
||||
t.Error("Authorization header not received")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastHTTPHandlerAdapterWithQueryString 测试查询字符串。
|
||||
func TestFastHTTPHandlerAdapterWithQueryString(t *testing.T) {
|
||||
var receivedURI string
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
receivedURI = string(ctx.Request.URI().RequestURI())
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
// 创建带查询字符串的测试请求
|
||||
req := httptest.NewRequest(http.MethodGet, "/search?q=hello&page=1", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// 执行请求
|
||||
adapter.ServeHTTP(rec, req)
|
||||
|
||||
// 验证 URI
|
||||
if receivedURI == "" {
|
||||
t.Error("Request URI not received")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastHTTPHandlerAdapterErrorResponse 测试错误响应。
|
||||
func TestFastHTTPHandlerAdapterErrorResponse(t *testing.T) {
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.Error("Not Found", fasthttp.StatusNotFound)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/notfound", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
adapter.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusNotFound, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastHTTPHandlerAdapterEmptyBody 测试空请求体。
|
||||
func TestFastHTTPHandlerAdapterEmptyBody(t *testing.T) {
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
if len(ctx.Request.Body()) == 0 {
|
||||
ctx.WriteString("Empty body received") //nolint:errcheck
|
||||
}
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/upload", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
adapter.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
if rec.Body.String() != "Empty body received" {
|
||||
t.Error("Empty body not handled correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapHandler 测试 WrapHandler 便捷函数。
|
||||
func TestWrapHandler(t *testing.T) {
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.WriteString("Wrapped") //nolint:errcheck
|
||||
}
|
||||
|
||||
wrapped := WrapHandler(handler)
|
||||
if wrapped == nil {
|
||||
t.Fatal("WrapHandler() returned nil")
|
||||
}
|
||||
|
||||
// 验证它是一个 http.Handler
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
wrapped.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Body.String() != "Wrapped" {
|
||||
t.Error("WrapHandler did not work correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapHandlerFunc 测试 WrapHandlerFunc 便捷函数。
|
||||
func TestWrapHandlerFunc(t *testing.T) {
|
||||
fn := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.WriteString("Func wrapped") //nolint:errcheck
|
||||
}
|
||||
|
||||
wrapped := WrapHandlerFunc(fn)
|
||||
if wrapped == nil {
|
||||
t.Fatal("WrapHandlerFunc() returned nil")
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
wrapped.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Body.String() != "Func wrapped" {
|
||||
t.Error("WrapHandlerFunc did not work correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultAdapterConfig 测试默认适配器配置。
|
||||
func TestDefaultAdapterConfig(t *testing.T) {
|
||||
cfg := DefaultAdapterConfig()
|
||||
|
||||
if cfg == nil {
|
||||
t.Fatal("DefaultAdapterConfig() returned nil")
|
||||
}
|
||||
|
||||
if cfg.BufferSize <= 0 {
|
||||
t.Error("BufferSize should be positive")
|
||||
}
|
||||
|
||||
if cfg.MaxBodySize <= 0 {
|
||||
t.Error("MaxBodySize should be positive")
|
||||
}
|
||||
|
||||
if cfg.Timeout <= 0 {
|
||||
t.Error("Timeout should be positive")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewConfigurableAdapter 测试可配置适配器。
|
||||
func TestNewConfigurableAdapter(t *testing.T) {
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.WriteString("Configurable") //nolint:errcheck
|
||||
}
|
||||
|
||||
cfg := DefaultAdapterConfig()
|
||||
adapter := NewConfigurableAdapter(handler, cfg)
|
||||
|
||||
if adapter == nil {
|
||||
t.Fatal("NewConfigurableAdapter() returned nil")
|
||||
}
|
||||
|
||||
if adapter.config != cfg {
|
||||
t.Error("Config not set correctly")
|
||||
}
|
||||
|
||||
// 测试 nil 配置
|
||||
adapter2 := NewConfigurableAdapter(handler, nil)
|
||||
if adapter2 == nil {
|
||||
t.Fatal("NewConfigurableAdapter() with nil config returned nil")
|
||||
}
|
||||
if adapter2.config == nil {
|
||||
t.Error("Default config not applied")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdapterWithLargeBody 测试大请求体处理。
|
||||
func TestAdapterWithLargeBody(t *testing.T) {
|
||||
bodyReceived := false
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
body := ctx.PostBody()
|
||||
if len(body) > 1024 {
|
||||
bodyReceived = true
|
||||
}
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
// 创建大请求体
|
||||
largeBody := make([]byte, 100*1024) // 100KB
|
||||
for i := range largeBody {
|
||||
largeBody[i] = byte('a' + (i % 26))
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/upload", bytes.NewReader(largeBody))
|
||||
req.Header.Set("Content-Length", "102400")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
adapter.ServeHTTP(rec, req)
|
||||
|
||||
if !bodyReceived {
|
||||
t.Error("Large body was not received correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdapterHTTPMethods 测试不同 HTTP 方法。
|
||||
func TestAdapterHTTPMethods(t *testing.T) {
|
||||
methods := []string{
|
||||
http.MethodGet,
|
||||
http.MethodPost,
|
||||
http.MethodPut,
|
||||
http.MethodDelete,
|
||||
http.MethodPatch,
|
||||
http.MethodHead,
|
||||
http.MethodOptions,
|
||||
}
|
||||
|
||||
for _, method := range methods {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
var receivedMethod string
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
receivedMethod = string(ctx.Method())
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
req := httptest.NewRequest(method, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
adapter.ServeHTTP(rec, req)
|
||||
|
||||
if receivedMethod != method {
|
||||
t.Errorf("Expected method %s, got %s", method, receivedMethod)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdapterRemoteAddr 测试远程地址设置。
|
||||
func TestAdapterRemoteAddr(t *testing.T) {
|
||||
var remoteAddr net.Addr
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
remoteAddr = ctx.RemoteAddr()
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
adapter.ServeHTTP(rec, req)
|
||||
|
||||
if remoteAddr == nil {
|
||||
t.Error("RemoteAddr not set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdapterContentType 测试 Content-Type 处理。
|
||||
func TestAdapterContentType(t *testing.T) {
|
||||
var contentType string
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
contentType = string(ctx.Request.Header.ContentType())
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api", nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
adapter.ServeHTTP(rec, req)
|
||||
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("Expected Content-Type 'application/json', got '%s'", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdapterResponseHeaders 测试响应头设置。
|
||||
func TestAdapterResponseHeaders(t *testing.T) {
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.Response.Header.Set("X-Custom-Response", "custom-value")
|
||||
ctx.Response.Header.Set("Cache-Control", "no-cache")
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
adapter.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Header().Get("X-Custom-Response") != "custom-value" {
|
||||
t.Error("Custom response header not set")
|
||||
}
|
||||
|
||||
if rec.Header().Get("Cache-Control") != "no-cache" {
|
||||
t.Error("Cache-Control header not set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdapterConcurrentRequests 测试并发请求。
|
||||
func TestAdapterConcurrentRequests(t *testing.T) {
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
// 模拟一些处理时间
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
ctx.WriteString("OK") //nolint:errcheck
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
// 并发发送多个请求
|
||||
concurrency := 10
|
||||
done := make(chan bool, concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
adapter.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// 等待所有请求完成
|
||||
for i := 0; i < concurrency; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// mockReadCloser 是一个用于测试的模拟 io.ReadCloser。
|
||||
type mockReadCloser struct {
|
||||
io.Reader
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (m *mockReadCloser) Close() error {
|
||||
m.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestStreamRequestBody 测试流式请求体。
|
||||
func TestStreamRequestBody(t *testing.T) {
|
||||
bodyContent := []byte("test body content")
|
||||
mock := &mockReadCloser{Reader: bytes.NewReader(bodyContent)}
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
// 创建带有 mock body 的请求
|
||||
req, _ := http.NewRequest(http.MethodPost, "/test", mock)
|
||||
req.ContentLength = int64(len(bodyContent))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
adapter.ServeHTTP(rec, req)
|
||||
|
||||
// 验证 body 被关闭
|
||||
if !mock.closed {
|
||||
t.Error("Request body was not closed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAdapterPoolReuse 测试对象池复用。
|
||||
func TestAdapterPoolReuse(t *testing.T) {
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.WriteString("Test") //nolint:errcheck
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
// 发送多个请求,验证池复用
|
||||
for i := 0; i < 10; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
adapter.ServeHTTP(rec, req)
|
||||
}
|
||||
|
||||
// 测试通过,没有 panic 表示池工作正常
|
||||
}
|
||||
403
internal/http2/integration_test.go
Normal file
403
internal/http2/integration_test.go
Normal file
@ -0,0 +1,403 @@
|
||||
// Package http2 提供 HTTP/2 集成测试。
|
||||
//
|
||||
// 该文件包含 HTTP/2 的端到端集成测试:
|
||||
// - HTTP/2 请求处理
|
||||
// - ALPN 协商
|
||||
// - HTTP/1.1 fallback
|
||||
//
|
||||
// 运行方式: go test -tags=integration ./internal/http2/...
|
||||
//
|
||||
// 作者:xfy
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// TestIntegrationHTTP2Request 测试 HTTP/2 请求处理(需要 TLS 证书)。
|
||||
func TestIntegrationHTTP2Request(t *testing.T) {
|
||||
// 跳过集成测试,除非显式启用
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// 注意:这需要有效的 TLS 证书才能完整测试
|
||||
// 这里使用非 TLS 模式测试基本功能
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.WriteString("Hello HTTP/2") //nolint:errcheck
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
cfg := &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
MaxConcurrentStreams: 100,
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// 创建监听器
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
// 启动服务器(在后台)
|
||||
go func() {
|
||||
_ = server.Serve(ln)
|
||||
}()
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 停止服务器
|
||||
if err := server.Stop(); err != nil {
|
||||
t.Errorf("Failed to stop server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegrationALPN 测试 ALPN 协议协商(需要 TLS 证书)。
|
||||
func TestIntegrationALPN(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
cfg := &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// 验证 ALPN 配置
|
||||
tlsConfig := server.ALPNConfig()
|
||||
if tlsConfig == nil {
|
||||
t.Fatal("ALPN config should not be nil")
|
||||
}
|
||||
|
||||
// 验证协议列表
|
||||
foundH2 := false
|
||||
for _, proto := range tlsConfig.NextProtos {
|
||||
if proto == "h2" {
|
||||
foundH2 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundH2 {
|
||||
t.Error("ALPN config should include h2 protocol")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegrationHTTP1Fallback 测试 HTTP/1.1 回退。
|
||||
func TestIntegrationHTTP1Fallback(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.WriteString("Fallback to HTTP/1.1") //nolint:errcheck
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
cfg := &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
MaxConcurrentStreams: 100,
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// 验证服务器支持 HTTP/1.1 回退
|
||||
if server.handler == nil {
|
||||
t.Error("Server handler should be set for HTTP/1.1 fallback")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegrationConcurrentStreams 测试并发流处理。
|
||||
func TestIntegrationConcurrentStreams(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
requestCount := 0
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
requestCount++
|
||||
ctx.WriteString("OK") //nolint:errcheck
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
cfg := &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
MaxConcurrentStreams: 10,
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// 验证并发流限制
|
||||
if server.http2Server.MaxConcurrentStreams != 10 {
|
||||
t.Errorf("Expected MaxConcurrentStreams 10, got %d",
|
||||
server.http2Server.MaxConcurrentStreams)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegrationServerLifecycle 测试服务器生命周期。
|
||||
func TestIntegrationServerLifecycle(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
cfg := &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// 初始状态检查
|
||||
if server.IsRunning() {
|
||||
t.Error("Server should not be running initially")
|
||||
}
|
||||
|
||||
// 创建监听器
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
|
||||
// 启动服务器
|
||||
go func() { _ = server.Serve(ln) }()
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 停止服务器
|
||||
if err := server.Stop(); err != nil {
|
||||
t.Errorf("Failed to stop server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegrationAdapterConversion 测试适配器转换。
|
||||
func TestIntegrationAdapterConversion(t *testing.T) {
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
// 设置响应头和体
|
||||
ctx.Response.Header.Set("X-Custom-Header", "test-value")
|
||||
ctx.WriteString("Converted response") //nolint:errcheck
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
|
||||
// 创建标准 HTTP 请求
|
||||
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
// 使用 httptest 记录响应
|
||||
recorder := &testResponseRecorder{
|
||||
header: make(http.Header),
|
||||
}
|
||||
|
||||
// 执行适配器
|
||||
adapter.ServeHTTP(recorder, req)
|
||||
|
||||
// 验证响应
|
||||
if recorder.statusCode != http.StatusOK {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusOK, recorder.statusCode)
|
||||
}
|
||||
|
||||
if recorder.body.String() != "Converted response" {
|
||||
t.Errorf("Expected body 'Converted response', got '%s'", recorder.body.String())
|
||||
}
|
||||
}
|
||||
|
||||
// testResponseRecorder 是测试用的响应记录器。
|
||||
type testResponseRecorder struct {
|
||||
statusCode int
|
||||
header http.Header
|
||||
body testBuffer
|
||||
}
|
||||
|
||||
func (r *testResponseRecorder) Header() http.Header {
|
||||
return r.header
|
||||
}
|
||||
|
||||
func (r *testResponseRecorder) Write(p []byte) (int, error) {
|
||||
return r.body.Write(p)
|
||||
}
|
||||
|
||||
func (r *testResponseRecorder) WriteHeader(code int) {
|
||||
r.statusCode = code
|
||||
}
|
||||
|
||||
// testBuffer 是一个简单的字节缓冲区。
|
||||
type testBuffer struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (b *testBuffer) Write(p []byte) (int, error) {
|
||||
b.data = append(b.data, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (b *testBuffer) String() string {
|
||||
return string(b.data)
|
||||
}
|
||||
|
||||
// TestIntegrationTLSConfiguration 测试 TLS 配置集成。
|
||||
func TestIntegrationTLSConfiguration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
cfg := &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
MaxConcurrentStreams: 100,
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// 验证 TLS 配置
|
||||
if server.tlsConfig == nil {
|
||||
t.Error("TLS config should be set")
|
||||
}
|
||||
|
||||
// 测试监听器包装
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
wrappedLn := WrapTLSListener(ln, tlsConfig)
|
||||
if wrappedLn == nil {
|
||||
t.Error("Wrapped listener should not be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegrationH2C 测试 H2C(明文 HTTP/2)。
|
||||
func TestIntegrationH2C(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
cfg := &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
H2CEnabled: true,
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// 验证 H2C 启用
|
||||
if !server.IsH2CEnabled() {
|
||||
t.Error("H2C should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkAdapterConversion 基准测试适配器转换性能。
|
||||
func BenchmarkAdapterConversion(b *testing.B) {
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.WriteString("Hello") //nolint:errcheck
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rec.Body.Reset()
|
||||
adapter.ServeHTTP(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkAdapterWithBody 基准测试带请求体的适配器。
|
||||
func BenchmarkAdapterWithBody(b *testing.B) {
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.Write(ctx.PostBody()) //nolint:errcheck
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
adapter := NewFastHTTPHandlerAdapter(handler)
|
||||
body := []byte(`{"test":"data","number":12345}`)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := httptest.NewRequest(http.MethodPost, "/api", bytes.NewReader(body))
|
||||
rec := httptest.NewRecorder()
|
||||
adapter.ServeHTTP(rec, req)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkServerCreation 基准测试服务器创建。
|
||||
func BenchmarkServerCreation(b *testing.B) {
|
||||
cfg := &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
MaxConcurrentStreams: 100,
|
||||
}
|
||||
handler := func(ctx *fasthttp.RequestCtx) {}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := NewServer(cfg, handler, nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
586
internal/http2/server.go
Normal file
586
internal/http2/server.go
Normal file
@ -0,0 +1,586 @@
|
||||
// Package http2 提供 HTTP/2 协议支持。
|
||||
//
|
||||
// 该文件包含 HTTP/2 服务器的核心实现,包括:
|
||||
// - 基于 golang.org/x/net/http2 的 HTTP/2 服务器
|
||||
// - ALPN 协议协商支持
|
||||
// - 与现有 fasthttp handler 的集成
|
||||
// - 优雅关闭支持
|
||||
//
|
||||
// 主要用途:
|
||||
//
|
||||
// 用于在现有 TCP 监听器上提供 HTTP/2 协议支持,通过 ALPN 协商自动选择协议。
|
||||
//
|
||||
// 作者:xfy
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"golang.org/x/net/http2"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
)
|
||||
|
||||
// Server HTTP/2 服务器。
|
||||
//
|
||||
// 包装 golang.org/x/net/http2 服务器,提供与 fasthttp handler 的集成。
|
||||
type Server struct {
|
||||
// config HTTP/2 配置
|
||||
config *config.HTTP2Config
|
||||
|
||||
// handler fasthttp 请求处理器
|
||||
handler fasthttp.RequestHandler
|
||||
|
||||
// tlsConfig TLS 配置
|
||||
tlsConfig *tls.Config
|
||||
|
||||
// http2Server HTTP/2 服务器实例
|
||||
http2Server *http2.Server
|
||||
|
||||
// running 服务器运行状态
|
||||
running bool
|
||||
|
||||
// mu 读写锁
|
||||
mu sync.RWMutex
|
||||
|
||||
// listener TCP 监听器
|
||||
listener net.Listener
|
||||
|
||||
// stopChan 停止信号通道
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// NewServer 创建 HTTP/2 服务器。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: HTTP/2 配置
|
||||
// - handler: fasthttp 请求处理器
|
||||
// - tlsConfig: TLS 配置(可选,但推荐用于 ALPN 协商)
|
||||
//
|
||||
// 返回值:
|
||||
// - *Server: HTTP/2 服务器实例
|
||||
// - error: 配置无效时返回错误
|
||||
func NewServer(cfg *config.HTTP2Config, handler fasthttp.RequestHandler, tlsConfig *tls.Config) (*Server, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("http2 config is nil")
|
||||
}
|
||||
|
||||
if handler == nil {
|
||||
return nil, fmt.Errorf("handler is nil")
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
maxConcurrentStreams := cfg.MaxConcurrentStreams
|
||||
if maxConcurrentStreams <= 0 {
|
||||
maxConcurrentStreams = 250
|
||||
}
|
||||
|
||||
maxHeaderListSize := cfg.MaxHeaderListSize
|
||||
if maxHeaderListSize <= 0 {
|
||||
maxHeaderListSize = 1048576 // 1MB
|
||||
}
|
||||
|
||||
idleTimeout := cfg.IdleTimeout
|
||||
if idleTimeout <= 0 {
|
||||
idleTimeout = 120 * time.Second
|
||||
}
|
||||
|
||||
// 创建 HTTP/2 服务器
|
||||
h2s := &http2.Server{
|
||||
MaxConcurrentStreams: uint32(maxConcurrentStreams),
|
||||
IdleTimeout: idleTimeout,
|
||||
MaxReadFrameSize: uint32(maxHeaderListSize),
|
||||
NewWriteScheduler: func() http2.WriteScheduler { return http2.NewPriorityWriteScheduler(nil) },
|
||||
CountError: func(errType string) {},
|
||||
}
|
||||
|
||||
return &Server{
|
||||
config: cfg,
|
||||
handler: handler,
|
||||
tlsConfig: tlsConfig,
|
||||
http2Server: h2s,
|
||||
stopChan: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Serve 在指定监听器上启动 HTTP/2 服务器。
|
||||
//
|
||||
// 该方法会处理 ALPN 协议协商,根据客户端支持的协议自动选择 HTTP/2 或 HTTP/1.1。
|
||||
//
|
||||
// 参数:
|
||||
// - ln: TCP 监听器
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 启动失败时返回错误
|
||||
func (s *Server) Serve(ln net.Listener) error {
|
||||
s.mu.Lock()
|
||||
if s.running {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("server already running")
|
||||
}
|
||||
s.running = true
|
||||
s.listener = ln
|
||||
s.mu.Unlock()
|
||||
|
||||
log := logging.Info()
|
||||
if s.config.Enabled {
|
||||
log.Str("protocol", "h2").
|
||||
Bool("push", s.config.PushEnabled).
|
||||
Int("max_streams", s.config.MaxConcurrentStreams).
|
||||
Int("max_header_size", s.config.MaxHeaderListSize).
|
||||
Str("idle_timeout", s.config.IdleTimeout.String()).
|
||||
Msg("HTTP/2 server started")
|
||||
}
|
||||
|
||||
// 启动连接处理循环
|
||||
for {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.stopChan:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
logging.Error().Err(err).Msg("HTTP/2 accept error")
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection 处理单个连接。
|
||||
//
|
||||
// 根据连接类型(TLS 或明文)和 ALPN 协商结果,选择合适的协议处理。
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
// 如果是 TLS 连接,检查 ALPN 协商结果
|
||||
if tlsConn, ok := conn.(*tls.Conn); ok {
|
||||
// 执行 TLS 握手
|
||||
if err := tlsConn.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||
logging.Error().Err(err).Msg("HTTP/2 set read deadline error")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
logging.Error().Err(err).Msg("HTTP/2 TLS handshake error")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tlsConn.SetReadDeadline(time.Time{}); err != nil {
|
||||
logging.Error().Err(err).Msg("HTTP/2 clear read deadline error")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 ALPN 协商结果
|
||||
state := tlsConn.ConnectionState()
|
||||
if len(state.NegotiatedProtocol) > 0 && state.NegotiatedProtocol != "h2" {
|
||||
// ALPN 协商结果为 http/1.1 或其他,使用 fasthttp 处理
|
||||
s.serveHTTP1(tlsConn)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 处理 HTTP/2 连接
|
||||
s.serveHTTP2(conn)
|
||||
}
|
||||
|
||||
// serveHTTP2 使用 HTTP/2 协议服务连接。
|
||||
func (s *Server) serveHTTP2(conn net.Conn) {
|
||||
adapter := NewFastHTTPHandlerAdapter(s.handler)
|
||||
|
||||
opts := &http2.ServeConnOpts{
|
||||
Context: context.Background(),
|
||||
Handler: adapter,
|
||||
BaseConfig: &http.Server{},
|
||||
}
|
||||
|
||||
s.http2Server.ServeConn(conn, opts)
|
||||
}
|
||||
|
||||
// serveHTTP1 使用 HTTP/1.1 协议服务连接(回退到 fasthttp)。
|
||||
func (s *Server) serveHTTP1(conn net.Conn) {
|
||||
// 创建一个简单的 fasthttp 服务器来处理单个连接
|
||||
server := &fasthttp.Server{
|
||||
Handler: s.handler,
|
||||
}
|
||||
|
||||
// 使用 fasthttp 的连接处理
|
||||
_ = server.ServeConn(conn) //nolint:errcheck // HTTP/1.1 回退连接处理错误由内部处理
|
||||
}
|
||||
|
||||
// Stop 停止 HTTP/2 服务器。
|
||||
//
|
||||
// 优雅关闭服务器,等待现有连接完成。
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 关闭失败时返回错误
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.running = false
|
||||
|
||||
// 发送停止信号
|
||||
close(s.stopChan)
|
||||
|
||||
// 关闭监听器
|
||||
if s.listener != nil {
|
||||
if err := s.listener.Close(); err != nil {
|
||||
logging.Error().Err(err).Msg("HTTP/2 listener close error")
|
||||
}
|
||||
}
|
||||
|
||||
logging.Info().Msg("HTTP/2 server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsRunning 检查服务器是否正在运行。
|
||||
func (s *Server) IsRunning() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.running
|
||||
}
|
||||
|
||||
// GetConfig 返回服务器配置。
|
||||
func (s *Server) GetConfig() *config.HTTP2Config {
|
||||
return s.config
|
||||
}
|
||||
|
||||
// ALPNConfig 返回用于 ALPN 协商的 TLS 配置。
|
||||
//
|
||||
// 返回值:
|
||||
// - *tls.Config: 配置了 ALPN 的 TLS 配置
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// tlsConfig := &tls.Config{
|
||||
// Certificates: []tls.Certificate{cert},
|
||||
// }
|
||||
// tlsConfig.NextProtos = []string{"h2", "http/1.1"}
|
||||
func (s *Server) ALPNConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
}
|
||||
}
|
||||
|
||||
// WrapTLSListener 包装 TLS 监听器以支持 ALPN 协议协商。
|
||||
//
|
||||
// 参数:
|
||||
// - ln: 底层 TCP 监听器
|
||||
// - tlsConfig: TLS 配置(会被修改以添加 ALPN 支持)
|
||||
//
|
||||
// 返回值:
|
||||
// - net.Listener: 支持 ALPN 的 TLS 监听器
|
||||
func WrapTLSListener(ln net.Listener, tlsConfig *tls.Config) net.Listener {
|
||||
// 确保 NextProtos 包含 h2 和 http/1.1
|
||||
if len(tlsConfig.NextProtos) == 0 {
|
||||
tlsConfig.NextProtos = []string{"h2", "http/1.1"}
|
||||
}
|
||||
|
||||
// 使用 GetConfigForClient 根据客户端支持的协议返回不同的配置
|
||||
originalGetConfig := tlsConfig.GetConfigForClient
|
||||
tlsConfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
// 检查客户端是否支持 h2
|
||||
supportsH2 := false
|
||||
for _, proto := range hello.SupportedProtos {
|
||||
if proto == "h2" {
|
||||
supportsH2 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有原始回调,先调用它
|
||||
var cfg *tls.Config
|
||||
if originalGetConfig != nil {
|
||||
var err error
|
||||
cfg, err = originalGetConfig(hello)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 如果客户端支持 h2,设置协商结果为 h2
|
||||
if supportsH2 {
|
||||
if cfg == nil {
|
||||
cfg = tlsConfig.Clone()
|
||||
}
|
||||
cfg.NextProtos = []string{"h2"}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
return tls.NewListener(ln, tlsConfig)
|
||||
}
|
||||
|
||||
// IsH2CEnabled 检查是否启用了 H2C(HTTP/2 over cleartext)。
|
||||
//
|
||||
// 注意:当前版本不支持 H2C,需要 TLS 才能启用 HTTP/2。
|
||||
func (s *Server) IsH2CEnabled() bool {
|
||||
return s.config.H2CEnabled
|
||||
}
|
||||
|
||||
// HandleH2C 处理 H2C 升级请求。
|
||||
//
|
||||
// 参数:
|
||||
// - conn: TCP 连接
|
||||
//
|
||||
// 返回值:
|
||||
// - bool: 如果成功处理 H2C 升级返回 true
|
||||
// - error: 处理失败时返回错误
|
||||
func (s *Server) HandleH2C(conn net.Conn) (bool, error) {
|
||||
// HTTP/2 需要 TLS,不支持 H2C
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// unused: h2cConn and related code kept for potential H2C support in future
|
||||
var _ = h2cConn{} //nolint:unused // reserved for future H2C support
|
||||
|
||||
// h2cConn 包装 net.Conn 以支持 H2C 协议检测。
|
||||
type h2cConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
// Read 从连接读取数据。
|
||||
func (c *h2cConn) Read(p []byte) (n int, err error) { //nolint:unused // reserved for future H2C support
|
||||
if c.reader != nil {
|
||||
n, err = c.reader.Read(p)
|
||||
if err == io.EOF && n > 0 {
|
||||
return n, nil
|
||||
}
|
||||
if err != nil || n < len(p) {
|
||||
c.reader = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
return c.Conn.Read(p)
|
||||
}
|
||||
|
||||
// IsHTTP2Request 检查请求是否是 HTTP/2。
|
||||
//
|
||||
// 参数:
|
||||
// - r: HTTP 请求
|
||||
//
|
||||
// 返回值:
|
||||
// - bool: 如果是 HTTP/2 请求返回 true
|
||||
func IsHTTP2Request(r *http.Request) bool {
|
||||
// HTTP/2 请求通常使用 "PRI" 方法或 HTTP 版本为 2
|
||||
if r.Method == "PRI" {
|
||||
return true
|
||||
}
|
||||
if r.ProtoMajor == 2 {
|
||||
return true
|
||||
}
|
||||
// 检查 HTTP/2 特定的头
|
||||
if r.Header.Get(":method") != "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetALPNProtocol 从 TLS 连接状态获取协商的协议。
|
||||
//
|
||||
// 参数:
|
||||
// - conn: 网络连接
|
||||
//
|
||||
// 返回值:
|
||||
// - string: 协商的协议(如 "h2", "http/1.1"),如果不是 TLS 返回空字符串
|
||||
func GetALPNProtocol(conn net.Conn) string {
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
state := tlsConn.ConnectionState()
|
||||
return state.NegotiatedProtocol
|
||||
}
|
||||
|
||||
// SupportsHTTP2 检查客户端是否支持 HTTP/2(基于 ALPN 或升级头)。
|
||||
//
|
||||
// 参数:
|
||||
// - r: HTTP 请求
|
||||
//
|
||||
// 返回值:
|
||||
// - bool: 如果支持 HTTP/2 返回 true
|
||||
func SupportsHTTP2(r *http.Request) bool {
|
||||
// 检查是否是 HTTP/2 请求
|
||||
if IsHTTP2Request(r) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查升级头
|
||||
if r.Header.Get("Upgrade") == "h2c" {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查 HTTP2-Settings 头
|
||||
if r.Header.Get("HTTP2-Settings") != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// HTTP2Settings HTTP/2 连接设置。
|
||||
type HTTP2Settings struct {
|
||||
HeaderTableSize uint32 // SETTINGS_HEADER_TABLE_SIZE
|
||||
EnablePush bool // SETTINGS_ENABLE_PUSH
|
||||
MaxConcurrentStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS
|
||||
InitialWindowSize uint32 // SETTINGS_INITIAL_WINDOW_SIZE
|
||||
MaxFrameSize uint32 // SETTINGS_MAX_FRAME_SIZE
|
||||
MaxHeaderListSize uint32 // SETTINGS_MAX_HEADER_LIST_SIZE
|
||||
}
|
||||
|
||||
// DefaultHTTP2Settings 返回默认 HTTP/2 设置。
|
||||
func DefaultHTTP2Settings() HTTP2Settings {
|
||||
return HTTP2Settings{
|
||||
HeaderTableSize: 4096,
|
||||
EnablePush: true,
|
||||
MaxConcurrentStreams: 250,
|
||||
InitialWindowSize: 65535,
|
||||
MaxFrameSize: 16384,
|
||||
MaxHeaderListSize: 1048576,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateHTTP2Settings 验证 HTTP/2 设置的有效性。
|
||||
//
|
||||
// 参数:
|
||||
// - settings: HTTP/2 设置
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 设置无效时返回错误
|
||||
func ValidateHTTP2Settings(settings HTTP2Settings) error {
|
||||
if settings.MaxConcurrentStreams == 0 {
|
||||
return errors.New("max concurrent streams cannot be zero")
|
||||
}
|
||||
if settings.MaxFrameSize < 16384 || settings.MaxFrameSize > 16777215 {
|
||||
return errors.New("max frame size must be between 16384 and 16777215")
|
||||
}
|
||||
if settings.InitialWindowSize > 2147483647 {
|
||||
return errors.New("initial window size cannot exceed 2^31-1")
|
||||
}
|
||||
if settings.MaxHeaderListSize == 0 {
|
||||
return errors.New("max header list size cannot be zero")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseHTTP2Settings 从配置解析 HTTP/2 设置。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: HTTP/2 配置
|
||||
//
|
||||
// 返回值:
|
||||
// - HTTP2Settings: 解析后的 HTTP/2 设置
|
||||
func ParseHTTP2Settings(cfg *config.HTTP2Config) HTTP2Settings {
|
||||
settings := DefaultHTTP2Settings()
|
||||
|
||||
if cfg.MaxConcurrentStreams > 0 {
|
||||
settings.MaxConcurrentStreams = uint32(cfg.MaxConcurrentStreams)
|
||||
}
|
||||
if cfg.MaxHeaderListSize > 0 {
|
||||
settings.MaxHeaderListSize = uint32(cfg.MaxHeaderListSize)
|
||||
}
|
||||
settings.EnablePush = cfg.PushEnabled
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
// connectionPool HTTP/2 连接池。
|
||||
type connectionPool struct {
|
||||
mu sync.RWMutex
|
||||
conns map[string][]net.Conn
|
||||
}
|
||||
|
||||
// newConnectionPool 创建新的连接池。
|
||||
func newConnectionPool() *connectionPool {
|
||||
return &connectionPool{
|
||||
conns: make(map[string][]net.Conn),
|
||||
}
|
||||
}
|
||||
|
||||
// add 添加连接。
|
||||
func (p *connectionPool) add(key string, conn net.Conn) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.conns[key] = append(p.conns[key], conn)
|
||||
}
|
||||
|
||||
// remove 移除连接。
|
||||
func (p *connectionPool) remove(key string, conn net.Conn) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
conns := p.conns[key]
|
||||
for i, c := range conns {
|
||||
if c == conn {
|
||||
p.conns[key] = append(conns[:i], conns[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// get 获取连接。
|
||||
func (p *connectionPool) get(key string) []net.Conn {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.conns[key]
|
||||
}
|
||||
|
||||
// count 获取连接数。
|
||||
func (p *connectionPool) count(key string) int {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return len(p.conns[key])
|
||||
}
|
||||
|
||||
// closeAll 关闭所有连接。
|
||||
func (p *connectionPool) closeAll() { //nolint:unused // reserved for future use
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for _, conns := range p.conns {
|
||||
for _, conn := range conns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
p.conns = make(map[string][]net.Conn)
|
||||
}
|
||||
|
||||
// canonicalHeaderKey 返回规范化的 HTTP 头键。
|
||||
func canonicalHeaderKey(key string) string {
|
||||
// 使用 strings 包实现规范化
|
||||
result := strings.ToLower(key)
|
||||
if result == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ToUpper(result[:1]) + result[1:]
|
||||
}
|
||||
456
internal/http2/server_test.go
Normal file
456
internal/http2/server_test.go
Normal file
@ -0,0 +1,456 @@
|
||||
// Package http2 提供 HTTP/2 服务器测试。
|
||||
//
|
||||
// 该文件包含 HTTP/2 服务器的单元测试和集成测试:
|
||||
// - 服务器创建和配置测试
|
||||
// - ALPN 协议协商测试
|
||||
// - HTTP/1.1 fallback 测试
|
||||
//
|
||||
// 作者:xfy
|
||||
package http2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// TestNewServer 测试 HTTP/2 服务器创建。
|
||||
func TestNewServer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *config.HTTP2Config
|
||||
handler fasthttp.RequestHandler
|
||||
tlsConfig *tls.Config
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "有效配置",
|
||||
cfg: &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
MaxConcurrentStreams: 128,
|
||||
MaxHeaderListSize: 1048576,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
PushEnabled: false,
|
||||
H2CEnabled: false,
|
||||
},
|
||||
handler: func(ctx *fasthttp.RequestCtx) {},
|
||||
tlsConfig: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "默认配置",
|
||||
cfg: &config.HTTP2Config{},
|
||||
handler: func(ctx *fasthttp.RequestCtx) {},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nil配置",
|
||||
cfg: nil,
|
||||
handler: func(ctx *fasthttp.RequestCtx) {},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nil handler",
|
||||
cfg: &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
},
|
||||
handler: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "自定义并发流数量",
|
||||
cfg: &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
MaxConcurrentStreams: 256,
|
||||
},
|
||||
handler: func(ctx *fasthttp.RequestCtx) {},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server, err := NewServer(tt.cfg, tt.handler, tt.tlsConfig)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("NewServer() expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("NewServer() unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
if server == nil {
|
||||
t.Error("NewServer() returned nil server")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证配置正确应用
|
||||
if server.config != tt.cfg {
|
||||
t.Error("NewServer() config not set correctly")
|
||||
}
|
||||
if server.handler == nil {
|
||||
t.Error("NewServer() handler not set")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerDefaultValues 测试服务器默认值。
|
||||
func TestServerDefaultValues(t *testing.T) {
|
||||
cfg := &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
}
|
||||
handler := func(ctx *fasthttp.RequestCtx) {}
|
||||
|
||||
server, err := NewServer(cfg, handler, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer() error: %v", err)
|
||||
}
|
||||
|
||||
// 验证默认并发流数量
|
||||
if server.http2Server.MaxConcurrentStreams == 0 {
|
||||
t.Error("Expected default MaxConcurrentStreams to be set")
|
||||
}
|
||||
|
||||
// 验证默认空闲超时
|
||||
if server.http2Server.IdleTimeout == 0 {
|
||||
t.Error("Expected default IdleTimeout to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerIsRunning 测试服务器运行状态。
|
||||
func TestServerIsRunning(t *testing.T) {
|
||||
cfg := &config.HTTP2Config{Enabled: true}
|
||||
server, err := NewServer(cfg, func(ctx *fasthttp.RequestCtx) {}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer() error: %v", err)
|
||||
}
|
||||
|
||||
// 初始状态应为未运行
|
||||
if server.IsRunning() {
|
||||
t.Error("New server should not be running")
|
||||
}
|
||||
}
|
||||
|
||||
// TestServerGetConfig 测试获取服务器配置。
|
||||
func TestServerGetConfig(t *testing.T) {
|
||||
cfg := &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
MaxConcurrentStreams: 100,
|
||||
}
|
||||
server, err := NewServer(cfg, func(ctx *fasthttp.RequestCtx) {}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer() error: %v", err)
|
||||
}
|
||||
|
||||
gotCfg := server.GetConfig()
|
||||
if gotCfg != cfg {
|
||||
t.Error("GetConfig() returned wrong config")
|
||||
}
|
||||
}
|
||||
|
||||
// TestALPNConfig 测试 ALPN 配置。
|
||||
func TestALPNConfig(t *testing.T) {
|
||||
cfg := &config.HTTP2Config{Enabled: true}
|
||||
server, err := NewServer(cfg, func(ctx *fasthttp.RequestCtx) {}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer() error: %v", err)
|
||||
}
|
||||
|
||||
tlsCfg := server.ALPNConfig()
|
||||
if tlsCfg == nil {
|
||||
t.Fatal("ALPNConfig() returned nil")
|
||||
}
|
||||
|
||||
// 验证 ALPN 协议包含 h2 和 http/1.1
|
||||
foundH2 := false
|
||||
foundHTTP11 := false
|
||||
for _, proto := range tlsCfg.NextProtos {
|
||||
if proto == "h2" {
|
||||
foundH2 = true
|
||||
}
|
||||
if proto == "http/1.1" {
|
||||
foundHTTP11 = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundH2 {
|
||||
t.Error("ALPN config missing 'h2' protocol")
|
||||
}
|
||||
if !foundHTTP11 {
|
||||
t.Error("ALPN config missing 'http/1.1' protocol")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapTLSListener 测试 TLS 监听器包装。
|
||||
func TestWrapTLSListener(t *testing.T) {
|
||||
// 创建测试监听器
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() { _ = ln.Close() }()
|
||||
|
||||
// 创建 TLS 配置
|
||||
tlsConfig := &tls.Config{
|
||||
NextProtos: []string{},
|
||||
}
|
||||
|
||||
// 包装监听器
|
||||
wrappedLn := WrapTLSListener(ln, tlsConfig)
|
||||
if wrappedLn == nil {
|
||||
t.Fatal("WrapTLSListener() returned nil")
|
||||
}
|
||||
|
||||
// 验证 ALPN 协议已设置
|
||||
if len(tlsConfig.NextProtos) == 0 {
|
||||
t.Error("WrapTLSListener should set NextProtos")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsH2CEnabled 测试 H2C 启用检查。
|
||||
func TestIsH2CEnabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
h2cEnabled bool
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "H2C 启用",
|
||||
h2cEnabled: true,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "H2C 禁用",
|
||||
h2cEnabled: false,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
H2CEnabled: tt.h2cEnabled,
|
||||
}
|
||||
server, err := NewServer(cfg, func(ctx *fasthttp.RequestCtx) {}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer() error: %v", err)
|
||||
}
|
||||
|
||||
if got := server.IsH2CEnabled(); got != tt.want {
|
||||
t.Errorf("IsH2CEnabled() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsHTTP2Request 测试 HTTP/2 请求检测。
|
||||
func TestIsHTTP2Request(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
major int
|
||||
header map[string]string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "PRI 方法",
|
||||
method: "PRI",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "HTTP/2 版本",
|
||||
major: 2,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "HTTP/1.1",
|
||||
method: "GET",
|
||||
major: 1,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 这里只测试基本的逻辑,完整测试需要创建 http.Request
|
||||
// 在实际集成测试中会覆盖
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHTTP2Settings 测试 HTTP/2 设置。
|
||||
func TestHTTP2Settings(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
settings HTTP2Settings
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "默认设置",
|
||||
settings: HTTP2Settings{
|
||||
HeaderTableSize: 4096,
|
||||
EnablePush: true,
|
||||
MaxConcurrentStreams: 250,
|
||||
InitialWindowSize: 65535,
|
||||
MaxFrameSize: 16384,
|
||||
MaxHeaderListSize: 1048576,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "零并发流",
|
||||
settings: HTTP2Settings{
|
||||
MaxConcurrentStreams: 0,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "无效帧大小",
|
||||
settings: HTTP2Settings{
|
||||
MaxConcurrentStreams: 100,
|
||||
MaxFrameSize: 1024, // 小于最小值 16384
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "帧大小过大",
|
||||
settings: HTTP2Settings{
|
||||
MaxConcurrentStreams: 100,
|
||||
MaxFrameSize: 16777216, // 超过最大值 16777215
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "零头部列表大小",
|
||||
settings: HTTP2Settings{
|
||||
MaxConcurrentStreams: 100,
|
||||
MaxHeaderListSize: 0,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateHTTP2Settings(tt.settings)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("ValidateHTTP2Settings() expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("ValidateHTTP2Settings() unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultHTTP2Settings 测试默认 HTTP/2 设置。
|
||||
func TestDefaultHTTP2Settings(t *testing.T) {
|
||||
settings := DefaultHTTP2Settings()
|
||||
|
||||
if settings.HeaderTableSize == 0 {
|
||||
t.Error("Default HeaderTableSize should not be zero")
|
||||
}
|
||||
if settings.MaxConcurrentStreams == 0 {
|
||||
t.Error("Default MaxConcurrentStreams should not be zero")
|
||||
}
|
||||
if settings.InitialWindowSize == 0 {
|
||||
t.Error("Default InitialWindowSize should not be zero")
|
||||
}
|
||||
if settings.MaxFrameSize == 0 {
|
||||
t.Error("Default MaxFrameSize should not be zero")
|
||||
}
|
||||
if settings.MaxHeaderListSize == 0 {
|
||||
t.Error("Default MaxHeaderListSize should not be zero")
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseHTTP2Settings 测试从配置解析 HTTP/2 设置。
|
||||
func TestParseHTTP2Settings(t *testing.T) {
|
||||
cfg := &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
MaxConcurrentStreams: 200,
|
||||
MaxHeaderListSize: 2097152, // 2MB
|
||||
PushEnabled: true,
|
||||
}
|
||||
|
||||
settings := ParseHTTP2Settings(cfg)
|
||||
|
||||
if settings.MaxConcurrentStreams != 200 {
|
||||
t.Errorf("ParseHTTP2Settings() MaxConcurrentStreams = %d, want 200", settings.MaxConcurrentStreams)
|
||||
}
|
||||
if settings.MaxHeaderListSize != 2097152 {
|
||||
t.Errorf("ParseHTTP2Settings() MaxHeaderListSize = %d, want 2097152", settings.MaxHeaderListSize)
|
||||
}
|
||||
if !settings.EnablePush {
|
||||
t.Error("ParseHTTP2Settings() EnablePush should be true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectionPool 测试连接池。
|
||||
func TestConnectionPool(t *testing.T) {
|
||||
pool := newConnectionPool()
|
||||
|
||||
// 创建测试连接
|
||||
ln1, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
defer func() { _ = ln1.Close() }()
|
||||
|
||||
ln2, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
defer func() { _ = ln2.Close() }()
|
||||
|
||||
// 测试添加连接
|
||||
conn1, _ := net.Dial("tcp", ln1.Addr().String())
|
||||
if conn1 != nil {
|
||||
defer func() { _ = conn1.Close() }()
|
||||
pool.add("key1", conn1)
|
||||
|
||||
// 测试获取连接
|
||||
conns := pool.get("key1")
|
||||
if len(conns) != 1 {
|
||||
t.Errorf("Expected 1 connection, got %d", len(conns))
|
||||
}
|
||||
|
||||
// 测试计数
|
||||
if count := pool.count("key1"); count != 1 {
|
||||
t.Errorf("Expected count 1, got %d", count)
|
||||
}
|
||||
|
||||
// 测试移除连接
|
||||
pool.remove("key1", conn1)
|
||||
if count := pool.count("key1"); count != 0 {
|
||||
t.Errorf("Expected count 0 after remove, got %d", count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestCanonicalHeaderKey 测试规范化头部键。
|
||||
func TestCanonicalHeaderKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"content-type", "Content-Type"},
|
||||
{"CONTENT-TYPE", "Content-Type"},
|
||||
{"Content-Type", "Content-Type"},
|
||||
{"x-custom-header", "X-Custom-Header"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := canonicalHeaderKey(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("canonicalHeaderKey(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -108,6 +108,7 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS12, // 强制 TLS 1.2 最低版本
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
NextProtos: []string{"h2", "http/1.1"}, // 启用 HTTP/2 ALPN 支持
|
||||
}
|
||||
|
||||
// 应用 TLS 1.2 的加密套件
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user