lolly/internal/http2/integration_tls_test.go

375 lines
9.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package http2 提供 HTTP/2 TLS 连接集成测试。
//
// 该文件测试 HTTP/2 服务器的 TLS 相关功能:
// - TLS 握手成功/失败
// - ALPN 协商 h2/http1.1
// - HTTP/1.1 回退路径
//
// 作者xfy
package http2
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"net"
"sync"
"testing"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
)
// generateTestCert 生成测试用自签名证书。
func generateTestCert(t *testing.T) (tls.Certificate, *x509.CertPool) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("Failed to generate private key: %v", err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("Failed to create certificate: %v", err)
}
cert := tls.Certificate{
Certificate: [][]byte{certDER},
PrivateKey: privateKey,
}
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(certDER)
return cert, certPool
}
// TestTLSHandshakeSuccess 测试 TLS 握手成功。
func TestTLSHandshakeSuccess(t *testing.T) {
cert, _ := generateTestCert(t)
handler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("Hello HTTP/2")
ctx.SetStatusCode(fasthttp.StatusOK)
}
cfg := &config.HTTP2Config{
Enabled: true,
MaxConcurrentStreams: 100,
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{"h2", "http/1.1"},
}
server, err := NewServer(cfg, handler, tlsConfig)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// 创建管道连接
serverConn, clientConn := net.Pipe()
defer func() {
_ = serverConn.Close()
_ = clientConn.Close()
}()
// 包装服务器端连接为 TLS
tlsServerConn := tls.Server(serverConn, tlsConfig)
// 需要先 Add(1) 因为 handleConnection 会调用 Done()
server.connWg.Add(1)
// 在后台处理连接
go func() {
server.handleConnection(tlsServerConn)
}()
// 客户端 TLS 握手
tlsClientConfig := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
}
tlsClientConn := tls.Client(clientConn, tlsClientConfig)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := tlsClientConn.HandshakeContext(ctx); err != nil {
t.Fatalf("TLS handshake failed: %v", err)
}
// 验证协商的协议
state := tlsClientConn.ConnectionState()
if state.NegotiatedProtocol != "h2" {
t.Errorf("Expected negotiated protocol 'h2', got '%s'", state.NegotiatedProtocol)
}
// 关闭连接
_ = tlsClientConn.Close()
_ = tlsServerConn.Close()
// 等待处理完成
done := make(chan struct{})
go func() {
server.connWg.Wait()
close(done)
}()
select {
case <-done:
t.Log("Connection handling completed")
case <-time.After(2 * time.Second):
t.Log("Timeout waiting for connection handling")
}
}
// TestTLSHandshakeFailure 测试 TLS 握手失败。
func TestTLSHandshakeFailure(t *testing.T) {
cert, _ := generateTestCert(t)
handler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("Hello")
}
cfg := &config.HTTP2Config{
Enabled: true,
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{"h2"},
}
server, err := NewServer(cfg, handler, tlsConfig)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
serverConn, clientConn := net.Pipe()
// 包装服务器端连接为 TLS
tlsServerConn := tls.Server(serverConn, tlsConfig)
// 需要先 Add(1) 因为 handleConnection 会调用 Done()
server.connWg.Add(1)
// 在后台处理连接
go func() {
server.handleConnection(tlsServerConn)
}()
// 客户端不进行 TLS 握手,直接发送无效数据
_, _ = clientConn.Write([]byte("INVALID DATA NOT TLS"))
// 等待处理完成
time.Sleep(200 * time.Millisecond)
// 关闭连接
_ = clientConn.Close()
_ = tlsServerConn.Close()
// 等待处理完成
done := make(chan struct{})
go func() {
server.connWg.Wait()
close(done)
}()
select {
case <-done:
t.Log("Connection handling completed after handshake failure")
case <-time.After(2 * time.Second):
t.Log("Timeout waiting for connection handling")
}
}
// TestALPNNegotiationH2 测试 ALPN 协商选择 h2。
// TestALPNHTTP11Fallback 测试 ALPN 协商回退到 HTTP/1.1。
func TestALPNHTTP11Fallback(t *testing.T) {
cert, _ := generateTestCert(t)
handler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("HTTP/1.1 response")
ctx.SetStatusCode(fasthttp.StatusOK)
}
cfg := &config.HTTP2Config{
Enabled: true,
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{"h2", "http/1.1"},
}
server, err := NewServer(cfg, handler, tlsConfig)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
serverConn, clientConn := net.Pipe()
defer func() {
_ = serverConn.Close()
_ = clientConn.Close()
}()
tlsServerConn := tls.Server(serverConn, tlsConfig)
// 需要先 Add(1) 因为 handleConnection 会调用 Done()
server.connWg.Add(1)
go func() {
server.handleConnection(tlsServerConn)
}()
// 客户端只支持 HTTP/1.1
tlsClientConfig := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"http/1.1"},
}
tlsClientConn := tls.Client(clientConn, tlsClientConfig)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := tlsClientConn.HandshakeContext(ctx); err != nil {
t.Fatalf("TLS handshake failed: %v", err)
}
// 验证协商的协议是 http/1.1
state := tlsClientConn.ConnectionState()
if state.NegotiatedProtocol != "http/1.1" {
t.Errorf("Expected negotiated protocol 'http/1.1', got '%s'", state.NegotiatedProtocol)
}
_ = tlsClientConn.Close()
_ = tlsServerConn.Close()
}
// TestTLSListenerWrapper 测试 TLS 监听器包装。
// TestTLSListenerExistingProtos 测试已有 NextProtos 的情况。
// TestServeHTTP1Fallback 测试 HTTP/1.1 回退。
func TestServeHTTP1Fallback(t *testing.T) {
handler := func(ctx *fasthttp.RequestCtx) {
ctx.WriteString("HTTP/1.1 response")
ctx.SetStatusCode(fasthttp.StatusOK)
ctx.Response.Header.Set("X-Test", "value")
}
cfg := &config.HTTP2Config{
Enabled: true,
}
server, err := NewServer(cfg, handler, nil)
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
serverConn, clientConn := net.Pipe()
defer func() {
_ = serverConn.Close()
_ = clientConn.Close()
}()
var wg sync.WaitGroup
wg.Go(func() {
server.serveHTTP1(serverConn)
})
// 发送 HTTP/1.1 请求
request := "GET /test HTTP/1.1\r\nHost: localhost\r\n\r\n"
_, _ = clientConn.Write([]byte(request))
// 读取响应
buf := make([]byte, 1024)
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := clientConn.Read(buf)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
response := string(buf[:n])
if response == "" {
t.Error("Expected non-empty response")
}
// 关闭连接
_ = clientConn.Close()
wg.Wait()
}
// TestConnectionPoolOperations 测试连接池操作。
func TestConnectionPoolOperations(t *testing.T) {
pool := newConnectionPool()
// 创建模拟连接
conn1 := &mockTestConn{remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}}
conn2 := &mockTestConn{remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12346}}
// 添加连接
pool.add("client1", conn1)
pool.add("client1", conn2)
// 验证连接数
if count := len(pool.conns["client1"]); count != 2 {
t.Errorf("Expected 2 connections, got %d", count)
}
// 获取连接
conns := pool.conns["client1"]
if len(conns) != 2 {
t.Errorf("Expected 2 connections, got %d", len(conns))
}
// 移除连接
pool.remove("client1", conn1)
if count := len(pool.conns["client1"]); count != 1 {
t.Errorf("Expected 1 connection after removal, got %d", count)
}
// 关闭所有连接
pool.closeAll()
if count := len(pool.conns["client1"]); count != 0 {
t.Errorf("Expected 0 connections after closeAll, got %d", count)
}
}
// mockTestConn 是用于测试的模拟连接。
type mockTestConn struct {
remoteAddr net.Addr
}
func (m *mockTestConn) Read(_ []byte) (n int, err error) { return 0, nil }
func (m *mockTestConn) Write(_ []byte) (n int, err error) { return 0, nil }
func (m *mockTestConn) Close() error { return nil }
func (m *mockTestConn) LocalAddr() net.Addr { return m.remoteAddr }
func (m *mockTestConn) RemoteAddr() net.Addr { return m.remoteAddr }
func (m *mockTestConn) SetDeadline(_ time.Time) error { return nil }
func (m *mockTestConn) SetReadDeadline(_ time.Time) error { return nil }
func (m *mockTestConn) SetWriteDeadline(_ time.Time) error {
return nil
}