test: 添加各模块覆盖率补充测试
- middleware/security: access 中间件覆盖率测试 - proxy: proxy 核心功能覆盖率测试 - server: server 扩展功能测试 - stream: stream 处理覆盖率测试 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
c82e6dcdb7
commit
5f5717d6a4
340
internal/middleware/security/access_coverage_test.go
Normal file
340
internal/middleware/security/access_coverage_test.go
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
// Package security 提供访问控制覆盖测试。
|
||||||
|
//
|
||||||
|
// 该文件补充测试 access.go 中未覆盖的方法:
|
||||||
|
// - Name() 方法
|
||||||
|
// - Process() 完整处理链(允许/拒绝路径)
|
||||||
|
// - getClientIP() 通过 Process 间接测试
|
||||||
|
// - Close() 方法
|
||||||
|
// - actionToString 边缘情况
|
||||||
|
// - trustedProxies 相关逻辑
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAccessControlName 测试 Name 方法
|
||||||
|
func TestAccessControlName(t *testing.T) {
|
||||||
|
ac, err := NewAccessControl(&config.AccessConfig{
|
||||||
|
Default: "allow",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAccessControl() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
name := ac.Name()
|
||||||
|
if name != "access_control" {
|
||||||
|
t.Errorf("Name() = %q, want 'access_control'", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAccessControlProcess_AllowPath 测试 Process 允许路径
|
||||||
|
func TestAccessControlProcess_AllowPath(t *testing.T) {
|
||||||
|
ac, err := NewAccessControl(&config.AccessConfig{
|
||||||
|
Default: "allow",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAccessControl() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
called := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
called = true
|
||||||
|
_, _ = ctx.WriteString("allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := ac.Process(nextHandler)
|
||||||
|
if handler == nil {
|
||||||
|
t.Fatal("Process() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Error("Process() should call next handler when access allowed")
|
||||||
|
}
|
||||||
|
if string(ctx.Response.Body()) != "allowed" {
|
||||||
|
t.Errorf("Process() body = %q, want 'allowed'", string(ctx.Response.Body()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAccessControlProcess_DenyPath 测试 Process 拒绝路径
|
||||||
|
func TestAccessControlProcess_DenyPath(t *testing.T) {
|
||||||
|
ac, err := NewAccessControl(&config.AccessConfig{
|
||||||
|
Default: "deny",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAccessControl() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
called := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
called = true
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := ac.Process(nextHandler)
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if called {
|
||||||
|
t.Error("Process() should NOT call next handler when access denied")
|
||||||
|
}
|
||||||
|
if ctx.Response.StatusCode() != fasthttp.StatusForbidden {
|
||||||
|
t.Errorf("Process() status = %d, want 403", ctx.Response.StatusCode())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAccessControlProcess_ExplicitAllow 测试显式允许列表
|
||||||
|
func TestAccessControlProcess_ExplicitAllow(t *testing.T) {
|
||||||
|
ac, err := NewAccessControl(&config.AccessConfig{
|
||||||
|
Allow: []string{"127.0.0.1"},
|
||||||
|
Default: "deny",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAccessControl() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
called := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
called = true
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := ac.Process(nextHandler)
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345})
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Error("Process() should call next handler for allowed IP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAccessControlProcess_ExplicitDeny 测试显式拒绝列表
|
||||||
|
func TestAccessControlProcess_ExplicitDeny(t *testing.T) {
|
||||||
|
ac, err := NewAccessControl(&config.AccessConfig{
|
||||||
|
Deny: []string{"10.0.0.1"},
|
||||||
|
Default: "allow",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAccessControl() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
called := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
called = true
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := ac.Process(nextHandler)
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 12345})
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if called {
|
||||||
|
t.Error("Process() should NOT call next handler for denied IP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAccessControlProcess_TrustedProxies_XFF 测试可信代理 XFF 解析
|
||||||
|
func TestAccessControlProcess_TrustedProxies_XFF(t *testing.T) {
|
||||||
|
// 配置可信代理,10.0.0.0/8 为可信代理段
|
||||||
|
ac, err := NewAccessControl(&config.AccessConfig{
|
||||||
|
TrustedProxies: []string{"10.0.0.0/8"},
|
||||||
|
Default: "deny",
|
||||||
|
Allow: []string{"192.168.1.0/24"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAccessControl() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
called := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
called = true
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := ac.Process(nextHandler)
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
// 请求来自可信代理 10.0.0.1,XFF 中包含真实客户端 192.168.1.100
|
||||||
|
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 12345})
|
||||||
|
ctx.Request.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1")
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Error("Process() should allow real client IP behind trusted proxy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAccessControlProcess_TrustedProxies_UntrustedSource 测试不可信来源不解析 XFF
|
||||||
|
func TestAccessControlProcess_TrustedProxies_UntrustedSource(t *testing.T) {
|
||||||
|
ac, err := NewAccessControl(&config.AccessConfig{
|
||||||
|
TrustedProxies: []string{"10.0.0.0/8"},
|
||||||
|
Default: "deny",
|
||||||
|
Allow: []string{"192.168.1.0/24"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAccessControl() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
called := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
called = true
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := ac.Process(nextHandler)
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
// 请求来自不可信地址,即使 XFF 包含允许列表 IP 也不应解析
|
||||||
|
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("203.0.113.1"), Port: 12345})
|
||||||
|
ctx.Request.Header.Set("X-Forwarded-For", "192.168.1.100")
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if called {
|
||||||
|
t.Error("Process() should not trust XFF from untrusted source")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAccessControlProcess_TrustedProxies_XRealIP 测试可信代理 X-Real-IP
|
||||||
|
func TestAccessControlProcess_TrustedProxies_XRealIP(t *testing.T) {
|
||||||
|
ac, err := NewAccessControl(&config.AccessConfig{
|
||||||
|
TrustedProxies: []string{"10.0.0.0/8"},
|
||||||
|
Default: "deny",
|
||||||
|
Allow: []string{"192.168.1.0/24"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAccessControl() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
called := false
|
||||||
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
called = true
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := ac.Process(nextHandler)
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("10.0.0.50"), Port: 12345})
|
||||||
|
ctx.Request.Header.Set("X-Real-IP", "192.168.1.50")
|
||||||
|
handler(ctx)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Error("Process() should use X-Real-IP from trusted proxy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAccessControlClose 测试 Close 方法
|
||||||
|
func TestAccessControlClose(t *testing.T) {
|
||||||
|
// 无 GeoIP 的 Close
|
||||||
|
ac, err := NewAccessControl(&config.AccessConfig{
|
||||||
|
Default: "allow",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAccessControl() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ac.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Close() error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestActionToString 测试 actionToString 边缘情况
|
||||||
|
func TestActionToString(t *testing.T) {
|
||||||
|
// 测试 ActionAllow
|
||||||
|
result := actionToString(ActionAllow)
|
||||||
|
if result != "allow" {
|
||||||
|
t.Errorf("actionToString(ActionAllow) = %q, want 'allow'", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试 ActionDeny
|
||||||
|
result = actionToString(ActionDeny)
|
||||||
|
if result != "deny" {
|
||||||
|
t.Errorf("actionToString(ActionDeny) = %q, want 'deny'", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试未知值
|
||||||
|
result = actionToString(Action(999))
|
||||||
|
if result != "unknown" {
|
||||||
|
t.Errorf("actionToString(999) = %q, want 'unknown'", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetStatsWithEmpty 测试 GetStats 空列表
|
||||||
|
func TestGetStatsWithEmpty(t *testing.T) {
|
||||||
|
ac, err := NewAccessControl(&config.AccessConfig{
|
||||||
|
Default: "allow",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAccessControl() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := ac.GetStats()
|
||||||
|
if stats.AllowCount != 0 {
|
||||||
|
t.Errorf("GetStats().AllowCount = %d, want 0", stats.AllowCount)
|
||||||
|
}
|
||||||
|
if stats.DenyCount != 0 {
|
||||||
|
t.Errorf("GetStats().DenyCount = %d, want 0", stats.DenyCount)
|
||||||
|
}
|
||||||
|
if stats.Default != "allow" {
|
||||||
|
t.Errorf("GetStats().Default = %q, want 'allow'", stats.Default)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSetDefaultValidCases 测试 SetDefault 所有有效值
|
||||||
|
func TestSetDefaultValidCases(t *testing.T) {
|
||||||
|
ac, err := NewAccessControl(&config.AccessConfig{
|
||||||
|
Default: "allow",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAccessControl() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 切换为 deny
|
||||||
|
err = ac.SetDefault("deny")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("SetDefault('deny') error: %v", err)
|
||||||
|
}
|
||||||
|
stats := ac.GetStats()
|
||||||
|
if stats.Default != "deny" {
|
||||||
|
t.Errorf("After SetDefault('deny'), Default = %q, want 'deny'", stats.Default)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 切换回 allow
|
||||||
|
err = ac.SetDefault("allow")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("SetDefault('allow') error: %v", err)
|
||||||
|
}
|
||||||
|
stats = ac.GetStats()
|
||||||
|
if stats.Default != "allow" {
|
||||||
|
t.Errorf("After SetDefault('allow'), Default = %q, want 'allow'", stats.Default)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 大小写不敏感
|
||||||
|
err = ac.SetDefault("DENY")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("SetDefault('DENY') error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpdateDenyListError 测试 UpdateDenyList 错误路径
|
||||||
|
func TestUpdateDenyListError(t *testing.T) {
|
||||||
|
ac, err := NewAccessControl(&config.AccessConfig{
|
||||||
|
Default: "allow",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAccessControl() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ac.UpdateDenyList([]string{"not-an-ip"})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("UpdateDenyList() should return error for invalid CIDR")
|
||||||
|
}
|
||||||
|
}
|
||||||
720
internal/proxy/proxy_coverage_test.go
Normal file
720
internal/proxy/proxy_coverage_test.go
Normal file
@ -0,0 +1,720 @@
|
|||||||
|
// Package proxy 提供反向代理覆盖测试,补充 proxy.go 中未覆盖的方法。
|
||||||
|
//
|
||||||
|
// 该文件测试代理模块的以下功能:
|
||||||
|
// - selectTargetExcluding 排除已失败目标的选择
|
||||||
|
// - extractHashKey 哈希键提取
|
||||||
|
// - buildCacheKeyHash / buildCacheKeyHashValue 缓存键计算
|
||||||
|
// - writeCachedResponse 缓存响应写入
|
||||||
|
// - GetCache / GetCacheStats 缓存访问
|
||||||
|
// - getCacheDuration 不同状态码的缓存时间
|
||||||
|
// - redirect_rewrite 相关功能
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"github.com/valyala/fasthttp/fasthttputil"
|
||||||
|
"rua.plus/lolly/internal/cache"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/loadbalance"
|
||||||
|
"rua.plus/lolly/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSelectTargetExcluding 测试排除失败目标的目标选择
|
||||||
|
func TestSelectTargetExcluding(t *testing.T) {
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
}
|
||||||
|
|
||||||
|
targets := []*loadbalance.Target{
|
||||||
|
{URL: "http://backend1:8080"},
|
||||||
|
{URL: "http://backend2:8080"},
|
||||||
|
{URL: "http://backend3:8080"},
|
||||||
|
}
|
||||||
|
for _, target := range targets {
|
||||||
|
target.Healthy.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||||
|
|
||||||
|
// 排除第一个目标,应该选择第二个
|
||||||
|
excluded := []*loadbalance.Target{targets[0]}
|
||||||
|
selected := p.selectTargetExcluding(ctx, excluded)
|
||||||
|
if selected == nil {
|
||||||
|
t.Fatal("selectTargetExcluding() returned nil")
|
||||||
|
}
|
||||||
|
if selected.URL == "http://backend1:8080" {
|
||||||
|
t.Error("selectTargetExcluding() should not select excluded target")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 排除所有目标,应该返回 nil
|
||||||
|
allExcluded := []*loadbalance.Target{targets[0], targets[1], targets[2]}
|
||||||
|
selected = p.selectTargetExcluding(ctx, allExcluded)
|
||||||
|
if selected != nil {
|
||||||
|
t.Error("selectTargetExcluding() should return nil when all excluded")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 空目标列表
|
||||||
|
p2, _ := NewProxy(cfg, []*loadbalance.Target{{URL: "http://a:1"}}, nil, nil)
|
||||||
|
p2.targets = nil
|
||||||
|
selected = p2.selectTargetExcluding(ctx, nil)
|
||||||
|
if selected != nil {
|
||||||
|
t.Error("selectTargetExcluding() should return nil for empty targets")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSelectTargetExcluding_IPHash 测试 IP Hash 排除选择
|
||||||
|
func TestSelectTargetExcluding_IPHash(t *testing.T) {
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "ip_hash",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
}
|
||||||
|
|
||||||
|
targets := []*loadbalance.Target{
|
||||||
|
{URL: "http://backend1:8080"},
|
||||||
|
{URL: "http://backend2:8080"},
|
||||||
|
{URL: "http://backend3:8080"},
|
||||||
|
}
|
||||||
|
for _, target := range targets {
|
||||||
|
target.Healthy.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{
|
||||||
|
"X-Forwarded-For": "192.168.1.1",
|
||||||
|
})
|
||||||
|
|
||||||
|
// 获取第一次选择
|
||||||
|
first := p.selectTarget(ctx)
|
||||||
|
if first == nil {
|
||||||
|
t.Fatal("selectTarget() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 排除第一次选择的目标
|
||||||
|
excluded := []*loadbalance.Target{first}
|
||||||
|
second := p.selectTargetExcluding(ctx, excluded)
|
||||||
|
if second == nil {
|
||||||
|
t.Fatal("selectTargetExcluding() returned nil")
|
||||||
|
}
|
||||||
|
if second.URL == first.URL {
|
||||||
|
t.Errorf("selectTargetExcluding() should not select same target %s", first.URL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSelectTargetExcluding_ConsistentHash 测试一致性哈希排除选择
|
||||||
|
func TestSelectTargetExcluding_ConsistentHash(t *testing.T) {
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "consistent_hash",
|
||||||
|
HashKey: "uri",
|
||||||
|
VirtualNodes: 150,
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
}
|
||||||
|
|
||||||
|
targets := []*loadbalance.Target{
|
||||||
|
{URL: "http://backend1:8080"},
|
||||||
|
{URL: "http://backend2:8080"},
|
||||||
|
}
|
||||||
|
for _, target := range targets {
|
||||||
|
target.Healthy.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||||
|
|
||||||
|
// 排除一个目标,应该还能选到另一个
|
||||||
|
excluded := []*loadbalance.Target{targets[0]}
|
||||||
|
selected := p.selectTargetExcluding(ctx, excluded)
|
||||||
|
if selected == nil {
|
||||||
|
t.Error("selectTargetExcluding() should return remaining target")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractHashKey 测试哈希键提取
|
||||||
|
func TestExtractHashKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hashKey string
|
||||||
|
headers map[string]string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ip hash key",
|
||||||
|
hashKey: "ip",
|
||||||
|
headers: map[string]string{"X-Forwarded-For": "10.0.0.1"},
|
||||||
|
expected: "10.0.0.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty hash key defaults to ip",
|
||||||
|
hashKey: "",
|
||||||
|
headers: map[string]string{"X-Forwarded-For": "10.0.0.2"},
|
||||||
|
expected: "10.0.0.2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uri hash key",
|
||||||
|
hashKey: "uri",
|
||||||
|
headers: nil,
|
||||||
|
expected: "/api/test",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "header hash key - found",
|
||||||
|
hashKey: "header:X-Custom-ID",
|
||||||
|
headers: map[string]string{"X-Custom-ID": "abc123"},
|
||||||
|
expected: "abc123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "header hash key - fallback to ip",
|
||||||
|
hashKey: "header:X-Missing",
|
||||||
|
headers: map[string]string{"X-Forwarded-For": "10.0.0.3"},
|
||||||
|
expected: "10.0.0.3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown hash key defaults to ip",
|
||||||
|
hashKey: "unknown",
|
||||||
|
headers: map[string]string{"X-Forwarded-For": "10.0.0.4"},
|
||||||
|
expected: "10.0.0.4",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "consistent_hash",
|
||||||
|
HashKey: tt.hashKey,
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
}
|
||||||
|
|
||||||
|
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", tt.headers)
|
||||||
|
result := p.extractHashKey(ctx, tt.hashKey)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("extractHashKey() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildCacheKeyHash 测试缓存键哈希计算
|
||||||
|
func TestBuildCacheKeyHash(t *testing.T) {
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
}
|
||||||
|
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||||
|
|
||||||
|
hashKey, origKey := p.buildCacheKeyHash(ctx)
|
||||||
|
if hashKey == 0 {
|
||||||
|
t.Error("buildCacheKeyHash() should return non-zero hash")
|
||||||
|
}
|
||||||
|
if origKey == "" {
|
||||||
|
t.Error("buildCacheKeyHash() should return non-empty origKey")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 相同请求应产生相同哈希
|
||||||
|
ctx2 := testutil.NewRequestCtx("GET", "/api/test")
|
||||||
|
hashKey2, _ := p.buildCacheKeyHash(ctx2)
|
||||||
|
if hashKey != hashKey2 {
|
||||||
|
t.Error("Same request should produce same hash")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 不同请求应产生不同哈希
|
||||||
|
ctx3 := testutil.NewRequestCtx("POST", "/api/other")
|
||||||
|
hashKey3, _ := p.buildCacheKeyHash(ctx3)
|
||||||
|
if hashKey == hashKey3 {
|
||||||
|
t.Error("Different request should produce different hash")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildCacheKeyHashValue 测试零分配缓存键哈希
|
||||||
|
func TestBuildCacheKeyHashValue(t *testing.T) {
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
}
|
||||||
|
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||||
|
|
||||||
|
hashValue := p.buildCacheKeyHashValue(ctx)
|
||||||
|
if hashValue == 0 {
|
||||||
|
t.Error("buildCacheKeyHashValue() should return non-zero hash")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应该与 buildCacheKeyHash 结果一致
|
||||||
|
hashKey, _ := p.buildCacheKeyHash(ctx)
|
||||||
|
if hashValue != hashKey {
|
||||||
|
t.Error("buildCacheKeyHashValue() should match buildCacheKeyHash()")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteCachedResponse 测试缓存响应写入
|
||||||
|
func TestWriteCachedResponse(t *testing.T) {
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
}
|
||||||
|
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 手动创建一个 Response 用于验证 writeCachedResponse 写入正确
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||||
|
|
||||||
|
entry := &cache.ProxyCacheEntry{
|
||||||
|
Data: []byte("cached body"),
|
||||||
|
Headers: map[string]string{"Content-Type": "text/html", "X-Cached": "true"},
|
||||||
|
Status: 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
p.writeCachedResponse(ctx, entry)
|
||||||
|
|
||||||
|
if ctx.Response.StatusCode() != 200 {
|
||||||
|
t.Errorf("writeCachedResponse() status = %d, want 200", ctx.Response.StatusCode())
|
||||||
|
}
|
||||||
|
if string(ctx.Response.Body()) != "cached body" {
|
||||||
|
t.Errorf("writeCachedResponse() body = %q, want %q", string(ctx.Response.Body()), "cached body")
|
||||||
|
}
|
||||||
|
ct := string(ctx.Response.Header.Peek("Content-Type"))
|
||||||
|
if ct != "text/html" {
|
||||||
|
t.Errorf("writeCachedResponse() Content-Type = %q, want %q", ct, "text/html")
|
||||||
|
}
|
||||||
|
xc := string(ctx.Response.Header.Peek("X-Cache"))
|
||||||
|
if xc != "HIT" {
|
||||||
|
t.Errorf("writeCachedResponse() X-Cache = %q, want HIT", xc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetCache 测试 GetCache 方法
|
||||||
|
func TestGetCache(t *testing.T) {
|
||||||
|
// 启用缓存时
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
Cache: config.ProxyCacheConfig{
|
||||||
|
Enabled: true,
|
||||||
|
MaxAge: 10 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c := p.GetCache()
|
||||||
|
if c == nil {
|
||||||
|
t.Error("GetCache() should return non-nil when cache enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 禁用缓存时
|
||||||
|
cfg2 := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
}
|
||||||
|
p2, _ := NewProxy(cfg2, targets, nil, nil)
|
||||||
|
c2 := p2.GetCache()
|
||||||
|
if c2 != nil {
|
||||||
|
t.Error("GetCache() should return nil when cache disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetCacheStats 测试 GetCacheStats 方法
|
||||||
|
func TestGetCacheStats(t *testing.T) {
|
||||||
|
// 启用缓存时
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
Cache: config.ProxyCacheConfig{
|
||||||
|
Enabled: true,
|
||||||
|
MaxAge: 10 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := p.GetCacheStats()
|
||||||
|
if stats == nil {
|
||||||
|
t.Error("GetCacheStats() should return non-nil when cache enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 禁用缓存时
|
||||||
|
cfg2 := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
}
|
||||||
|
p2, _ := NewProxy(cfg2, targets, nil, nil)
|
||||||
|
stats2 := p2.GetCacheStats()
|
||||||
|
if stats2 != nil {
|
||||||
|
t.Error("GetCacheStats() should return nil when cache disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetCacheDuration 测试不同状态码的缓存时间计算
|
||||||
|
func TestGetCacheDuration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cacheValid *config.ProxyCacheValidConfig
|
||||||
|
maxAge time.Duration
|
||||||
|
statusCode int
|
||||||
|
expected time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no CacheValid config uses MaxAge",
|
||||||
|
maxAge: 5 * time.Minute,
|
||||||
|
statusCode: 200,
|
||||||
|
expected: 5 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "2xx with CacheValid.OK set",
|
||||||
|
cacheValid: &config.ProxyCacheValidConfig{
|
||||||
|
OK: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
statusCode: 200,
|
||||||
|
expected: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "2xx with CacheValid.OK=0 inherits MaxAge",
|
||||||
|
cacheValid: &config.ProxyCacheValidConfig{
|
||||||
|
OK: 0,
|
||||||
|
},
|
||||||
|
maxAge: 3 * time.Minute,
|
||||||
|
statusCode: 201,
|
||||||
|
expected: 3 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "301 redirect",
|
||||||
|
cacheValid: &config.ProxyCacheValidConfig{
|
||||||
|
Redirect: 1 * time.Hour,
|
||||||
|
},
|
||||||
|
statusCode: 301,
|
||||||
|
expected: 1 * time.Hour,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "302 redirect",
|
||||||
|
cacheValid: &config.ProxyCacheValidConfig{
|
||||||
|
Redirect: 30 * time.Minute,
|
||||||
|
},
|
||||||
|
statusCode: 302,
|
||||||
|
expected: 30 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "302 with zero Redirect means no cache",
|
||||||
|
cacheValid: &config.ProxyCacheValidConfig{
|
||||||
|
Redirect: 0,
|
||||||
|
},
|
||||||
|
statusCode: 302,
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "404",
|
||||||
|
cacheValid: &config.ProxyCacheValidConfig{
|
||||||
|
NotFound: 1 * time.Minute,
|
||||||
|
},
|
||||||
|
statusCode: 404,
|
||||||
|
expected: 1 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "4xx client error",
|
||||||
|
cacheValid: &config.ProxyCacheValidConfig{
|
||||||
|
ClientError: 30 * time.Second,
|
||||||
|
},
|
||||||
|
statusCode: 400,
|
||||||
|
expected: 30 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "5xx server error",
|
||||||
|
cacheValid: &config.ProxyCacheValidConfig{
|
||||||
|
ServerError: 0,
|
||||||
|
},
|
||||||
|
statusCode: 500,
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other status code",
|
||||||
|
statusCode: 100,
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
Cache: config.ProxyCacheConfig{
|
||||||
|
Enabled: true,
|
||||||
|
MaxAge: tt.maxAge,
|
||||||
|
},
|
||||||
|
CacheValid: tt.cacheValid,
|
||||||
|
}
|
||||||
|
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := p.getCacheDuration(tt.statusCode)
|
||||||
|
if duration != tt.expected {
|
||||||
|
t.Errorf("getCacheDuration(%d) = %v, want %v", tt.statusCode, duration, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBackgroundRefresh 测试后台缓存刷新(标记为 skip 因为需要真实网络)
|
||||||
|
func TestBackgroundRefresh(t *testing.T) {
|
||||||
|
t.Skip("skipping: requires real network connection and is timing-sensitive")
|
||||||
|
ln := fasthttputil.NewInmemoryListener()
|
||||||
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
s := &fasthttp.Server{
|
||||||
|
Handler: func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.SetStatusCode(200)
|
||||||
|
ctx.SetBodyString("refreshed")
|
||||||
|
ctx.Response.Header.Set("Content-Type", "text/plain")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_ = s.Serve(ln)
|
||||||
|
}()
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
addr := ln.Addr().String()
|
||||||
|
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
Cache: config.ProxyCacheConfig{
|
||||||
|
Enabled: true,
|
||||||
|
MaxAge: 10 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
targets := []*loadbalance.Target{
|
||||||
|
{URL: "http://" + addr},
|
||||||
|
}
|
||||||
|
targets[0].Healthy.Store(true)
|
||||||
|
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||||
|
|
||||||
|
// 设置缓存锁
|
||||||
|
hashKey := uint64(12345)
|
||||||
|
p.cache.AcquireLock(hashKey)
|
||||||
|
|
||||||
|
// 调用后台刷新(它会执行实际请求来刷新缓存)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
p.backgroundRefresh(ctx, targets[0], hashKey, "/api/test")
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 等待完成
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// 完成
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("backgroundRefresh() timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBackgroundRefresh_NoClient 测试后台刷新时客户端不存在的情况
|
||||||
|
func TestBackgroundRefresh_NoClient(t *testing.T) {
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
Cache: config.ProxyCacheConfig{
|
||||||
|
Enabled: true,
|
||||||
|
MaxAge: 10 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
targets := []*loadbalance.Target{
|
||||||
|
{URL: "http://nonexistent:9999"},
|
||||||
|
}
|
||||||
|
targets[0].Healthy.Store(true)
|
||||||
|
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除客户端
|
||||||
|
p.mu.Lock()
|
||||||
|
delete(p.clients, targets[0].URL)
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||||
|
hashKey := uint64(99999)
|
||||||
|
p.cache.AcquireLock(hashKey)
|
||||||
|
|
||||||
|
// 应该不会 panic,直接返回
|
||||||
|
p.backgroundRefresh(ctx, targets[0], hashKey, "/api/test")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServeHTTP_CacheHit 测试缓存命中路径
|
||||||
|
func TestServeHTTP_CacheHit(t *testing.T) {
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
Cache: config.ProxyCacheConfig{
|
||||||
|
Enabled: true,
|
||||||
|
MaxAge: 10 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||||
|
targets[0].Healthy.Store(true)
|
||||||
|
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 预填充缓存
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/api/cached")
|
||||||
|
hashKey, origKey := p.buildCacheKeyHash(ctx)
|
||||||
|
p.cache.Set(hashKey, origKey, []byte("cached!"), map[string]string{
|
||||||
|
"Content-Type": "text/plain",
|
||||||
|
}, 200, 10*time.Second)
|
||||||
|
|
||||||
|
// 执行请求
|
||||||
|
p.ServeHTTP(ctx)
|
||||||
|
|
||||||
|
// 应该返回缓存的响应
|
||||||
|
if ctx.Response.StatusCode() != 200 {
|
||||||
|
t.Errorf("ServeHTTP() status = %d, want 200", ctx.Response.StatusCode())
|
||||||
|
}
|
||||||
|
if string(ctx.Response.Body()) != "cached!" {
|
||||||
|
t.Errorf("ServeHTTP() body = %q, want %q", string(ctx.Response.Body()), "cached!")
|
||||||
|
}
|
||||||
|
xc := string(ctx.Response.Header.Peek("X-Cache"))
|
||||||
|
if xc != "HIT" {
|
||||||
|
t.Errorf("ServeHTTP() X-Cache = %q, want HIT", xc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServeHTTP_ClientNil 测试客户端为 nil 时的行为
|
||||||
|
func TestServeHTTP_ClientNil(t *testing.T) {
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
NextUpstream: config.NextUpstreamConfig{
|
||||||
|
Tries: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
targets := []*loadbalance.Target{
|
||||||
|
{URL: "http://backend1:8080"},
|
||||||
|
{URL: "http://backend2:8080"},
|
||||||
|
}
|
||||||
|
for _, target := range targets {
|
||||||
|
target.Healthy.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除所有客户端
|
||||||
|
p.mu.Lock()
|
||||||
|
p.clients = make(map[string]*fasthttp.HostClient)
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||||
|
p.ServeHTTP(ctx)
|
||||||
|
|
||||||
|
// 所有客户端都不存在,应该返回 502
|
||||||
|
if ctx.Response.StatusCode() != fasthttp.StatusBadGateway {
|
||||||
|
t.Errorf("ServeHTTP() status = %d, want 502", ctx.Response.StatusCode())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServeHTTP_WithRedirectRewrite 测试带 redirect_rewrite 的缓存命中
|
||||||
|
func TestServeHTTP_WithRedirectRewrite_CacheHit(t *testing.T) {
|
||||||
|
cfg := &config.ProxyConfig{
|
||||||
|
Path: "/api",
|
||||||
|
LoadBalance: "round_robin",
|
||||||
|
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||||
|
Cache: config.ProxyCacheConfig{
|
||||||
|
Enabled: true,
|
||||||
|
MaxAge: 10 * time.Second,
|
||||||
|
},
|
||||||
|
RedirectRewrite: &config.RedirectRewriteConfig{
|
||||||
|
Mode: "off", // 关闭改写
|
||||||
|
},
|
||||||
|
}
|
||||||
|
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||||
|
targets[0].Healthy.Store(true)
|
||||||
|
|
||||||
|
p, err := NewProxy(cfg, targets, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewProxy() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := testutil.NewRequestCtx("GET", "/api/test")
|
||||||
|
hashKey, origKey := p.buildCacheKeyHash(ctx)
|
||||||
|
p.cache.Set(hashKey, origKey, []byte("ok"), map[string]string{
|
||||||
|
"Content-Type": "text/plain",
|
||||||
|
}, 200, 10*time.Second)
|
||||||
|
|
||||||
|
p.ServeHTTP(ctx)
|
||||||
|
|
||||||
|
if ctx.Response.StatusCode() != 200 {
|
||||||
|
t.Errorf("ServeHTTP() status = %d, want 200", ctx.Response.StatusCode())
|
||||||
|
}
|
||||||
|
}
|
||||||
465
internal/server/server_extended_test.go
Normal file
465
internal/server/server_extended_test.go
Normal file
@ -0,0 +1,465 @@
|
|||||||
|
// Package server 提供 HTTP 服务器的核心实现测试补充。
|
||||||
|
//
|
||||||
|
// 该文件补充以下测试覆盖:
|
||||||
|
// - 多模式启动逻辑测试(single/vhost/multi_server/auto)
|
||||||
|
// - 多服务器模式 shutdownServers 函数测试
|
||||||
|
// - 监听器创建测试(TCP/Unix socket)
|
||||||
|
// - StopWithTimeout 超时行为测试
|
||||||
|
// - GracefulStop 超时行为测试
|
||||||
|
// - 中间件链错误路径测试
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
"rua.plus/lolly/internal/config"
|
||||||
|
"rua.plus/lolly/internal/resolver"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestServer_GetMode_Single 测试单服务器模式
|
||||||
|
func TestServer_GetMode_Single(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Mode: config.ServerModeSingle,
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: ":0",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s.config.GetMode() != config.ServerModeSingle {
|
||||||
|
t.Errorf("Expected mode single, got %s", s.config.GetMode())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServer_GetMode_VHost 测试虚拟主机模式
|
||||||
|
func TestServer_GetMode_VHost(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Mode: config.ServerModeVHost,
|
||||||
|
Servers: []config.ServerConfig{
|
||||||
|
{Listen: ":0", Name: "host1.example.com"},
|
||||||
|
{Listen: ":0", Name: "host2.example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s.config.GetMode() != config.ServerModeVHost {
|
||||||
|
t.Errorf("Expected mode vhost, got %s", s.config.GetMode())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServer_GetMode_MultiServer 测试多服务器模式
|
||||||
|
func TestServer_GetMode_MultiServer(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Mode: config.ServerModeMultiServer,
|
||||||
|
Servers: []config.ServerConfig{
|
||||||
|
{Listen: ":8080", Name: "server1"},
|
||||||
|
{Listen: ":8081", Name: "server2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s.config.GetMode() != config.ServerModeMultiServer {
|
||||||
|
t.Errorf("Expected mode multi_server, got %s", s.config.GetMode())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServer_GetMode_Auto 测试自动模式
|
||||||
|
func TestServer_GetMode_Auto(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Mode: config.ServerModeAuto,
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: ":0",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
mode := s.config.GetMode()
|
||||||
|
if mode != config.ServerModeSingle {
|
||||||
|
t.Errorf("Expected auto to resolve to single, got %s", mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestShutdownServers_Empty 测试空服务器列表关闭
|
||||||
|
func TestShutdownServers_Empty(t *testing.T) {
|
||||||
|
err := shutdownServers(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected nil error for empty servers, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestShutdownServers_NilServer 测试含 nil 的服务器列表关闭
|
||||||
|
func TestShutdownServers_NilServer(t *testing.T) {
|
||||||
|
servers := []*fasthttp.Server{nil, nil}
|
||||||
|
err := shutdownServers(nil, servers)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected nil error with nil servers, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestShutdownServers_Timeout 测试关闭超时
|
||||||
|
func TestShutdownServers_Timeout(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fastSrv := &fasthttp.Server{
|
||||||
|
Handler: func(ctx *fasthttp.RequestCtx) {
|
||||||
|
select {}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_ = fastSrv.Serve(ln)
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err = shutdownServers(ctx, []*fasthttp.Server{fastSrv})
|
||||||
|
// context 超时后 shutdownServers 会返回 ctx.Err()
|
||||||
|
if err != nil && err != context.DeadlineExceeded {
|
||||||
|
t.Errorf("Expected context.DeadlineExceeded or nil, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = fastSrv.Shutdown()
|
||||||
|
_ = ln.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStopWithTimeout_DefaultTimeout 测试零超时使用默认值
|
||||||
|
func TestStopWithTimeout_DefaultTimeout(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: ":0",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
err := s.StopWithTimeout(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("StopWithTimeout(0) should succeed, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.StopWithTimeout(-1 * time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("StopWithTimeout(-1s) should succeed, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStopWithTimeout_MultiServerMode 测试多服务器模式停止
|
||||||
|
func TestStopWithTimeout_MultiServerMode(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Mode: config.ServerModeMultiServer,
|
||||||
|
Servers: []config.ServerConfig{
|
||||||
|
{Listen: ":0", Name: "server1"},
|
||||||
|
{Listen: ":0", Name: "server2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
err := s.StopWithTimeout(1 * time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("StopWithTimeout on non-started multi-server should succeed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGracefulStop_Timeout 测试优雅停止超时
|
||||||
|
func TestGracefulStop_Timeout(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: ":0",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
s.fastServer = &fasthttp.Server{
|
||||||
|
Handler: func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.SetBodyString("ok")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s.running = true
|
||||||
|
|
||||||
|
err := s.GracefulStop(100 * time.Millisecond)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("GracefulStop should succeed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServer_SetUpgradeManager 测试设置升级管理器
|
||||||
|
func TestServer_SetUpgradeManager(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: ":0",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
mgr := NewUpgradeManager(s)
|
||||||
|
|
||||||
|
s.SetUpgradeManager(mgr)
|
||||||
|
if s.upgradeManager != mgr {
|
||||||
|
t.Error("upgradeManager not set correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockResolver 用于测试的 mock DNS 解析器
|
||||||
|
type mockResolver struct{}
|
||||||
|
|
||||||
|
func (m *mockResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
|
||||||
|
return []string{"127.0.0.1"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockResolver) LookupHostWithCache(ctx context.Context, host string) ([]string, error) {
|
||||||
|
return []string{"127.0.0.1"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockResolver) Refresh(host string) error { return nil }
|
||||||
|
func (m *mockResolver) Start() error { return nil }
|
||||||
|
func (m *mockResolver) Stop() error { return nil }
|
||||||
|
func (m *mockResolver) Stats() resolver.Stats { return resolver.Stats{} }
|
||||||
|
|
||||||
|
// TestServer_Resolver 测试 DNS 解析器设置
|
||||||
|
func TestServer_Resolver(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: ":0",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
if s.GetResolver() != nil {
|
||||||
|
t.Error("Expected nil resolver initially")
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRes := &mockResolver{}
|
||||||
|
s.SetResolver(mockRes)
|
||||||
|
|
||||||
|
if s.GetResolver() == nil {
|
||||||
|
t.Error("Resolver not set correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCreateListener_TCP 测试 TCP 监听器创建
|
||||||
|
func TestCreateListener_TCP(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: ":0",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
serverCfg := &cfg.Servers[0]
|
||||||
|
|
||||||
|
ln, err := s.createListener(serverCfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("createListener failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
|
if ln == nil {
|
||||||
|
t.Fatal("Expected non-nil listener")
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := ln.Addr().(*net.TCPAddr)
|
||||||
|
if addr.Port == 0 {
|
||||||
|
t.Error("Expected non-zero port")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCreateListener_InvalidTCP 测试无效 TCP 监听地址
|
||||||
|
func TestCreateListener_InvalidTCP(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: "invalid:address:format",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
_, err := s.createListener(&cfg.Servers[0])
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for invalid listen address")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestListenerManagement 测试监听器管理
|
||||||
|
func TestListenerManagement(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: ":0",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
ln1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create listener 1: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = ln1.Close() }()
|
||||||
|
|
||||||
|
ln2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create listener 2: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = ln2.Close() }()
|
||||||
|
|
||||||
|
s.SetListeners([]net.Listener{ln1, ln2})
|
||||||
|
|
||||||
|
listeners := s.GetListeners()
|
||||||
|
if len(listeners) != 2 {
|
||||||
|
t.Errorf("Expected 2 listeners, got %d", len(listeners))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStart_WithGoroutinePoolAndFileCache 测试同时启用 GoroutinePool 和 FileCache
|
||||||
|
func TestStart_WithGoroutinePoolAndFileCache(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: ":0",
|
||||||
|
}},
|
||||||
|
Performance: config.PerformanceConfig{
|
||||||
|
GoroutinePool: config.GoroutinePoolConfig{
|
||||||
|
Enabled: true,
|
||||||
|
MaxWorkers: 50,
|
||||||
|
MinWorkers: 5,
|
||||||
|
},
|
||||||
|
FileCache: config.FileCacheConfig{
|
||||||
|
MaxEntries: 500,
|
||||||
|
MaxSize: 50 * 1024 * 1024,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("New() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.config.Performance.GoroutinePool.Enabled != true {
|
||||||
|
t.Error("GoroutinePool should be enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServer_GetHandler_NilThenSet 测试 handler 的 nil 到设置
|
||||||
|
func TestServer_GetHandler_NilThenSet(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: ":0",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
if s.GetHandler() != nil {
|
||||||
|
t.Error("Expected nil handler initially")
|
||||||
|
}
|
||||||
|
|
||||||
|
testHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.SetBodyString("test handler response")
|
||||||
|
}
|
||||||
|
s.handler = testHandler
|
||||||
|
|
||||||
|
got := s.GetHandler()
|
||||||
|
if got == nil {
|
||||||
|
t.Error("Expected non-nil handler after setting")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
got(ctx)
|
||||||
|
if string(ctx.Response.Body()) != "test handler response" {
|
||||||
|
t.Errorf("Handler response = %q, want %q", string(ctx.Response.Body()), "test handler response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServer_TrackStats_Concurrent 测试并发统计追踪
|
||||||
|
func TestServer_TrackStats_Concurrent(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: ":0",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
|
||||||
|
handler := func(ctx *fasthttp.RequestCtx) {
|
||||||
|
ctx.SetBodyString("ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
wrappedHandler := s.trackStats(handler)
|
||||||
|
|
||||||
|
const numGoroutines = 100
|
||||||
|
done := make(chan bool, numGoroutines)
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func() {
|
||||||
|
ctx := &fasthttp.RequestCtx{}
|
||||||
|
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||||
|
wrappedHandler(ctx)
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.requests.Load() != int64(numGoroutines) {
|
||||||
|
t.Errorf("Expected %d requests, got %d", numGoroutines, s.requests.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildMiddlewareChain_BodyLimit 测试请求体限制中间件
|
||||||
|
func TestBuildMiddlewareChain_BodyLimit(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Logging: config.LoggingConfig{},
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: ":0",
|
||||||
|
Proxy: []config.ProxyConfig{{
|
||||||
|
Path: "/api/",
|
||||||
|
ClientMaxBodySize: "1MB",
|
||||||
|
}},
|
||||||
|
ClientMaxBodySize: "10MB",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
chain, err := s.buildMiddlewareChain(&cfg.Servers[0])
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("buildMiddlewareChain failed: %v", err)
|
||||||
|
}
|
||||||
|
if chain == nil {
|
||||||
|
t.Error("Expected non-nil chain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildMiddlewareChain_BodyLimit_Invalid 测试无效的请求体限制
|
||||||
|
func TestBuildMiddlewareChain_BodyLimit_Invalid(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
Logging: config.LoggingConfig{},
|
||||||
|
Servers: []config.ServerConfig{{
|
||||||
|
Listen: ":0",
|
||||||
|
ClientMaxBodySize: "invalid_size",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := New(cfg)
|
||||||
|
_, err := s.buildMiddlewareChain(&cfg.Servers[0])
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error for invalid body limit size")
|
||||||
|
}
|
||||||
|
}
|
||||||
554
internal/stream/stream_coverage_test.go
Normal file
554
internal/stream/stream_coverage_test.go
Normal file
@ -0,0 +1,554 @@
|
|||||||
|
// Package stream 提供流代理覆盖测试。
|
||||||
|
//
|
||||||
|
// 该文件补充测试 stream.go 中未覆盖的方法:
|
||||||
|
// - ipHash.Select() (空 IP)
|
||||||
|
// - handleConnection() 连接处理
|
||||||
|
// - getOrCreateSession() 会话创建
|
||||||
|
// - handleBackendResponse() 后端响应处理
|
||||||
|
// - Stats 完整统计
|
||||||
|
//
|
||||||
|
// 作者:xfy
|
||||||
|
package stream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestIPHashSelect 测试 ipHash 的 Select 方法(空字符串 IP)
|
||||||
|
func TestIPHashSelect(t *testing.T) {
|
||||||
|
targets := []*Target{
|
||||||
|
{addr: "localhost:8001"},
|
||||||
|
{addr: "localhost:8002"},
|
||||||
|
}
|
||||||
|
for _, target := range targets {
|
||||||
|
target.healthy.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
ih := newIPHash()
|
||||||
|
|
||||||
|
// Select() 使用空字符串作为 IP
|
||||||
|
selected := ih.Select(targets)
|
||||||
|
if selected == nil {
|
||||||
|
t.Error("Select() with empty IP should return a target")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 多次调用应返回相同目标(确定性哈希)
|
||||||
|
selected2 := ih.Select(targets)
|
||||||
|
if selected != selected2 {
|
||||||
|
t.Error("Select() with same empty IP should be consistent")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 无健康目标时应返回 nil
|
||||||
|
for _, target := range targets {
|
||||||
|
target.healthy.Store(false)
|
||||||
|
}
|
||||||
|
selected = ih.Select(targets)
|
||||||
|
if selected != nil {
|
||||||
|
t.Error("Select() with no healthy targets should return nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSelectByIPNoHealthy 测试 SelectByIP 无健康目标
|
||||||
|
func TestSelectByIPNoHealthy(t *testing.T) {
|
||||||
|
targets := []*Target{
|
||||||
|
{addr: "localhost:8001"},
|
||||||
|
{addr: "localhost:8002"},
|
||||||
|
}
|
||||||
|
|
||||||
|
ih := newIPHash()
|
||||||
|
selected := ih.(*ipHash).SelectByIP(targets, "192.168.1.1")
|
||||||
|
if selected != nil {
|
||||||
|
t.Error("SelectByIP() with no healthy targets should return nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWeightedRoundRobinZeroWeight 测试零权重处理
|
||||||
|
func TestWeightedRoundRobinZeroWeight(t *testing.T) {
|
||||||
|
targets := []*Target{
|
||||||
|
{addr: "localhost:8001", weight: 0},
|
||||||
|
{addr: "localhost:8002", weight: -1},
|
||||||
|
}
|
||||||
|
for _, target := range targets {
|
||||||
|
target.healthy.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
wrr := newWeightedRoundRobin().(*weightedRoundRobin)
|
||||||
|
|
||||||
|
// 权重为 0 或负数应视为权重 1
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
selected := wrr.Select(targets)
|
||||||
|
if selected == nil {
|
||||||
|
t.Error("Select() should return target with zero/negative weight")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleConnection 测试 handleConnection 方法
|
||||||
|
func TestHandleConnection(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
|
||||||
|
// 添加上游配置
|
||||||
|
targets := []TargetSpec{
|
||||||
|
{Addr: "127.0.0.1:29001", Weight: 1},
|
||||||
|
}
|
||||||
|
_ = s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{})
|
||||||
|
s.upstreams["test"].targets[0].healthy.Store(true)
|
||||||
|
|
||||||
|
// 创建模拟客户端连接(不会实际建立连接,测试无上游路径)
|
||||||
|
s.mu.Lock()
|
||||||
|
// 设置上游为空,测试无上游配置路径
|
||||||
|
s.upstreams = make(map[string]*Upstream)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// 创建一对连接
|
||||||
|
serverLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = serverLn.Close() }()
|
||||||
|
|
||||||
|
clientConn, err := net.Dial("tcp", serverLn.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to dial: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = clientConn.Close() }()
|
||||||
|
|
||||||
|
serverConn, err := serverLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to accept: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = serverConn.Close() }()
|
||||||
|
|
||||||
|
// 测试无上游配置的 handleConnection
|
||||||
|
s.handleConnection(clientConn, "127.0.0.1:0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleConnection_NoHealthyTarget 测试无健康目标路径
|
||||||
|
func TestHandleConnection_NoHealthyTarget(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
|
||||||
|
// 添加不健康的上游
|
||||||
|
targets := []TargetSpec{
|
||||||
|
{Addr: "127.0.0.1:29002", Weight: 1},
|
||||||
|
}
|
||||||
|
_ = s.AddUpstream("test2", targets, "round_robin", HealthCheckSpec{})
|
||||||
|
// 目标不健康(默认 false)
|
||||||
|
|
||||||
|
serverLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = serverLn.Close() }()
|
||||||
|
|
||||||
|
clientConn, err := net.Dial("tcp", serverLn.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to dial: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = clientConn.Close() }()
|
||||||
|
|
||||||
|
serverConn, err := serverLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to accept: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = serverConn.Close() }()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
s.handleConnection(clientConn, "127.0.0.1:0")
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// 完成
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("handleConnection() timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleConnection_DialFail 测试连接目标失败路径
|
||||||
|
func TestHandleConnection_DialFail(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
|
||||||
|
// 添加上游,目标不可达
|
||||||
|
targets := []TargetSpec{
|
||||||
|
{Addr: "127.0.0.1:29999", Weight: 1},
|
||||||
|
}
|
||||||
|
_ = s.AddUpstream("test3", targets, "round_robin", HealthCheckSpec{})
|
||||||
|
s.upstreams["test3"].targets[0].healthy.Store(true)
|
||||||
|
|
||||||
|
serverLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = serverLn.Close() }()
|
||||||
|
|
||||||
|
clientConn, err := net.Dial("tcp", serverLn.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to dial: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = clientConn.Close() }()
|
||||||
|
|
||||||
|
serverConn, err := serverLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to accept: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = serverConn.Close() }()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
s.handleConnection(clientConn, "127.0.0.1:0")
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// 完成 - 连接目标失败后应标记为不健康
|
||||||
|
if s.upstreams["test3"].targets[0].healthy.Load() {
|
||||||
|
t.Error("Target should be marked unhealthy after dial failure")
|
||||||
|
}
|
||||||
|
case <-time.After(15 * time.Second):
|
||||||
|
t.Fatal("handleConnection() timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetOrCreateSession 测试 getOrCreateSession 方法
|
||||||
|
func TestGetOrCreateSession(t *testing.T) {
|
||||||
|
// 创建 UDP 连接
|
||||||
|
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||||
|
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
// 创建上游
|
||||||
|
upstream := &Upstream{
|
||||||
|
targets: []*Target{{addr: "127.0.0.1:29003"}},
|
||||||
|
balancer: newRoundRobin(),
|
||||||
|
}
|
||||||
|
upstream.targets[0].healthy.Store(true)
|
||||||
|
|
||||||
|
srv := newUDPServer(conn, upstream, 1*time.Minute)
|
||||||
|
|
||||||
|
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29010")
|
||||||
|
|
||||||
|
// 第一次调用 - 应该创建新会话(但由于后端不可达,应该失败)
|
||||||
|
session, err := srv.getOrCreateSession(clientAddr)
|
||||||
|
if err != nil {
|
||||||
|
// 预期失败,因为后端不可达
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if session == nil {
|
||||||
|
t.Error("getOrCreateSession() should return a session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetOrCreateSession_DoubleCheck 测试双重检查锁定
|
||||||
|
func TestGetOrCreateSession_DoubleCheck(t *testing.T) {
|
||||||
|
// 创建 UDP 连接
|
||||||
|
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||||
|
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
// 创建上游
|
||||||
|
upstream := &Upstream{
|
||||||
|
targets: []*Target{{addr: "127.0.0.1:29004"}},
|
||||||
|
balancer: newRoundRobin(),
|
||||||
|
}
|
||||||
|
upstream.targets[0].healthy.Store(true)
|
||||||
|
|
||||||
|
srv := newUDPServer(conn, upstream, 1*time.Minute)
|
||||||
|
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29011")
|
||||||
|
|
||||||
|
// 手动创建一个会话来测试双重检查
|
||||||
|
srv.mu.Lock()
|
||||||
|
testSession := &udpSession{
|
||||||
|
clientAddr: clientAddr,
|
||||||
|
lastActive: time.Now(),
|
||||||
|
srv: srv,
|
||||||
|
}
|
||||||
|
srv.sessions[sessionKey(clientAddr)] = testSession
|
||||||
|
srv.mu.Unlock()
|
||||||
|
|
||||||
|
// 再次获取应该返回现有会话
|
||||||
|
session, err := srv.getOrCreateSession(clientAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("getOrCreateSession() should not error for existing session: %v", err)
|
||||||
|
}
|
||||||
|
if session != testSession {
|
||||||
|
t.Error("getOrCreateSession() should return existing session")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetOrCreateSession_NoHealthyTarget 测试无健康目标
|
||||||
|
func TestGetOrCreateSession_NoHealthyTarget(t *testing.T) {
|
||||||
|
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||||
|
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
upstream := &Upstream{
|
||||||
|
targets: []*Target{{addr: "127.0.0.1:29005"}},
|
||||||
|
balancer: newRoundRobin(),
|
||||||
|
}
|
||||||
|
// 不设置 healthy,默认为 false
|
||||||
|
|
||||||
|
srv := newUDPServer(conn, upstream, 1*time.Minute)
|
||||||
|
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29012")
|
||||||
|
|
||||||
|
session, err := srv.getOrCreateSession(clientAddr)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("getOrCreateSession() should return error when no healthy target")
|
||||||
|
}
|
||||||
|
if session != nil {
|
||||||
|
t.Error("getOrCreateSession() should return nil session on error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetOrCreateSession_InvalidTargetAddr 测试无效目标地址
|
||||||
|
func TestGetOrCreateSession_InvalidTargetAddr(t *testing.T) {
|
||||||
|
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||||
|
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
upstream := &Upstream{
|
||||||
|
targets: []*Target{{addr: "invalid-address"}},
|
||||||
|
balancer: newRoundRobin(),
|
||||||
|
}
|
||||||
|
upstream.targets[0].healthy.Store(true)
|
||||||
|
|
||||||
|
srv := newUDPServer(conn, upstream, 1*time.Minute)
|
||||||
|
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29013")
|
||||||
|
|
||||||
|
session, err := srv.getOrCreateSession(clientAddr)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("getOrCreateSession() should return error for invalid target address")
|
||||||
|
}
|
||||||
|
if session != nil {
|
||||||
|
t.Error("getOrCreateSession() should return nil session on error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleBackendResponse 测试 handleBackendResponse 超时清理路径
|
||||||
|
func TestHandleBackendResponse(t *testing.T) {
|
||||||
|
// 创建 UDP 连接(服务端)
|
||||||
|
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||||
|
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
// 创建上游
|
||||||
|
upstream := &Upstream{
|
||||||
|
targets: []*Target{{addr: "127.0.0.1:29006"}},
|
||||||
|
balancer: newRoundRobin(),
|
||||||
|
}
|
||||||
|
upstream.targets[0].healthy.Store(true)
|
||||||
|
|
||||||
|
srv := newUDPServer(conn, upstream, 50*time.Millisecond) // 短超时
|
||||||
|
|
||||||
|
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29014")
|
||||||
|
|
||||||
|
// 创建目标连接(监听器用于接收目标连接)
|
||||||
|
targetAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29006")
|
||||||
|
targetConn, _ := net.ListenUDP("udp", targetAddr)
|
||||||
|
defer func() { _ = targetConn.Close() }()
|
||||||
|
|
||||||
|
// 创建会话时需要先添加到 WaitGroup(handleBackendResponse 会调用 Done)
|
||||||
|
srv.wg.Add(1)
|
||||||
|
session := &udpSession{
|
||||||
|
clientAddr: clientAddr,
|
||||||
|
targetConn: targetConn,
|
||||||
|
lastActive: time.Now().Add(-2 * time.Hour), // 很久以前
|
||||||
|
srv: srv,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加会话到服务器
|
||||||
|
srv.sessions[sessionKey(clientAddr)] = session
|
||||||
|
|
||||||
|
// 启动后端响应处理
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
session.handleBackendResponse()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// 应该因为超时而清理会话
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("handleBackendResponse() timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleBackendResponse_ErrorPath 测试 handleBackendResponse 错误路径
|
||||||
|
func TestHandleBackendResponse_ErrorPath(t *testing.T) {
|
||||||
|
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||||
|
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
upstream := &Upstream{
|
||||||
|
targets: []*Target{{addr: "127.0.0.1:29007"}},
|
||||||
|
balancer: newRoundRobin(),
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := newUDPServer(conn, upstream, 10*time.Millisecond)
|
||||||
|
|
||||||
|
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29015")
|
||||||
|
|
||||||
|
// 创建一个已关闭的连接作为 targetConn
|
||||||
|
targetUDPAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:29007")
|
||||||
|
targetConn, err := net.DialUDP("udp", nil, targetUDPAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create target connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
session := &udpSession{
|
||||||
|
clientAddr: clientAddr,
|
||||||
|
targetConn: targetConn,
|
||||||
|
lastActive: time.Now(),
|
||||||
|
srv: srv,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建会话时需要先添加到 WaitGroup(handleBackendResponse 会调用 Done)
|
||||||
|
srv.wg.Add(1)
|
||||||
|
srv.sessions[sessionKey(clientAddr)] = session
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
session.handleBackendResponse()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// 完成
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("handleBackendResponse() timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServe_InvalidUpstream 测试 serve 方法无效上游路径
|
||||||
|
func TestServe_InvalidUpstream(t *testing.T) {
|
||||||
|
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||||
|
conn, _ := net.ListenUDP("udp", udpAddr)
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
upstream := &Upstream{
|
||||||
|
targets: []*Target{},
|
||||||
|
balancer: newRoundRobin(),
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := newUDPServer(conn, upstream, 50*time.Millisecond)
|
||||||
|
|
||||||
|
// 启动 serve
|
||||||
|
go srv.serve()
|
||||||
|
|
||||||
|
// 立即停止
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
srv.stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServerStop 测试 Server.Stop 方法
|
||||||
|
func TestServerStop(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
|
||||||
|
// 添加上游
|
||||||
|
targets := []TargetSpec{
|
||||||
|
{Addr: "127.0.0.1:29008", Weight: 1},
|
||||||
|
}
|
||||||
|
hcSpec := HealthCheckSpec{
|
||||||
|
Enabled: true,
|
||||||
|
Interval: 1 * time.Second,
|
||||||
|
Timeout: 500 * time.Millisecond,
|
||||||
|
}
|
||||||
|
_ = s.AddUpstream("stop_test", targets, "round_robin", hcSpec)
|
||||||
|
|
||||||
|
// 监听 TCP
|
||||||
|
err := s.ListenTCP("127.0.0.1:29009")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListenTCP failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 监听 UDP
|
||||||
|
err = s.ListenUDP("127.0.0.1:29010", "stop_test", 1*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListenUDP failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动
|
||||||
|
err = s.Start()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Start failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// 停止
|
||||||
|
err = s.Stop()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Stop() error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStatsComplete 测试 Stats 完整统计
|
||||||
|
func TestStatsComplete(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
|
||||||
|
// 添加 TCP 监听
|
||||||
|
err := s.ListenTCP("127.0.0.1:29020")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListenTCP failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加上游
|
||||||
|
targets := []TargetSpec{{Addr: "127.0.0.1:29021", Weight: 1}}
|
||||||
|
_ = s.AddUpstream("stats_test", targets, "round_robin", HealthCheckSpec{})
|
||||||
|
|
||||||
|
// 添加 UDP 监听
|
||||||
|
err = s.ListenUDP("127.0.0.1:29022", "stats_test", 1*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListenUDP failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := s.Stats()
|
||||||
|
if stats.Listeners != 2 {
|
||||||
|
t.Errorf("Stats().Listeners = %d, want 2 (1 TCP + 1 UDP)", stats.Listeners)
|
||||||
|
}
|
||||||
|
if stats.Upstreams != 1 {
|
||||||
|
t.Errorf("Stats().Upstreams = %d, want 1", stats.Upstreams)
|
||||||
|
}
|
||||||
|
if stats.Connections != 0 {
|
||||||
|
t.Errorf("Stats().Connections = %d, want 0", stats.Connections)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAcceptLoop_Error 测试 acceptLoop 错误处理路径
|
||||||
|
func TestAcceptLoop_Error(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
s.running.Store(true)
|
||||||
|
|
||||||
|
// 创建一个立即关闭的监听器
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create listener: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 在另一个 goroutine 中关闭监听器
|
||||||
|
go func() {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
_ = ln.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
s.acceptLoop("test", ln)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// 完成
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
s.running.Store(false)
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user