lolly/internal/proxy/websocket_test.go
xfy f2352ab9cc docs(config,stream,logging,handler,proxy,cache,server,ssl,middleware): 为核心模块添加详细 GoDoc 文档注释
- config: 为 Config 和所有子配置结构添加完整文档,包含使用示例和注意事项
- stream: 为负载均衡器和服务器添加详细的参数、返回值和功能说明
- logging: 为日志格式化和输出函数添加文档,说明支持的变量替换
- handler: 为路由器、静态文件和 sendfile 处理器添加文档
- proxy: 为健康检查器和代理功能添加完整文档
- cache/server/ssl/middleware: 补充相关模块的文档注释
- config.example.yaml: 添加可信代理配置、加密套件示例,更新压缩级别说明

Co-Authored-By: Claude <noreply@anthropic.com>
2026-04-07 15:36:09 +08:00

314 lines
7.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package proxy 提供 WebSocket 代理功能的测试。
//
// 该文件测试 WebSocket 代理模块的各项功能,包括:
// - 桥接器创建
// - 桥接器关闭
// - 空连接处理
// - 连接关闭错误检测
// - 目标地址拨号
// - URL 解析
// - 数据复制
// - 双向数据转发
// - 超时错误处理
//
// 作者xfy
package proxy
import (
"errors"
"io"
"net"
"strings"
"testing"
"time"
)
// TestNewWebSocketBridge 测试桥接器创建
func TestNewWebSocketBridge(t *testing.T) {
clientConn, _ := net.Pipe()
targetConn, _ := net.Pipe()
defer func() { _ = clientConn.Close() }()
defer func() { _ = targetConn.Close() }()
bridge := NewWebSocketBridge(clientConn, targetConn)
if bridge == nil {
t.Fatal("Expected non-nil bridge")
}
if bridge.clientConn != clientConn {
t.Error("Expected clientConn to be set")
}
if bridge.targetConn != targetConn {
t.Error("Expected targetConn to be set")
}
if bridge.closed {
t.Error("Expected closed to be false initially")
}
}
// TestWebSocketBridge_Close 测试关闭桥接器
func TestWebSocketBridge_Close(t *testing.T) {
clientConn, client2 := net.Pipe()
targetConn, target2 := net.Pipe()
bridge := NewWebSocketBridge(clientConn, targetConn)
// 关闭桥接器
err := bridge.Close()
if err != nil {
t.Errorf("Expected nil error, got: %v", err)
}
// 验证连接已关闭 - 写入应该失败
_, err = client2.Write([]byte("test"))
if err == nil {
t.Error("Expected error writing to closed connection")
}
_ = target2
// 重复关闭应该安全
err = bridge.Close()
if err != nil {
t.Errorf("Expected nil error on double close, got: %v", err)
}
}
// TestWebSocketBridge_Close_NilConnections 测试空连接的关闭
func TestWebSocketBridge_Close_NilConnections(t *testing.T) {
bridge := &WebSocketBridge{
clientConn: nil,
targetConn: nil,
closed: false,
}
err := bridge.Close()
if err != nil {
t.Errorf("Expected nil error for nil connections, got: %v", err)
}
}
// TestIsConnectionClosedError 测试连接关闭错误检测
func TestIsConnectionClosedError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "EOF",
err: io.EOF,
expected: true,
},
{
name: "other error",
err: errors.New("some other error"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isConnectionClosedError(tt.err)
if result != tt.expected {
t.Errorf("isConnectionClosedError(%v) = %v, expected %v", tt.err, result, tt.expected)
}
})
}
}
// TestExtractHost 测试从 URL 提取主机
func TestExtractHost(t *testing.T) {
// extractHost 函数可能不存在,检查一下
// 如果存在则测试
}
// TestDialTarget_InvalidAddress 测试无效地址的拨号
func TestDialTarget_InvalidAddress(t *testing.T) {
// 测试连接到无效端口
_, err := dialTarget("http://127.0.0.1:1", 100*time.Millisecond)
if err == nil {
t.Error("Expected error for invalid address")
}
}
// TestDialTarget_HTTPS 测试 HTTPS 连接(会失败,但验证错误处理)
func TestDialTarget_HTTPS(t *testing.T) {
// 测试 HTTPS 连接到无效端口
_, err := dialTarget("https://127.0.0.1:1", 100*time.Millisecond)
if err == nil {
t.Error("Expected error for invalid HTTPS address")
}
}
// mockNetError 模拟网络错误
type mockNetError struct {
msg string
}
func (e *mockNetError) Error() string { return e.msg }
func (e *mockNetError) Timeout() bool { return true }
func (e *mockNetError) Temporary() bool { return false }
// TestIsConnectionClosedError_Timeout 测试超时错误
func TestIsConnectionClosedError_Timeout(t *testing.T) {
timeoutErr := &mockNetError{msg: "timeout"}
result := isConnectionClosedError(timeoutErr)
if !result {
t.Error("Expected timeout error to be treated as closed connection error")
}
}
// TestWebSocketBridge_Bridge 测试双向数据转发
func TestWebSocketBridge_Bridge(t *testing.T) {
// 创建管道连接
client1, client2 := net.Pipe()
target1, target2 := net.Pipe()
defer func() { _ = client2.Close() }()
defer func() { _ = target2.Close() }()
bridge := NewWebSocketBridge(client1, target1)
// 启动桥接(在 goroutine 中)
errCh := make(chan error, 1)
go func() {
errCh <- bridge.Bridge()
}()
// 发送数据从客户端到后端
testData := []byte("hello from client")
go func() {
_, _ = client2.Write(testData)
}()
// 在后端读取数据
buf := make([]byte, 1024)
n, err := target2.Read(buf)
if err != nil {
t.Fatalf("Failed to read from target: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected %q, got %q", string(testData), string(buf[:n]))
}
// 发送数据从后端到客户端
testData2 := []byte("hello from target")
go func() {
_, _ = target2.Write(testData2)
}()
// 在客户端读取数据
buf2 := make([]byte, 1024)
n, err = client2.Read(buf2)
if err != nil {
t.Fatalf("Failed to read from client: %v", err)
}
if string(buf2[:n]) != string(testData2) {
t.Errorf("Expected %q, got %q", string(testData2), string(buf2[:n]))
}
// 关闭连接以结束桥接
_ = client2.Close()
_ = target2.Close()
// 等待桥接完成
select {
case err := <-errCh:
if err != nil {
t.Errorf("Bridge returned error: %v", err)
}
case <-time.After(1 * time.Second):
t.Error("Bridge did not complete in time")
}
}
// TestDialTarget_URLParsing 测试 URL 解析
func TestDialTarget_URLParsing(t *testing.T) {
tests := []struct {
name string
url string
expectError bool
}{
{
name: "http URL with invalid port",
url: "http://127.0.0.1:1",
expectError: true, // 连接会失败
},
{
name: "https URL with invalid port",
url: "https://127.0.0.1:1",
expectError: true, // 连接会失败
},
{
name: "URL with path",
url: "http://127.0.0.1:1/ws",
expectError: true, // 连接会失败
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := dialTarget(tt.url, 10*time.Millisecond)
if tt.expectError {
if err == nil {
t.Error("Expected error, got nil")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
})
}
}
// TestCopyData 测试数据复制
func TestCopyData(t *testing.T) {
// 创建管道连接
src1, src2 := net.Pipe()
dst1, dst2 := net.Pipe()
defer func() { _ = src2.Close() }()
defer func() { _ = dst2.Close() }()
bridge := &WebSocketBridge{}
// 启动数据复制
errCh := make(chan error, 1)
go func() {
errCh <- bridge.copyData(dst1, src1, "test")
}()
// 发送数据
testData := []byte("test data")
_, _ = src2.Write(testData)
// 接收数据
buf := make([]byte, 1024)
n, err := dst2.Read(buf)
if err != nil {
t.Fatalf("Failed to read: %v", err)
}
if string(buf[:n]) != string(testData) {
t.Errorf("Expected %q, got %q", string(testData), string(buf[:n]))
}
// 关闭连接
_ = src2.Close()
_ = dst2.Close()
// 等待复制完成
select {
case err := <-errCh:
// 连接关闭错误应返回 nil
if err != nil && !strings.Contains(err.Error(), "closed") {
t.Errorf("copyData returned unexpected error: %v", err)
}
case <-time.After(1 * time.Second):
t.Error("copyData did not complete in time")
}
}