diff --git a/internal/http2/adapter_test.go b/internal/http2/adapter_test.go index 6c9d12d..3b54fae 100644 --- a/internal/http2/adapter_test.go +++ b/internal/http2/adapter_test.go @@ -496,21 +496,122 @@ func TestStreamRequestBody(t *testing.T) { } } -// TestAdapterPoolReuse 测试对象池复用。 -func TestAdapterPoolReuse(_ *testing.T) { +// TestAdapterConvertHeaders_Empty 测试空 header 转换 +func TestAdapterConvertHeaders_Empty(t *testing.T) { handler := func(ctx *fasthttp.RequestCtx) { - ctx.WriteString("Test") + // 检查是否有任何 header(除了 Content-Type) ctx.SetStatusCode(fasthttp.StatusOK) } adapter := NewFastHTTPHandlerAdapter(handler) - // 发送多个请求,验证池复用 - for i := 0; i < 10; i++ { - req := httptest.NewRequest(http.MethodGet, "/test", nil) - rec := httptest.NewRecorder() - adapter.ServeHTTP(rec, req) + req := httptest.NewRequest(http.MethodGet, "/test", nil) + // 不设置任何自定义 header + rec := httptest.NewRecorder() + + adapter.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code) + } +} + +// TestAdapterConvertHeaders_SpecialChars 测试特殊字符 header 转换 +func TestAdapterConvertHeaders_SpecialChars(t *testing.T) { + var receivedHeaders map[string]string + + handler := func(ctx *fasthttp.RequestCtx) { + receivedHeaders = make(map[string]string) + ctx.Request.Header.VisitAll(func(key, value []byte) { + receivedHeaders[string(key)] = string(value) + }) + ctx.SetStatusCode(fasthttp.StatusOK) } - // 测试通过,没有 panic 表示池工作正常 + adapter := NewFastHTTPHandlerAdapter(handler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + // 测试各种特殊字符 + req.Header.Set("X-Special-Value", "test=value&foo=bar") + req.Header.Set("X-Unicode", "Hello世界") + req.Header.Set("X-Empty", "") + req.Header.Set("X-Space", "hello world") + req.Header.Set("X-Quote", "test\"quoted\"") + rec := httptest.NewRecorder() + + adapter.ServeHTTP(rec, req) + + if receivedHeaders == nil { + t.Fatal("No headers received") + } + + // 验证特殊字符被正确处理 + if _, ok := receivedHeaders["X-Special-Value"]; !ok { + t.Error("X-Special-Value header not received") + } + if _, ok := receivedHeaders["X-Unicode"]; !ok { + t.Error("X-Unicode header not received") + } +} + +// TestAdapterConvertHeaders_MultipleValues 测试多值 header +func TestAdapterConvertHeaders_MultipleValues(t *testing.T) { + receivedHeaders := make(map[string][]string) + + handler := func(ctx *fasthttp.RequestCtx) { + ctx.Request.Header.VisitAll(func(key, value []byte) { + k := string(key) + receivedHeaders[k] = append(receivedHeaders[k], string(value)) + }) + ctx.SetStatusCode(fasthttp.StatusOK) + } + + adapter := NewFastHTTPHandlerAdapter(handler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + // 添加多个同名的 header(Accept 是标准 header,fasthttp 支持多值) + req.Header.Add("Accept", "application/json") + req.Header.Add("Accept", "text/plain") + rec := httptest.NewRecorder() + + adapter.ServeHTTP(rec, req) + + // 验证至少接收到一个 Accept header + if len(receivedHeaders) == 0 { + t.Error("No headers received") + } + + // 检查 Accept header 是否有值 + acceptValues, ok := receivedHeaders["Accept"] + if !ok { + // fasthttp 可能将多值合并为一个,这是正常的 + t.Logf("Accept header values merged or not present, headers received: %v", receivedHeaders) + } else if len(acceptValues) == 0 { + t.Error("Accept header present but no values") + } +} + +// TestAdapterConvertHeaders_LongHeaderName 测试长 header 名称 +func TestAdapterConvertHeaders_LongHeaderName(t *testing.T) { + handler := func(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(fasthttp.StatusOK) + } + + adapter := NewFastHTTPHandlerAdapter(handler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + // 创建一个很长的 header 名称 + longHeaderName := "X-" + string(make([]byte, 1000)) + for i := range longHeaderName[2:] { + longHeaderName = longHeaderName[:2] + string('a'+byte(i%26)) + longHeaderName[3:] + } + req.Header.Set(longHeaderName, "value") + rec := httptest.NewRecorder() + + // 不应该 panic + adapter.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("Expected status %d, got %d", http.StatusOK, rec.Code) + } } diff --git a/internal/http3/mock_test.go b/internal/http3/mock_test.go new file mode 100644 index 0000000..4dec8d8 --- /dev/null +++ b/internal/http3/mock_test.go @@ -0,0 +1,40 @@ +// Package http3 提供测试用的 Mock 实现 +package http3 + +import ( + "context" + "net" + + "github.com/quic-go/quic-go" +) + +// MockQUICListener 是 QUIC 监听器的 Mock 实现 +type MockQUICListener struct { + AcceptFunc func(ctx context.Context) (*quic.Conn, error) + CloseFunc func() error + AddrFunc func() net.Addr +} + +// Accept 接受连接 +func (m *MockQUICListener) Accept(ctx context.Context) (*quic.Conn, error) { + if m.AcceptFunc != nil { + return m.AcceptFunc(ctx) + } + return nil, nil +} + +// Close 关闭监听器 +func (m *MockQUICListener) Close() error { + if m.CloseFunc != nil { + return m.CloseFunc() + } + return nil +} + +// Addr 返回监听地址 +func (m *MockQUICListener) Addr() net.Addr { + if m.AddrFunc != nil { + return m.AddrFunc() + } + return &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443} +} diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go index 8ef3eff..6ccbfc9 100644 --- a/internal/http3/server_test.go +++ b/internal/http3/server_test.go @@ -359,25 +359,92 @@ func TestServer_MultipleGracefulStop(t *testing.T) { } } -// TestStats_Struct 测试 Stats 结构体 -func TestStats_Struct(t *testing.T) { - stats := Stats{ - Running: true, - Listen: ":443", - Enable0RTT: true, - MaxStreams: 100, +// TestGetAltSvcHeader_PortBoundaries 测试端口边界值 +func TestGetAltSvcHeader_PortBoundaries(t *testing.T) { + tests := []struct { + name string + listen string + expected string + }{ + { + name: "标准 HTTP 端口 80", + listen: ":80", + expected: `h3=":80"; ma=86400`, + }, + { + name: "标准 HTTPS 端口 443", + listen: ":443", + expected: `h3=":443"; ma=86400`, + }, + { + name: "高端口 65535", + listen: ":65535", + expected: `h3=":65535"; ma=86400`, + }, + { + name: "低端口 1", + listen: ":1", + expected: `h3=":1"; ma=86400`, + }, + { + name: "带 IP 地址的监听", + listen: "0.0.0.0:8443", + expected: `h3=":0.0.0.0:8443"; ma=86400`, + }, + { + name: "带 IPv6 地址的监听", + listen: "[::]:8443", + expected: `h3=":[::]:8443"; ma=86400`, + }, } - if !stats.Running { - t.Error("Expected Running true") - } - if stats.Listen != ":443" { - t.Errorf("Expected Listen ':443', got '%s'", stats.Listen) - } - if !stats.Enable0RTT { - t.Error("Expected Enable0RTT true") - } - if stats.MaxStreams != 100 { - t.Errorf("Expected MaxStreams 100, got %d", stats.MaxStreams) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.HTTP3Config{ + Enabled: true, + Listen: tt.listen, + } + handler := func(_ *fasthttp.RequestCtx) {} + + server, _ := NewServer(cfg, handler, &tls.Config{}) + + header := server.GetAltSvcHeader() + if header != tt.expected { + t.Errorf("GetAltSvcHeader() = %q, want %q", header, tt.expected) + } + }) + } +} + +// TestGetAltSvcHeader_DisabledServer 测试禁用状态下的服务器 +func TestGetAltSvcHeader_DisabledServer(t *testing.T) { + cfg := &config.HTTP3Config{ + Enabled: false, + Listen: ":443", + } + handler := func(_ *fasthttp.RequestCtx) {} + + server, _ := NewServer(cfg, handler, &tls.Config{}) + + header := server.GetAltSvcHeader() + if header != "" { + t.Errorf("Expected empty Alt-Svc header when disabled, got %q", header) + } +} + +// TestGetAltSvcHeader_NilConfig 测试 nil 配置情况 +func TestGetAltSvcHeader_NilConfig(t *testing.T) { + cfg := &config.HTTP3Config{ + Enabled: true, + Listen: ":443", + } + handler := func(_ *fasthttp.RequestCtx) {} + + server, _ := NewServer(cfg, handler, &tls.Config{}) + + // 正常情况下应该返回 header + header := server.GetAltSvcHeader() + if header == "" { + t.Error("Expected non-empty Alt-Svc header with valid config") } } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 68fcc44..759d2eb 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -1314,3 +1314,45 @@ func TestUpstreamTimingZero(t *testing.T) { t.Errorf("GetConnectTime() after MarkConnectStart = %v, want 0", timing.GetConnectTime()) } } + +// TestUpstreamTiming_ZeroValues 测试 UpstreamTiming 完全零值情况 +func TestUpstreamTiming_ZeroValues(t *testing.T) { + // 创建一个零值的时间记录器(模拟未初始化的状态) + timing := &UpstreamTiming{} + + // 所有时间应该返回 0 + if timing.GetConnectTime() != 0 { + t.Errorf("Zero timing GetConnectTime() = %v, want 0", timing.GetConnectTime()) + } + if timing.GetHeaderTime() != 0 { + t.Errorf("Zero timing GetHeaderTime() = %v, want 0", timing.GetHeaderTime()) + } + if timing.GetResponseTime() != 0 { + t.Errorf("Zero timing GetResponseTime() = %v, want 0", timing.GetResponseTime()) + } +} + +// TestUpstreamTiming_PartialMarks 测试部分标记的情况 +func TestUpstreamTiming_PartialMarks(t *testing.T) { + timing := NewUpstreamTiming() + + // 只标记 connectEnd,不标记 connectStart + timing.MarkConnectEnd() + if timing.GetConnectTime() != 0 { + t.Errorf("GetConnectTime() with only end marked = %v, want 0", timing.GetConnectTime()) + } + + // 重置并测试只有 headerReceived 的情况 + timing = NewUpstreamTiming() + timing.MarkHeaderReceived() + if timing.GetHeaderTime() != 0 { + t.Errorf("GetHeaderTime() without connectEnd = %v, want 0", timing.GetHeaderTime()) + } + + // 重置并测试只有 responseEnd 的情况 + timing = NewUpstreamTiming() + timing.MarkResponseEnd() + if timing.GetResponseTime() != 0 { + t.Errorf("GetResponseTime() without connectEnd = %v, want 0", timing.GetResponseTime()) + } +}