fix(stream): correct upstream selection and add graceful shutdown

- Fix handleConnection to use addr parameter for direct upstream map
  lookup instead of always selecting the first upstream
- Add Server.Stop() for graceful shutdown with listener closing, UDP
  server cleanup, health checker termination, and goroutine joining
- Add shutdownStream() to App and call it in SIGTERM/SIGQUIT/SIGUSR2
  signal handlers to prevent goroutine and port leaks on shutdown
This commit is contained in:
xfy 2026-06-10 13:45:35 +08:00
parent f12ffd180f
commit 7204432ca0
5 changed files with 219 additions and 41 deletions

View File

@ -123,6 +123,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
timeout = 30 * time.Second
}
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("Graceful stop (waiting %v)", timeout))
a.shutdownStream()
a.shutdownHTTP2()
a.shutdownHTTP3()
_ = a.srv.GracefulStop(timeout)
@ -139,6 +140,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
} else {
a.logger.LogSignal(sigName(sigTyped), "Stopping server")
}
a.shutdownStream()
a.shutdownHTTP2()
a.shutdownHTTP3()
_ = a.srv.StopWithTimeout(timeout)
@ -329,6 +331,7 @@ func (a *App) gracefulUpgrade() {
if timeout <= 0 {
timeout = 30 * time.Second
}
a.shutdownStream()
a.shutdownHTTP2()
a.shutdownHTTP3()
_ = a.srv.GracefulStop(timeout)

View File

@ -252,6 +252,12 @@ func (a *App) shutdownHTTP2() {
}
// reopenLogs reinitializes the logger from current config.
func (a *App) shutdownStream() {
if a.streamSrv != nil {
a.streamSrv.Stop()
}
}
func (a *App) reopenLogs() {
if a.cfg != nil {
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)

View File

@ -12,6 +12,7 @@
package stream
import (
"bytes"
"fmt"
"io"
"net"
@ -63,7 +64,7 @@ func TestStart_NoListeners(t *testing.T) {
require.NoError(t, err)
assert.True(t, s.running.Load())
s.running.Store(false)
s.Stop()
}
func TestStart_WithTCPListeners(t *testing.T) {
@ -81,12 +82,7 @@ func TestStart_WithTCPListeners(t *testing.T) {
require.NoError(t, err)
assert.True(t, s.running.Load())
s.running.Store(false)
s.mu.RLock()
for _, ln := range s.listeners {
_ = ln.Close()
}
s.mu.RUnlock()
s.Stop()
}
func TestStart_AcceptConnections(t *testing.T) {
@ -115,7 +111,7 @@ func TestStart_AcceptConnections(t *testing.T) {
proxyAddr := ln.Addr().String()
s.mu.Lock()
s.listeners[proxyAddr] = ln
s.listeners["test"] = ln
s.mu.Unlock()
err = s.Start()
@ -136,12 +132,7 @@ func TestStart_AcceptConnections(t *testing.T) {
_ = clientConn.Close()
s.running.Store(false)
s.mu.RLock()
for _, l := range s.listeners {
_ = l.Close()
}
s.mu.RUnlock()
s.Stop()
_ = backendLn.Close()
}
@ -637,3 +628,138 @@ func TestStartCleanupTicker_StopsOnSignal(t *testing.T) {
t.Fatal("startCleanupTicker did not stop after signal")
}
}
func TestHandleConnection_MultipleUpstreams(t *testing.T) {
s := NewServer()
backend1, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
backend2, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer backend1.Close()
defer backend2.Close()
go func() {
conn, err := backend1.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 1024)
n, _ := conn.Read(buf)
_, _ = conn.Write(append([]byte("backend1:"), buf[:n]...))
}()
go func() {
conn, err := backend2.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 1024)
n, _ := conn.Read(buf)
_, _ = conn.Write(append([]byte("backend2:"), buf[:n]...))
}()
_ = s.AddUpstream("upstream1", []TargetSpec{{Addr: backend1.Addr().String()}}, "round_robin", HealthCheckSpec{})
_ = s.AddUpstream("upstream2", []TargetSpec{{Addr: backend2.Addr().String()}}, "round_robin", HealthCheckSpec{})
s.upstreams["upstream1"].targets[0].healthy.Store(true)
s.upstreams["upstream2"].targets[0].healthy.Store(true)
ln1, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
ln2, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s.mu.Lock()
s.listeners["upstream1"] = ln1
s.listeners["upstream2"] = ln2
s.mu.Unlock()
err = s.Start()
require.NoError(t, err)
defer s.Stop()
conn1, err := net.DialTimeout("tcp", ln1.Addr().String(), 2*time.Second)
require.NoError(t, err)
defer conn1.Close()
_, err = conn1.Write([]byte("hello"))
require.NoError(t, err)
buf := make([]byte, 1024)
_ = conn1.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := conn1.Read(buf)
require.NoError(t, err)
assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend1:")), "should route to backend1, got: %s", string(buf[:n]))
conn2, err := net.DialTimeout("tcp", ln2.Addr().String(), 2*time.Second)
require.NoError(t, err)
defer conn2.Close()
_, err = conn2.Write([]byte("world"))
require.NoError(t, err)
_ = conn2.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err = conn2.Read(buf)
require.NoError(t, err)
assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend2:")), "should route to backend2, got: %s", string(buf[:n]))
}
func TestStop(t *testing.T) {
s := NewServer()
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer backendLn.Close()
go func() {
for {
conn, err := backendLn.Accept()
if err != nil {
return
}
_, _ = io.Copy(conn, conn)
_ = conn.Close()
}
}()
_ = s.AddUpstream("test", []TargetSpec{{Addr: backendLn.Addr().String()}}, "round_robin", HealthCheckSpec{})
s.upstreams["test"].targets[0].healthy.Store(true)
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s.mu.Lock()
s.listeners["test"] = ln
s.mu.Unlock()
err = s.Start()
require.NoError(t, err)
clientConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
require.NoError(t, err)
_ = clientConn.Close()
s.Stop()
assert.False(t, s.running.Load())
s.mu.RLock()
assert.Empty(t, s.listeners)
s.mu.RUnlock()
}
func TestStop_Idempotent(t *testing.T) {
s := NewServer()
s.Stop()
s.Stop()
}
func TestStart_DoubleStart(t *testing.T) {
s := NewServer()
err := s.Start()
require.NoError(t, err)
err = s.Start()
assert.Error(t, err)
s.Stop()
}

View File

@ -29,6 +29,7 @@
package stream
import (
"fmt"
"hash/fnv"
"io"
"net"
@ -286,18 +287,14 @@ func (i *ipHash) SelectByIP(targets []*Target, clientIP string) *Target {
// Server TCP/UDP Stream 代理服务器。
type Server struct {
// listeners TCP 监听器映射,按 upstream 名称索引
listeners map[string]net.Listener
// udpServers UDP 服务器映射
udpServers map[string]*udpServer
// upstreams 上游配置映射
upstreams map[string]*Upstream
// connCount 当前连接数
connCount atomic.Int64
// mu 读写锁,保护并发访问
mu sync.RWMutex
// running 运行状态标志
running atomic.Bool
wg sync.WaitGroup
stopCh chan struct{}
}
// Upstream Stream 上游配置。
@ -393,6 +390,7 @@ func NewServer() *Server {
listeners: make(map[string]net.Listener),
udpServers: make(map[string]*udpServer),
upstreams: make(map[string]*Upstream),
stopCh: make(chan struct{}),
}
}
@ -491,25 +489,70 @@ func (s *Server) ListenUDP(addr string, upstreamName string, timeout time.Durati
// Start 启动 Stream 服务器。
func (s *Server) Start() error {
s.running.Store(true)
if !s.running.CompareAndSwap(false, true) {
return fmt.Errorf("stream server already running")
}
s.mu.RLock()
defer s.mu.RUnlock()
// 启动 TCP 监听器
for addr, listener := range s.listeners {
go s.acceptLoop(addr, listener)
s.wg.Add(1)
go func(a string, ln net.Listener) {
defer s.wg.Done()
s.acceptLoop(a, ln)
}(addr, listener)
}
// 启动 UDP 服务器
for _, udpSrv := range s.udpServers {
go udpSrv.serve()
go udpSrv.startCleanupTicker()
s.wg.Add(1)
go func(u *udpServer) {
defer s.wg.Done()
u.serve()
}(udpSrv)
s.wg.Add(1)
go func(u *udpServer) {
defer s.wg.Done()
u.startCleanupTicker()
}(udpSrv)
}
return nil
}
// Stop stops the stream server, closing all listeners and waiting for goroutines to finish.
func (s *Server) Stop() {
if !s.running.CompareAndSwap(true, false) {
return
}
close(s.stopCh)
s.mu.Lock()
for _, ln := range s.listeners {
_ = ln.Close()
}
for _, udpSrv := range s.udpServers {
close(udpSrv.stopCh)
if udpSrv.conn != nil {
_ = udpSrv.conn.Close()
}
}
for _, upstream := range s.upstreams {
if upstream.healthChk != nil && upstream.healthChk.stopCh != nil {
close(upstream.healthChk.stopCh)
}
}
s.mu.Unlock()
s.wg.Wait()
s.mu.Lock()
s.listeners = make(map[string]net.Listener)
s.udpServers = make(map[string]*udpServer)
s.mu.Unlock()
}
// acceptLoop 接受连接循环。
//
// 在单独的 goroutine 中运行,持续接受 TCP 连接。
@ -523,7 +566,12 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) {
conn, err := listener.Accept()
if err != nil {
if !s.running.Load() {
return // 正常关闭
return
}
select {
case <-s.stopCh:
return
default:
}
continue
}
@ -544,19 +592,14 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) {
// 参数:
// - clientConn: 客户端连接
// - addr: 监听地址
func (s *Server) handleConnection(clientConn net.Conn, _ string) {
func (s *Server) handleConnection(clientConn net.Conn, addr string) {
defer func() {
_ = clientConn.Close()
s.connCount.Add(-1)
}()
s.mu.RLock()
// 根据监听地址找到对应 upstream简化用第一个
var upstream *Upstream
for _, up := range s.upstreams {
upstream = up
break
}
upstream := s.upstreams[addr]
s.mu.RUnlock()
if upstream == nil {

View File

@ -157,7 +157,7 @@ func TestHandleConnection_NoHealthyTarget(t *testing.T) {
done := make(chan struct{})
go func() {
s.handleConnection(clientConn, "127.0.0.1:0")
s.handleConnection(clientConn, "test2")
close(done)
}()
@ -200,7 +200,7 @@ func TestHandleConnection_DialFail(t *testing.T) {
done := make(chan struct{})
go func() {
s.handleConnection(clientConn, "127.0.0.1:0")
s.handleConnection(clientConn, "test3")
close(done)
}()