fix(stream): correct upstream selection and add graceful shutdown
- Fix handleConnection to use addr parameter for direct upstream map lookup instead of always selecting the first upstream - Add Server.Stop() for graceful shutdown with listener closing, UDP server cleanup, health checker termination, and goroutine joining - Add shutdownStream() to App and call it in SIGTERM/SIGQUIT/SIGUSR2 signal handlers to prevent goroutine and port leaks on shutdown
This commit is contained in:
parent
f12ffd180f
commit
7204432ca0
@ -123,6 +123,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("Graceful stop (waiting %v)", timeout))
|
||||
a.shutdownStream()
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.GracefulStop(timeout)
|
||||
@ -139,6 +140,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
||||
} else {
|
||||
a.logger.LogSignal(sigName(sigTyped), "Stopping server")
|
||||
}
|
||||
a.shutdownStream()
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.StopWithTimeout(timeout)
|
||||
@ -329,6 +331,7 @@ func (a *App) gracefulUpgrade() {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
a.shutdownStream()
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.GracefulStop(timeout)
|
||||
|
||||
@ -252,6 +252,12 @@ func (a *App) shutdownHTTP2() {
|
||||
}
|
||||
|
||||
// reopenLogs reinitializes the logger from current config.
|
||||
func (a *App) shutdownStream() {
|
||||
if a.streamSrv != nil {
|
||||
a.streamSrv.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) reopenLogs() {
|
||||
if a.cfg != nil {
|
||||
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -63,7 +64,7 @@ func TestStart_NoListeners(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.True(t, s.running.Load())
|
||||
|
||||
s.running.Store(false)
|
||||
s.Stop()
|
||||
}
|
||||
|
||||
func TestStart_WithTCPListeners(t *testing.T) {
|
||||
@ -81,12 +82,7 @@ func TestStart_WithTCPListeners(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.True(t, s.running.Load())
|
||||
|
||||
s.running.Store(false)
|
||||
s.mu.RLock()
|
||||
for _, ln := range s.listeners {
|
||||
_ = ln.Close()
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
s.Stop()
|
||||
}
|
||||
|
||||
func TestStart_AcceptConnections(t *testing.T) {
|
||||
@ -115,7 +111,7 @@ func TestStart_AcceptConnections(t *testing.T) {
|
||||
proxyAddr := ln.Addr().String()
|
||||
|
||||
s.mu.Lock()
|
||||
s.listeners[proxyAddr] = ln
|
||||
s.listeners["test"] = ln
|
||||
s.mu.Unlock()
|
||||
|
||||
err = s.Start()
|
||||
@ -136,12 +132,7 @@ func TestStart_AcceptConnections(t *testing.T) {
|
||||
|
||||
_ = clientConn.Close()
|
||||
|
||||
s.running.Store(false)
|
||||
s.mu.RLock()
|
||||
for _, l := range s.listeners {
|
||||
_ = l.Close()
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
s.Stop()
|
||||
_ = backendLn.Close()
|
||||
}
|
||||
|
||||
@ -637,3 +628,138 @@ func TestStartCleanupTicker_StopsOnSignal(t *testing.T) {
|
||||
t.Fatal("startCleanupTicker did not stop after signal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleConnection_MultipleUpstreams(t *testing.T) {
|
||||
s := NewServer()
|
||||
|
||||
backend1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
backend2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backend1.Close()
|
||||
defer backend2.Close()
|
||||
|
||||
go func() {
|
||||
conn, err := backend1.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := conn.Read(buf)
|
||||
_, _ = conn.Write(append([]byte("backend1:"), buf[:n]...))
|
||||
}()
|
||||
|
||||
go func() {
|
||||
conn, err := backend2.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := conn.Read(buf)
|
||||
_, _ = conn.Write(append([]byte("backend2:"), buf[:n]...))
|
||||
}()
|
||||
|
||||
_ = s.AddUpstream("upstream1", []TargetSpec{{Addr: backend1.Addr().String()}}, "round_robin", HealthCheckSpec{})
|
||||
_ = s.AddUpstream("upstream2", []TargetSpec{{Addr: backend2.Addr().String()}}, "round_robin", HealthCheckSpec{})
|
||||
s.upstreams["upstream1"].targets[0].healthy.Store(true)
|
||||
s.upstreams["upstream2"].targets[0].healthy.Store(true)
|
||||
|
||||
ln1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
ln2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s.mu.Lock()
|
||||
s.listeners["upstream1"] = ln1
|
||||
s.listeners["upstream2"] = ln2
|
||||
s.mu.Unlock()
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
defer s.Stop()
|
||||
|
||||
conn1, err := net.DialTimeout("tcp", ln1.Addr().String(), 2*time.Second)
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
|
||||
_, err = conn1.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
_ = conn1.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, err := conn1.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend1:")), "should route to backend1, got: %s", string(buf[:n]))
|
||||
|
||||
conn2, err := net.DialTimeout("tcp", ln2.Addr().String(), 2*time.Second)
|
||||
require.NoError(t, err)
|
||||
defer conn2.Close()
|
||||
|
||||
_, err = conn2.Write([]byte("world"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_ = conn2.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, err = conn2.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend2:")), "should route to backend2, got: %s", string(buf[:n]))
|
||||
}
|
||||
|
||||
func TestStop(t *testing.T) {
|
||||
s := NewServer()
|
||||
|
||||
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer backendLn.Close()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := backendLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = io.Copy(conn, conn)
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
_ = s.AddUpstream("test", []TargetSpec{{Addr: backendLn.Addr().String()}}, "round_robin", HealthCheckSpec{})
|
||||
s.upstreams["test"].targets[0].healthy.Store(true)
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
s.mu.Lock()
|
||||
s.listeners["test"] = ln
|
||||
s.mu.Unlock()
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
|
||||
require.NoError(t, err)
|
||||
_ = clientConn.Close()
|
||||
|
||||
s.Stop()
|
||||
|
||||
assert.False(t, s.running.Load())
|
||||
|
||||
s.mu.RLock()
|
||||
assert.Empty(t, s.listeners)
|
||||
s.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestStop_Idempotent(t *testing.T) {
|
||||
s := NewServer()
|
||||
s.Stop()
|
||||
s.Stop()
|
||||
}
|
||||
|
||||
func TestStart_DoubleStart(t *testing.T) {
|
||||
s := NewServer()
|
||||
err := s.Start()
|
||||
require.NoError(t, err)
|
||||
err = s.Start()
|
||||
assert.Error(t, err)
|
||||
s.Stop()
|
||||
}
|
||||
|
||||
@ -29,6 +29,7 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"net"
|
||||
@ -286,18 +287,14 @@ func (i *ipHash) SelectByIP(targets []*Target, clientIP string) *Target {
|
||||
|
||||
// Server TCP/UDP Stream 代理服务器。
|
||||
type Server struct {
|
||||
// listeners TCP 监听器映射,按 upstream 名称索引
|
||||
listeners map[string]net.Listener
|
||||
// udpServers UDP 服务器映射
|
||||
udpServers map[string]*udpServer
|
||||
// upstreams 上游配置映射
|
||||
upstreams map[string]*Upstream
|
||||
// connCount 当前连接数
|
||||
connCount atomic.Int64
|
||||
// mu 读写锁,保护并发访问
|
||||
mu sync.RWMutex
|
||||
// running 运行状态标志
|
||||
running atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// Upstream Stream 上游配置。
|
||||
@ -393,6 +390,7 @@ func NewServer() *Server {
|
||||
listeners: make(map[string]net.Listener),
|
||||
udpServers: make(map[string]*udpServer),
|
||||
upstreams: make(map[string]*Upstream),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -491,25 +489,70 @@ func (s *Server) ListenUDP(addr string, upstreamName string, timeout time.Durati
|
||||
|
||||
// Start 启动 Stream 服务器。
|
||||
func (s *Server) Start() error {
|
||||
s.running.Store(true)
|
||||
if !s.running.CompareAndSwap(false, true) {
|
||||
return fmt.Errorf("stream server already running")
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// 启动 TCP 监听器
|
||||
for addr, listener := range s.listeners {
|
||||
go s.acceptLoop(addr, listener)
|
||||
s.wg.Add(1)
|
||||
go func(a string, ln net.Listener) {
|
||||
defer s.wg.Done()
|
||||
s.acceptLoop(a, ln)
|
||||
}(addr, listener)
|
||||
}
|
||||
|
||||
// 启动 UDP 服务器
|
||||
for _, udpSrv := range s.udpServers {
|
||||
go udpSrv.serve()
|
||||
go udpSrv.startCleanupTicker()
|
||||
s.wg.Add(1)
|
||||
go func(u *udpServer) {
|
||||
defer s.wg.Done()
|
||||
u.serve()
|
||||
}(udpSrv)
|
||||
s.wg.Add(1)
|
||||
go func(u *udpServer) {
|
||||
defer s.wg.Done()
|
||||
u.startCleanupTicker()
|
||||
}(udpSrv)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the stream server, closing all listeners and waiting for goroutines to finish.
|
||||
func (s *Server) Stop() {
|
||||
if !s.running.CompareAndSwap(true, false) {
|
||||
return
|
||||
}
|
||||
|
||||
close(s.stopCh)
|
||||
|
||||
s.mu.Lock()
|
||||
for _, ln := range s.listeners {
|
||||
_ = ln.Close()
|
||||
}
|
||||
for _, udpSrv := range s.udpServers {
|
||||
close(udpSrv.stopCh)
|
||||
if udpSrv.conn != nil {
|
||||
_ = udpSrv.conn.Close()
|
||||
}
|
||||
}
|
||||
for _, upstream := range s.upstreams {
|
||||
if upstream.healthChk != nil && upstream.healthChk.stopCh != nil {
|
||||
close(upstream.healthChk.stopCh)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
s.wg.Wait()
|
||||
|
||||
s.mu.Lock()
|
||||
s.listeners = make(map[string]net.Listener)
|
||||
s.udpServers = make(map[string]*udpServer)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// acceptLoop 接受连接循环。
|
||||
//
|
||||
// 在单独的 goroutine 中运行,持续接受 TCP 连接。
|
||||
@ -523,7 +566,12 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if !s.running.Load() {
|
||||
return // 正常关闭
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
continue
|
||||
}
|
||||
@ -544,19 +592,14 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) {
|
||||
// 参数:
|
||||
// - clientConn: 客户端连接
|
||||
// - addr: 监听地址
|
||||
func (s *Server) handleConnection(clientConn net.Conn, _ string) {
|
||||
func (s *Server) handleConnection(clientConn net.Conn, addr string) {
|
||||
defer func() {
|
||||
_ = clientConn.Close()
|
||||
s.connCount.Add(-1)
|
||||
}()
|
||||
|
||||
s.mu.RLock()
|
||||
// 根据监听地址找到对应 upstream(简化:用第一个)
|
||||
var upstream *Upstream
|
||||
for _, up := range s.upstreams {
|
||||
upstream = up
|
||||
break
|
||||
}
|
||||
upstream := s.upstreams[addr]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if upstream == nil {
|
||||
|
||||
@ -157,7 +157,7 @@ func TestHandleConnection_NoHealthyTarget(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.handleConnection(clientConn, "127.0.0.1:0")
|
||||
s.handleConnection(clientConn, "test2")
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@ -200,7 +200,7 @@ func TestHandleConnection_DialFail(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.handleConnection(clientConn, "127.0.0.1:0")
|
||||
s.handleConnection(clientConn, "test3")
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user