主要变更: - WebSocket 代理支持 (internal/proxy/websocket.go) - OCSP stapling 实现 (internal/ssl/ocsp.go) - 监控状态端点 (internal/server/status.go) - 新增 nginx 模块文档 (19-24) - UDP 代理超时配置支持 - 多模块代码注释完善和功能增强 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
349 lines
7.3 KiB
Go
349 lines
7.3 KiB
Go
package stream
|
||
|
||
import (
|
||
"net"
|
||
"sync"
|
||
"sync/atomic"
|
||
"testing"
|
||
"time"
|
||
)
|
||
|
||
func TestNewServer(t *testing.T) {
|
||
s := NewServer()
|
||
if s == nil {
|
||
t.Error("Expected non-nil server")
|
||
}
|
||
if s.listeners == nil {
|
||
t.Error("Expected initialized listeners map")
|
||
}
|
||
if s.upstreams == nil {
|
||
t.Error("Expected initialized upstreams map")
|
||
}
|
||
}
|
||
|
||
func TestAddUpstream(t *testing.T) {
|
||
s := NewServer()
|
||
|
||
targets := []TargetSpec{
|
||
{Addr: "localhost:8001", Weight: 1},
|
||
{Addr: "localhost:8002", Weight: 2},
|
||
}
|
||
|
||
hcSpec := HealthCheckSpec{
|
||
Enabled: false,
|
||
Interval: 10 * time.Second,
|
||
Timeout: 5 * time.Second,
|
||
}
|
||
|
||
err := s.AddUpstream("test", targets, "round_robin", hcSpec)
|
||
if err != nil {
|
||
t.Errorf("AddUpstream failed: %v", err)
|
||
}
|
||
|
||
if len(s.upstreams) != 1 {
|
||
t.Errorf("Expected 1 upstream, got %d", len(s.upstreams))
|
||
}
|
||
|
||
up := s.upstreams["test"]
|
||
if up == nil {
|
||
t.Error("Expected non-nil upstream")
|
||
}
|
||
if len(up.targets) != 2 {
|
||
t.Errorf("Expected 2 targets, got %d", len(up.targets))
|
||
}
|
||
}
|
||
|
||
func TestRoundRobinBalancer(t *testing.T) {
|
||
targets := []*Target{
|
||
{addr: "localhost:8001"},
|
||
{addr: "localhost:8002"},
|
||
{addr: "localhost:8003"},
|
||
}
|
||
for _, target := range targets {
|
||
target.healthy.Store(true)
|
||
}
|
||
|
||
rr := newRoundRobin()
|
||
|
||
// 测试轮询
|
||
results := make(map[string]int)
|
||
for i := 0; i < 6; i++ {
|
||
selected := rr.Select(targets)
|
||
if selected == nil {
|
||
t.Error("Expected non-nil target")
|
||
continue
|
||
}
|
||
results[selected.addr]++
|
||
}
|
||
|
||
// 每个目标应该被选中 2 次
|
||
for _, target := range targets {
|
||
if results[target.addr] != 2 {
|
||
t.Errorf("Expected %s to be selected 2 times, got %d", target.addr, results[target.addr])
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestLeastConnBalancer(t *testing.T) {
|
||
targets := []*Target{
|
||
{addr: "localhost:8001", conns: 5},
|
||
{addr: "localhost:8002", conns: 2},
|
||
{addr: "localhost:8003", conns: 8},
|
||
}
|
||
for _, t := range targets {
|
||
t.healthy.Store(true)
|
||
}
|
||
|
||
lc := newLeastConn()
|
||
selected := lc.Select(targets)
|
||
|
||
if selected == nil {
|
||
t.Error("Expected non-nil target")
|
||
} else if selected.addr != "localhost:8002" {
|
||
t.Errorf("Expected localhost:8002 (least connections), got %s", selected.addr)
|
||
}
|
||
}
|
||
|
||
func TestBalancerNoHealthyTargets(t *testing.T) {
|
||
targets := []*Target{
|
||
{addr: "localhost:8001"},
|
||
{addr: "localhost:8002"},
|
||
}
|
||
// 不设置 healthy,默认为 false
|
||
|
||
rr := newRoundRobin()
|
||
selected := rr.Select(targets)
|
||
if selected != nil {
|
||
t.Error("Expected nil for no healthy targets")
|
||
}
|
||
|
||
lc := newLeastConn()
|
||
selected = lc.Select(targets)
|
||
if selected != nil {
|
||
t.Error("Expected nil for no healthy targets")
|
||
}
|
||
}
|
||
|
||
func TestServerStats(t *testing.T) {
|
||
s := NewServer()
|
||
|
||
stats := s.Stats()
|
||
if stats.Connections != 0 {
|
||
t.Errorf("Expected 0 connections, got %d", stats.Connections)
|
||
}
|
||
if stats.Listeners != 0 {
|
||
t.Errorf("Expected 0 listeners, got %d", stats.Listeners)
|
||
}
|
||
}
|
||
|
||
func TestUpstreamSelect(t *testing.T) {
|
||
u := &Upstream{
|
||
targets: []*Target{
|
||
{addr: "localhost:8001"},
|
||
{addr: "localhost:8002"},
|
||
},
|
||
balancer: newRoundRobin(),
|
||
}
|
||
for _, t := range u.targets {
|
||
t.healthy.Store(true)
|
||
}
|
||
|
||
selected := u.Select()
|
||
if selected == nil {
|
||
t.Error("Expected non-nil target")
|
||
}
|
||
}
|
||
|
||
func TestHealthChecker(t *testing.T) {
|
||
u := &Upstream{
|
||
targets: []*Target{
|
||
{addr: "localhost:99999"}, // 不存在的端口
|
||
},
|
||
}
|
||
|
||
hc := &HealthChecker{
|
||
upstream: u,
|
||
interval: 1 * time.Second,
|
||
timeout: 100 * time.Millisecond,
|
||
stopCh: make(chan struct{}),
|
||
}
|
||
|
||
// 执行一次检查
|
||
hc.check()
|
||
|
||
// 目标应该被标记为不健康
|
||
if u.targets[0].healthy.Load() {
|
||
t.Error("Expected target to be marked unhealthy")
|
||
}
|
||
}
|
||
|
||
func TestUDPListener(t *testing.T) {
|
||
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to resolve UDP address: %v", err)
|
||
}
|
||
|
||
conn, err := net.ListenUDP("udp", addr)
|
||
if err != nil {
|
||
t.Fatalf("Failed to listen UDP: %v", err)
|
||
}
|
||
defer conn.Close()
|
||
|
||
ul := &udpListener{conn: conn}
|
||
|
||
// 测试 Addr
|
||
if ul.Addr() == nil {
|
||
t.Error("Expected non-nil address")
|
||
}
|
||
|
||
// 测试 Close
|
||
if err := ul.Close(); err != nil {
|
||
t.Errorf("Close failed: %v", err)
|
||
}
|
||
|
||
// 测试 Accept(应该返回 io.EOF)
|
||
_, err = ul.Accept()
|
||
if err == nil {
|
||
t.Error("Expected error from Accept")
|
||
}
|
||
}
|
||
|
||
func TestConcurrentConnections(t *testing.T) {
|
||
s := NewServer()
|
||
|
||
targets := []TargetSpec{
|
||
{Addr: "localhost:8001", Weight: 1},
|
||
}
|
||
s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{})
|
||
|
||
// 并发增加连接数
|
||
var wg sync.WaitGroup
|
||
for i := 0; i < 100; i++ {
|
||
wg.Add(1)
|
||
go func() {
|
||
defer wg.Done()
|
||
atomic.AddInt64(&s.connCount, 1)
|
||
}()
|
||
}
|
||
wg.Wait()
|
||
|
||
if s.connCount != 100 {
|
||
t.Errorf("Expected 100 connections, got %d", s.connCount)
|
||
}
|
||
}
|
||
|
||
func TestUDPServer(t *testing.T) {
|
||
s := NewServer()
|
||
|
||
// 添加 UDP 上游配置
|
||
targets := []TargetSpec{
|
||
{Addr: "127.0.0.1:0", Weight: 1},
|
||
}
|
||
err := s.AddUpstream("udp_test", targets, "round_robin", HealthCheckSpec{})
|
||
if err != nil {
|
||
t.Fatalf("AddUpstream failed: %v", err)
|
||
}
|
||
|
||
// 测试 UDP 监听(使用 :0 让系统分配端口)
|
||
err = s.ListenUDP("127.0.0.1:0", "udp_test", 1*time.Second)
|
||
if err != nil {
|
||
t.Fatalf("ListenUDP failed: %v", err)
|
||
}
|
||
|
||
// 验证 UDP 服务器已创建
|
||
s.mu.RLock()
|
||
if len(s.udpServers) != 1 {
|
||
t.Errorf("Expected 1 UDP server, got %d", len(s.udpServers))
|
||
}
|
||
s.mu.RUnlock()
|
||
|
||
// 测试 Stats 包含 UDP 监听器
|
||
stats := s.Stats()
|
||
if stats.Listeners != 1 {
|
||
t.Errorf("Expected 1 listener in stats, got %d", stats.Listeners)
|
||
}
|
||
}
|
||
|
||
func TestUDPServerInvalidUpstream(t *testing.T) {
|
||
s := NewServer()
|
||
|
||
// 尝试监听不存在的上游配置
|
||
err := s.ListenUDP("127.0.0.1:0", "non_existent", 0)
|
||
if err == nil {
|
||
t.Error("Expected error for non-existent upstream")
|
||
}
|
||
}
|
||
|
||
func TestUDPServerStartAndStop(t *testing.T) {
|
||
s := NewServer()
|
||
|
||
// 添加上游
|
||
targets := []TargetSpec{
|
||
{Addr: "127.0.0.1:19001", Weight: 1},
|
||
}
|
||
s.AddUpstream("udp_stop_test", targets, "round_robin", HealthCheckSpec{})
|
||
|
||
// 监听 UDP
|
||
err := s.ListenUDP("127.0.0.1:19000", "udp_stop_test", 500*time.Millisecond)
|
||
if err != nil {
|
||
t.Fatalf("ListenUDP failed: %v", err)
|
||
}
|
||
|
||
// 启动服务器
|
||
err = s.Start()
|
||
if err != nil {
|
||
t.Fatalf("Start failed: %v", err)
|
||
}
|
||
|
||
// 给服务器一点时间启动
|
||
time.Sleep(50 * time.Millisecond)
|
||
|
||
// 停止服务器
|
||
err = s.Stop()
|
||
if err != nil {
|
||
t.Errorf("Stop failed: %v", err)
|
||
}
|
||
}
|
||
|
||
func TestUDPSessionKey(t *testing.T) {
|
||
addr1, _ := net.ResolveUDPAddr("udp", "127.0.0.1:1234")
|
||
addr2, _ := net.ResolveUDPAddr("udp", "127.0.0.1:5678")
|
||
addr3, _ := net.ResolveUDPAddr("udp", "127.0.0.1:1234")
|
||
|
||
key1 := sessionKey(addr1)
|
||
key2 := sessionKey(addr2)
|
||
key3 := sessionKey(addr3)
|
||
|
||
if key1 == key2 {
|
||
t.Error("Different addresses should have different keys")
|
||
}
|
||
|
||
if key1 != key3 {
|
||
t.Error("Same addresses should have same keys")
|
||
}
|
||
}
|
||
|
||
func TestNewUDPServer(t *testing.T) {
|
||
// 创建 UDP 连接
|
||
udpAddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||
conn, _ := net.ListenUDP("udp", udpAddr)
|
||
defer conn.Close()
|
||
|
||
// 创建上游
|
||
upstream := &Upstream{
|
||
targets: []*Target{{addr: "127.0.0.1:19002"}},
|
||
balancer: newRoundRobin(),
|
||
}
|
||
|
||
// 测试默认超时
|
||
srv := newUDPServer(conn, upstream, 0)
|
||
if srv.timeout != 60*time.Second {
|
||
t.Errorf("Expected default timeout 60s, got %v", srv.timeout)
|
||
}
|
||
|
||
// 测试自定义超时
|
||
srv2 := newUDPServer(conn, upstream, 30*time.Second)
|
||
if srv2.timeout != 30*time.Second {
|
||
t.Errorf("Expected timeout 30s, got %v", srv2.timeout)
|
||
}
|
||
} |