- Add GoDoc for Warning.String, ParseError.Error - Add GoDoc for ngxReqAPILayer.String, Phase.String, SocketState.String - Add GoDoc for ConflictError.Error - Add GoDoc for noopResolver methods (LookupHost, LookupHostWithCache, Refresh, Start, Stop, Stats) - Add GoDoc for load balancer Select methods (roundRobin, weightedRoundRobin, ipHash) - Add GoDoc for WithWSHeaders test utility - Include author attribution (xfy)
413 lines
8.1 KiB
Go
413 lines
8.1 KiB
Go
//go:build e2e
|
||
|
||
// Package testutil 提供 E2E 测试的工具函数。
|
||
//
|
||
// 包含 WebSocket 测试辅助工具。
|
||
//
|
||
// 作者:xfy
|
||
package testutil
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"net/http"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/gorilla/websocket"
|
||
)
|
||
|
||
// WSClient WebSocket 测试客户端。
|
||
//
|
||
// 封装 gorilla/websocket,提供简单的测试接口。
|
||
type WSClient struct {
|
||
conn *websocket.Conn
|
||
url string
|
||
mu sync.Mutex
|
||
closed bool
|
||
closeChan chan struct{}
|
||
}
|
||
|
||
// WSOption WebSocket 客户端选项。
|
||
type WSOption func(*wsConfig)
|
||
|
||
type wsConfig struct {
|
||
headers http.Header
|
||
pingPeriod time.Duration
|
||
pongWait time.Duration
|
||
}
|
||
|
||
// WithHeaders 设置请求头。
|
||
// WithWSHeaders 设置 WebSocket 请求头。
|
||
func WithWSHeaders(headers http.Header) WSOption {
|
||
return func(c *wsConfig) {
|
||
c.headers = headers
|
||
}
|
||
}
|
||
|
||
// WithWSTimeout 设置超时时间。
|
||
func WithWSTimeout(pongWait, pingPeriod time.Duration) WSOption {
|
||
return func(c *wsConfig) {
|
||
c.pongWait = pongWait
|
||
c.pingPeriod = pingPeriod
|
||
}
|
||
}
|
||
|
||
// NewWSClient 创建 WebSocket 客户端。
|
||
//
|
||
// 参数:
|
||
// - ctx: 上下文
|
||
// - url: WebSocket URL(ws:// 或 wss://)
|
||
// - opts: 可选配置
|
||
//
|
||
// 返回 WebSocket 客户端实例。
|
||
func NewWSClient(ctx context.Context, url string, opts ...WSOption) (*WSClient, error) {
|
||
cfg := &wsConfig{
|
||
headers: http.Header{},
|
||
pongWait: 60 * time.Second,
|
||
pingPeriod: 54 * time.Second,
|
||
}
|
||
|
||
for _, opt := range opts {
|
||
opt(cfg)
|
||
}
|
||
|
||
dialer := websocket.DefaultDialer
|
||
conn, _, err := dialer.DialContext(ctx, url, cfg.headers)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("websocket dial failed: %w", err)
|
||
}
|
||
|
||
client := &WSClient{
|
||
conn: conn,
|
||
url: url,
|
||
closeChan: make(chan struct{}),
|
||
}
|
||
|
||
// 设置 pong 处理
|
||
conn.SetPongHandler(func(string) error {
|
||
return conn.SetReadDeadline(time.Now().Add(cfg.pongWait))
|
||
})
|
||
|
||
return client, nil
|
||
}
|
||
|
||
// Send 发送文本消息。
|
||
func (c *WSClient) Send(message string) error {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
if c.closed {
|
||
return fmt.Errorf("connection closed")
|
||
}
|
||
|
||
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||
return c.conn.WriteMessage(websocket.TextMessage, []byte(message))
|
||
}
|
||
|
||
// SendBinary 发送二进制消息。
|
||
func (c *WSClient) SendBinary(data []byte) error {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
if c.closed {
|
||
return fmt.Errorf("connection closed")
|
||
}
|
||
|
||
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||
return c.conn.WriteMessage(websocket.BinaryMessage, data)
|
||
}
|
||
|
||
// SendJSON 发送 JSON 消息。
|
||
func (c *WSClient) SendJSON(v interface{}) error {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
if c.closed {
|
||
return fmt.Errorf("connection closed")
|
||
}
|
||
|
||
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||
return c.conn.WriteJSON(v)
|
||
}
|
||
|
||
// Receive 接收文本消息。
|
||
//
|
||
// 返回消息内容和错误。
|
||
func (c *WSClient) Receive() (string, error) {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
if c.closed {
|
||
return "", fmt.Errorf("connection closed")
|
||
}
|
||
|
||
c.conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||
messageType, data, err := c.conn.ReadMessage()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if messageType != websocket.TextMessage {
|
||
return "", fmt.Errorf("expected text message, got %d", messageType)
|
||
}
|
||
|
||
return string(data), nil
|
||
}
|
||
|
||
// ReceiveBinary 接收二进制消息。
|
||
func (c *WSClient) ReceiveBinary() ([]byte, error) {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
if c.closed {
|
||
return nil, fmt.Errorf("connection closed")
|
||
}
|
||
|
||
c.conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||
messageType, data, err := c.conn.ReadMessage()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if messageType != websocket.BinaryMessage {
|
||
return nil, fmt.Errorf("expected binary message, got %d", messageType)
|
||
}
|
||
|
||
return data, nil
|
||
}
|
||
|
||
// ReceiveJSON 接收 JSON 消息。
|
||
func (c *WSClient) ReceiveJSON(v interface{}) error {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
if c.closed {
|
||
return fmt.Errorf("connection closed")
|
||
}
|
||
|
||
c.conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||
return c.conn.ReadJSON(v)
|
||
}
|
||
|
||
// ReceiveWithTimeout 接收消息(带超时)。
|
||
func (c *WSClient) ReceiveWithTimeout(timeout time.Duration) (string, error) {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
if c.closed {
|
||
return "", fmt.Errorf("connection closed")
|
||
}
|
||
|
||
c.conn.SetReadDeadline(time.Now().Add(timeout))
|
||
messageType, data, err := c.conn.ReadMessage()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if messageType != websocket.TextMessage {
|
||
return "", fmt.Errorf("expected text message, got %d", messageType)
|
||
}
|
||
|
||
return string(data), nil
|
||
}
|
||
|
||
// ReceiveChan 返回消息通道。
|
||
//
|
||
// 在后台持续接收消息,通过通道返回。
|
||
func (c *WSClient) ReceiveChan() <-chan WSMessage {
|
||
ch := make(chan WSMessage, 10)
|
||
|
||
go func() {
|
||
defer close(ch)
|
||
for {
|
||
msg, err := c.Receive()
|
||
if err != nil {
|
||
return
|
||
}
|
||
ch <- WSMessage{Data: msg}
|
||
}
|
||
}()
|
||
|
||
return ch
|
||
}
|
||
|
||
// WSMessage WebSocket 消息。
|
||
type WSMessage struct {
|
||
Data string
|
||
Error error
|
||
}
|
||
|
||
// Close 关闭连接。
|
||
func (c *WSClient) Close() error {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
if c.closed {
|
||
return nil
|
||
}
|
||
|
||
c.closed = true
|
||
close(c.closeChan)
|
||
|
||
// 发送关闭帧
|
||
err := c.conn.WriteMessage(websocket.CloseMessage,
|
||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||
if err != nil {
|
||
c.conn.Close()
|
||
return err
|
||
}
|
||
|
||
return c.conn.Close()
|
||
}
|
||
|
||
// IsClosed 检查连接是否已关闭。
|
||
func (c *WSClient) IsClosed() bool {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
return c.closed
|
||
}
|
||
|
||
// URL 返回连接 URL。
|
||
func (c *WSClient) URL() string {
|
||
return c.url
|
||
}
|
||
|
||
// CloseChan 返回关闭通道。
|
||
func (c *WSClient) CloseChan() <-chan struct{} {
|
||
return c.closeChan
|
||
}
|
||
|
||
// WSPool WebSocket 连接池。
|
||
//
|
||
// 管理多个 WebSocket 连接,用于并发测试。
|
||
type WSPool struct {
|
||
clients []*WSClient
|
||
mu sync.Mutex
|
||
}
|
||
|
||
// NewWSPool 创建 WebSocket 连接池。
|
||
//
|
||
// 参数:
|
||
// - ctx: 上下文
|
||
// - url: WebSocket URL
|
||
// - count: 连接数量
|
||
//
|
||
// 返回连接池实例。
|
||
func NewWSPool(ctx context.Context, url string, count int) (*WSPool, error) {
|
||
pool := &WSPool{
|
||
clients: make([]*WSClient, count),
|
||
}
|
||
|
||
for i := 0; i < count; i++ {
|
||
client, err := NewWSClient(ctx, url)
|
||
if err != nil {
|
||
pool.Close()
|
||
return nil, fmt.Errorf("failed to create client %d: %w", i, err)
|
||
}
|
||
pool.clients[i] = client
|
||
}
|
||
|
||
return pool, nil
|
||
}
|
||
|
||
// SendAll 向所有连接发送消息。
|
||
func (p *WSPool) SendAll(message string) error {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
|
||
var lastErr error
|
||
for _, client := range p.clients {
|
||
if client != nil {
|
||
if err := client.Send(message); err != nil {
|
||
lastErr = err
|
||
}
|
||
}
|
||
}
|
||
return lastErr
|
||
}
|
||
|
||
// SendOne 向指定连接发送消息。
|
||
func (p *WSPool) SendOne(index int, message string) error {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
|
||
if index < 0 || index >= len(p.clients) {
|
||
return fmt.Errorf("invalid index %d", index)
|
||
}
|
||
|
||
if p.clients[index] == nil {
|
||
return fmt.Errorf("client %d is nil", index)
|
||
}
|
||
|
||
return p.clients[index].Send(message)
|
||
}
|
||
|
||
// ReceiveAll 从所有连接接收消息。
|
||
//
|
||
// 返回每个连接收到的消息列表。
|
||
func (p *WSPool) ReceiveAll() ([]string, error) {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
|
||
messages := make([]string, len(p.clients))
|
||
var lastErr error
|
||
|
||
for i, client := range p.clients {
|
||
if client != nil {
|
||
msg, err := client.Receive()
|
||
if err != nil {
|
||
lastErr = err
|
||
messages[i] = ""
|
||
} else {
|
||
messages[i] = msg
|
||
}
|
||
}
|
||
}
|
||
|
||
return messages, lastErr
|
||
}
|
||
|
||
// Count 返回连接数量。
|
||
func (p *WSPool) Count() int {
|
||
return len(p.clients)
|
||
}
|
||
|
||
// Close 关闭所有连接。
|
||
func (p *WSPool) Close() error {
|
||
p.mu.Lock()
|
||
defer p.mu.Unlock()
|
||
|
||
var lastErr error
|
||
for _, client := range p.clients {
|
||
if client != nil {
|
||
if err := client.Close(); err != nil {
|
||
lastErr = err
|
||
}
|
||
}
|
||
}
|
||
return lastErr
|
||
}
|
||
|
||
// WSEchoServer WebSocket Echo 服务器配置。
|
||
type WSEchoServer struct {
|
||
Port int
|
||
Handler func(*websocket.Conn)
|
||
}
|
||
|
||
// NewWSEchoHandler 创建 Echo 处理器。
|
||
//
|
||
// 将收到的消息原样返回。
|
||
func NewWSEchoHandler() func(*websocket.Conn) {
|
||
return func(conn *websocket.Conn) {
|
||
defer conn.Close()
|
||
for {
|
||
messageType, data, err := conn.ReadMessage()
|
||
if err != nil {
|
||
return
|
||
}
|
||
conn.WriteMessage(messageType, data)
|
||
}
|
||
}
|
||
}
|