From 9f3524f641f2feb3f17131f0a0fdf03acb5328d4 Mon Sep 17 00:00:00 2001 From: xfy Date: Tue, 21 Apr 2026 08:12:33 +0800 Subject: [PATCH] =?UTF-8?q?test(http2,http3):=20=E6=B7=BB=E5=8A=A0=20HTTP/?= =?UTF-8?q?2=20=E5=92=8C=20HTTP/3=20=E6=9C=8D=E5=8A=A1=E5=99=A8=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.7 --- internal/http2/server_test.go | 505 +++++++++++++++++++++++++++++++++ internal/http3/adapter_test.go | 234 +++++++++++++++ internal/http3/server_test.go | 247 ++++++++++++++++ 3 files changed, 986 insertions(+) diff --git a/internal/http2/server_test.go b/internal/http2/server_test.go index b519f38..ade01b0 100644 --- a/internal/http2/server_test.go +++ b/internal/http2/server_test.go @@ -9,8 +9,12 @@ package http2 import ( + "bufio" + "bytes" "crypto/tls" + "errors" "net" + "net/http" "testing" "time" @@ -479,3 +483,504 @@ func TestCanonicalHeaderKey(t *testing.T) { }) } } + +// TestValidateSettings_InitialWindowSize 测试 InitialWindowSize 超限。 +func TestValidateSettings_InitialWindowSize(t *testing.T) { + settings := Settings{ + MaxConcurrentStreams: 100, + MaxFrameSize: 16384, + MaxHeaderListSize: 1048576, + InitialWindowSize: 2147483648, // 超过 2^31-1 + } + + err := ValidateSettings(settings) + if err == nil { + t.Error("ValidateSettings() expected error for InitialWindowSize > 2^31-1") + } +} + +// TestServe_AcceptError 测试 Accept 错误处理。 +func TestServe_AcceptError(t *testing.T) { + cfg := &config.HTTP2Config{Enabled: true} + server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil) + if err != nil { + t.Fatalf("NewServer() error: %v", err) + } + + // 创建一个已关闭的监听器 + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + + // 启动服务器 + errCh := make(chan error, 1) + go func() { + errCh <- server.Serve(ln) + }() + + // 关闭监听器触发 Accept 错误 + if err := ln.Close(); err != nil { + t.Logf("Failed to close listener: %v", err) + } + + // 停止服务器 + _ = server.Stop() + + // 服务器应该正常退出 + select { + case err := <-errCh: + if err != nil { + t.Errorf("Serve() unexpected error: %v", err) + } + case <-time.After(2 * time.Second): + t.Error("Serve() did not exit in time") + } +} + +// TestServe_AlreadyRunning 测试服务器重复启动。 +func TestServe_AlreadyRunning(t *testing.T) { + cfg := &config.HTTP2Config{Enabled: true} + server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil) + if err != nil { + t.Fatalf("NewServer() error: %v", err) + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer func() { + if err := ln.Close(); err != nil { + t.Logf("Failed to close listener: %v", err) + } + }() + + // 启动服务器 + go func() { + _ = server.Serve(ln) + }() + + // 等待服务器启动 + time.Sleep(50 * time.Millisecond) + + // 尝试再次启动 + err = server.Serve(ln) + if err == nil { + t.Error("Serve() should return error when already running") + } + + // 停止服务器 + _ = server.Stop() +} + +// TestStop_GracefulShutdownTimeout 测试优雅关闭超时。 +func TestStop_GracefulShutdownTimeout(t *testing.T) { + cfg := &config.HTTP2Config{ + Enabled: true, + GracefulShutdownTimeout: 100 * time.Millisecond, + } + + handler := func(ctx *fasthttp.RequestCtx) { + // 模拟长时间处理 + time.Sleep(2 * time.Second) + ctx.SetStatusCode(fasthttp.StatusOK) + } + + server, err := NewServer(cfg, handler, nil) + if err != nil { + t.Fatalf("NewServer() error: %v", err) + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer func() { + if err := ln.Close(); err != nil { + t.Logf("Failed to close listener: %v", err) + } + }() + + // 启动服务器 + go func() { + _ = server.Serve(ln) + }() + + // 等待服务器启动 + time.Sleep(50 * time.Millisecond) + + // 停止服务器(应该超时) + start := time.Now() + _ = server.Stop() + elapsed := time.Since(start) + + // 应该在超时后返回 + if elapsed > 500*time.Millisecond { + t.Errorf("Stop() took too long: %v", elapsed) + } +} + +// TestStop_NotRunning 测试停止未运行的服务器。 +func TestStop_NotRunning(t *testing.T) { + cfg := &config.HTTP2Config{Enabled: true} + server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil) + if err != nil { + t.Fatalf("NewServer() error: %v", err) + } + + // 停止未运行的服务器应该返回 nil + err = server.Stop() + if err != nil { + t.Errorf("Stop() on non-running server should return nil, got: %v", err) + } +} + +// TestHandleH2C 测试 H2C 处理。 +func TestHandleH2C(t *testing.T) { + cfg := &config.HTTP2Config{Enabled: true} + server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil) + if err != nil { + t.Fatalf("NewServer() error: %v", err) + } + + // 创建一个 mock 连接 + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer func() { + if err := ln.Close(); err != nil { + t.Logf("Failed to close listener: %v", err) + } + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Logf("Failed to close connection: %v", err) + } + }() + + // HandleH2C 应该返回 false(不支持 H2C) + handled, err := server.HandleH2C(conn) + if handled { + t.Error("HandleH2C() should return false (H2C not supported)") + } + if err != nil { + t.Errorf("HandleH2C() should return nil error, got: %v", err) + } +} + +// TestH2CConnRead 测试 h2cConn.Read。 +func TestH2CConnRead(t *testing.T) { + // 创建一个测试用的连接 + server, client := net.Pipe() + defer func() { + _ = server.Close() + _ = client.Close() + }() + + // 创建 h2cConn + h2c := &h2cConn{ + Conn: server, + reader: nil, + } + + // 测试无 reader 的读取 + data := []byte("test data") + go func() { + _, _ = client.Write(data) + }() + + buf := make([]byte, 100) + n, err := h2c.Read(buf) + if err != nil { + t.Errorf("h2cConn.Read() error: %v", err) + } + if n != len(data) { + t.Errorf("h2cConn.Read() n = %d, want %d", n, len(data)) + } +} + +// TestH2CConnRead_WithReader 测试 h2cConn.Read 带 reader。 +func TestH2CConnRead_WithReader(t *testing.T) { + // 创建一个测试用的连接 + server, client := net.Pipe() + defer func() { + _ = server.Close() + _ = client.Close() + }() + + // 创建带 reader 的 h2cConn + reader := bufio.NewReader(bytes.NewReader([]byte("prefetched"))) + h2c := &h2cConn{ + Conn: server, + reader: reader, + } + + buf := make([]byte, 100) + n, err := h2c.Read(buf) + if err != nil { + t.Errorf("h2cConn.Read() error: %v", err) + } + if n == 0 { + t.Error("h2cConn.Read() should read from reader") + } +} + +// TestIsHTTP2Request 测试 HTTP/2 请求检测。 +func TestIsHTTP2Request_Full(t *testing.T) { + tests := []struct { + name string + method string + major int + header map[string]string + want bool + }{ + { + name: "PRI 方法", + method: "PRI", + major: 1, + want: true, + }, + { + name: "HTTP/2 版本", + method: "GET", + major: 2, + want: true, + }, + { + name: "HTTP/1.1", + method: "GET", + major: 1, + want: false, + }, + { + name: "带 :method 头", + method: "GET", + major: 1, + header: map[string]string{":method": "GET"}, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequest(tt.method, "http://example.com/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + req.ProtoMajor = tt.major + + for k, v := range tt.header { + req.Header.Set(k, v) + } + + got := IsHTTP2Request(req) + if got != tt.want { + t.Errorf("IsHTTP2Request() = %v, want %v", got, tt.want) + } + }) + } +} + +// TestGetALPNProtocol 测试获取 ALPN 协议。 +func TestGetALPNProtocol(t *testing.T) { + // 测试非 TLS 连接 + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer func() { + if err := ln.Close(); err != nil { + t.Logf("Failed to close listener: %v", err) + } + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer func() { + if err := conn.Close(); err != nil { + t.Logf("Failed to close connection: %v", err) + } + }() + + // 非 TLS 连接应返回空字符串 + proto := GetALPNProtocol(conn) + if proto != "" { + t.Errorf("GetALPNProtocol() on non-TLS connection = %q, want empty", proto) + } +} + +// TestSupportsHTTP2 测试 HTTP/2 支持检测。 +func TestSupportsHTTP2(t *testing.T) { + tests := []struct { + name string + method string + major int + header map[string]string + want bool + }{ + { + name: "HTTP/2 请求", + method: "GET", + major: 2, + want: true, + }, + { + name: "H2C 升级头", + method: "GET", + major: 1, + header: map[string]string{"Upgrade": "h2c"}, + want: true, + }, + { + name: "HTTP2-Settings 头", + method: "GET", + major: 1, + header: map[string]string{"HTTP2-Settings": "test"}, + want: true, + }, + { + name: "HTTP/1.1 无升级", + method: "GET", + major: 1, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequest(tt.method, "http://example.com/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + req.ProtoMajor = tt.major + + for k, v := range tt.header { + req.Header.Set(k, v) + } + + got := SupportsHTTP2(req) + if got != tt.want { + t.Errorf("SupportsHTTP2() = %v, want %v", got, tt.want) + } + }) + } +} + +// TestWrapTLSListener_GetConfigForClient 测试 TLS 监听器的 GetConfigForClient 回调。 +func TestWrapTLSListener_GetConfigForClient(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer func() { + if err := ln.Close(); err != nil { + t.Logf("Failed to close listener: %v", err) + } + }() + + tlsConfig := &tls.Config{ + NextProtos: []string{}, + GetConfigForClient: func(_ *tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + }, + } + + _ = WrapTLSListener(ln, tlsConfig) +} + +// TestWrapTLSListener_GetConfigForClientError 测试 GetConfigForClient 返回错误。 +func TestWrapTLSListener_GetConfigForClientError(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer func() { + if err := ln.Close(); err != nil { + t.Logf("Failed to close listener: %v", err) + } + }() + + expectedErr := errors.New("client config error") + tlsConfig := &tls.Config{ + NextProtos: []string{}, + GetConfigForClient: func(_ *tls.ClientHelloInfo) (*tls.Config, error) { + return nil, expectedErr + }, + } + + _ = WrapTLSListener(ln, tlsConfig) +} + +// TestConnectionPool_CloseAll 测试连接池关闭所有连接。 +func TestConnectionPool_CloseAll(t *testing.T) { + pool := newConnectionPool() + + // 创建多个连接 + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer func() { + if err := ln.Close(); err != nil { + t.Logf("Failed to close listener: %v", err) + } + }() + + conn1, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + conn2, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + + pool.add("key1", conn1) + pool.add("key1", conn2) + + // 关闭所有连接 + pool.closeAll() + + // 验证连接池已清空 + if count := pool.count("key1"); count != 0 { + t.Errorf("Expected count 0 after closeAll, got %d", count) + } +} + +// TestConnectionPool_RemoveNonExistent 测试移除不存在的连接。 +func TestConnectionPool_RemoveNonExistent(t *testing.T) { + pool := newConnectionPool() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer func() { + if err := ln.Close(); err != nil { + t.Logf("Failed to close listener: %v", err) + } + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer func() { + _ = conn.Close() + }() + + // 移除不存在的 key/conn 组合不应 panic + pool.remove("nonexistent", conn) + pool.remove("key1", conn) +} diff --git a/internal/http3/adapter_test.go b/internal/http3/adapter_test.go index 2ddc3ca..8f61b92 100644 --- a/internal/http3/adapter_test.go +++ b/internal/http3/adapter_test.go @@ -377,3 +377,237 @@ func (m *mockResponseWriter) Write(data []byte) (int, error) { } return len(data), nil } + +// TestStreamRequestBody_LargeBody 测试大请求体的流式处理 +func TestStreamRequestBody_LargeBody(t *testing.T) { + adapter := NewAdapter() + + // 创建大于 64KB 的请求体 + largeBody := make([]byte, 100*1024) // 100KB + for i := range largeBody { + largeBody[i] = byte(i % 256) + } + + req := &http.Request{ + Method: "POST", + URL: &url.URL{Path: "/test"}, + Host: "localhost", + Body: io.NopCloser(bytes.NewReader(largeBody)), + ContentLength: int64(len(largeBody)), + } + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + + adapter.convertRequest(req, ctx) + + if !bytes.Equal(ctx.Request.Body(), largeBody) { + t.Errorf("Large body mismatch: expected %d bytes, got %d bytes", len(largeBody), len(ctx.Request.Body())) + } +} + +// TestStreamRequestBody_UnknownContentLength 测试未知内容长度的请求体 +func TestStreamRequestBody_UnknownContentLength(t *testing.T) { + adapter := NewAdapter() + + bodyContent := []byte("test body with unknown length") + req := &http.Request{ + Method: "POST", + URL: &url.URL{Path: "/test"}, + Host: "localhost", + Body: io.NopCloser(bytes.NewReader(bodyContent)), + ContentLength: -1, // 未知长度 + } + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + + adapter.convertRequest(req, ctx) + + if !bytes.Equal(ctx.Request.Body(), bodyContent) { + t.Errorf("Body mismatch: expected %s, got %s", bodyContent, ctx.Request.Body()) + } +} + +// TestStreamRequestBody_NoBody 测试无请求体的情况 +func TestStreamRequestBody_NoBody(t *testing.T) { + adapter := NewAdapter() + + req := &http.Request{ + Method: "GET", + URL: &url.URL{Path: "/test"}, + Host: "localhost", + Body: nil, + } + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + + adapter.convertRequest(req, ctx) + + if len(ctx.Request.Body()) != 0 { + t.Errorf("Expected empty body, got %d bytes", len(ctx.Request.Body())) + } +} + +// TestStreamRequestBody_NoBodyConstant 测试 http.NoBody 的情况 +func TestStreamRequestBody_NoBodyConstant(t *testing.T) { + adapter := NewAdapter() + + req := &http.Request{ + Method: "GET", + URL: &url.URL{Path: "/test"}, + Host: "localhost", + Body: http.NoBody, + } + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + + adapter.convertRequest(req, ctx) + + if len(ctx.Request.Body()) != 0 { + t.Errorf("Expected empty body with http.NoBody, got %d bytes", len(ctx.Request.Body())) + } +} + +// TestConvertRequest_EmptyRemoteAddr 测试空远程地址 +func TestConvertRequest_EmptyRemoteAddr(t *testing.T) { + adapter := NewAdapter() + + req := &http.Request{ + Method: "GET", + URL: &url.URL{Path: "/test"}, + Host: "localhost", + RemoteAddr: "", + } + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + + adapter.convertRequest(req, ctx) + + // 空远程地址不应该导致错误 + // fasthttp 会设置默认值,所以只需验证不会 panic +} + +// TestConvertRequest_InvalidRemoteAddr 测试无效远程地址 +func TestConvertRequest_InvalidRemoteAddr(t *testing.T) { + adapter := NewAdapter() + + req := &http.Request{ + Method: "GET", + URL: &url.URL{Path: "/test"}, + Host: "localhost", + RemoteAddr: "invalid-address", + } + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + + adapter.convertRequest(req, ctx) + + // 无效远程地址应该被忽略,不应导致错误 + // ctx.RemoteAddr() 可能为 nil 或默认值 +} + +// TestWrap_TypeAssertionFailure 测试类型断言失败的情况 +func TestWrap_TypeAssertionFailure(t *testing.T) { + adapter := NewAdapter() + + // 创建一个会触发类型断言的 handler + handler := func(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(200) + ctx.SetBodyString("OK") + } + + httpHandler := adapter.Wrap(handler) + + // 执行请求 + req := &http.Request{ + Method: "GET", + URL: &url.URL{Path: "/test"}, + Host: "localhost", + } + + rw := &mockResponseWriter{} + httpHandler.ServeHTTP(rw, req) + + // 应该正常处理 + if rw.status != 200 { + t.Errorf("Expected status 200, got %d", rw.status) + } +} + +// TestConvertResponse_EmptyBody 测试空响应体 +func TestConvertResponse_EmptyBody(t *testing.T) { + adapter := NewAdapter() + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + ctx.SetStatusCode(204) // No Content + + rw := &mockResponseWriter{} + + adapter.convertResponse(ctx, rw) + + if rw.status != 204 { + t.Errorf("Expected status 204, got %d", rw.status) + } + + if len(rw.body) != 0 { + t.Errorf("Expected empty body, got %d bytes", len(rw.body)) + } +} + +// TestConvertRequest_Protocol 测试协议版本设置 +func TestConvertRequest_Protocol(t *testing.T) { + adapter := NewAdapter() + + req := &http.Request{ + Method: "GET", + URL: &url.URL{Path: "/test"}, + Host: "localhost", + } + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + + adapter.convertRequest(req, ctx) + + protocol := string(ctx.Request.Header.Protocol()) + if protocol != "HTTP/3" { + t.Errorf("Expected protocol HTTP/3, got %s", protocol) + } +} + +// errorReader 是一个会返回错误的 io.Reader +type errorReader struct{} + +func (e *errorReader) Read(_ []byte) (n int, err error) { + return 0, io.ErrUnexpectedEOF +} + +// TestStreamRequestBody_ReadError 测试读取请求体时的错误处理 +func TestStreamRequestBody_ReadError(t *testing.T) { + adapter := NewAdapter() + + req := &http.Request{ + Method: "POST", + URL: &url.URL{Path: "/test"}, + Host: "localhost", + Body: io.NopCloser(&errorReader{}), + ContentLength: -1, // 强制使用流式读取 + } + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + + // 这不应该 panic,应该优雅处理错误 + adapter.convertRequest(req, ctx) + + // 错误读取时,body 应该为空 + if len(ctx.Request.Body()) != 0 { + t.Errorf("Expected empty body on read error, got %d bytes", len(ctx.Request.Body())) + } +} diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go index 6ccbfc9..445e7d0 100644 --- a/internal/http3/server_test.go +++ b/internal/http3/server_test.go @@ -11,7 +11,14 @@ package http3 import ( + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" "testing" "time" @@ -448,3 +455,243 @@ func TestGetAltSvcHeader_NilConfig(t *testing.T) { t.Error("Expected non-empty Alt-Svc header with valid config") } } + +// TestStart_AlreadyRunning 测试启动已运行的服务器 +func TestStart_AlreadyRunning(t *testing.T) { + cfg := &config.HTTP3Config{ + Enabled: true, + Listen: ":0", // 使用随机端口避免冲突 + MaxStreams: 100, + } + handler := func(_ *fasthttp.RequestCtx) {} + + cert := generateTestCertificate(t) + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + server, err := NewServer(cfg, handler, tlsConfig) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // 启动服务器 + err = server.Start() + if err != nil { + t.Fatalf("Failed to start server: %v", err) + } + defer server.Stop() + + // 再次启动应该失败 + err = server.Start() + if err == nil { + t.Error("Expected error when starting already running server") + } + if err.Error() != "server already running" { + t.Errorf("Expected 'server already running' error, got: %v", err) + } +} + +// TestStart_InvalidListenAddress 测试无效监听地址 +func TestStart_InvalidListenAddress(t *testing.T) { + cfg := &config.HTTP3Config{ + Enabled: true, + Listen: "invalid:address:format", // 无效地址 + MaxStreams: 100, + } + handler := func(_ *fasthttp.RequestCtx) {} + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{}, + } + + server, err := NewServer(cfg, handler, tlsConfig) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + err = server.Start() + if err == nil { + t.Error("Expected error for invalid listen address") + server.Stop() + } +} + +// TestStop_RunningServer 测试停止运行中的服务器 +func TestStop_RunningServer(t *testing.T) { + cfg := &config.HTTP3Config{ + Enabled: true, + Listen: ":0", + MaxStreams: 100, + } + handler := func(_ *fasthttp.RequestCtx) {} + + cert := generateTestCertificate(t) + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + server, err := NewServer(cfg, handler, tlsConfig) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // 启动服务器 + err = server.Start() + if err != nil { + t.Fatalf("Failed to start server: %v", err) + } + + if !server.IsRunning() { + t.Error("Server should be running after Start()") + } + + // 停止服务器 + err = server.Stop() + if err != nil { + t.Errorf("Unexpected error stopping server: %v", err) + } + + if server.IsRunning() { + t.Error("Server should not be running after Stop()") + } +} + +// TestGracefulStop_RunningServer 测试优雅停止运行中的服务器 +func TestGracefulStop_RunningServer(t *testing.T) { + cfg := &config.HTTP3Config{ + Enabled: true, + Listen: ":0", + MaxStreams: 100, + } + handler := func(_ *fasthttp.RequestCtx) {} + + cert := generateTestCertificate(t) + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + server, err := NewServer(cfg, handler, tlsConfig) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // 启动服务器 + err = server.Start() + if err != nil { + t.Fatalf("Failed to start server: %v", err) + } + + if !server.IsRunning() { + t.Error("Server should be running after Start()") + } + + // 优雅停止服务器 + err = server.GracefulStop(5 * time.Second) + if err != nil { + t.Errorf("Unexpected error graceful stopping server: %v", err) + } + + if server.IsRunning() { + t.Error("Server should not be running after GracefulStop()") + } +} + +// TestStart_Enable0RTT 测试启用 0-RTT 时的警告日志 +func TestStart_Enable0RTT(t *testing.T) { + cfg := &config.HTTP3Config{ + Enabled: true, + Listen: ":0", + Enable0RTT: true, + MaxStreams: 100, + } + handler := func(_ *fasthttp.RequestCtx) {} + + cert := generateTestCertificate(t) + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + server, err := NewServer(cfg, handler, tlsConfig) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + err = server.Start() + if err != nil { + t.Fatalf("Failed to start server: %v", err) + } + defer server.Stop() + + // 验证服务器已启动 + if !server.IsRunning() { + t.Error("Server should be running") + } +} + +// TestStart_DefaultValues 测试默认值设置 +func TestStart_DefaultValues(t *testing.T) { + cfg := &config.HTTP3Config{ + Enabled: true, + Listen: ":0", + MaxStreams: 0, // 使用默认值 + } + handler := func(_ *fasthttp.RequestCtx) {} + + cert := generateTestCertificate(t) + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + server, err := NewServer(cfg, handler, tlsConfig) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + err = server.Start() + if err != nil { + t.Fatalf("Failed to start server: %v", err) + } + defer server.Stop() + + if !server.IsRunning() { + t.Error("Server should be running") + } +} + +// generateTestCertificate 生成用于测试的自签名证书 +func generateTestCertificate(t *testing.T) tls.Certificate { + t.Helper() + + // 使用 RSA 密钥生成自签名证书 + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "localhost"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("Failed to generate RSA key: %v", err) + } + + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("Failed to create certificate: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatalf("Failed to load certificate: %v", err) + } + + return cert +}