fix(stream): correct upstream selection and add graceful shutdown
- Fix handleConnection to use addr parameter for direct upstream map lookup instead of always selecting the first upstream - Add Server.Stop() for graceful shutdown with listener closing, UDP server cleanup, health checker termination, and goroutine joining - Add shutdownStream() to App and call it in SIGTERM/SIGQUIT/SIGUSR2 signal handlers to prevent goroutine and port leaks on shutdown
This commit is contained in:
parent
f12ffd180f
commit
7204432ca0
@ -123,6 +123,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
|||||||
timeout = 30 * time.Second
|
timeout = 30 * time.Second
|
||||||
}
|
}
|
||||||
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("Graceful stop (waiting %v)", timeout))
|
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("Graceful stop (waiting %v)", timeout))
|
||||||
|
a.shutdownStream()
|
||||||
a.shutdownHTTP2()
|
a.shutdownHTTP2()
|
||||||
a.shutdownHTTP3()
|
a.shutdownHTTP3()
|
||||||
_ = a.srv.GracefulStop(timeout)
|
_ = a.srv.GracefulStop(timeout)
|
||||||
@ -139,6 +140,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
|||||||
} else {
|
} else {
|
||||||
a.logger.LogSignal(sigName(sigTyped), "Stopping server")
|
a.logger.LogSignal(sigName(sigTyped), "Stopping server")
|
||||||
}
|
}
|
||||||
|
a.shutdownStream()
|
||||||
a.shutdownHTTP2()
|
a.shutdownHTTP2()
|
||||||
a.shutdownHTTP3()
|
a.shutdownHTTP3()
|
||||||
_ = a.srv.StopWithTimeout(timeout)
|
_ = a.srv.StopWithTimeout(timeout)
|
||||||
@ -329,6 +331,7 @@ func (a *App) gracefulUpgrade() {
|
|||||||
if timeout <= 0 {
|
if timeout <= 0 {
|
||||||
timeout = 30 * time.Second
|
timeout = 30 * time.Second
|
||||||
}
|
}
|
||||||
|
a.shutdownStream()
|
||||||
a.shutdownHTTP2()
|
a.shutdownHTTP2()
|
||||||
a.shutdownHTTP3()
|
a.shutdownHTTP3()
|
||||||
_ = a.srv.GracefulStop(timeout)
|
_ = a.srv.GracefulStop(timeout)
|
||||||
|
|||||||
@ -252,6 +252,12 @@ func (a *App) shutdownHTTP2() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// reopenLogs reinitializes the logger from current config.
|
// reopenLogs reinitializes the logger from current config.
|
||||||
|
func (a *App) shutdownStream() {
|
||||||
|
if a.streamSrv != nil {
|
||||||
|
a.streamSrv.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a *App) reopenLogs() {
|
func (a *App) reopenLogs() {
|
||||||
if a.cfg != nil {
|
if a.cfg != nil {
|
||||||
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)
|
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)
|
||||||
|
|||||||
@ -12,6 +12,7 @@
|
|||||||
package stream
|
package stream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -63,7 +64,7 @@ func TestStart_NoListeners(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, s.running.Load())
|
assert.True(t, s.running.Load())
|
||||||
|
|
||||||
s.running.Store(false)
|
s.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStart_WithTCPListeners(t *testing.T) {
|
func TestStart_WithTCPListeners(t *testing.T) {
|
||||||
@ -81,12 +82,7 @@ func TestStart_WithTCPListeners(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, s.running.Load())
|
assert.True(t, s.running.Load())
|
||||||
|
|
||||||
s.running.Store(false)
|
s.Stop()
|
||||||
s.mu.RLock()
|
|
||||||
for _, ln := range s.listeners {
|
|
||||||
_ = ln.Close()
|
|
||||||
}
|
|
||||||
s.mu.RUnlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStart_AcceptConnections(t *testing.T) {
|
func TestStart_AcceptConnections(t *testing.T) {
|
||||||
@ -115,7 +111,7 @@ func TestStart_AcceptConnections(t *testing.T) {
|
|||||||
proxyAddr := ln.Addr().String()
|
proxyAddr := ln.Addr().String()
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.listeners[proxyAddr] = ln
|
s.listeners["test"] = ln
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
err = s.Start()
|
err = s.Start()
|
||||||
@ -136,12 +132,7 @@ func TestStart_AcceptConnections(t *testing.T) {
|
|||||||
|
|
||||||
_ = clientConn.Close()
|
_ = clientConn.Close()
|
||||||
|
|
||||||
s.running.Store(false)
|
s.Stop()
|
||||||
s.mu.RLock()
|
|
||||||
for _, l := range s.listeners {
|
|
||||||
_ = l.Close()
|
|
||||||
}
|
|
||||||
s.mu.RUnlock()
|
|
||||||
_ = backendLn.Close()
|
_ = backendLn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -637,3 +628,138 @@ func TestStartCleanupTicker_StopsOnSignal(t *testing.T) {
|
|||||||
t.Fatal("startCleanupTicker did not stop after signal")
|
t.Fatal("startCleanupTicker did not stop after signal")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleConnection_MultipleUpstreams(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
|
||||||
|
backend1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
backend2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer backend1.Close()
|
||||||
|
defer backend2.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := backend1.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _ := conn.Read(buf)
|
||||||
|
_, _ = conn.Write(append([]byte("backend1:"), buf[:n]...))
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := backend2.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _ := conn.Read(buf)
|
||||||
|
_, _ = conn.Write(append([]byte("backend2:"), buf[:n]...))
|
||||||
|
}()
|
||||||
|
|
||||||
|
_ = s.AddUpstream("upstream1", []TargetSpec{{Addr: backend1.Addr().String()}}, "round_robin", HealthCheckSpec{})
|
||||||
|
_ = s.AddUpstream("upstream2", []TargetSpec{{Addr: backend2.Addr().String()}}, "round_robin", HealthCheckSpec{})
|
||||||
|
s.upstreams["upstream1"].targets[0].healthy.Store(true)
|
||||||
|
s.upstreams["upstream2"].targets[0].healthy.Store(true)
|
||||||
|
|
||||||
|
ln1, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
ln2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.listeners["upstream1"] = ln1
|
||||||
|
s.listeners["upstream2"] = ln2
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
err = s.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer s.Stop()
|
||||||
|
|
||||||
|
conn1, err := net.DialTimeout("tcp", ln1.Addr().String(), 2*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn1.Close()
|
||||||
|
|
||||||
|
_, err = conn1.Write([]byte("hello"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
_ = conn1.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, err := conn1.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend1:")), "should route to backend1, got: %s", string(buf[:n]))
|
||||||
|
|
||||||
|
conn2, err := net.DialTimeout("tcp", ln2.Addr().String(), 2*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn2.Close()
|
||||||
|
|
||||||
|
_, err = conn2.Write([]byte("world"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_ = conn2.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
n, err = conn2.Read(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, bytes.HasPrefix(buf[:n], []byte("backend2:")), "should route to backend2, got: %s", string(buf[:n]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStop(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
|
||||||
|
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer backendLn.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := backendLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = io.Copy(conn, conn)
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_ = s.AddUpstream("test", []TargetSpec{{Addr: backendLn.Addr().String()}}, "round_robin", HealthCheckSpec{})
|
||||||
|
s.upstreams["test"].targets[0].healthy.Store(true)
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
s.mu.Lock()
|
||||||
|
s.listeners["test"] = ln
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
err = s.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
clientConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_ = clientConn.Close()
|
||||||
|
|
||||||
|
s.Stop()
|
||||||
|
|
||||||
|
assert.False(t, s.running.Load())
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
assert.Empty(t, s.listeners)
|
||||||
|
s.mu.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStop_Idempotent(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
s.Stop()
|
||||||
|
s.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStart_DoubleStart(t *testing.T) {
|
||||||
|
s := NewServer()
|
||||||
|
err := s.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = s.Start()
|
||||||
|
assert.Error(t, err)
|
||||||
|
s.Stop()
|
||||||
|
}
|
||||||
|
|||||||
@ -29,6 +29,7 @@
|
|||||||
package stream
|
package stream
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -286,18 +287,14 @@ func (i *ipHash) SelectByIP(targets []*Target, clientIP string) *Target {
|
|||||||
|
|
||||||
// Server TCP/UDP Stream 代理服务器。
|
// Server TCP/UDP Stream 代理服务器。
|
||||||
type Server struct {
|
type Server struct {
|
||||||
// listeners TCP 监听器映射,按 upstream 名称索引
|
listeners map[string]net.Listener
|
||||||
listeners map[string]net.Listener
|
|
||||||
// udpServers UDP 服务器映射
|
|
||||||
udpServers map[string]*udpServer
|
udpServers map[string]*udpServer
|
||||||
// upstreams 上游配置映射
|
upstreams map[string]*Upstream
|
||||||
upstreams map[string]*Upstream
|
connCount atomic.Int64
|
||||||
// connCount 当前连接数
|
mu sync.RWMutex
|
||||||
connCount atomic.Int64
|
running atomic.Bool
|
||||||
// mu 读写锁,保护并发访问
|
wg sync.WaitGroup
|
||||||
mu sync.RWMutex
|
stopCh chan struct{}
|
||||||
// running 运行状态标志
|
|
||||||
running atomic.Bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upstream Stream 上游配置。
|
// Upstream Stream 上游配置。
|
||||||
@ -393,6 +390,7 @@ func NewServer() *Server {
|
|||||||
listeners: make(map[string]net.Listener),
|
listeners: make(map[string]net.Listener),
|
||||||
udpServers: make(map[string]*udpServer),
|
udpServers: make(map[string]*udpServer),
|
||||||
upstreams: make(map[string]*Upstream),
|
upstreams: make(map[string]*Upstream),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -491,25 +489,70 @@ func (s *Server) ListenUDP(addr string, upstreamName string, timeout time.Durati
|
|||||||
|
|
||||||
// Start 启动 Stream 服务器。
|
// Start 启动 Stream 服务器。
|
||||||
func (s *Server) Start() error {
|
func (s *Server) Start() error {
|
||||||
s.running.Store(true)
|
if !s.running.CompareAndSwap(false, true) {
|
||||||
|
return fmt.Errorf("stream server already running")
|
||||||
|
}
|
||||||
|
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
// 启动 TCP 监听器
|
|
||||||
for addr, listener := range s.listeners {
|
for addr, listener := range s.listeners {
|
||||||
go s.acceptLoop(addr, listener)
|
s.wg.Add(1)
|
||||||
|
go func(a string, ln net.Listener) {
|
||||||
|
defer s.wg.Done()
|
||||||
|
s.acceptLoop(a, ln)
|
||||||
|
}(addr, listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 启动 UDP 服务器
|
|
||||||
for _, udpSrv := range s.udpServers {
|
for _, udpSrv := range s.udpServers {
|
||||||
go udpSrv.serve()
|
s.wg.Add(1)
|
||||||
go udpSrv.startCleanupTicker()
|
go func(u *udpServer) {
|
||||||
|
defer s.wg.Done()
|
||||||
|
u.serve()
|
||||||
|
}(udpSrv)
|
||||||
|
s.wg.Add(1)
|
||||||
|
go func(u *udpServer) {
|
||||||
|
defer s.wg.Done()
|
||||||
|
u.startCleanupTicker()
|
||||||
|
}(udpSrv)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop stops the stream server, closing all listeners and waiting for goroutines to finish.
|
||||||
|
func (s *Server) Stop() {
|
||||||
|
if !s.running.CompareAndSwap(true, false) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
close(s.stopCh)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
for _, ln := range s.listeners {
|
||||||
|
_ = ln.Close()
|
||||||
|
}
|
||||||
|
for _, udpSrv := range s.udpServers {
|
||||||
|
close(udpSrv.stopCh)
|
||||||
|
if udpSrv.conn != nil {
|
||||||
|
_ = udpSrv.conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, upstream := range s.upstreams {
|
||||||
|
if upstream.healthChk != nil && upstream.healthChk.stopCh != nil {
|
||||||
|
close(upstream.healthChk.stopCh)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
s.wg.Wait()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.listeners = make(map[string]net.Listener)
|
||||||
|
s.udpServers = make(map[string]*udpServer)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// acceptLoop 接受连接循环。
|
// acceptLoop 接受连接循环。
|
||||||
//
|
//
|
||||||
// 在单独的 goroutine 中运行,持续接受 TCP 连接。
|
// 在单独的 goroutine 中运行,持续接受 TCP 连接。
|
||||||
@ -523,7 +566,12 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) {
|
|||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !s.running.Load() {
|
if !s.running.Load() {
|
||||||
return // 正常关闭
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-s.stopCh:
|
||||||
|
return
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -544,19 +592,14 @@ func (s *Server) acceptLoop(addr string, listener net.Listener) {
|
|||||||
// 参数:
|
// 参数:
|
||||||
// - clientConn: 客户端连接
|
// - clientConn: 客户端连接
|
||||||
// - addr: 监听地址
|
// - addr: 监听地址
|
||||||
func (s *Server) handleConnection(clientConn net.Conn, _ string) {
|
func (s *Server) handleConnection(clientConn net.Conn, addr string) {
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = clientConn.Close()
|
_ = clientConn.Close()
|
||||||
s.connCount.Add(-1)
|
s.connCount.Add(-1)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
// 根据监听地址找到对应 upstream(简化:用第一个)
|
upstream := s.upstreams[addr]
|
||||||
var upstream *Upstream
|
|
||||||
for _, up := range s.upstreams {
|
|
||||||
upstream = up
|
|
||||||
break
|
|
||||||
}
|
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
|
|
||||||
if upstream == nil {
|
if upstream == nil {
|
||||||
|
|||||||
@ -157,7 +157,7 @@ func TestHandleConnection_NoHealthyTarget(t *testing.T) {
|
|||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
s.handleConnection(clientConn, "127.0.0.1:0")
|
s.handleConnection(clientConn, "test2")
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -200,7 +200,7 @@ func TestHandleConnection_DialFail(t *testing.T) {
|
|||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
s.handleConnection(clientConn, "127.0.0.1:0")
|
s.handleConnection(clientConn, "test3")
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user