使用空白标识符忽略测试辅助函数中 Close、ReadFrom、Set 等返回值, 与主代码风格保持一致。 Co-Authored-By: Claude <noreply@anthropic.com>
301 lines
7.0 KiB
Go
301 lines
7.0 KiB
Go
package proxy
|
|
|
|
import (
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// TestNewWebSocketBridge 测试桥接器创建
|
|
func TestNewWebSocketBridge(t *testing.T) {
|
|
clientConn, _ := net.Pipe()
|
|
targetConn, _ := net.Pipe()
|
|
defer func() { _ = clientConn.Close() }()
|
|
defer func() { _ = targetConn.Close() }()
|
|
|
|
bridge := NewWebSocketBridge(clientConn, targetConn)
|
|
|
|
if bridge == nil {
|
|
t.Error("Expected non-nil bridge")
|
|
}
|
|
if bridge.clientConn != clientConn {
|
|
t.Error("Expected clientConn to be set")
|
|
}
|
|
if bridge.targetConn != targetConn {
|
|
t.Error("Expected targetConn to be set")
|
|
}
|
|
if bridge.closed != false {
|
|
t.Error("Expected closed to be false initially")
|
|
}
|
|
}
|
|
|
|
// TestWebSocketBridge_Close 测试关闭桥接器
|
|
func TestWebSocketBridge_Close(t *testing.T) {
|
|
clientConn, client2 := net.Pipe()
|
|
targetConn, target2 := net.Pipe()
|
|
|
|
bridge := NewWebSocketBridge(clientConn, targetConn)
|
|
|
|
// 关闭桥接器
|
|
err := bridge.Close()
|
|
if err != nil {
|
|
t.Errorf("Expected nil error, got: %v", err)
|
|
}
|
|
|
|
// 验证连接已关闭 - 写入应该失败
|
|
_, err = client2.Write([]byte("test"))
|
|
if err == nil {
|
|
t.Error("Expected error writing to closed connection")
|
|
}
|
|
|
|
_ = target2
|
|
|
|
// 重复关闭应该安全
|
|
err = bridge.Close()
|
|
if err != nil {
|
|
t.Errorf("Expected nil error on double close, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestWebSocketBridge_Close_NilConnections 测试空连接的关闭
|
|
func TestWebSocketBridge_Close_NilConnections(t *testing.T) {
|
|
bridge := &WebSocketBridge{
|
|
clientConn: nil,
|
|
targetConn: nil,
|
|
closed: false,
|
|
}
|
|
|
|
err := bridge.Close()
|
|
if err != nil {
|
|
t.Errorf("Expected nil error for nil connections, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestIsConnectionClosedError 测试连接关闭错误检测
|
|
func TestIsConnectionClosedError(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "nil error",
|
|
err: nil,
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "EOF",
|
|
err: io.EOF,
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "other error",
|
|
err: errors.New("some other error"),
|
|
expected: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := isConnectionClosedError(tt.err)
|
|
if result != tt.expected {
|
|
t.Errorf("isConnectionClosedError(%v) = %v, expected %v", tt.err, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestExtractHost 测试从 URL 提取主机
|
|
func TestExtractHost(t *testing.T) {
|
|
// extractHost 函数可能不存在,检查一下
|
|
// 如果存在则测试
|
|
}
|
|
|
|
// TestDialTarget_InvalidAddress 测试无效地址的拨号
|
|
func TestDialTarget_InvalidAddress(t *testing.T) {
|
|
// 测试连接到无效端口
|
|
_, err := dialTarget("http://127.0.0.1:1", 100*time.Millisecond)
|
|
if err == nil {
|
|
t.Error("Expected error for invalid address")
|
|
}
|
|
}
|
|
|
|
// TestDialTarget_HTTPS 测试 HTTPS 连接(会失败,但验证错误处理)
|
|
func TestDialTarget_HTTPS(t *testing.T) {
|
|
// 测试 HTTPS 连接到无效端口
|
|
_, err := dialTarget("https://127.0.0.1:1", 100*time.Millisecond)
|
|
if err == nil {
|
|
t.Error("Expected error for invalid HTTPS address")
|
|
}
|
|
}
|
|
|
|
// mockNetError 模拟网络错误
|
|
type mockNetError struct {
|
|
msg string
|
|
}
|
|
|
|
func (e *mockNetError) Error() string { return e.msg }
|
|
func (e *mockNetError) Timeout() bool { return true }
|
|
func (e *mockNetError) Temporary() bool { return false }
|
|
|
|
// TestIsConnectionClosedError_Timeout 测试超时错误
|
|
func TestIsConnectionClosedError_Timeout(t *testing.T) {
|
|
timeoutErr := &mockNetError{msg: "timeout"}
|
|
result := isConnectionClosedError(timeoutErr)
|
|
if !result {
|
|
t.Error("Expected timeout error to be treated as closed connection error")
|
|
}
|
|
}
|
|
|
|
// TestWebSocketBridge_Bridge 测试双向数据转发
|
|
func TestWebSocketBridge_Bridge(t *testing.T) {
|
|
// 创建管道连接
|
|
client1, client2 := net.Pipe()
|
|
target1, target2 := net.Pipe()
|
|
defer func() { _ = client2.Close() }()
|
|
defer func() { _ = target2.Close() }()
|
|
|
|
bridge := NewWebSocketBridge(client1, target1)
|
|
|
|
// 启动桥接(在 goroutine 中)
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- bridge.Bridge()
|
|
}()
|
|
|
|
// 发送数据从客户端到后端
|
|
testData := []byte("hello from client")
|
|
go func() {
|
|
_, _ = client2.Write(testData)
|
|
}()
|
|
|
|
// 在后端读取数据
|
|
buf := make([]byte, 1024)
|
|
n, err := target2.Read(buf)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read from target: %v", err)
|
|
}
|
|
if string(buf[:n]) != string(testData) {
|
|
t.Errorf("Expected %q, got %q", string(testData), string(buf[:n]))
|
|
}
|
|
|
|
// 发送数据从后端到客户端
|
|
testData2 := []byte("hello from target")
|
|
go func() {
|
|
_, _ = target2.Write(testData2)
|
|
}()
|
|
|
|
// 在客户端读取数据
|
|
buf2 := make([]byte, 1024)
|
|
n, err = client2.Read(buf2)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read from client: %v", err)
|
|
}
|
|
if string(buf2[:n]) != string(testData2) {
|
|
t.Errorf("Expected %q, got %q", string(testData2), string(buf2[:n]))
|
|
}
|
|
|
|
// 关闭连接以结束桥接
|
|
_ = client2.Close()
|
|
_ = target2.Close()
|
|
|
|
// 等待桥接完成
|
|
select {
|
|
case err := <-errCh:
|
|
if err != nil {
|
|
t.Errorf("Bridge returned error: %v", err)
|
|
}
|
|
case <-time.After(1 * time.Second):
|
|
t.Error("Bridge did not complete in time")
|
|
}
|
|
}
|
|
|
|
// TestDialTarget_URLParsing 测试 URL 解析
|
|
func TestDialTarget_URLParsing(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
url string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "http URL with invalid port",
|
|
url: "http://127.0.0.1:1",
|
|
expectError: true, // 连接会失败
|
|
},
|
|
{
|
|
name: "https URL with invalid port",
|
|
url: "https://127.0.0.1:1",
|
|
expectError: true, // 连接会失败
|
|
},
|
|
{
|
|
name: "URL with path",
|
|
url: "http://127.0.0.1:1/ws",
|
|
expectError: true, // 连接会失败
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
_, err := dialTarget(tt.url, 10*time.Millisecond)
|
|
if tt.expectError {
|
|
if err == nil {
|
|
t.Error("Expected error, got nil")
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestCopyData 测试数据复制
|
|
func TestCopyData(t *testing.T) {
|
|
// 创建管道连接
|
|
src1, src2 := net.Pipe()
|
|
dst1, dst2 := net.Pipe()
|
|
defer func() { _ = src2.Close() }()
|
|
defer func() { _ = dst2.Close() }()
|
|
|
|
bridge := &WebSocketBridge{}
|
|
|
|
// 启动数据复制
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- bridge.copyData(dst1, src1, "test")
|
|
}()
|
|
|
|
// 发送数据
|
|
testData := []byte("test data")
|
|
_, _ = src2.Write(testData)
|
|
|
|
// 接收数据
|
|
buf := make([]byte, 1024)
|
|
n, err := dst2.Read(buf)
|
|
if err != nil {
|
|
t.Fatalf("Failed to read: %v", err)
|
|
}
|
|
if string(buf[:n]) != string(testData) {
|
|
t.Errorf("Expected %q, got %q", string(testData), string(buf[:n]))
|
|
}
|
|
|
|
// 关闭连接
|
|
_ = src2.Close()
|
|
_ = dst2.Close()
|
|
|
|
// 等待复制完成
|
|
select {
|
|
case err := <-errCh:
|
|
// 连接关闭错误应返回 nil
|
|
if err != nil && !strings.Contains(err.Error(), "closed") {
|
|
t.Errorf("copyData returned unexpected error: %v", err)
|
|
}
|
|
case <-time.After(1 * time.Second):
|
|
t.Error("copyData did not complete in time")
|
|
}
|
|
}
|