test(http2,http3): 添加 HTTP/2 和 HTTP/3 服务器测试

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-21 08:12:33 +08:00
parent cfb27a9b9d
commit 9f3524f641
3 changed files with 986 additions and 0 deletions

View File

@ -9,8 +9,12 @@
package http2
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"net"
"net/http"
"testing"
"time"
@ -479,3 +483,504 @@ func TestCanonicalHeaderKey(t *testing.T) {
})
}
}
// TestValidateSettings_InitialWindowSize 测试 InitialWindowSize 超限。
func TestValidateSettings_InitialWindowSize(t *testing.T) {
settings := Settings{
MaxConcurrentStreams: 100,
MaxFrameSize: 16384,
MaxHeaderListSize: 1048576,
InitialWindowSize: 2147483648, // 超过 2^31-1
}
err := ValidateSettings(settings)
if err == nil {
t.Error("ValidateSettings() expected error for InitialWindowSize > 2^31-1")
}
}
// TestServe_AcceptError 测试 Accept 错误处理。
func TestServe_AcceptError(t *testing.T) {
cfg := &config.HTTP2Config{Enabled: true}
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
if err != nil {
t.Fatalf("NewServer() error: %v", err)
}
// 创建一个已关闭的监听器
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
// 启动服务器
errCh := make(chan error, 1)
go func() {
errCh <- server.Serve(ln)
}()
// 关闭监听器触发 Accept 错误
if err := ln.Close(); err != nil {
t.Logf("Failed to close listener: %v", err)
}
// 停止服务器
_ = server.Stop()
// 服务器应该正常退出
select {
case err := <-errCh:
if err != nil {
t.Errorf("Serve() unexpected error: %v", err)
}
case <-time.After(2 * time.Second):
t.Error("Serve() did not exit in time")
}
}
// TestServe_AlreadyRunning 测试服务器重复启动。
func TestServe_AlreadyRunning(t *testing.T) {
cfg := &config.HTTP2Config{Enabled: true}
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
if err != nil {
t.Fatalf("NewServer() error: %v", err)
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() {
if err := ln.Close(); err != nil {
t.Logf("Failed to close listener: %v", err)
}
}()
// 启动服务器
go func() {
_ = server.Serve(ln)
}()
// 等待服务器启动
time.Sleep(50 * time.Millisecond)
// 尝试再次启动
err = server.Serve(ln)
if err == nil {
t.Error("Serve() should return error when already running")
}
// 停止服务器
_ = server.Stop()
}
// TestStop_GracefulShutdownTimeout 测试优雅关闭超时。
func TestStop_GracefulShutdownTimeout(t *testing.T) {
cfg := &config.HTTP2Config{
Enabled: true,
GracefulShutdownTimeout: 100 * time.Millisecond,
}
handler := func(ctx *fasthttp.RequestCtx) {
// 模拟长时间处理
time.Sleep(2 * time.Second)
ctx.SetStatusCode(fasthttp.StatusOK)
}
server, err := NewServer(cfg, handler, nil)
if err != nil {
t.Fatalf("NewServer() error: %v", err)
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() {
if err := ln.Close(); err != nil {
t.Logf("Failed to close listener: %v", err)
}
}()
// 启动服务器
go func() {
_ = server.Serve(ln)
}()
// 等待服务器启动
time.Sleep(50 * time.Millisecond)
// 停止服务器(应该超时)
start := time.Now()
_ = server.Stop()
elapsed := time.Since(start)
// 应该在超时后返回
if elapsed > 500*time.Millisecond {
t.Errorf("Stop() took too long: %v", elapsed)
}
}
// TestStop_NotRunning 测试停止未运行的服务器。
func TestStop_NotRunning(t *testing.T) {
cfg := &config.HTTP2Config{Enabled: true}
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
if err != nil {
t.Fatalf("NewServer() error: %v", err)
}
// 停止未运行的服务器应该返回 nil
err = server.Stop()
if err != nil {
t.Errorf("Stop() on non-running server should return nil, got: %v", err)
}
}
// TestHandleH2C 测试 H2C 处理。
func TestHandleH2C(t *testing.T) {
cfg := &config.HTTP2Config{Enabled: true}
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
if err != nil {
t.Fatalf("NewServer() error: %v", err)
}
// 创建一个 mock 连接
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() {
if err := ln.Close(); err != nil {
t.Logf("Failed to close listener: %v", err)
}
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer func() {
if err := conn.Close(); err != nil {
t.Logf("Failed to close connection: %v", err)
}
}()
// HandleH2C 应该返回 false不支持 H2C
handled, err := server.HandleH2C(conn)
if handled {
t.Error("HandleH2C() should return false (H2C not supported)")
}
if err != nil {
t.Errorf("HandleH2C() should return nil error, got: %v", err)
}
}
// TestH2CConnRead 测试 h2cConn.Read。
func TestH2CConnRead(t *testing.T) {
// 创建一个测试用的连接
server, client := net.Pipe()
defer func() {
_ = server.Close()
_ = client.Close()
}()
// 创建 h2cConn
h2c := &h2cConn{
Conn: server,
reader: nil,
}
// 测试无 reader 的读取
data := []byte("test data")
go func() {
_, _ = client.Write(data)
}()
buf := make([]byte, 100)
n, err := h2c.Read(buf)
if err != nil {
t.Errorf("h2cConn.Read() error: %v", err)
}
if n != len(data) {
t.Errorf("h2cConn.Read() n = %d, want %d", n, len(data))
}
}
// TestH2CConnRead_WithReader 测试 h2cConn.Read 带 reader。
func TestH2CConnRead_WithReader(t *testing.T) {
// 创建一个测试用的连接
server, client := net.Pipe()
defer func() {
_ = server.Close()
_ = client.Close()
}()
// 创建带 reader 的 h2cConn
reader := bufio.NewReader(bytes.NewReader([]byte("prefetched")))
h2c := &h2cConn{
Conn: server,
reader: reader,
}
buf := make([]byte, 100)
n, err := h2c.Read(buf)
if err != nil {
t.Errorf("h2cConn.Read() error: %v", err)
}
if n == 0 {
t.Error("h2cConn.Read() should read from reader")
}
}
// TestIsHTTP2Request 测试 HTTP/2 请求检测。
func TestIsHTTP2Request_Full(t *testing.T) {
tests := []struct {
name string
method string
major int
header map[string]string
want bool
}{
{
name: "PRI 方法",
method: "PRI",
major: 1,
want: true,
},
{
name: "HTTP/2 版本",
method: "GET",
major: 2,
want: true,
},
{
name: "HTTP/1.1",
method: "GET",
major: 1,
want: false,
},
{
name: "带 :method 头",
method: "GET",
major: 1,
header: map[string]string{":method": "GET"},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequest(tt.method, "http://example.com/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.ProtoMajor = tt.major
for k, v := range tt.header {
req.Header.Set(k, v)
}
got := IsHTTP2Request(req)
if got != tt.want {
t.Errorf("IsHTTP2Request() = %v, want %v", got, tt.want)
}
})
}
}
// TestGetALPNProtocol 测试获取 ALPN 协议。
func TestGetALPNProtocol(t *testing.T) {
// 测试非 TLS 连接
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() {
if err := ln.Close(); err != nil {
t.Logf("Failed to close listener: %v", err)
}
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer func() {
if err := conn.Close(); err != nil {
t.Logf("Failed to close connection: %v", err)
}
}()
// 非 TLS 连接应返回空字符串
proto := GetALPNProtocol(conn)
if proto != "" {
t.Errorf("GetALPNProtocol() on non-TLS connection = %q, want empty", proto)
}
}
// TestSupportsHTTP2 测试 HTTP/2 支持检测。
func TestSupportsHTTP2(t *testing.T) {
tests := []struct {
name string
method string
major int
header map[string]string
want bool
}{
{
name: "HTTP/2 请求",
method: "GET",
major: 2,
want: true,
},
{
name: "H2C 升级头",
method: "GET",
major: 1,
header: map[string]string{"Upgrade": "h2c"},
want: true,
},
{
name: "HTTP2-Settings 头",
method: "GET",
major: 1,
header: map[string]string{"HTTP2-Settings": "test"},
want: true,
},
{
name: "HTTP/1.1 无升级",
method: "GET",
major: 1,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequest(tt.method, "http://example.com/", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.ProtoMajor = tt.major
for k, v := range tt.header {
req.Header.Set(k, v)
}
got := SupportsHTTP2(req)
if got != tt.want {
t.Errorf("SupportsHTTP2() = %v, want %v", got, tt.want)
}
})
}
}
// TestWrapTLSListener_GetConfigForClient 测试 TLS 监听器的 GetConfigForClient 回调。
func TestWrapTLSListener_GetConfigForClient(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() {
if err := ln.Close(); err != nil {
t.Logf("Failed to close listener: %v", err)
}
}()
tlsConfig := &tls.Config{
NextProtos: []string{},
GetConfigForClient: func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
return nil, nil
},
}
_ = WrapTLSListener(ln, tlsConfig)
}
// TestWrapTLSListener_GetConfigForClientError 测试 GetConfigForClient 返回错误。
func TestWrapTLSListener_GetConfigForClientError(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() {
if err := ln.Close(); err != nil {
t.Logf("Failed to close listener: %v", err)
}
}()
expectedErr := errors.New("client config error")
tlsConfig := &tls.Config{
NextProtos: []string{},
GetConfigForClient: func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
return nil, expectedErr
},
}
_ = WrapTLSListener(ln, tlsConfig)
}
// TestConnectionPool_CloseAll 测试连接池关闭所有连接。
func TestConnectionPool_CloseAll(t *testing.T) {
pool := newConnectionPool()
// 创建多个连接
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() {
if err := ln.Close(); err != nil {
t.Logf("Failed to close listener: %v", err)
}
}()
conn1, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
conn2, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
pool.add("key1", conn1)
pool.add("key1", conn2)
// 关闭所有连接
pool.closeAll()
// 验证连接池已清空
if count := pool.count("key1"); count != 0 {
t.Errorf("Expected count 0 after closeAll, got %d", count)
}
}
// TestConnectionPool_RemoveNonExistent 测试移除不存在的连接。
func TestConnectionPool_RemoveNonExistent(t *testing.T) {
pool := newConnectionPool()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() {
if err := ln.Close(); err != nil {
t.Logf("Failed to close listener: %v", err)
}
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer func() {
_ = conn.Close()
}()
// 移除不存在的 key/conn 组合不应 panic
pool.remove("nonexistent", conn)
pool.remove("key1", conn)
}

View File

@ -377,3 +377,237 @@ func (m *mockResponseWriter) Write(data []byte) (int, error) {
}
return len(data), nil
}
// TestStreamRequestBody_LargeBody 测试大请求体的流式处理
func TestStreamRequestBody_LargeBody(t *testing.T) {
adapter := NewAdapter()
// 创建大于 64KB 的请求体
largeBody := make([]byte, 100*1024) // 100KB
for i := range largeBody {
largeBody[i] = byte(i % 256)
}
req := &http.Request{
Method: "POST",
URL: &url.URL{Path: "/test"},
Host: "localhost",
Body: io.NopCloser(bytes.NewReader(largeBody)),
ContentLength: int64(len(largeBody)),
}
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
adapter.convertRequest(req, ctx)
if !bytes.Equal(ctx.Request.Body(), largeBody) {
t.Errorf("Large body mismatch: expected %d bytes, got %d bytes", len(largeBody), len(ctx.Request.Body()))
}
}
// TestStreamRequestBody_UnknownContentLength 测试未知内容长度的请求体
func TestStreamRequestBody_UnknownContentLength(t *testing.T) {
adapter := NewAdapter()
bodyContent := []byte("test body with unknown length")
req := &http.Request{
Method: "POST",
URL: &url.URL{Path: "/test"},
Host: "localhost",
Body: io.NopCloser(bytes.NewReader(bodyContent)),
ContentLength: -1, // 未知长度
}
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
adapter.convertRequest(req, ctx)
if !bytes.Equal(ctx.Request.Body(), bodyContent) {
t.Errorf("Body mismatch: expected %s, got %s", bodyContent, ctx.Request.Body())
}
}
// TestStreamRequestBody_NoBody 测试无请求体的情况
func TestStreamRequestBody_NoBody(t *testing.T) {
adapter := NewAdapter()
req := &http.Request{
Method: "GET",
URL: &url.URL{Path: "/test"},
Host: "localhost",
Body: nil,
}
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
adapter.convertRequest(req, ctx)
if len(ctx.Request.Body()) != 0 {
t.Errorf("Expected empty body, got %d bytes", len(ctx.Request.Body()))
}
}
// TestStreamRequestBody_NoBodyConstant 测试 http.NoBody 的情况
func TestStreamRequestBody_NoBodyConstant(t *testing.T) {
adapter := NewAdapter()
req := &http.Request{
Method: "GET",
URL: &url.URL{Path: "/test"},
Host: "localhost",
Body: http.NoBody,
}
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
adapter.convertRequest(req, ctx)
if len(ctx.Request.Body()) != 0 {
t.Errorf("Expected empty body with http.NoBody, got %d bytes", len(ctx.Request.Body()))
}
}
// TestConvertRequest_EmptyRemoteAddr 测试空远程地址
func TestConvertRequest_EmptyRemoteAddr(t *testing.T) {
adapter := NewAdapter()
req := &http.Request{
Method: "GET",
URL: &url.URL{Path: "/test"},
Host: "localhost",
RemoteAddr: "",
}
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
adapter.convertRequest(req, ctx)
// 空远程地址不应该导致错误
// fasthttp 会设置默认值,所以只需验证不会 panic
}
// TestConvertRequest_InvalidRemoteAddr 测试无效远程地址
func TestConvertRequest_InvalidRemoteAddr(t *testing.T) {
adapter := NewAdapter()
req := &http.Request{
Method: "GET",
URL: &url.URL{Path: "/test"},
Host: "localhost",
RemoteAddr: "invalid-address",
}
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
adapter.convertRequest(req, ctx)
// 无效远程地址应该被忽略,不应导致错误
// ctx.RemoteAddr() 可能为 nil 或默认值
}
// TestWrap_TypeAssertionFailure 测试类型断言失败的情况
func TestWrap_TypeAssertionFailure(t *testing.T) {
adapter := NewAdapter()
// 创建一个会触发类型断言的 handler
handler := func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetBodyString("OK")
}
httpHandler := adapter.Wrap(handler)
// 执行请求
req := &http.Request{
Method: "GET",
URL: &url.URL{Path: "/test"},
Host: "localhost",
}
rw := &mockResponseWriter{}
httpHandler.ServeHTTP(rw, req)
// 应该正常处理
if rw.status != 200 {
t.Errorf("Expected status 200, got %d", rw.status)
}
}
// TestConvertResponse_EmptyBody 测试空响应体
func TestConvertResponse_EmptyBody(t *testing.T) {
adapter := NewAdapter()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
ctx.SetStatusCode(204) // No Content
rw := &mockResponseWriter{}
adapter.convertResponse(ctx, rw)
if rw.status != 204 {
t.Errorf("Expected status 204, got %d", rw.status)
}
if len(rw.body) != 0 {
t.Errorf("Expected empty body, got %d bytes", len(rw.body))
}
}
// TestConvertRequest_Protocol 测试协议版本设置
func TestConvertRequest_Protocol(t *testing.T) {
adapter := NewAdapter()
req := &http.Request{
Method: "GET",
URL: &url.URL{Path: "/test"},
Host: "localhost",
}
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
adapter.convertRequest(req, ctx)
protocol := string(ctx.Request.Header.Protocol())
if protocol != "HTTP/3" {
t.Errorf("Expected protocol HTTP/3, got %s", protocol)
}
}
// errorReader 是一个会返回错误的 io.Reader
type errorReader struct{}
func (e *errorReader) Read(_ []byte) (n int, err error) {
return 0, io.ErrUnexpectedEOF
}
// TestStreamRequestBody_ReadError 测试读取请求体时的错误处理
func TestStreamRequestBody_ReadError(t *testing.T) {
adapter := NewAdapter()
req := &http.Request{
Method: "POST",
URL: &url.URL{Path: "/test"},
Host: "localhost",
Body: io.NopCloser(&errorReader{}),
ContentLength: -1, // 强制使用流式读取
}
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
// 这不应该 panic应该优雅处理错误
adapter.convertRequest(req, ctx)
// 错误读取时body 应该为空
if len(ctx.Request.Body()) != 0 {
t.Errorf("Expected empty body on read error, got %d bytes", len(ctx.Request.Body()))
}
}

View File

@ -11,7 +11,14 @@
package http3
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"testing"
"time"
@ -448,3 +455,243 @@ func TestGetAltSvcHeader_NilConfig(t *testing.T) {
t.Error("Expected non-empty Alt-Svc header with valid config")
}
}
// TestStart_AlreadyRunning 测试启动已运行的服务器
func TestStart_AlreadyRunning(t *testing.T) {
cfg := &config.HTTP3Config{
Enabled: true,
Listen: ":0", // 使用随机端口避免冲突
MaxStreams: 100,
}
handler := func(_ *fasthttp.RequestCtx) {}
cert := generateTestCertificate(t)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
server, err := NewServer(cfg, handler, tlsConfig)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// 启动服务器
err = server.Start()
if err != nil {
t.Fatalf("Failed to start server: %v", err)
}
defer server.Stop()
// 再次启动应该失败
err = server.Start()
if err == nil {
t.Error("Expected error when starting already running server")
}
if err.Error() != "server already running" {
t.Errorf("Expected 'server already running' error, got: %v", err)
}
}
// TestStart_InvalidListenAddress 测试无效监听地址
func TestStart_InvalidListenAddress(t *testing.T) {
cfg := &config.HTTP3Config{
Enabled: true,
Listen: "invalid:address:format", // 无效地址
MaxStreams: 100,
}
handler := func(_ *fasthttp.RequestCtx) {}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{},
}
server, err := NewServer(cfg, handler, tlsConfig)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
err = server.Start()
if err == nil {
t.Error("Expected error for invalid listen address")
server.Stop()
}
}
// TestStop_RunningServer 测试停止运行中的服务器
func TestStop_RunningServer(t *testing.T) {
cfg := &config.HTTP3Config{
Enabled: true,
Listen: ":0",
MaxStreams: 100,
}
handler := func(_ *fasthttp.RequestCtx) {}
cert := generateTestCertificate(t)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
server, err := NewServer(cfg, handler, tlsConfig)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// 启动服务器
err = server.Start()
if err != nil {
t.Fatalf("Failed to start server: %v", err)
}
if !server.IsRunning() {
t.Error("Server should be running after Start()")
}
// 停止服务器
err = server.Stop()
if err != nil {
t.Errorf("Unexpected error stopping server: %v", err)
}
if server.IsRunning() {
t.Error("Server should not be running after Stop()")
}
}
// TestGracefulStop_RunningServer 测试优雅停止运行中的服务器
func TestGracefulStop_RunningServer(t *testing.T) {
cfg := &config.HTTP3Config{
Enabled: true,
Listen: ":0",
MaxStreams: 100,
}
handler := func(_ *fasthttp.RequestCtx) {}
cert := generateTestCertificate(t)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
server, err := NewServer(cfg, handler, tlsConfig)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// 启动服务器
err = server.Start()
if err != nil {
t.Fatalf("Failed to start server: %v", err)
}
if !server.IsRunning() {
t.Error("Server should be running after Start()")
}
// 优雅停止服务器
err = server.GracefulStop(5 * time.Second)
if err != nil {
t.Errorf("Unexpected error graceful stopping server: %v", err)
}
if server.IsRunning() {
t.Error("Server should not be running after GracefulStop()")
}
}
// TestStart_Enable0RTT 测试启用 0-RTT 时的警告日志
func TestStart_Enable0RTT(t *testing.T) {
cfg := &config.HTTP3Config{
Enabled: true,
Listen: ":0",
Enable0RTT: true,
MaxStreams: 100,
}
handler := func(_ *fasthttp.RequestCtx) {}
cert := generateTestCertificate(t)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
server, err := NewServer(cfg, handler, tlsConfig)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
err = server.Start()
if err != nil {
t.Fatalf("Failed to start server: %v", err)
}
defer server.Stop()
// 验证服务器已启动
if !server.IsRunning() {
t.Error("Server should be running")
}
}
// TestStart_DefaultValues 测试默认值设置
func TestStart_DefaultValues(t *testing.T) {
cfg := &config.HTTP3Config{
Enabled: true,
Listen: ":0",
MaxStreams: 0, // 使用默认值
}
handler := func(_ *fasthttp.RequestCtx) {}
cert := generateTestCertificate(t)
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
server, err := NewServer(cfg, handler, tlsConfig)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
err = server.Start()
if err != nil {
t.Fatalf("Failed to start server: %v", err)
}
defer server.Stop()
if !server.IsRunning() {
t.Error("Server should be running")
}
}
// generateTestCertificate 生成用于测试的自签名证书
func generateTestCertificate(t *testing.T) tls.Certificate {
t.Helper()
// 使用 RSA 密钥生成自签名证书
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "localhost"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate RSA key: %v", err)
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
if err != nil {
t.Fatalf("Failed to create certificate: %v", err)
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatalf("Failed to load certificate: %v", err)
}
return cert
}