lolly/internal/proxy/websocket_test.go
xfy 27e00b84a8 fix(proxy,handler,server,stream,ratelimit): fix resource leaks and functional bugs
- proxy/proxy.go: decrement connection count on dangerous path rejection
  (line 724) to prevent connection count leak
- handler/sendfile_linux.go: return *os.File from getSocketFile and let
  linuxSendfile close it, fixing EBADF from deferred close in getSocketFd
- proxy/websocket.go: return bufio.Reader from readWebSocketUpgradeResponse
  and wrap targetConn with bufferedConn to consume pre-buffered frame data,
  preventing first-frame loss
- server/pool.go: use non-blocking send after starting new worker to avoid
  deadlock when queue is full
- stream/stream.go: check stopCh on non-timeout UDP read errors to prevent
  infinite loop and shutdown deadlock
- middleware/ratelimit: replace select-based close guard with sync.Once in
  StopCleanup to prevent double-close panic
2026-06-11 16:35:10 +08:00

785 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package proxy 提供 WebSocket 代理功能的测试。
//
// 该文件测试 WebSocket 代理模块的各项功能,包括:
// - 桥接器创建
// - 桥接器关闭
// - 空连接处理
// - 连接关闭错误检测
// - 目标地址拨号
// - URL 解析
// - 数据复制
// - 双向数据转发
// - 超时错误处理
// - 并发连接测试
// - 大消息转发测试
//
// goroutine 泄漏检测说明:
// fasthttp 库使用后台 worker goroutine与 goleak 不兼容。
// 如需检测泄漏可手动运行go test -race ./internal/proxy/...
//
// 作者xfy
package proxy
import (
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"testing"
"time"
"github.com/valyala/fasthttp"
)
// TestNewWebSocketBridge 测试桥接器创建
func TestNewWebSocketBridge(t *testing.T) {
clientConn, _ := net.Pipe()
targetConn, _ := net.Pipe()
defer func() { _ = clientConn.Close() }()
defer func() { _ = targetConn.Close() }()
bridge := NewWebSocketBridge(clientConn, targetConn)
if bridge == nil {
t.Fatal("Expected non-nil bridge")
}
if bridge.clientConn != clientConn {
t.Error("Expected clientConn to be set")
}
if bridge.targetConn != targetConn {
t.Error("Expected targetConn to be set")
}
if bridge.closed {
t.Error("Expected closed to be false initially")
}
}
// TestWebSocketBridge_Close 测试关闭桥接器
func TestWebSocketBridge_Close(t *testing.T) {
clientConn, client2 := net.Pipe()
targetConn, target2 := net.Pipe()
bridge := NewWebSocketBridge(clientConn, targetConn)
// 关闭桥接器
err := bridge.Close()
if err != nil {
t.Errorf("Expected nil error, got: %v", err)
}
// 验证连接已关闭 - 写入应该失败
_, err = client2.Write([]byte("test"))
if err == nil {
t.Error("Expected error writing to closed connection")
}
_ = target2
// 重复关闭应该安全
err = bridge.Close()
if err != nil {
t.Errorf("Expected nil error on double close, got: %v", err)
}
}
// TestWebSocketBridge_Close_NilConnections 测试空连接的关闭
func TestWebSocketBridge_Close_NilConnections(t *testing.T) {
bridge := &WebSocketBridge{
clientConn: nil,
targetConn: nil,
closed: false,
}
err := bridge.Close()
if err != nil {
t.Errorf("Expected nil error for nil connections, got: %v", err)
}
}
// TestIsConnectionClosedError 测试连接关闭错误检测
func TestIsConnectionClosedError(t *testing.T) {
tests := []struct {
err error
name string
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "EOF",
err: io.EOF,
expected: true,
},
{
name: "other error",
err: errors.New("some other error"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isConnectionClosedError(tt.err)
if result != tt.expected {
t.Errorf("isConnectionClosedError(%v) = %v, expected %v", tt.err, result, tt.expected)
}
})
}
}
// TestDialTarget_InvalidAddress 测试无效地址的拨号
func TestDialTarget_InvalidAddress(t *testing.T) {
// 测试连接到无效端口
_, err := dialTarget("http://127.0.0.1:1", 100*time.Millisecond)
if err == nil {
t.Error("Expected error for invalid address")
}
}
// TestDialTarget_HTTPS 测试 HTTPS 连接(会失败,但验证错误处理)
func TestDialTarget_HTTPS(t *testing.T) {
// 测试 HTTPS 连接到无效端口
_, err := dialTarget("https://127.0.0.1:1", 100*time.Millisecond)
if err == nil {
t.Error("Expected error for invalid HTTPS address")
}
}
// mockNetError 模拟网络错误
type mockNetError struct {
msg string
}
func (e *mockNetError) Error() string { return e.msg }
func (e *mockNetError) Timeout() bool { return true }
func (e *mockNetError) Temporary() bool { return false }
// TestIsConnectionClosedError_Timeout 测试超时错误
func TestIsConnectionClosedError_Timeout(t *testing.T) {
timeoutErr := &mockNetError{msg: "timeout"}
result := isConnectionClosedError(timeoutErr)
if !result {
t.Error("Expected timeout error to be treated as closed connection error")
}
}
// TestWebSocketBridge_Bridge 测试双向数据转发
func TestWebSocketBridge_Bridge(t *testing.T) {
// 创建管道连接
client1, client2 := net.Pipe()
target1, target2 := net.Pipe()
defer func() { _ = client2.Close() }()
defer func() { _ = target2.Close() }()
bridge := NewWebSocketBridge(client1, target1)
// 启动桥接(在 goroutine 中)
errCh := make(chan error, 1)
go func() {
errCh <- bridge.Bridge()
}()
// 发送数据从客户端到后端
testData := []byte("hello from client")
go func() {
_, _ = client2.Write(testData)
}()
// 在后端读取数据
buf := make([]byte, 1024)
n, err := target2.Read(buf)
if err != nil {
t.Fatalf("Failed to read from target: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected %q, got %q", string(testData), string(buf[:n]))
}
// 发送数据从后端到客户端
testData2 := []byte("hello from target")
go func() {
_, _ = target2.Write(testData2)
}()
// 在客户端读取数据
buf2 := make([]byte, 1024)
n, err = client2.Read(buf2)
if err != nil {
t.Fatalf("Failed to read from client: %v", err)
}
if string(buf2[:n]) != string(testData2) {
t.Errorf("Expected %q, got %q", string(testData2), string(buf2[:n]))
}
// 关闭连接以结束桥接
_ = client2.Close()
_ = target2.Close()
// 等待桥接完成
select {
case err := <-errCh:
if err != nil {
t.Errorf("Bridge returned error: %v", err)
}
case <-time.After(1 * time.Second):
t.Error("Bridge did not complete in time")
}
}
// TestDialTarget_URLParsing 测试 URL 解析
func TestDialTarget_URLParsing(t *testing.T) {
tests := []struct {
name string
url string
expectError bool
}{
{
name: "http URL with invalid port",
url: "http://127.0.0.1:1",
expectError: true, // 连接会失败
},
{
name: "https URL with invalid port",
url: "https://127.0.0.1:1",
expectError: true, // 连接会失败
},
{
name: "URL with path",
url: "http://127.0.0.1:1/ws",
expectError: true, // 连接会失败
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := dialTarget(tt.url, 10*time.Millisecond)
if tt.expectError {
if err == nil {
t.Error("Expected error, got nil")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
})
}
}
// TestCopyData 测试数据复制
func TestCopyData(t *testing.T) {
// 创建管道连接
src1, src2 := net.Pipe()
dst1, dst2 := net.Pipe()
defer func() { _ = src2.Close() }()
defer func() { _ = dst2.Close() }()
bridge := &WebSocketBridge{}
// 启动数据复制
errCh := make(chan error, 1)
go func() {
errCh <- bridge.copyData(dst1, src1, "test")
}()
// 发送数据
testData := []byte("test data")
_, _ = src2.Write(testData)
// 接收数据
buf := make([]byte, 1024)
n, err := dst2.Read(buf)
if err != nil {
t.Fatalf("Failed to read: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected %q, got %q", string(testData), string(buf[:n]))
}
// 关闭连接
_ = src2.Close()
_ = dst2.Close()
// 等待复制完成
select {
case err := <-errCh:
// 连接关闭错误应返回 nil
if err != nil && !strings.Contains(err.Error(), "closed") {
t.Errorf("copyData returned unexpected error: %v", err)
}
case <-time.After(1 * time.Second):
t.Error("copyData did not complete in time")
}
}
// TestBuildWebSocketUpgradeRequest 测试构建 WebSocket 升级请求
func TestBuildWebSocketUpgradeRequest(t *testing.T) {
tests := []struct {
name string
path string
query string
host string
targetHost string
wantContains []string
}{
{
name: "basic request",
path: "/ws",
query: "",
host: "client.example.com",
targetHost: "backend.example.com:8080",
wantContains: []string{
"GET /ws HTTP/1.1",
"Host: backend.example.com:8080",
"X-Forwarded-For:",
"X-Real-IP:",
"X-Forwarded-Host: client.example.com",
},
},
{
name: "request with query",
path: "/ws",
query: "token=abc123",
host: "client.example.com",
targetHost: "backend.example.com",
wantContains: []string{
"GET /ws?token=abc123 HTTP/1.1",
"Host: backend.example.com",
},
},
{
name: "empty path defaults to slash",
path: "",
query: "",
host: "client.example.com",
targetHost: "backend.example.com",
wantContains: []string{
"GET / HTTP/1.1",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI(tt.path)
if tt.query != "" {
ctx.QueryArgs().Parse(tt.query)
}
ctx.Request.Header.SetHost(tt.host)
result := buildWebSocketUpgradeRequest(ctx, tt.targetHost, nil)
for _, want := range tt.wantContains {
if !strings.Contains(result, want) {
t.Errorf("buildWebSocketUpgradeRequest() missing %q in output:\n%s", want, result)
}
}
})
}
}
// TestBuildWebSocketUpgradeRequest_WithHeaders 测试复制 WebSocket 头
func TestBuildWebSocketUpgradeRequest_WithHeaders(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/ws")
ctx.Request.Header.Set("Upgrade", "websocket")
ctx.Request.Header.Set("Connection", "Upgrade")
ctx.Request.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
ctx.Request.Header.Set("Sec-WebSocket-Version", "13")
ctx.Request.Header.Set("Sec-WebSocket-Protocol", "chat")
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com", nil)
// 验证关键头被复制
expectedHeaders := []string{
"Upgrade: websocket",
"Connection: Upgrade",
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version: 13",
"Sec-WebSocket-Protocol: chat",
}
for _, expected := range expectedHeaders {
if !strings.Contains(result, expected) {
t.Errorf("Missing expected header %q in:\n%s", expected, result)
}
}
}
// TestBuildWebSocketUpgradeRequest_TLSProto 测试 TLS 协议标记
func TestBuildWebSocketUpgradeRequest_TLSProto(t *testing.T) {
tests := []struct {
name string
wantProto string
isTLS bool
}{
{
name: "non-TLS connection",
isTLS: false,
wantProto: "X-Forwarded-Proto: http",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/ws")
// 注意fasthttp.RequestCtx 默认 IsTLS() 返回 false
// 无法在单元测试中直接模拟 TLS 连接
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com", nil)
if !strings.Contains(result, tt.wantProto) {
t.Errorf("Missing %q in:\n%s", tt.wantProto, result)
}
})
}
}
// TestExtractHost 测试从 URL 提取主机
func TestExtractHost(t *testing.T) {
tests := []struct {
name string
url string
expected string
}{
{
name: "http with port",
url: "http://example.com:8080",
expected: "example.com:8080",
},
{
name: "https with port",
url: "https://example.com:8443",
expected: "example.com:8443",
},
{
name: "http without port",
url: "http://example.com",
expected: "example.com:80",
},
{
name: "https without port",
url: "https://example.com",
expected: "example.com:443",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractHost(tt.url)
if result != tt.expected {
t.Errorf("extractHost(%q) = %q, want %q", tt.url, result, tt.expected)
}
})
}
}
// TestWriteUpgradeResponse 测试写入升级响应
func TestWriteUpgradeResponse(t *testing.T) {
// 创建管道连接
conn1, conn2 := net.Pipe()
defer func() { _ = conn2.Close() }()
// 创建模拟 HTTP 响应
resp := &http.Response{
ProtoMajor: 1,
ProtoMinor: 1,
Status: "101 Switching Protocols",
StatusCode: 101,
Header: http.Header{
"Upgrade": []string{"websocket"},
"Connection": []string{"Upgrade"},
},
}
// 启动写入
errCh := make(chan error, 1)
go func() {
errCh <- writeUpgradeResponse(conn1, resp)
_ = conn1.Close()
}()
// 读取响应
buf := make([]byte, 1024)
n, err := conn2.Read(buf)
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
response := string(buf[:n])
// 验证响应格式
expectedParts := []string{
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
}
for _, expected := range expectedParts {
if !strings.Contains(response, expected) {
t.Errorf("Missing %q in response:\n%s", expected, response)
}
}
// 等待写入完成
select {
case err := <-errCh:
if err != nil {
t.Errorf("writeUpgradeResponse returned error: %v", err)
}
case <-time.After(1 * time.Second):
t.Error("writeUpgradeResponse did not complete in time")
}
}
// TestReadWebSocketUpgradeResponse 测试读取升级响应
func TestReadWebSocketUpgradeResponse(t *testing.T) {
// 创建管道连接
conn1, conn2 := net.Pipe()
defer func() { _ = conn1.Close() }()
// 在另一个 goroutine 中写入响应
go func() {
response := "HTTP/1.1 101 Switching Protocols\r\n" +
"Upgrade: websocket\r\n" +
"Connection: Upgrade\r\n" +
"\r\n"
_, _ = conn2.Write([]byte(response))
_ = conn2.Close()
}()
// 读取响应
resp, _, err := readWebSocketUpgradeResponse(conn1, 1*time.Second)
if err != nil {
t.Fatalf("readWebSocketUpgradeResponse failed: %v", err)
}
if resp.StatusCode != 101 {
t.Errorf("Expected status 101, got %d", resp.StatusCode)
}
if resp.Header.Get("Upgrade") != "websocket" {
t.Errorf("Expected Upgrade: websocket, got %q", resp.Header.Get("Upgrade"))
}
}
// TestReadWebSocketUpgradeResponse_Timeout 测试读取超时
func TestReadWebSocketUpgradeResponse_Timeout(t *testing.T) {
// 创建管道连接但不写入数据
conn1, conn2 := net.Pipe()
defer func() { _ = conn1.Close() }()
defer func() { _ = conn2.Close() }()
// 使用很短的超时
_, _, err := readWebSocketUpgradeResponse(conn1, 10*time.Millisecond)
if err == nil {
t.Error("Expected timeout error, got nil")
}
}
// TestDialTarget_TLS 测试 TLS 连接(连接无效端口应失败)
func TestDialTarget_TLS(t *testing.T) {
// 测试 HTTPS 连接到无效端口
_, err := dialTarget("https://127.0.0.1:1", 100*time.Millisecond)
if err == nil {
t.Error("Expected error for invalid HTTPS address")
}
}
// TestIsConnectionClosedError_ClosedConn 测试已关闭连接错误
func TestIsConnectionClosedError_ClosedConn(t *testing.T) {
// 创建并立即关闭连接
ln, _ := net.Listen("tcp", "127.0.0.1:0")
conn, _ := net.Dial("tcp", ln.Addr().String())
_ = conn.Close()
_ = ln.Close()
// 尝试读取应返回错误
_, err := conn.Read(make([]byte, 1))
if err == nil {
t.Error("Expected error reading from closed connection")
}
// 验证错误被识别为连接关闭错误
if !isConnectionClosedError(err) {
t.Errorf("Expected closed connection error, got: %v", err)
}
}
// TestWebSocketBridge_LargeMessage 测试大消息转发
func TestWebSocketBridge_LargeMessage(t *testing.T) {
// 创建管道连接
client1, client2 := net.Pipe()
target1, target2 := net.Pipe()
defer func() { _ = client2.Close() }()
defer func() { _ = target2.Close() }()
bridge := NewWebSocketBridge(client1, target1)
// 启动桥接
errCh := make(chan error, 1)
go func() {
errCh <- bridge.Bridge()
}()
// 发送超过 64KB 的数据
largeData := make([]byte, 100*1024) // 100KB
for i := range largeData {
largeData[i] = byte(i % 256)
}
// 客户端发送大消息
go func() {
_, _ = client2.Write(largeData)
}()
// 后端接收数据
buf := make([]byte, 150*1024)
total := 0
for total < len(largeData) {
n, err := target2.Read(buf[total:])
if err != nil {
t.Fatalf("Failed to read large message: %v", err)
}
total += n
}
// 验证数据完整性
for i := range largeData {
if buf[i] != largeData[i] {
t.Errorf("Data mismatch at byte %d: got %d, want %d", i, buf[i], largeData[i])
break
}
}
// 关闭连接
_ = client2.Close()
_ = target2.Close()
// 等待桥接完成
select {
case err := <-errCh:
if err != nil {
t.Errorf("Bridge returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Error("Bridge did not complete in time")
}
}
// TestWebSocketBridge_Concurrent 测试并发桥接
func TestWebSocketBridge_Concurrent(t *testing.T) {
const numBridges = 10
var wg sync.WaitGroup
errCh := make(chan error, numBridges)
for i := range numBridges {
wg.Add(1)
go func(id int) {
defer wg.Done()
// 创建管道连接
client1, client2 := net.Pipe()
target1, target2 := net.Pipe()
defer func() { _ = client2.Close() }()
defer func() { _ = target2.Close() }()
bridge := NewWebSocketBridge(client1, target1)
// 启动桥接
done := make(chan error, 1)
go func() {
done <- bridge.Bridge()
}()
// 发送测试数据
testData := []byte("concurrent test data")
go func() {
_, _ = client2.Write(testData)
}()
// 接收数据
buf := make([]byte, 1024)
n, err := target2.Read(buf)
if err != nil {
errCh <- fmt.Errorf("bridge %d: read error: %v", id, err)
return
}
if string(buf[:n]) != string(testData) {
errCh <- fmt.Errorf("bridge %d: data mismatch", id)
return
}
// 关闭连接
_ = client2.Close()
_ = target2.Close()
// 等待桥接完成
select {
case err := <-done:
if err != nil {
errCh <- fmt.Errorf("bridge %d: %v", id, err)
}
case <-time.After(1 * time.Second):
errCh <- fmt.Errorf("bridge %d: timeout", id)
}
}(i)
}
wg.Wait()
close(errCh)
// 检查错误
for err := range errCh {
if err != nil {
t.Errorf("Concurrent bridge error: %v", err)
}
}
}
// TestCopyData_WriteError 测试写入错误处理
func TestCopyData_WriteError(t *testing.T) {
// 创建管道连接
src1, src2 := net.Pipe()
dst1, dst2 := net.Pipe()
bridge := &WebSocketBridge{}
// 先关闭目标连接
_ = dst1.Close()
_ = dst2.Close()
// 启动数据复制
errCh := make(chan error, 1)
go func() {
errCh <- bridge.copyData(dst1, src1, "test")
}()
// 发送数据(应该触发写入错误)
_, _ = src2.Write([]byte("test data"))
_ = src2.Close()
// 等待完成
select {
case err := <-errCh:
// 写入到已关闭连接应该返回 nil视为连接关闭错误
if err != nil && !strings.Contains(err.Error(), "closed") {
t.Errorf("copyData returned unexpected error: %v", err)
}
case <-time.After(1 * time.Second):
t.Error("copyData did not complete in time")
}
_ = src1.Close()
}