提取硬编码字符串为命名常量: - upstreamCache = "CACHE" - protoHTTPS = "https" ProxyWebSocket → WebSocket 适配 variable.Context 重命名 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
314 lines
7.3 KiB
Go
314 lines
7.3 KiB
Go
// Package proxy 提供 WebSocket 代理功能的测试。
|
||
//
|
||
// 该文件测试 WebSocket 代理模块的各项功能,包括:
|
||
// - 桥接器创建
|
||
// - 桥接器关闭
|
||
// - 空连接处理
|
||
// - 连接关闭错误检测
|
||
// - 目标地址拨号
|
||
// - URL 解析
|
||
// - 数据复制
|
||
// - 双向数据转发
|
||
// - 超时错误处理
|
||
//
|
||
// 作者:xfy
|
||
package proxy
|
||
|
||
import (
|
||
"errors"
|
||
"io"
|
||
"net"
|
||
"strings"
|
||
"testing"
|
||
"time"
|
||
)
|
||
|
||
// TestNewWebSocketBridge 测试桥接器创建
|
||
func TestNewWebSocketBridge(t *testing.T) {
|
||
clientConn, _ := net.Pipe()
|
||
targetConn, _ := net.Pipe()
|
||
defer func() { _ = clientConn.Close() }()
|
||
defer func() { _ = targetConn.Close() }()
|
||
|
||
bridge := NewWebSocketBridge(clientConn, targetConn)
|
||
if bridge == nil {
|
||
t.Fatal("Expected non-nil bridge")
|
||
}
|
||
if bridge.clientConn != clientConn {
|
||
t.Error("Expected clientConn to be set")
|
||
}
|
||
if bridge.targetConn != targetConn {
|
||
t.Error("Expected targetConn to be set")
|
||
}
|
||
if bridge.closed {
|
||
t.Error("Expected closed to be false initially")
|
||
}
|
||
}
|
||
|
||
// TestWebSocketBridge_Close 测试关闭桥接器
|
||
func TestWebSocketBridge_Close(t *testing.T) {
|
||
clientConn, client2 := net.Pipe()
|
||
targetConn, target2 := net.Pipe()
|
||
|
||
bridge := NewWebSocketBridge(clientConn, targetConn)
|
||
|
||
// 关闭桥接器
|
||
err := bridge.Close()
|
||
if err != nil {
|
||
t.Errorf("Expected nil error, got: %v", err)
|
||
}
|
||
|
||
// 验证连接已关闭 - 写入应该失败
|
||
_, err = client2.Write([]byte("test"))
|
||
if err == nil {
|
||
t.Error("Expected error writing to closed connection")
|
||
}
|
||
|
||
_ = target2
|
||
|
||
// 重复关闭应该安全
|
||
err = bridge.Close()
|
||
if err != nil {
|
||
t.Errorf("Expected nil error on double close, got: %v", err)
|
||
}
|
||
}
|
||
|
||
// TestWebSocketBridge_Close_NilConnections 测试空连接的关闭
|
||
func TestWebSocketBridge_Close_NilConnections(t *testing.T) {
|
||
bridge := &WebSocketBridge{
|
||
clientConn: nil,
|
||
targetConn: nil,
|
||
closed: false,
|
||
}
|
||
|
||
err := bridge.Close()
|
||
if err != nil {
|
||
t.Errorf("Expected nil error for nil connections, got: %v", err)
|
||
}
|
||
}
|
||
|
||
// TestIsConnectionClosedError 测试连接关闭错误检测
|
||
func TestIsConnectionClosedError(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
err error
|
||
expected bool
|
||
}{
|
||
{
|
||
name: "nil error",
|
||
err: nil,
|
||
expected: false,
|
||
},
|
||
{
|
||
name: "EOF",
|
||
err: io.EOF,
|
||
expected: true,
|
||
},
|
||
{
|
||
name: "other error",
|
||
err: errors.New("some other error"),
|
||
expected: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
result := isConnectionClosedError(tt.err)
|
||
if result != tt.expected {
|
||
t.Errorf("isConnectionClosedError(%v) = %v, expected %v", tt.err, result, tt.expected)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestExtractHost 测试从 URL 提取主机
|
||
func TestExtractHost(_ *testing.T) {
|
||
// extractHost 函数可能不存在,检查一下
|
||
// 如果存在则测试
|
||
}
|
||
|
||
// TestDialTarget_InvalidAddress 测试无效地址的拨号
|
||
func TestDialTarget_InvalidAddress(t *testing.T) {
|
||
// 测试连接到无效端口
|
||
_, err := dialTarget("http://127.0.0.1:1", 100*time.Millisecond)
|
||
if err == nil {
|
||
t.Error("Expected error for invalid address")
|
||
}
|
||
}
|
||
|
||
// TestDialTarget_HTTPS 测试 HTTPS 连接(会失败,但验证错误处理)
|
||
func TestDialTarget_HTTPS(t *testing.T) {
|
||
// 测试 HTTPS 连接到无效端口
|
||
_, err := dialTarget("https://127.0.0.1:1", 100*time.Millisecond)
|
||
if err == nil {
|
||
t.Error("Expected error for invalid HTTPS address")
|
||
}
|
||
}
|
||
|
||
// mockNetError 模拟网络错误
|
||
type mockNetError struct {
|
||
msg string
|
||
}
|
||
|
||
func (e *mockNetError) Error() string { return e.msg }
|
||
func (e *mockNetError) Timeout() bool { return true }
|
||
func (e *mockNetError) Temporary() bool { return false }
|
||
|
||
// TestIsConnectionClosedError_Timeout 测试超时错误
|
||
func TestIsConnectionClosedError_Timeout(t *testing.T) {
|
||
timeoutErr := &mockNetError{msg: "timeout"}
|
||
result := isConnectionClosedError(timeoutErr)
|
||
if !result {
|
||
t.Error("Expected timeout error to be treated as closed connection error")
|
||
}
|
||
}
|
||
|
||
// TestWebSocketBridge_Bridge 测试双向数据转发
|
||
func TestWebSocketBridge_Bridge(t *testing.T) {
|
||
// 创建管道连接
|
||
client1, client2 := net.Pipe()
|
||
target1, target2 := net.Pipe()
|
||
defer func() { _ = client2.Close() }()
|
||
defer func() { _ = target2.Close() }()
|
||
|
||
bridge := NewWebSocketBridge(client1, target1)
|
||
|
||
// 启动桥接(在 goroutine 中)
|
||
errCh := make(chan error, 1)
|
||
go func() {
|
||
errCh <- bridge.Bridge()
|
||
}()
|
||
|
||
// 发送数据从客户端到后端
|
||
testData := []byte("hello from client")
|
||
go func() {
|
||
_, _ = client2.Write(testData)
|
||
}()
|
||
|
||
// 在后端读取数据
|
||
buf := make([]byte, 1024)
|
||
n, err := target2.Read(buf)
|
||
if err != nil {
|
||
t.Fatalf("Failed to read from target: %v", err)
|
||
}
|
||
if string(buf[:n]) != string(testData) {
|
||
t.Errorf("Expected %q, got %q", string(testData), string(buf[:n]))
|
||
}
|
||
|
||
// 发送数据从后端到客户端
|
||
testData2 := []byte("hello from target")
|
||
go func() {
|
||
_, _ = target2.Write(testData2)
|
||
}()
|
||
|
||
// 在客户端读取数据
|
||
buf2 := make([]byte, 1024)
|
||
n, err = client2.Read(buf2)
|
||
if err != nil {
|
||
t.Fatalf("Failed to read from client: %v", err)
|
||
}
|
||
if string(buf2[:n]) != string(testData2) {
|
||
t.Errorf("Expected %q, got %q", string(testData2), string(buf2[:n]))
|
||
}
|
||
|
||
// 关闭连接以结束桥接
|
||
_ = client2.Close()
|
||
_ = target2.Close()
|
||
|
||
// 等待桥接完成
|
||
select {
|
||
case err := <-errCh:
|
||
if err != nil {
|
||
t.Errorf("Bridge returned error: %v", err)
|
||
}
|
||
case <-time.After(1 * time.Second):
|
||
t.Error("Bridge did not complete in time")
|
||
}
|
||
}
|
||
|
||
// TestDialTarget_URLParsing 测试 URL 解析
|
||
func TestDialTarget_URLParsing(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
url string
|
||
expectError bool
|
||
}{
|
||
{
|
||
name: "http URL with invalid port",
|
||
url: "http://127.0.0.1:1",
|
||
expectError: true, // 连接会失败
|
||
},
|
||
{
|
||
name: "https URL with invalid port",
|
||
url: "https://127.0.0.1:1",
|
||
expectError: true, // 连接会失败
|
||
},
|
||
{
|
||
name: "URL with path",
|
||
url: "http://127.0.0.1:1/ws",
|
||
expectError: true, // 连接会失败
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
_, err := dialTarget(tt.url, 10*time.Millisecond)
|
||
if tt.expectError {
|
||
if err == nil {
|
||
t.Error("Expected error, got nil")
|
||
}
|
||
} else {
|
||
if err != nil {
|
||
t.Errorf("Unexpected error: %v", err)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestCopyData 测试数据复制
|
||
func TestCopyData(t *testing.T) {
|
||
// 创建管道连接
|
||
src1, src2 := net.Pipe()
|
||
dst1, dst2 := net.Pipe()
|
||
defer func() { _ = src2.Close() }()
|
||
defer func() { _ = dst2.Close() }()
|
||
|
||
bridge := &WebSocketBridge{}
|
||
|
||
// 启动数据复制
|
||
errCh := make(chan error, 1)
|
||
go func() {
|
||
errCh <- bridge.copyData(dst1, src1, "test")
|
||
}()
|
||
|
||
// 发送数据
|
||
testData := []byte("test data")
|
||
_, _ = src2.Write(testData)
|
||
|
||
// 接收数据
|
||
buf := make([]byte, 1024)
|
||
n, err := dst2.Read(buf)
|
||
if err != nil {
|
||
t.Fatalf("Failed to read: %v", err)
|
||
}
|
||
if string(buf[:n]) != string(testData) {
|
||
t.Errorf("Expected %q, got %q", string(testData), string(buf[:n]))
|
||
}
|
||
|
||
// 关闭连接
|
||
_ = src2.Close()
|
||
_ = dst2.Close()
|
||
|
||
// 等待复制完成
|
||
select {
|
||
case err := <-errCh:
|
||
// 连接关闭错误应返回 nil
|
||
if err != nil && !strings.Contains(err.Error(), "closed") {
|
||
t.Errorf("copyData returned unexpected error: %v", err)
|
||
}
|
||
case <-time.After(1 * time.Second):
|
||
t.Error("copyData did not complete in time")
|
||
}
|
||
}
|