From 0790c5a9e4e176f2c588969a148890614d328a27 Mon Sep 17 00:00:00 2001 From: xfy Date: Thu, 23 Apr 2026 14:51:11 +0800 Subject: [PATCH] =?UTF-8?q?test(e2e/testutil):=20=E6=89=A9=E5=B1=95?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E5=B7=A5=E5=85=B7=E5=8C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加配置生成、常量定义、测试设置、SSL 和 WebSocket 工具函数。 重构 container.go 支持函数式选项模式配置容器。 Co-Authored-By: Claude Opus 4.7 --- internal/e2e/testutil/concurrent.go | 8 +- internal/e2e/testutil/config.go | 585 ++++++++++++++++++++++++++++ internal/e2e/testutil/constants.go | 70 ++++ internal/e2e/testutil/container.go | 511 +++++++++++++++++++++++- internal/e2e/testutil/setup.go | 209 ++++++++++ internal/e2e/testutil/ssl.go | 98 +++++ internal/e2e/testutil/websocket.go | 411 +++++++++++++++++++ 7 files changed, 1870 insertions(+), 22 deletions(-) create mode 100644 internal/e2e/testutil/config.go create mode 100644 internal/e2e/testutil/constants.go create mode 100644 internal/e2e/testutil/setup.go create mode 100644 internal/e2e/testutil/ssl.go create mode 100644 internal/e2e/testutil/websocket.go diff --git a/internal/e2e/testutil/concurrent.go b/internal/e2e/testutil/concurrent.go index 88ab442..3e2ee7a 100644 --- a/internal/e2e/testutil/concurrent.go +++ b/internal/e2e/testutil/concurrent.go @@ -18,6 +18,7 @@ type ConcurrentRequestConfig struct { Count int Timeout time.Duration ExpectCode int + Client *http.Client // 可选的自定义客户端 } // ConcurrentRequestResult 并发请求结果。 @@ -35,8 +36,11 @@ func RunConcurrentRequests(cfg ConcurrentRequestConfig) []ConcurrentRequestResul results := make([]ConcurrentRequestResult, cfg.Count) var wg sync.WaitGroup - client := &http.Client{ - Timeout: cfg.Timeout, + client := cfg.Client + if client == nil { + client = &http.Client{ + Timeout: cfg.Timeout, + } } for i := 0; i < cfg.Count; i++ { diff --git a/internal/e2e/testutil/config.go b/internal/e2e/testutil/config.go new file mode 100644 index 0000000..de7126d --- /dev/null +++ b/internal/e2e/testutil/config.go @@ -0,0 +1,585 @@ +//go:build e2e + +// Package testutil 提供 E2E 测试的工具函数。 +// +// 包含动态配置生成器,支持编程方式生成 YAML 配置文件。 +// +// 作者:xfy +package testutil + +import ( + "fmt" + "os" + "path/filepath" + "time" + + "gopkg.in/yaml.v3" + + "rua.plus/lolly/internal/config" +) + +// ConfigBuilder 动态配置构建器。 +// +// 支持编程方式生成 YAML 配置,用于 E2E 测试场景。 +// 提供链式调用方式,方便组合不同配置。 +// +// 使用示例: +// +// cfg := testutil.NewConfigBuilder(). +// WithServer(":8080"). +// WithProxy("/api/", targets). +// WithSSL(certPath, keyPath). +// Build() +type ConfigBuilder struct { + cfg *config.Config +} + +// NewConfigBuilder 创建配置构建器。 +func NewConfigBuilder() *ConfigBuilder { + return &ConfigBuilder{ + cfg: &config.Config{ + Servers: []config.ServerConfig{}, + }, + } +} + +// WithServer 添加服务器配置。 +// +// 参数: +// - listen: 监听地址,如 ":8080" +// +// 返回构建器以支持链式调用。 +func (b *ConfigBuilder) WithServer(listen string) *ConfigBuilder { + b.cfg.Servers = append(b.cfg.Servers, config.ServerConfig{ + Listen: listen, + }) + return b +} + +// WithServerConfig 添加完整服务器配置。 +func (b *ConfigBuilder) WithServerConfig(server config.ServerConfig) *ConfigBuilder { + b.cfg.Servers = append(b.cfg.Servers, server) + return b +} + +// ProxyTargetOption 代理目标选项。 +type ProxyTargetOption func(*config.ProxyTarget) + +// WithWeight 设置权重。 +func WithWeight(weight int) ProxyTargetOption { + return func(t *config.ProxyTarget) { + t.Weight = weight + } +} + +// WithMaxConns 设置最大连接数。 +func WithMaxConns(maxConns int) ProxyTargetOption { + return func(t *config.ProxyTarget) { + t.MaxConns = maxConns + } +} + +// WithMaxFails 设置最大失败次数。 +func WithMaxFails(maxFails int, failTimeout time.Duration) ProxyTargetOption { + return func(t *config.ProxyTarget) { + t.MaxFails = maxFails + t.FailTimeout = failTimeout + } +} + +// WithBackup 设置为备份服务器。 +func WithBackup() ProxyTargetOption { + return func(t *config.ProxyTarget) { + t.Backup = true + } +} + +// ProxyOption 代理配置选项。 +type ProxyOption func(*config.ProxyConfig) + +// ProxyConfig 代理配置类型别名。 +type ProxyConfig = config.ProxyConfig + +// WithLoadBalance 设置负载均衡算法。 +func WithLoadBalance(algorithm string) ProxyOption { + return func(p *config.ProxyConfig) { + p.LoadBalance = algorithm + } +} + +// WithHealthCheck 设置健康检查。 +func WithHealthCheck(path string, interval, timeout time.Duration) ProxyOption { + return func(p *config.ProxyConfig) { + p.HealthCheck = config.HealthCheckConfig{ + Path: path, + Interval: interval, + Timeout: timeout, + } + } +} + +// WithProxyTimeout 设置代理超时。 +func WithProxyTimeout(connect, read, write time.Duration) ProxyOption { + return func(p *config.ProxyConfig) { + p.Timeout = config.ProxyTimeout{ + Connect: connect, + Read: read, + Write: write, + } + } +} + +// WithProxyHeaders 设置代理头部。 +func WithProxyHeaders(setRequest, setResponse map[string]string) ProxyOption { + return func(p *config.ProxyConfig) { + p.Headers = config.ProxyHeaders{ + SetRequest: setRequest, + SetResponse: setResponse, + } + } +} + +// WithProxyCache 设置代理缓存。 +func WithProxyCache(maxAge time.Duration, cacheLock bool) ProxyOption { + return func(p *config.ProxyConfig) { + p.Cache = config.ProxyCacheConfig{ + Enabled: true, + MaxAge: maxAge, + CacheLock: cacheLock, + } + } +} + +// WithProxySSL 设置上游 SSL。 +func WithProxySSL(serverName string, insecureSkipVerify bool) ProxyOption { + return func(p *config.ProxyConfig) { + p.ProxySSL = &config.ProxySSLConfig{ + Enabled: true, + ServerName: serverName, + InsecureSkipVerify: insecureSkipVerify, + } + } +} + +// WithProxyBuffering 设置代理缓冲。 +func WithProxyBuffering(mode string, bufferSize int) ProxyOption { + return func(p *config.ProxyConfig) { + p.Buffering = &config.ProxyBufferingConfig{ + Mode: mode, + BufferSize: bufferSize, + } + } +} + +// WithProxyNextUpstream 设置故障转移。 +func WithProxyNextUpstream(tries int, httpCodes []int) ProxyOption { + return func(p *config.ProxyConfig) { + p.NextUpstream = config.NextUpstreamConfig{ + Tries: tries, + HTTPCodes: httpCodes, + } + } +} + +// WithProxy 添加代理配置。 +// +// 参数: +// - path: 代理路径前缀 +// - urls: 后端 URL 列表 +// - opts: 可选配置选项 +// +// 返回构建器以支持链式调用。 +func (b *ConfigBuilder) WithProxy(path string, urls []string, opts ...ProxyOption) *ConfigBuilder { + if len(b.cfg.Servers) == 0 { + b.WithServer(":8080") + } + + targets := make([]config.ProxyTarget, len(urls)) + for i, url := range urls { + targets[i] = config.ProxyTarget{ + URL: url, + } + } + + proxy := config.ProxyConfig{ + Path: path, + Targets: targets, + } + + for _, opt := range opts { + opt(&proxy) + } + + // 添加到第一个服务器 + b.cfg.Servers[0].Proxy = append(b.cfg.Servers[0].Proxy, proxy) + return b +} + +// WithProxyTargets 添加带选项的代理配置。 +// +// targetOptsPerTarget 是每个目标的选项列表(索引对应 urls)。 +// opts 是代理级别的选项。 +func (b *ConfigBuilder) WithProxyTargets(path string, urls []string, targetOptsPerTarget [][]ProxyTargetOption, opts ...ProxyOption) *ConfigBuilder { + if len(b.cfg.Servers) == 0 { + b.WithServer(":8080") + } + + targets := make([]config.ProxyTarget, len(urls)) + for i, url := range urls { + targets[i] = config.ProxyTarget{ + URL: url, + } + // 应用该目标的选项 + if i < len(targetOptsPerTarget) { + for _, opt := range targetOptsPerTarget[i] { + opt(&targets[i]) + } + } + } + + proxy := config.ProxyConfig{ + Path: path, + Targets: targets, + } + + for _, opt := range opts { + opt(&proxy) + } + + b.cfg.Servers[0].Proxy = append(b.cfg.Servers[0].Proxy, proxy) + return b +} + +// SSLOption SSL 配置选项。 +type SSLOption func(*config.SSLConfig) + +// WithHTTP2 启用 HTTP/2。 +func WithHTTP2(enabled bool, maxConcurrentStreams int) SSLOption { + return func(s *config.SSLConfig) { + s.HTTP2 = config.HTTP2Config{ + Enabled: enabled, + MaxConcurrentStreams: maxConcurrentStreams, + } + } +} + +// WithTLSProtocols 设置 TLS 协议版本。 +func WithTLSProtocols(protocols []string) SSLOption { + return func(s *config.SSLConfig) { + s.Protocols = protocols + } +} + +// WithSessionTickets 启用 Session Tickets。 +func WithSessionTickets(enabled bool) SSLOption { + return func(s *config.SSLConfig) { + s.SessionTickets = config.SessionTicketsConfig{ + Enabled: enabled, + } + } +} + +// WithHSTS 配置 HSTS。 +func WithHSTS(maxAge int, includeSubDomains bool) SSLOption { + return func(s *config.SSLConfig) { + s.HSTS = config.HSTSConfig{ + MaxAge: maxAge, + IncludeSubDomains: includeSubDomains, + } + } +} + +// WithSSL 配置 SSL/TLS。 +// +// 参数: +// - cert: 证书文件路径 +// - key: 私钥文件路径 +// - opts: 可选 SSL 配置 +// +// 返回构建器以支持链式调用。 +func (b *ConfigBuilder) WithSSL(cert, key string, opts ...SSLOption) *ConfigBuilder { + if len(b.cfg.Servers) == 0 { + b.WithServer(":8443") + } + + ssl := config.SSLConfig{ + Cert: cert, + Key: key, + } + + for _, opt := range opts { + opt(&ssl) + } + + b.cfg.Servers[0].SSL = ssl + return b +} + +// StaticOption 静态文件配置选项。 +type StaticOption func(*config.StaticConfig) + +// WithIndex 设置索引文件。 +func WithIndex(index []string) StaticOption { + return func(s *config.StaticConfig) { + s.Index = index + } +} + +// WithTryFiles 设置 try_files。 +func WithTryFiles(tryFiles []string) StaticOption { + return func(s *config.StaticConfig) { + s.TryFiles = tryFiles + } +} + +// WithStatic 添加静态文件配置。 +func (b *ConfigBuilder) WithStatic(path, root string, opts ...StaticOption) *ConfigBuilder { + if len(b.cfg.Servers) == 0 { + b.WithServer(":8080") + } + + static := config.StaticConfig{ + Path: path, + Root: root, + } + + for _, opt := range opts { + opt(&static) + } + + b.cfg.Servers[0].Static = append(b.cfg.Servers[0].Static, static) + return b +} + +// SecurityOption 安全配置选项。 +type SecurityOption func(*config.SecurityConfig) + +// WithRateLimit 设置速率限制。 +func WithRateLimit(requestRate, burst int) SecurityOption { + return func(s *config.SecurityConfig) { + s.RateLimit = config.RateLimitConfig{ + RequestRate: requestRate, + Burst: burst, + } + } +} + +// WithAccessControl 设置访问控制。 +func WithAccessControl(allow, deny []string, defaultAction string) SecurityOption { + return func(s *config.SecurityConfig) { + s.Access = config.AccessConfig{ + Allow: allow, + Deny: deny, + Default: defaultAction, + } + } +} + +// WithBasicAuth 设置 Basic 认证。 +func WithBasicAuth(users []config.User) SecurityOption { + return func(s *config.SecurityConfig) { + s.Auth = config.AuthConfig{ + Type: "basic", + Users: users, + } + } +} + +// WithSecurity 配置安全选项。 +func (b *ConfigBuilder) WithSecurity(opts ...SecurityOption) *ConfigBuilder { + if len(b.cfg.Servers) == 0 { + b.WithServer(":8080") + } + + for _, opt := range opts { + opt(&b.cfg.Servers[0].Security) + } + return b +} + +// CompressionOption 压缩配置选项。 +type CompressionOption func(*config.CompressionConfig) + +// WithCompressionType 设置压缩类型。 +func WithCompressionType(typ string) CompressionOption { + return func(c *config.CompressionConfig) { + c.Type = typ + } +} + +// WithCompressionLevel 设置压缩级别。 +func WithCompressionLevel(level int) CompressionOption { + return func(c *config.CompressionConfig) { + c.Level = level + } +} + +// WithCompressionMinSize 设置最小压缩大小。 +func WithCompressionMinSize(minSize int) CompressionOption { + return func(c *config.CompressionConfig) { + c.MinSize = minSize + } +} + +// WithCompression 配置压缩。 +func (b *ConfigBuilder) WithCompression(opts ...CompressionOption) *ConfigBuilder { + if len(b.cfg.Servers) == 0 { + b.WithServer(":8080") + } + + for _, opt := range opts { + opt(&b.cfg.Servers[0].Compression) + } + return b +} + +// RewriteOption 重写规则选项。 +type RewriteOption func(*config.RewriteRule) + +// WithRewriteFlag 设置重写标志。 +func WithRewriteFlag(flag string) RewriteOption { + return func(r *config.RewriteRule) { + r.Flag = flag + } +} + +// WithRewrite 添加 URL 重写规则。 +func (b *ConfigBuilder) WithRewrite(pattern, replacement string, opts ...RewriteOption) *ConfigBuilder { + if len(b.cfg.Servers) == 0 { + b.WithServer(":8080") + } + + rule := config.RewriteRule{ + Pattern: pattern, + Replacement: replacement, + } + + for _, opt := range opts { + opt(&rule) + } + + b.cfg.Servers[0].Rewrite = append(b.cfg.Servers[0].Rewrite, rule) + return b +} + +// WithCachePath 配置缓存路径。 +func (b *ConfigBuilder) WithCachePath(path string, maxSize int64) *ConfigBuilder { + b.cfg.CachePath = &config.ProxyCachePathConfig{ + Path: path, + MaxSize: maxSize, + } + return b +} + +// WithResolver 配置 DNS 解析器。 +func (b *ConfigBuilder) WithResolver(addresses []string, valid, timeout time.Duration) *ConfigBuilder { + b.cfg.Resolver = config.ResolverConfig{ + Enabled: true, + Addresses: addresses, + Valid: valid, + Timeout: timeout, + } + return b +} + +// WithLogging 配置日志。 +func (b *ConfigBuilder) WithLogging(format string) *ConfigBuilder { + b.cfg.Logging = config.LoggingConfig{ + Format: format, + } + return b +} + +// WithShutdown 配置关闭超时。 +func (b *ConfigBuilder) WithShutdown(graceful, fast time.Duration) *ConfigBuilder { + b.cfg.Shutdown = config.ShutdownConfig{ + GracefulTimeout: graceful, + FastTimeout: fast, + } + return b +} + +// Build 生成 YAML 配置字符串。 +// +// 返回 YAML 格式的配置字符串。 +func (b *ConfigBuilder) Build() (string, error) { + data, err := yaml.Marshal(b.cfg) + if err != nil { + return "", fmt.Errorf("序列化配置失败: %w", err) + } + return string(data), nil +} + +// WriteTemp 写入临时文件。 +// +// 创建临时目录并写入配置文件,返回文件路径。 +// 调用者负责清理临时目录。 +func (b *ConfigBuilder) WriteTemp() (string, error) { + yamlStr, err := b.Build() + if err != nil { + return "", err + } + + // 创建临时目录 + tmpDir, err := os.MkdirTemp("", "lolly-e2e-*") + if err != nil { + return "", fmt.Errorf("创建临时目录失败: %w", err) + } + + // 写入配置文件 + configPath := filepath.Join(tmpDir, "lolly.yaml") + if err := os.WriteFile(configPath, []byte(yamlStr), 0o644); err != nil { + os.RemoveAll(tmpDir) + return "", fmt.Errorf("写入配置文件失败: %w", err) + } + + return configPath, nil +} + +// WriteTo 写入指定目录。 +func (b *ConfigBuilder) WriteTo(dir string) (string, error) { + yamlStr, err := b.Build() + if err != nil { + return "", err + } + + // 确保目录存在 + if err := os.MkdirAll(dir, 0o755); err != nil { + return "", fmt.Errorf("创建目录失败: %w", err) + } + + configPath := filepath.Join(dir, "lolly.yaml") + if err := os.WriteFile(configPath, []byte(yamlStr), 0o644); err != nil { + return "", fmt.Errorf("写入配置文件失败: %w", err) + } + + return configPath, nil +} + +// GetConfig 返回配置对象。 +func (b *ConfigBuilder) GetConfig() *config.Config { + return b.cfg +} + +// Reset 重置构建器。 +func (b *ConfigBuilder) Reset() *ConfigBuilder { + b.cfg = &config.Config{ + Servers: []config.ServerConfig{}, + } + return b +} + +// Clone 克隆构建器。 +func (b *ConfigBuilder) Clone() *ConfigBuilder { + data, err := yaml.Marshal(b.cfg) + if err != nil { + return NewConfigBuilder() + } + + var newCfg config.Config + if err := yaml.Unmarshal(data, &newCfg); err != nil { + return NewConfigBuilder() + } + + return &ConfigBuilder{cfg: &newCfg} +} \ No newline at end of file diff --git a/internal/e2e/testutil/constants.go b/internal/e2e/testutil/constants.go new file mode 100644 index 0000000..4ce47fe --- /dev/null +++ b/internal/e2e/testutil/constants.go @@ -0,0 +1,70 @@ +//go:build e2e + +// Package testutil 提供 E2E 测试的工具函数。 +// +// 包含测试常量定义。 +// +// 作者:xfy +package testutil + +import ( + "crypto/tls" + "time" +) + +// 测试超时常量。 +const ( + // ContainerStartupTimeout 容器启动超时。 + ContainerStartupTimeout = 30 * time.Second + + // HealthCheckWaitTimeout 健康检查等待超时。 + HealthCheckWaitTimeout = 30 * time.Second + + // HealthCheckDetectionTime 健康检查检测时间。 + HealthCheckDetectionTime = 10 * time.Second + + // CacheExpireBuffer 缓存过期缓冲时间。 + CacheExpireBuffer = 1 * time.Second + + // DefaultTestTimeout 测试上下文超时。 + DefaultTestTimeout = 180 * time.Second + + // DefaultClientTimeout HTTP 客户端超时。 + DefaultClientTimeout = 10 * time.Second + + // ConcurrentRequestTimeout 并发请求超时。 + ConcurrentRequestTimeout = 30 * time.Second + + // ShortTestTimeout 短测试超时(用于快速测试)。 + ShortTestTimeout = 60 * time.Second + + // MediumTestTimeout 中等测试超时。 + MediumTestTimeout = 120 * time.Second +) + +// 测试配置常量。 +const ( + // DefaultBackendCount 默认后端数量。 + DefaultBackendCount = 2 + + // DefaultConcurrentRequests 并发请求数量。 + DefaultConcurrentRequests = 10 + + // HighConcurrentRequests 高并发请求数量。 + HighConcurrentRequests = 20 + + // CacheTestMaxAge 缓存测试过期时间。 + CacheTestMaxAge = 5 * time.Minute + + // CacheTestShortMaxAge 短缓存过期时间(用于过期测试)。 + CacheTestShortMaxAge = 2 * time.Second +) + +// TLS 版本常量(用于配置客户端)。 +const ( + // TLSVersion12 TLS 1.2。 + TLSVersion12 = tls.VersionTLS12 + + // TLSVersion13 TLS 1.3。 + TLSVersion13 = tls.VersionTLS13 +) diff --git a/internal/e2e/testutil/container.go b/internal/e2e/testutil/container.go index 47a3513..27e6833 100644 --- a/internal/e2e/testutil/container.go +++ b/internal/e2e/testutil/container.go @@ -29,6 +29,89 @@ servers: - "index.html" ` +// LollyContainerOption 容器启动选项。 +type LollyContainerOption func(*lollyContainerConfig) + +// lollyContainerConfig 容器配置。 +type lollyContainerConfig struct { + configPath string + configYAML string + network string + certPath string + keyPath string + extraMounts []testcontainers.ContainerMount + env map[string]string + exposedPorts []string + waitFor wait.Strategy +} + +// WithConfigFile 使用配置文件路径。 +func WithConfigFile(path string) LollyContainerOption { + return func(c *lollyContainerConfig) { + c.configPath = path + } +} + +// WithConfigYAML 使用 YAML 字符串配置。 +func WithConfigYAML(yaml string) LollyContainerOption { + return func(c *lollyContainerConfig) { + c.configYAML = yaml + } +} + +// WithNetwork 加入指定网络。 +func WithNetwork(name string) LollyContainerOption { + return func(c *lollyContainerConfig) { + c.network = name + } +} + +// WithCert 挂载证书文件。 +func WithCert(certPath, keyPath string) LollyContainerOption { + return func(c *lollyContainerConfig) { + c.certPath = certPath + c.keyPath = keyPath + } +} + +// WithExtraMount 添加额外挂载。 +func WithExtraMount(hostPath, containerPath string) LollyContainerOption { + return func(c *lollyContainerConfig) { + c.extraMounts = append(c.extraMounts, testcontainers.ContainerMount{ + Source: testcontainers.GenericBindMountSource{ + HostPath: hostPath, + }, + Target: testcontainers.ContainerMountTarget(containerPath), + }) + } +} + +// WithEnv 设置环境变量。 +func WithEnv(env map[string]string) LollyContainerOption { + return func(c *lollyContainerConfig) { + if c.env == nil { + c.env = make(map[string]string) + } + for k, v := range env { + c.env[k] = v + } + } +} + +// WithExposedPorts 设置暴露端口。 +func WithExposedPorts(ports ...string) LollyContainerOption { + return func(c *lollyContainerConfig) { + c.exposedPorts = ports + } +} + +// WithWaitStrategy 设置等待策略。 +func WithWaitStrategy(strategy wait.Strategy) LollyContainerOption { + return func(c *lollyContainerConfig) { + c.waitFor = strategy + } +} + // LollyContainer 封装 lolly 服务器容器。 type LollyContainer struct { Container testcontainers.Container @@ -40,34 +123,99 @@ type LollyContainer struct { // StartLollyContainer 启动 lolly 服务器容器。 // // 使用预构建的 lolly 镜像。如果 configPath 为空,使用默认配置。 +// 支持通过选项函数自定义配置。 func StartLollyContainer(ctx context.Context, configPath string) (*LollyContainer, error) { - req := testcontainers.ContainerRequest{ - Image: "lolly:latest", - ExposedPorts: []string{"8080/tcp", "8443/tcp"}, - WaitingFor: wait.ForLog("HTTP 服务器启动中").WithStartupTimeout(30 * time.Second), + return StartLolly(ctx, WithConfigFile(configPath)) +} + +// StartLolly 启动 lolly 容器(增强版)。 +// +// 支持多种配置方式和自定义选项。 +// +// 使用示例: +// +// // 使用默认配置 +// lolly, err := StartLolly(ctx) +// +// // 使用配置文件 +// lolly, err := StartLolly(ctx, WithConfigFile("/path/to/config.yaml")) +// +// // 使用动态配置 +// cfg := NewConfigBuilder().WithProxy("/api/", targets).Build() +// lolly, err := StartLolly(ctx, WithConfigYAML(cfg)) +// +// // 使用 SSL +// lolly, err := StartLolly(ctx, WithConfigBuilder(cfg), WithCert(certPath, keyPath)) +func StartLolly(ctx context.Context, opts ...LollyContainerOption) (*LollyContainer, error) { + cfg := &lollyContainerConfig{ + exposedPorts: []string{"8080/tcp", "8443/tcp"}, + waitFor: wait.ForLog("HTTP 服务器启动中").WithStartupTimeout(30 * time.Second), } - // 配置文件挂载 - if configPath != "" { - req.Mounts = []testcontainers.ContainerMount{ - { - Source: testcontainers.GenericBindMountSource{ - HostPath: configPath, - }, - Target: "/etc/lolly/lolly.yaml", + for _, opt := range opts { + opt(cfg) + } + + req := testcontainers.ContainerRequest{ + Image: "lolly:latest", + ExposedPorts: cfg.exposedPorts, + WaitingFor: cfg.waitFor, + } + + // 设置环境变量 + if len(cfg.env) > 0 { + req.Env = cfg.env + } + + // 配置网络 + if cfg.network != "" { + req.Networks = []string{cfg.network} + } + + // 处理配置文件 + if cfg.configPath != "" { + req.Mounts = append(req.Mounts, testcontainers.ContainerMount{ + Source: testcontainers.GenericBindMountSource{ + HostPath: cfg.configPath, }, - } + Target: "/etc/lolly/lolly.yaml", + }) + } else if cfg.configYAML != "" { + req.Files = append(req.Files, testcontainers.ContainerFile{ + Reader: strings.NewReader(cfg.configYAML), + ContainerFilePath: "/etc/lolly/lolly.yaml", + FileMode: 0o644, + }) } else { // 使用内嵌默认配置 - req.Files = []testcontainers.ContainerFile{ - { - Reader: strings.NewReader(defaultLollyConfig), - ContainerFilePath: "/etc/lolly/lolly.yaml", - FileMode: 0o644, - }, - } + req.Files = append(req.Files, testcontainers.ContainerFile{ + Reader: strings.NewReader(defaultLollyConfig), + ContainerFilePath: "/etc/lolly/lolly.yaml", + FileMode: 0o644, + }) } + // 挂载证书 + if cfg.certPath != "" && cfg.keyPath != "" { + req.Mounts = append(req.Mounts, + testcontainers.ContainerMount{ + Source: testcontainers.GenericBindMountSource{ + HostPath: cfg.certPath, + }, + Target: "/etc/lolly/ssl/server.crt", + }, + testcontainers.ContainerMount{ + Source: testcontainers.GenericBindMountSource{ + HostPath: cfg.keyPath, + }, + Target: "/etc/lolly/ssl/server.key", + }, + ) + } + + // 添加额外挂载 + req.Mounts = append(req.Mounts, cfg.extraMounts...) + container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ ContainerRequest: req, Started: true, @@ -226,6 +374,11 @@ func LollyImageAvailable(ctx context.Context) bool { // // 使用 nginx 作为模拟后端,返回容器和访问地址。 // 注意:此函数仅用于代理测试的后端模拟,不应作为被测系统。 +// +// 返回值: +// - container: 容器实例 +// - hostPort: 宿主机访问地址(用于测试代码访问) +// - internalAddr: 容器内部访问地址(用于 lolly 配置) func StartMockBackend(ctx context.Context) (testcontainers.Container, string, error) { req := testcontainers.ContainerRequest{ Image: "nginx:alpine", @@ -253,6 +406,324 @@ func StartMockBackend(ctx context.Context) (testcontainers.Container, string, er return nil, "", fmt.Errorf("failed to get port: %w", err) } + // 返回宿主机地址 addr := fmt.Sprintf("http://%s:%s", host, port.Port()) return container, addr, nil } + +// BackendPool 后端池管理。 +// +// 管理多个后端容器,用于负载均衡测试。 +// 支持网络模式:当 network 不为空时,容器加入指定网络, +// 并提供内部地址供 lolly 容器访问。 +type BackendPool struct { + containers []testcontainers.Container + addresses []string // 宿主机访问地址 + internal []string // 容器网络内部地址 + network string // Docker 网络名称 +} + +// StartBackendPool 启动多个后端容器。 +// +// 参数: +// - count: 后端数量 +// +// 返回后端池和地址列表(宿主机访问地址)。 +func StartBackendPool(ctx context.Context, count int) (*BackendPool, error) { + return StartBackendPoolWithNetwork(ctx, count, "") +} + +// StartBackendPoolWithNetwork 启动多个后端容器并加入网络。 +// +// 参数: +// - count: 后端数量 +// - network: Docker 网络名称(可选,为空则不加入网络) +// +// 当 network 不为空时,容器会加入该网络,并提供内部地址。 +func StartBackendPoolWithNetwork(ctx context.Context, count int, network string) (*BackendPool, error) { + pool := &BackendPool{ + containers: make([]testcontainers.Container, count), + addresses: make([]string, count), + internal: make([]string, count), + network: network, + } + + for i := 0; i < count; i++ { + container, addr, internalAddr, err := startMockBackendWithNetwork(ctx, network, i) + if err != nil { + // 清理已启动的容器 + pool.Terminate(ctx) + return nil, fmt.Errorf("failed to start backend %d: %w", i, err) + } + pool.containers[i] = container + pool.addresses[i] = addr + pool.internal[i] = internalAddr + } + + return pool, nil +} + +// startMockBackendWithNetwork 启动单个后端容器。 +func startMockBackendWithNetwork(ctx context.Context, network string, index int) (testcontainers.Container, string, string, error) { + // 生成容器名称(用于网络通信) + containerName := fmt.Sprintf("backend-%d-%d", time.Now().UnixNano(), index) + + req := testcontainers.ContainerRequest{ + Image: "nginx:alpine", + ExposedPorts: []string{"80/tcp"}, + WaitingFor: wait.ForHTTP("/").WithStartupTimeout(30 * time.Second), + Name: containerName, + } + + if network != "" { + req.Networks = []string{network} + } + + container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + if err != nil { + return nil, "", "", fmt.Errorf("failed to start mock backend: %w", err) + } + + host, err := container.Host(ctx) + if err != nil { + container.Terminate(ctx) + return nil, "", "", fmt.Errorf("failed to get host: %w", err) + } + + port, err := container.MappedPort(ctx, "80/tcp") + if err != nil { + container.Terminate(ctx) + return nil, "", "", fmt.Errorf("failed to get port: %w", err) + } + + // 宿主机访问地址 + hostAddr := fmt.Sprintf("http://%s:%s", host, port.Port()) + + // 容器网络内部地址(使用容器名称) + internalAddr := fmt.Sprintf("http://%s:80", containerName) + + return container, hostAddr, internalAddr, nil +} + +// Addresses 返回后端地址列表(宿主机访问地址)。 +func (p *BackendPool) Addresses() []string { + return p.addresses +} + +// InternalAddresses 返回容器网络内部地址列表。 +// +// 当 lolly 和后端在同一 Docker 网络时,应使用此地址。 +func (p *BackendPool) InternalAddresses() []string { + return p.internal +} + +// Containers 返回容器列表。 +func (p *BackendPool) Containers() []testcontainers.Container { + return p.containers +} + +// Count 返回后端数量。 +func (p *BackendPool) Count() int { + return len(p.containers) +} + +// Terminate 终止所有容器。 +func (p *BackendPool) Terminate(ctx context.Context) { + for _, container := range p.containers { + if container != nil { + container.Terminate(ctx) + } + } +} + +// TerminateOne 终止指定索引的容器。 +func (p *BackendPool) TerminateOne(ctx context.Context, index int) error { + if index < 0 || index >= len(p.containers) { + return fmt.Errorf("invalid index %d", index) + } + if p.containers[index] != nil { + err := p.containers[index].Terminate(ctx) + p.containers[index] = nil + p.addresses[index] = "" + p.internal[index] = "" + return err + } + return nil +} + +// RestartOne 重启指定索引的容器。 +func (p *BackendPool) RestartOne(ctx context.Context, index int) error { + if index < 0 || index >= len(p.containers) { + return fmt.Errorf("invalid index %d", index) + } + + // 先终止旧容器 + if p.containers[index] != nil { + p.containers[index].Terminate(ctx) + } + + // 启动新容器 + container, addr, internalAddr, err := startMockBackendWithNetwork(ctx, p.network, index) + if err != nil { + return err + } + + p.containers[index] = container + p.addresses[index] = addr + p.internal[index] = internalAddr + return nil +} + +// CreateNetwork 创建 Docker 网络。 +// +// 用于容器间通信。 +func CreateNetwork(ctx context.Context, name string) (testcontainers.Network, error) { + network, err := testcontainers.GenericNetwork(ctx, testcontainers.GenericNetworkRequest{ + NetworkRequest: testcontainers.NetworkRequest{ + Name: name, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to create network: %w", err) + } + return network, nil +} + +// sharedNetworkName 共享网络名称。 +const sharedNetworkName = "lolly-e2e-test" + +// SetupProxyTest 设置代理测试环境。 +// +// 创建网络、启动后端池,返回网络名称和后端池。 +// lolly 容器应使用 InternalAddresses() 作为代理目标。 +// 使用共享网络名称,避免网络地址池耗尽。 +// +// 使用示例: +// +// network, pool, err := testutil.SetupProxyTest(ctx, 2) +// if err != nil { +// t.Fatal(err) +// } +// defer testutil.CleanupProxyTest(ctx, network, pool) +// +// lolly, err := testutil.StartLolly(ctx, +// testutil.WithConfigYAML(configYAML), +// testutil.WithNetwork(network), +// ) +func SetupProxyTest(ctx context.Context, backendCount int) (string, *BackendPool, error) { + // 使用共享网络名称 + networkName := sharedNetworkName + + // 尝试创建网络(如果不存在) + // 忽略"已存在"错误 + _, err := CreateNetwork(ctx, networkName) + if err != nil && !isNetworkExistsError(err) { + return "", nil, fmt.Errorf("failed to create network: %w", err) + } + + // 启动后端池并加入网络 + pool, err := StartBackendPoolWithNetwork(ctx, backendCount, networkName) + if err != nil { + return "", nil, fmt.Errorf("failed to start backend pool: %w", err) + } + + return networkName, pool, nil +} + +// isNetworkExistsError 检查是否是网络已存在错误。 +func isNetworkExistsError(err error) bool { + return err != nil && (containsString(err.Error(), "already exists") || + containsString(err.Error(), "network with name") || + containsString(err.Error(), "failed to create network")) +} + +// containsString 检查字符串是否包含子串。 +func containsString(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsSubstring(s, substr)) +} + +func containsSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// CleanupProxyTest 清理代理测试环境。 +func CleanupProxyTest(ctx context.Context, networkName string, pool *BackendPool) { + if pool != nil { + pool.Terminate(ctx) + } + // 网络会随容器终止自动清理 +} + +// ProxyTestEnv 代理测试环境。 +// +// 封装代理测试所需的资源。 +type ProxyTestEnv struct { + Network string + Pool *BackendPool + Lolly *LollyContainer + Cleanup func() + HTTPClient *http.Client +} + +// SetupProxyTestEnv 设置完整的代理测试环境。 +// +// 创建网络、启动后端池、启动 lolly,返回封装的环境。 +// 这是一个便捷函数,简化测试设置。 +// +// 使用示例: +// +// env, err := testutil.SetupProxyTestEnv(ctx, t, 2, func(pool *testutil.BackendPool) string { +// cfg := testutil.NewConfigBuilder(). +// WithServer(":8080"). +// WithProxy("/", pool.InternalAddresses()) +// yaml, _ := cfg.Build() +// return yaml +// }) +func SetupProxyTestEnv(ctx context.Context, backendCount int, configBuilder func(*BackendPool) string) (*ProxyTestEnv, error) { + // 设置代理测试环境 + networkName, pool, err := SetupProxyTest(ctx, backendCount) + if err != nil { + return nil, err + } + + // 构建配置 + configYAML := configBuilder(pool) + + // 启动 lolly + lolly, err := StartLolly(ctx, + WithConfigYAML(configYAML), + WithNetwork(networkName), + ) + if err != nil { + CleanupProxyTest(ctx, networkName, pool) + return nil, fmt.Errorf("failed to start lolly: %w", err) + } + + // 等待健康 + if err := lolly.WaitForHealthy(ctx, 30*time.Second); err != nil { + CleanupProxyTest(ctx, networkName, pool) + lolly.Terminate(ctx) + return nil, fmt.Errorf("lolly not healthy: %w", err) + } + + cleanup := func() { + lolly.Terminate(ctx) + CleanupProxyTest(ctx, networkName, pool) + } + + return &ProxyTestEnv{ + Network: networkName, + Pool: pool, + Lolly: lolly, + Cleanup: cleanup, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, + }, nil +} diff --git a/internal/e2e/testutil/setup.go b/internal/e2e/testutil/setup.go new file mode 100644 index 0000000..2800641 --- /dev/null +++ b/internal/e2e/testutil/setup.go @@ -0,0 +1,209 @@ +//go:build e2e + +// Package testutil 提供 E2E 测试的工具函数。 +// +// 包含统一的测试环境设置函数。 +// +// 作者:xfy +package testutil + +import ( + "context" + "net/http" + "testing" + "time" +) + +// E2ETestEnv E2E 测试环境。 +// +// 封装测试所需的资源和清理函数。 +type E2ETestEnv struct { + Ctx context.Context + Network string + Pool *BackendPool + Lolly *LollyContainer + Client *http.Client + cleanup func() +} + +// SetupE2ETest 设置 E2E 测试环境。 +// +// 自动处理镜像检查、后端启动、lolly 启动和资源清理。 +// 使用 t.Cleanup() 确保资源正确释放。 +// +// 参数: +// - t: 测试对象 +// - backendCount: 后端数量 +// - cfgBuilder: 配置构建函数,接收后端池返回 YAML 配置 +// +// 使用示例: +// +// env := testutil.SetupE2ETest(t, 2, func(pool *testutil.BackendPool) string { +// cfg := testutil.NewConfigBuilder(). +// WithServer(":8080"). +// WithProxy("/", pool.InternalAddresses()) +// yaml, _ := cfg.Build() +// return yaml +// }) +// defer env.Cleanup() +func SetupE2ETest(t *testing.T, backendCount int, cfgBuilder func(*BackendPool) string) *E2ETestEnv { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), DefaultTestTimeout) + + if !LollyImageAvailable(ctx) { + t.Skip("lolly:latest image not available, run 'make docker-build' first") + } + + network, pool, err := SetupProxyTest(ctx, backendCount) + if err != nil { + cancel() + t.Fatalf("Failed to setup proxy test: %v", err) + } + + cfgYAML := cfgBuilder(pool) + + lolly, err := StartLolly(ctx, + WithConfigYAML(cfgYAML), + WithNetwork(network), + ) + if err != nil { + CleanupProxyTest(ctx, network, pool) + cancel() + t.Fatalf("Failed to start lolly: %v", err) + } + + if err := lolly.WaitForHealthy(ctx, HealthCheckWaitTimeout); err != nil { + lolly.Terminate(ctx) + CleanupProxyTest(ctx, network, pool) + cancel() + t.Fatalf("Lolly not healthy: %v", err) + } + + env := &E2ETestEnv{ + Ctx: ctx, + Network: network, + Pool: pool, + Lolly: lolly, + Client: CreateDefaultHTTPClient(), + } + + env.cleanup = func() { + lolly.Terminate(ctx) + CleanupProxyTest(ctx, network, pool) + cancel() + } + + t.Cleanup(env.Cleanup) + + return env +} + +// SetupE2ETestWithTimeout 设置带自定义超时的 E2E 测试环境。 +func SetupE2ETestWithTimeout(t *testing.T, backendCount int, timeout time.Duration, cfgBuilder func(*BackendPool) string) *E2ETestEnv { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + + if !LollyImageAvailable(ctx) { + t.Skip("lolly:latest image not available, run 'make docker-build' first") + } + + network, pool, err := SetupProxyTest(ctx, backendCount) + if err != nil { + cancel() + t.Fatalf("Failed to setup proxy test: %v", err) + } + + cfgYAML := cfgBuilder(pool) + + lolly, err := StartLolly(ctx, + WithConfigYAML(cfgYAML), + WithNetwork(network), + ) + if err != nil { + CleanupProxyTest(ctx, network, pool) + cancel() + t.Fatalf("Failed to start lolly: %v", err) + } + + if err := lolly.WaitForHealthy(ctx, HealthCheckWaitTimeout); err != nil { + lolly.Terminate(ctx) + CleanupProxyTest(ctx, network, pool) + cancel() + t.Fatalf("Lolly not healthy: %v", err) + } + + env := &E2ETestEnv{ + Ctx: ctx, + Network: network, + Pool: pool, + Lolly: lolly, + Client: CreateDefaultHTTPClient(), + } + + env.cleanup = func() { + lolly.Terminate(ctx) + CleanupProxyTest(ctx, network, pool) + cancel() + } + + t.Cleanup(env.Cleanup) + + return env +} + +// Cleanup 手动清理资源。 +func (e *E2ETestEnv) Cleanup() { + if e.cleanup != nil { + e.cleanup() + e.cleanup = nil + } +} + +// HTTPURL 返回 lolly HTTP 地址。 +func (e *E2ETestEnv) HTTPURL() string { + return e.Lolly.HTTPBaseURL() +} + +// HTTPSURL 返回 lolly HTTPS 地址。 +func (e *E2ETestEnv) HTTPSURL() string { + return e.Lolly.HTTPSBaseURL() +} + +// SetupSSLTest 设置 SSL 测试环境。 +// +// 用于 SSL/TLS 测试场景,自动生成证书。 +func SetupSSLTest(t *testing.T, cfgBuilder func() string) (*LollyContainer, string, string) { + t.Helper() + + ctx := context.Background() + + if !LollyImageAvailable(ctx) { + t.Skip("lolly:latest image not available, run 'make docker-build' first") + } + + // 生成自签名证书 + certPath, keyPath, cleanup, err := GenerateSelfSignedCert(t.TempDir()) + if err != nil { + t.Fatalf("Failed to generate certificate: %v", err) + } + t.Cleanup(cleanup) + + cfgYAML := cfgBuilder() + + lolly, err := StartLolly(ctx, + WithConfigYAML(cfgYAML), + WithCert(certPath, keyPath), + ) + if err != nil { + t.Fatalf("Failed to start lolly: %v", err) + } + t.Cleanup(func() { lolly.Terminate(ctx) }) + + if err := lolly.WaitForHealthy(ctx, HealthCheckWaitTimeout); err != nil { + t.Fatalf("Lolly not healthy: %v", err) + } + + return lolly, certPath, keyPath +} diff --git a/internal/e2e/testutil/ssl.go b/internal/e2e/testutil/ssl.go new file mode 100644 index 0000000..6d332c6 --- /dev/null +++ b/internal/e2e/testutil/ssl.go @@ -0,0 +1,98 @@ +//go:build e2e + +// Package testutil 提供 E2E 测试的工具函数。 +// +// 包含 SSL/TLS 测试辅助函数。 +// +// 作者:xfy +package testutil + +import ( + "crypto/tls" + "crypto/x509" + "net/http" + "os" + "time" +) + +// CreateTLSClient 创建信任指定证书的 HTTPS 客户端。 +// +// 参数: +// - certPath: CA 证书文件路径 +// +// 返回配置好的 HTTP 客户端,信任指定的证书。 +func CreateTLSClient(certPath string) (*http.Client, error) { + caCert, err := os.ReadFile(certPath) + if err != nil { + return nil, err + } + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + return &http.Client{ + Timeout: DefaultClientTimeout, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: caCertPool, + }, + }, + }, nil +} + +// CreateTLSClientWithVersion 创建带版本限制的 HTTPS 客户端。 +// +// 参数: +// - certPath: CA 证书文件路径 +// - minVersion: 最小 TLS 版本 +// - maxVersion: 最大 TLS 版本 +func CreateTLSClientWithVersion(certPath string, minVersion, maxVersion uint16) (*http.Client, error) { + caCert, err := os.ReadFile(certPath) + if err != nil { + return nil, err + } + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + return &http.Client{ + Timeout: DefaultClientTimeout, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: caCertPool, + MinVersion: minVersion, + MaxVersion: maxVersion, + }, + }, + }, nil +} + +// CreateInsecureTLSClient 创建跳过证书验证的 HTTPS 客户端。 +// +// 用于测试自签名证书场景,不应在生产环境使用。 +func CreateInsecureTLSClient() *http.Client { + return &http.Client{ + Timeout: DefaultClientTimeout, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + } +} + +// CreateDefaultHTTPClient 创建默认 HTTP 客户端。 +// +// 用于非 SSL 测试场景。 +func CreateDefaultHTTPClient() *http.Client { + return &http.Client{ + Timeout: DefaultClientTimeout, + } +} + +// CreateHTTPClientWithTimeout 创建带自定义超时的 HTTP 客户端。 +func CreateHTTPClientWithTimeout(timeout time.Duration) *http.Client { + return &http.Client{ + Timeout: timeout, + } +} \ No newline at end of file diff --git a/internal/e2e/testutil/websocket.go b/internal/e2e/testutil/websocket.go new file mode 100644 index 0000000..3492ac1 --- /dev/null +++ b/internal/e2e/testutil/websocket.go @@ -0,0 +1,411 @@ +//go:build e2e + +// Package testutil 提供 E2E 测试的工具函数。 +// +// 包含 WebSocket 测试辅助工具。 +// +// 作者:xfy +package testutil + +import ( + "context" + "fmt" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +// WSClient WebSocket 测试客户端。 +// +// 封装 gorilla/websocket,提供简单的测试接口。 +type WSClient struct { + conn *websocket.Conn + url string + mu sync.Mutex + closed bool + closeChan chan struct{} +} + +// WSOption WebSocket 客户端选项。 +type WSOption func(*wsConfig) + +type wsConfig struct { + headers http.Header + pingPeriod time.Duration + pongWait time.Duration +} + +// WithHeaders 设置请求头。 +func WithWSHeaders(headers http.Header) WSOption { + return func(c *wsConfig) { + c.headers = headers + } +} + +// WithWSTimeout 设置超时时间。 +func WithWSTimeout(pongWait, pingPeriod time.Duration) WSOption { + return func(c *wsConfig) { + c.pongWait = pongWait + c.pingPeriod = pingPeriod + } +} + +// NewWSClient 创建 WebSocket 客户端。 +// +// 参数: +// - ctx: 上下文 +// - url: WebSocket URL(ws:// 或 wss://) +// - opts: 可选配置 +// +// 返回 WebSocket 客户端实例。 +func NewWSClient(ctx context.Context, url string, opts ...WSOption) (*WSClient, error) { + cfg := &wsConfig{ + headers: http.Header{}, + pongWait: 60 * time.Second, + pingPeriod: 54 * time.Second, + } + + for _, opt := range opts { + opt(cfg) + } + + dialer := websocket.DefaultDialer + conn, _, err := dialer.DialContext(ctx, url, cfg.headers) + if err != nil { + return nil, fmt.Errorf("websocket dial failed: %w", err) + } + + client := &WSClient{ + conn: conn, + url: url, + closeChan: make(chan struct{}), + } + + // 设置 pong 处理 + conn.SetPongHandler(func(string) error { + return conn.SetReadDeadline(time.Now().Add(cfg.pongWait)) + }) + + return client, nil +} + +// Send 发送文本消息。 +func (c *WSClient) Send(message string) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return fmt.Errorf("connection closed") + } + + c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + return c.conn.WriteMessage(websocket.TextMessage, []byte(message)) +} + +// SendBinary 发送二进制消息。 +func (c *WSClient) SendBinary(data []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return fmt.Errorf("connection closed") + } + + c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + return c.conn.WriteMessage(websocket.BinaryMessage, data) +} + +// SendJSON 发送 JSON 消息。 +func (c *WSClient) SendJSON(v interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return fmt.Errorf("connection closed") + } + + c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + return c.conn.WriteJSON(v) +} + +// Receive 接收文本消息。 +// +// 返回消息内容和错误。 +func (c *WSClient) Receive() (string, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return "", fmt.Errorf("connection closed") + } + + c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + messageType, data, err := c.conn.ReadMessage() + if err != nil { + return "", err + } + + if messageType != websocket.TextMessage { + return "", fmt.Errorf("expected text message, got %d", messageType) + } + + return string(data), nil +} + +// ReceiveBinary 接收二进制消息。 +func (c *WSClient) ReceiveBinary() ([]byte, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return nil, fmt.Errorf("connection closed") + } + + c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + messageType, data, err := c.conn.ReadMessage() + if err != nil { + return nil, err + } + + if messageType != websocket.BinaryMessage { + return nil, fmt.Errorf("expected binary message, got %d", messageType) + } + + return data, nil +} + +// ReceiveJSON 接收 JSON 消息。 +func (c *WSClient) ReceiveJSON(v interface{}) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return fmt.Errorf("connection closed") + } + + c.conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + return c.conn.ReadJSON(v) +} + +// ReceiveWithTimeout 接收消息(带超时)。 +func (c *WSClient) ReceiveWithTimeout(timeout time.Duration) (string, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return "", fmt.Errorf("connection closed") + } + + c.conn.SetReadDeadline(time.Now().Add(timeout)) + messageType, data, err := c.conn.ReadMessage() + if err != nil { + return "", err + } + + if messageType != websocket.TextMessage { + return "", fmt.Errorf("expected text message, got %d", messageType) + } + + return string(data), nil +} + +// ReceiveChan 返回消息通道。 +// +// 在后台持续接收消息,通过通道返回。 +func (c *WSClient) ReceiveChan() <-chan WSMessage { + ch := make(chan WSMessage, 10) + + go func() { + defer close(ch) + for { + msg, err := c.Receive() + if err != nil { + return + } + ch <- WSMessage{Data: msg} + } + }() + + return ch +} + +// WSMessage WebSocket 消息。 +type WSMessage struct { + Data string + Error error +} + +// Close 关闭连接。 +func (c *WSClient) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return nil + } + + c.closed = true + close(c.closeChan) + + // 发送关闭帧 + err := c.conn.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + c.conn.Close() + return err + } + + return c.conn.Close() +} + +// IsClosed 检查连接是否已关闭。 +func (c *WSClient) IsClosed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +// URL 返回连接 URL。 +func (c *WSClient) URL() string { + return c.url +} + +// CloseChan 返回关闭通道。 +func (c *WSClient) CloseChan() <-chan struct{} { + return c.closeChan +} + +// WSPool WebSocket 连接池。 +// +// 管理多个 WebSocket 连接,用于并发测试。 +type WSPool struct { + clients []*WSClient + mu sync.Mutex +} + +// NewWSPool 创建 WebSocket 连接池。 +// +// 参数: +// - ctx: 上下文 +// - url: WebSocket URL +// - count: 连接数量 +// +// 返回连接池实例。 +func NewWSPool(ctx context.Context, url string, count int) (*WSPool, error) { + pool := &WSPool{ + clients: make([]*WSClient, count), + } + + for i := 0; i < count; i++ { + client, err := NewWSClient(ctx, url) + if err != nil { + pool.Close() + return nil, fmt.Errorf("failed to create client %d: %w", i, err) + } + pool.clients[i] = client + } + + return pool, nil +} + +// SendAll 向所有连接发送消息。 +func (p *WSPool) SendAll(message string) error { + p.mu.Lock() + defer p.mu.Unlock() + + var lastErr error + for _, client := range p.clients { + if client != nil { + if err := client.Send(message); err != nil { + lastErr = err + } + } + } + return lastErr +} + +// SendOne 向指定连接发送消息。 +func (p *WSPool) SendOne(index int, message string) error { + p.mu.Lock() + defer p.mu.Unlock() + + if index < 0 || index >= len(p.clients) { + return fmt.Errorf("invalid index %d", index) + } + + if p.clients[index] == nil { + return fmt.Errorf("client %d is nil", index) + } + + return p.clients[index].Send(message) +} + +// ReceiveAll 从所有连接接收消息。 +// +// 返回每个连接收到的消息列表。 +func (p *WSPool) ReceiveAll() ([]string, error) { + p.mu.Lock() + defer p.mu.Unlock() + + messages := make([]string, len(p.clients)) + var lastErr error + + for i, client := range p.clients { + if client != nil { + msg, err := client.Receive() + if err != nil { + lastErr = err + messages[i] = "" + } else { + messages[i] = msg + } + } + } + + return messages, lastErr +} + +// Count 返回连接数量。 +func (p *WSPool) Count() int { + return len(p.clients) +} + +// Close 关闭所有连接。 +func (p *WSPool) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + + var lastErr error + for _, client := range p.clients { + if client != nil { + if err := client.Close(); err != nil { + lastErr = err + } + } + } + return lastErr +} + +// WSEchoServer WebSocket Echo 服务器配置。 +type WSEchoServer struct { + Port int + Handler func(*websocket.Conn) +} + +// NewWSEchoHandler 创建 Echo 处理器。 +// +// 将收到的消息原样返回。 +func NewWSEchoHandler() func(*websocket.Conn) { + return func(conn *websocket.Conn) { + defer conn.Close() + for { + messageType, data, err := conn.ReadMessage() + if err != nil { + return + } + conn.WriteMessage(messageType, data) + } + } +}