diff --git a/internal/server/server.go b/internal/server/server.go index 5445e45..ab056bd 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -24,7 +24,9 @@ import ( "crypto/tls" "fmt" "net" + "os" "strings" + "sync" "sync/atomic" "time" @@ -71,6 +73,7 @@ type Server struct { pool *GoroutinePool config *config.Config fastServer *fasthttp.Server + fastServers []*fasthttp.Server // 多监听器模式使用 proxies []*proxy.Proxy listeners []net.Listener healthCheckers []*proxy.HealthChecker @@ -424,10 +427,19 @@ func (s *Server) Start() error { return err } - if s.config.HasServers() { + // 根据模式选择启动方式 + mode := s.config.GetMode() + switch mode { + case config.ServerModeSingle: + return s.startSingleMode() + case config.ServerModeVHost: return s.startVHostMode() + case config.ServerModeMultiServer: + return s.startMultiServerMode() + default: + // 默认使用单服务器模式 + return s.startSingleMode() } - return s.startSingleMode() } // startSingleMode 单服务器模式启动。 @@ -442,6 +454,9 @@ func (s *Server) Start() error { // - 静态文件服务作为 fallback 处理非代理路径的请求 // - 使用零拷贝传输优化大文件传输 func (s *Server) startSingleMode() error { + // 使用 Servers[0] 配置(迁移后 Server 字段为空) + serverCfg := &s.config.Servers[0] + router := handler.NewRouter() // 注册状态监控端点(如果配置) @@ -466,13 +481,13 @@ func (s *Server) startSingleMode() error { } // 注册代理路由 - s.registerProxyRoutes(router, &s.config.Server) + s.registerProxyRoutes(router, serverCfg) // 静态文件服务 - s.registerStaticHandlers(router, &s.config.Server) + s.registerStaticHandlers(router, serverCfg) // 构建中间件链 - chain, err := s.buildMiddlewareChain(&s.config.Server) + chain, err := s.buildMiddlewareChain(serverCfg) if err != nil { return err } @@ -489,27 +504,27 @@ func (s *Server) startSingleMode() error { s.fastServer = &fasthttp.Server{ Name: "lolly", Handler: s.handler, - ReadTimeout: s.config.Server.ReadTimeout, - WriteTimeout: s.config.Server.WriteTimeout, - IdleTimeout: s.config.Server.IdleTimeout, - MaxConnsPerIP: s.config.Server.MaxConnsPerIP, - MaxRequestsPerConn: s.config.Server.MaxRequestsPerConn, + ReadTimeout: serverCfg.ReadTimeout, + WriteTimeout: serverCfg.WriteTimeout, + IdleTimeout: serverCfg.IdleTimeout, + MaxConnsPerIP: serverCfg.MaxConnsPerIP, + MaxRequestsPerConn: serverCfg.MaxRequestsPerConn, CloseOnShutdown: true, } s.running = true // 创建监听器并保存,用于热升级 - ln, err := net.Listen("tcp", s.config.Server.Listen) + ln, err := net.Listen("tcp", serverCfg.Listen) if err != nil { return fmt.Errorf("failed to listen: %w", err) } s.listeners = []net.Listener{ln} // 检查是否配置了 SSL/TLS - if s.config.Server.SSL.Cert != "" && s.config.Server.SSL.Key != "" { + if serverCfg.SSL.Cert != "" && serverCfg.SSL.Key != "" { var err error - s.tlsManager, err = ssl.NewTLSManager(&s.config.Server.SSL) + s.tlsManager, err = ssl.NewTLSManager(&serverCfg.SSL) if err != nil { return fmt.Errorf("创建 TLS 管理器失败: %w", err) } @@ -601,30 +616,33 @@ func (s *Server) startVHostMode() error { // 包装统计追踪 s.handler = s.trackStats(s.handler) + // 使用 Servers[0] 配置(迁移后 Server 字段为空) + serverCfg := &s.config.Servers[0] + s.fastServer = &fasthttp.Server{ Name: "lolly", Handler: s.handler, - ReadTimeout: s.config.Server.ReadTimeout, - WriteTimeout: s.config.Server.WriteTimeout, - IdleTimeout: s.config.Server.IdleTimeout, - MaxConnsPerIP: s.config.Server.MaxConnsPerIP, - MaxRequestsPerConn: s.config.Server.MaxRequestsPerConn, + ReadTimeout: serverCfg.ReadTimeout, + WriteTimeout: serverCfg.WriteTimeout, + IdleTimeout: serverCfg.IdleTimeout, + MaxConnsPerIP: serverCfg.MaxConnsPerIP, + MaxRequestsPerConn: serverCfg.MaxRequestsPerConn, CloseOnShutdown: true, } s.running = true // 创建监听器并保存,用于热升级 - ln, err := net.Listen("tcp", s.config.Server.Listen) + ln, err := net.Listen("tcp", serverCfg.Listen) if err != nil { return fmt.Errorf("failed to listen: %w", err) } s.listeners = []net.Listener{ln} // 检查是否配置了 SSL/TLS - if s.config.Server.SSL.Cert != "" && s.config.Server.SSL.Key != "" { + if serverCfg.SSL.Cert != "" && serverCfg.SSL.Key != "" { var err error - s.tlsManager, err = ssl.NewTLSManager(&s.config.Server.SSL) + s.tlsManager, err = ssl.NewTLSManager(&serverCfg.SSL) if err != nil { return fmt.Errorf("创建 TLS 管理器失败: %w", err) } @@ -635,6 +653,145 @@ func (s *Server) startVHostMode() error { return s.fastServer.Serve(ln) } +// startMultiServerMode 多服务器模式启动。 +// +// 为每个配置的服务器创建独立的 fasthttp.Server 实例, +// 每个实例监听各自的地址并运行在独立的 goroutine 中。 +// +// 返回值: +// - error: 启动过程中遇到的第一个错误(或全部成功时返回 nil) +// +// 注意事项: +// - 每个服务器有独立的中间件配置 +// - 热升级场景下回退到虚拟主机模式 +// - 使用 goroutine 并行启动多个服务器 +func (s *Server) startMultiServerMode() error { + // 热升级检测:multi_server 热升级未实现,回退到 vhost 模式 + if os.Getenv("GRACEFUL_UPGRADE") == "1" { + logging.Warn().Msg("热升级模式下 multi_server 模式未实现,回退到虚拟主机模式") + return s.startVHostMode() + } + + s.fastServers = make([]*fasthttp.Server, len(s.config.Servers)) + s.listeners = make([]net.Listener, len(s.config.Servers)) + + var wg sync.WaitGroup + errCh := make(chan error, len(s.config.Servers)) + + // 并行创建监听器和 fasthttp.Server + for i := range s.config.Servers { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + serverCfg := &s.config.Servers[idx] + + // 创建监听器 + ln, err := net.Listen("tcp", serverCfg.Listen) + if err != nil { + errCh <- fmt.Errorf("监听地址 %s 失败: %w", serverCfg.Listen, err) + return + } + s.listeners[idx] = ln + + // 创建路由器 + router := handler.NewRouter() + s.registerProxyRoutes(router, serverCfg) + + // 静态文件服务 + s.registerStaticHandlers(router, serverCfg) + + // 构建独立的中间件链 + chain, err := s.buildMiddlewareChain(serverCfg) + if err != nil { + errCh <- fmt.Errorf("构建中间件链失败 (server[%d]): %w", idx, err) + return + } + + // 应用中间件 + h := chain.Apply(router.Handler()) + if s.pool != nil { + h = s.pool.WrapHandler(h) + } + h = s.trackStats(h) + + // 创建 fasthttp.Server + fastSrv := &fasthttp.Server{ + Name: "lolly", + Handler: h, + ReadTimeout: serverCfg.ReadTimeout, + WriteTimeout: serverCfg.WriteTimeout, + IdleTimeout: serverCfg.IdleTimeout, + MaxConnsPerIP: serverCfg.MaxConnsPerIP, + MaxRequestsPerConn: serverCfg.MaxRequestsPerConn, + CloseOnShutdown: true, + } + + // 检查 SSL 配置 + if serverCfg.SSL.Cert != "" && serverCfg.SSL.Key != "" { + tlsManager, err := ssl.NewTLSManager(&serverCfg.SSL) + if err != nil { + errCh <- fmt.Errorf("创建 TLS 管理器失败 (server[%d]): %w", idx, err) + return + } + fastSrv.TLSConfig = tlsManager.GetTLSConfig() + } + + s.fastServers[idx] = fastSrv + }(i) + } + + // 等待所有 goroutine 完成 + wg.Wait() + close(errCh) + + // 检查是否有错误 + var firstErr error + for err := range errCh { + if firstErr == nil { + firstErr = err + } + } + + // 如果有错误,清理已创建的监听器 + if firstErr != nil { + for _, ln := range s.listeners { + if ln != nil { + _ = ln.Close() + } + } + return firstErr + } + + s.running = true + + // 启动所有服务器 + for idx, fastSrv := range s.fastServers { + ln := s.listeners[idx] + if fastSrv == nil || ln == nil { + continue + } + + wg.Add(1) + go func(f *fasthttp.Server, l net.Listener, i int) { + defer wg.Done() + var serveErr error + if f.TLSConfig != nil { + serveErr = f.ServeTLS(l, "", "") + } else { + serveErr = f.Serve(l) + } + if serveErr != nil { + logging.Error().Err(serveErr).Msgf("服务器 [%d] 监听 %s 时发生错误", i, l.Addr()) + } + }(fastSrv, ln, idx) + } + + // 等待服务器停止(阻塞) + wg.Wait() + return nil +} + // registerProxyRoutes 注册代理路由。 // // 根据配置为路由器注册代理路径,创建代理处理器和健康检查器。 @@ -692,6 +849,68 @@ func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.S } } +// shutdownServers 并行关闭多个 fasthttp.Server 实例。 +// +// 使用 goroutine 并行关闭所有服务器,收集所有错误并返回聚合错误。 +// 部分服务器关闭失败不会影响其他服务器的关闭。 +// +// 参数: +// - ctx: 关闭上下文,用于控制超时和取消 +// - servers: 要关闭的 fasthttp.Server 实例列表 +// +// 返回值: +// - error: 聚合错误,无错误或全部成功时返回 nil +func shutdownServers(ctx context.Context, servers []*fasthttp.Server) error { + if len(servers) == 0 { + return nil + } + + var ( + mu sync.Mutex + errs []error + wg sync.WaitGroup + ) + + for _, srv := range servers { + if srv == nil { + continue + } + wg.Add(1) + go func(s *fasthttp.Server) { + defer wg.Done() + if err := s.Shutdown(); err != nil { + mu.Lock() + errs = append(errs, err) + mu.Unlock() + } + }(srv) + } + + // 等待所有关闭完成或上下文取消 + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + if len(errs) == 0 { + return nil + } + if len(errs) == 1 { + return errs[0] + } + msgs := make([]string, len(errs)) + for i, e := range errs { + msgs[i] = e.Error() + } + return fmt.Errorf("关闭服务器时发生 %d 个错误: %s", len(errs), strings.Join(msgs, "; ")) + case <-ctx.Done(): + return ctx.Err() + } +} + // StopWithTimeout 快速停止服务器(支持自定义超时)。 // // 立即停止服务器,不等待正在处理的请求完成。 @@ -747,11 +966,16 @@ func (s *Server) StopWithTimeout(timeout time.Duration) error { logging.Info().Msg("Lua 引擎已关闭") } - if s.fastServer != nil { - // 使用传入的 timeout - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + // 多服务器模式:并行关闭所有 fasthttp.Server + if len(s.fastServers) > 0 { + return shutdownServers(ctx, s.fastServers) + } + + // 单服务器模式:关闭单个 fasthttp.Server + if s.fastServer != nil { done := make(chan struct{}) go func() { _ = s.fastServer.Shutdown() @@ -762,7 +986,6 @@ func (s *Server) StopWithTimeout(timeout time.Duration) error { case <-done: return nil case <-ctx.Done(): - // timeout,直接返回 return ctx.Err() } } @@ -836,10 +1059,16 @@ func (s *Server) GracefulStop(timeout time.Duration) error { logging.Info().Msg("Lua 引擎已关闭") } - if s.fastServer != nil { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + // 多服务器模式:并行关闭所有 fasthttp.Server + if len(s.fastServers) > 0 { + return shutdownServers(ctx, s.fastServers) + } + + // 单服务器模式:关闭单个 fasthttp.Server + if s.fastServer != nil { done := make(chan struct{}) go func() { _ = s.fastServer.Shutdown()