987 lines
22 KiB
Go
987 lines
22 KiB
Go
// Package http2 提供 HTTP/2 服务器测试。
|
||
//
|
||
// 该文件包含 HTTP/2 服务器的单元测试和集成测试:
|
||
// - 服务器创建和配置测试
|
||
// - ALPN 协议协商测试
|
||
// - HTTP/1.1 fallback 测试
|
||
//
|
||
// 作者:xfy
|
||
package http2
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"crypto/tls"
|
||
"errors"
|
||
"net"
|
||
"net/http"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/valyala/fasthttp"
|
||
"rua.plus/lolly/internal/config"
|
||
)
|
||
|
||
// TestNewServer 测试 HTTP/2 服务器创建。
|
||
func TestNewServer(t *testing.T) {
|
||
tests := []struct {
|
||
cfg *config.HTTP2Config
|
||
handler fasthttp.RequestHandler
|
||
tlsConfig *tls.Config
|
||
name string
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "有效配置",
|
||
cfg: &config.HTTP2Config{
|
||
Enabled: true,
|
||
MaxConcurrentStreams: 128,
|
||
MaxHeaderListSize: 1048576,
|
||
IdleTimeout: 120 * time.Second,
|
||
PushEnabled: false,
|
||
H2CEnabled: false,
|
||
},
|
||
handler: func(_ *fasthttp.RequestCtx) {},
|
||
tlsConfig: nil,
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "默认配置",
|
||
cfg: &config.HTTP2Config{},
|
||
handler: func(_ *fasthttp.RequestCtx) {},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "nil配置",
|
||
cfg: nil,
|
||
handler: func(_ *fasthttp.RequestCtx) {},
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "nil handler",
|
||
cfg: &config.HTTP2Config{
|
||
Enabled: true,
|
||
},
|
||
handler: nil,
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "自定义并发流数量",
|
||
cfg: &config.HTTP2Config{
|
||
Enabled: true,
|
||
MaxConcurrentStreams: 256,
|
||
},
|
||
handler: func(_ *fasthttp.RequestCtx) {},
|
||
wantErr: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
server, err := NewServer(tt.cfg, tt.handler, tt.tlsConfig)
|
||
if tt.wantErr {
|
||
if err == nil {
|
||
t.Errorf("NewServer() expected error, got nil")
|
||
}
|
||
return
|
||
}
|
||
if err != nil {
|
||
t.Errorf("NewServer() unexpected error: %v", err)
|
||
return
|
||
}
|
||
if server == nil {
|
||
t.Error("NewServer() returned nil server")
|
||
return
|
||
}
|
||
|
||
// 验证配置正确应用
|
||
if server.config != tt.cfg {
|
||
t.Error("NewServer() config not set correctly")
|
||
}
|
||
if server.handler == nil {
|
||
t.Error("NewServer() handler not set")
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestServerDefaultValues 测试服务器默认值。
|
||
func TestServerDefaultValues(t *testing.T) {
|
||
cfg := &config.HTTP2Config{
|
||
Enabled: true,
|
||
}
|
||
handler := func(_ *fasthttp.RequestCtx) {}
|
||
|
||
server, err := NewServer(cfg, handler, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewServer() error: %v", err)
|
||
}
|
||
|
||
// 验证默认并发流数量
|
||
if server.http2Server.MaxConcurrentStreams == 0 {
|
||
t.Error("Expected default MaxConcurrentStreams to be set")
|
||
}
|
||
|
||
// 验证默认空闲超时
|
||
if server.http2Server.IdleTimeout == 0 {
|
||
t.Error("Expected default IdleTimeout to be set")
|
||
}
|
||
}
|
||
|
||
// TestServerIsRunning 测试服务器运行状态。
|
||
func TestServerIsRunning(t *testing.T) {
|
||
cfg := &config.HTTP2Config{Enabled: true}
|
||
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewServer() error: %v", err)
|
||
}
|
||
|
||
// 初始状态应为未运行
|
||
if server.IsRunning() {
|
||
t.Error("New server should not be running")
|
||
}
|
||
}
|
||
|
||
// TestServerGetConfig 测试获取服务器配置。
|
||
func TestServerGetConfig(t *testing.T) {
|
||
cfg := &config.HTTP2Config{
|
||
Enabled: true,
|
||
MaxConcurrentStreams: 100,
|
||
}
|
||
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewServer() error: %v", err)
|
||
}
|
||
|
||
gotCfg := server.GetConfig()
|
||
if gotCfg != cfg {
|
||
t.Error("GetConfig() returned wrong config")
|
||
}
|
||
}
|
||
|
||
// TestALPNConfig 测试 ALPN 配置。
|
||
func TestALPNConfig(t *testing.T) {
|
||
cfg := &config.HTTP2Config{Enabled: true}
|
||
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewServer() error: %v", err)
|
||
}
|
||
|
||
tlsCfg := server.ALPNConfig()
|
||
if tlsCfg == nil {
|
||
t.Fatal("ALPNConfig() returned nil")
|
||
}
|
||
|
||
// 验证 ALPN 协议包含 h2 和 http/1.1
|
||
foundH2 := false
|
||
foundHTTP11 := false
|
||
for _, proto := range tlsCfg.NextProtos {
|
||
if proto == "h2" {
|
||
foundH2 = true
|
||
}
|
||
if proto == "http/1.1" {
|
||
foundHTTP11 = true
|
||
}
|
||
}
|
||
|
||
if !foundH2 {
|
||
t.Error("ALPN config missing 'h2' protocol")
|
||
}
|
||
if !foundHTTP11 {
|
||
t.Error("ALPN config missing 'http/1.1' protocol")
|
||
}
|
||
}
|
||
|
||
// TestWrapTLSListener 测试 TLS 监听器包装。
|
||
func TestWrapTLSListener(t *testing.T) {
|
||
// 创建测试监听器
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create listener: %v", err)
|
||
}
|
||
defer func() {
|
||
if err := ln.Close(); err != nil {
|
||
t.Logf("Failed to close listener: %v", err)
|
||
}
|
||
}()
|
||
|
||
// 创建 TLS 配置
|
||
tlsConfig := &tls.Config{
|
||
NextProtos: []string{},
|
||
}
|
||
|
||
// 包装监听器
|
||
wrappedLn := WrapTLSListener(ln, tlsConfig)
|
||
if wrappedLn == nil {
|
||
t.Fatal("WrapTLSListener() returned nil")
|
||
}
|
||
|
||
// 验证 ALPN 协议已设置
|
||
if len(tlsConfig.NextProtos) == 0 {
|
||
t.Error("WrapTLSListener should set NextProtos")
|
||
}
|
||
}
|
||
|
||
// TestIsH2CEnabled 测试 H2C 启用检查。
|
||
func TestIsH2CEnabled(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
h2cEnabled bool
|
||
want bool
|
||
}{
|
||
{
|
||
name: "H2C 启用",
|
||
h2cEnabled: true,
|
||
want: true,
|
||
},
|
||
{
|
||
name: "H2C 禁用",
|
||
h2cEnabled: false,
|
||
want: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
cfg := &config.HTTP2Config{
|
||
Enabled: true,
|
||
H2CEnabled: tt.h2cEnabled,
|
||
}
|
||
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewServer() error: %v", err)
|
||
}
|
||
|
||
if got := server.IsH2CEnabled(); got != tt.want {
|
||
t.Errorf("IsH2CEnabled() = %v, want %v", got, tt.want)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestIsHTTP2Request 测试 HTTP/2 请求检测。
|
||
func TestIsHTTP2Request(t *testing.T) {
|
||
tests := []struct {
|
||
header map[string]string
|
||
name string
|
||
method string
|
||
major int
|
||
want bool
|
||
}{
|
||
{
|
||
name: "PRI 方法",
|
||
method: "PRI",
|
||
want: true,
|
||
},
|
||
{
|
||
name: "HTTP/2 版本",
|
||
major: 2,
|
||
want: true,
|
||
},
|
||
{
|
||
name: "HTTP/1.1",
|
||
method: "GET",
|
||
major: 1,
|
||
want: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(_ *testing.T) {
|
||
// 这里只测试基本的逻辑,完整测试需要创建 http.Request
|
||
// 在实际集成测试中会覆盖
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestSettings 测试 HTTP/2 设置。
|
||
func TestSettings(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
settings Settings
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "默认设置",
|
||
settings: Settings{
|
||
HeaderTableSize: 4096,
|
||
EnablePush: true,
|
||
MaxConcurrentStreams: 250,
|
||
InitialWindowSize: 65535,
|
||
MaxFrameSize: 16384,
|
||
MaxHeaderListSize: 1048576,
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "零并发流",
|
||
settings: Settings{
|
||
MaxConcurrentStreams: 0,
|
||
},
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "无效帧大小",
|
||
settings: Settings{
|
||
MaxConcurrentStreams: 100,
|
||
MaxFrameSize: 1024, // 小于最小值 16384
|
||
},
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "帧大小过大",
|
||
settings: Settings{
|
||
MaxConcurrentStreams: 100,
|
||
MaxFrameSize: 16777216, // 超过最大值 16777215
|
||
},
|
||
wantErr: true,
|
||
},
|
||
{
|
||
name: "零头部列表大小",
|
||
settings: Settings{
|
||
MaxConcurrentStreams: 100,
|
||
MaxHeaderListSize: 0,
|
||
},
|
||
wantErr: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
err := ValidateSettings(tt.settings)
|
||
if tt.wantErr {
|
||
if err == nil {
|
||
t.Errorf("ValidateSettings() expected error, got nil")
|
||
}
|
||
return
|
||
}
|
||
if err != nil {
|
||
t.Errorf("ValidateSettings() unexpected error: %v", err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestDefaultSettings 测试默认 HTTP/2 设置。
|
||
func TestDefaultSettings(t *testing.T) {
|
||
settings := DefaultSettings()
|
||
|
||
if settings.HeaderTableSize == 0 {
|
||
t.Error("Default HeaderTableSize should not be zero")
|
||
}
|
||
if settings.MaxConcurrentStreams == 0 {
|
||
t.Error("Default MaxConcurrentStreams should not be zero")
|
||
}
|
||
if settings.InitialWindowSize == 0 {
|
||
t.Error("Default InitialWindowSize should not be zero")
|
||
}
|
||
if settings.MaxFrameSize == 0 {
|
||
t.Error("Default MaxFrameSize should not be zero")
|
||
}
|
||
if settings.MaxHeaderListSize == 0 {
|
||
t.Error("Default MaxHeaderListSize should not be zero")
|
||
}
|
||
}
|
||
|
||
// TestParseSettings 测试从配置解析 HTTP/2 设置。
|
||
func TestParseSettings(t *testing.T) {
|
||
cfg := &config.HTTP2Config{
|
||
Enabled: true,
|
||
MaxConcurrentStreams: 200,
|
||
MaxHeaderListSize: 2097152, // 2MB
|
||
PushEnabled: true,
|
||
}
|
||
|
||
settings := ParseSettings(cfg)
|
||
|
||
if settings.MaxConcurrentStreams != 200 {
|
||
t.Errorf("ParseSettings() MaxConcurrentStreams = %d, want 200", settings.MaxConcurrentStreams)
|
||
}
|
||
if settings.MaxHeaderListSize != 2097152 {
|
||
t.Errorf("ParseSettings() MaxHeaderListSize = %d, want 2097152", settings.MaxHeaderListSize)
|
||
}
|
||
if !settings.EnablePush {
|
||
t.Error("ParseSettings() EnablePush should be true")
|
||
}
|
||
}
|
||
|
||
// TestConnectionPool 测试连接池。
|
||
func TestConnectionPool(t *testing.T) {
|
||
pool := newConnectionPool()
|
||
|
||
// 创建测试连接
|
||
ln1, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create listener 1: %v", err)
|
||
}
|
||
defer func() {
|
||
if cerr := ln1.Close(); cerr != nil {
|
||
t.Logf("Failed to close listener 1: %v", cerr)
|
||
}
|
||
}()
|
||
|
||
ln2, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create listener 2: %v", err)
|
||
}
|
||
defer func() {
|
||
if cerr := ln2.Close(); cerr != nil {
|
||
t.Logf("Failed to close listener 2: %v", cerr)
|
||
}
|
||
}()
|
||
|
||
// 测试添加连接
|
||
conn1, err := net.Dial("tcp", ln1.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial listener 1: %v", err)
|
||
}
|
||
if conn1 != nil {
|
||
defer func() {
|
||
if err := conn1.Close(); err != nil {
|
||
t.Logf("Failed to close connection 1: %v", err)
|
||
}
|
||
}()
|
||
pool.add("key1", conn1)
|
||
|
||
// 测试获取连接
|
||
conns := pool.get("key1")
|
||
if len(conns) != 1 {
|
||
t.Errorf("Expected 1 connection, got %d", len(conns))
|
||
}
|
||
|
||
// 测试计数
|
||
if count := pool.count("key1"); count != 1 {
|
||
t.Errorf("Expected count 1, got %d", count)
|
||
}
|
||
|
||
// 测试移除连接
|
||
pool.remove("key1", conn1)
|
||
if count := pool.count("key1"); count != 0 {
|
||
t.Errorf("Expected count 0 after remove, got %d", count)
|
||
}
|
||
}
|
||
}
|
||
|
||
// TestCanonicalHeaderKey 测试规范化头部键。
|
||
func TestCanonicalHeaderKey(t *testing.T) {
|
||
tests := []struct {
|
||
input string
|
||
want string
|
||
}{
|
||
{"content-type", "Content-Type"},
|
||
{"CONTENT-TYPE", "Content-Type"},
|
||
{"Content-Type", "Content-Type"},
|
||
{"x-custom-header", "X-Custom-Header"},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.input, func(t *testing.T) {
|
||
got := canonicalHeaderKey(tt.input)
|
||
if got != tt.want {
|
||
t.Errorf("canonicalHeaderKey(%q) = %q, want %q", tt.input, got, tt.want)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestValidateSettings_InitialWindowSize 测试 InitialWindowSize 超限。
|
||
func TestValidateSettings_InitialWindowSize(t *testing.T) {
|
||
settings := Settings{
|
||
MaxConcurrentStreams: 100,
|
||
MaxFrameSize: 16384,
|
||
MaxHeaderListSize: 1048576,
|
||
InitialWindowSize: 2147483648, // 超过 2^31-1
|
||
}
|
||
|
||
err := ValidateSettings(settings)
|
||
if err == nil {
|
||
t.Error("ValidateSettings() expected error for InitialWindowSize > 2^31-1")
|
||
}
|
||
}
|
||
|
||
// TestServe_AcceptError 测试 Accept 错误处理。
|
||
func TestServe_AcceptError(t *testing.T) {
|
||
cfg := &config.HTTP2Config{Enabled: true}
|
||
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewServer() error: %v", err)
|
||
}
|
||
|
||
// 创建一个已关闭的监听器
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create listener: %v", err)
|
||
}
|
||
|
||
// 启动服务器
|
||
errCh := make(chan error, 1)
|
||
go func() {
|
||
errCh <- server.Serve(ln)
|
||
}()
|
||
|
||
// 关闭监听器触发 Accept 错误
|
||
if err := ln.Close(); err != nil {
|
||
t.Logf("Failed to close listener: %v", err)
|
||
}
|
||
|
||
// 停止服务器
|
||
_ = server.Stop()
|
||
|
||
// 服务器应该正常退出
|
||
select {
|
||
case err := <-errCh:
|
||
if err != nil {
|
||
t.Errorf("Serve() unexpected error: %v", err)
|
||
}
|
||
case <-time.After(2 * time.Second):
|
||
t.Error("Serve() did not exit in time")
|
||
}
|
||
}
|
||
|
||
// TestServe_AlreadyRunning 测试服务器重复启动。
|
||
func TestServe_AlreadyRunning(t *testing.T) {
|
||
cfg := &config.HTTP2Config{Enabled: true}
|
||
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewServer() error: %v", err)
|
||
}
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create listener: %v", err)
|
||
}
|
||
defer func() {
|
||
if err := ln.Close(); err != nil {
|
||
t.Logf("Failed to close listener: %v", err)
|
||
}
|
||
}()
|
||
|
||
// 启动服务器
|
||
go func() {
|
||
_ = server.Serve(ln)
|
||
}()
|
||
|
||
// 等待服务器启动
|
||
time.Sleep(50 * time.Millisecond)
|
||
|
||
// 尝试再次启动
|
||
err = server.Serve(ln)
|
||
if err == nil {
|
||
t.Error("Serve() should return error when already running")
|
||
}
|
||
|
||
// 停止服务器
|
||
_ = server.Stop()
|
||
}
|
||
|
||
// TestStop_GracefulShutdownTimeout 测试优雅关闭超时。
|
||
func TestStop_GracefulShutdownTimeout(t *testing.T) {
|
||
cfg := &config.HTTP2Config{
|
||
Enabled: true,
|
||
GracefulShutdownTimeout: 100 * time.Millisecond,
|
||
}
|
||
|
||
handler := func(ctx *fasthttp.RequestCtx) {
|
||
// 模拟长时间处理
|
||
time.Sleep(2 * time.Second)
|
||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||
}
|
||
|
||
server, err := NewServer(cfg, handler, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewServer() error: %v", err)
|
||
}
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create listener: %v", err)
|
||
}
|
||
defer func() {
|
||
if err := ln.Close(); err != nil {
|
||
t.Logf("Failed to close listener: %v", err)
|
||
}
|
||
}()
|
||
|
||
// 启动服务器
|
||
go func() {
|
||
_ = server.Serve(ln)
|
||
}()
|
||
|
||
// 等待服务器启动
|
||
time.Sleep(50 * time.Millisecond)
|
||
|
||
// 停止服务器(应该超时)
|
||
start := time.Now()
|
||
_ = server.Stop()
|
||
elapsed := time.Since(start)
|
||
|
||
// 应该在超时后返回
|
||
if elapsed > 500*time.Millisecond {
|
||
t.Errorf("Stop() took too long: %v", elapsed)
|
||
}
|
||
}
|
||
|
||
// TestStop_NotRunning 测试停止未运行的服务器。
|
||
func TestStop_NotRunning(t *testing.T) {
|
||
cfg := &config.HTTP2Config{Enabled: true}
|
||
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewServer() error: %v", err)
|
||
}
|
||
|
||
// 停止未运行的服务器应该返回 nil
|
||
err = server.Stop()
|
||
if err != nil {
|
||
t.Errorf("Stop() on non-running server should return nil, got: %v", err)
|
||
}
|
||
}
|
||
|
||
// TestHandleH2C 测试 H2C 处理。
|
||
func TestHandleH2C(t *testing.T) {
|
||
cfg := &config.HTTP2Config{Enabled: true}
|
||
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
|
||
if err != nil {
|
||
t.Fatalf("NewServer() error: %v", err)
|
||
}
|
||
|
||
// 创建一个 mock 连接
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create listener: %v", err)
|
||
}
|
||
defer func() {
|
||
if err := ln.Close(); err != nil {
|
||
t.Logf("Failed to close listener: %v", err)
|
||
}
|
||
}()
|
||
|
||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial: %v", err)
|
||
}
|
||
defer func() {
|
||
if err := conn.Close(); err != nil {
|
||
t.Logf("Failed to close connection: %v", err)
|
||
}
|
||
}()
|
||
|
||
// HandleH2C 应该返回 false(不支持 H2C)
|
||
handled, err := server.HandleH2C(conn)
|
||
if handled {
|
||
t.Error("HandleH2C() should return false (H2C not supported)")
|
||
}
|
||
if err != nil {
|
||
t.Errorf("HandleH2C() should return nil error, got: %v", err)
|
||
}
|
||
}
|
||
|
||
// TestH2CConnRead 测试 h2cConn.Read。
|
||
func TestH2CConnRead(t *testing.T) {
|
||
// 创建一个测试用的连接
|
||
server, client := net.Pipe()
|
||
defer func() {
|
||
_ = server.Close()
|
||
_ = client.Close()
|
||
}()
|
||
|
||
// 创建 h2cConn
|
||
h2c := &h2cConn{
|
||
Conn: server,
|
||
reader: nil,
|
||
}
|
||
|
||
// 测试无 reader 的读取
|
||
data := []byte("test data")
|
||
go func() {
|
||
_, _ = client.Write(data)
|
||
}()
|
||
|
||
buf := make([]byte, 100)
|
||
n, err := h2c.Read(buf)
|
||
if err != nil {
|
||
t.Errorf("h2cConn.Read() error: %v", err)
|
||
}
|
||
if n != len(data) {
|
||
t.Errorf("h2cConn.Read() n = %d, want %d", n, len(data))
|
||
}
|
||
}
|
||
|
||
// TestH2CConnRead_WithReader 测试 h2cConn.Read 带 reader。
|
||
func TestH2CConnRead_WithReader(t *testing.T) {
|
||
// 创建一个测试用的连接
|
||
server, client := net.Pipe()
|
||
defer func() {
|
||
_ = server.Close()
|
||
_ = client.Close()
|
||
}()
|
||
|
||
// 创建带 reader 的 h2cConn
|
||
reader := bufio.NewReader(bytes.NewReader([]byte("prefetched")))
|
||
h2c := &h2cConn{
|
||
Conn: server,
|
||
reader: reader,
|
||
}
|
||
|
||
buf := make([]byte, 100)
|
||
n, err := h2c.Read(buf)
|
||
if err != nil {
|
||
t.Errorf("h2cConn.Read() error: %v", err)
|
||
}
|
||
if n == 0 {
|
||
t.Error("h2cConn.Read() should read from reader")
|
||
}
|
||
}
|
||
|
||
// TestIsHTTP2Request 测试 HTTP/2 请求检测。
|
||
func TestIsHTTP2Request_Full(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
method string
|
||
major int
|
||
header map[string]string
|
||
want bool
|
||
}{
|
||
{
|
||
name: "PRI 方法",
|
||
method: "PRI",
|
||
major: 1,
|
||
want: true,
|
||
},
|
||
{
|
||
name: "HTTP/2 版本",
|
||
method: "GET",
|
||
major: 2,
|
||
want: true,
|
||
},
|
||
{
|
||
name: "HTTP/1.1",
|
||
method: "GET",
|
||
major: 1,
|
||
want: false,
|
||
},
|
||
{
|
||
name: "带 :method 头",
|
||
method: "GET",
|
||
major: 1,
|
||
header: map[string]string{":method": "GET"},
|
||
want: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
req, err := http.NewRequest(tt.method, "http://example.com/", nil)
|
||
if err != nil {
|
||
t.Fatalf("Failed to create request: %v", err)
|
||
}
|
||
|
||
req.ProtoMajor = tt.major
|
||
|
||
for k, v := range tt.header {
|
||
req.Header.Set(k, v)
|
||
}
|
||
|
||
got := IsHTTP2Request(req)
|
||
if got != tt.want {
|
||
t.Errorf("IsHTTP2Request() = %v, want %v", got, tt.want)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestGetALPNProtocol 测试获取 ALPN 协议。
|
||
func TestGetALPNProtocol(t *testing.T) {
|
||
// 测试非 TLS 连接
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create listener: %v", err)
|
||
}
|
||
defer func() {
|
||
if err := ln.Close(); err != nil {
|
||
t.Logf("Failed to close listener: %v", err)
|
||
}
|
||
}()
|
||
|
||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial: %v", err)
|
||
}
|
||
defer func() {
|
||
if err := conn.Close(); err != nil {
|
||
t.Logf("Failed to close connection: %v", err)
|
||
}
|
||
}()
|
||
|
||
// 非 TLS 连接应返回空字符串
|
||
proto := GetALPNProtocol(conn)
|
||
if proto != "" {
|
||
t.Errorf("GetALPNProtocol() on non-TLS connection = %q, want empty", proto)
|
||
}
|
||
}
|
||
|
||
// TestSupportsHTTP2 测试 HTTP/2 支持检测。
|
||
func TestSupportsHTTP2(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
method string
|
||
major int
|
||
header map[string]string
|
||
want bool
|
||
}{
|
||
{
|
||
name: "HTTP/2 请求",
|
||
method: "GET",
|
||
major: 2,
|
||
want: true,
|
||
},
|
||
{
|
||
name: "H2C 升级头",
|
||
method: "GET",
|
||
major: 1,
|
||
header: map[string]string{"Upgrade": "h2c"},
|
||
want: true,
|
||
},
|
||
{
|
||
name: "HTTP2-Settings 头",
|
||
method: "GET",
|
||
major: 1,
|
||
header: map[string]string{"HTTP2-Settings": "test"},
|
||
want: true,
|
||
},
|
||
{
|
||
name: "HTTP/1.1 无升级",
|
||
method: "GET",
|
||
major: 1,
|
||
want: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
req, err := http.NewRequest(tt.method, "http://example.com/", nil)
|
||
if err != nil {
|
||
t.Fatalf("Failed to create request: %v", err)
|
||
}
|
||
|
||
req.ProtoMajor = tt.major
|
||
|
||
for k, v := range tt.header {
|
||
req.Header.Set(k, v)
|
||
}
|
||
|
||
got := SupportsHTTP2(req)
|
||
if got != tt.want {
|
||
t.Errorf("SupportsHTTP2() = %v, want %v", got, tt.want)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestWrapTLSListener_GetConfigForClient 测试 TLS 监听器的 GetConfigForClient 回调。
|
||
func TestWrapTLSListener_GetConfigForClient(t *testing.T) {
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create listener: %v", err)
|
||
}
|
||
defer func() {
|
||
if err := ln.Close(); err != nil {
|
||
t.Logf("Failed to close listener: %v", err)
|
||
}
|
||
}()
|
||
|
||
tlsConfig := &tls.Config{
|
||
NextProtos: []string{},
|
||
GetConfigForClient: func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
|
||
return nil, nil
|
||
},
|
||
}
|
||
|
||
_ = WrapTLSListener(ln, tlsConfig)
|
||
}
|
||
|
||
// TestWrapTLSListener_GetConfigForClientError 测试 GetConfigForClient 返回错误。
|
||
func TestWrapTLSListener_GetConfigForClientError(t *testing.T) {
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create listener: %v", err)
|
||
}
|
||
defer func() {
|
||
if err := ln.Close(); err != nil {
|
||
t.Logf("Failed to close listener: %v", err)
|
||
}
|
||
}()
|
||
|
||
expectedErr := errors.New("client config error")
|
||
tlsConfig := &tls.Config{
|
||
NextProtos: []string{},
|
||
GetConfigForClient: func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
|
||
return nil, expectedErr
|
||
},
|
||
}
|
||
|
||
_ = WrapTLSListener(ln, tlsConfig)
|
||
}
|
||
|
||
// TestConnectionPool_CloseAll 测试连接池关闭所有连接。
|
||
func TestConnectionPool_CloseAll(t *testing.T) {
|
||
pool := newConnectionPool()
|
||
|
||
// 创建多个连接
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create listener: %v", err)
|
||
}
|
||
defer func() {
|
||
if err := ln.Close(); err != nil {
|
||
t.Logf("Failed to close listener: %v", err)
|
||
}
|
||
}()
|
||
|
||
conn1, err := net.Dial("tcp", ln.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial: %v", err)
|
||
}
|
||
conn2, err := net.Dial("tcp", ln.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial: %v", err)
|
||
}
|
||
|
||
pool.add("key1", conn1)
|
||
pool.add("key1", conn2)
|
||
|
||
// 关闭所有连接
|
||
pool.closeAll()
|
||
|
||
// 验证连接池已清空
|
||
if count := pool.count("key1"); count != 0 {
|
||
t.Errorf("Expected count 0 after closeAll, got %d", count)
|
||
}
|
||
}
|
||
|
||
// TestConnectionPool_RemoveNonExistent 测试移除不存在的连接。
|
||
func TestConnectionPool_RemoveNonExistent(t *testing.T) {
|
||
pool := newConnectionPool()
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create listener: %v", err)
|
||
}
|
||
defer func() {
|
||
if err := ln.Close(); err != nil {
|
||
t.Logf("Failed to close listener: %v", err)
|
||
}
|
||
}()
|
||
|
||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial: %v", err)
|
||
}
|
||
defer func() {
|
||
_ = conn.Close()
|
||
}()
|
||
|
||
// 移除不存在的 key/conn 组合不应 panic
|
||
pool.remove("nonexistent", conn)
|
||
pool.remove("key1", conn)
|
||
}
|