- proxy/proxy.go: decrement connection count on dangerous path rejection (line 724) to prevent connection count leak - handler/sendfile_linux.go: return *os.File from getSocketFile and let linuxSendfile close it, fixing EBADF from deferred close in getSocketFd - proxy/websocket.go: return bufio.Reader from readWebSocketUpgradeResponse and wrap targetConn with bufferedConn to consume pre-buffered frame data, preventing first-frame loss - server/pool.go: use non-blocking send after starting new worker to avoid deadlock when queue is full - stream/stream.go: check stopCh on non-timeout UDP read errors to prevent infinite loop and shutdown deadlock - middleware/ratelimit: replace select-based close guard with sync.Once in StopCleanup to prevent double-close panic
1534 lines
41 KiB
Go
1534 lines
41 KiB
Go
// Package proxy 提供低覆盖率函数的补充测试。
|
||
//
|
||
// 该文件专注于提升以下函数的测试覆盖率:
|
||
// - proxyDebugLog (0%)
|
||
// - ServeHTTP (47.3%)
|
||
// - selectTarget (46.7%)
|
||
// - backgroundRefresh (41.9%)
|
||
// - selectByLua (39.1%)
|
||
// - WebSocket (15.4%)
|
||
// - dialTarget (46.7%)
|
||
// - DNS 相关函数 (0%)
|
||
//
|
||
// 作者:xfy
|
||
package proxy
|
||
|
||
import (
|
||
"bufio"
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net"
|
||
"net/http"
|
||
"os"
|
||
"sync/atomic"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/require"
|
||
"github.com/valyala/fasthttp"
|
||
"rua.plus/lolly/internal/config"
|
||
"rua.plus/lolly/internal/loadbalance"
|
||
"rua.plus/lolly/internal/lua"
|
||
"rua.plus/lolly/internal/resolver"
|
||
"rua.plus/lolly/internal/testutil"
|
||
)
|
||
|
||
// TestProxyDebugLog 测试 proxyDebugLog 不会 panic,覆盖各类型分支。
|
||
func TestProxyDebugLog(t *testing.T) {
|
||
t.Run("各种 kv 类型", func(t *testing.T) {
|
||
proxyDebugLog("测试消息",
|
||
"str_key", "字符串值",
|
||
"int_key", 42,
|
||
"bool_key", true,
|
||
"iface_key", []string{"a", "b"},
|
||
)
|
||
})
|
||
|
||
t.Run("非字符串 key 跳过", func(t *testing.T) {
|
||
proxyDebugLog("非字符串key",
|
||
123, "value",
|
||
"valid_key", "value",
|
||
)
|
||
})
|
||
|
||
t.Run("奇数 kv 参数", func(t *testing.T) {
|
||
proxyDebugLog("奇数参数",
|
||
"key1", "val1",
|
||
"key2",
|
||
)
|
||
})
|
||
|
||
t.Run("空消息", func(t *testing.T) {
|
||
proxyDebugLog("")
|
||
})
|
||
|
||
t.Run("空 kv", func(t *testing.T) {
|
||
proxyDebugLog("无参数")
|
||
})
|
||
}
|
||
|
||
// TestServeHTTP_WithRealBackend 测试使用真实后端的 ServeHTTP 完整流程。
|
||
func TestServeHTTP_WithRealBackend(t *testing.T) {
|
||
backend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
ctx.SetStatusCode(200)
|
||
ctx.SetBodyString("backend response")
|
||
ctx.Response.Header.Set("X-Backend", "true")
|
||
},
|
||
}
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer ln.Close()
|
||
|
||
go func() { _ = backend.Serve(ln) }()
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
backendAddr := "http://" + ln.Addr().String()
|
||
|
||
t.Run("GET 请求转发", func(t *testing.T) {
|
||
cfg := testutil.NewTestProxyConfig("/")
|
||
targets := []*loadbalance.Target{{URL: backendAddr}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/test")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.Equal(t, 200, ctx.Response.StatusCode())
|
||
assert.Equal(t, "backend response", string(ctx.Response.Body()))
|
||
})
|
||
|
||
t.Run("POST 请求转发", func(t *testing.T) {
|
||
cfg := testutil.NewTestProxyConfig("/")
|
||
targets := []*loadbalance.Target{{URL: backendAddr}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtxWithBody("POST", "/submit", "request body")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.Equal(t, 200, ctx.Response.StatusCode())
|
||
})
|
||
|
||
t.Run("PUT 请求转发", func(t *testing.T) {
|
||
cfg := testutil.NewTestProxyConfig("/")
|
||
targets := []*loadbalance.Target{{URL: backendAddr}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtxWithBody("PUT", "/resource", "put body")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.Equal(t, 200, ctx.Response.StatusCode())
|
||
})
|
||
}
|
||
|
||
// TestServeHTTP_ConnectionRefused 测试连接被拒绝时的错误处理。
|
||
func TestServeHTTP_ConnectionRefused(t *testing.T) {
|
||
t.Run("返回 502 Bad Gateway", func(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 100 * time.Millisecond, Read: 100 * time.Millisecond, Write: 100 * time.Millisecond},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{{URL: "http://127.0.0.1:1"}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/test")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.True(t, ctx.Response.StatusCode() == 502 || ctx.Response.StatusCode() == 504,
|
||
"expected 502 or 504, got %d", ctx.Response.StatusCode())
|
||
})
|
||
}
|
||
|
||
// TestServeHTTP_Timeout 测试请求超时场景。
|
||
func TestServeHTTP_Timeout(t *testing.T) {
|
||
// 创建一个延迟很大的后端
|
||
backend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
time.Sleep(5 * time.Second)
|
||
ctx.SetStatusCode(200)
|
||
},
|
||
}
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer ln.Close()
|
||
|
||
go func() { _ = backend.Serve(ln) }()
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
backendAddr := "http://" + ln.Addr().String()
|
||
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 50 * time.Millisecond, Read: 50 * time.Millisecond, Write: 50 * time.Millisecond},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{{URL: backendAddr}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/slow")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.True(t, ctx.Response.StatusCode() == 504 || ctx.Response.StatusCode() == 502,
|
||
"expected timeout error (504 or 502), got %d", ctx.Response.StatusCode())
|
||
}
|
||
|
||
// TestServeHTTP_NextUpstreamFailover 测试故障转移到下一个后端。
|
||
func TestServeHTTP_NextUpstreamFailover(t *testing.T) {
|
||
// 创建健康的后端
|
||
healthyBackend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
ctx.SetStatusCode(200)
|
||
ctx.SetBodyString("healthy backend")
|
||
},
|
||
}
|
||
|
||
healthyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer healthyLn.Close()
|
||
|
||
go func() { _ = healthyBackend.Serve(healthyLn) }()
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
healthyAddr := "http://" + healthyLn.Addr().String()
|
||
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 100 * time.Millisecond, Read: 100 * time.Millisecond, Write: 100 * time.Millisecond},
|
||
NextUpstream: config.NextUpstreamConfig{
|
||
Tries: 3,
|
||
HTTPCodes: []int{502, 503, 504},
|
||
},
|
||
}
|
||
|
||
// 第一个目标不可达,第二个健康
|
||
targets := []*loadbalance.Target{
|
||
{URL: "http://127.0.0.1:1"},
|
||
{URL: healthyAddr},
|
||
}
|
||
targets[0].Healthy.Store(true)
|
||
targets[1].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/test")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.Equal(t, 200, ctx.Response.StatusCode())
|
||
assert.Equal(t, "healthy backend", string(ctx.Response.Body()))
|
||
}
|
||
|
||
// TestServeHTTP_RetryOnHTTPError 测试后端返回错误状态码时重试。
|
||
func TestServeHTTP_RetryOnHTTPError(t *testing.T) {
|
||
// 坏后端:返回 502
|
||
badBackend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
ctx.SetStatusCode(502)
|
||
ctx.SetBodyString("bad gateway")
|
||
},
|
||
}
|
||
badLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer badLn.Close()
|
||
go func() { _ = badBackend.Serve(badLn) }()
|
||
|
||
// 好后端:返回 200
|
||
goodBackend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
ctx.SetStatusCode(200)
|
||
ctx.SetBodyString("ok")
|
||
},
|
||
}
|
||
goodLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer goodLn.Close()
|
||
go func() { _ = goodBackend.Serve(goodLn) }()
|
||
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 1 * time.Second, Read: 1 * time.Second, Write: 1 * time.Second},
|
||
NextUpstream: config.NextUpstreamConfig{
|
||
Tries: 3,
|
||
HTTPCodes: []int{502},
|
||
},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{
|
||
{URL: "http://" + badLn.Addr().String()},
|
||
{URL: "http://" + goodLn.Addr().String()},
|
||
}
|
||
targets[0].Healthy.Store(true)
|
||
targets[1].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/test")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.Equal(t, 200, ctx.Response.StatusCode())
|
||
}
|
||
|
||
// TestServeHTTP_XAccelRedirect 测试 X-Accel-Redirect 内部重定向。
|
||
// /internal/ 和 /admin/ 前缀的路径不做内部重定向,直接返回原始响应。
|
||
func TestServeHTTP_XAccelRedirect(t *testing.T) {
|
||
t.Run("/internal/ 前缀不做重定向", func(t *testing.T) {
|
||
backend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
ctx.Response.Header.Set("X-Accel-Redirect", "/internal/secret")
|
||
ctx.SetStatusCode(200)
|
||
ctx.SetBodyString("raw response")
|
||
},
|
||
}
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer ln.Close()
|
||
|
||
go func() { _ = backend.Serve(ln) }()
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
cfg := testutil.NewTestProxyConfig("/")
|
||
targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/api/redirect")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.Equal(t, 200, ctx.Response.StatusCode())
|
||
assert.Equal(t, "raw response", string(ctx.Response.Body()))
|
||
})
|
||
|
||
t.Run("非 /internal/ 路径做内部重定向", func(t *testing.T) {
|
||
backend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
ctx.Response.Header.Set("X-Accel-Redirect", "/other/path")
|
||
ctx.SetStatusCode(200)
|
||
},
|
||
}
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer ln.Close()
|
||
|
||
go func() { _ = backend.Serve(ln) }()
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
cfg := testutil.NewTestProxyConfig("/")
|
||
targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/api/redirect")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.Equal(t, "/other/path", string(ctx.Request.URI().Path()))
|
||
})
|
||
}
|
||
|
||
// TestServeHTTP_ProxyURI 测试 ProxyURI 路径替换。
|
||
func TestServeHTTP_ProxyURI(t *testing.T) {
|
||
var receivedPath string
|
||
|
||
backend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
receivedPath = string(ctx.Path())
|
||
ctx.SetStatusCode(200)
|
||
ctx.SetBodyString("ok")
|
||
},
|
||
}
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer ln.Close()
|
||
|
||
go func() { _ = backend.Serve(ln) }()
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
backendAddr := "http://" + ln.Addr().String()
|
||
|
||
cfg := testutil.NewTestProxyConfig("/")
|
||
targets := []*loadbalance.Target{{URL: backendAddr, ProxyURI: "/v2/api"}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/v1/api")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.Equal(t, 200, ctx.Response.StatusCode())
|
||
assert.Contains(t, receivedPath, "/v2/api")
|
||
}
|
||
|
||
// TestServeHTTP_SuspiciousPath 测试危险字符路径被拒绝。
|
||
func TestServeHTTP_SuspiciousPath(t *testing.T) {
|
||
cfg := testutil.NewTestProxyConfig("/")
|
||
targets := []*loadbalance.Target{{URL: "http://localhost:8080", ProxyURI: "/test@path"}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/test")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.Equal(t, 502, ctx.Response.StatusCode())
|
||
}
|
||
|
||
// TestSelectTarget_Random 测试随机负载均衡算法。
|
||
func TestSelectTarget_Random(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "random",
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{
|
||
{URL: "http://backend1:8080"},
|
||
{URL: "http://backend2:8080"},
|
||
{URL: "http://backend3:8080"},
|
||
}
|
||
for _, target := range targets {
|
||
target.Healthy.Store(true)
|
||
}
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||
selected := p.selectTarget(ctx)
|
||
|
||
require.NotNil(t, selected)
|
||
assert.Contains(t, []string{
|
||
"http://backend1:8080",
|
||
"http://backend2:8080",
|
||
"http://backend3:8080",
|
||
}, selected.URL)
|
||
}
|
||
|
||
// TestSelectTarget_LuaSuccess 测试 Lua balancer 成功选择目标。
|
||
func TestSelectTarget_LuaSuccess(t *testing.T) {
|
||
// 创建 Lua 脚本,选择第一个目标
|
||
tmpFile, err := os.CreateTemp("", "balancer_*.lua")
|
||
require.NoError(t, err)
|
||
defer os.Remove(tmpFile.Name())
|
||
|
||
luaScript := `
|
||
local balancer = require("ngx.balancer")
|
||
balancer.set_current_peer(1)
|
||
`
|
||
_, err = tmpFile.WriteString(luaScript)
|
||
require.NoError(t, err)
|
||
tmpFile.Close()
|
||
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
BalancerByLua: config.BalancerByLuaConfig{
|
||
Enabled: true,
|
||
Script: tmpFile.Name(),
|
||
},
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{
|
||
{URL: "http://backend1:8080"},
|
||
{URL: "http://backend2:8080"},
|
||
}
|
||
for _, target := range targets {
|
||
target.Healthy.Store(true)
|
||
}
|
||
|
||
luaEngine, err := lua.NewEngine(nil)
|
||
require.NoError(t, err)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, luaEngine)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||
selected := p.selectTarget(ctx)
|
||
|
||
require.NotNil(t, selected)
|
||
assert.Equal(t, "http://backend1:8080", selected.URL)
|
||
}
|
||
|
||
// TestSelectTarget_LuaFallback 测试 Lua balancer 失败时回退到 fallback。
|
||
func TestSelectTarget_LuaFallback(t *testing.T) {
|
||
// 创建一个不调用 set_current_peer 的脚本
|
||
tmpFile, err := os.CreateTemp("", "balancer_noop_*.lua")
|
||
require.NoError(t, err)
|
||
defer os.Remove(tmpFile.Name())
|
||
|
||
luaScript := `-- 不调用 set_current_peer`
|
||
_, err = tmpFile.WriteString(luaScript)
|
||
require.NoError(t, err)
|
||
tmpFile.Close()
|
||
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
BalancerByLua: config.BalancerByLuaConfig{
|
||
Enabled: true,
|
||
Script: tmpFile.Name(),
|
||
Fallback: "round_robin",
|
||
},
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{
|
||
{URL: "http://backend1:8080"},
|
||
{URL: "http://backend2:8080"},
|
||
}
|
||
for _, target := range targets {
|
||
target.Healthy.Store(true)
|
||
}
|
||
|
||
luaEngine, err := lua.NewEngine(nil)
|
||
require.NoError(t, err)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, luaEngine)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||
selected := p.selectTarget(ctx)
|
||
|
||
require.NotNil(t, selected)
|
||
}
|
||
|
||
// TestSelectByLua_ValidScript 测试有效的 Lua balancer 脚本。
|
||
// 注意:lua.NewEngine(nil) 不初始化 ngx 全局表,selectByLua 会返回错误,
|
||
// 但 selectTarget 会自动回退到 fallback 算法。这里直接测试 selectTarget。
|
||
func TestSelectByLua_ValidScript(t *testing.T) {
|
||
tmpFile, err := os.CreateTemp("", "lua_valid_*.lua")
|
||
require.NoError(t, err)
|
||
defer os.Remove(tmpFile.Name())
|
||
|
||
luaScript := `
|
||
local balancer = require("ngx.balancer")
|
||
balancer.set_current_peer(2)
|
||
`
|
||
_, err = tmpFile.WriteString(luaScript)
|
||
require.NoError(t, err)
|
||
tmpFile.Close()
|
||
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
BalancerByLua: config.BalancerByLuaConfig{
|
||
Enabled: true,
|
||
Script: tmpFile.Name(),
|
||
Fallback: "round_robin",
|
||
},
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{
|
||
{URL: "http://backend1:8080"},
|
||
{URL: "http://backend2:8080"},
|
||
{URL: "http://backend3:8080"},
|
||
}
|
||
for _, target := range targets {
|
||
target.Healthy.Store(true)
|
||
}
|
||
|
||
luaEngine, err := lua.NewEngine(nil)
|
||
require.NoError(t, err)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, luaEngine)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||
|
||
// selectTarget 会因 Lua 错误自动回退到 fallback,仍然返回有效目标
|
||
selected := p.selectTarget(ctx)
|
||
require.NotNil(t, selected)
|
||
}
|
||
|
||
// TestSelectByLua_ScriptNotSelecting 测试 Lua 引擎未初始化 ngx 表时返回错误。
|
||
func TestSelectByLua_ScriptNotSelecting(t *testing.T) {
|
||
tmpFile, err := os.CreateTemp("", "lua_nope_*.lua")
|
||
require.NoError(t, err)
|
||
defer os.Remove(tmpFile.Name())
|
||
|
||
luaScript := `-- do nothing`
|
||
_, err = tmpFile.WriteString(luaScript)
|
||
require.NoError(t, err)
|
||
tmpFile.Close()
|
||
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
BalancerByLua: config.BalancerByLuaConfig{
|
||
Enabled: true,
|
||
Script: tmpFile.Name(),
|
||
Fallback: "round_robin",
|
||
},
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{
|
||
{URL: "http://backend1:8080"},
|
||
}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
luaEngine, err := lua.NewEngine(nil)
|
||
require.NoError(t, err)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, luaEngine)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||
|
||
// lua.NewEngine(nil) 未初始化 ngx 全局表,selectByLua 会返回错误
|
||
_, err = p.selectByLua(ctx, targets)
|
||
require.Error(t, err)
|
||
}
|
||
|
||
// TestBackgroundRefresh_WithCacheEntry 测试有缓存条目时的后台刷新。
|
||
func TestBackgroundRefresh_WithCacheEntry(t *testing.T) {
|
||
backend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
ctx.SetStatusCode(200)
|
||
ctx.SetBodyString("refreshed content")
|
||
ctx.Response.Header.Set("Content-Type", "text/plain")
|
||
ctx.Response.Header.Set("ETag", "\"new-etag\"")
|
||
},
|
||
}
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer ln.Close()
|
||
|
||
go func() { _ = backend.Serve(ln) }()
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
backendAddr := "http://" + ln.Addr().String()
|
||
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
Cache: config.ProxyCacheConfig{
|
||
Enabled: true,
|
||
MaxAge: 10 * time.Second,
|
||
Revalidate: true,
|
||
},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{{URL: backendAddr}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||
hashKey, origKey := p.buildCacheKeyHash(ctx)
|
||
|
||
p.cache.Set(hashKey, origKey, []byte("old content"), map[string]string{
|
||
"Content-Type": "text/plain",
|
||
"Last-Modified": "Mon, 01 Jan 2024 00:00:00 GMT",
|
||
"ETag": "\"old-etag\"",
|
||
}, 200, 10*time.Second)
|
||
|
||
// backgroundRefresh 期望请求 URI 包含完整目标 URL
|
||
reqCopy := fasthttp.AcquireRequest()
|
||
ctx.Request.CopyTo(reqCopy)
|
||
reqCopy.SetRequestURI(backendAddr + "/api/test")
|
||
|
||
p.backgroundRefresh(reqCopy, targets[0], hashKey, origKey)
|
||
fasthttp.ReleaseRequest(reqCopy)
|
||
|
||
entry, ok, _ := p.cache.Get(hashKey, origKey)
|
||
require.True(t, ok)
|
||
assert.Equal(t, "refreshed content", string(entry.Data))
|
||
}
|
||
|
||
// TestBackgroundRefresh_RequestError 测试后台刷新请求失败时释放缓存锁。
|
||
func TestBackgroundRefresh_RequestError(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
Cache: config.ProxyCacheConfig{
|
||
Enabled: true,
|
||
MaxAge: 10 * time.Second,
|
||
},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{{URL: "http://127.0.0.1:1"}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||
hashKey := uint64(54321)
|
||
|
||
p.backgroundRefresh(&ctx.Request, targets[0], hashKey, "GET:/api/test")
|
||
}
|
||
|
||
// TestDialTarget_Success 测试成功建立 TCP 连接。
|
||
func TestDialTarget_Success(t *testing.T) {
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer ln.Close()
|
||
|
||
go func() {
|
||
conn, _ := ln.Accept()
|
||
if conn != nil {
|
||
_ = conn.Close()
|
||
}
|
||
}()
|
||
|
||
conn, err := dialTarget("http://"+ln.Addr().String(), 1*time.Second)
|
||
require.NoError(t, err)
|
||
require.NotNil(t, conn)
|
||
_ = conn.Close()
|
||
}
|
||
|
||
// TestDialTarget_Timeout 测试连接超时。
|
||
func TestDialTarget_Timeout(t *testing.T) {
|
||
_, err := dialTarget("http://10.255.255.1:9999", 50*time.Millisecond)
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "failed to connect")
|
||
}
|
||
|
||
// TestWebSocket_UpgradeRejected 测试后端拒绝 WebSocket 升级。
|
||
// 使用真实 TCP 连接测试 readWebSocketUpgradeResponse 返回非 101 状态码。
|
||
func TestWebSocket_UpgradeRejected(t *testing.T) {
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer ln.Close()
|
||
|
||
go func() {
|
||
conn, acceptErr := ln.Accept()
|
||
if acceptErr != nil {
|
||
return
|
||
}
|
||
reader := bufio.NewReader(conn)
|
||
_, _ = http.ReadRequest(reader)
|
||
_, _ = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
||
_ = conn.Close()
|
||
}()
|
||
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||
require.NoError(t, err)
|
||
defer conn.Close()
|
||
|
||
// 发送一个请求让服务端触发
|
||
_, _ = conn.Write([]byte("GET /ws HTTP/1.1\r\nHost: localhost\r\n\r\n"))
|
||
|
||
resp, _, err := readWebSocketUpgradeResponse(conn, 1*time.Second)
|
||
require.NoError(t, err)
|
||
assert.Equal(t, 400, resp.StatusCode)
|
||
}
|
||
|
||
// TestWebSocket_BackendSuccess 测试 WebSocket 函数因 Hijack 失败而返回错误。
|
||
// testutil 创建的 ctx 不支持 Hijack,验证此错误路径。
|
||
func TestWebSocket_BackendSuccess(t *testing.T) {
|
||
ctx := testutil.NewRequestCtxWithHeader("GET", "/ws", map[string]string{
|
||
"Upgrade": "websocket",
|
||
"Connection": "Upgrade",
|
||
"Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
|
||
"Sec-WebSocket-Version": "13",
|
||
})
|
||
|
||
target := &loadbalance.Target{URL: "http://127.0.0.1:1"}
|
||
target.Healthy.Store(true)
|
||
|
||
err := WebSocket(ctx, target, 2*time.Second, nil)
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "hijack")
|
||
}
|
||
|
||
// TestBuildWebSocketUpgradeRequest_HeaderConfig 测试 Headers 配置控制。
|
||
func TestBuildWebSocketUpgradeRequest_HeaderConfig(t *testing.T) {
|
||
t.Run("禁用 X-Forwarded-Host", func(t *testing.T) {
|
||
falseVal := false
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/ws")
|
||
ctx.Request.Header.SetHost("client.example.com")
|
||
|
||
result := buildWebSocketUpgradeRequest(ctx, "backend:8080", &config.ProxyHeaders{
|
||
SetForwardedHost: &falseVal,
|
||
})
|
||
|
||
assert.NotContains(t, result, "X-Forwarded-Host")
|
||
})
|
||
|
||
t.Run("禁用 X-Forwarded-Proto", func(t *testing.T) {
|
||
falseVal := false
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/ws")
|
||
ctx.Request.Header.SetHost("client.example.com")
|
||
|
||
result := buildWebSocketUpgradeRequest(ctx, "backend:8080", &config.ProxyHeaders{
|
||
SetForwardedProto: &falseVal,
|
||
})
|
||
|
||
assert.NotContains(t, result, "X-Forwarded-Proto")
|
||
})
|
||
|
||
t.Run("启用所有头", func(t *testing.T) {
|
||
trueVal := true
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/ws")
|
||
ctx.Request.Header.SetHost("client.example.com")
|
||
|
||
result := buildWebSocketUpgradeRequest(ctx, "backend:8080", &config.ProxyHeaders{
|
||
SetForwardedHost: &trueVal,
|
||
SetForwardedProto: &trueVal,
|
||
})
|
||
|
||
assert.Contains(t, result, "X-Forwarded-Host: client.example.com")
|
||
assert.Contains(t, result, "X-Forwarded-Proto: http")
|
||
})
|
||
}
|
||
|
||
// TestDNS_StartIdempotent 测试 Start 方法是幂等的。
|
||
func TestDNS_StartIdempotent(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
mr := &mockResolver{}
|
||
p.SetResolver(mr)
|
||
|
||
err = p.Start()
|
||
require.NoError(t, err)
|
||
|
||
err = p.Start()
|
||
require.NoError(t, err)
|
||
|
||
assert.Equal(t, 1, mr.startCalls)
|
||
}
|
||
|
||
// TestDNS_RefreshDNS_LookupError 测试 DNS 刷新时查找失败。
|
||
func TestDNS_RefreshDNS_LookupError(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
mr := &mockResolver{
|
||
lookupError: errors.New("DNS lookup failed"),
|
||
}
|
||
p.SetResolver(mr)
|
||
|
||
p.refreshDNS()
|
||
}
|
||
|
||
// TestDNS_RefreshDNS_NoResolver 测试没有解析器时不执行刷新。
|
||
func TestDNS_RefreshDNS_NoResolver(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
p.refreshDNS()
|
||
}
|
||
|
||
// TestDNS_UpdateHostClientAddr 测试更新 HostClient 地址。
|
||
func TestDNS_UpdateHostClientAddr(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
targets := []*loadbalance.Target{{URL: "http://example.com:8080"}}
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
p.updateHostClientAddr(targets[0], "10.0.0.1")
|
||
|
||
client := p.getClient("http://example.com:8080")
|
||
require.NotNil(t, client)
|
||
assert.Equal(t, "10.0.0.1:8080", client.Addr)
|
||
}
|
||
|
||
// TestDNS_UpdateHostClientAddr_DefaultPort 测试无端口时使用默认端口。
|
||
func TestDNS_UpdateHostClientAddr_DefaultPort(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
url string
|
||
ip string
|
||
expected string
|
||
}{
|
||
{"HTTP 默认端口", "http://example.com", "10.0.0.1", "10.0.0.1:80"},
|
||
{"HTTPS 默认端口", "https://example.com", "10.0.0.2", "10.0.0.2:443"},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
targets := []*loadbalance.Target{{URL: tt.url}}
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
p.updateHostClientAddr(targets[0], tt.ip)
|
||
|
||
client := p.getClient(tt.url)
|
||
require.NotNil(t, client)
|
||
assert.Equal(t, tt.expected, client.Addr)
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestDNS_GetResolverTTL 测试 TTL 获取。
|
||
func TestDNS_GetResolverTTL(t *testing.T) {
|
||
t.Run("无解析器返回 0", func(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ttl := p.getResolverTTL()
|
||
assert.Equal(t, time.Duration(0), ttl)
|
||
})
|
||
|
||
t.Run("有解析器返回 30s", func(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
p.SetResolver(&mockResolver{})
|
||
|
||
ttl := p.getResolverTTL()
|
||
assert.Equal(t, 30*time.Second, ttl)
|
||
})
|
||
}
|
||
|
||
// TestDNS_RefreshDNS_Success 测试 DNS 刷新成功。
|
||
func TestDNS_RefreshDNS_Success(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
targets := []*loadbalance.Target{{URL: "http://example.com:8080"}}
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
mr := &mockResolver{
|
||
lookupResults: map[string][]string{
|
||
"example.com": {"10.0.0.1"},
|
||
},
|
||
}
|
||
p.SetResolver(mr)
|
||
|
||
p.refreshDNS()
|
||
|
||
client := p.getClient("http://example.com:8080")
|
||
require.NotNil(t, client)
|
||
assert.Equal(t, "10.0.0.1:8080", client.Addr)
|
||
}
|
||
|
||
// TestDNS_StartResolverFails 测试解析器启动失败。
|
||
func TestDNS_StartResolverFails(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
}
|
||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
mr := &mockResolver{
|
||
startErr: errors.New("resolver start failed"),
|
||
}
|
||
p.SetResolver(mr)
|
||
|
||
err = p.Start()
|
||
require.Error(t, err)
|
||
assert.Contains(t, err.Error(), "resolver start failed")
|
||
}
|
||
|
||
// mockTargetResolver 实现 resolver.Resolver 接口的简化 mock。
|
||
type mockTargetResolver struct {
|
||
lookupFunc func(ctx context.Context, host string) ([]string, error)
|
||
}
|
||
|
||
func (m *mockTargetResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
|
||
return m.lookupFunc(ctx, host)
|
||
}
|
||
|
||
func (m *mockTargetResolver) LookupHostWithCache(ctx context.Context, host string) ([]string, error) {
|
||
return m.lookupFunc(ctx, host)
|
||
}
|
||
|
||
func (m *mockTargetResolver) Refresh(host string) error { return nil }
|
||
|
||
func (m *mockTargetResolver) Start() error { return nil }
|
||
|
||
func (m *mockTargetResolver) Stop() error { return nil }
|
||
|
||
func (m *mockTargetResolver) Stats() resolver.Stats { return resolver.Stats{} }
|
||
|
||
// TestServeHTTP_CacheStaleWhileRevalidate 测试缓存过期时的后台刷新。
|
||
func TestServeHTTP_CacheStaleWhileRevalidate(t *testing.T) {
|
||
backend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
ctx.SetStatusCode(200)
|
||
ctx.SetBodyString("fresh content")
|
||
ctx.Response.Header.Set("Content-Type", "text/plain")
|
||
},
|
||
}
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer ln.Close()
|
||
|
||
go func() { _ = backend.Serve(ln) }()
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||
Cache: config.ProxyCacheConfig{
|
||
Enabled: true,
|
||
MaxAge: 10 * time.Second,
|
||
},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/api/cached")
|
||
hashKey, origKey := p.buildCacheKeyHash(ctx)
|
||
|
||
headers := map[string]string{"Content-Type": "text/plain"}
|
||
p.cache.Set(hashKey, origKey, []byte("stale content"), headers, 200, -1*time.Second)
|
||
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.Equal(t, 200, ctx.Response.StatusCode())
|
||
body := string(ctx.Response.Body())
|
||
assert.True(t, body == "stale content" || body == "fresh content",
|
||
"expected stale or fresh content, got: %s", body)
|
||
}
|
||
|
||
// TestServeHTTP_AllUpstreamsFailed 测试所有上游都失败时返回 502。
|
||
func TestServeHTTP_AllUpstreamsFailed(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 50 * time.Millisecond, Read: 50 * time.Millisecond, Write: 50 * time.Millisecond},
|
||
NextUpstream: config.NextUpstreamConfig{
|
||
Tries: 2,
|
||
HTTPCodes: []int{502},
|
||
},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{
|
||
{URL: "http://127.0.0.1:1"},
|
||
{URL: "http://127.0.0.1:2"},
|
||
}
|
||
for _, target := range targets {
|
||
target.Healthy.Store(true)
|
||
}
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/test")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.True(t, ctx.Response.StatusCode() == 502 || ctx.Response.StatusCode() == 504,
|
||
"expected 502 or 504, got %d", ctx.Response.StatusCode())
|
||
}
|
||
|
||
// TestWriteUpgradeResponse_WriteError 测试写入升级响应失败。
|
||
func TestWriteUpgradeResponse_WriteError(t *testing.T) {
|
||
conn1, conn2 := net.Pipe()
|
||
_ = conn2.Close()
|
||
|
||
resp := &http.Response{
|
||
ProtoMajor: 1,
|
||
ProtoMinor: 1,
|
||
Status: "101 Switching Protocols",
|
||
StatusCode: 101,
|
||
Header: http.Header{
|
||
"Upgrade": []string{"websocket"},
|
||
},
|
||
}
|
||
|
||
err := writeUpgradeResponse(conn1, resp)
|
||
assert.Error(t, err)
|
||
_ = conn1.Close()
|
||
}
|
||
|
||
// TestIsConnectionClosedError_ClosedConnString 测试包含 "use of closed" 的 net.Error。
|
||
func TestIsConnectionClosedError_ClosedConnString(t *testing.T) {
|
||
// 构造一个同时是 net.Error 且包含关闭字符串的错误
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
|
||
conn, err := net.Dial("tcp", ln.Addr().String())
|
||
require.NoError(t, err)
|
||
|
||
_ = conn.Close()
|
||
_ = ln.Close()
|
||
|
||
_, readErr := conn.Read(make([]byte, 1))
|
||
require.Error(t, readErr)
|
||
assert.True(t, isConnectionClosedError(readErr))
|
||
}
|
||
|
||
// TestCopyData_ReadError 测试读取端错误处理。
|
||
func TestCopyData_ReadError(t *testing.T) {
|
||
src1, src2 := net.Pipe()
|
||
dst1, dst2 := net.Pipe()
|
||
|
||
_ = src1.Close()
|
||
_ = src2.Close()
|
||
|
||
bridge := &WebSocketBridge{}
|
||
|
||
errCh := make(chan error, 1)
|
||
go func() {
|
||
errCh <- bridge.copyData(dst1, src1, "test-direction")
|
||
}()
|
||
|
||
_ = dst1.Close()
|
||
_ = dst2.Close()
|
||
|
||
select {
|
||
case err := <-errCh:
|
||
_ = err
|
||
case <-time.After(1 * time.Second):
|
||
t.Error("copyData did not complete in time")
|
||
}
|
||
}
|
||
|
||
// TestServeHTTP_CacheStoreAndHit 测试缓存存储和命中完整流程。
|
||
func TestServeHTTP_CacheStoreAndHit(t *testing.T) {
|
||
var requestCount atomic.Int32
|
||
|
||
backend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
count := requestCount.Add(1)
|
||
ctx.SetStatusCode(200)
|
||
ctx.SetBodyString(fmt.Sprintf("response-%d", count))
|
||
ctx.Response.Header.Set("Content-Type", "text/plain")
|
||
},
|
||
}
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer ln.Close()
|
||
|
||
go func() { _ = backend.Serve(ln) }()
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
backendAddr := "http://" + ln.Addr().String()
|
||
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 1 * time.Second, Read: 1 * time.Second, Write: 1 * time.Second},
|
||
Cache: config.ProxyCacheConfig{
|
||
Enabled: true,
|
||
MaxAge: 10 * time.Second,
|
||
},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{{URL: backendAddr}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx1 := testutil.NewRequestCtx("GET", "/cacheable")
|
||
p.ServeHTTP(ctx1)
|
||
assert.Equal(t, 200, ctx1.Response.StatusCode())
|
||
assert.Equal(t, "response-1", string(ctx1.Response.Body()))
|
||
|
||
ctx2 := testutil.NewRequestCtx("GET", "/cacheable")
|
||
p.ServeHTTP(ctx2)
|
||
assert.Equal(t, 200, ctx2.Response.StatusCode())
|
||
|
||
body := string(ctx2.Response.Body())
|
||
assert.True(t, body == "response-1" || body == "response-2",
|
||
"expected cached or fresh response, got: %s", body)
|
||
}
|
||
|
||
// TestServeHTTP_ConnectionClosed 测试连接关闭错误。
|
||
func TestServeHTTP_ConnectionClosed(t *testing.T) {
|
||
backend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
ctx.SetStatusCode(200)
|
||
ctx.SetBodyString("ok")
|
||
},
|
||
}
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
|
||
go func() { _ = backend.Serve(ln) }()
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
backendAddr := "http://" + ln.Addr().String()
|
||
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 1 * time.Second, Read: 1 * time.Second, Write: 1 * time.Second},
|
||
NextUpstream: config.NextUpstreamConfig{
|
||
Tries: 2,
|
||
HTTPCodes: []int{502},
|
||
},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{
|
||
{URL: backendAddr},
|
||
}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ln.Close()
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/test")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.True(t, ctx.Response.StatusCode() == 502 || ctx.Response.StatusCode() == 504)
|
||
}
|
||
|
||
// TestReadWebSocketUpgradeResponse_ReadError 测试读取升级响应失败。
|
||
func TestReadWebSocketUpgradeResponse_ReadError(t *testing.T) {
|
||
conn1, conn2 := net.Pipe()
|
||
_ = conn2.Close()
|
||
|
||
_, _, err := readWebSocketUpgradeResponse(conn1, 100*time.Millisecond)
|
||
assert.Error(t, err)
|
||
_ = conn1.Close()
|
||
}
|
||
|
||
// TestIsConnectionClosedError_NilNetError 测试普通错误不包含关闭字符串时不被识别。
|
||
func TestIsConnectionClosedError_NilNetError(t *testing.T) {
|
||
err := errors.New("some random error")
|
||
assert.False(t, isConnectionClosedError(err))
|
||
}
|
||
|
||
// TestBridge_NonClosedErrors 测试桥接返回非关闭错误。
|
||
func TestBridge_NonClosedErrors(t *testing.T) {
|
||
errConn1, errConn2 := net.Pipe()
|
||
normalConn1, normalConn2 := net.Pipe()
|
||
defer func() {
|
||
_ = normalConn1.Close()
|
||
_ = normalConn2.Close()
|
||
}()
|
||
|
||
bridge := NewWebSocketBridge(errConn1, normalConn1)
|
||
|
||
_ = errConn2.Close()
|
||
|
||
errCh := make(chan error, 1)
|
||
go func() {
|
||
errCh <- bridge.Bridge()
|
||
}()
|
||
|
||
time.Sleep(50 * time.Millisecond)
|
||
_ = normalConn2.Close()
|
||
|
||
select {
|
||
case err := <-errCh:
|
||
_ = err
|
||
case <-time.After(2 * time.Second):
|
||
t.Error("Bridge did not complete in time")
|
||
}
|
||
}
|
||
|
||
// TestServeHTTP_IgnoresEmptyTargetURL 测试跳过空 URL 的目标。
|
||
func TestServeHTTP_IgnoresEmptyTargetURL(t *testing.T) {
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/api",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 100 * time.Millisecond},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{
|
||
{URL: ""},
|
||
{URL: "http://127.0.0.1:1"},
|
||
}
|
||
targets[1].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.True(t, ctx.Response.StatusCode() == 502 || ctx.Response.StatusCode() == 504)
|
||
}
|
||
|
||
// TestServeHTTP_WithQueryParams 测试带查询参数的请求。
|
||
func TestServeHTTP_WithQueryParams(t *testing.T) {
|
||
var receivedURI string
|
||
|
||
backend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
receivedURI = string(ctx.RequestURI())
|
||
ctx.SetStatusCode(200)
|
||
ctx.SetBodyString("ok")
|
||
},
|
||
}
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer ln.Close()
|
||
|
||
go func() { _ = backend.Serve(ln) }()
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
backendAddr := "http://" + ln.Addr().String()
|
||
|
||
cfg := testutil.NewTestProxyConfig("/")
|
||
targets := []*loadbalance.Target{{URL: backendAddr}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtx("GET", "/search?q=test&limit=10")
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.Equal(t, 200, ctx.Response.StatusCode())
|
||
assert.Contains(t, receivedURI, "q=test")
|
||
assert.Contains(t, receivedURI, "limit=10")
|
||
}
|
||
|
||
// TestWebSocket_ReadResponseError 测试读取 WebSocket 升级响应失败。
|
||
func TestWebSocket_ReadResponseError(t *testing.T) {
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
defer ln.Close()
|
||
|
||
go func() {
|
||
conn, acceptErr := ln.Accept()
|
||
if acceptErr != nil {
|
||
return
|
||
}
|
||
reader := bufio.NewReader(conn)
|
||
_, _ = http.ReadRequest(reader)
|
||
_ = conn.Close()
|
||
}()
|
||
|
||
time.Sleep(20 * time.Millisecond)
|
||
|
||
ctx := testutil.NewRequestCtxWithHeader("GET", "/ws", map[string]string{
|
||
"Upgrade": "websocket",
|
||
"Connection": "Upgrade",
|
||
})
|
||
|
||
target := &loadbalance.Target{URL: "http://" + ln.Addr().String()}
|
||
target.Healthy.Store(true)
|
||
|
||
err = WebSocket(ctx, target, 1*time.Second, nil)
|
||
require.Error(t, err)
|
||
}
|
||
|
||
// TestCopyData_WriteErrorNonClosed 测试写入错误(非关闭类)。
|
||
func TestCopyData_WriteErrorNonClosed(t *testing.T) {
|
||
src1, src2 := net.Pipe()
|
||
dst1, _ := net.Pipe()
|
||
|
||
bridge := &WebSocketBridge{}
|
||
|
||
errCh := make(chan error, 1)
|
||
go func() {
|
||
errCh <- bridge.copyData(dst1, src1, "test-dir")
|
||
}()
|
||
|
||
_ = dst1.Close()
|
||
_, _ = src2.Write([]byte("data"))
|
||
_ = src2.Close()
|
||
|
||
select {
|
||
case err := <-errCh:
|
||
_ = err
|
||
case <-time.After(1 * time.Second):
|
||
t.Error("copyData did not complete")
|
||
}
|
||
}
|
||
|
||
// TestWebSocket_HijackFails 测试 Hijack 失败场景。
|
||
func TestWebSocket_HijackFails(t *testing.T) {
|
||
ctx := &fasthttp.RequestCtx{}
|
||
|
||
target := &loadbalance.Target{URL: "http://127.0.0.1:1"}
|
||
target.Healthy.Store(true)
|
||
|
||
err := WebSocket(ctx, target, 100*time.Millisecond, nil)
|
||
require.Error(t, err)
|
||
}
|
||
|
||
// TestServeHTTP_RedirectRewrite 测试重定向改写。
|
||
func TestServeHTTP_RedirectRewrite(t *testing.T) {
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
require.NoError(t, err)
|
||
|
||
backend := &fasthttp.Server{
|
||
Handler: func(ctx *fasthttp.RequestCtx) {
|
||
ctx.SetStatusCode(301)
|
||
ctx.Response.Header.Set("Location", "http://"+ln.Addr().String()+"/new-path")
|
||
},
|
||
}
|
||
|
||
go func() { _ = backend.Serve(ln) }()
|
||
time.Sleep(20 * time.Millisecond)
|
||
defer ln.Close()
|
||
|
||
cfg := &config.ProxyConfig{
|
||
Path: "/",
|
||
LoadBalance: "round_robin",
|
||
Timeout: config.ProxyTimeout{Connect: 1 * time.Second, Read: 1 * time.Second, Write: 1 * time.Second},
|
||
RedirectRewrite: &config.RedirectRewriteConfig{
|
||
Mode: "default",
|
||
},
|
||
}
|
||
|
||
targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}}
|
||
targets[0].Healthy.Store(true)
|
||
|
||
p, err := NewProxy(cfg, targets, nil, nil)
|
||
require.NoError(t, err)
|
||
|
||
ctx := testutil.NewRequestCtxWithHeader("GET", "/old-path", map[string]string{
|
||
"Host": "frontend.example.com",
|
||
})
|
||
p.ServeHTTP(ctx)
|
||
|
||
assert.Equal(t, 301, ctx.Response.StatusCode())
|
||
location := string(ctx.Response.Header.Peek("Location"))
|
||
assert.Contains(t, location, "frontend.example.com")
|
||
}
|
||
|
||
// TestBuildWebSocketUpgradeRequest_Origin 测试 Origin 头复制。
|
||
func TestBuildWebSocketUpgradeRequest_Origin(t *testing.T) {
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Request.SetRequestURI("/ws")
|
||
ctx.Request.Header.Set("Origin", "http://client.example.com")
|
||
ctx.Request.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate")
|
||
|
||
result := buildWebSocketUpgradeRequest(ctx, "backend:8080", nil)
|
||
|
||
assert.Contains(t, result, "Origin: http://client.example.com")
|
||
assert.Contains(t, result, "Sec-WebSocket-Extensions: permessage-deflate")
|
||
}
|
||
|
||
// TestCopyData_ReaderError 测试读取端错误(非关闭类)。
|
||
func TestCopyData_ReaderError(t *testing.T) {
|
||
src1, _ := net.Pipe()
|
||
dst1, dst2 := net.Pipe()
|
||
defer func() { _ = dst2.Close() }()
|
||
|
||
_ = src1.Close()
|
||
|
||
bridge := &WebSocketBridge{}
|
||
|
||
errCh := make(chan error, 1)
|
||
go func() {
|
||
errCh <- bridge.copyData(dst1, src1, "test")
|
||
}()
|
||
|
||
select {
|
||
case err := <-errCh:
|
||
_ = err
|
||
case <-time.After(1 * time.Second):
|
||
t.Error("copyData did not complete in time")
|
||
}
|
||
_ = dst1.Close()
|
||
}
|
||
|
||
// TestIsConnectionClosedError_RegularError 测试普通错误不被识别为关闭错误。
|
||
func TestIsConnectionClosedError_RegularError(t *testing.T) {
|
||
err := errors.New("random error")
|
||
assert.False(t, isConnectionClosedError(err))
|
||
}
|
||
|
||
// TestIsConnectionClosedError_Nil 测试 nil 不被识别为关闭错误。
|
||
func TestIsConnectionClosedError_Nil(t *testing.T) {
|
||
assert.False(t, isConnectionClosedError(nil))
|
||
}
|
||
|
||
// TestIsConnectionClosedError_EOF 测试 EOF 被识别为关闭错误。
|
||
func TestIsConnectionClosedError_EOF(t *testing.T) {
|
||
assert.True(t, isConnectionClosedError(io.EOF))
|
||
}
|