From 7204432ca098cd27bab8ea14ac7ed4af65de7502 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 10 Jun 2026 13:45:35 +0800 Subject: [PATCH] fix(stream): correct upstream selection and add graceful shutdown - Fix handleConnection to use addr parameter for direct upstream map lookup instead of always selecting the first upstream - Add Server.Stop() for graceful shutdown with listener closing, UDP server cleanup, health checker termination, and goroutine joining - Add shutdownStream() to App and call it in SIGTERM/SIGQUIT/SIGUSR2 signal handlers to prevent goroutine and port leaks on shutdown --- internal/app/app.go | 3 + internal/app/app_common.go | 6 + internal/stream/server_coverage_test.go | 154 +++++++++++++++++++++--- internal/stream/stream.go | 93 ++++++++++---- internal/stream/stream_coverage_test.go | 4 +- 5 files changed, 219 insertions(+), 41 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index 230d15f..8ee20d7 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -123,6 +123,7 @@ func (a *App) handleSignal(sig os.Signal) bool { timeout = 30 * time.Second } a.logger.LogSignal("SIGQUIT", fmt.Sprintf("Graceful stop (waiting %v)", timeout)) + a.shutdownStream() a.shutdownHTTP2() a.shutdownHTTP3() _ = a.srv.GracefulStop(timeout) @@ -139,6 +140,7 @@ func (a *App) handleSignal(sig os.Signal) bool { } else { a.logger.LogSignal(sigName(sigTyped), "Stopping server") } + a.shutdownStream() a.shutdownHTTP2() a.shutdownHTTP3() _ = a.srv.StopWithTimeout(timeout) @@ -329,6 +331,7 @@ func (a *App) gracefulUpgrade() { if timeout <= 0 { timeout = 30 * time.Second } + a.shutdownStream() a.shutdownHTTP2() a.shutdownHTTP3() _ = a.srv.GracefulStop(timeout) diff --git a/internal/app/app_common.go b/internal/app/app_common.go index 4867911..0d788b5 100644 --- a/internal/app/app_common.go +++ b/internal/app/app_common.go @@ -252,6 +252,12 @@ func (a *App) shutdownHTTP2() { } // reopenLogs reinitializes the logger from current config. +func (a *App) shutdownStream() { + if a.streamSrv != nil { + a.streamSrv.Stop() + } +} + func (a *App) reopenLogs() { if a.cfg != nil { logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format) diff --git a/internal/stream/server_coverage_test.go b/internal/stream/server_coverage_test.go index 8fc419a..6461669 100644 --- a/internal/stream/server_coverage_test.go +++ b/internal/stream/server_coverage_test.go @@ -12,6 +12,7 @@ package stream import ( + "bytes" "fmt" "io" "net" @@ -63,7 +64,7 @@ func TestStart_NoListeners(t *testing.T) { require.NoError(t, err) assert.True(t, s.running.Load()) - s.running.Store(false) + s.Stop() } func TestStart_WithTCPListeners(t *testing.T) { @@ -81,12 +82,7 @@ func TestStart_WithTCPListeners(t *testing.T) { require.NoError(t, err) assert.True(t, s.running.Load()) - s.running.Store(false) - s.mu.RLock() - for _, ln := range s.listeners { - _ = ln.Close() - } - s.mu.RUnlock() + s.Stop() } func TestStart_AcceptConnections(t *testing.T) { @@ -115,7 +111,7 @@ func TestStart_AcceptConnections(t *testing.T) { proxyAddr := ln.Addr().String() s.mu.Lock() - s.listeners[proxyAddr] = ln + s.listeners["test"] = ln s.mu.Unlock() err = s.Start() @@ -136,12 +132,7 @@ func TestStart_AcceptConnections(t *testing.T) { _ = clientConn.Close() - s.running.Store(false) - s.mu.RLock() - for _, l := range s.listeners { - _ = l.Close() - } - s.mu.RUnlock() + s.Stop() _ = backendLn.Close() } @@ -637,3 +628,138 @@ func TestStartCleanupTicker_StopsOnSignal(t *testing.T) { t.Fatal("startCleanupTicker did not stop after signal") } } + +func TestHandleConnection_MultipleUpstreams(t *testing.T) { + s := NewServer() + + backend1, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + backend2, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backend1.Close() + defer backend2.Close() + + go func() { + conn, err := backend1.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(append([]byte("backend1:"), buf[:n]...)) + }() + + go func() { + conn, err := backend2.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, _ := conn.Read(buf) + _, _ = conn.Write(append([]byte("backend2:"), buf[:n]...)) + }() + + _ = s.AddUpstream("upstream1", []TargetSpec{{Addr: backend1.Addr().String()}}, "round_robin", HealthCheckSpec{}) + _ = s.AddUpstream("upstream2", []TargetSpec{{Addr: backend2.Addr().String()}}, "round_robin", HealthCheckSpec{}) + s.upstreams["upstream1"].targets[0].healthy.Store(true) + s.upstreams["upstream2"].targets[0].healthy.Store(true) + + ln1, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + ln2, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + s.mu.Lock() + s.listeners["upstream1"] = ln1 + s.listeners["upstream2"] = ln2 + s.mu.Unlock() + + err = s.Start() + require.NoError(t, err) + defer s.Stop() + + conn1, err := net.DialTimeout("tcp", ln1.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn1.Close() + + _, err = conn1.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 1024) + _ = conn1.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := conn1.Read(buf) + require.NoError(t, err) + assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend1:")), "should route to backend1, got: %s", string(buf[:n])) + + conn2, err := net.DialTimeout("tcp", ln2.Addr().String(), 2*time.Second) + require.NoError(t, err) + defer conn2.Close() + + _, err = conn2.Write([]byte("world")) + require.NoError(t, err) + + _ = conn2.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err = conn2.Read(buf) + require.NoError(t, err) + assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend2:")), "should route to backend2, got: %s", string(buf[:n])) +} + +func TestStop(t *testing.T) { + s := NewServer() + + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer backendLn.Close() + + go func() { + for { + conn, err := backendLn.Accept() + if err != nil { + return + } + _, _ = io.Copy(conn, conn) + _ = conn.Close() + } + }() + + _ = s.AddUpstream("test", []TargetSpec{{Addr: backendLn.Addr().String()}}, "round_robin", HealthCheckSpec{}) + s.upstreams["test"].targets[0].healthy.Store(true) + + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + s.mu.Lock() + s.listeners["test"] = ln + s.mu.Unlock() + + err = s.Start() + require.NoError(t, err) + + clientConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + require.NoError(t, err) + _ = clientConn.Close() + + s.Stop() + + assert.False(t, s.running.Load()) + + s.mu.RLock() + assert.Empty(t, s.listeners) + s.mu.RUnlock() +} + +func TestStop_Idempotent(t *testing.T) { + s := NewServer() + s.Stop() + s.Stop() +} + +func TestStart_DoubleStart(t *testing.T) { + s := NewServer() + err := s.Start() + require.NoError(t, err) + err = s.Start() + assert.Error(t, err) + s.Stop() +} diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 4bb906f..ae77e7a 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -29,6 +29,7 @@ package stream import ( + "fmt" "hash/fnv" "io" "net" @@ -286,18 +287,14 @@ func (i *ipHash) SelectByIP(targets []*Target, clientIP string) *Target { // Server TCP/UDP Stream 代理服务器。 type Server struct { - // listeners TCP 监听器映射,按 upstream 名称索引 - listeners map[string]net.Listener - // udpServers UDP 服务器映射 + listeners map[string]net.Listener udpServers map[string]*udpServer - // upstreams 上游配置映射 - upstreams map[string]*Upstream - // connCount 当前连接数 - connCount atomic.Int64 - // mu 读写锁,保护并发访问 - mu sync.RWMutex - // running 运行状态标志 - running atomic.Bool + upstreams map[string]*Upstream + connCount atomic.Int64 + mu sync.RWMutex + running atomic.Bool + wg sync.WaitGroup + stopCh chan struct{} } // Upstream Stream 上游配置。 @@ -393,6 +390,7 @@ func NewServer() *Server { listeners: make(map[string]net.Listener), udpServers: make(map[string]*udpServer), upstreams: make(map[string]*Upstream), + stopCh: make(chan struct{}), } } @@ -491,25 +489,70 @@ func (s *Server) ListenUDP(addr string, upstreamName string, timeout time.Durati // Start 启动 Stream 服务器。 func (s *Server) Start() error { - s.running.Store(true) + if !s.running.CompareAndSwap(false, true) { + return fmt.Errorf("stream server already running") + } s.mu.RLock() defer s.mu.RUnlock() - // 启动 TCP 监听器 for addr, listener := range s.listeners { - go s.acceptLoop(addr, listener) + s.wg.Add(1) + go func(a string, ln net.Listener) { + defer s.wg.Done() + s.acceptLoop(a, ln) + }(addr, listener) } - // 启动 UDP 服务器 for _, udpSrv := range s.udpServers { - go udpSrv.serve() - go udpSrv.startCleanupTicker() + s.wg.Add(1) + go func(u *udpServer) { + defer s.wg.Done() + u.serve() + }(udpSrv) + s.wg.Add(1) + go func(u *udpServer) { + defer s.wg.Done() + u.startCleanupTicker() + }(udpSrv) } return nil } +// Stop stops the stream server, closing all listeners and waiting for goroutines to finish. +func (s *Server) Stop() { + if !s.running.CompareAndSwap(true, false) { + return + } + + close(s.stopCh) + + s.mu.Lock() + for _, ln := range s.listeners { + _ = ln.Close() + } + for _, udpSrv := range s.udpServers { + close(udpSrv.stopCh) + if udpSrv.conn != nil { + _ = udpSrv.conn.Close() + } + } + for _, upstream := range s.upstreams { + if upstream.healthChk != nil && upstream.healthChk.stopCh != nil { + close(upstream.healthChk.stopCh) + } + } + s.mu.Unlock() + + s.wg.Wait() + + s.mu.Lock() + s.listeners = make(map[string]net.Listener) + s.udpServers = make(map[string]*udpServer) + s.mu.Unlock() +} + // acceptLoop 接受连接循环。 // // 在单独的 goroutine 中运行,持续接受 TCP 连接。 @@ -523,7 +566,12 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) { conn, err := listener.Accept() if err != nil { if !s.running.Load() { - return // 正常关闭 + return + } + select { + case <-s.stopCh: + return + default: } continue } @@ -544,19 +592,14 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) { // 参数: // - clientConn: 客户端连接 // - addr: 监听地址 -func (s *Server) handleConnection(clientConn net.Conn, _ string) { +func (s *Server) handleConnection(clientConn net.Conn, addr string) { defer func() { _ = clientConn.Close() s.connCount.Add(-1) }() s.mu.RLock() - // 根据监听地址找到对应 upstream(简化:用第一个) - var upstream *Upstream - for _, up := range s.upstreams { - upstream = up - break - } + upstream := s.upstreams[addr] s.mu.RUnlock() if upstream == nil { diff --git a/internal/stream/stream_coverage_test.go b/internal/stream/stream_coverage_test.go index 0f20ed3..f02d8a1 100644 --- a/internal/stream/stream_coverage_test.go +++ b/internal/stream/stream_coverage_test.go @@ -157,7 +157,7 @@ func TestHandleConnection_NoHealthyTarget(t *testing.T) { done := make(chan struct{}) go func() { - s.handleConnection(clientConn, "127.0.0.1:0") + s.handleConnection(clientConn, "test2") close(done) }() @@ -200,7 +200,7 @@ func TestHandleConnection_DialFail(t *testing.T) { done := make(chan struct{}) go func() { - s.handleConnection(clientConn, "127.0.0.1:0") + s.handleConnection(clientConn, "test3") close(done) }()