lolly/internal/proxy/proxy_low_coverage_test.go
xfy 27e00b84a8 fix(proxy,handler,server,stream,ratelimit): fix resource leaks and functional bugs
- proxy/proxy.go: decrement connection count on dangerous path rejection
  (line 724) to prevent connection count leak
- handler/sendfile_linux.go: return *os.File from getSocketFile and let
  linuxSendfile close it, fixing EBADF from deferred close in getSocketFd
- proxy/websocket.go: return bufio.Reader from readWebSocketUpgradeResponse
  and wrap targetConn with bufferedConn to consume pre-buffered frame data,
  preventing first-frame loss
- server/pool.go: use non-blocking send after starting new worker to avoid
  deadlock when queue is full
- stream/stream.go: check stopCh on non-timeout UDP read errors to prevent
  infinite loop and shutdown deadlock
- middleware/ratelimit: replace select-based close guard with sync.Once in
  StopCleanup to prevent double-close panic
2026-06-11 16:35:10 +08:00

1534 lines
41 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package proxy 提供低覆盖率函数的补充测试。
//
// 该文件专注于提升以下函数的测试覆盖率:
// - proxyDebugLog (0%)
// - ServeHTTP (47.3%)
// - selectTarget (46.7%)
// - backgroundRefresh (41.9%)
// - selectByLua (39.1%)
// - WebSocket (15.4%)
// - dialTarget (46.7%)
// - DNS 相关函数 (0%)
//
// 作者xfy
package proxy
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/lua"
"rua.plus/lolly/internal/resolver"
"rua.plus/lolly/internal/testutil"
)
// TestProxyDebugLog 测试 proxyDebugLog 不会 panic覆盖各类型分支。
func TestProxyDebugLog(t *testing.T) {
t.Run("各种 kv 类型", func(t *testing.T) {
proxyDebugLog("测试消息",
"str_key", "字符串值",
"int_key", 42,
"bool_key", true,
"iface_key", []string{"a", "b"},
)
})
t.Run("非字符串 key 跳过", func(t *testing.T) {
proxyDebugLog("非字符串key",
123, "value",
"valid_key", "value",
)
})
t.Run("奇数 kv 参数", func(t *testing.T) {
proxyDebugLog("奇数参数",
"key1", "val1",
"key2",
)
})
t.Run("空消息", func(t *testing.T) {
proxyDebugLog("")
})
t.Run("空 kv", func(t *testing.T) {
proxyDebugLog("无参数")
})
}
// TestServeHTTP_WithRealBackend 测试使用真实后端的 ServeHTTP 完整流程。
func TestServeHTTP_WithRealBackend(t *testing.T) {
backend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetBodyString("backend response")
ctx.Response.Header.Set("X-Backend", "true")
},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
go func() { _ = backend.Serve(ln) }()
time.Sleep(20 * time.Millisecond)
backendAddr := "http://" + ln.Addr().String()
t.Run("GET 请求转发", func(t *testing.T) {
cfg := testutil.NewTestProxyConfig("/")
targets := []*loadbalance.Target{{URL: backendAddr}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/test")
p.ServeHTTP(ctx)
assert.Equal(t, 200, ctx.Response.StatusCode())
assert.Equal(t, "backend response", string(ctx.Response.Body()))
})
t.Run("POST 请求转发", func(t *testing.T) {
cfg := testutil.NewTestProxyConfig("/")
targets := []*loadbalance.Target{{URL: backendAddr}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtxWithBody("POST", "/submit", "request body")
p.ServeHTTP(ctx)
assert.Equal(t, 200, ctx.Response.StatusCode())
})
t.Run("PUT 请求转发", func(t *testing.T) {
cfg := testutil.NewTestProxyConfig("/")
targets := []*loadbalance.Target{{URL: backendAddr}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtxWithBody("PUT", "/resource", "put body")
p.ServeHTTP(ctx)
assert.Equal(t, 200, ctx.Response.StatusCode())
})
}
// TestServeHTTP_ConnectionRefused 测试连接被拒绝时的错误处理。
func TestServeHTTP_ConnectionRefused(t *testing.T) {
t.Run("返回 502 Bad Gateway", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 100 * time.Millisecond, Read: 100 * time.Millisecond, Write: 100 * time.Millisecond},
}
targets := []*loadbalance.Target{{URL: "http://127.0.0.1:1"}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/test")
p.ServeHTTP(ctx)
assert.True(t, ctx.Response.StatusCode() == 502 || ctx.Response.StatusCode() == 504,
"expected 502 or 504, got %d", ctx.Response.StatusCode())
})
}
// TestServeHTTP_Timeout 测试请求超时场景。
func TestServeHTTP_Timeout(t *testing.T) {
// 创建一个延迟很大的后端
backend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
time.Sleep(5 * time.Second)
ctx.SetStatusCode(200)
},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
go func() { _ = backend.Serve(ln) }()
time.Sleep(20 * time.Millisecond)
backendAddr := "http://" + ln.Addr().String()
cfg := &config.ProxyConfig{
Path: "/",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 50 * time.Millisecond, Read: 50 * time.Millisecond, Write: 50 * time.Millisecond},
}
targets := []*loadbalance.Target{{URL: backendAddr}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/slow")
p.ServeHTTP(ctx)
assert.True(t, ctx.Response.StatusCode() == 504 || ctx.Response.StatusCode() == 502,
"expected timeout error (504 or 502), got %d", ctx.Response.StatusCode())
}
// TestServeHTTP_NextUpstreamFailover 测试故障转移到下一个后端。
func TestServeHTTP_NextUpstreamFailover(t *testing.T) {
// 创建健康的后端
healthyBackend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetBodyString("healthy backend")
},
}
healthyLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer healthyLn.Close()
go func() { _ = healthyBackend.Serve(healthyLn) }()
time.Sleep(20 * time.Millisecond)
healthyAddr := "http://" + healthyLn.Addr().String()
cfg := &config.ProxyConfig{
Path: "/",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 100 * time.Millisecond, Read: 100 * time.Millisecond, Write: 100 * time.Millisecond},
NextUpstream: config.NextUpstreamConfig{
Tries: 3,
HTTPCodes: []int{502, 503, 504},
},
}
// 第一个目标不可达,第二个健康
targets := []*loadbalance.Target{
{URL: "http://127.0.0.1:1"},
{URL: healthyAddr},
}
targets[0].Healthy.Store(true)
targets[1].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/test")
p.ServeHTTP(ctx)
assert.Equal(t, 200, ctx.Response.StatusCode())
assert.Equal(t, "healthy backend", string(ctx.Response.Body()))
}
// TestServeHTTP_RetryOnHTTPError 测试后端返回错误状态码时重试。
func TestServeHTTP_RetryOnHTTPError(t *testing.T) {
// 坏后端:返回 502
badBackend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(502)
ctx.SetBodyString("bad gateway")
},
}
badLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer badLn.Close()
go func() { _ = badBackend.Serve(badLn) }()
// 好后端:返回 200
goodBackend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetBodyString("ok")
},
}
goodLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer goodLn.Close()
go func() { _ = goodBackend.Serve(goodLn) }()
time.Sleep(20 * time.Millisecond)
cfg := &config.ProxyConfig{
Path: "/",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 1 * time.Second, Read: 1 * time.Second, Write: 1 * time.Second},
NextUpstream: config.NextUpstreamConfig{
Tries: 3,
HTTPCodes: []int{502},
},
}
targets := []*loadbalance.Target{
{URL: "http://" + badLn.Addr().String()},
{URL: "http://" + goodLn.Addr().String()},
}
targets[0].Healthy.Store(true)
targets[1].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/test")
p.ServeHTTP(ctx)
assert.Equal(t, 200, ctx.Response.StatusCode())
}
// TestServeHTTP_XAccelRedirect 测试 X-Accel-Redirect 内部重定向。
// /internal/ 和 /admin/ 前缀的路径不做内部重定向,直接返回原始响应。
func TestServeHTTP_XAccelRedirect(t *testing.T) {
t.Run("/internal/ 前缀不做重定向", func(t *testing.T) {
backend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.Response.Header.Set("X-Accel-Redirect", "/internal/secret")
ctx.SetStatusCode(200)
ctx.SetBodyString("raw response")
},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
go func() { _ = backend.Serve(ln) }()
time.Sleep(20 * time.Millisecond)
cfg := testutil.NewTestProxyConfig("/")
targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/api/redirect")
p.ServeHTTP(ctx)
assert.Equal(t, 200, ctx.Response.StatusCode())
assert.Equal(t, "raw response", string(ctx.Response.Body()))
})
t.Run("非 /internal/ 路径做内部重定向", func(t *testing.T) {
backend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.Response.Header.Set("X-Accel-Redirect", "/other/path")
ctx.SetStatusCode(200)
},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
go func() { _ = backend.Serve(ln) }()
time.Sleep(20 * time.Millisecond)
cfg := testutil.NewTestProxyConfig("/")
targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/api/redirect")
p.ServeHTTP(ctx)
assert.Equal(t, "/other/path", string(ctx.Request.URI().Path()))
})
}
// TestServeHTTP_ProxyURI 测试 ProxyURI 路径替换。
func TestServeHTTP_ProxyURI(t *testing.T) {
var receivedPath string
backend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
receivedPath = string(ctx.Path())
ctx.SetStatusCode(200)
ctx.SetBodyString("ok")
},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
go func() { _ = backend.Serve(ln) }()
time.Sleep(20 * time.Millisecond)
backendAddr := "http://" + ln.Addr().String()
cfg := testutil.NewTestProxyConfig("/")
targets := []*loadbalance.Target{{URL: backendAddr, ProxyURI: "/v2/api"}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/v1/api")
p.ServeHTTP(ctx)
assert.Equal(t, 200, ctx.Response.StatusCode())
assert.Contains(t, receivedPath, "/v2/api")
}
// TestServeHTTP_SuspiciousPath 测试危险字符路径被拒绝。
func TestServeHTTP_SuspiciousPath(t *testing.T) {
cfg := testutil.NewTestProxyConfig("/")
targets := []*loadbalance.Target{{URL: "http://localhost:8080", ProxyURI: "/test@path"}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/test")
p.ServeHTTP(ctx)
assert.Equal(t, 502, ctx.Response.StatusCode())
}
// TestSelectTarget_Random 测试随机负载均衡算法。
func TestSelectTarget_Random(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "random",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
{URL: "http://backend3:8080"},
}
for _, target := range targets {
target.Healthy.Store(true)
}
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/api/test")
selected := p.selectTarget(ctx)
require.NotNil(t, selected)
assert.Contains(t, []string{
"http://backend1:8080",
"http://backend2:8080",
"http://backend3:8080",
}, selected.URL)
}
// TestSelectTarget_LuaSuccess 测试 Lua balancer 成功选择目标。
func TestSelectTarget_LuaSuccess(t *testing.T) {
// 创建 Lua 脚本,选择第一个目标
tmpFile, err := os.CreateTemp("", "balancer_*.lua")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
luaScript := `
local balancer = require("ngx.balancer")
balancer.set_current_peer(1)
`
_, err = tmpFile.WriteString(luaScript)
require.NoError(t, err)
tmpFile.Close()
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
BalancerByLua: config.BalancerByLuaConfig{
Enabled: true,
Script: tmpFile.Name(),
},
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
}
for _, target := range targets {
target.Healthy.Store(true)
}
luaEngine, err := lua.NewEngine(nil)
require.NoError(t, err)
p, err := NewProxy(cfg, targets, nil, luaEngine)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/api/test")
selected := p.selectTarget(ctx)
require.NotNil(t, selected)
assert.Equal(t, "http://backend1:8080", selected.URL)
}
// TestSelectTarget_LuaFallback 测试 Lua balancer 失败时回退到 fallback。
func TestSelectTarget_LuaFallback(t *testing.T) {
// 创建一个不调用 set_current_peer 的脚本
tmpFile, err := os.CreateTemp("", "balancer_noop_*.lua")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
luaScript := `-- 不调用 set_current_peer`
_, err = tmpFile.WriteString(luaScript)
require.NoError(t, err)
tmpFile.Close()
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
BalancerByLua: config.BalancerByLuaConfig{
Enabled: true,
Script: tmpFile.Name(),
Fallback: "round_robin",
},
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
}
for _, target := range targets {
target.Healthy.Store(true)
}
luaEngine, err := lua.NewEngine(nil)
require.NoError(t, err)
p, err := NewProxy(cfg, targets, nil, luaEngine)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/api/test")
selected := p.selectTarget(ctx)
require.NotNil(t, selected)
}
// TestSelectByLua_ValidScript 测试有效的 Lua balancer 脚本。
// 注意lua.NewEngine(nil) 不初始化 ngx 全局表selectByLua 会返回错误,
// 但 selectTarget 会自动回退到 fallback 算法。这里直接测试 selectTarget。
func TestSelectByLua_ValidScript(t *testing.T) {
tmpFile, err := os.CreateTemp("", "lua_valid_*.lua")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
luaScript := `
local balancer = require("ngx.balancer")
balancer.set_current_peer(2)
`
_, err = tmpFile.WriteString(luaScript)
require.NoError(t, err)
tmpFile.Close()
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
BalancerByLua: config.BalancerByLuaConfig{
Enabled: true,
Script: tmpFile.Name(),
Fallback: "round_robin",
},
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
{URL: "http://backend2:8080"},
{URL: "http://backend3:8080"},
}
for _, target := range targets {
target.Healthy.Store(true)
}
luaEngine, err := lua.NewEngine(nil)
require.NoError(t, err)
p, err := NewProxy(cfg, targets, nil, luaEngine)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/api/test")
// selectTarget 会因 Lua 错误自动回退到 fallback仍然返回有效目标
selected := p.selectTarget(ctx)
require.NotNil(t, selected)
}
// TestSelectByLua_ScriptNotSelecting 测试 Lua 引擎未初始化 ngx 表时返回错误。
func TestSelectByLua_ScriptNotSelecting(t *testing.T) {
tmpFile, err := os.CreateTemp("", "lua_nope_*.lua")
require.NoError(t, err)
defer os.Remove(tmpFile.Name())
luaScript := `-- do nothing`
_, err = tmpFile.WriteString(luaScript)
require.NoError(t, err)
tmpFile.Close()
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
BalancerByLua: config.BalancerByLuaConfig{
Enabled: true,
Script: tmpFile.Name(),
Fallback: "round_robin",
},
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{
{URL: "http://backend1:8080"},
}
targets[0].Healthy.Store(true)
luaEngine, err := lua.NewEngine(nil)
require.NoError(t, err)
p, err := NewProxy(cfg, targets, nil, luaEngine)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/api/test")
// lua.NewEngine(nil) 未初始化 ngx 全局表selectByLua 会返回错误
_, err = p.selectByLua(ctx, targets)
require.Error(t, err)
}
// TestBackgroundRefresh_WithCacheEntry 测试有缓存条目时的后台刷新。
func TestBackgroundRefresh_WithCacheEntry(t *testing.T) {
backend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetBodyString("refreshed content")
ctx.Response.Header.Set("Content-Type", "text/plain")
ctx.Response.Header.Set("ETag", "\"new-etag\"")
},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
go func() { _ = backend.Serve(ln) }()
time.Sleep(20 * time.Millisecond)
backendAddr := "http://" + ln.Addr().String()
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
Revalidate: true,
},
}
targets := []*loadbalance.Target{{URL: backendAddr}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/api/test")
hashKey, origKey := p.buildCacheKeyHash(ctx)
p.cache.Set(hashKey, origKey, []byte("old content"), map[string]string{
"Content-Type": "text/plain",
"Last-Modified": "Mon, 01 Jan 2024 00:00:00 GMT",
"ETag": "\"old-etag\"",
}, 200, 10*time.Second)
// backgroundRefresh 期望请求 URI 包含完整目标 URL
reqCopy := fasthttp.AcquireRequest()
ctx.Request.CopyTo(reqCopy)
reqCopy.SetRequestURI(backendAddr + "/api/test")
p.backgroundRefresh(reqCopy, targets[0], hashKey, origKey)
fasthttp.ReleaseRequest(reqCopy)
entry, ok, _ := p.cache.Get(hashKey, origKey)
require.True(t, ok)
assert.Equal(t, "refreshed content", string(entry.Data))
}
// TestBackgroundRefresh_RequestError 测试后台刷新请求失败时释放缓存锁。
func TestBackgroundRefresh_RequestError(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
}
targets := []*loadbalance.Target{{URL: "http://127.0.0.1:1"}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/api/test")
hashKey := uint64(54321)
p.backgroundRefresh(&ctx.Request, targets[0], hashKey, "GET:/api/test")
}
// TestDialTarget_Success 测试成功建立 TCP 连接。
func TestDialTarget_Success(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
go func() {
conn, _ := ln.Accept()
if conn != nil {
_ = conn.Close()
}
}()
conn, err := dialTarget("http://"+ln.Addr().String(), 1*time.Second)
require.NoError(t, err)
require.NotNil(t, conn)
_ = conn.Close()
}
// TestDialTarget_Timeout 测试连接超时。
func TestDialTarget_Timeout(t *testing.T) {
_, err := dialTarget("http://10.255.255.1:9999", 50*time.Millisecond)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to connect")
}
// TestWebSocket_UpgradeRejected 测试后端拒绝 WebSocket 升级。
// 使用真实 TCP 连接测试 readWebSocketUpgradeResponse 返回非 101 状态码。
func TestWebSocket_UpgradeRejected(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
go func() {
conn, acceptErr := ln.Accept()
if acceptErr != nil {
return
}
reader := bufio.NewReader(conn)
_, _ = http.ReadRequest(reader)
_, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
_ = conn.Close()
}()
time.Sleep(20 * time.Millisecond)
conn, err := net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
defer conn.Close()
// 发送一个请求让服务端触发
_, _ = conn.Write([]byte("GET /ws HTTP/1.1\r\nHost: localhost\r\n\r\n"))
resp, _, err := readWebSocketUpgradeResponse(conn, 1*time.Second)
require.NoError(t, err)
assert.Equal(t, 400, resp.StatusCode)
}
// TestWebSocket_BackendSuccess 测试 WebSocket 函数因 Hijack 失败而返回错误。
// testutil 创建的 ctx 不支持 Hijack验证此错误路径。
func TestWebSocket_BackendSuccess(t *testing.T) {
ctx := testutil.NewRequestCtxWithHeader("GET", "/ws", map[string]string{
"Upgrade": "websocket",
"Connection": "Upgrade",
"Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
"Sec-WebSocket-Version": "13",
})
target := &loadbalance.Target{URL: "http://127.0.0.1:1"}
target.Healthy.Store(true)
err := WebSocket(ctx, target, 2*time.Second, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "hijack")
}
// TestBuildWebSocketUpgradeRequest_HeaderConfig 测试 Headers 配置控制。
func TestBuildWebSocketUpgradeRequest_HeaderConfig(t *testing.T) {
t.Run("禁用 X-Forwarded-Host", func(t *testing.T) {
falseVal := false
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/ws")
ctx.Request.Header.SetHost("client.example.com")
result := buildWebSocketUpgradeRequest(ctx, "backend:8080", &config.ProxyHeaders{
SetForwardedHost: &falseVal,
})
assert.NotContains(t, result, "X-Forwarded-Host")
})
t.Run("禁用 X-Forwarded-Proto", func(t *testing.T) {
falseVal := false
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/ws")
ctx.Request.Header.SetHost("client.example.com")
result := buildWebSocketUpgradeRequest(ctx, "backend:8080", &config.ProxyHeaders{
SetForwardedProto: &falseVal,
})
assert.NotContains(t, result, "X-Forwarded-Proto")
})
t.Run("启用所有头", func(t *testing.T) {
trueVal := true
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/ws")
ctx.Request.Header.SetHost("client.example.com")
result := buildWebSocketUpgradeRequest(ctx, "backend:8080", &config.ProxyHeaders{
SetForwardedHost: &trueVal,
SetForwardedProto: &trueVal,
})
assert.Contains(t, result, "X-Forwarded-Host: client.example.com")
assert.Contains(t, result, "X-Forwarded-Proto: http")
})
}
// TestDNS_StartIdempotent 测试 Start 方法是幂等的。
func TestDNS_StartIdempotent(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
mr := &mockResolver{}
p.SetResolver(mr)
err = p.Start()
require.NoError(t, err)
err = p.Start()
require.NoError(t, err)
assert.Equal(t, 1, mr.startCalls)
}
// TestDNS_RefreshDNS_LookupError 测试 DNS 刷新时查找失败。
func TestDNS_RefreshDNS_LookupError(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
mr := &mockResolver{
lookupError: errors.New("DNS lookup failed"),
}
p.SetResolver(mr)
p.refreshDNS()
}
// TestDNS_RefreshDNS_NoResolver 测试没有解析器时不执行刷新。
func TestDNS_RefreshDNS_NoResolver(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
p.refreshDNS()
}
// TestDNS_UpdateHostClientAddr 测试更新 HostClient 地址。
func TestDNS_UpdateHostClientAddr(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://example.com:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
p.updateHostClientAddr(targets[0], "10.0.0.1")
client := p.getClient("http://example.com:8080")
require.NotNil(t, client)
assert.Equal(t, "10.0.0.1:8080", client.Addr)
}
// TestDNS_UpdateHostClientAddr_DefaultPort 测试无端口时使用默认端口。
func TestDNS_UpdateHostClientAddr_DefaultPort(t *testing.T) {
tests := []struct {
name string
url string
ip string
expected string
}{
{"HTTP 默认端口", "http://example.com", "10.0.0.1", "10.0.0.1:80"},
{"HTTPS 默认端口", "https://example.com", "10.0.0.2", "10.0.0.2:443"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: tt.url}}
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
p.updateHostClientAddr(targets[0], tt.ip)
client := p.getClient(tt.url)
require.NotNil(t, client)
assert.Equal(t, tt.expected, client.Addr)
})
}
}
// TestDNS_GetResolverTTL 测试 TTL 获取。
func TestDNS_GetResolverTTL(t *testing.T) {
t.Run("无解析器返回 0", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ttl := p.getResolverTTL()
assert.Equal(t, time.Duration(0), ttl)
})
t.Run("有解析器返回 30s", func(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
p.SetResolver(&mockResolver{})
ttl := p.getResolverTTL()
assert.Equal(t, 30*time.Second, ttl)
})
}
// TestDNS_RefreshDNS_Success 测试 DNS 刷新成功。
func TestDNS_RefreshDNS_Success(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://example.com:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
mr := &mockResolver{
lookupResults: map[string][]string{
"example.com": {"10.0.0.1"},
},
}
p.SetResolver(mr)
p.refreshDNS()
client := p.getClient("http://example.com:8080")
require.NotNil(t, client)
assert.Equal(t, "10.0.0.1:8080", client.Addr)
}
// TestDNS_StartResolverFails 测试解析器启动失败。
func TestDNS_StartResolverFails(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
}
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
mr := &mockResolver{
startErr: errors.New("resolver start failed"),
}
p.SetResolver(mr)
err = p.Start()
require.Error(t, err)
assert.Contains(t, err.Error(), "resolver start failed")
}
// mockTargetResolver 实现 resolver.Resolver 接口的简化 mock。
type mockTargetResolver struct {
lookupFunc func(ctx context.Context, host string) ([]string, error)
}
func (m *mockTargetResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
return m.lookupFunc(ctx, host)
}
func (m *mockTargetResolver) LookupHostWithCache(ctx context.Context, host string) ([]string, error) {
return m.lookupFunc(ctx, host)
}
func (m *mockTargetResolver) Refresh(host string) error { return nil }
func (m *mockTargetResolver) Start() error { return nil }
func (m *mockTargetResolver) Stop() error { return nil }
func (m *mockTargetResolver) Stats() resolver.Stats { return resolver.Stats{} }
// TestServeHTTP_CacheStaleWhileRevalidate 测试缓存过期时的后台刷新。
func TestServeHTTP_CacheStaleWhileRevalidate(t *testing.T) {
backend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetBodyString("fresh content")
ctx.Response.Header.Set("Content-Type", "text/plain")
},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
go func() { _ = backend.Serve(ln) }()
time.Sleep(20 * time.Millisecond)
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
}
targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/api/cached")
hashKey, origKey := p.buildCacheKeyHash(ctx)
headers := map[string]string{"Content-Type": "text/plain"}
p.cache.Set(hashKey, origKey, []byte("stale content"), headers, 200, -1*time.Second)
p.ServeHTTP(ctx)
assert.Equal(t, 200, ctx.Response.StatusCode())
body := string(ctx.Response.Body())
assert.True(t, body == "stale content" || body == "fresh content",
"expected stale or fresh content, got: %s", body)
}
// TestServeHTTP_AllUpstreamsFailed 测试所有上游都失败时返回 502。
func TestServeHTTP_AllUpstreamsFailed(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 50 * time.Millisecond, Read: 50 * time.Millisecond, Write: 50 * time.Millisecond},
NextUpstream: config.NextUpstreamConfig{
Tries: 2,
HTTPCodes: []int{502},
},
}
targets := []*loadbalance.Target{
{URL: "http://127.0.0.1:1"},
{URL: "http://127.0.0.1:2"},
}
for _, target := range targets {
target.Healthy.Store(true)
}
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/test")
p.ServeHTTP(ctx)
assert.True(t, ctx.Response.StatusCode() == 502 || ctx.Response.StatusCode() == 504,
"expected 502 or 504, got %d", ctx.Response.StatusCode())
}
// TestWriteUpgradeResponse_WriteError 测试写入升级响应失败。
func TestWriteUpgradeResponse_WriteError(t *testing.T) {
conn1, conn2 := net.Pipe()
_ = conn2.Close()
resp := &http.Response{
ProtoMajor: 1,
ProtoMinor: 1,
Status: "101 Switching Protocols",
StatusCode: 101,
Header: http.Header{
"Upgrade": []string{"websocket"},
},
}
err := writeUpgradeResponse(conn1, resp)
assert.Error(t, err)
_ = conn1.Close()
}
// TestIsConnectionClosedError_ClosedConnString 测试包含 "use of closed" 的 net.Error。
func TestIsConnectionClosedError_ClosedConnString(t *testing.T) {
// 构造一个同时是 net.Error 且包含关闭字符串的错误
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
conn, err := net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
_ = conn.Close()
_ = ln.Close()
_, readErr := conn.Read(make([]byte, 1))
require.Error(t, readErr)
assert.True(t, isConnectionClosedError(readErr))
}
// TestCopyData_ReadError 测试读取端错误处理。
func TestCopyData_ReadError(t *testing.T) {
src1, src2 := net.Pipe()
dst1, dst2 := net.Pipe()
_ = src1.Close()
_ = src2.Close()
bridge := &WebSocketBridge{}
errCh := make(chan error, 1)
go func() {
errCh <- bridge.copyData(dst1, src1, "test-direction")
}()
_ = dst1.Close()
_ = dst2.Close()
select {
case err := <-errCh:
_ = err
case <-time.After(1 * time.Second):
t.Error("copyData did not complete in time")
}
}
// TestServeHTTP_CacheStoreAndHit 测试缓存存储和命中完整流程。
func TestServeHTTP_CacheStoreAndHit(t *testing.T) {
var requestCount atomic.Int32
backend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
count := requestCount.Add(1)
ctx.SetStatusCode(200)
ctx.SetBodyString(fmt.Sprintf("response-%d", count))
ctx.Response.Header.Set("Content-Type", "text/plain")
},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
go func() { _ = backend.Serve(ln) }()
time.Sleep(20 * time.Millisecond)
backendAddr := "http://" + ln.Addr().String()
cfg := &config.ProxyConfig{
Path: "/",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 1 * time.Second, Read: 1 * time.Second, Write: 1 * time.Second},
Cache: config.ProxyCacheConfig{
Enabled: true,
MaxAge: 10 * time.Second,
},
}
targets := []*loadbalance.Target{{URL: backendAddr}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx1 := testutil.NewRequestCtx("GET", "/cacheable")
p.ServeHTTP(ctx1)
assert.Equal(t, 200, ctx1.Response.StatusCode())
assert.Equal(t, "response-1", string(ctx1.Response.Body()))
ctx2 := testutil.NewRequestCtx("GET", "/cacheable")
p.ServeHTTP(ctx2)
assert.Equal(t, 200, ctx2.Response.StatusCode())
body := string(ctx2.Response.Body())
assert.True(t, body == "response-1" || body == "response-2",
"expected cached or fresh response, got: %s", body)
}
// TestServeHTTP_ConnectionClosed 测试连接关闭错误。
func TestServeHTTP_ConnectionClosed(t *testing.T) {
backend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(200)
ctx.SetBodyString("ok")
},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
go func() { _ = backend.Serve(ln) }()
time.Sleep(20 * time.Millisecond)
backendAddr := "http://" + ln.Addr().String()
cfg := &config.ProxyConfig{
Path: "/",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 1 * time.Second, Read: 1 * time.Second, Write: 1 * time.Second},
NextUpstream: config.NextUpstreamConfig{
Tries: 2,
HTTPCodes: []int{502},
},
}
targets := []*loadbalance.Target{
{URL: backendAddr},
}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ln.Close()
ctx := testutil.NewRequestCtx("GET", "/test")
p.ServeHTTP(ctx)
assert.True(t, ctx.Response.StatusCode() == 502 || ctx.Response.StatusCode() == 504)
}
// TestReadWebSocketUpgradeResponse_ReadError 测试读取升级响应失败。
func TestReadWebSocketUpgradeResponse_ReadError(t *testing.T) {
conn1, conn2 := net.Pipe()
_ = conn2.Close()
_, _, err := readWebSocketUpgradeResponse(conn1, 100*time.Millisecond)
assert.Error(t, err)
_ = conn1.Close()
}
// TestIsConnectionClosedError_NilNetError 测试普通错误不包含关闭字符串时不被识别。
func TestIsConnectionClosedError_NilNetError(t *testing.T) {
err := errors.New("some random error")
assert.False(t, isConnectionClosedError(err))
}
// TestBridge_NonClosedErrors 测试桥接返回非关闭错误。
func TestBridge_NonClosedErrors(t *testing.T) {
errConn1, errConn2 := net.Pipe()
normalConn1, normalConn2 := net.Pipe()
defer func() {
_ = normalConn1.Close()
_ = normalConn2.Close()
}()
bridge := NewWebSocketBridge(errConn1, normalConn1)
_ = errConn2.Close()
errCh := make(chan error, 1)
go func() {
errCh <- bridge.Bridge()
}()
time.Sleep(50 * time.Millisecond)
_ = normalConn2.Close()
select {
case err := <-errCh:
_ = err
case <-time.After(2 * time.Second):
t.Error("Bridge did not complete in time")
}
}
// TestServeHTTP_IgnoresEmptyTargetURL 测试跳过空 URL 的目标。
func TestServeHTTP_IgnoresEmptyTargetURL(t *testing.T) {
cfg := &config.ProxyConfig{
Path: "/api",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 100 * time.Millisecond},
}
targets := []*loadbalance.Target{
{URL: ""},
{URL: "http://127.0.0.1:1"},
}
targets[1].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/api/test")
p.ServeHTTP(ctx)
assert.True(t, ctx.Response.StatusCode() == 502 || ctx.Response.StatusCode() == 504)
}
// TestServeHTTP_WithQueryParams 测试带查询参数的请求。
func TestServeHTTP_WithQueryParams(t *testing.T) {
var receivedURI string
backend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
receivedURI = string(ctx.RequestURI())
ctx.SetStatusCode(200)
ctx.SetBodyString("ok")
},
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
go func() { _ = backend.Serve(ln) }()
time.Sleep(20 * time.Millisecond)
backendAddr := "http://" + ln.Addr().String()
cfg := testutil.NewTestProxyConfig("/")
targets := []*loadbalance.Target{{URL: backendAddr}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtx("GET", "/search?q=test&limit=10")
p.ServeHTTP(ctx)
assert.Equal(t, 200, ctx.Response.StatusCode())
assert.Contains(t, receivedURI, "q=test")
assert.Contains(t, receivedURI, "limit=10")
}
// TestWebSocket_ReadResponseError 测试读取 WebSocket 升级响应失败。
func TestWebSocket_ReadResponseError(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
go func() {
conn, acceptErr := ln.Accept()
if acceptErr != nil {
return
}
reader := bufio.NewReader(conn)
_, _ = http.ReadRequest(reader)
_ = conn.Close()
}()
time.Sleep(20 * time.Millisecond)
ctx := testutil.NewRequestCtxWithHeader("GET", "/ws", map[string]string{
"Upgrade": "websocket",
"Connection": "Upgrade",
})
target := &loadbalance.Target{URL: "http://" + ln.Addr().String()}
target.Healthy.Store(true)
err = WebSocket(ctx, target, 1*time.Second, nil)
require.Error(t, err)
}
// TestCopyData_WriteErrorNonClosed 测试写入错误(非关闭类)。
func TestCopyData_WriteErrorNonClosed(t *testing.T) {
src1, src2 := net.Pipe()
dst1, _ := net.Pipe()
bridge := &WebSocketBridge{}
errCh := make(chan error, 1)
go func() {
errCh <- bridge.copyData(dst1, src1, "test-dir")
}()
_ = dst1.Close()
_, _ = src2.Write([]byte("data"))
_ = src2.Close()
select {
case err := <-errCh:
_ = err
case <-time.After(1 * time.Second):
t.Error("copyData did not complete")
}
}
// TestWebSocket_HijackFails 测试 Hijack 失败场景。
func TestWebSocket_HijackFails(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
target := &loadbalance.Target{URL: "http://127.0.0.1:1"}
target.Healthy.Store(true)
err := WebSocket(ctx, target, 100*time.Millisecond, nil)
require.Error(t, err)
}
// TestServeHTTP_RedirectRewrite 测试重定向改写。
func TestServeHTTP_RedirectRewrite(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
backend := &fasthttp.Server{
Handler: func(ctx *fasthttp.RequestCtx) {
ctx.SetStatusCode(301)
ctx.Response.Header.Set("Location", "http://"+ln.Addr().String()+"/new-path")
},
}
go func() { _ = backend.Serve(ln) }()
time.Sleep(20 * time.Millisecond)
defer ln.Close()
cfg := &config.ProxyConfig{
Path: "/",
LoadBalance: "round_robin",
Timeout: config.ProxyTimeout{Connect: 1 * time.Second, Read: 1 * time.Second, Write: 1 * time.Second},
RedirectRewrite: &config.RedirectRewriteConfig{
Mode: "default",
},
}
targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}}
targets[0].Healthy.Store(true)
p, err := NewProxy(cfg, targets, nil, nil)
require.NoError(t, err)
ctx := testutil.NewRequestCtxWithHeader("GET", "/old-path", map[string]string{
"Host": "frontend.example.com",
})
p.ServeHTTP(ctx)
assert.Equal(t, 301, ctx.Response.StatusCode())
location := string(ctx.Response.Header.Peek("Location"))
assert.Contains(t, location, "frontend.example.com")
}
// TestBuildWebSocketUpgradeRequest_Origin 测试 Origin 头复制。
func TestBuildWebSocketUpgradeRequest_Origin(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/ws")
ctx.Request.Header.Set("Origin", "http://client.example.com")
ctx.Request.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate")
result := buildWebSocketUpgradeRequest(ctx, "backend:8080", nil)
assert.Contains(t, result, "Origin: http://client.example.com")
assert.Contains(t, result, "Sec-WebSocket-Extensions: permessage-deflate")
}
// TestCopyData_ReaderError 测试读取端错误(非关闭类)。
func TestCopyData_ReaderError(t *testing.T) {
src1, _ := net.Pipe()
dst1, dst2 := net.Pipe()
defer func() { _ = dst2.Close() }()
_ = src1.Close()
bridge := &WebSocketBridge{}
errCh := make(chan error, 1)
go func() {
errCh <- bridge.copyData(dst1, src1, "test")
}()
select {
case err := <-errCh:
_ = err
case <-time.After(1 * time.Second):
t.Error("copyData did not complete in time")
}
_ = dst1.Close()
}
// TestIsConnectionClosedError_RegularError 测试普通错误不被识别为关闭错误。
func TestIsConnectionClosedError_RegularError(t *testing.T) {
err := errors.New("random error")
assert.False(t, isConnectionClosedError(err))
}
// TestIsConnectionClosedError_Nil 测试 nil 不被识别为关闭错误。
func TestIsConnectionClosedError_Nil(t *testing.T) {
assert.False(t, isConnectionClosedError(nil))
}
// TestIsConnectionClosedError_EOF 测试 EOF 被识别为关闭错误。
func TestIsConnectionClosedError_EOF(t *testing.T) {
assert.True(t, isConnectionClosedError(io.EOF))
}