package stream import ( "net" "sync" "sync/atomic" "testing" "time" ) func TestNewServer(t *testing.T) { s := NewServer() if s == nil { t.Error("Expected non-nil server") } if s.listeners == nil { t.Error("Expected initialized listeners map") } if s.upstreams == nil { t.Error("Expected initialized upstreams map") } } func TestAddUpstream(t *testing.T) { s := NewServer() targets := []TargetSpec{ {Addr: "localhost:8001", Weight: 1}, {Addr: "localhost:8002", Weight: 2}, } hcSpec := HealthCheckSpec{ Enabled: false, Interval: 10 * time.Second, Timeout: 5 * time.Second, } err := s.AddUpstream("test", targets, "round_robin", hcSpec) if err != nil { t.Errorf("AddUpstream failed: %v", err) } if len(s.upstreams) != 1 { t.Errorf("Expected 1 upstream, got %d", len(s.upstreams)) } up := s.upstreams["test"] if up == nil { t.Error("Expected non-nil upstream") } if len(up.targets) != 2 { t.Errorf("Expected 2 targets, got %d", len(up.targets)) } } func TestRoundRobinBalancer(t *testing.T) { targets := []*Target{ {addr: "localhost:8001"}, {addr: "localhost:8002"}, {addr: "localhost:8003"}, } for _, target := range targets { target.healthy.Store(true) } rr := newRoundRobin() // 测试轮询 results := make(map[string]int) for i := 0; i < 6; i++ { selected := rr.Select(targets) if selected == nil { t.Error("Expected non-nil target") continue } results[selected.addr]++ } // 每个目标应该被选中 2 次 for _, target := range targets { if results[target.addr] != 2 { t.Errorf("Expected %s to be selected 2 times, got %d", target.addr, results[target.addr]) } } } func TestLeastConnBalancer(t *testing.T) { targets := []*Target{ {addr: "localhost:8001", conns: 5}, {addr: "localhost:8002", conns: 2}, {addr: "localhost:8003", conns: 8}, } for _, t := range targets { t.healthy.Store(true) } lc := newLeastConn() selected := lc.Select(targets) if selected == nil { t.Error("Expected non-nil target") } else if selected.addr != "localhost:8002" { t.Errorf("Expected localhost:8002 (least connections), got %s", selected.addr) } } func TestBalancerNoHealthyTargets(t *testing.T) { targets := []*Target{ {addr: "localhost:8001"}, {addr: "localhost:8002"}, } // 不设置 healthy,默认为 false rr := newRoundRobin() selected := rr.Select(targets) if selected != nil { t.Error("Expected nil for no healthy targets") } lc := newLeastConn() selected = lc.Select(targets) if selected != nil { t.Error("Expected nil for no healthy targets") } } func TestServerStats(t *testing.T) { s := NewServer() stats := s.Stats() if stats.Connections != 0 { t.Errorf("Expected 0 connections, got %d", stats.Connections) } if stats.Listeners != 0 { t.Errorf("Expected 0 listeners, got %d", stats.Listeners) } } func TestUpstreamSelect(t *testing.T) { u := &Upstream{ targets: []*Target{ {addr: "localhost:8001"}, {addr: "localhost:8002"}, }, balancer: newRoundRobin(), } for _, t := range u.targets { t.healthy.Store(true) } selected := u.Select() if selected == nil { t.Error("Expected non-nil target") } } func TestHealthChecker(t *testing.T) { u := &Upstream{ targets: []*Target{ {addr: "localhost:99999"}, // 不存在的端口 }, } hc := &HealthChecker{ upstream: u, interval: 1 * time.Second, timeout: 100 * time.Millisecond, stopCh: make(chan struct{}), } // 执行一次检查 hc.check() // 目标应该被标记为不健康 if u.targets[0].healthy.Load() { t.Error("Expected target to be marked unhealthy") } } func TestUDPListener(t *testing.T) { addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") if err != nil { t.Fatalf("Failed to resolve UDP address: %v", err) } conn, err := net.ListenUDP("udp", addr) if err != nil { t.Fatalf("Failed to listen UDP: %v", err) } defer conn.Close() ul := &udpListener{conn: conn} // 测试 Addr if ul.Addr() == nil { t.Error("Expected non-nil address") } // 测试 Close if err := ul.Close(); err != nil { t.Errorf("Close failed: %v", err) } // 测试 Accept(应该返回 io.EOF) _, err = ul.Accept() if err == nil { t.Error("Expected error from Accept") } } func TestConcurrentConnections(t *testing.T) { s := NewServer() targets := []TargetSpec{ {Addr: "localhost:8001", Weight: 1}, } s.AddUpstream("test", targets, "round_robin", HealthCheckSpec{}) // 并发增加连接数 var wg sync.WaitGroup for i := 0; i < 100; i++ { wg.Add(1) go func() { defer wg.Done() atomic.AddInt64(&s.connCount, 1) }() } wg.Wait() if s.connCount != 100 { t.Errorf("Expected 100 connections, got %d", s.connCount) } }