From f3f78b24a862f31dd8eab2016e6e87c707289087 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 3 Jun 2026 11:42:45 +0800 Subject: [PATCH] feat(server,app): implement proper config hot reload via SIGHUP createListener now checks pre-set s.listeners (Path 2) for hot reload, not just upgradeManager.IsChild() (Path 1). Add DupListener to dup FDs so old/new servers own independent listeners. Reload rebuilds HTTP/2 and HTTP/3. Add matchInheritedListener with TCP any-addr matching. Add requiresFullRestart with VHost server count detection. --- internal/app/app.go | 111 ++++++++++++++++++++++++++++ internal/app/app_test.go | 3 - internal/server/server.go | 103 +++++++++++++++++++++----- internal/server/server_test.go | 130 +++++++++++++++++++++++++++++++++ 4 files changed, 327 insertions(+), 20 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index 2cff744..987c27c 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -4,6 +4,7 @@ package app import ( "fmt" + "net" "os" "os/signal" "syscall" @@ -164,11 +165,121 @@ func (a *App) reloadConfig() { return } + if a.srv == nil { + a.cfg = newCfg + a.logger = logging.NewAppLogger(&newCfg.Logging) + a.logger.LogStartup("Config reloaded (no running server)", nil) + return + } + + if a.requiresFullRestart(newCfg) { + logging.Warn().Msg("Config requires full restart (listen address or mode changed). Use SIGUSR2 for graceful upgrade.") + return + } + + listeners := a.srv.GetListeners() + if len(listeners) == 0 { + a.logger.Error().Msg("Cannot reload: server has no saved listeners") + return + } + + duped := make([]net.Listener, len(listeners)) + for i, ln := range listeners { + duped[i], err = server.DupListener(ln) + if err != nil { + a.logger.Error().Err(err).Msg("Failed to dup listener for reload") + return + } + } + + newSrv := server.New(newCfg) + if a.resv != nil { + newSrv.SetResolver(a.resv) + } + newSrv.SetListeners(duped) + + startErr := make(chan error, 1) + go func() { + if err := newSrv.Start(); err != nil { + startErr <- err + } + }() + + select { + case err := <-startErr: + a.logger.Error().Err(err).Msg("Failed to start new server with reloaded config") + for _, ln := range duped { + _ = ln.Close() + } + return + case <-time.After(5 * time.Second): + } + + oldSrv := a.srv + oldHTTP2 := a.http2Srv + oldHTTP3 := a.http3Srv + + a.srv = newSrv a.cfg = newCfg a.logger = logging.NewAppLogger(&newCfg.Logging) + a.http2Srv = nil + a.http3Srv = nil + + a.initVariables() + a.initHTTP2() + a.initHTTP3() + + if a.upgradeMgr != nil { + a.upgradeMgr.SetListeners(newSrv.GetListeners()) + } + + go func() { + if oldHTTP2 != nil { + _ = oldHTTP2.Stop() + } + if oldHTTP3 != nil { + _ = oldHTTP3.Stop() + } + _ = oldSrv.GracefulStop(30 * time.Second) + }() + a.logger.LogStartup("Config reloaded successfully", nil) } +func (a *App) requiresFullRestart(newCfg *config.Config) bool { + if a.cfg.GetMode() != newCfg.GetMode() { + return true + } + oldMode := a.cfg.GetMode() + switch oldMode { + case config.ServerModeSingle: + if len(a.cfg.Servers) > 0 && len(newCfg.Servers) > 0 { + if a.cfg.Servers[0].Listen != newCfg.Servers[0].Listen { + return true + } + } + case config.ServerModeVHost: + if len(a.cfg.Servers) != len(newCfg.Servers) { + return true + } + if len(a.cfg.Servers) > 0 && len(newCfg.Servers) > 0 { + if a.cfg.Servers[0].Listen != newCfg.Servers[0].Listen { + return true + } + } + case config.ServerModeMultiServer: + if len(a.cfg.Servers) != len(newCfg.Servers) { + return true + } + for i := range a.cfg.Servers { + if a.cfg.Servers[i].Listen != newCfg.Servers[i].Listen { + return true + } + } + } + return false +} + func (a *App) gracefulUpgrade() { execPath, err := os.Executable() if err != nil { diff --git a/internal/app/app_test.go b/internal/app/app_test.go index dd8b32f..42e4dba 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -461,7 +461,6 @@ func TestHandleSignal_SIGINT(t *testing.T) { // TestHandleSignal_SIGHUP 测试 SIGHUP 信号处理(重载配置) func TestHandleSignal_SIGHUP(t *testing.T) { - // 创建临时配置文件 tmpDir := t.TempDir() cfgPath := filepath.Join(tmpDir, "config.yaml") cfgContent := ` @@ -1448,14 +1447,12 @@ logging: } app.logger = setupTestLogger() - // 发送 SIGHUP 信号 result := app.handleSignal(syscall.SIGHUP) if result != true { t.Error("Expected handleSignal(SIGHUP) to return true") } - // 验证配置已更新 if app.cfg.Servers[0].Listen != ":7070" { t.Errorf("Expected listen ':7070', got '%s'", app.cfg.Servers[0].Listen) } diff --git a/internal/server/server.go b/internal/server/server.go index ddba853..834d282 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -318,32 +318,31 @@ func (s *Server) Start() error { func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error) { listenAddr := cfg.Listen + if s.upgradeManager != nil && s.upgradeManager.IsChild() { + inherited, _ := s.upgradeManager.GetInheritedListeners() + if ln := s.matchInheritedListener(inherited, listenAddr); ln != nil { + return ln, nil + } + } + + if len(s.listeners) > 0 { + if ln := s.matchInheritedListener(s.listeners, listenAddr); ln != nil { + return ln, nil + } + } + if strings.HasPrefix(listenAddr, "unix:") { - // Unix Socket 模式 socketPath := listenAddr[5:] - // 1. 检查继承的监听器(热升级场景) - if s.upgradeManager != nil && s.upgradeManager.IsChild() { - inherited, _ := s.upgradeManager.GetInheritedListeners() - for _, ln := range inherited { - if ln.Addr().Network() == "unix" && ln.Addr().String() == socketPath { - return ln, nil - } - } - } - - // 2. 清理旧 socket 文件 if _, err := os.Stat(socketPath); err == nil { _ = os.Remove(socketPath) } - // 3. 创建 Unix socket listener listener, err := net.Listen("unix", socketPath) if err != nil { return nil, fmt.Errorf("create unix socket failed: %w", err) } - // 4. 设置 socket 文件权限 mode := 0o666 if cfg.UnixSocket.Mode > 0 { mode = cfg.UnixSocket.Mode @@ -352,19 +351,89 @@ func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error) logging.Warn().Err(err).Msg("Failed to set socket file permissions") } - // 5. 设置文件所有权(需要 root 权限) if cfg.UnixSocket.User != "" || cfg.UnixSocket.Group != "" { - // 简化处理:仅记录警告,实际实现需要 syscall.Chown logging.Warn().Msg("Unix socket user/group config requires root privileges, skipped") } return listener, nil } - // TCP 模式 return net.Listen("tcp", listenAddr) } +func (s *Server) matchInheritedListener(inherited []net.Listener, listenAddr string) net.Listener { + if len(inherited) == 0 { + return nil + } + + if strings.HasPrefix(listenAddr, "unix:") { + socketPath := listenAddr[5:] + for _, ln := range inherited { + if ln == nil { + continue + } + if ln.Addr().Network() == "unix" && ln.Addr().String() == socketPath { + return ln + } + } + return nil + } + + for _, ln := range inherited { + if ln == nil { + continue + } + if ln.Addr().Network() != "tcp" { + continue + } + if s.tcpAddrMatch(ln.Addr().String(), listenAddr) { + return ln + } + } + return nil +} + +func (s *Server) tcpAddrMatch(inherited, target string) bool { + if inherited == target { + return true + } + host1, port1, err1 := net.SplitHostPort(inherited) + host2, port2, err2 := net.SplitHostPort(target) + if err1 != nil || err2 != nil { + return false + } + if port1 != port2 { + return false + } + return host1 == host2 || isAnyAddr(host1) || isAnyAddr(host2) +} + +func isAnyAddr(host string) bool { + return host == "" || host == "0.0.0.0" || host == "::" || host == "[::]" +} + +// DupListener 复制 listener 的文件描述符,返回独立的 listener。 +// +// 用于热重载场景:新旧 server 各自持有独立 FD,互不影响关闭操作。 +func DupListener(ln net.Listener) (net.Listener, error) { + switch l := ln.(type) { + case *net.TCPListener: + file, err := l.File() + if err != nil { + return nil, fmt.Errorf("dup tcp listener: %w", err) + } + return net.FileListener(file) + case *net.UnixListener: + file, err := l.File() + if err != nil { + return nil, fmt.Errorf("dup unix listener: %w", err) + } + return net.FileListener(file) + default: + return nil, fmt.Errorf("unsupported listener type: %T", ln) + } +} + // startSingleMode 单服务器模式启动。 // // 在单服务器模式下,创建单一路由器,注册代理路由和静态文件服务, diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 938cb80..2a4b76c 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -989,6 +989,136 @@ func TestCreateListener_UnixSocketCleanup(t *testing.T) { defer ln.Close() } +func TestDupListener_TCP(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + duped, err := DupListener(ln) + if err != nil { + t.Fatalf("DupListener() error: %v", err) + } + defer duped.Close() + + if duped.Addr().Network() != "tcp" { + t.Errorf("expected tcp, got %s", duped.Addr().Network()) + } + if duped.Addr().String() != ln.Addr().String() { + t.Errorf("expected same address %s, got %s", ln.Addr().String(), duped.Addr().String()) + } +} + +func TestDupListener_Unix(t *testing.T) { + dir := t.TempDir() + socketPath := dir + "/dup.sock" + ln, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + duped, err := DupListener(ln) + if err != nil { + t.Fatalf("DupListener() error: %v", err) + } + defer duped.Close() +} + +func TestDupListener_Unsupported(t *testing.T) { + _, err := DupListener(struct{ net.Listener }{}) + if err == nil { + t.Error("expected error for unsupported type") + } +} + +func TestTcpAddrMatch(t *testing.T) { + s := &Server{} + + tests := []struct { + inherited string + target string + want bool + }{ + {"127.0.0.1:8080", "127.0.0.1:8080", true}, + {"0.0.0.0:8080", ":8080", true}, + {"[::]:8080", ":8080", true}, + {"0.0.0.0:8080", "0.0.0.0:8080", true}, + {"0.0.0.0:8080", "127.0.0.1:8080", true}, + {"127.0.0.1:8080", "0.0.0.0:8080", true}, + {"127.0.0.1:8080", ":9090", false}, + {"127.0.0.1:8080", "192.168.1.1:8080", false}, + } + + for _, tt := range tests { + got := s.tcpAddrMatch(tt.inherited, tt.target) + if got != tt.want { + t.Errorf("tcpAddrMatch(%q, %q) = %v, want %v", tt.inherited, tt.target, got, tt.want) + } + } +} + +func TestMatchInheritedListener_TCP(t *testing.T) { + s := &Server{} + + ln1, _ := net.Listen("tcp", "127.0.0.1:0") + defer ln1.Close() + + ln2, _ := net.Listen("tcp", "127.0.0.1:0") + defer ln2.Close() + + inherited := []net.Listener{ln1, ln2} + + result := s.matchInheritedListener(inherited, "0.0.0.0:99999") + if result != nil { + t.Error("expected nil for non-matching address") + } + + addr1 := ln1.Addr().String() + result = s.matchInheritedListener(inherited, addr1) + if result != ln1 { + t.Errorf("expected ln1 for address %s", addr1) + } +} + +func TestMatchInheritedListener_Empty(t *testing.T) { + s := &Server{} + result := s.matchInheritedListener(nil, ":8080") + if result != nil { + t.Error("expected nil for empty inherited list") + } +} + +func TestMatchInheritedListener_PresetListeners(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{Listen: "127.0.0.1:0"}}, + } + s := New(cfg) + + ln, err := s.createListener(&cfg.Servers[0]) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + s.SetListeners([]net.Listener{ln}) + + addr := ln.Addr().String() + cfg.Servers[0].Listen = addr + + matched, err := s.createListener(&cfg.Servers[0]) + if err != nil { + t.Fatalf("createListener with preset should reuse: %v", err) + } + if matched == nil { + t.Fatal("expected non-nil listener from preset match") + } + if matched.Addr().String() != addr { + t.Errorf("expected same address %s, got %s", addr, matched.Addr().String()) + } +} + // TestServer_StatsMethods 测试服务器统计方法。 func TestServer_StatsMethods(t *testing.T) { cfg := &config.Config{