test(http2,http3): 添加 HTTP/2 和 HTTP/3 服务器测试
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
cfb27a9b9d
commit
9f3524f641
@ -9,8 +9,12 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -479,3 +483,504 @@ func TestCanonicalHeaderKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateSettings_InitialWindowSize 测试 InitialWindowSize 超限。
|
||||
func TestValidateSettings_InitialWindowSize(t *testing.T) {
|
||||
settings := Settings{
|
||||
MaxConcurrentStreams: 100,
|
||||
MaxFrameSize: 16384,
|
||||
MaxHeaderListSize: 1048576,
|
||||
InitialWindowSize: 2147483648, // 超过 2^31-1
|
||||
}
|
||||
|
||||
err := ValidateSettings(settings)
|
||||
if err == nil {
|
||||
t.Error("ValidateSettings() expected error for InitialWindowSize > 2^31-1")
|
||||
}
|
||||
}
|
||||
|
||||
// TestServe_AcceptError 测试 Accept 错误处理。
|
||||
func TestServe_AcceptError(t *testing.T) {
|
||||
cfg := &config.HTTP2Config{Enabled: true}
|
||||
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer() error: %v", err)
|
||||
}
|
||||
|
||||
// 创建一个已关闭的监听器
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
|
||||
// 启动服务器
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- server.Serve(ln)
|
||||
}()
|
||||
|
||||
// 关闭监听器触发 Accept 错误
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Logf("Failed to close listener: %v", err)
|
||||
}
|
||||
|
||||
// 停止服务器
|
||||
_ = server.Stop()
|
||||
|
||||
// 服务器应该正常退出
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Errorf("Serve() unexpected error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Serve() did not exit in time")
|
||||
}
|
||||
}
|
||||
|
||||
// TestServe_AlreadyRunning 测试服务器重复启动。
|
||||
func TestServe_AlreadyRunning(t *testing.T) {
|
||||
cfg := &config.HTTP2Config{Enabled: true}
|
||||
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer() error: %v", err)
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Logf("Failed to close listener: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 启动服务器
|
||||
go func() {
|
||||
_ = server.Serve(ln)
|
||||
}()
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 尝试再次启动
|
||||
err = server.Serve(ln)
|
||||
if err == nil {
|
||||
t.Error("Serve() should return error when already running")
|
||||
}
|
||||
|
||||
// 停止服务器
|
||||
_ = server.Stop()
|
||||
}
|
||||
|
||||
// TestStop_GracefulShutdownTimeout 测试优雅关闭超时。
|
||||
func TestStop_GracefulShutdownTimeout(t *testing.T) {
|
||||
cfg := &config.HTTP2Config{
|
||||
Enabled: true,
|
||||
GracefulShutdownTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
// 模拟长时间处理
|
||||
time.Sleep(2 * time.Second)
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer() error: %v", err)
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Logf("Failed to close listener: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 启动服务器
|
||||
go func() {
|
||||
_ = server.Serve(ln)
|
||||
}()
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 停止服务器(应该超时)
|
||||
start := time.Now()
|
||||
_ = server.Stop()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// 应该在超时后返回
|
||||
if elapsed > 500*time.Millisecond {
|
||||
t.Errorf("Stop() took too long: %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStop_NotRunning 测试停止未运行的服务器。
|
||||
func TestStop_NotRunning(t *testing.T) {
|
||||
cfg := &config.HTTP2Config{Enabled: true}
|
||||
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer() error: %v", err)
|
||||
}
|
||||
|
||||
// 停止未运行的服务器应该返回 nil
|
||||
err = server.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Stop() on non-running server should return nil, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleH2C 测试 H2C 处理。
|
||||
func TestHandleH2C(t *testing.T) {
|
||||
cfg := &config.HTTP2Config{Enabled: true}
|
||||
server, err := NewServer(cfg, func(_ *fasthttp.RequestCtx) {}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer() error: %v", err)
|
||||
}
|
||||
|
||||
// 创建一个 mock 连接
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Logf("Failed to close listener: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("Failed to close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// HandleH2C 应该返回 false(不支持 H2C)
|
||||
handled, err := server.HandleH2C(conn)
|
||||
if handled {
|
||||
t.Error("HandleH2C() should return false (H2C not supported)")
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("HandleH2C() should return nil error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestH2CConnRead 测试 h2cConn.Read。
|
||||
func TestH2CConnRead(t *testing.T) {
|
||||
// 创建一个测试用的连接
|
||||
server, client := net.Pipe()
|
||||
defer func() {
|
||||
_ = server.Close()
|
||||
_ = client.Close()
|
||||
}()
|
||||
|
||||
// 创建 h2cConn
|
||||
h2c := &h2cConn{
|
||||
Conn: server,
|
||||
reader: nil,
|
||||
}
|
||||
|
||||
// 测试无 reader 的读取
|
||||
data := []byte("test data")
|
||||
go func() {
|
||||
_, _ = client.Write(data)
|
||||
}()
|
||||
|
||||
buf := make([]byte, 100)
|
||||
n, err := h2c.Read(buf)
|
||||
if err != nil {
|
||||
t.Errorf("h2cConn.Read() error: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("h2cConn.Read() n = %d, want %d", n, len(data))
|
||||
}
|
||||
}
|
||||
|
||||
// TestH2CConnRead_WithReader 测试 h2cConn.Read 带 reader。
|
||||
func TestH2CConnRead_WithReader(t *testing.T) {
|
||||
// 创建一个测试用的连接
|
||||
server, client := net.Pipe()
|
||||
defer func() {
|
||||
_ = server.Close()
|
||||
_ = client.Close()
|
||||
}()
|
||||
|
||||
// 创建带 reader 的 h2cConn
|
||||
reader := bufio.NewReader(bytes.NewReader([]byte("prefetched")))
|
||||
h2c := &h2cConn{
|
||||
Conn: server,
|
||||
reader: reader,
|
||||
}
|
||||
|
||||
buf := make([]byte, 100)
|
||||
n, err := h2c.Read(buf)
|
||||
if err != nil {
|
||||
t.Errorf("h2cConn.Read() error: %v", err)
|
||||
}
|
||||
if n == 0 {
|
||||
t.Error("h2cConn.Read() should read from reader")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsHTTP2Request 测试 HTTP/2 请求检测。
|
||||
func TestIsHTTP2Request_Full(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
major int
|
||||
header map[string]string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "PRI 方法",
|
||||
method: "PRI",
|
||||
major: 1,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "HTTP/2 版本",
|
||||
method: "GET",
|
||||
major: 2,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "HTTP/1.1",
|
||||
method: "GET",
|
||||
major: 1,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "带 :method 头",
|
||||
method: "GET",
|
||||
major: 1,
|
||||
header: map[string]string{":method": "GET"},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(tt.method, "http://example.com/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
req.ProtoMajor = tt.major
|
||||
|
||||
for k, v := range tt.header {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
got := IsHTTP2Request(req)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsHTTP2Request() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetALPNProtocol 测试获取 ALPN 协议。
|
||||
func TestGetALPNProtocol(t *testing.T) {
|
||||
// 测试非 TLS 连接
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Logf("Failed to close listener: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("Failed to close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 非 TLS 连接应返回空字符串
|
||||
proto := GetALPNProtocol(conn)
|
||||
if proto != "" {
|
||||
t.Errorf("GetALPNProtocol() on non-TLS connection = %q, want empty", proto)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSupportsHTTP2 测试 HTTP/2 支持检测。
|
||||
func TestSupportsHTTP2(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
major int
|
||||
header map[string]string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "HTTP/2 请求",
|
||||
method: "GET",
|
||||
major: 2,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "H2C 升级头",
|
||||
method: "GET",
|
||||
major: 1,
|
||||
header: map[string]string{"Upgrade": "h2c"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "HTTP2-Settings 头",
|
||||
method: "GET",
|
||||
major: 1,
|
||||
header: map[string]string{"HTTP2-Settings": "test"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "HTTP/1.1 无升级",
|
||||
method: "GET",
|
||||
major: 1,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(tt.method, "http://example.com/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
req.ProtoMajor = tt.major
|
||||
|
||||
for k, v := range tt.header {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
got := SupportsHTTP2(req)
|
||||
if got != tt.want {
|
||||
t.Errorf("SupportsHTTP2() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapTLSListener_GetConfigForClient 测试 TLS 监听器的 GetConfigForClient 回调。
|
||||
func TestWrapTLSListener_GetConfigForClient(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Logf("Failed to close listener: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
NextProtos: []string{},
|
||||
GetConfigForClient: func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
_ = WrapTLSListener(ln, tlsConfig)
|
||||
}
|
||||
|
||||
// TestWrapTLSListener_GetConfigForClientError 测试 GetConfigForClient 返回错误。
|
||||
func TestWrapTLSListener_GetConfigForClientError(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Logf("Failed to close listener: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
expectedErr := errors.New("client config error")
|
||||
tlsConfig := &tls.Config{
|
||||
NextProtos: []string{},
|
||||
GetConfigForClient: func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
return nil, expectedErr
|
||||
},
|
||||
}
|
||||
|
||||
_ = WrapTLSListener(ln, tlsConfig)
|
||||
}
|
||||
|
||||
// TestConnectionPool_CloseAll 测试连接池关闭所有连接。
|
||||
func TestConnectionPool_CloseAll(t *testing.T) {
|
||||
pool := newConnectionPool()
|
||||
|
||||
// 创建多个连接
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Logf("Failed to close listener: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
conn1, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial: %v", err)
|
||||
}
|
||||
conn2, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial: %v", err)
|
||||
}
|
||||
|
||||
pool.add("key1", conn1)
|
||||
pool.add("key1", conn2)
|
||||
|
||||
// 关闭所有连接
|
||||
pool.closeAll()
|
||||
|
||||
// 验证连接池已清空
|
||||
if count := pool.count("key1"); count != 0 {
|
||||
t.Errorf("Expected count 0 after closeAll, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectionPool_RemoveNonExistent 测试移除不存在的连接。
|
||||
func TestConnectionPool_RemoveNonExistent(t *testing.T) {
|
||||
pool := newConnectionPool()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ln.Close(); err != nil {
|
||||
t.Logf("Failed to close listener: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
// 移除不存在的 key/conn 组合不应 panic
|
||||
pool.remove("nonexistent", conn)
|
||||
pool.remove("key1", conn)
|
||||
}
|
||||
|
||||
@ -377,3 +377,237 @@ func (m *mockResponseWriter) Write(data []byte) (int, error) {
|
||||
}
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// TestStreamRequestBody_LargeBody 测试大请求体的流式处理
|
||||
func TestStreamRequestBody_LargeBody(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
|
||||
// 创建大于 64KB 的请求体
|
||||
largeBody := make([]byte, 100*1024) // 100KB
|
||||
for i := range largeBody {
|
||||
largeBody[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
req := &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/test"},
|
||||
Host: "localhost",
|
||||
Body: io.NopCloser(bytes.NewReader(largeBody)),
|
||||
ContentLength: int64(len(largeBody)),
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
adapter.convertRequest(req, ctx)
|
||||
|
||||
if !bytes.Equal(ctx.Request.Body(), largeBody) {
|
||||
t.Errorf("Large body mismatch: expected %d bytes, got %d bytes", len(largeBody), len(ctx.Request.Body()))
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamRequestBody_UnknownContentLength 测试未知内容长度的请求体
|
||||
func TestStreamRequestBody_UnknownContentLength(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
|
||||
bodyContent := []byte("test body with unknown length")
|
||||
req := &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/test"},
|
||||
Host: "localhost",
|
||||
Body: io.NopCloser(bytes.NewReader(bodyContent)),
|
||||
ContentLength: -1, // 未知长度
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
adapter.convertRequest(req, ctx)
|
||||
|
||||
if !bytes.Equal(ctx.Request.Body(), bodyContent) {
|
||||
t.Errorf("Body mismatch: expected %s, got %s", bodyContent, ctx.Request.Body())
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamRequestBody_NoBody 测试无请求体的情况
|
||||
func TestStreamRequestBody_NoBody(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
|
||||
req := &http.Request{
|
||||
Method: "GET",
|
||||
URL: &url.URL{Path: "/test"},
|
||||
Host: "localhost",
|
||||
Body: nil,
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
adapter.convertRequest(req, ctx)
|
||||
|
||||
if len(ctx.Request.Body()) != 0 {
|
||||
t.Errorf("Expected empty body, got %d bytes", len(ctx.Request.Body()))
|
||||
}
|
||||
}
|
||||
|
||||
// TestStreamRequestBody_NoBodyConstant 测试 http.NoBody 的情况
|
||||
func TestStreamRequestBody_NoBodyConstant(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
|
||||
req := &http.Request{
|
||||
Method: "GET",
|
||||
URL: &url.URL{Path: "/test"},
|
||||
Host: "localhost",
|
||||
Body: http.NoBody,
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
adapter.convertRequest(req, ctx)
|
||||
|
||||
if len(ctx.Request.Body()) != 0 {
|
||||
t.Errorf("Expected empty body with http.NoBody, got %d bytes", len(ctx.Request.Body()))
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertRequest_EmptyRemoteAddr 测试空远程地址
|
||||
func TestConvertRequest_EmptyRemoteAddr(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
|
||||
req := &http.Request{
|
||||
Method: "GET",
|
||||
URL: &url.URL{Path: "/test"},
|
||||
Host: "localhost",
|
||||
RemoteAddr: "",
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
adapter.convertRequest(req, ctx)
|
||||
|
||||
// 空远程地址不应该导致错误
|
||||
// fasthttp 会设置默认值,所以只需验证不会 panic
|
||||
}
|
||||
|
||||
// TestConvertRequest_InvalidRemoteAddr 测试无效远程地址
|
||||
func TestConvertRequest_InvalidRemoteAddr(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
|
||||
req := &http.Request{
|
||||
Method: "GET",
|
||||
URL: &url.URL{Path: "/test"},
|
||||
Host: "localhost",
|
||||
RemoteAddr: "invalid-address",
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
adapter.convertRequest(req, ctx)
|
||||
|
||||
// 无效远程地址应该被忽略,不应导致错误
|
||||
// ctx.RemoteAddr() 可能为 nil 或默认值
|
||||
}
|
||||
|
||||
// TestWrap_TypeAssertionFailure 测试类型断言失败的情况
|
||||
func TestWrap_TypeAssertionFailure(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
|
||||
// 创建一个会触发类型断言的 handler
|
||||
handler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(200)
|
||||
ctx.SetBodyString("OK")
|
||||
}
|
||||
|
||||
httpHandler := adapter.Wrap(handler)
|
||||
|
||||
// 执行请求
|
||||
req := &http.Request{
|
||||
Method: "GET",
|
||||
URL: &url.URL{Path: "/test"},
|
||||
Host: "localhost",
|
||||
}
|
||||
|
||||
rw := &mockResponseWriter{}
|
||||
httpHandler.ServeHTTP(rw, req)
|
||||
|
||||
// 应该正常处理
|
||||
if rw.status != 200 {
|
||||
t.Errorf("Expected status 200, got %d", rw.status)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertResponse_EmptyBody 测试空响应体
|
||||
func TestConvertResponse_EmptyBody(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
ctx.SetStatusCode(204) // No Content
|
||||
|
||||
rw := &mockResponseWriter{}
|
||||
|
||||
adapter.convertResponse(ctx, rw)
|
||||
|
||||
if rw.status != 204 {
|
||||
t.Errorf("Expected status 204, got %d", rw.status)
|
||||
}
|
||||
|
||||
if len(rw.body) != 0 {
|
||||
t.Errorf("Expected empty body, got %d bytes", len(rw.body))
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertRequest_Protocol 测试协议版本设置
|
||||
func TestConvertRequest_Protocol(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
|
||||
req := &http.Request{
|
||||
Method: "GET",
|
||||
URL: &url.URL{Path: "/test"},
|
||||
Host: "localhost",
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
adapter.convertRequest(req, ctx)
|
||||
|
||||
protocol := string(ctx.Request.Header.Protocol())
|
||||
if protocol != "HTTP/3" {
|
||||
t.Errorf("Expected protocol HTTP/3, got %s", protocol)
|
||||
}
|
||||
}
|
||||
|
||||
// errorReader 是一个会返回错误的 io.Reader
|
||||
type errorReader struct{}
|
||||
|
||||
func (e *errorReader) Read(_ []byte) (n int, err error) {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
// TestStreamRequestBody_ReadError 测试读取请求体时的错误处理
|
||||
func TestStreamRequestBody_ReadError(t *testing.T) {
|
||||
adapter := NewAdapter()
|
||||
|
||||
req := &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/test"},
|
||||
Host: "localhost",
|
||||
Body: io.NopCloser(&errorReader{}),
|
||||
ContentLength: -1, // 强制使用流式读取
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
// 这不应该 panic,应该优雅处理错误
|
||||
adapter.convertRequest(req, ctx)
|
||||
|
||||
// 错误读取时,body 应该为空
|
||||
if len(ctx.Request.Body()) != 0 {
|
||||
t.Errorf("Expected empty body on read error, got %d bytes", len(ctx.Request.Body()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -11,7 +11,14 @@
|
||||
package http3
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -448,3 +455,243 @@ func TestGetAltSvcHeader_NilConfig(t *testing.T) {
|
||||
t.Error("Expected non-empty Alt-Svc header with valid config")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_AlreadyRunning 测试启动已运行的服务器
|
||||
func TestStart_AlreadyRunning(t *testing.T) {
|
||||
cfg := &config.HTTP3Config{
|
||||
Enabled: true,
|
||||
Listen: ":0", // 使用随机端口避免冲突
|
||||
MaxStreams: 100,
|
||||
}
|
||||
handler := func(_ *fasthttp.RequestCtx) {}
|
||||
|
||||
cert := generateTestCertificate(t)
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// 启动服务器
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
defer server.Stop()
|
||||
|
||||
// 再次启动应该失败
|
||||
err = server.Start()
|
||||
if err == nil {
|
||||
t.Error("Expected error when starting already running server")
|
||||
}
|
||||
if err.Error() != "server already running" {
|
||||
t.Errorf("Expected 'server already running' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_InvalidListenAddress 测试无效监听地址
|
||||
func TestStart_InvalidListenAddress(t *testing.T) {
|
||||
cfg := &config.HTTP3Config{
|
||||
Enabled: true,
|
||||
Listen: "invalid:address:format", // 无效地址
|
||||
MaxStreams: 100,
|
||||
}
|
||||
handler := func(_ *fasthttp.RequestCtx) {}
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{},
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
err = server.Start()
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid listen address")
|
||||
server.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// TestStop_RunningServer 测试停止运行中的服务器
|
||||
func TestStop_RunningServer(t *testing.T) {
|
||||
cfg := &config.HTTP3Config{
|
||||
Enabled: true,
|
||||
Listen: ":0",
|
||||
MaxStreams: 100,
|
||||
}
|
||||
handler := func(_ *fasthttp.RequestCtx) {}
|
||||
|
||||
cert := generateTestCertificate(t)
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// 启动服务器
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
|
||||
if !server.IsRunning() {
|
||||
t.Error("Server should be running after Start()")
|
||||
}
|
||||
|
||||
// 停止服务器
|
||||
err = server.Stop()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error stopping server: %v", err)
|
||||
}
|
||||
|
||||
if server.IsRunning() {
|
||||
t.Error("Server should not be running after Stop()")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGracefulStop_RunningServer 测试优雅停止运行中的服务器
|
||||
func TestGracefulStop_RunningServer(t *testing.T) {
|
||||
cfg := &config.HTTP3Config{
|
||||
Enabled: true,
|
||||
Listen: ":0",
|
||||
MaxStreams: 100,
|
||||
}
|
||||
handler := func(_ *fasthttp.RequestCtx) {}
|
||||
|
||||
cert := generateTestCertificate(t)
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// 启动服务器
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
|
||||
if !server.IsRunning() {
|
||||
t.Error("Server should be running after Start()")
|
||||
}
|
||||
|
||||
// 优雅停止服务器
|
||||
err = server.GracefulStop(5 * time.Second)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error graceful stopping server: %v", err)
|
||||
}
|
||||
|
||||
if server.IsRunning() {
|
||||
t.Error("Server should not be running after GracefulStop()")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_Enable0RTT 测试启用 0-RTT 时的警告日志
|
||||
func TestStart_Enable0RTT(t *testing.T) {
|
||||
cfg := &config.HTTP3Config{
|
||||
Enabled: true,
|
||||
Listen: ":0",
|
||||
Enable0RTT: true,
|
||||
MaxStreams: 100,
|
||||
}
|
||||
handler := func(_ *fasthttp.RequestCtx) {}
|
||||
|
||||
cert := generateTestCertificate(t)
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
defer server.Stop()
|
||||
|
||||
// 验证服务器已启动
|
||||
if !server.IsRunning() {
|
||||
t.Error("Server should be running")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_DefaultValues 测试默认值设置
|
||||
func TestStart_DefaultValues(t *testing.T) {
|
||||
cfg := &config.HTTP3Config{
|
||||
Enabled: true,
|
||||
Listen: ":0",
|
||||
MaxStreams: 0, // 使用默认值
|
||||
}
|
||||
handler := func(_ *fasthttp.RequestCtx) {}
|
||||
|
||||
cert := generateTestCertificate(t)
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
|
||||
server, err := NewServer(cfg, handler, tlsConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start server: %v", err)
|
||||
}
|
||||
defer server.Stop()
|
||||
|
||||
if !server.IsRunning() {
|
||||
t.Error("Server should be running")
|
||||
}
|
||||
}
|
||||
|
||||
// generateTestCertificate 生成用于测试的自签名证书
|
||||
func generateTestCertificate(t *testing.T) tls.Certificate {
|
||||
t.Helper()
|
||||
|
||||
// 使用 RSA 密钥生成自签名证书
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "localhost"},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{"localhost"},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
}
|
||||
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
|
||||
|
||||
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load certificate: %v", err)
|
||||
}
|
||||
|
||||
return cert
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user