test(e2e/testutil): 扩展测试工具包

添加配置生成、常量定义、测试设置、SSL 和 WebSocket 工具函数。
重构 container.go 支持函数式选项模式配置容器。

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-23 14:51:11 +08:00
parent 66060928d1
commit 0790c5a9e4
7 changed files with 1870 additions and 22 deletions

View File

@ -18,6 +18,7 @@ type ConcurrentRequestConfig struct {
Count int
Timeout time.Duration
ExpectCode int
Client *http.Client // 可选的自定义客户端
}
// ConcurrentRequestResult 并发请求结果。
@ -35,8 +36,11 @@ func RunConcurrentRequests(cfg ConcurrentRequestConfig) []ConcurrentRequestResul
results := make([]ConcurrentRequestResult, cfg.Count)
var wg sync.WaitGroup
client := &http.Client{
Timeout: cfg.Timeout,
client := cfg.Client
if client == nil {
client = &http.Client{
Timeout: cfg.Timeout,
}
}
for i := 0; i < cfg.Count; i++ {

View File

@ -0,0 +1,585 @@
//go:build e2e
// Package testutil 提供 E2E 测试的工具函数。
//
// 包含动态配置生成器,支持编程方式生成 YAML 配置文件。
//
// 作者xfy
package testutil
import (
"fmt"
"os"
"path/filepath"
"time"
"gopkg.in/yaml.v3"
"rua.plus/lolly/internal/config"
)
// ConfigBuilder 动态配置构建器。
//
// 支持编程方式生成 YAML 配置,用于 E2E 测试场景。
// 提供链式调用方式,方便组合不同配置。
//
// 使用示例:
//
// cfg := testutil.NewConfigBuilder().
// WithServer(":8080").
// WithProxy("/api/", targets).
// WithSSL(certPath, keyPath).
// Build()
type ConfigBuilder struct {
cfg *config.Config
}
// NewConfigBuilder 创建配置构建器。
func NewConfigBuilder() *ConfigBuilder {
return &ConfigBuilder{
cfg: &config.Config{
Servers: []config.ServerConfig{},
},
}
}
// WithServer 添加服务器配置。
//
// 参数:
// - listen: 监听地址,如 ":8080"
//
// 返回构建器以支持链式调用。
func (b *ConfigBuilder) WithServer(listen string) *ConfigBuilder {
b.cfg.Servers = append(b.cfg.Servers, config.ServerConfig{
Listen: listen,
})
return b
}
// WithServerConfig 添加完整服务器配置。
func (b *ConfigBuilder) WithServerConfig(server config.ServerConfig) *ConfigBuilder {
b.cfg.Servers = append(b.cfg.Servers, server)
return b
}
// ProxyTargetOption 代理目标选项。
type ProxyTargetOption func(*config.ProxyTarget)
// WithWeight 设置权重。
func WithWeight(weight int) ProxyTargetOption {
return func(t *config.ProxyTarget) {
t.Weight = weight
}
}
// WithMaxConns 设置最大连接数。
func WithMaxConns(maxConns int) ProxyTargetOption {
return func(t *config.ProxyTarget) {
t.MaxConns = maxConns
}
}
// WithMaxFails 设置最大失败次数。
func WithMaxFails(maxFails int, failTimeout time.Duration) ProxyTargetOption {
return func(t *config.ProxyTarget) {
t.MaxFails = maxFails
t.FailTimeout = failTimeout
}
}
// WithBackup 设置为备份服务器。
func WithBackup() ProxyTargetOption {
return func(t *config.ProxyTarget) {
t.Backup = true
}
}
// ProxyOption 代理配置选项。
type ProxyOption func(*config.ProxyConfig)
// ProxyConfig 代理配置类型别名。
type ProxyConfig = config.ProxyConfig
// WithLoadBalance 设置负载均衡算法。
func WithLoadBalance(algorithm string) ProxyOption {
return func(p *config.ProxyConfig) {
p.LoadBalance = algorithm
}
}
// WithHealthCheck 设置健康检查。
func WithHealthCheck(path string, interval, timeout time.Duration) ProxyOption {
return func(p *config.ProxyConfig) {
p.HealthCheck = config.HealthCheckConfig{
Path: path,
Interval: interval,
Timeout: timeout,
}
}
}
// WithProxyTimeout 设置代理超时。
func WithProxyTimeout(connect, read, write time.Duration) ProxyOption {
return func(p *config.ProxyConfig) {
p.Timeout = config.ProxyTimeout{
Connect: connect,
Read: read,
Write: write,
}
}
}
// WithProxyHeaders 设置代理头部。
func WithProxyHeaders(setRequest, setResponse map[string]string) ProxyOption {
return func(p *config.ProxyConfig) {
p.Headers = config.ProxyHeaders{
SetRequest: setRequest,
SetResponse: setResponse,
}
}
}
// WithProxyCache 设置代理缓存。
func WithProxyCache(maxAge time.Duration, cacheLock bool) ProxyOption {
return func(p *config.ProxyConfig) {
p.Cache = config.ProxyCacheConfig{
Enabled: true,
MaxAge: maxAge,
CacheLock: cacheLock,
}
}
}
// WithProxySSL 设置上游 SSL。
func WithProxySSL(serverName string, insecureSkipVerify bool) ProxyOption {
return func(p *config.ProxyConfig) {
p.ProxySSL = &config.ProxySSLConfig{
Enabled: true,
ServerName: serverName,
InsecureSkipVerify: insecureSkipVerify,
}
}
}
// WithProxyBuffering 设置代理缓冲。
func WithProxyBuffering(mode string, bufferSize int) ProxyOption {
return func(p *config.ProxyConfig) {
p.Buffering = &config.ProxyBufferingConfig{
Mode: mode,
BufferSize: bufferSize,
}
}
}
// WithProxyNextUpstream 设置故障转移。
func WithProxyNextUpstream(tries int, httpCodes []int) ProxyOption {
return func(p *config.ProxyConfig) {
p.NextUpstream = config.NextUpstreamConfig{
Tries: tries,
HTTPCodes: httpCodes,
}
}
}
// WithProxy 添加代理配置。
//
// 参数:
// - path: 代理路径前缀
// - urls: 后端 URL 列表
// - opts: 可选配置选项
//
// 返回构建器以支持链式调用。
func (b *ConfigBuilder) WithProxy(path string, urls []string, opts ...ProxyOption) *ConfigBuilder {
if len(b.cfg.Servers) == 0 {
b.WithServer(":8080")
}
targets := make([]config.ProxyTarget, len(urls))
for i, url := range urls {
targets[i] = config.ProxyTarget{
URL: url,
}
}
proxy := config.ProxyConfig{
Path: path,
Targets: targets,
}
for _, opt := range opts {
opt(&proxy)
}
// 添加到第一个服务器
b.cfg.Servers[0].Proxy = append(b.cfg.Servers[0].Proxy, proxy)
return b
}
// WithProxyTargets 添加带选项的代理配置。
//
// targetOptsPerTarget 是每个目标的选项列表(索引对应 urls
// opts 是代理级别的选项。
func (b *ConfigBuilder) WithProxyTargets(path string, urls []string, targetOptsPerTarget [][]ProxyTargetOption, opts ...ProxyOption) *ConfigBuilder {
if len(b.cfg.Servers) == 0 {
b.WithServer(":8080")
}
targets := make([]config.ProxyTarget, len(urls))
for i, url := range urls {
targets[i] = config.ProxyTarget{
URL: url,
}
// 应用该目标的选项
if i < len(targetOptsPerTarget) {
for _, opt := range targetOptsPerTarget[i] {
opt(&targets[i])
}
}
}
proxy := config.ProxyConfig{
Path: path,
Targets: targets,
}
for _, opt := range opts {
opt(&proxy)
}
b.cfg.Servers[0].Proxy = append(b.cfg.Servers[0].Proxy, proxy)
return b
}
// SSLOption SSL 配置选项。
type SSLOption func(*config.SSLConfig)
// WithHTTP2 启用 HTTP/2。
func WithHTTP2(enabled bool, maxConcurrentStreams int) SSLOption {
return func(s *config.SSLConfig) {
s.HTTP2 = config.HTTP2Config{
Enabled: enabled,
MaxConcurrentStreams: maxConcurrentStreams,
}
}
}
// WithTLSProtocols 设置 TLS 协议版本。
func WithTLSProtocols(protocols []string) SSLOption {
return func(s *config.SSLConfig) {
s.Protocols = protocols
}
}
// WithSessionTickets 启用 Session Tickets。
func WithSessionTickets(enabled bool) SSLOption {
return func(s *config.SSLConfig) {
s.SessionTickets = config.SessionTicketsConfig{
Enabled: enabled,
}
}
}
// WithHSTS 配置 HSTS。
func WithHSTS(maxAge int, includeSubDomains bool) SSLOption {
return func(s *config.SSLConfig) {
s.HSTS = config.HSTSConfig{
MaxAge: maxAge,
IncludeSubDomains: includeSubDomains,
}
}
}
// WithSSL 配置 SSL/TLS。
//
// 参数:
// - cert: 证书文件路径
// - key: 私钥文件路径
// - opts: 可选 SSL 配置
//
// 返回构建器以支持链式调用。
func (b *ConfigBuilder) WithSSL(cert, key string, opts ...SSLOption) *ConfigBuilder {
if len(b.cfg.Servers) == 0 {
b.WithServer(":8443")
}
ssl := config.SSLConfig{
Cert: cert,
Key: key,
}
for _, opt := range opts {
opt(&ssl)
}
b.cfg.Servers[0].SSL = ssl
return b
}
// StaticOption 静态文件配置选项。
type StaticOption func(*config.StaticConfig)
// WithIndex 设置索引文件。
func WithIndex(index []string) StaticOption {
return func(s *config.StaticConfig) {
s.Index = index
}
}
// WithTryFiles 设置 try_files。
func WithTryFiles(tryFiles []string) StaticOption {
return func(s *config.StaticConfig) {
s.TryFiles = tryFiles
}
}
// WithStatic 添加静态文件配置。
func (b *ConfigBuilder) WithStatic(path, root string, opts ...StaticOption) *ConfigBuilder {
if len(b.cfg.Servers) == 0 {
b.WithServer(":8080")
}
static := config.StaticConfig{
Path: path,
Root: root,
}
for _, opt := range opts {
opt(&static)
}
b.cfg.Servers[0].Static = append(b.cfg.Servers[0].Static, static)
return b
}
// SecurityOption 安全配置选项。
type SecurityOption func(*config.SecurityConfig)
// WithRateLimit 设置速率限制。
func WithRateLimit(requestRate, burst int) SecurityOption {
return func(s *config.SecurityConfig) {
s.RateLimit = config.RateLimitConfig{
RequestRate: requestRate,
Burst: burst,
}
}
}
// WithAccessControl 设置访问控制。
func WithAccessControl(allow, deny []string, defaultAction string) SecurityOption {
return func(s *config.SecurityConfig) {
s.Access = config.AccessConfig{
Allow: allow,
Deny: deny,
Default: defaultAction,
}
}
}
// WithBasicAuth 设置 Basic 认证。
func WithBasicAuth(users []config.User) SecurityOption {
return func(s *config.SecurityConfig) {
s.Auth = config.AuthConfig{
Type: "basic",
Users: users,
}
}
}
// WithSecurity 配置安全选项。
func (b *ConfigBuilder) WithSecurity(opts ...SecurityOption) *ConfigBuilder {
if len(b.cfg.Servers) == 0 {
b.WithServer(":8080")
}
for _, opt := range opts {
opt(&b.cfg.Servers[0].Security)
}
return b
}
// CompressionOption 压缩配置选项。
type CompressionOption func(*config.CompressionConfig)
// WithCompressionType 设置压缩类型。
func WithCompressionType(typ string) CompressionOption {
return func(c *config.CompressionConfig) {
c.Type = typ
}
}
// WithCompressionLevel 设置压缩级别。
func WithCompressionLevel(level int) CompressionOption {
return func(c *config.CompressionConfig) {
c.Level = level
}
}
// WithCompressionMinSize 设置最小压缩大小。
func WithCompressionMinSize(minSize int) CompressionOption {
return func(c *config.CompressionConfig) {
c.MinSize = minSize
}
}
// WithCompression 配置压缩。
func (b *ConfigBuilder) WithCompression(opts ...CompressionOption) *ConfigBuilder {
if len(b.cfg.Servers) == 0 {
b.WithServer(":8080")
}
for _, opt := range opts {
opt(&b.cfg.Servers[0].Compression)
}
return b
}
// RewriteOption 重写规则选项。
type RewriteOption func(*config.RewriteRule)
// WithRewriteFlag 设置重写标志。
func WithRewriteFlag(flag string) RewriteOption {
return func(r *config.RewriteRule) {
r.Flag = flag
}
}
// WithRewrite 添加 URL 重写规则。
func (b *ConfigBuilder) WithRewrite(pattern, replacement string, opts ...RewriteOption) *ConfigBuilder {
if len(b.cfg.Servers) == 0 {
b.WithServer(":8080")
}
rule := config.RewriteRule{
Pattern: pattern,
Replacement: replacement,
}
for _, opt := range opts {
opt(&rule)
}
b.cfg.Servers[0].Rewrite = append(b.cfg.Servers[0].Rewrite, rule)
return b
}
// WithCachePath 配置缓存路径。
func (b *ConfigBuilder) WithCachePath(path string, maxSize int64) *ConfigBuilder {
b.cfg.CachePath = &config.ProxyCachePathConfig{
Path: path,
MaxSize: maxSize,
}
return b
}
// WithResolver 配置 DNS 解析器。
func (b *ConfigBuilder) WithResolver(addresses []string, valid, timeout time.Duration) *ConfigBuilder {
b.cfg.Resolver = config.ResolverConfig{
Enabled: true,
Addresses: addresses,
Valid: valid,
Timeout: timeout,
}
return b
}
// WithLogging 配置日志。
func (b *ConfigBuilder) WithLogging(format string) *ConfigBuilder {
b.cfg.Logging = config.LoggingConfig{
Format: format,
}
return b
}
// WithShutdown 配置关闭超时。
func (b *ConfigBuilder) WithShutdown(graceful, fast time.Duration) *ConfigBuilder {
b.cfg.Shutdown = config.ShutdownConfig{
GracefulTimeout: graceful,
FastTimeout: fast,
}
return b
}
// Build 生成 YAML 配置字符串。
//
// 返回 YAML 格式的配置字符串。
func (b *ConfigBuilder) Build() (string, error) {
data, err := yaml.Marshal(b.cfg)
if err != nil {
return "", fmt.Errorf("序列化配置失败: %w", err)
}
return string(data), nil
}
// WriteTemp 写入临时文件。
//
// 创建临时目录并写入配置文件,返回文件路径。
// 调用者负责清理临时目录。
func (b *ConfigBuilder) WriteTemp() (string, error) {
yamlStr, err := b.Build()
if err != nil {
return "", err
}
// 创建临时目录
tmpDir, err := os.MkdirTemp("", "lolly-e2e-*")
if err != nil {
return "", fmt.Errorf("创建临时目录失败: %w", err)
}
// 写入配置文件
configPath := filepath.Join(tmpDir, "lolly.yaml")
if err := os.WriteFile(configPath, []byte(yamlStr), 0o644); err != nil {
os.RemoveAll(tmpDir)
return "", fmt.Errorf("写入配置文件失败: %w", err)
}
return configPath, nil
}
// WriteTo 写入指定目录。
func (b *ConfigBuilder) WriteTo(dir string) (string, error) {
yamlStr, err := b.Build()
if err != nil {
return "", err
}
// 确保目录存在
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", fmt.Errorf("创建目录失败: %w", err)
}
configPath := filepath.Join(dir, "lolly.yaml")
if err := os.WriteFile(configPath, []byte(yamlStr), 0o644); err != nil {
return "", fmt.Errorf("写入配置文件失败: %w", err)
}
return configPath, nil
}
// GetConfig 返回配置对象。
func (b *ConfigBuilder) GetConfig() *config.Config {
return b.cfg
}
// Reset 重置构建器。
func (b *ConfigBuilder) Reset() *ConfigBuilder {
b.cfg = &config.Config{
Servers: []config.ServerConfig{},
}
return b
}
// Clone 克隆构建器。
func (b *ConfigBuilder) Clone() *ConfigBuilder {
data, err := yaml.Marshal(b.cfg)
if err != nil {
return NewConfigBuilder()
}
var newCfg config.Config
if err := yaml.Unmarshal(data, &newCfg); err != nil {
return NewConfigBuilder()
}
return &ConfigBuilder{cfg: &newCfg}
}

View File

@ -0,0 +1,70 @@
//go:build e2e
// Package testutil 提供 E2E 测试的工具函数。
//
// 包含测试常量定义。
//
// 作者xfy
package testutil
import (
"crypto/tls"
"time"
)
// 测试超时常量。
const (
// ContainerStartupTimeout 容器启动超时。
ContainerStartupTimeout = 30 * time.Second
// HealthCheckWaitTimeout 健康检查等待超时。
HealthCheckWaitTimeout = 30 * time.Second
// HealthCheckDetectionTime 健康检查检测时间。
HealthCheckDetectionTime = 10 * time.Second
// CacheExpireBuffer 缓存过期缓冲时间。
CacheExpireBuffer = 1 * time.Second
// DefaultTestTimeout 测试上下文超时。
DefaultTestTimeout = 180 * time.Second
// DefaultClientTimeout HTTP 客户端超时。
DefaultClientTimeout = 10 * time.Second
// ConcurrentRequestTimeout 并发请求超时。
ConcurrentRequestTimeout = 30 * time.Second
// ShortTestTimeout 短测试超时(用于快速测试)。
ShortTestTimeout = 60 * time.Second
// MediumTestTimeout 中等测试超时。
MediumTestTimeout = 120 * time.Second
)
// 测试配置常量。
const (
// DefaultBackendCount 默认后端数量。
DefaultBackendCount = 2
// DefaultConcurrentRequests 并发请求数量。
DefaultConcurrentRequests = 10
// HighConcurrentRequests 高并发请求数量。
HighConcurrentRequests = 20
// CacheTestMaxAge 缓存测试过期时间。
CacheTestMaxAge = 5 * time.Minute
// CacheTestShortMaxAge 短缓存过期时间(用于过期测试)。
CacheTestShortMaxAge = 2 * time.Second
)
// TLS 版本常量(用于配置客户端)。
const (
// TLSVersion12 TLS 1.2。
TLSVersion12 = tls.VersionTLS12
// TLSVersion13 TLS 1.3。
TLSVersion13 = tls.VersionTLS13
)

View File

@ -29,6 +29,89 @@ servers:
- "index.html"
`
// LollyContainerOption 容器启动选项。
type LollyContainerOption func(*lollyContainerConfig)
// lollyContainerConfig 容器配置。
type lollyContainerConfig struct {
configPath string
configYAML string
network string
certPath string
keyPath string
extraMounts []testcontainers.ContainerMount
env map[string]string
exposedPorts []string
waitFor wait.Strategy
}
// WithConfigFile 使用配置文件路径。
func WithConfigFile(path string) LollyContainerOption {
return func(c *lollyContainerConfig) {
c.configPath = path
}
}
// WithConfigYAML 使用 YAML 字符串配置。
func WithConfigYAML(yaml string) LollyContainerOption {
return func(c *lollyContainerConfig) {
c.configYAML = yaml
}
}
// WithNetwork 加入指定网络。
func WithNetwork(name string) LollyContainerOption {
return func(c *lollyContainerConfig) {
c.network = name
}
}
// WithCert 挂载证书文件。
func WithCert(certPath, keyPath string) LollyContainerOption {
return func(c *lollyContainerConfig) {
c.certPath = certPath
c.keyPath = keyPath
}
}
// WithExtraMount 添加额外挂载。
func WithExtraMount(hostPath, containerPath string) LollyContainerOption {
return func(c *lollyContainerConfig) {
c.extraMounts = append(c.extraMounts, testcontainers.ContainerMount{
Source: testcontainers.GenericBindMountSource{
HostPath: hostPath,
},
Target: testcontainers.ContainerMountTarget(containerPath),
})
}
}
// WithEnv 设置环境变量。
func WithEnv(env map[string]string) LollyContainerOption {
return func(c *lollyContainerConfig) {
if c.env == nil {
c.env = make(map[string]string)
}
for k, v := range env {
c.env[k] = v
}
}
}
// WithExposedPorts 设置暴露端口。
func WithExposedPorts(ports ...string) LollyContainerOption {
return func(c *lollyContainerConfig) {
c.exposedPorts = ports
}
}
// WithWaitStrategy 设置等待策略。
func WithWaitStrategy(strategy wait.Strategy) LollyContainerOption {
return func(c *lollyContainerConfig) {
c.waitFor = strategy
}
}
// LollyContainer 封装 lolly 服务器容器。
type LollyContainer struct {
Container testcontainers.Container
@ -40,34 +123,99 @@ type LollyContainer struct {
// StartLollyContainer 启动 lolly 服务器容器。
//
// 使用预构建的 lolly 镜像。如果 configPath 为空,使用默认配置。
// 支持通过选项函数自定义配置。
func StartLollyContainer(ctx context.Context, configPath string) (*LollyContainer, error) {
req := testcontainers.ContainerRequest{
Image: "lolly:latest",
ExposedPorts: []string{"8080/tcp", "8443/tcp"},
WaitingFor: wait.ForLog("HTTP 服务器启动中").WithStartupTimeout(30 * time.Second),
return StartLolly(ctx, WithConfigFile(configPath))
}
// StartLolly 启动 lolly 容器(增强版)。
//
// 支持多种配置方式和自定义选项。
//
// 使用示例:
//
// // 使用默认配置
// lolly, err := StartLolly(ctx)
//
// // 使用配置文件
// lolly, err := StartLolly(ctx, WithConfigFile("/path/to/config.yaml"))
//
// // 使用动态配置
// cfg := NewConfigBuilder().WithProxy("/api/", targets).Build()
// lolly, err := StartLolly(ctx, WithConfigYAML(cfg))
//
// // 使用 SSL
// lolly, err := StartLolly(ctx, WithConfigBuilder(cfg), WithCert(certPath, keyPath))
func StartLolly(ctx context.Context, opts ...LollyContainerOption) (*LollyContainer, error) {
cfg := &lollyContainerConfig{
exposedPorts: []string{"8080/tcp", "8443/tcp"},
waitFor: wait.ForLog("HTTP 服务器启动中").WithStartupTimeout(30 * time.Second),
}
// 配置文件挂载
if configPath != "" {
req.Mounts = []testcontainers.ContainerMount{
{
Source: testcontainers.GenericBindMountSource{
HostPath: configPath,
},
Target: "/etc/lolly/lolly.yaml",
for _, opt := range opts {
opt(cfg)
}
req := testcontainers.ContainerRequest{
Image: "lolly:latest",
ExposedPorts: cfg.exposedPorts,
WaitingFor: cfg.waitFor,
}
// 设置环境变量
if len(cfg.env) > 0 {
req.Env = cfg.env
}
// 配置网络
if cfg.network != "" {
req.Networks = []string{cfg.network}
}
// 处理配置文件
if cfg.configPath != "" {
req.Mounts = append(req.Mounts, testcontainers.ContainerMount{
Source: testcontainers.GenericBindMountSource{
HostPath: cfg.configPath,
},
}
Target: "/etc/lolly/lolly.yaml",
})
} else if cfg.configYAML != "" {
req.Files = append(req.Files, testcontainers.ContainerFile{
Reader: strings.NewReader(cfg.configYAML),
ContainerFilePath: "/etc/lolly/lolly.yaml",
FileMode: 0o644,
})
} else {
// 使用内嵌默认配置
req.Files = []testcontainers.ContainerFile{
{
Reader: strings.NewReader(defaultLollyConfig),
ContainerFilePath: "/etc/lolly/lolly.yaml",
FileMode: 0o644,
},
}
req.Files = append(req.Files, testcontainers.ContainerFile{
Reader: strings.NewReader(defaultLollyConfig),
ContainerFilePath: "/etc/lolly/lolly.yaml",
FileMode: 0o644,
})
}
// 挂载证书
if cfg.certPath != "" && cfg.keyPath != "" {
req.Mounts = append(req.Mounts,
testcontainers.ContainerMount{
Source: testcontainers.GenericBindMountSource{
HostPath: cfg.certPath,
},
Target: "/etc/lolly/ssl/server.crt",
},
testcontainers.ContainerMount{
Source: testcontainers.GenericBindMountSource{
HostPath: cfg.keyPath,
},
Target: "/etc/lolly/ssl/server.key",
},
)
}
// 添加额外挂载
req.Mounts = append(req.Mounts, cfg.extraMounts...)
container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
@ -226,6 +374,11 @@ func LollyImageAvailable(ctx context.Context) bool {
//
// 使用 nginx 作为模拟后端,返回容器和访问地址。
// 注意:此函数仅用于代理测试的后端模拟,不应作为被测系统。
//
// 返回值:
// - container: 容器实例
// - hostPort: 宿主机访问地址(用于测试代码访问)
// - internalAddr: 容器内部访问地址(用于 lolly 配置)
func StartMockBackend(ctx context.Context) (testcontainers.Container, string, error) {
req := testcontainers.ContainerRequest{
Image: "nginx:alpine",
@ -253,6 +406,324 @@ func StartMockBackend(ctx context.Context) (testcontainers.Container, string, er
return nil, "", fmt.Errorf("failed to get port: %w", err)
}
// 返回宿主机地址
addr := fmt.Sprintf("http://%s:%s", host, port.Port())
return container, addr, nil
}
// BackendPool 后端池管理。
//
// 管理多个后端容器,用于负载均衡测试。
// 支持网络模式:当 network 不为空时,容器加入指定网络,
// 并提供内部地址供 lolly 容器访问。
type BackendPool struct {
containers []testcontainers.Container
addresses []string // 宿主机访问地址
internal []string // 容器网络内部地址
network string // Docker 网络名称
}
// StartBackendPool 启动多个后端容器。
//
// 参数:
// - count: 后端数量
//
// 返回后端池和地址列表(宿主机访问地址)。
func StartBackendPool(ctx context.Context, count int) (*BackendPool, error) {
return StartBackendPoolWithNetwork(ctx, count, "")
}
// StartBackendPoolWithNetwork 启动多个后端容器并加入网络。
//
// 参数:
// - count: 后端数量
// - network: Docker 网络名称(可选,为空则不加入网络)
//
// 当 network 不为空时,容器会加入该网络,并提供内部地址。
func StartBackendPoolWithNetwork(ctx context.Context, count int, network string) (*BackendPool, error) {
pool := &BackendPool{
containers: make([]testcontainers.Container, count),
addresses: make([]string, count),
internal: make([]string, count),
network: network,
}
for i := 0; i < count; i++ {
container, addr, internalAddr, err := startMockBackendWithNetwork(ctx, network, i)
if err != nil {
// 清理已启动的容器
pool.Terminate(ctx)
return nil, fmt.Errorf("failed to start backend %d: %w", i, err)
}
pool.containers[i] = container
pool.addresses[i] = addr
pool.internal[i] = internalAddr
}
return pool, nil
}
// startMockBackendWithNetwork 启动单个后端容器。
func startMockBackendWithNetwork(ctx context.Context, network string, index int) (testcontainers.Container, string, string, error) {
// 生成容器名称(用于网络通信)
containerName := fmt.Sprintf("backend-%d-%d", time.Now().UnixNano(), index)
req := testcontainers.ContainerRequest{
Image: "nginx:alpine",
ExposedPorts: []string{"80/tcp"},
WaitingFor: wait.ForHTTP("/").WithStartupTimeout(30 * time.Second),
Name: containerName,
}
if network != "" {
req.Networks = []string{network}
}
container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
if err != nil {
return nil, "", "", fmt.Errorf("failed to start mock backend: %w", err)
}
host, err := container.Host(ctx)
if err != nil {
container.Terminate(ctx)
return nil, "", "", fmt.Errorf("failed to get host: %w", err)
}
port, err := container.MappedPort(ctx, "80/tcp")
if err != nil {
container.Terminate(ctx)
return nil, "", "", fmt.Errorf("failed to get port: %w", err)
}
// 宿主机访问地址
hostAddr := fmt.Sprintf("http://%s:%s", host, port.Port())
// 容器网络内部地址(使用容器名称)
internalAddr := fmt.Sprintf("http://%s:80", containerName)
return container, hostAddr, internalAddr, nil
}
// Addresses 返回后端地址列表(宿主机访问地址)。
func (p *BackendPool) Addresses() []string {
return p.addresses
}
// InternalAddresses 返回容器网络内部地址列表。
//
// 当 lolly 和后端在同一 Docker 网络时,应使用此地址。
func (p *BackendPool) InternalAddresses() []string {
return p.internal
}
// Containers 返回容器列表。
func (p *BackendPool) Containers() []testcontainers.Container {
return p.containers
}
// Count 返回后端数量。
func (p *BackendPool) Count() int {
return len(p.containers)
}
// Terminate 终止所有容器。
func (p *BackendPool) Terminate(ctx context.Context) {
for _, container := range p.containers {
if container != nil {
container.Terminate(ctx)
}
}
}
// TerminateOne 终止指定索引的容器。
func (p *BackendPool) TerminateOne(ctx context.Context, index int) error {
if index < 0 || index >= len(p.containers) {
return fmt.Errorf("invalid index %d", index)
}
if p.containers[index] != nil {
err := p.containers[index].Terminate(ctx)
p.containers[index] = nil
p.addresses[index] = ""
p.internal[index] = ""
return err
}
return nil
}
// RestartOne 重启指定索引的容器。
func (p *BackendPool) RestartOne(ctx context.Context, index int) error {
if index < 0 || index >= len(p.containers) {
return fmt.Errorf("invalid index %d", index)
}
// 先终止旧容器
if p.containers[index] != nil {
p.containers[index].Terminate(ctx)
}
// 启动新容器
container, addr, internalAddr, err := startMockBackendWithNetwork(ctx, p.network, index)
if err != nil {
return err
}
p.containers[index] = container
p.addresses[index] = addr
p.internal[index] = internalAddr
return nil
}
// CreateNetwork 创建 Docker 网络。
//
// 用于容器间通信。
func CreateNetwork(ctx context.Context, name string) (testcontainers.Network, error) {
network, err := testcontainers.GenericNetwork(ctx, testcontainers.GenericNetworkRequest{
NetworkRequest: testcontainers.NetworkRequest{
Name: name,
},
})
if err != nil {
return nil, fmt.Errorf("failed to create network: %w", err)
}
return network, nil
}
// sharedNetworkName 共享网络名称。
const sharedNetworkName = "lolly-e2e-test"
// SetupProxyTest 设置代理测试环境。
//
// 创建网络、启动后端池,返回网络名称和后端池。
// lolly 容器应使用 InternalAddresses() 作为代理目标。
// 使用共享网络名称,避免网络地址池耗尽。
//
// 使用示例:
//
// network, pool, err := testutil.SetupProxyTest(ctx, 2)
// if err != nil {
// t.Fatal(err)
// }
// defer testutil.CleanupProxyTest(ctx, network, pool)
//
// lolly, err := testutil.StartLolly(ctx,
// testutil.WithConfigYAML(configYAML),
// testutil.WithNetwork(network),
// )
func SetupProxyTest(ctx context.Context, backendCount int) (string, *BackendPool, error) {
// 使用共享网络名称
networkName := sharedNetworkName
// 尝试创建网络(如果不存在)
// 忽略"已存在"错误
_, err := CreateNetwork(ctx, networkName)
if err != nil && !isNetworkExistsError(err) {
return "", nil, fmt.Errorf("failed to create network: %w", err)
}
// 启动后端池并加入网络
pool, err := StartBackendPoolWithNetwork(ctx, backendCount, networkName)
if err != nil {
return "", nil, fmt.Errorf("failed to start backend pool: %w", err)
}
return networkName, pool, nil
}
// isNetworkExistsError 检查是否是网络已存在错误。
func isNetworkExistsError(err error) bool {
return err != nil && (containsString(err.Error(), "already exists") ||
containsString(err.Error(), "network with name") ||
containsString(err.Error(), "failed to create network"))
}
// containsString 检查字符串是否包含子串。
func containsString(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr))
}
func containsSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// CleanupProxyTest 清理代理测试环境。
func CleanupProxyTest(ctx context.Context, networkName string, pool *BackendPool) {
if pool != nil {
pool.Terminate(ctx)
}
// 网络会随容器终止自动清理
}
// ProxyTestEnv 代理测试环境。
//
// 封装代理测试所需的资源。
type ProxyTestEnv struct {
Network string
Pool *BackendPool
Lolly *LollyContainer
Cleanup func()
HTTPClient *http.Client
}
// SetupProxyTestEnv 设置完整的代理测试环境。
//
// 创建网络、启动后端池、启动 lolly返回封装的环境。
// 这是一个便捷函数,简化测试设置。
//
// 使用示例:
//
// env, err := testutil.SetupProxyTestEnv(ctx, t, 2, func(pool *testutil.BackendPool) string {
// cfg := testutil.NewConfigBuilder().
// WithServer(":8080").
// WithProxy("/", pool.InternalAddresses())
// yaml, _ := cfg.Build()
// return yaml
// })
func SetupProxyTestEnv(ctx context.Context, backendCount int, configBuilder func(*BackendPool) string) (*ProxyTestEnv, error) {
// 设置代理测试环境
networkName, pool, err := SetupProxyTest(ctx, backendCount)
if err != nil {
return nil, err
}
// 构建配置
configYAML := configBuilder(pool)
// 启动 lolly
lolly, err := StartLolly(ctx,
WithConfigYAML(configYAML),
WithNetwork(networkName),
)
if err != nil {
CleanupProxyTest(ctx, networkName, pool)
return nil, fmt.Errorf("failed to start lolly: %w", err)
}
// 等待健康
if err := lolly.WaitForHealthy(ctx, 30*time.Second); err != nil {
CleanupProxyTest(ctx, networkName, pool)
lolly.Terminate(ctx)
return nil, fmt.Errorf("lolly not healthy: %w", err)
}
cleanup := func() {
lolly.Terminate(ctx)
CleanupProxyTest(ctx, networkName, pool)
}
return &ProxyTestEnv{
Network: networkName,
Pool: pool,
Lolly: lolly,
Cleanup: cleanup,
HTTPClient: &http.Client{Timeout: 10 * time.Second},
}, nil
}

View File

@ -0,0 +1,209 @@
//go:build e2e
// Package testutil 提供 E2E 测试的工具函数。
//
// 包含统一的测试环境设置函数。
//
// 作者xfy
package testutil
import (
"context"
"net/http"
"testing"
"time"
)
// E2ETestEnv E2E 测试环境。
//
// 封装测试所需的资源和清理函数。
type E2ETestEnv struct {
Ctx context.Context
Network string
Pool *BackendPool
Lolly *LollyContainer
Client *http.Client
cleanup func()
}
// SetupE2ETest 设置 E2E 测试环境。
//
// 自动处理镜像检查、后端启动、lolly 启动和资源清理。
// 使用 t.Cleanup() 确保资源正确释放。
//
// 参数:
// - t: 测试对象
// - backendCount: 后端数量
// - cfgBuilder: 配置构建函数,接收后端池返回 YAML 配置
//
// 使用示例:
//
// env := testutil.SetupE2ETest(t, 2, func(pool *testutil.BackendPool) string {
// cfg := testutil.NewConfigBuilder().
// WithServer(":8080").
// WithProxy("/", pool.InternalAddresses())
// yaml, _ := cfg.Build()
// return yaml
// })
// defer env.Cleanup()
func SetupE2ETest(t *testing.T, backendCount int, cfgBuilder func(*BackendPool) string) *E2ETestEnv {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), DefaultTestTimeout)
if !LollyImageAvailable(ctx) {
t.Skip("lolly:latest image not available, run 'make docker-build' first")
}
network, pool, err := SetupProxyTest(ctx, backendCount)
if err != nil {
cancel()
t.Fatalf("Failed to setup proxy test: %v", err)
}
cfgYAML := cfgBuilder(pool)
lolly, err := StartLolly(ctx,
WithConfigYAML(cfgYAML),
WithNetwork(network),
)
if err != nil {
CleanupProxyTest(ctx, network, pool)
cancel()
t.Fatalf("Failed to start lolly: %v", err)
}
if err := lolly.WaitForHealthy(ctx, HealthCheckWaitTimeout); err != nil {
lolly.Terminate(ctx)
CleanupProxyTest(ctx, network, pool)
cancel()
t.Fatalf("Lolly not healthy: %v", err)
}
env := &E2ETestEnv{
Ctx: ctx,
Network: network,
Pool: pool,
Lolly: lolly,
Client: CreateDefaultHTTPClient(),
}
env.cleanup = func() {
lolly.Terminate(ctx)
CleanupProxyTest(ctx, network, pool)
cancel()
}
t.Cleanup(env.Cleanup)
return env
}
// SetupE2ETestWithTimeout 设置带自定义超时的 E2E 测试环境。
func SetupE2ETestWithTimeout(t *testing.T, backendCount int, timeout time.Duration, cfgBuilder func(*BackendPool) string) *E2ETestEnv {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
if !LollyImageAvailable(ctx) {
t.Skip("lolly:latest image not available, run 'make docker-build' first")
}
network, pool, err := SetupProxyTest(ctx, backendCount)
if err != nil {
cancel()
t.Fatalf("Failed to setup proxy test: %v", err)
}
cfgYAML := cfgBuilder(pool)
lolly, err := StartLolly(ctx,
WithConfigYAML(cfgYAML),
WithNetwork(network),
)
if err != nil {
CleanupProxyTest(ctx, network, pool)
cancel()
t.Fatalf("Failed to start lolly: %v", err)
}
if err := lolly.WaitForHealthy(ctx, HealthCheckWaitTimeout); err != nil {
lolly.Terminate(ctx)
CleanupProxyTest(ctx, network, pool)
cancel()
t.Fatalf("Lolly not healthy: %v", err)
}
env := &E2ETestEnv{
Ctx: ctx,
Network: network,
Pool: pool,
Lolly: lolly,
Client: CreateDefaultHTTPClient(),
}
env.cleanup = func() {
lolly.Terminate(ctx)
CleanupProxyTest(ctx, network, pool)
cancel()
}
t.Cleanup(env.Cleanup)
return env
}
// Cleanup 手动清理资源。
func (e *E2ETestEnv) Cleanup() {
if e.cleanup != nil {
e.cleanup()
e.cleanup = nil
}
}
// HTTPURL 返回 lolly HTTP 地址。
func (e *E2ETestEnv) HTTPURL() string {
return e.Lolly.HTTPBaseURL()
}
// HTTPSURL 返回 lolly HTTPS 地址。
func (e *E2ETestEnv) HTTPSURL() string {
return e.Lolly.HTTPSBaseURL()
}
// SetupSSLTest 设置 SSL 测试环境。
//
// 用于 SSL/TLS 测试场景,自动生成证书。
func SetupSSLTest(t *testing.T, cfgBuilder func() string) (*LollyContainer, string, string) {
t.Helper()
ctx := context.Background()
if !LollyImageAvailable(ctx) {
t.Skip("lolly:latest image not available, run 'make docker-build' first")
}
// 生成自签名证书
certPath, keyPath, cleanup, err := GenerateSelfSignedCert(t.TempDir())
if err != nil {
t.Fatalf("Failed to generate certificate: %v", err)
}
t.Cleanup(cleanup)
cfgYAML := cfgBuilder()
lolly, err := StartLolly(ctx,
WithConfigYAML(cfgYAML),
WithCert(certPath, keyPath),
)
if err != nil {
t.Fatalf("Failed to start lolly: %v", err)
}
t.Cleanup(func() { lolly.Terminate(ctx) })
if err := lolly.WaitForHealthy(ctx, HealthCheckWaitTimeout); err != nil {
t.Fatalf("Lolly not healthy: %v", err)
}
return lolly, certPath, keyPath
}

View File

@ -0,0 +1,98 @@
//go:build e2e
// Package testutil 提供 E2E 测试的工具函数。
//
// 包含 SSL/TLS 测试辅助函数。
//
// 作者xfy
package testutil
import (
"crypto/tls"
"crypto/x509"
"net/http"
"os"
"time"
)
// CreateTLSClient 创建信任指定证书的 HTTPS 客户端。
//
// 参数:
// - certPath: CA 证书文件路径
//
// 返回配置好的 HTTP 客户端,信任指定的证书。
func CreateTLSClient(certPath string) (*http.Client, error) {
caCert, err := os.ReadFile(certPath)
if err != nil {
return nil, err
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
return &http.Client{
Timeout: DefaultClientTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: caCertPool,
},
},
}, nil
}
// CreateTLSClientWithVersion 创建带版本限制的 HTTPS 客户端。
//
// 参数:
// - certPath: CA 证书文件路径
// - minVersion: 最小 TLS 版本
// - maxVersion: 最大 TLS 版本
func CreateTLSClientWithVersion(certPath string, minVersion, maxVersion uint16) (*http.Client, error) {
caCert, err := os.ReadFile(certPath)
if err != nil {
return nil, err
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
return &http.Client{
Timeout: DefaultClientTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: caCertPool,
MinVersion: minVersion,
MaxVersion: maxVersion,
},
},
}, nil
}
// CreateInsecureTLSClient 创建跳过证书验证的 HTTPS 客户端。
//
// 用于测试自签名证书场景,不应在生产环境使用。
func CreateInsecureTLSClient() *http.Client {
return &http.Client{
Timeout: DefaultClientTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
}
// CreateDefaultHTTPClient 创建默认 HTTP 客户端。
//
// 用于非 SSL 测试场景。
func CreateDefaultHTTPClient() *http.Client {
return &http.Client{
Timeout: DefaultClientTimeout,
}
}
// CreateHTTPClientWithTimeout 创建带自定义超时的 HTTP 客户端。
func CreateHTTPClientWithTimeout(timeout time.Duration) *http.Client {
return &http.Client{
Timeout: timeout,
}
}

View File

@ -0,0 +1,411 @@
//go:build e2e
// Package testutil 提供 E2E 测试的工具函数。
//
// 包含 WebSocket 测试辅助工具。
//
// 作者xfy
package testutil
import (
"context"
"fmt"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
)
// WSClient WebSocket 测试客户端。
//
// 封装 gorilla/websocket提供简单的测试接口。
type WSClient struct {
conn *websocket.Conn
url string
mu sync.Mutex
closed bool
closeChan chan struct{}
}
// WSOption WebSocket 客户端选项。
type WSOption func(*wsConfig)
type wsConfig struct {
headers http.Header
pingPeriod time.Duration
pongWait time.Duration
}
// WithHeaders 设置请求头。
func WithWSHeaders(headers http.Header) WSOption {
return func(c *wsConfig) {
c.headers = headers
}
}
// WithWSTimeout 设置超时时间。
func WithWSTimeout(pongWait, pingPeriod time.Duration) WSOption {
return func(c *wsConfig) {
c.pongWait = pongWait
c.pingPeriod = pingPeriod
}
}
// NewWSClient 创建 WebSocket 客户端。
//
// 参数:
// - ctx: 上下文
// - url: WebSocket URLws:// 或 wss://
// - opts: 可选配置
//
// 返回 WebSocket 客户端实例。
func NewWSClient(ctx context.Context, url string, opts ...WSOption) (*WSClient, error) {
cfg := &wsConfig{
headers: http.Header{},
pongWait: 60 * time.Second,
pingPeriod: 54 * time.Second,
}
for _, opt := range opts {
opt(cfg)
}
dialer := websocket.DefaultDialer
conn, _, err := dialer.DialContext(ctx, url, cfg.headers)
if err != nil {
return nil, fmt.Errorf("websocket dial failed: %w", err)
}
client := &WSClient{
conn: conn,
url: url,
closeChan: make(chan struct{}),
}
// 设置 pong 处理
conn.SetPongHandler(func(string) error {
return conn.SetReadDeadline(time.Now().Add(cfg.pongWait))
})
return client, nil
}
// Send 发送文本消息。
func (c *WSClient) Send(message string) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return fmt.Errorf("connection closed")
}
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return c.conn.WriteMessage(websocket.TextMessage, []byte(message))
}
// SendBinary 发送二进制消息。
func (c *WSClient) SendBinary(data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return fmt.Errorf("connection closed")
}
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return c.conn.WriteMessage(websocket.BinaryMessage, data)
}
// SendJSON 发送 JSON 消息。
func (c *WSClient) SendJSON(v interface{}) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return fmt.Errorf("connection closed")
}
c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return c.conn.WriteJSON(v)
}
// Receive 接收文本消息。
//
// 返回消息内容和错误。
func (c *WSClient) Receive() (string, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return "", fmt.Errorf("connection closed")
}
c.conn.SetReadDeadline(time.Now().Add(30 * time.Second))
messageType, data, err := c.conn.ReadMessage()
if err != nil {
return "", err
}
if messageType != websocket.TextMessage {
return "", fmt.Errorf("expected text message, got %d", messageType)
}
return string(data), nil
}
// ReceiveBinary 接收二进制消息。
func (c *WSClient) ReceiveBinary() ([]byte, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil, fmt.Errorf("connection closed")
}
c.conn.SetReadDeadline(time.Now().Add(30 * time.Second))
messageType, data, err := c.conn.ReadMessage()
if err != nil {
return nil, err
}
if messageType != websocket.BinaryMessage {
return nil, fmt.Errorf("expected binary message, got %d", messageType)
}
return data, nil
}
// ReceiveJSON 接收 JSON 消息。
func (c *WSClient) ReceiveJSON(v interface{}) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return fmt.Errorf("connection closed")
}
c.conn.SetReadDeadline(time.Now().Add(30 * time.Second))
return c.conn.ReadJSON(v)
}
// ReceiveWithTimeout 接收消息(带超时)。
func (c *WSClient) ReceiveWithTimeout(timeout time.Duration) (string, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return "", fmt.Errorf("connection closed")
}
c.conn.SetReadDeadline(time.Now().Add(timeout))
messageType, data, err := c.conn.ReadMessage()
if err != nil {
return "", err
}
if messageType != websocket.TextMessage {
return "", fmt.Errorf("expected text message, got %d", messageType)
}
return string(data), nil
}
// ReceiveChan 返回消息通道。
//
// 在后台持续接收消息,通过通道返回。
func (c *WSClient) ReceiveChan() <-chan WSMessage {
ch := make(chan WSMessage, 10)
go func() {
defer close(ch)
for {
msg, err := c.Receive()
if err != nil {
return
}
ch <- WSMessage{Data: msg}
}
}()
return ch
}
// WSMessage WebSocket 消息。
type WSMessage struct {
Data string
Error error
}
// Close 关闭连接。
func (c *WSClient) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil
}
c.closed = true
close(c.closeChan)
// 发送关闭帧
err := c.conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
c.conn.Close()
return err
}
return c.conn.Close()
}
// IsClosed 检查连接是否已关闭。
func (c *WSClient) IsClosed() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.closed
}
// URL 返回连接 URL。
func (c *WSClient) URL() string {
return c.url
}
// CloseChan 返回关闭通道。
func (c *WSClient) CloseChan() <-chan struct{} {
return c.closeChan
}
// WSPool WebSocket 连接池。
//
// 管理多个 WebSocket 连接,用于并发测试。
type WSPool struct {
clients []*WSClient
mu sync.Mutex
}
// NewWSPool 创建 WebSocket 连接池。
//
// 参数:
// - ctx: 上下文
// - url: WebSocket URL
// - count: 连接数量
//
// 返回连接池实例。
func NewWSPool(ctx context.Context, url string, count int) (*WSPool, error) {
pool := &WSPool{
clients: make([]*WSClient, count),
}
for i := 0; i < count; i++ {
client, err := NewWSClient(ctx, url)
if err != nil {
pool.Close()
return nil, fmt.Errorf("failed to create client %d: %w", i, err)
}
pool.clients[i] = client
}
return pool, nil
}
// SendAll 向所有连接发送消息。
func (p *WSPool) SendAll(message string) error {
p.mu.Lock()
defer p.mu.Unlock()
var lastErr error
for _, client := range p.clients {
if client != nil {
if err := client.Send(message); err != nil {
lastErr = err
}
}
}
return lastErr
}
// SendOne 向指定连接发送消息。
func (p *WSPool) SendOne(index int, message string) error {
p.mu.Lock()
defer p.mu.Unlock()
if index < 0 || index >= len(p.clients) {
return fmt.Errorf("invalid index %d", index)
}
if p.clients[index] == nil {
return fmt.Errorf("client %d is nil", index)
}
return p.clients[index].Send(message)
}
// ReceiveAll 从所有连接接收消息。
//
// 返回每个连接收到的消息列表。
func (p *WSPool) ReceiveAll() ([]string, error) {
p.mu.Lock()
defer p.mu.Unlock()
messages := make([]string, len(p.clients))
var lastErr error
for i, client := range p.clients {
if client != nil {
msg, err := client.Receive()
if err != nil {
lastErr = err
messages[i] = ""
} else {
messages[i] = msg
}
}
}
return messages, lastErr
}
// Count 返回连接数量。
func (p *WSPool) Count() int {
return len(p.clients)
}
// Close 关闭所有连接。
func (p *WSPool) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
var lastErr error
for _, client := range p.clients {
if client != nil {
if err := client.Close(); err != nil {
lastErr = err
}
}
}
return lastErr
}
// WSEchoServer WebSocket Echo 服务器配置。
type WSEchoServer struct {
Port int
Handler func(*websocket.Conn)
}
// NewWSEchoHandler 创建 Echo 处理器。
//
// 将收到的消息原样返回。
func NewWSEchoHandler() func(*websocket.Conn) {
return func(conn *websocket.Conn) {
defer conn.Close()
for {
messageType, data, err := conn.ReadMessage()
if err != nil {
return
}
conn.WriteMessage(messageType, data)
}
}
}