lolly/internal/stream/stream_test.go

547 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package stream 提供 TCP/UDP 流代理功能的测试。
//
// 该文件测试流代理模块的各项功能,包括:
// - 服务器创建和初始化
// - 上游配置和负载均衡
// - TCP 和 UDP 监听
// - 健康检查
// - 连接统计
//
// 作者xfy
package stream
import (
"net"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestNewServer(t *testing.T) {
s := NewServer()
if s == nil {
t.Fatal("Expected non-nil server")
}
if s.listeners == nil {
t.Error("Expected initialized listeners map")
}
if s.upstreams == nil {
t.Error("Expected initialized upstreams map")
}
}
func TestAddUpstream(t *testing.T) {
s := NewServer()
targets := []TargetSpec{
{Addr: "localhost:8001", Weight: 1},
{Addr: "localhost:8002", Weight: 2},
}
hcSpec := HealthCheckSpec{
Enabled: false,
Interval: 10 * time.Second,
Timeout: 5 * time.Second,
}
err := s.AddUpstream("test", targets, "round_robin", hcSpec)
if err != nil {
t.Errorf("AddUpstream failed: %v", err)
}
if len(s.upstreams) != 1 {
t.Errorf("Expected 1 upstream, got %d", len(s.upstreams))
}
up := s.upstreams["test"]
if up == nil {
t.Fatal("Expected non-nil upstream")
}
if len(up.targets) != 2 {
t.Errorf("Expected 2 targets, got %d", len(up.targets))
}
}
func TestRoundRobinBalancer(t *testing.T) {
targets := []*Target{
{addr: "localhost:8001"},
{addr: "localhost:8002"},
{addr: "localhost:8003"},
}
for _, target := range targets {
target.healthy.Store(true)
}
rr := newRoundRobin()
// 测试轮询
results := make(map[string]int)
for range 6 {
selected := rr.Select(targets)
if selected == nil {
t.Error("Expected non-nil target")
continue
}
results[selected.addr]++
}
// 每个目标应该被选中 2 次
for _, target := range targets {
if results[target.addr] != 2 {
t.Errorf("Expected %s to be selected 2 times, got %d", target.addr, results[target.addr])
}
}
}
func TestLeastConnBalancer(t *testing.T) {
targets := []*Target{
{addr: "localhost:8001", conns: 5},
{addr: "localhost:8002", conns: 2},
{addr: "localhost:8003", conns: 8},
}
for _, t := range targets {
t.healthy.Store(true)
}
lc := newLeastConn()
selected := lc.Select(targets)
if selected == nil {
t.Error("Expected non-nil target")
} else if selected.addr != "localhost:8002" {
t.Errorf("Expected localhost:8002 (least connections), got %s", selected.addr)
}
}
func TestBalancerNoHealthyTargets(t *testing.T) {
targets := []*Target{
{addr: "localhost:8001"},
{addr: "localhost:8002"},
}
// 不设置 healthy默认为 false
rr := newRoundRobin()
selected := rr.Select(targets)
if selected != nil {
t.Error("Expected nil for no healthy targets")
}
lc := newLeastConn()
selected = lc.Select(targets)
if selected != nil {
t.Error("Expected nil for no healthy targets")
}
}
func TestWeightedRoundRobinBalancer(t *testing.T) {
targets := []*Target{
{addr: "localhost:8001", weight: 3},
{addr: "localhost:8002", weight: 1},
}
for _, target := range targets {
target.healthy.Store(true)
}
wrr := newWeightedRoundRobin()
// 测试加权分布3:1 比例
results := make(map[string]int)
for range 8 {
selected := wrr.Select(targets)
if selected == nil {
t.Error("Expected non-nil target")
continue
}
results[selected.addr]++
}
// localhost:8001 应被选中 6 次localhost:8002 应被选中 2 次
if results["localhost:8001"] != 6 {
t.Errorf("Expected localhost:8001 to be selected 6 times, got %d", results["localhost:8001"])
}
if results["localhost:8002"] != 2 {
t.Errorf("Expected localhost:8002 to be selected 2 times, got %d", results["localhost:8002"])
}
}
func TestIPHashBalancer(t *testing.T) {
targets := []*Target{
{addr: "localhost:8001"},
{addr: "localhost:8002"},
{addr: "localhost:8003"},
}
for _, target := range targets {
target.healthy.Store(true)
}
ih := newIPHash()
// 相同 IP 应始终选择同一目标
ip1 := "192.168.1.1"
selected1 := ih.(*ipHash).SelectByIP(targets, ip1)
selected2 := ih.(*ipHash).SelectByIP(targets, ip1)
if selected1 != selected2 {
t.Error("Same IP should select same target")
}
// 不同 IP 可能选择不同目标
ip2 := "10.0.0.1"
selected3 := ih.(*ipHash).SelectByIP(targets, ip2)
// 验证返回非空
if selected3 == nil {
t.Error("Expected non-nil target for different IP")
}
}
func TestUpstreamSelect(t *testing.T) {
u := &Upstream{
targets: []*Target{
{addr: "localhost:8001"},
{addr: "localhost:8002"},
},
balancer: newRoundRobin(),
}
for _, t := range u.targets {
t.healthy.Store(true)
}
selected := u.Select()
if selected == nil {
t.Error("Expected non-nil target")
}
}
func TestTargetHealthy(t *testing.T) {
target := &Target{
addr: "localhost:8001",
weight: 1,
}
// 初始状态应该是不健康(默认为 false
if target.healthy.Load() {
t.Error("新目标应该默认为不健康")
}
// 设置为健康
target.healthy.Store(true)
if !target.healthy.Load() {
t.Error("目标应该被标记为健康")
}
// 设置为不健康
target.healthy.Store(false)
if target.healthy.Load() {
t.Error("目标应该被标记为不健康")
}
}
func TestConcurrentConnections(t *testing.T) {
s := NewServer()
targets := []TargetSpec{
{Addr: "localhost:8001", Weight: 1},
}
_ = s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{})
// 并发增加连接数
var wg sync.WaitGroup
for range 100 {
wg.Go(func() {
atomic.AddInt64(&s.connCount, 1)
})
}
wg.Wait()
if s.connCount != 100 {
t.Errorf("Expected 100 connections, got %d", s.connCount)
}
}
func TestUDPServerInvalidUpstream(t *testing.T) {
s := NewServer()
// 尝试监听不存在的上游配置
err := s.ListenUDP("127.0.0.1:0", "non_existent", 0)
if err == nil {
t.Error("Expected error for non-existent upstream")
}
}
func TestUDPSessionKey(t *testing.T) {
addr1, _ := net.ResolveUDPAddr("udp", "127.0.0.1:1234")
addr2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:5678")
addr3, _ := net.ResolveUDPAddr("udp", "127.0.0.1:1234")
key1 := sessionKey(addr1)
key2 := sessionKey(addr2)
key3 := sessionKey(addr3)
if key1 == key2 {
t.Error("Different addresses should have different keys")
}
if key1 != key3 {
t.Error("Same addresses should have same keys")
}
}
func TestNewUDPServer(t *testing.T) {
// 创建 UDP 连接
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer func() { _ = conn.Close() }()
// 创建上游
upstream := &Upstream{
targets: []*Target{{addr: "127.0.0.1:19002"}},
balancer: newRoundRobin(),
}
// 测试默认超时
srv := newUDPServer(conn, upstream, 0)
if srv.timeout != 60*time.Second {
t.Errorf("Expected default timeout 60s, got %v", srv.timeout)
}
// 测试自定义超时
srv2 := newUDPServer(conn, upstream, 30*time.Second)
if srv2.timeout != 30*time.Second {
t.Errorf("Expected timeout 30s, got %v", srv2.timeout)
}
}
func TestRoundRobinBalancerWithSingleTarget(t *testing.T) {
rb := newRoundRobin()
targets := []*Target{
{addr: "backend1:8080"},
}
targets[0].healthy.Store(true)
// 测试单个健康目标
for range 5 {
target := rb.Select(targets)
if target == nil {
t.Error("Expected non-nil target")
continue
}
if target.addr != "backend1:8080" {
t.Errorf("Expected backend1:8080, got %s", target.addr)
}
}
}
func TestLeastConnBalancerWithTie(t *testing.T) {
lc := newLeastConn()
targets := []*Target{
{addr: "backend1:8080", conns: 5},
{addr: "backend2:8080", conns: 5},
{addr: "backend3:8080", conns: 5},
}
for _, t := range targets {
t.healthy.Store(true)
}
// 当连接数相同时,应该选择第一个
selected := lc.Select(targets)
if selected == nil {
t.Error("Expected non-nil target")
}
}
func TestAddUpstreamWithLeastConn(t *testing.T) {
s := NewServer()
targets := []TargetSpec{
{Addr: "localhost:8001", Weight: 1},
{Addr: "localhost:8002", Weight: 2},
}
err := s.AddUpstream("least_conn_test", targets, "least_conn", HealthCheckSpec{})
if err != nil {
t.Errorf("AddUpstream failed: %v", err)
}
up := s.upstreams["least_conn_test"]
if up == nil {
t.Fatal("Expected non-nil upstream")
}
// 验证使用的是最少连接均衡器
_, ok := up.balancer.(*leastConn)
if !ok {
t.Error("Expected leastConn balancer")
}
}
func TestUpstreamSelectNoHealthy(t *testing.T) {
u := &Upstream{
targets: []*Target{
{addr: "localhost:8001"},
{addr: "localhost:8002"},
},
balancer: newRoundRobin(),
}
// 不设置 healthy默认为 false
selected := u.Select()
if selected != nil {
t.Error("Expected nil for no healthy targets")
}
}
func TestCleanupExpiredSessions(t *testing.T) {
// 创建 UDP 连接
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer func() { _ = conn.Close() }()
// 创建上游
upstream := &Upstream{
targets: []*Target{{addr: "127.0.0.1:19005"}},
balancer: newRoundRobin(),
}
upstream.targets[0].healthy.Store(true)
// 创建 UDP 服务器,设置很短的超时时间
srv := newUDPServer(conn, upstream, 1*time.Millisecond)
// 创建模拟会话
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12345")
session := &udpSession{
clientAddr: clientAddr,
lastActive: time.Now().Add(-1 * time.Hour), // 很久以前的活动
}
srv.sessions[sessionKey(clientAddr)] = session
// 执行清理
srv.cleanupExpiredSessions()
// 验证会话已被清理
srv.mu.RLock()
if len(srv.sessions) != 0 {
t.Errorf("Expected 0 sessions after cleanup, got %d", len(srv.sessions))
}
srv.mu.RUnlock()
}
func TestUDPSessionOperations(t *testing.T) {
// 创建 UDP 连接
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn, _ := net.ListenUDP("udp", udpAddr)
defer func() { _ = conn.Close() }()
// 创建目标服务器(用于模拟连接)
targetAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:19007")
targetConn, _ := net.ListenUDP("udp", targetAddr)
defer func() { _ = targetConn.Close() }()
// 创建上游
upstream := &Upstream{
targets: []*Target{{addr: "127.0.0.1:19007"}},
balancer: newRoundRobin(),
}
upstream.targets[0].healthy.Store(true)
// 创建 UDP 服务器
srv := newUDPServer(conn, upstream, 1*time.Minute)
// 测试 getSession - 不存在的会话
clientAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:12347")
session := srv.getSession(clientAddr)
if session != nil {
t.Error("Expected nil for non-existent session")
}
// 创建模拟会话
testSession := &udpSession{
clientAddr: clientAddr,
lastActive: time.Now(),
srv: srv,
}
srv.sessions[sessionKey(clientAddr)] = testSession
// 测试 getSession - 存在的会话
session = srv.getSession(clientAddr)
if session == nil {
t.Error("Expected non-nil session")
}
// 测试 removeSession
srv.removeSession(clientAddr)
session = srv.getSession(clientAddr)
if session != nil {
t.Error("Expected nil after removeSession")
}
}
func TestUDPSessionClose(_ *testing.T) {
// 创建两个 UDP 连接用于测试
udpAddr1, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn1, _ := net.ListenUDP("udp", udpAddr1)
udpAddr2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
conn2, _ := net.ListenUDP("udp", udpAddr2)
// 创建会话
session := &udpSession{
clientAddr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12348},
targetConn: conn2,
lastActive: time.Now(),
}
// 测试 close - 应该能正常关闭
session.close()
// 第二次调用 close 不应该出错(使用 sync.Once
session.close()
_ = conn1.Close()
}
func TestHealthCheckerCheckWithHealthyTarget(t *testing.T) {
// 启动一个临时的 TCP 服务器
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
defer func() { _ = listener.Close() }()
// 在后台运行服务器
go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}
_ = conn.Close()
}
}()
addr := listener.Addr().String()
u := &Upstream{
targets: []*Target{
{addr: addr},
},
}
// 初始设置为不健康
u.targets[0].healthy.Store(false)
hc := &HealthChecker{
upstream: u,
interval: 1 * time.Second,
timeout: 100 * time.Millisecond,
stopCh: make(chan struct{}),
}
// 执行健康检查
hc.check()
// 目标应该被标记为健康(因为服务器在运行)
if !u.targets[0].healthy.Load() {
t.Error("Expected target to be marked healthy")
}
}