547 lines
12 KiB
Go
547 lines
12 KiB
Go
// Package stream 提供 TCP/UDP 流代理功能的测试。
|
||
//
|
||
// 该文件测试流代理模块的各项功能,包括:
|
||
// - 服务器创建和初始化
|
||
// - 上游配置和负载均衡
|
||
// - TCP 和 UDP 监听
|
||
// - 健康检查
|
||
// - 连接统计
|
||
//
|
||
// 作者:xfy
|
||
package stream
|
||
|
||
import (
|
||
"net"
|
||
"sync"
|
||
"sync/atomic"
|
||
"testing"
|
||
"time"
|
||
)
|
||
|
||
func TestNewServer(t *testing.T) {
|
||
s := NewServer()
|
||
if s == nil {
|
||
t.Fatal("Expected non-nil server")
|
||
}
|
||
if s.listeners == nil {
|
||
t.Error("Expected initialized listeners map")
|
||
}
|
||
if s.upstreams == nil {
|
||
t.Error("Expected initialized upstreams map")
|
||
}
|
||
}
|
||
|
||
func TestAddUpstream(t *testing.T) {
|
||
s := NewServer()
|
||
|
||
targets := []TargetSpec{
|
||
{Addr: "localhost:8001", Weight: 1},
|
||
{Addr: "localhost:8002", Weight: 2},
|
||
}
|
||
|
||
hcSpec := HealthCheckSpec{
|
||
Enabled: false,
|
||
Interval: 10 * time.Second,
|
||
Timeout: 5 * time.Second,
|
||
}
|
||
|
||
err := s.AddUpstream("test", targets, "round_robin", hcSpec)
|
||
if err != nil {
|
||
t.Errorf("AddUpstream failed: %v", err)
|
||
}
|
||
|
||
if len(s.upstreams) != 1 {
|
||
t.Errorf("Expected 1 upstream, got %d", len(s.upstreams))
|
||
}
|
||
|
||
up := s.upstreams["test"]
|
||
if up == nil {
|
||
t.Fatal("Expected non-nil upstream")
|
||
}
|
||
if len(up.targets) != 2 {
|
||
t.Errorf("Expected 2 targets, got %d", len(up.targets))
|
||
}
|
||
}
|
||
|
||
func TestRoundRobinBalancer(t *testing.T) {
|
||
targets := []*Target{
|
||
{addr: "localhost:8001"},
|
||
{addr: "localhost:8002"},
|
||
{addr: "localhost:8003"},
|
||
}
|
||
for _, target := range targets {
|
||
target.healthy.Store(true)
|
||
}
|
||
|
||
rr := newRoundRobin()
|
||
|
||
// 测试轮询
|
||
results := make(map[string]int)
|
||
for range 6 {
|
||
selected := rr.Select(targets)
|
||
if selected == nil {
|
||
t.Error("Expected non-nil target")
|
||
continue
|
||
}
|
||
results[selected.addr]++
|
||
}
|
||
|
||
// 每个目标应该被选中 2 次
|
||
for _, target := range targets {
|
||
if results[target.addr] != 2 {
|
||
t.Errorf("Expected %s to be selected 2 times, got %d", target.addr, results[target.addr])
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestLeastConnBalancer(t *testing.T) {
|
||
targets := []*Target{
|
||
{addr: "localhost:8001", conns: 5},
|
||
{addr: "localhost:8002", conns: 2},
|
||
{addr: "localhost:8003", conns: 8},
|
||
}
|
||
for _, t := range targets {
|
||
t.healthy.Store(true)
|
||
}
|
||
|
||
lc := newLeastConn()
|
||
selected := lc.Select(targets)
|
||
|
||
if selected == nil {
|
||
t.Error("Expected non-nil target")
|
||
} else if selected.addr != "localhost:8002" {
|
||
t.Errorf("Expected localhost:8002 (least connections), got %s", selected.addr)
|
||
}
|
||
}
|
||
|
||
func TestBalancerNoHealthyTargets(t *testing.T) {
|
||
targets := []*Target{
|
||
{addr: "localhost:8001"},
|
||
{addr: "localhost:8002"},
|
||
}
|
||
// 不设置 healthy,默认为 false
|
||
|
||
rr := newRoundRobin()
|
||
selected := rr.Select(targets)
|
||
if selected != nil {
|
||
t.Error("Expected nil for no healthy targets")
|
||
}
|
||
|
||
lc := newLeastConn()
|
||
selected = lc.Select(targets)
|
||
if selected != nil {
|
||
t.Error("Expected nil for no healthy targets")
|
||
}
|
||
}
|
||
|
||
func TestWeightedRoundRobinBalancer(t *testing.T) {
|
||
targets := []*Target{
|
||
{addr: "localhost:8001", weight: 3},
|
||
{addr: "localhost:8002", weight: 1},
|
||
}
|
||
for _, target := range targets {
|
||
target.healthy.Store(true)
|
||
}
|
||
|
||
wrr := newWeightedRoundRobin()
|
||
|
||
// 测试加权分布:3:1 比例
|
||
results := make(map[string]int)
|
||
for range 8 {
|
||
selected := wrr.Select(targets)
|
||
if selected == nil {
|
||
t.Error("Expected non-nil target")
|
||
continue
|
||
}
|
||
results[selected.addr]++
|
||
}
|
||
|
||
// localhost:8001 应被选中 6 次,localhost:8002 应被选中 2 次
|
||
if results["localhost:8001"] != 6 {
|
||
t.Errorf("Expected localhost:8001 to be selected 6 times, got %d", results["localhost:8001"])
|
||
}
|
||
if results["localhost:8002"] != 2 {
|
||
t.Errorf("Expected localhost:8002 to be selected 2 times, got %d", results["localhost:8002"])
|
||
}
|
||
}
|
||
|
||
func TestIPHashBalancer(t *testing.T) {
|
||
targets := []*Target{
|
||
{addr: "localhost:8001"},
|
||
{addr: "localhost:8002"},
|
||
{addr: "localhost:8003"},
|
||
}
|
||
for _, target := range targets {
|
||
target.healthy.Store(true)
|
||
}
|
||
|
||
ih := newIPHash()
|
||
|
||
// 相同 IP 应始终选择同一目标
|
||
ip1 := "192.168.1.1"
|
||
selected1 := ih.(*ipHash).SelectByIP(targets, ip1)
|
||
selected2 := ih.(*ipHash).SelectByIP(targets, ip1)
|
||
|
||
if selected1 != selected2 {
|
||
t.Error("Same IP should select same target")
|
||
}
|
||
|
||
// 不同 IP 可能选择不同目标
|
||
ip2 := "10.0.0.1"
|
||
selected3 := ih.(*ipHash).SelectByIP(targets, ip2)
|
||
// 验证返回非空
|
||
if selected3 == nil {
|
||
t.Error("Expected non-nil target for different IP")
|
||
}
|
||
}
|
||
|
||
func TestUpstreamSelect(t *testing.T) {
|
||
u := &Upstream{
|
||
targets: []*Target{
|
||
{addr: "localhost:8001"},
|
||
{addr: "localhost:8002"},
|
||
},
|
||
balancer: newRoundRobin(),
|
||
}
|
||
for _, t := range u.targets {
|
||
t.healthy.Store(true)
|
||
}
|
||
|
||
selected := u.Select()
|
||
if selected == nil {
|
||
t.Error("Expected non-nil target")
|
||
}
|
||
}
|
||
|
||
func TestTargetHealthy(t *testing.T) {
|
||
target := &Target{
|
||
addr: "localhost:8001",
|
||
weight: 1,
|
||
}
|
||
|
||
// 初始状态应该是不健康(默认为 false)
|
||
if target.healthy.Load() {
|
||
t.Error("新目标应该默认为不健康")
|
||
}
|
||
|
||
// 设置为健康
|
||
target.healthy.Store(true)
|
||
if !target.healthy.Load() {
|
||
t.Error("目标应该被标记为健康")
|
||
}
|
||
|
||
// 设置为不健康
|
||
target.healthy.Store(false)
|
||
if target.healthy.Load() {
|
||
t.Error("目标应该被标记为不健康")
|
||
}
|
||
}
|
||
|
||
func TestConcurrentConnections(t *testing.T) {
|
||
s := NewServer()
|
||
|
||
targets := []TargetSpec{
|
||
{Addr: "localhost:8001", Weight: 1},
|
||
}
|
||
_ = s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{})
|
||
|
||
// 并发增加连接数
|
||
var wg sync.WaitGroup
|
||
for range 100 {
|
||
wg.Go(func() {
|
||
atomic.AddInt64(&s.connCount, 1)
|
||
})
|
||
}
|
||
wg.Wait()
|
||
|
||
if s.connCount != 100 {
|
||
t.Errorf("Expected 100 connections, got %d", s.connCount)
|
||
}
|
||
}
|
||
|
||
func TestUDPServerInvalidUpstream(t *testing.T) {
|
||
s := NewServer()
|
||
|
||
// 尝试监听不存在的上游配置
|
||
err := s.ListenUDP("127.0.0.1:0", "non_existent", 0)
|
||
if err == nil {
|
||
t.Error("Expected error for non-existent upstream")
|
||
}
|
||
}
|
||
|
||
func TestUDPSessionKey(t *testing.T) {
|
||
addr1, _ := net.ResolveUDPAddr("udp", "127.0.0.1:1234")
|
||
addr2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:5678")
|
||
addr3, _ := net.ResolveUDPAddr("udp", "127.0.0.1:1234")
|
||
|
||
key1 := sessionKey(addr1)
|
||
key2 := sessionKey(addr2)
|
||
key3 := sessionKey(addr3)
|
||
|
||
if key1 == key2 {
|
||
t.Error("Different addresses should have different keys")
|
||
}
|
||
|
||
if key1 != key3 {
|
||
t.Error("Same addresses should have same keys")
|
||
}
|
||
}
|
||
|
||
func TestNewUDPServer(t *testing.T) {
|
||
// 创建 UDP 连接
|
||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||
defer func() { _ = conn.Close() }()
|
||
|
||
// 创建上游
|
||
upstream := &Upstream{
|
||
targets: []*Target{{addr: "127.0.0.1:19002"}},
|
||
balancer: newRoundRobin(),
|
||
}
|
||
|
||
// 测试默认超时
|
||
srv := newUDPServer(conn, upstream, 0)
|
||
if srv.timeout != 60*time.Second {
|
||
t.Errorf("Expected default timeout 60s, got %v", srv.timeout)
|
||
}
|
||
|
||
// 测试自定义超时
|
||
srv2 := newUDPServer(conn, upstream, 30*time.Second)
|
||
if srv2.timeout != 30*time.Second {
|
||
t.Errorf("Expected timeout 30s, got %v", srv2.timeout)
|
||
}
|
||
}
|
||
|
||
func TestRoundRobinBalancerWithSingleTarget(t *testing.T) {
|
||
rb := newRoundRobin()
|
||
targets := []*Target{
|
||
{addr: "backend1:8080"},
|
||
}
|
||
targets[0].healthy.Store(true)
|
||
|
||
// 测试单个健康目标
|
||
for range 5 {
|
||
target := rb.Select(targets)
|
||
if target == nil {
|
||
t.Error("Expected non-nil target")
|
||
continue
|
||
}
|
||
if target.addr != "backend1:8080" {
|
||
t.Errorf("Expected backend1:8080, got %s", target.addr)
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestLeastConnBalancerWithTie(t *testing.T) {
|
||
lc := newLeastConn()
|
||
targets := []*Target{
|
||
{addr: "backend1:8080", conns: 5},
|
||
{addr: "backend2:8080", conns: 5},
|
||
{addr: "backend3:8080", conns: 5},
|
||
}
|
||
for _, t := range targets {
|
||
t.healthy.Store(true)
|
||
}
|
||
|
||
// 当连接数相同时,应该选择第一个
|
||
selected := lc.Select(targets)
|
||
if selected == nil {
|
||
t.Error("Expected non-nil target")
|
||
}
|
||
}
|
||
|
||
func TestAddUpstreamWithLeastConn(t *testing.T) {
|
||
s := NewServer()
|
||
|
||
targets := []TargetSpec{
|
||
{Addr: "localhost:8001", Weight: 1},
|
||
{Addr: "localhost:8002", Weight: 2},
|
||
}
|
||
|
||
err := s.AddUpstream("least_conn_test", targets, "least_conn", HealthCheckSpec{})
|
||
if err != nil {
|
||
t.Errorf("AddUpstream failed: %v", err)
|
||
}
|
||
|
||
up := s.upstreams["least_conn_test"]
|
||
if up == nil {
|
||
t.Fatal("Expected non-nil upstream")
|
||
}
|
||
|
||
// 验证使用的是最少连接均衡器
|
||
_, ok := up.balancer.(*leastConn)
|
||
if !ok {
|
||
t.Error("Expected leastConn balancer")
|
||
}
|
||
}
|
||
|
||
func TestUpstreamSelectNoHealthy(t *testing.T) {
|
||
u := &Upstream{
|
||
targets: []*Target{
|
||
{addr: "localhost:8001"},
|
||
{addr: "localhost:8002"},
|
||
},
|
||
balancer: newRoundRobin(),
|
||
}
|
||
// 不设置 healthy,默认为 false
|
||
|
||
selected := u.Select()
|
||
if selected != nil {
|
||
t.Error("Expected nil for no healthy targets")
|
||
}
|
||
}
|
||
|
||
func TestCleanupExpiredSessions(t *testing.T) {
|
||
// 创建 UDP 连接
|
||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||
defer func() { _ = conn.Close() }()
|
||
|
||
// 创建上游
|
||
upstream := &Upstream{
|
||
targets: []*Target{{addr: "127.0.0.1:19005"}},
|
||
balancer: newRoundRobin(),
|
||
}
|
||
upstream.targets[0].healthy.Store(true)
|
||
|
||
// 创建 UDP 服务器,设置很短的超时时间
|
||
srv := newUDPServer(conn, upstream, 1*time.Millisecond)
|
||
|
||
// 创建模拟会话
|
||
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12345")
|
||
session := &udpSession{
|
||
clientAddr: clientAddr,
|
||
lastActive: time.Now().Add(-1 * time.Hour), // 很久以前的活动
|
||
}
|
||
srv.sessions[sessionKey(clientAddr)] = session
|
||
|
||
// 执行清理
|
||
srv.cleanupExpiredSessions()
|
||
|
||
// 验证会话已被清理
|
||
srv.mu.RLock()
|
||
if len(srv.sessions) != 0 {
|
||
t.Errorf("Expected 0 sessions after cleanup, got %d", len(srv.sessions))
|
||
}
|
||
srv.mu.RUnlock()
|
||
}
|
||
|
||
func TestUDPSessionOperations(t *testing.T) {
|
||
// 创建 UDP 连接
|
||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||
defer func() { _ = conn.Close() }()
|
||
|
||
// 创建目标服务器(用于模拟连接)
|
||
targetAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:19007")
|
||
targetConn, _ := net.ListenUDP("udp", targetAddr)
|
||
defer func() { _ = targetConn.Close() }()
|
||
|
||
// 创建上游
|
||
upstream := &Upstream{
|
||
targets: []*Target{{addr: "127.0.0.1:19007"}},
|
||
balancer: newRoundRobin(),
|
||
}
|
||
upstream.targets[0].healthy.Store(true)
|
||
|
||
// 创建 UDP 服务器
|
||
srv := newUDPServer(conn, upstream, 1*time.Minute)
|
||
|
||
// 测试 getSession - 不存在的会话
|
||
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12347")
|
||
session := srv.getSession(clientAddr)
|
||
if session != nil {
|
||
t.Error("Expected nil for non-existent session")
|
||
}
|
||
|
||
// 创建模拟会话
|
||
testSession := &udpSession{
|
||
clientAddr: clientAddr,
|
||
lastActive: time.Now(),
|
||
srv: srv,
|
||
}
|
||
srv.sessions[sessionKey(clientAddr)] = testSession
|
||
|
||
// 测试 getSession - 存在的会话
|
||
session = srv.getSession(clientAddr)
|
||
if session == nil {
|
||
t.Error("Expected non-nil session")
|
||
}
|
||
|
||
// 测试 removeSession
|
||
srv.removeSession(clientAddr)
|
||
session = srv.getSession(clientAddr)
|
||
if session != nil {
|
||
t.Error("Expected nil after removeSession")
|
||
}
|
||
}
|
||
|
||
func TestUDPSessionClose(_ *testing.T) {
|
||
// 创建两个 UDP 连接用于测试
|
||
udpAddr1, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||
conn1, _ := net.ListenUDP("udp", udpAddr1)
|
||
|
||
udpAddr2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||
conn2, _ := net.ListenUDP("udp", udpAddr2)
|
||
|
||
// 创建会话
|
||
session := &udpSession{
|
||
clientAddr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12348},
|
||
targetConn: conn2,
|
||
lastActive: time.Now(),
|
||
}
|
||
|
||
// 测试 close - 应该能正常关闭
|
||
session.close()
|
||
|
||
// 第二次调用 close 不应该出错(使用 sync.Once)
|
||
session.close()
|
||
|
||
_ = conn1.Close()
|
||
}
|
||
|
||
func TestHealthCheckerCheckWithHealthyTarget(t *testing.T) {
|
||
// 启动一个临时的 TCP 服务器
|
||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to create listener: %v", err)
|
||
}
|
||
defer func() { _ = listener.Close() }()
|
||
|
||
// 在后台运行服务器
|
||
go func() {
|
||
for {
|
||
conn, err := listener.Accept()
|
||
if err != nil {
|
||
return
|
||
}
|
||
_ = conn.Close()
|
||
}
|
||
}()
|
||
|
||
addr := listener.Addr().String()
|
||
|
||
u := &Upstream{
|
||
targets: []*Target{
|
||
{addr: addr},
|
||
},
|
||
}
|
||
// 初始设置为不健康
|
||
u.targets[0].healthy.Store(false)
|
||
|
||
hc := &HealthChecker{
|
||
upstream: u,
|
||
interval: 1 * time.Second,
|
||
timeout: 100 * time.Millisecond,
|
||
stopCh: make(chan struct{}),
|
||
}
|
||
|
||
// 执行健康检查
|
||
hc.check()
|
||
|
||
// 目标应该被标记为健康(因为服务器在运行)
|
||
if !u.targets[0].healthy.Load() {
|
||
t.Error("Expected target to be marked healthy")
|
||
}
|
||
}
|