diff --git a/internal/app/app.go b/internal/app/app.go index 8bc2e60..ddbd43c 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -296,6 +296,7 @@ func (a *App) Run() int { // 创建升级管理器 a.upgradeMgr = server.NewUpgradeManager(a.srv) + a.srv.SetUpgradeManager(a.upgradeMgr) if a.pidFile != "" { a.upgradeMgr.SetPidFile(a.pidFile) _ = a.upgradeMgr.WritePid() diff --git a/internal/server/server.go b/internal/server/server.go index 0a3eeef..4773a09 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -37,6 +37,7 @@ import ( "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/logging" "rua.plus/lolly/internal/lua" + "rua.plus/lolly/internal/matcher" "rua.plus/lolly/internal/middleware" "rua.plus/lolly/internal/middleware/accesslog" "rua.plus/lolly/internal/middleware/bodylimit" @@ -71,6 +72,7 @@ type Server struct { errorPageManager *handler.ErrorPageManager fileCache *cache.FileCache pool *GoroutinePool + upgradeManager *UpgradeManager config *config.Config fastServer *fasthttp.Server fastServers []*fasthttp.Server // 多监听器模式使用 @@ -81,6 +83,7 @@ type Server struct { requests atomic.Int64 bytesSent atomic.Int64 bytesReceived atomic.Int64 + locationEngine *matcher.LocationEngine running bool } @@ -136,6 +139,17 @@ func (s *Server) SetListeners(listeners []net.Listener) { s.listeners = listeners } +// SetUpgradeManager 设置升级管理器。 +// +// 用于从外部(App 层)注入升级管理器,使服务器能够在 +// createListener 中检查热升级状态和继承的监听器。 +// +// 参数: +// - mgr: 升级管理器实例 +func (s *Server) SetUpgradeManager(mgr *UpgradeManager) { + s.upgradeManager = mgr +} + // GetTLSConfig 获取 TLS 配置。 // // 返回服务器的 TLS 配置,用于 HTTP/3 等需要 TLS 的协议。 @@ -447,6 +461,73 @@ func (s *Server) Start() error { } } +// createListener 根据配置创建监听器。 +// +// 支持两种监听器格式: +// - "unix:/path/to/socket" -> Unix domain socket +// - ":8080" / "127.0.0.1:8080" -> TCP +// +// Unix socket 模式下会自动处理: +// - 热升级时继承的监听器复用 +// - 旧 socket 文件清理 +// - socket 文件权限设置 +// +// 参数: +// - cfg: 服务器配置 +// +// 返回值: +// - net.Listener: 创建的监听器 +// - error: 创建失败时返回错误 +func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error) { + listenAddr := cfg.Listen + + if strings.HasPrefix(listenAddr, "unix:") { + // Unix Socket 模式 + socketPath := listenAddr[5:] + + // 1. 检查继承的监听器(热升级场景) + if s.upgradeManager != nil && s.upgradeManager.IsChild() { + inherited, _ := s.upgradeManager.GetInheritedListeners() + for _, ln := range inherited { + if ln.Addr().Network() == "unix" && ln.Addr().String() == socketPath { + return ln, nil + } + } + } + + // 2. 清理旧 socket 文件 + if _, err := os.Stat(socketPath); err == nil { + os.Remove(socketPath) + } + + // 3. 创建 Unix socket listener + listener, err := net.Listen("unix", socketPath) + if err != nil { + return nil, fmt.Errorf("create unix socket failed: %w", err) + } + + // 4. 设置 socket 文件权限 + mode := 0o666 + if cfg.UnixSocket.Mode > 0 { + mode = cfg.UnixSocket.Mode + } + if err := os.Chmod(socketPath, os.FileMode(mode)); err != nil { + logging.Warn().Err(err).Msg("设置 socket 文件权限失败") + } + + // 5. 设置文件所有权(需要 root 权限) + if cfg.UnixSocket.User != "" || cfg.UnixSocket.Group != "" { + // 简化处理:仅记录警告,实际实现需要 syscall.Chown + logging.Warn().Msg("Unix socket 用户/组配置需要 root 权限,已跳过") + } + + return listener, nil + } + + // TCP 模式 + return net.Listen("tcp", listenAddr) +} + // startSingleMode 单服务器模式启动。 // // 在单服务器模式下,创建单一路由器,注册代理路由和静态文件服务, @@ -462,7 +543,8 @@ func (s *Server) startSingleMode() error { // 使用 Servers[0] 配置(迁移后 Server 字段为空) serverCfg := &s.config.Servers[0] - router := handler.NewRouter() + // 创建 LocationEngine + s.locationEngine = matcher.NewLocationEngine() // 注册状态监控端点(如果配置) if s.config.Monitoring.Status.Path != "" || len(s.config.Monitoring.Status.Allow) > 0 { @@ -470,7 +552,7 @@ func (s *Server) startSingleMode() error { if err != nil { logging.Error().Msg("创建状态处理器失败: " + err.Error()) } else { - router.GET(statusHandler.Path(), statusHandler.ServeHTTP) + _ = s.locationEngine.AddExact(statusHandler.Path(), statusHandler.ServeHTTP) } } @@ -480,8 +562,8 @@ func (s *Server) startSingleMode() error { if err != nil { logging.Error().Msg("创建 pprof 处理器失败: " + err.Error()) } else { - router.GET(pprofHandler.Path(), pprofHandler.ServeHTTP) - router.GET(pprofHandler.Path()+"/{profile:*}", pprofHandler.ServeHTTP) + _ = s.locationEngine.AddExact(pprofHandler.Path(), pprofHandler.ServeHTTP) + _ = s.locationEngine.AddPrefixPriority(pprofHandler.Path()+"/", pprofHandler.ServeHTTP) } } @@ -491,15 +573,18 @@ func (s *Server) startSingleMode() error { if err != nil { logging.Error().Msg("创建缓存清理处理器失败: " + err.Error()) } else { - router.POST(purgeHandler.Path(), purgeHandler.ServeHTTP) + _ = s.locationEngine.AddExact(purgeHandler.Path(), purgeHandler.ServeHTTP) } } // 注册代理路由 - s.registerProxyRoutes(router, serverCfg) + s.registerProxyRoutesWithLocationEngine(serverCfg) // 静态文件服务 - s.registerStaticHandlers(router, serverCfg) + s.registerStaticHandlersWithLocationEngine(serverCfg) + + // 标记 LocationEngine 初始化完成 + s.locationEngine.MarkInitialized() // 构建中间件链 chain, err := s.buildMiddlewareChain(serverCfg) @@ -507,8 +592,22 @@ func (s *Server) startSingleMode() error { return err } - // 应用 GoroutinePool(如果启用) - handler := chain.Apply(router.Handler()) + // 创建主请求处理器,使用 LocationEngine 匹配路由 + locationEngine := s.locationEngine + baseHandler := func(ctx *fasthttp.RequestCtx) { + path := string(ctx.Path()) + result := locationEngine.Match(path) + if result != nil && result.Handler != nil { + result.Handler(ctx) + return + } + // 无匹配,返回 404 + ctx.SetStatusCode(404) + ctx.SetBodyString("Not Found") + } + + // 应用中间件 + handler := chain.Apply(baseHandler) if s.pool != nil { handler = s.pool.WrapHandler(handler) } @@ -535,7 +634,7 @@ func (s *Server) startSingleMode() error { s.running = true // 创建监听器并保存,用于热升级 - ln, err := net.Listen("tcp", serverCfg.Listen) + ln, err := s.createListener(serverCfg) if err != nil { return fmt.Errorf("failed to listen: %w", err) } @@ -587,7 +686,9 @@ func (s *Server) startVHostMode() error { handler = s.pool.WrapHandler(handler) } - vhostMgr.AddHost(s.config.Servers[i].Name, handler) + if err := vhostMgr.AddHost(s.config.Servers[i].Name, handler); err != nil { + return err + } } // 默认主机 @@ -839,6 +940,115 @@ func (s *Server) startMultiServerMode() error { return nil } +// registerProxyRoutesWithLocationEngine 使用 LocationEngine 注册代理路由。 +// +// 根据配置为 LocationEngine 注册代理路径,创建代理处理器和健康检查器。 +// 支持通过 LocationType 配置不同的匹配方式。 +func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) { + for i := range serverCfg.Proxy { + proxyCfg := &serverCfg.Proxy[i] + + // 转换目标 + targets := make([]*loadbalance.Target, len(proxyCfg.Targets)) + for j, t := range proxyCfg.Targets { + targets[j] = loadbalance.NewTargetFromConfig(t.URL, t.Weight) + } + + // 传递 Transport 配置和 Lua 引擎 + p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport, s.luaEngine) + if err != nil { + logging.Error().Msg("创建代理失败: " + err.Error()) + continue + } + + // 设置 DNS 解析器(如果已配置) + if s.resolver != nil { + p.SetResolver(s.resolver) + if err := p.Start(); err != nil { + logging.Error().Err(err).Msg("启动代理失败") + } + } + + // 启动健康检查 + if proxyCfg.HealthCheck.Interval > 0 { + hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck) + hc.Start() + s.healthCheckers = append(s.healthCheckers, hc) + // 设置被动健康检查 + p.SetHealthChecker(hc) + } + + // 保存代理实例用于缓存统计 + s.proxies = append(s.proxies, p) + + // 根据 LocationType 注册路由 + locType := proxyCfg.LocationType + if locType == "" { + locType = "prefix" + } + + switch locType { + case "exact": + _ = s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP) + case "prefix_priority": + _ = s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP) + case "regex", "regex_caseless": + caseInsensitive := locType == "regex_caseless" + _ = s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive) + case "named": + if proxyCfg.LocationName != "" { + _ = s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP) + } + case "prefix": + _ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP) + default: + _ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP) + } + } +} + +// registerStaticHandlersWithLocationEngine 使用 LocationEngine 注册静态文件处理器。 +func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) { + for _, static := range cfg.Static { + path := static.Path + if path == "" { + path = "/" + } + + staticHandler := handler.NewStaticHandler( + static.Root, + path, + static.Index, + true, // useSendfile + ) + if s.fileCache != nil { + staticHandler.SetFileCache(s.fileCache) + // 设置默认缓存 TTL (5s) + staticHandler.SetCacheTTL(5 * time.Second) + } + if cfg.Compression.GzipStatic { + staticHandler.SetGzipStatic(true, cfg.Compression.GzipStaticExtensions) + } + + // 根据 LocationType 注册路由 + locType := static.LocationType + if locType == "" { + locType = "prefix" + } + + switch locType { + case "exact": + _ = s.locationEngine.AddExact(path, staticHandler.Handle) + case "prefix_priority": + _ = s.locationEngine.AddPrefixPriority(path, staticHandler.Handle) + case "prefix": + _ = s.locationEngine.AddPrefix(path, staticHandler.Handle) + default: + _ = s.locationEngine.AddPrefix(path, staticHandler.Handle) + } + } +} + // registerProxyRoutes 注册代理路由。 // // 根据配置为路由器注册代理路径,创建代理处理器和健康检查器。