diff --git a/internal/app/app.go b/internal/app/app.go index fc8aa6a..2cff744 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -4,196 +4,33 @@ package app import ( "fmt" - "net" "os" "os/signal" "syscall" "time" "rua.plus/lolly/internal/config" - "rua.plus/lolly/internal/http2" - "rua.plus/lolly/internal/http3" "rua.plus/lolly/internal/logging" - "rua.plus/lolly/internal/resolver" "rua.plus/lolly/internal/server" - "rua.plus/lolly/internal/stream" - "rua.plus/lolly/internal/variable" ) -// App manages the server lifecycle, including HTTP, HTTP/3, Stream servers and graceful upgrades. -type App struct { - resv resolver.Resolver - cfg *config.Config - srv *server.Server - http3Srv *http3.Server - http2Srv *http2.Server - streamSrv *stream.Server - upgradeMgr *server.UpgradeManager - logger *logging.AppLogger - cfgPath string - pidFile string - logFile string - listeners []net.Listener -} - -// NewApp creates a new App instance with the given config path. -func NewApp(cfgPath string) *App { - return &App{ - cfgPath: cfgPath, - } -} - -// SetPidFile sets the path to the PID file for the app. -func (a *App) SetPidFile(path string) { - a.pidFile = path -} - -// SetLogFile sets the path to the log file for the app. -func (a *App) SetLogFile(path string) { - a.logFile = path -} - // Run starts the application: loads config, creates servers, and handles signals. func (a *App) Run() int { - cfg, err := config.Load(a.cfgPath) - if err != nil { - fmt.Fprintf(os.Stderr, "加载配置失败: %v\n", err) + if err := a.loadAndValidateConfig(); err != nil { return 1 } - a.cfg = cfg - a.logger = logging.NewAppLogger(&cfg.Logging) - variable.SetGlobalVariables(cfg.Variables.Set) - if len(cfg.Variables.Set) > 0 { - a.logger.LogStartup("全局变量已加载", map[string]string{ - "count": fmt.Sprintf("%d", len(cfg.Variables.Set)), - }) - } + a.initVariables() // Inherit parent listeners when running as a graceful upgrade child. - if os.Getenv("GRACEFUL_UPGRADE") == "1" { - a.logger.LogStartup("检测到热升级模式,继承父进程监听器", nil) - a.upgradeMgr = server.NewUpgradeManager(nil) - listeners, err := a.upgradeMgr.GetInheritedListeners() - if err == nil && len(listeners) > 0 { - a.listeners = listeners - } - } + a.inheritListeners() - a.logger.LogStartup("配置加载成功", map[string]string{"config_path": a.cfgPath}) - - mode := a.cfg.GetMode() - if mode == config.ServerModeMultiServer { - for i, srv := range a.cfg.Servers { - a.logger.LogStartup("监听地址", map[string]string{ - "index": fmt.Sprintf("[%d]", i), - "listen": srv.Listen, - "name": srv.Name, - }) - } - } else { - a.logger.LogStartup("监听地址", map[string]string{"listen": a.cfg.Servers[0].Listen}) - } - - if a.cfg.Resolver.Enabled { - a.resv = resolver.New(&a.cfg.Resolver) - a.logger.LogStartup("DNS 解析器已启用", map[string]string{ - "addresses": fmt.Sprintf("%v", a.cfg.Resolver.Addresses), - "ttl": a.cfg.Resolver.TTL().String(), - }) - } - - a.srv = server.New(a.cfg) - - if a.resv != nil { - a.srv.SetResolver(a.resv) - } - - if len(a.listeners) > 0 { - a.srv.SetListeners(a.listeners) - } - - if len(a.cfg.Stream) > 0 { - a.streamSrv = stream.NewServer() - for _, sc := range a.cfg.Stream { - targets := make([]stream.TargetSpec, len(sc.Upstream.Targets)) - for i, t := range sc.Upstream.Targets { - targets[i] = stream.TargetSpec{ - Addr: t.Addr, - Weight: t.Weight, - } - } - - if err := a.streamSrv.AddUpstream(sc.Listen, targets, sc.Upstream.LoadBalance, stream.HealthCheckSpec{}); err != nil { - a.logger.Error().Err(err).Msg("添加 Stream 上游失败") - } - - if sc.Protocol == "udp" { - if err := a.streamSrv.ListenUDP(sc.Listen, sc.Listen, 60*time.Second); err != nil { - a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("监听 UDP 失败") - } - } else { - if err := a.streamSrv.ListenTCP(sc.Listen); err != nil { - a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("监听 TCP 失败") - } - } - } - - go func() { - a.logger.LogStartup("Stream 服务器启动中", nil) - if err := a.streamSrv.Start(); err != nil { - a.logger.Error().Err(err).Msg("Stream 服务器启动失败") - } - }() - } - - if a.cfg.HTTP3.Enabled && a.cfg.Servers[0].SSL.Cert != "" { - tlsConfig, err := a.srv.GetTLSConfig() - if err != nil { - a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/3") - } else { - a.http3Srv, err = http3.NewServer(&a.cfg.HTTP3, a.srv.GetHandler(), tlsConfig) - if err != nil { - a.logger.Error().Err(err).Msg("创建 HTTP/3 服务器失败") - } else { - go func() { - a.logger.LogStartup("HTTP/3 服务器启动中", map[string]string{"listen": a.cfg.HTTP3.Listen}) - if err := a.http3Srv.Start(); err != nil { - a.logger.Error().Err(err).Msg("HTTP/3 服务器启动失败") - } - }() - } - } - } - - if a.cfg.Servers[0].SSL.HTTP2.Enabled && a.cfg.Servers[0].SSL.Cert != "" { - tlsConfig, err := a.srv.GetTLSConfig() - if err != nil { - a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/2") - } else { - a.http2Srv, err = http2.NewServer(&a.cfg.Servers[0].SSL.HTTP2, a.srv.GetHandler(), tlsConfig) - if err != nil { - a.logger.Error().Err(err).Msg("创建 HTTP/2 服务器失败") - } else { - go func() { - a.logger.LogStartup("HTTP/2 服务器启动中", map[string]string{ - "listen": a.cfg.Servers[0].Listen, - "max_concurrent_streams": fmt.Sprintf("%d", a.cfg.Servers[0].SSL.HTTP2.MaxConcurrentStreams), - "push_enabled": fmt.Sprintf("%t", a.cfg.Servers[0].SSL.HTTP2.PushEnabled), - }) - // HTTP/2 shares the main server's listener; ALPN negotiates protocol selection. - listeners := a.srv.GetListeners() - if len(listeners) > 0 { - if err := a.http2Srv.Serve(listeners[0]); err != nil { - a.logger.Error().Err(err).Msg("HTTP/2 服务器启动失败") - } - } else { - a.logger.Error().Msg("HTTP/2 服务器启动失败: 无可用监听器") - } - }() - } - } - } + a.logServerAddresses() + a.initResolver() + a.initServer() + a.initStreamServers() + a.initHTTP3() + a.initHTTP2() a.upgradeMgr = server.NewUpgradeManager(a.srv) a.srv.SetUpgradeManager(a.upgradeMgr) @@ -207,7 +44,7 @@ func (a *App) Run() int { errChan := make(chan error, 1) go func() { - a.logger.LogStartup("HTTP 服务器启动中", nil) + a.logger.LogStartup("Starting HTTP server", nil) if err := a.srv.Start(); err != nil { errChan <- err } @@ -218,24 +55,36 @@ func (a *App) Run() int { for { select { case err := <-errChan: - a.logger.Error().Err(err).Msg("服务器启动失败") + a.logger.Error().Err(err).Msg("Server failed to start") return 1 case sig := <-sigChan: if sig == syscall.SIGINT { sigintCount++ if sigintCount >= 3 { - a.logger.LogShutdown("收到 3 次 SIGINT,强制退出") + a.logger.LogShutdown("Received 3 SIGINT, forcing exit") return 1 } } if !a.handleSignal(sig) { - a.logger.LogShutdown("服务器已停止") + a.logger.LogShutdown("Server stopped") return 0 } } } } +// inheritListeners inherits parent listeners during graceful upgrade. +func (a *App) inheritListeners() { + if os.Getenv("GRACEFUL_UPGRADE") == "1" { + a.logger.LogStartup("Graceful upgrade mode detected, inheriting parent listeners", nil) + a.upgradeMgr = server.NewUpgradeManager(nil) + listeners, err := a.upgradeMgr.GetInheritedListeners() + if err == nil && len(listeners) > 0 { + a.listeners = listeners + } + } +} + func (a *App) setupSignalHandlers(sigChan chan<- os.Signal) { signal.Notify(sigChan, syscall.SIGTERM, @@ -250,7 +99,7 @@ func (a *App) setupSignalHandlers(sigChan chan<- os.Signal) { // handleSignal returns false to indicate the app should exit. func (a *App) handleSignal(sig os.Signal) bool { if a.cfg == nil { - a.logger.Error().Msg("信号处理失败: 配置为 nil,使用默认超时") + a.logger.Error().Msg("Signal handling failed: config is nil, using default timeout") a.cfg = &config.Config{ Shutdown: config.ShutdownConfig{ GracefulTimeout: 30 * time.Second, @@ -265,7 +114,7 @@ func (a *App) handleSignal(sig os.Signal) bool { if timeout <= 0 { timeout = 30 * time.Second } - a.logger.LogSignal("SIGQUIT", fmt.Sprintf("优雅停止(等待 %v)", timeout)) + a.logger.LogSignal("SIGQUIT", fmt.Sprintf("Graceful stop (waiting %v)", timeout)) a.shutdownHTTP2() a.shutdownHTTP3() _ = a.srv.GracefulStop(timeout) @@ -278,9 +127,9 @@ func (a *App) handleSignal(sig os.Signal) bool { } sigTyped, ok := sig.(syscall.Signal) if !ok { - a.logger.LogSignal("unknown", "停止服务器") + a.logger.LogSignal("unknown", "Stopping server") } else { - a.logger.LogSignal(sigName(sigTyped), "停止服务器") + a.logger.LogSignal(sigName(sigTyped), "Stopping server") } a.shutdownHTTP2() a.shutdownHTTP3() @@ -288,89 +137,65 @@ func (a *App) handleSignal(sig os.Signal) bool { return false case syscall.SIGHUP: - a.logger.LogSignal("SIGHUP", "重载配置") + a.logger.LogSignal("SIGHUP", "Reloading config") a.reloadConfig() return true case syscall.SIGUSR1: - a.logger.LogSignal("SIGUSR1", "重新打开日志") + a.logger.LogSignal("SIGUSR1", "Reopening logs") a.reopenLogs() return true case syscall.SIGUSR2: - a.logger.LogSignal("SIGUSR2", "执行热升级") + a.logger.LogSignal("SIGUSR2", "Performing graceful upgrade") a.gracefulUpgrade() return true default: - a.logger.Info().Str("signal", sig.String()).Msg("收到未知信号") + a.logger.Info().Str("signal", sig.String()).Msg("Received unknown signal") return true } } -func (a *App) shutdownHTTP3() { - if a.http3Srv != nil { - if err := a.http3Srv.Stop(); err != nil { - a.logger.Error().Err(err).Msg("HTTP/3 服务器关闭失败") - } - } -} - -func (a *App) shutdownHTTP2() { - if a.http2Srv != nil { - if err := a.http2Srv.Stop(); err != nil { - a.logger.Error().Err(err).Msg("HTTP/2 服务器关闭失败") - } - } -} - func (a *App) reloadConfig() { newCfg, err := config.Load(a.cfgPath) if err != nil { - a.logger.Error().Err(err).Msg("重载配置失败") + a.logger.Error().Err(err).Msg("Failed to reload config") return } a.cfg = newCfg a.logger = logging.NewAppLogger(&newCfg.Logging) - a.logger.LogStartup("配置重载成功", nil) -} - -func (a *App) reopenLogs() { - if a.cfg != nil { - logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format) - a.logger = logging.NewAppLogger(&a.cfg.Logging) - } - a.logger.LogStartup("日志已重新打开", nil) + a.logger.LogStartup("Config reloaded successfully", nil) } func (a *App) gracefulUpgrade() { execPath, err := os.Executable() if err != nil { - a.logger.Error().Err(err).Msg("获取可执行文件路径失败") + a.logger.Error().Err(err).Msg("Failed to get executable path") return } if a.srv == nil { - a.logger.Error().Msg("热升级失败: 服务器实例为 nil") + a.logger.Error().Msg("Graceful upgrade failed: server instance is nil") return } listeners := a.srv.GetListeners() if len(listeners) == 0 { - a.logger.Error().Msg("热升级失败: 服务器未保存监听器(热升级当前未完全实现)") - a.logger.Info().Msg("提示: 热升级需要服务器使用手动监听器管理模式") + a.logger.Error().Msg("Graceful upgrade failed: server has no saved listeners (graceful upgrade not fully implemented)") + a.logger.Info().Msg("Hint: graceful upgrade requires the server to use manual listener management mode") return } a.upgradeMgr.SetListeners(listeners) if err := a.upgradeMgr.GracefulUpgrade(execPath); err != nil { - a.logger.Error().Err(err).Msg("热升级失败") + a.logger.Error().Err(err).Msg("Graceful upgrade failed") return } - a.logger.LogStartup("热升级已启动,新进程正在接管", nil) + a.logger.LogStartup("Graceful upgrade started, new process is taking over", nil) timeout := a.cfg.Shutdown.GracefulTimeout if timeout <= 0 { diff --git a/internal/app/app_common.go b/internal/app/app_common.go new file mode 100644 index 0000000..82fe474 --- /dev/null +++ b/internal/app/app_common.go @@ -0,0 +1,242 @@ +package app + +import ( + "fmt" + "net" + "os" + "time" + + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/http2" + "rua.plus/lolly/internal/http3" + "rua.plus/lolly/internal/logging" + "rua.plus/lolly/internal/resolver" + "rua.plus/lolly/internal/server" + "rua.plus/lolly/internal/stream" + "rua.plus/lolly/internal/variable" +) + +// App manages the server lifecycle, including HTTP, HTTP/3, Stream servers and graceful upgrades. +type App struct { + resv resolver.Resolver + cfg *config.Config + srv *server.Server + http3Srv *http3.Server + http2Srv *http2.Server + streamSrv *stream.Server + upgradeMgr *server.UpgradeManager + logger *logging.AppLogger + cfgPath string + pidFile string + logFile string + listeners []net.Listener +} + +// NewApp creates a new App instance with the given config path. +func NewApp(cfgPath string) *App { + return &App{ + cfgPath: cfgPath, + } +} + +// SetPidFile sets the path to the PID file for the app. +func (a *App) SetPidFile(path string) { + a.pidFile = path +} + +// SetLogFile sets the path to the log file for the app. +func (a *App) SetLogFile(path string) { + a.logFile = path +} + +// loadAndValidateConfig loads configuration and initializes the logger. +func (a *App) loadAndValidateConfig() error { + cfg, err := config.Load(a.cfgPath) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to load config: %v\n", err) + return err + } + a.cfg = cfg + a.logger = logging.NewAppLogger(&cfg.Logging) + return nil +} + +// initVariables loads global variables from configuration. +func (a *App) initVariables() { + variable.SetGlobalVariables(a.cfg.Variables.Set) + if len(a.cfg.Variables.Set) > 0 { + a.logger.LogStartup("Global variables loaded", map[string]string{ + "count": fmt.Sprintf("%d", len(a.cfg.Variables.Set)), + }) + } +} + +// logServerAddresses logs the listening addresses based on server mode. +func (a *App) logServerAddresses() { + a.logger.LogStartup("Config loaded successfully", map[string]string{"config_path": a.cfgPath}) + + mode := a.cfg.GetMode() + if mode == config.ServerModeMultiServer { + for i, srv := range a.cfg.Servers { + a.logger.LogStartup("Listening address", map[string]string{ + "index": fmt.Sprintf("[%d]", i), + "listen": srv.Listen, + "name": srv.Name, + }) + } + } else { + a.logger.LogStartup("Listening address", map[string]string{"listen": a.cfg.Servers[0].Listen}) + } +} + +// initResolver initializes the DNS resolver if enabled. +func (a *App) initResolver() { + if a.cfg.Resolver.Enabled { + a.resv = resolver.New(&a.cfg.Resolver) + a.logger.LogStartup("DNS resolver enabled", map[string]string{ + "addresses": fmt.Sprintf("%v", a.cfg.Resolver.Addresses), + "ttl": a.cfg.Resolver.TTL().String(), + }) + } +} + +// initServer creates the main server and sets the resolver. +func (a *App) initServer() { + a.srv = server.New(a.cfg) + + if a.resv != nil { + a.srv.SetResolver(a.resv) + } + + if len(a.listeners) > 0 { + a.srv.SetListeners(a.listeners) + } +} + +// initStreamServers configures and starts stream servers. +func (a *App) initStreamServers() { + if len(a.cfg.Stream) == 0 { + return + } + + a.streamSrv = stream.NewServer() + for _, sc := range a.cfg.Stream { + targets := make([]stream.TargetSpec, len(sc.Upstream.Targets)) + for i, t := range sc.Upstream.Targets { + targets[i] = stream.TargetSpec{ + Addr: t.Addr, + Weight: t.Weight, + } + } + + if err := a.streamSrv.AddUpstream(sc.Listen, targets, sc.Upstream.LoadBalance, stream.HealthCheckSpec{}); err != nil { + a.logger.Error().Err(err).Msg("Failed to add Stream upstream") + } + + if sc.Protocol == "udp" { + if err := a.streamSrv.ListenUDP(sc.Listen, sc.Listen, 60*time.Second); err != nil { + a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("Failed to listen on UDP") + } + } else { + if err := a.streamSrv.ListenTCP(sc.Listen); err != nil { + a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("Failed to listen on TCP") + } + } + } + + go func() { + a.logger.LogStartup("Starting Stream server", nil) + if err := a.streamSrv.Start(); err != nil { + a.logger.Error().Err(err).Msg("Stream server failed to start") + } + }() +} + +// initHTTP3 starts the HTTP/3 server if enabled. +func (a *App) initHTTP3() { + if !a.cfg.HTTP3.Enabled || a.cfg.Servers[0].SSL.Cert == "" { + return + } + + tlsConfig, err := a.srv.GetTLSConfig() + if err != nil { + a.logger.Error().Err(err).Msg("Failed to get TLS config, skipping HTTP/3") + return + } + + a.http3Srv, err = http3.NewServer(&a.cfg.HTTP3, a.srv.GetHandler(), tlsConfig) + if err != nil { + a.logger.Error().Err(err).Msg("Failed to create HTTP/3 server") + return + } + + go func() { + a.logger.LogStartup("Starting HTTP/3 server", map[string]string{"listen": a.cfg.HTTP3.Listen}) + if err := a.http3Srv.Start(); err != nil { + a.logger.Error().Err(err).Msg("HTTP/3 server failed to start") + } + }() +} + +// initHTTP2 starts the HTTP/2 server if enabled. +func (a *App) initHTTP2() { + if !a.cfg.Servers[0].SSL.HTTP2.Enabled || a.cfg.Servers[0].SSL.Cert == "" { + return + } + + tlsConfig, err := a.srv.GetTLSConfig() + if err != nil { + a.logger.Error().Err(err).Msg("Failed to get TLS config, skipping HTTP/2") + return + } + + a.http2Srv, err = http2.NewServer(&a.cfg.Servers[0].SSL.HTTP2, a.srv.GetHandler(), tlsConfig) + if err != nil { + a.logger.Error().Err(err).Msg("Failed to create HTTP/2 server") + return + } + + go func() { + a.logger.LogStartup("Starting HTTP/2 server", map[string]string{ + "listen": a.cfg.Servers[0].Listen, + "max_concurrent_streams": fmt.Sprintf("%d", a.cfg.Servers[0].SSL.HTTP2.MaxConcurrentStreams), + "push_enabled": fmt.Sprintf("%t", a.cfg.Servers[0].SSL.HTTP2.PushEnabled), + }) + // HTTP/2 shares the main server's listener; ALPN negotiates protocol selection. + listeners := a.srv.GetListeners() + if len(listeners) > 0 { + if err := a.http2Srv.Serve(listeners[0]); err != nil { + a.logger.Error().Err(err).Msg("HTTP/2 server failed to start") + } + } else { + a.logger.Error().Msg("HTTP/2 server failed to start: no available listeners") + } + }() +} + +// shutdownHTTP3 gracefully stops the HTTP/3 server. +func (a *App) shutdownHTTP3() { + if a.http3Srv != nil { + if err := a.http3Srv.Stop(); err != nil { + a.logger.Error().Err(err).Msg("Failed to shutdown HTTP/3 server") + } + } +} + +// shutdownHTTP2 gracefully stops the HTTP/2 server. +func (a *App) shutdownHTTP2() { + if a.http2Srv != nil { + if err := a.http2Srv.Stop(); err != nil { + a.logger.Error().Err(err).Msg("Failed to shutdown HTTP/2 server") + } + } +} + +// reopenLogs reinitializes the logger from current config. +func (a *App) reopenLogs() { + if a.cfg != nil { + logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format) + a.logger = logging.NewAppLogger(&a.cfg.Logging) + } + a.logger.LogStartup("Logs reopened", nil) +} diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 80cc6f1..d9847cb 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -225,7 +225,7 @@ func TestRun(t *testing.T) { genConfig: true, outputPath: filepath.Join(t.TempDir(), "config.yaml"), wantExitCode: 0, - wantContains: "配置已写入:", + wantContains: "Config written to:", }, { name: "配置文件不存在", @@ -233,7 +233,7 @@ func TestRun(t *testing.T) { genConfig: false, showVersion: false, wantExitCode: 1, - wantErrContains: "加载配置失败", + wantErrContains: "Failed to load config", }, { name: "generate 与 import 互斥", @@ -252,7 +252,7 @@ func TestRun(t *testing.T) { name: "导入 nginx 配置文件不存在", importPath: "/tmp/nginx.conf", wantExitCode: 1, - wantErrContains: "解析 nginx 配置失败", + wantErrContains: "failed to parse nginx config", }, } @@ -371,8 +371,8 @@ func TestGenerateConfig(t *testing.T) { t.Errorf("exit code = %d, want 1", exitCode) } - if !strings.Contains(stderr, "写入文件失败") { - t.Errorf("stderr 应包含 '写入文件失败', 实际输出: %q", stderr) + if !strings.Contains(stderr, "Failed to write file") { + t.Errorf("stderr should contain 'Failed to write file', actual: %q", stderr) } }) } @@ -1247,8 +1247,8 @@ func TestGenerateConfig_ErrorCase(t *testing.T) { t.Errorf("exit code = %d, want 1", exitCode) } - if !strings.Contains(stderr, "写入文件失败") { - t.Errorf("stderr 应包含 '写入文件失败', 实际输出: %q", stderr) + if !strings.Contains(stderr, "Failed to write file") { + t.Errorf("stderr should contain 'Failed to write file', actual: %q", stderr) } }) } diff --git a/internal/app/app_windows.go b/internal/app/app_windows.go index f3ce544..9d150f6 100644 --- a/internal/app/app_windows.go +++ b/internal/app/app_windows.go @@ -6,178 +6,27 @@ package app import ( "fmt" - "net" "os" "os/signal" "syscall" "time" - "rua.plus/lolly/internal/config" - "rua.plus/lolly/internal/http2" - "rua.plus/lolly/internal/http3" - "rua.plus/lolly/internal/logging" - "rua.plus/lolly/internal/resolver" "rua.plus/lolly/internal/server" - "rua.plus/lolly/internal/stream" - "rua.plus/lolly/internal/variable" ) -// App manages the server lifecycle (Windows version). -type App struct { - resv resolver.Resolver - cfg *config.Config - srv *server.Server - http3Srv *http3.Server - http2Srv *http2.Server - streamSrv *stream.Server - logger *logging.AppLogger - cfgPath string - pidFile string - logFile string - listeners []net.Listener - upgradeMgr *server.UpgradeManager -} - -func NewApp(cfgPath string) *App { - return &App{ - cfgPath: cfgPath, - } -} - -func (a *App) SetPidFile(path string) { - a.pidFile = path -} - -func (a *App) SetLogFile(path string) { - a.logFile = path -} - // Run starts the application: loads config, creates servers, and handles signals (Windows version). func (a *App) Run() int { - cfg, err := config.Load(a.cfgPath) - if err != nil { - fmt.Fprintf(os.Stderr, "加载配置失败: %v\n", err) + if err := a.loadAndValidateConfig(); err != nil { return 1 } - a.cfg = cfg - a.logger = logging.NewAppLogger(&cfg.Logging) - variable.SetGlobalVariables(cfg.Variables.Set) - if len(cfg.Variables.Set) > 0 { - a.logger.LogStartup("全局变量已加载", map[string]string{ - "count": fmt.Sprintf("%d", len(cfg.Variables.Set)), - }) - } - - a.logger.LogStartup("配置加载成功", map[string]string{"config_path": a.cfgPath}) - - mode := a.cfg.GetMode() - if mode == config.ServerModeMultiServer { - for i, srv := range a.cfg.Servers { - a.logger.LogStartup("监听地址", map[string]string{ - "index": fmt.Sprintf("[%d]", i), - "listen": srv.Listen, - "name": srv.Name, - }) - } - } else { - a.logger.LogStartup("监听地址", map[string]string{"listen": a.cfg.Servers[0].Listen}) - } - - if a.cfg.Resolver.Enabled { - a.resv = resolver.New(&a.cfg.Resolver) - a.logger.LogStartup("DNS 解析器已启用", map[string]string{ - "addresses": fmt.Sprintf("%v", a.cfg.Resolver.Addresses), - "ttl": a.cfg.Resolver.TTL().String(), - }) - } - - a.srv = server.New(a.cfg) - - if a.resv != nil { - a.srv.SetResolver(a.resv) - } - - if len(a.cfg.Stream) > 0 { - a.streamSrv = stream.NewServer() - for _, sc := range a.cfg.Stream { - targets := make([]stream.TargetSpec, len(sc.Upstream.Targets)) - for i, t := range sc.Upstream.Targets { - targets[i] = stream.TargetSpec{ - Addr: t.Addr, - Weight: t.Weight, - } - } - - if err := a.streamSrv.AddUpstream(sc.Listen, targets, sc.Upstream.LoadBalance, stream.HealthCheckSpec{}); err != nil { - a.logger.Error().Err(err).Msg("添加 Stream 上游失败") - } - - if sc.Protocol == "udp" { - if err := a.streamSrv.ListenUDP(sc.Listen, sc.Listen, 60*time.Second); err != nil { - a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("监听 UDP 失败") - } - } else { - if err := a.streamSrv.ListenTCP(sc.Listen); err != nil { - a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("监听 TCP 失败") - } - } - } - - go func() { - a.logger.LogStartup("Stream 服务器启动中", nil) - if err := a.streamSrv.Start(); err != nil { - a.logger.Error().Err(err).Msg("Stream 服务器启动失败") - } - }() - } - - if a.cfg.HTTP3.Enabled && a.cfg.Servers[0].SSL.Cert != "" { - tlsConfig, err := a.srv.GetTLSConfig() - if err != nil { - a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/3") - } else { - a.http3Srv, err = http3.NewServer(&a.cfg.HTTP3, a.srv.GetHandler(), tlsConfig) - if err != nil { - a.logger.Error().Err(err).Msg("创建 HTTP/3 服务器失败") - } else { - go func() { - a.logger.LogStartup("HTTP/3 服务器启动中", map[string]string{"listen": a.cfg.HTTP3.Listen}) - if err := a.http3Srv.Start(); err != nil { - a.logger.Error().Err(err).Msg("HTTP/3 服务器启动失败") - } - }() - } - } - } - - if a.cfg.Servers[0].SSL.HTTP2.Enabled && a.cfg.Servers[0].SSL.Cert != "" { - tlsConfig, err := a.srv.GetTLSConfig() - if err != nil { - a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/2") - } else { - a.http2Srv, err = http2.NewServer(&a.cfg.Servers[0].SSL.HTTP2, a.srv.GetHandler(), tlsConfig) - if err != nil { - a.logger.Error().Err(err).Msg("创建 HTTP/2 服务器失败") - } else { - go func() { - a.logger.LogStartup("HTTP/2 服务器启动中", map[string]string{ - "listen": a.cfg.Servers[0].Listen, - "max_concurrent_streams": fmt.Sprintf("%d", a.cfg.Servers[0].SSL.HTTP2.MaxConcurrentStreams), - "push_enabled": fmt.Sprintf("%t", a.cfg.Servers[0].SSL.HTTP2.PushEnabled), - }) - listeners := a.srv.GetListeners() - if len(listeners) > 0 { - if err := a.http2Srv.Serve(listeners[0]); err != nil { - a.logger.Error().Err(err).Msg("HTTP/2 服务器启动失败") - } - } else { - a.logger.Error().Msg("HTTP/2 服务器启动失败: 无可用监听器") - } - }() - } - } - } + a.initVariables() + a.logServerAddresses() + a.initResolver() + a.initServer() + a.initStreamServers() + a.initHTTP3() + a.initHTTP2() a.upgradeMgr = server.NewUpgradeManager(a.srv) if a.pidFile != "" { @@ -190,7 +39,7 @@ func (a *App) Run() int { errChan := make(chan error, 1) go func() { - a.logger.LogStartup("HTTP 服务器启动中", nil) + a.logger.LogStartup("Starting HTTP server", nil) if err := a.srv.Start(); err != nil { errChan <- err } @@ -201,18 +50,18 @@ func (a *App) Run() int { for { select { case err := <-errChan: - a.logger.Error().Err(err).Msg("服务器启动失败") + a.logger.Error().Err(err).Msg("Server failed to start") return 1 case sig := <-sigChan: if sig == syscall.SIGINT { sigintCount++ if sigintCount >= 3 { - a.logger.LogShutdown("收到 3 次 SIGINT,强制退出") + a.logger.LogShutdown("Received 3 SIGINT, forcing exit") return 1 } } if !a.handleSignal(sig) { - a.logger.LogShutdown("服务器已停止") + a.logger.LogShutdown("Server stopped") return 0 } } @@ -234,47 +83,23 @@ func (a *App) handleSignal(sig os.Signal) bool { if timeout <= 0 { timeout = 5 * time.Second } - a.logger.LogSignal(sigName(sig.(syscall.Signal)), "停止服务器") + a.logger.LogSignal(sigName(sig.(syscall.Signal)), "Stopping server") a.shutdownHTTP2() a.shutdownHTTP3() _ = a.srv.StopWithTimeout(timeout) return false default: - a.logger.Info().Str("signal", sig.String()).Msg("收到信号(Windows 忽略)") + a.logger.Info().Str("signal", sig.String()).Msg("Received signal (ignored on Windows)") return true } } -func (a *App) shutdownHTTP3() { - if a.http3Srv != nil { - if err := a.http3Srv.Stop(); err != nil { - a.logger.Error().Err(err).Msg("HTTP/3 服务器关闭失败") - } - } -} - -func (a *App) shutdownHTTP2() { - if a.http2Srv != nil { - if err := a.http2Srv.Stop(); err != nil { - a.logger.Error().Err(err).Msg("HTTP/2 服务器关闭失败") - } - } -} - func (a *App) reloadConfig() { // Windows stub - functionality limited } -func (a *App) reopenLogs() { - if a.cfg != nil { - logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format) - a.logger = logging.NewAppLogger(&a.cfg.Logging) - } - a.logger.LogStartup("日志已重新打开", nil) -} - func (a *App) gracefulUpgrade() { - a.logger.Info().Msg("Windows 不支持热升级") + a.logger.Info().Msg("Graceful upgrade is not supported on Windows") } func sigName(sig syscall.Signal) string { diff --git a/internal/app/import.go b/internal/app/import.go index f373d32..0df1198 100644 --- a/internal/app/import.go +++ b/internal/app/import.go @@ -48,7 +48,7 @@ func generateConfig(outputPath string) int { cfg := config.DefaultConfig() yamlData, err := config.GenerateConfigYAML(cfg) if err != nil { - fmt.Fprintf(os.Stderr, "生成配置失败: %v\n", err) + fmt.Fprintf(os.Stderr, "Failed to generate config: %v\n", err) return 1 } @@ -56,10 +56,10 @@ func generateConfig(outputPath string) int { fmt.Print(string(yamlData)) } else { if err := os.WriteFile(outputPath, yamlData, 0o644); err != nil { - fmt.Fprintf(os.Stderr, "写入文件失败: %v\n", err) + fmt.Fprintf(os.Stderr, "Failed to write file: %v\n", err) return 1 } - fmt.Printf("配置已写入: %s\n", outputPath) + fmt.Printf("Config written to: %s\n", outputPath) } return 0 } @@ -67,12 +67,12 @@ func generateConfig(outputPath string) int { func importNginxConfig(path, outputPath string) error { nginxCfg, err := nginx.ParseFile(path) if err != nil { - return fmt.Errorf("解析 nginx 配置失败: %w", err) + return fmt.Errorf("failed to parse nginx config: %w", err) } result, err := nginx.Convert(nginxCfg) if err != nil { - return fmt.Errorf("转换配置失败: %w", err) + return fmt.Errorf("failed to convert config: %w", err) } for _, w := range result.Warnings { @@ -80,26 +80,26 @@ func importNginxConfig(path, outputPath string) error { } if validateErr := config.Validate(result.Config); validateErr != nil { - return fmt.Errorf("转换后配置验证失败: %w", validateErr) + return fmt.Errorf("converted config validation failed: %w", validateErr) } yamlData, err := yaml.Marshal(result.Config) if err != nil { - return fmt.Errorf("序列化 YAML 失败: %w", err) + return fmt.Errorf("failed to marshal YAML: %w", err) } if outputPath == "" { if _, err := os.Stdout.Write(yamlData); err != nil { - return fmt.Errorf("写入标准输出失败: %w", err) + return fmt.Errorf("failed to write to stdout: %w", err) } } else { if err := os.MkdirAll(filepath.Dir(outputPath), 0o755); err != nil { - return fmt.Errorf("创建输出目录失败: %w", err) + return fmt.Errorf("failed to create output directory: %w", err) } if err := os.WriteFile(outputPath, yamlData, 0o644); err != nil { - return fmt.Errorf("写入文件失败: %w", err) + return fmt.Errorf("failed to write file: %w", err) } - fmt.Printf("配置已写入: %s\n", outputPath) + fmt.Printf("Config written to: %s\n", outputPath) } return nil diff --git a/internal/cache/purge.go b/internal/cache/purge.go index 5256e70..75b16b4 100644 --- a/internal/cache/purge.go +++ b/internal/cache/purge.go @@ -12,7 +12,6 @@ package cache import ( - "crypto/subtle" "encoding/json" "hash/fnv" "net" @@ -20,7 +19,7 @@ import ( "github.com/valyala/fasthttp" "rua.plus/lolly/internal/config" - "rua.plus/lolly/internal/netutil" + "rua.plus/lolly/internal/utils" ) // PurgeAPI 缓存清理 API 处理器。 @@ -117,26 +116,26 @@ func (p *PurgeAPI) Path() string { func (p *PurgeAPI) ServeHTTP(ctx *fasthttp.RequestCtx) { // 仅允许 POST 方法 if string(ctx.Method()) != "POST" { - p.sendError(ctx, fasthttp.StatusMethodNotAllowed, "method not allowed") + utils.SendJSONError(ctx, fasthttp.StatusMethodNotAllowed, "method not allowed") return } // 检查 IP 访问权限 - if !p.checkAccess(ctx) { - p.sendError(ctx, fasthttp.StatusForbidden, "forbidden") + if !utils.CheckIPAccess(ctx, p.allowed) { + utils.SendJSONError(ctx, fasthttp.StatusForbidden, "forbidden") return } // 检查认证 - if !p.checkAuth(ctx) { - p.sendError(ctx, fasthttp.StatusUnauthorized, "unauthorized") + if !utils.CheckTokenAuth(ctx, p.auth) { + utils.SendJSONError(ctx, fasthttp.StatusUnauthorized, "unauthorized") return } // 解析请求体 var req PurgeRequest if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { - p.sendError(ctx, fasthttp.StatusBadRequest, "invalid request body") + utils.SendJSONError(ctx, fasthttp.StatusBadRequest, "invalid request body") return } @@ -147,7 +146,7 @@ func (p *PurgeAPI) ServeHTTP(ctx *fasthttp.RequestCtx) { } else if req.Pattern != "" { deleted = p.purgeByPattern(req.Pattern) } else { - p.sendError(ctx, fasthttp.StatusBadRequest, "missing path or pattern") + utils.SendJSONError(ctx, fasthttp.StatusBadRequest, "missing path or pattern") return } @@ -157,61 +156,6 @@ func (p *PurgeAPI) ServeHTTP(ctx *fasthttp.RequestCtx) { _ = json.NewEncoder(ctx).Encode(PurgeResponse{Deleted: deleted}) } -// checkAccess 检查客户端 IP 是否在允许列表中。 -func (p *PurgeAPI) checkAccess(ctx *fasthttp.RequestCtx) bool { - // 如果没有配置允许列表,允许所有访问 - if len(p.allowed) == 0 { - return true - } - - clientIP := p.getClientIP(ctx) - if clientIP == nil { - return false - } - - // 检查是否在允许列表中 - for _, network := range p.allowed { - if network.Contains(clientIP) { - return true - } - } - - return false -} - -// checkAuth 检查认证。 -func (p *PurgeAPI) checkAuth(ctx *fasthttp.RequestCtx) bool { - // 无需认证 - if p.auth.Type == "" || p.auth.Type == "none" { - return true - } - - // Token 认证 - if p.auth.Type == "token" { - // 从 Authorization header 获取 token - authHeader := ctx.Request.Header.Peek("Authorization") - if len(authHeader) == 0 { - return false - } - - // 支持 Bearer token 格式 - authStr := string(authHeader) - if token, ok := strings.CutPrefix(authStr, "Bearer "); ok { - return subtle.ConstantTimeCompare([]byte(token), []byte(p.auth.Token)) == 1 - } - - // 也支持直接传递 token - return subtle.ConstantTimeCompare([]byte(authStr), []byte(p.auth.Token)) == 1 - } - - return false -} - -// getClientIP 从请求上下文提取客户端 IP。 -func (p *PurgeAPI) getClientIP(ctx *fasthttp.RequestCtx) net.IP { - return netutil.ExtractClientIPNet(ctx) -} - // purgeByPath 按精确路径清理缓存。 func (p *PurgeAPI) purgeByPath(path string) int { if p.cache == nil { @@ -323,9 +267,3 @@ func matchPattern(pattern, path string) bool { return MatchPattern(pattern, path) } -// sendError 发送错误响应。 -func (p *PurgeAPI) sendError(ctx *fasthttp.RequestCtx, status int, errMsg string) { - ctx.SetContentType("application/json; charset=utf-8") - ctx.SetStatusCode(status) - _ = json.NewEncoder(ctx).Encode(PurgeErrorResponse{Error: errMsg}) -} diff --git a/internal/config/validate.go b/internal/config/validate.go index f4a1e6c..90ab61f 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -525,7 +525,7 @@ func validateProxy(p *ProxyConfig) error { } // 使用 regexp.Compile 验证正则语法有效性 if _, err := regexp.Compile(p.Path); err != nil { - return fmt.Errorf("location_type 为 '%s' 时,path '%s' 不是有效正则: %v", p.LocationType, p.Path, err) + return fmt.Errorf("location_type 为 '%s' 时,path '%s' 不是有效正则: %w", p.LocationType, p.Path, err) } } @@ -1273,7 +1273,7 @@ func validateRedirectRewrite(cfg *RedirectRewriteConfig) error { patternStr = rule.Pattern[2:] // 去掉 ~* 前缀(大小写不敏感) } if _, err := regexp.Compile(patternStr); err != nil { - return fmt.Errorf("redirect_rewrite.rules[%d].pattern invalid regex: %v", i, err) + return fmt.Errorf("redirect_rewrite.rules[%d].pattern invalid regex: %w", i, err) } } } diff --git a/internal/handler/errorpage.go b/internal/handler/errorpage.go index f474c5d..7e40c25 100644 --- a/internal/handler/errorpage.go +++ b/internal/handler/errorpage.go @@ -18,6 +18,7 @@ package handler import ( + "errors" "fmt" "os" "sync" @@ -106,7 +107,11 @@ func NewErrorPageManager(cfg *config.ErrorPageConfig) (*ErrorPageManager, error) if len(loadErrors) == totalPages { // 全部加载失败,返回错误 - return nil, fmt.Errorf("所有错误页面加载失败: %v", loadErrors) + errs := make([]error, 0, len(loadErrors)) + for _, e := range loadErrors { + errs = append(errs, e) + } + return nil, fmt.Errorf("所有错误页面加载失败: %w", errors.Join(errs...)) } // 部分失败,记录警告(由调用者处理) return manager, &PartialLoadError{Errors: loadErrors} diff --git a/internal/http3/server.go b/internal/http3/server.go index 184d26c..ac7f86b 100644 --- a/internal/http3/server.go +++ b/internal/http3/server.go @@ -201,15 +201,9 @@ func (s *Server) Stop() error { s.running = false if s.http3Server != nil { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - if err := s.http3Server.Close(); err != nil { logging.Error().Err(err).Msg("HTTP/3 server close error") } - - // 等待服务完全停止 - <-ctx.Done() } logging.Info().Msg("HTTP/3 server stopped") diff --git a/internal/middleware/rewrite/rewrite.go b/internal/middleware/rewrite/rewrite.go index 616dc50..e9db6ea 100644 --- a/internal/middleware/rewrite/rewrite.go +++ b/internal/middleware/rewrite/rewrite.go @@ -25,6 +25,7 @@ import ( "github.com/valyala/fasthttp" "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/utils" "rua.plus/lolly/internal/variable" ) @@ -179,7 +180,7 @@ func (m *Middleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandl for ruleIndex < len(m.rules) { // 步骤1: 检查迭代次数是否超过限制(防止无限循环) if iterationCount >= MaxRewriteIterations { - ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError) + utils.SendError(ctx, utils.ErrInternalError) return } diff --git a/internal/middleware/security/auth.go b/internal/middleware/security/auth.go index ce3c9e7..7e759c9 100644 --- a/internal/middleware/security/auth.go +++ b/internal/middleware/security/auth.go @@ -41,6 +41,7 @@ import ( "golang.org/x/crypto/bcrypt" "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/middleware" + "rua.plus/lolly/internal/utils" ) // HashAlgorithm 表示密码哈希算法类型。 @@ -381,7 +382,7 @@ func (ba *BasicAuth) extractCredentials(ctx *fasthttp.RequestCtx) (string, strin func (ba *BasicAuth) sendAuthChallenge(ctx *fasthttp.RequestCtx) { ctx.Response.Header.Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\", charset=\"UTF-8\"", ba.realm)) - ctx.Error("Unauthorized", fasthttp.StatusUnauthorized) + utils.SendError(ctx, utils.ErrUnauthorized) } // AddUser 动态添加新用户。 diff --git a/internal/proxy/cache_handler.go b/internal/proxy/cache_handler.go new file mode 100644 index 0000000..fc4b821 --- /dev/null +++ b/internal/proxy/cache_handler.go @@ -0,0 +1,197 @@ +package proxy + +import ( + "hash/fnv" + "time" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/cache" + "rua.plus/lolly/internal/loadbalance" +) + +// buildCacheKey 构建缓存键字符串。 +// +// 使用请求方法和完整请求 URI 作为缓存键。 +// 该函数保留用于日志记录和调试场景。 +// +// 参数: +// - ctx: FastHTTP 请求上下文 +// +// 返回值: +// - string: 缓存键(格式 "METHOD:URI") +func (p *Proxy) buildCacheKey(ctx *fasthttp.RequestCtx) string { + // 使用请求方法和路径作为缓存键 + return string(ctx.Request.Header.Method()) + ":" + string(ctx.Request.URI().RequestURI()) +} + +// buildCacheKeyHash 使用 FNV-64a 计算缓存键的 uint64 哈希值。 +// 返回哈希值和原始字符串键。 +// 注意:此函数会先构建字符串键再哈希,存在双重分配。 +// 对于只需要哈希值的场景,使用 buildCacheKeyHashValue 代替。 +func (p *Proxy) buildCacheKeyHash(ctx *fasthttp.RequestCtx) (uint64, string) { + // 构建原始 key + origKey := p.buildCacheKey(ctx) + + // 使用 FNV-64a 计算哈希 + h := fnv.New64a() + h.Write([]byte(origKey)) + return h.Sum64(), origKey +} + +// buildCacheKeyHashValue 直接计算缓存键的哈希值,零字符串分配。 +// 用于只需要哈希值而不需要原始键的场景。 +func (p *Proxy) buildCacheKeyHashValue(ctx *fasthttp.RequestCtx) uint64 { + h := fnv.New64a() + h.Write(ctx.Request.Header.Method()) + h.Write([]byte(":")) + h.Write(ctx.Request.URI().RequestURI()) + return h.Sum64() +} + +// writeCachedResponse 将缓存的响应写入 FastHTTP 响应上下文。 +// +// 设置响应体、状态码、响应头,并添加 X-Cache: HIT 头标记缓存命中。 +// +// 参数: +// - ctx: FastHTTP 请求上下文 +// - entry: 缓存条目,包含响应数据和元数据 +func (p *Proxy) writeCachedResponse(ctx *fasthttp.RequestCtx, entry *cache.ProxyCacheEntry) { + ctx.Response.SetBody(entry.Data) + ctx.Response.SetStatusCode(entry.Status) + for key, value := range entry.Headers { + ctx.Response.Header.Set(key, value) + } + ctx.Response.Header.Set("X-Cache", "HIT") +} + +// backgroundRefresh 在后台异步刷新缓存条目。 +// +// 向对应的上游目标发送请求,获取最新响应并更新缓存。 +// 该方法在独立 goroutine 中运行,不阻塞主请求流程。 +// +// 参数: +// - ctx: 原始 FastHTTP 请求上下文(仅用于复制请求信息) +// - target: 要刷新的后端目标 +// - hashKey: 缓存哈希键 +// - origKey: 缓存原始键 +func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance.Target, hashKey uint64, origKey string) { + // 创建新的请求上下文副本 + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // 复制原始请求 + ctx.Request.CopyTo(req) + + // 如果启用 Revalidate,添加条件请求头 + if p.config.Cache.Revalidate { + if entry, ok, _ := p.cache.Get(hashKey, origKey); ok { + if entry.LastModified != "" { + req.Header.Set("If-Modified-Since", entry.LastModified) + } + if entry.ETag != "" { + req.Header.Set("If-None-Match", entry.ETag) + } + } + } + + // 获取客户端 + client := p.getClient(target.URL) + if client == nil { + return + } + + // 执行请求 + err := client.Do(req, resp) + if err != nil { + p.cache.ReleaseLock(hashKey, err) + return + } + + // 处理 304 Not Modified 响应 + if resp.StatusCode() == 304 { + newHeaders := make(map[string]string) + if lm := resp.Header.Peek("Last-Modified"); len(lm) > 0 { + newHeaders["Last-Modified"] = string(lm) + } + if et := resp.Header.Peek("ETag"); len(et) > 0 { + newHeaders["ETag"] = string(et) + } + p.cache.RefreshTTL(hashKey, origKey, newHeaders) + return + } + + // 提取响应头(使用 pool 复用 map) + headers, ok := headersPool.Get().(map[string]string) + if !ok { + headers = make(map[string]string, 20) + } + for k := range headers { + delete(headers, k) + } + for key, value := range resp.Header.All() { + headers[string(key)] = string(value) + } + + // 更新缓存 + p.cache.Set(hashKey, origKey, resp.Body(), headers, resp.StatusCode(), p.getCacheDuration(resp.StatusCode())) +} + +// GetCache 返回代理的 ProxyCache 实例(用于 purge handler)。 +// 如果缓存未启用,返回 nil。 +func (p *Proxy) GetCache() *cache.ProxyCache { + return p.cache +} + +// GetCacheStats 返回代理缓存的统计信息。 +// 如果缓存未启用,返回 nil。 +func (p *Proxy) GetCacheStats() *cache.ProxyCacheStats { + if p.cache == nil { + return nil + } + stats := p.cache.Stats() + return &stats +} + +// getCacheDuration 根据状态码获取缓存时间。 +// 优先级:CacheValid 配置 > MaxAge +// +// 映射规则: +// - 200-299: CacheValid.OK(0 时继承 MaxAge) +// - 301/302: CacheValid.Redirect +// - 404: CacheValid.NotFound +// - 400-499(除 404): CacheValid.ClientError +// - 500-599: CacheValid.ServerError +// - 其他: 不缓存(返回 0) +func (p *Proxy) getCacheDuration(statusCode int) time.Duration { + // 无 CacheValid 配置,使用 MaxAge + if p.config.CacheValid == nil { + return p.config.Cache.MaxAge + } + + cv := p.config.CacheValid + + switch { + case statusCode >= 200 && statusCode < 300: + if cv.OK > 0 { + return cv.OK + } + return p.config.Cache.MaxAge // 0 表示继承 MaxAge + + case statusCode == 301 || statusCode == 302: + return cv.Redirect // 0 表示不缓存 + + case statusCode == 404: + return cv.NotFound + + case statusCode >= 400 && statusCode < 500: + return cv.ClientError + + case statusCode >= 500: + return cv.ServerError + + default: + return 0 // 不缓存 + } +} diff --git a/internal/proxy/header_modifier.go b/internal/proxy/header_modifier.go new file mode 100644 index 0000000..06ed394 --- /dev/null +++ b/internal/proxy/header_modifier.go @@ -0,0 +1,166 @@ +package proxy + +import ( + "strings" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/loadbalance" + "rua.plus/lolly/internal/logging" + "rua.plus/lolly/internal/variable" +) + +// modifyRequestHeaders 在转发请求到后端之前修改请求头。 +// +// 执行以下操作: +// 1. 设置 Host header 为目标主机地址 +// 2. 提取并设置 X-Forwarded-For、X-Real-IP、X-Forwarded-Host、X-Forwarded-Proto +// 3. 应用自定义请求头配置(支持变量展开) +// 4. 移除配置的请求头 +// +// 参数: +// - ctx: FastHTTP 请求上下文 +// - target: 选中的后端目标 +func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalance.Target) { + headers := &ctx.Request.Header + + // 设置 Host header 为目标主机 + // 从 target.URL 提取 host:port(HostClient 连接需要此格式) + targetHost := extractHostFromURL(target.URL) + if targetHost != "" { + headers.Set("Host", targetHost) + } + + // 提取并设置 X-Forwarded 系列头 + fh := ExtractForwardedHeaders(ctx) + SetForwardedHeaders(headers, fh, true) + + // 从配置设置自定义请求头(支持变量展开) + if p.config.Headers.SetRequest != nil { + vc := variable.NewContext(ctx) + defer variable.ReleaseContext(vc) + for key, value := range p.config.Headers.SetRequest { + expanded := vc.Expand(value) + if containsCRLF(expanded) { + logging.Warn().Msgf("rejected CRLF in header value: %s", key) + continue + } + headers.Set(key, expanded) + } + } + + // 移除配置的请求头 + if len(p.config.Headers.Remove) > 0 { + for _, key := range p.config.Headers.Remove { + headers.Del(key) + } + } +} + +// modifyResponseHeaders 在发送给客户端之前修改响应头。 +// +// 应用自定义响应头配置,支持变量展开(如 $upstream_addr、$status 等)。 +// +// 参数: +// - ctx: FastHTTP 请求上下文 +func (p *Proxy) modifyResponseHeaders(ctx *fasthttp.RequestCtx) { + respHeaders := &ctx.Response.Header + + // 构建 PassResponse 集合(多处使用) + passSet := make(map[string]bool, len(p.config.Headers.PassResponse)) + for _, h := range p.config.Headers.PassResponse { + passSet[h] = true + } + + // PassResponse 白名单模式:仅传递列出的头部 + if len(passSet) > 0 { + var toDelete []string + for key := range respHeaders.All() { + if !passSet[string(key)] { + toDelete = append(toDelete, string(key)) + } + } + for _, k := range toDelete { + respHeaders.Del(k) + } + } + + // HideResponse:移除指定的响应头(PassResponse 优先,跳过已传递的头部) + for _, key := range p.config.Headers.HideResponse { + if !passSet[key] { + respHeaders.Del(key) + } + } + + // IgnoreHeaders:从请求和响应中移除(PassResponse 优先) + for _, key := range p.config.Headers.IgnoreHeaders { + ctx.Request.Header.Del(key) + if !passSet[key] { + respHeaders.Del(key) + } + } + + // Cookie 域/路径重写 + if p.config.Headers.CookieDomain != "" || p.config.Headers.CookiePath != "" { + p.rewriteCookies(respHeaders) + } + + // 从配置设置自定义响应头(支持变量展开) + if p.config.Headers.SetResponse != nil { + vc := variable.NewContext(ctx) + defer variable.ReleaseContext(vc) + for key, value := range p.config.Headers.SetResponse { + expanded := vc.Expand(value) + if containsCRLF(expanded) { + logging.Warn().Msgf("rejected CRLF in header value: %s", key) + continue + } + respHeaders.Set(key, expanded) + } + } +} + +// rewriteCookies 重写响应中 Set-Cookie 头的 domain 和 path。 +func (p *Proxy) rewriteCookies(respHeaders *fasthttp.ResponseHeader) { + cookieDomain := p.config.Headers.CookieDomain + cookiePath := p.config.Headers.CookiePath + if cookieDomain == "" && cookiePath == "" { + return + } + + cookies := make([]string, 0, respHeaders.Len()) + for _, value := range respHeaders.Cookies() { + cookie := string(value) + if cookieDomain != "" { + cookie = rewriteCookieAttr(cookie, "Domain", cookieDomain) + } + if cookiePath != "" { + cookie = rewriteCookieAttr(cookie, "Path", cookiePath) + } + cookies = append(cookies, cookie) + } + + if len(cookies) > 0 { + respHeaders.Del("Set-Cookie") + for _, c := range cookies { + respHeaders.Add("Set-Cookie", c) + } + } +} + +// rewriteCookieAttr 替换 Cookie 字符串中指定属性的值(大小写不敏感)。 +func rewriteCookieAttr(cookie, attr, newValue string) string { + prefix := attr + "=" + lower := strings.ToLower(cookie) + idx := strings.Index(lower, strings.ToLower(prefix)) + if idx == -1 { + return cookie + } + + start := idx + len(prefix) + end := start + for end < len(cookie) && cookie[end] != ';' && cookie[end] != ' ' { + end++ + } + + return cookie[:start] + newValue + cookie[end:] +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 8bc1b6c..4e6b5f4 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -33,10 +33,8 @@ package proxy import ( "bytes" - "context" "errors" "fmt" - "hash/fnv" "net" urlpath "path" "slices" @@ -46,7 +44,6 @@ import ( "time" "github.com/valyala/fasthttp" - glua "github.com/yuin/gopher-lua" "rua.plus/lolly/internal/cache" "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/loadbalance" @@ -902,352 +899,6 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target { return p.selectByBalancer(ctx, targets) } -// selectByLua 使用 Lua 脚本选择后端目标。 -// -// 执行配置的 Lua 脚本,脚本可通过 ngx.balancer.set_current_peer() 选择目标。 -// 如果 Lua 脚本执行失败或未调用 set_current_peer,返回 nil 表示需要使用 fallback 算法。 -// -// 参数: -// - ctx: FastHTTP 请求上下文 -// - targets: 候选目标列表 -// -// 返回值: -// - *loadbalance.Target: Lua 脚本选中的目标,nil 表示未选择 -// - error: Lua 执行失败时返回错误 -func (p *Proxy) selectByLua(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) (*loadbalance.Target, error) { - clientIP := netutil.ExtractClientIP(ctx) - - bctx := &lua.BalancerContext{ - Targets: targets, - ClientIP: clientIP, - Retries: p.config.NextUpstream.Tries, - } - - // 创建 Lua 协程 - coro, err := p.luaEngine.NewCoroutine(ctx) - if err != nil { - return nil, fmt.Errorf("create lua coroutine: %w", err) - } - defer coro.Close() - - // 注册 balancer API - L := coro.Co - ngx, ok := L.GetGlobal("ngx").(*glua.LTable) - if !ok { - return nil, fmt.Errorf("global 'ngx' is not an LTable") - } - lua.RegisterBalancerAPI(L, bctx, ngx) - - // 设置超时 - timeout := p.config.BalancerByLua.Timeout - if timeout <= 0 { - timeout = 100 * time.Millisecond - } - - // 执行脚本(带超时) - execCtx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - coro.ExecutionContext = execCtx - - err = coro.ExecuteFile(p.config.BalancerByLua.Script) - if err != nil { - return nil, fmt.Errorf("execute lua script: %w", err) - } - - // 检查是否调用了 set_current_peer - if !bctx.IsSelected() { - return nil, nil // 未选择,返回 nil 表示需使用 fallback - } - - return bctx.Selected, nil -} - -// selectByFallback 使用 fallback 负载均衡算法选择目标。 -// -// 当 Lua balancer 执行失败或未选择目标时使用。 -// 对于 IPHash 算法,会自动提取客户端 IP 进行哈希选择。 -// -// 参数: -// - ctx: FastHTTP 请求上下文 -// - targets: 候选目标列表 -// -// 返回值: -// - *loadbalance.Target: fallback 算法选中的目标 -func (p *Proxy) selectByFallback(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target { - p.mu.RLock() - balancer := p.fallbackBalancer - p.mu.RUnlock() - - if ipHash, ok := balancer.(*loadbalance.IPHash); ok { - clientIP := netutil.ExtractClientIP(ctx) - return ipHash.SelectByIP(targets, clientIP) - } - - return balancer.Select(targets) -} - -// selectByBalancer 使用主负载均衡器选择目标。 -// -// 对于特殊算法(IPHash、ConsistentHash),会从请求上下文中提取 -// 相应的哈希键(客户端 IP、URI、自定义 Header 等)。 -// -// 参数: -// - ctx: FastHTTP 请求上下文 -// - targets: 候选目标列表 -// -// 返回值: -// - *loadbalance.Target: 主负载均衡器选中的目标 -func (p *Proxy) selectByBalancer(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target { - p.mu.RLock() - balancer := p.balancer - p.mu.RUnlock() - - // 对于 IPHash 负载均衡器,提取客户端 IP - if ipHash, ok := balancer.(*loadbalance.IPHash); ok { - clientIP := netutil.ExtractClientIP(ctx) - return ipHash.SelectByIP(targets, clientIP) - } - - // 对于一致性哈希,根据 hash_key 配置选择 - if ch, ok := balancer.(*loadbalance.ConsistentHash); ok { - hashKey := ch.GetHashKey() - key := p.extractHashKey(ctx, hashKey) - return ch.SelectByKey(targets, key) - } - - return balancer.Select(targets) -} - -// selectTargetExcluding 选择后端目标,排除已尝试失败的目标。 -// 用于故障转移场景,避免重复选择已失败的目标。 -// 如果没有可用的健康目标则返回 nil。 -func (p *Proxy) selectTargetExcluding(ctx *fasthttp.RequestCtx, excluded []*loadbalance.Target) *loadbalance.Target { - p.mu.RLock() - balancer := p.balancer - targets := p.targets - p.mu.RUnlock() - - if len(targets) == 0 { - return nil - } - - // 对于 IPHash 负载均衡器,提取客户端 IP - if ipHash, ok := balancer.(*loadbalance.IPHash); ok { - clientIP := netutil.ExtractClientIP(ctx) - return ipHash.SelectExcludingByIP(targets, excluded, clientIP) - } - - // 对于一致性哈希,根据 hash_key 配置选择 - if ch, ok := balancer.(*loadbalance.ConsistentHash); ok { - hashKey := ch.GetHashKey() - key := p.extractHashKey(ctx, hashKey) - return ch.SelectExcludingByKey(targets, excluded, key) - } - - return balancer.SelectExcluding(targets, excluded) -} - -// extractHashKey 根据一致性哈希配置提取哈希键值。 -// -// 支持的 hash_key 配置: -// - "ip" 或 "": 使用客户端 IP 地址 -// - "uri": 使用完整请求 URI -// - "header:NAME": 使用指定请求头的值,缺失时回退到客户端 IP -// -// 参数: -// - ctx: FastHTTP 请求上下文 -// - hashKey: 哈希键配置 -// -// 返回值: -// - string: 提取的哈希键值 -func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string { - switch { - case hashKey == "ip" || hashKey == "": - return netutil.ExtractClientIP(ctx) - case hashKey == "uri": - return string(ctx.RequestURI()) - case strings.HasPrefix(hashKey, "header:"): - headerName := strings.TrimPrefix(hashKey, "header:") - value := ctx.Request.Header.Peek(headerName) - if len(value) > 0 { - return string(value) - } - return netutil.ExtractClientIP(ctx) // fallback to IP - default: - return netutil.ExtractClientIP(ctx) - } -} - -// getClient 返回指定目标 URL 对应的 HostClient 连接池实例。 -// 如果目标 URL 不存在于连接池中,返回 nil。 -func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient { - key := targetURL - if p.config.ProxyBind != "" { - key = targetURL + "|" + p.config.ProxyBind - } - p.mu.RLock() - client := p.clients[key] - p.mu.RUnlock() - return client -} - -// modifyRequestHeaders 在转发请求到后端之前修改请求头。 -// -// 执行以下操作: -// 1. 设置 Host header 为目标主机地址 -// 2. 提取并设置 X-Forwarded-For、X-Real-IP、X-Forwarded-Host、X-Forwarded-Proto -// 3. 应用自定义请求头配置(支持变量展开) -// 4. 移除配置的请求头 -// -// 参数: -// - ctx: FastHTTP 请求上下文 -// - target: 选中的后端目标 -func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalance.Target) { - headers := &ctx.Request.Header - - // 设置 Host header 为目标主机 - // 从 target.URL 提取 host:port(HostClient 连接需要此格式) - targetHost := extractHostFromURL(target.URL) - if targetHost != "" { - headers.Set("Host", targetHost) - } - - // 提取并设置 X-Forwarded 系列头 - fh := ExtractForwardedHeaders(ctx) - SetForwardedHeaders(headers, fh, true) - - // 从配置设置自定义请求头(支持变量展开) - if p.config.Headers.SetRequest != nil { - vc := variable.NewContext(ctx) - defer variable.ReleaseContext(vc) - for key, value := range p.config.Headers.SetRequest { - expanded := vc.Expand(value) - if containsCRLF(expanded) { - logging.Warn().Msgf("rejected CRLF in header value: %s", key) - continue - } - headers.Set(key, expanded) - } - } - - // 移除配置的请求头 - if len(p.config.Headers.Remove) > 0 { - for _, key := range p.config.Headers.Remove { - headers.Del(key) - } - } -} - -// modifyResponseHeaders 在发送给客户端之前修改响应头。 -// -// 应用自定义响应头配置,支持变量展开(如 $upstream_addr、$status 等)。 -// -// 参数: -// - ctx: FastHTTP 请求上下文 -func (p *Proxy) modifyResponseHeaders(ctx *fasthttp.RequestCtx) { - respHeaders := &ctx.Response.Header - - // 构建 PassResponse 集合(多处使用) - passSet := make(map[string]bool, len(p.config.Headers.PassResponse)) - for _, h := range p.config.Headers.PassResponse { - passSet[h] = true - } - - // PassResponse 白名单模式:仅传递列出的头部 - if len(passSet) > 0 { - var toDelete []string - for key := range respHeaders.All() { - if !passSet[string(key)] { - toDelete = append(toDelete, string(key)) - } - } - for _, k := range toDelete { - respHeaders.Del(k) - } - } - - // HideResponse:移除指定的响应头(PassResponse 优先,跳过已传递的头部) - for _, key := range p.config.Headers.HideResponse { - if !passSet[key] { - respHeaders.Del(key) - } - } - - // IgnoreHeaders:从请求和响应中移除(PassResponse 优先) - for _, key := range p.config.Headers.IgnoreHeaders { - ctx.Request.Header.Del(key) - if !passSet[key] { - respHeaders.Del(key) - } - } - - // Cookie 域/路径重写 - if p.config.Headers.CookieDomain != "" || p.config.Headers.CookiePath != "" { - p.rewriteCookies(respHeaders) - } - - // 从配置设置自定义响应头(支持变量展开) - if p.config.Headers.SetResponse != nil { - vc := variable.NewContext(ctx) - defer variable.ReleaseContext(vc) - for key, value := range p.config.Headers.SetResponse { - expanded := vc.Expand(value) - if containsCRLF(expanded) { - logging.Warn().Msgf("rejected CRLF in header value: %s", key) - continue - } - respHeaders.Set(key, expanded) - } - } -} - -// rewriteCookies 重写响应中 Set-Cookie 头的 domain 和 path。 -func (p *Proxy) rewriteCookies(respHeaders *fasthttp.ResponseHeader) { - cookieDomain := p.config.Headers.CookieDomain - cookiePath := p.config.Headers.CookiePath - if cookieDomain == "" && cookiePath == "" { - return - } - - cookies := make([]string, 0, respHeaders.Len()) - for _, value := range respHeaders.Cookies() { - cookie := string(value) - if cookieDomain != "" { - cookie = rewriteCookieAttr(cookie, "Domain", cookieDomain) - } - if cookiePath != "" { - cookie = rewriteCookieAttr(cookie, "Path", cookiePath) - } - cookies = append(cookies, cookie) - } - - if len(cookies) > 0 { - respHeaders.Del("Set-Cookie") - for _, c := range cookies { - respHeaders.Add("Set-Cookie", c) - } - } -} - -// rewriteCookieAttr 替换 Cookie 字符串中指定属性的值(大小写不敏感)。 -func rewriteCookieAttr(cookie, attr, newValue string) string { - prefix := attr + "=" - lower := strings.ToLower(cookie) - idx := strings.Index(lower, strings.ToLower(prefix)) - if idx == -1 { - return cookie - } - - start := idx + len(prefix) - end := start - for end < len(cookie) && cookie[end] != ';' && cookie[end] != ' ' { - end++ - } - - return cookie[:start] + newValue + cookie[end:] -} - // isWebSocketRequest 检查请求是否为 WebSocket 升级请求。 // // 通过检查 Connection 和 Upgrade 请求头判断: @@ -1333,151 +984,6 @@ func (p *Proxy) GetConfig() *config.ProxyConfig { return p.config } -// buildCacheKey 构建缓存键字符串。 -// -// 使用请求方法和完整请求 URI 作为缓存键。 -// 该函数保留用于日志记录和调试场景。 -// -// 参数: -// - ctx: FastHTTP 请求上下文 -// -// 返回值: -// - string: 缓存键(格式 "METHOD:URI") -func (p *Proxy) buildCacheKey(ctx *fasthttp.RequestCtx) string { - // 使用请求方法和路径作为缓存键 - return string(ctx.Request.Header.Method()) + ":" + string(ctx.Request.URI().RequestURI()) -} - -// buildCacheKeyHash 使用 FNV-64a 计算缓存键的 uint64 哈希值。 -// 返回哈希值和原始字符串键。 -// 注意:此函数会先构建字符串键再哈希,存在双重分配。 -// 对于只需要哈希值的场景,使用 buildCacheKeyHashValue 代替。 -func (p *Proxy) buildCacheKeyHash(ctx *fasthttp.RequestCtx) (uint64, string) { - // 构建原始 key - origKey := p.buildCacheKey(ctx) - - // 使用 FNV-64a 计算哈希 - h := fnv.New64a() - h.Write([]byte(origKey)) - return h.Sum64(), origKey -} - -// buildCacheKeyHashValue 直接计算缓存键的哈希值,零字符串分配。 -// 用于只需要哈希值而不需要原始键的场景。 -func (p *Proxy) buildCacheKeyHashValue(ctx *fasthttp.RequestCtx) uint64 { - h := fnv.New64a() - h.Write(ctx.Request.Header.Method()) - h.Write([]byte(":")) - h.Write(ctx.Request.URI().RequestURI()) - return h.Sum64() -} - -// writeCachedResponse 将缓存的响应写入 FastHTTP 响应上下文。 -// -// 设置响应体、状态码、响应头,并添加 X-Cache: HIT 头标记缓存命中。 -// -// 参数: -// - ctx: FastHTTP 请求上下文 -// - entry: 缓存条目,包含响应数据和元数据 -func (p *Proxy) writeCachedResponse(ctx *fasthttp.RequestCtx, entry *cache.ProxyCacheEntry) { - ctx.Response.SetBody(entry.Data) - ctx.Response.SetStatusCode(entry.Status) - for key, value := range entry.Headers { - ctx.Response.Header.Set(key, value) - } - ctx.Response.Header.Set("X-Cache", "HIT") -} - -// backgroundRefresh 在后台异步刷新缓存条目。 -// -// 向对应的上游目标发送请求,获取最新响应并更新缓存。 -// 该方法在独立 goroutine 中运行,不阻塞主请求流程。 -// -// 参数: -// - ctx: 原始 FastHTTP 请求上下文(仅用于复制请求信息) -// - target: 要刷新的后端目标 -// - hashKey: 缓存哈希键 -// - origKey: 缓存原始键 -func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance.Target, hashKey uint64, origKey string) { - // 创建新的请求上下文副本 - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - // 复制原始请求 - ctx.Request.CopyTo(req) - - // 如果启用 Revalidate,添加条件请求头 - if p.config.Cache.Revalidate { - if entry, ok, _ := p.cache.Get(hashKey, origKey); ok { - if entry.LastModified != "" { - req.Header.Set("If-Modified-Since", entry.LastModified) - } - if entry.ETag != "" { - req.Header.Set("If-None-Match", entry.ETag) - } - } - } - - // 获取客户端 - client := p.getClient(target.URL) - if client == nil { - return - } - - // 执行请求 - err := client.Do(req, resp) - if err != nil { - p.cache.ReleaseLock(hashKey, err) - return - } - - // 处理 304 Not Modified 响应 - if resp.StatusCode() == 304 { - newHeaders := make(map[string]string) - if lm := resp.Header.Peek("Last-Modified"); len(lm) > 0 { - newHeaders["Last-Modified"] = string(lm) - } - if et := resp.Header.Peek("ETag"); len(et) > 0 { - newHeaders["ETag"] = string(et) - } - p.cache.RefreshTTL(hashKey, origKey, newHeaders) - return - } - - // 提取响应头(使用 pool 复用 map) - headers, ok := headersPool.Get().(map[string]string) - if !ok { - headers = make(map[string]string, 20) - } - for k := range headers { - delete(headers, k) - } - for key, value := range resp.Header.All() { - headers[string(key)] = string(value) - } - - // 更新缓存 - p.cache.Set(hashKey, origKey, resp.Body(), headers, resp.StatusCode(), p.getCacheDuration(resp.StatusCode())) -} - -// GetCache 返回代理的 ProxyCache 实例(用于 purge handler)。 -// 如果缓存未启用,返回 nil。 -func (p *Proxy) GetCache() *cache.ProxyCache { - return p.cache -} - -// GetCacheStats 返回代理缓存的统计信息。 -// 如果缓存未启用,返回 nil。 -func (p *Proxy) GetCacheStats() *cache.ProxyCacheStats { - if p.cache == nil { - return nil - } - stats := p.cache.Stats() - return &stats -} - // extractHostFromURL 从 URL 字符串中提取 host:port 部分。 // // 移除 http:// 或 https:// 协议前缀,以及路径部分, @@ -1505,44 +1011,3 @@ func extractHostFromURL(urlStr string) string { return host } -// getCacheDuration 根据状态码获取缓存时间。 -// 优先级:CacheValid 配置 > MaxAge -// -// 映射规则: -// - 200-299: CacheValid.OK(0 时继承 MaxAge) -// - 301/302: CacheValid.Redirect -// - 404: CacheValid.NotFound -// - 400-499(除 404): CacheValid.ClientError -// - 500-599: CacheValid.ServerError -// - 其他: 不缓存(返回 0) -func (p *Proxy) getCacheDuration(statusCode int) time.Duration { - // 无 CacheValid 配置,使用 MaxAge - if p.config.CacheValid == nil { - return p.config.Cache.MaxAge - } - - cv := p.config.CacheValid - - switch { - case statusCode >= 200 && statusCode < 300: - if cv.OK > 0 { - return cv.OK - } - return p.config.Cache.MaxAge // 0 表示继承 MaxAge - - case statusCode == 301 || statusCode == 302: - return cv.Redirect // 0 表示不缓存 - - case statusCode == 404: - return cv.NotFound - - case statusCode >= 400 && statusCode < 500: - return cv.ClientError - - case statusCode >= 500: - return cv.ServerError - - default: - return 0 // 不缓存 - } -} diff --git a/internal/proxy/proxy_ssl.go b/internal/proxy/proxy_ssl.go index f4596fa..2efb824 100644 --- a/internal/proxy/proxy_ssl.go +++ b/internal/proxy/proxy_ssl.go @@ -30,6 +30,7 @@ import ( "strings" "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/logging" ) // tlsVersionMap TLS 版本字符串到 tls 常量的映射表。 @@ -100,7 +101,9 @@ func CreateTLSConfig(cfg *config.ProxySSLConfig, defaultServerName string) (*tls tlsCfg.Certificates = []tls.Certificate{cert} } - // TLS 版本配置 + // TLS 版本配置:默认 MinVersion = TLS 1.2 + tlsCfg.MinVersion = tls.VersionTLS12 + if cfg.MinVersion != "" { version, ok := tlsVersionMap[strings.ToUpper(cfg.MinVersion)] if !ok { @@ -109,6 +112,11 @@ func CreateTLSConfig(cfg *config.ProxySSLConfig, defaultServerName string) (*tls tlsCfg.MinVersion = version } + // 警告:TLS 1.0/1.1 已不安全,不应在生产环境使用 + if tlsCfg.MinVersion < tls.VersionTLS12 { + logging.Warn().Msgf("上游 TLS MinVersion 设置为 %s(低于 TLS 1.2),存在安全风险", cfg.MinVersion) + } + if cfg.MaxVersion != "" { version, ok := tlsVersionMap[strings.ToUpper(cfg.MaxVersion)] if !ok { diff --git a/internal/proxy/target_selector.go b/internal/proxy/target_selector.go new file mode 100644 index 0000000..6393707 --- /dev/null +++ b/internal/proxy/target_selector.go @@ -0,0 +1,204 @@ +package proxy + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/valyala/fasthttp" + glua "github.com/yuin/gopher-lua" + "rua.plus/lolly/internal/loadbalance" + "rua.plus/lolly/internal/lua" + "rua.plus/lolly/internal/netutil" +) + +// selectByLua 使用 Lua 脚本选择后端目标。 +// +// 执行配置的 Lua 脚本,脚本可通过 ngx.balancer.set_current_peer() 选择目标。 +// 如果 Lua 脚本执行失败或未调用 set_current_peer,返回 nil 表示需要使用 fallback 算法。 +// +// 参数: +// - ctx: FastHTTP 请求上下文 +// - targets: 候选目标列表 +// +// 返回值: +// - *loadbalance.Target: Lua 脚本选中的目标,nil 表示未选择 +// - error: Lua 执行失败时返回错误 +func (p *Proxy) selectByLua(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) (*loadbalance.Target, error) { + clientIP := netutil.ExtractClientIP(ctx) + + bctx := &lua.BalancerContext{ + Targets: targets, + ClientIP: clientIP, + Retries: p.config.NextUpstream.Tries, + } + + // 创建 Lua 协程 + coro, err := p.luaEngine.NewCoroutine(ctx) + if err != nil { + return nil, fmt.Errorf("create lua coroutine: %w", err) + } + defer coro.Close() + + // 注册 balancer API + L := coro.Co + ngx, ok := L.GetGlobal("ngx").(*glua.LTable) + if !ok { + return nil, fmt.Errorf("global 'ngx' is not an LTable") + } + lua.RegisterBalancerAPI(L, bctx, ngx) + + // 设置超时 + timeout := p.config.BalancerByLua.Timeout + if timeout <= 0 { + timeout = 100 * time.Millisecond + } + + // 执行脚本(带超时) + execCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + coro.ExecutionContext = execCtx + + err = coro.ExecuteFile(p.config.BalancerByLua.Script) + if err != nil { + return nil, fmt.Errorf("execute lua script: %w", err) + } + + // 检查是否调用了 set_current_peer + if !bctx.IsSelected() { + return nil, nil // 未选择,返回 nil 表示需使用 fallback + } + + return bctx.Selected, nil +} + +// selectByFallback 使用 fallback 负载均衡算法选择目标。 +// +// 当 Lua balancer 执行失败或未选择目标时使用。 +// 对于 IPHash 算法,会自动提取客户端 IP 进行哈希选择。 +// +// 参数: +// - ctx: FastHTTP 请求上下文 +// - targets: 候选目标列表 +// +// 返回值: +// - *loadbalance.Target: fallback 算法选中的目标 +func (p *Proxy) selectByFallback(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target { + p.mu.RLock() + balancer := p.fallbackBalancer + p.mu.RUnlock() + + if ipHash, ok := balancer.(*loadbalance.IPHash); ok { + clientIP := netutil.ExtractClientIP(ctx) + return ipHash.SelectByIP(targets, clientIP) + } + + return balancer.Select(targets) +} + +// selectByBalancer 使用主负载均衡器选择目标。 +// +// 对于特殊算法(IPHash、ConsistentHash),会从请求上下文中提取 +// 相应的哈希键(客户端 IP、URI、自定义 Header 等)。 +// +// 参数: +// - ctx: FastHTTP 请求上下文 +// - targets: 候选目标列表 +// +// 返回值: +// - *loadbalance.Target: 主负载均衡器选中的目标 +func (p *Proxy) selectByBalancer(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target { + p.mu.RLock() + balancer := p.balancer + p.mu.RUnlock() + + // 对于 IPHash 负载均衡器,提取客户端 IP + if ipHash, ok := balancer.(*loadbalance.IPHash); ok { + clientIP := netutil.ExtractClientIP(ctx) + return ipHash.SelectByIP(targets, clientIP) + } + + // 对于一致性哈希,根据 hash_key 配置选择 + if ch, ok := balancer.(*loadbalance.ConsistentHash); ok { + hashKey := ch.GetHashKey() + key := p.extractHashKey(ctx, hashKey) + return ch.SelectByKey(targets, key) + } + + return balancer.Select(targets) +} + +// selectTargetExcluding 选择后端目标,排除已尝试失败的目标。 +// 用于故障转移场景,避免重复选择已失败的目标。 +// 如果没有可用的健康目标则返回 nil。 +func (p *Proxy) selectTargetExcluding(ctx *fasthttp.RequestCtx, excluded []*loadbalance.Target) *loadbalance.Target { + p.mu.RLock() + balancer := p.balancer + targets := p.targets + p.mu.RUnlock() + + if len(targets) == 0 { + return nil + } + + // 对于 IPHash 负载均衡器,提取客户端 IP + if ipHash, ok := balancer.(*loadbalance.IPHash); ok { + clientIP := netutil.ExtractClientIP(ctx) + return ipHash.SelectExcludingByIP(targets, excluded, clientIP) + } + + // 对于一致性哈希,根据 hash_key 配置选择 + if ch, ok := balancer.(*loadbalance.ConsistentHash); ok { + hashKey := ch.GetHashKey() + key := p.extractHashKey(ctx, hashKey) + return ch.SelectExcludingByKey(targets, excluded, key) + } + + return balancer.SelectExcluding(targets, excluded) +} + +// extractHashKey 根据一致性哈希配置提取哈希键值。 +// +// 支持的 hash_key 配置: +// - "ip" 或 "": 使用客户端 IP 地址 +// - "uri": 使用完整请求 URI +// - "header:NAME": 使用指定请求头的值,缺失时回退到客户端 IP +// +// 参数: +// - ctx: FastHTTP 请求上下文 +// - hashKey: 哈希键配置 +// +// 返回值: +// - string: 提取的哈希键值 +func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string { + switch { + case hashKey == "ip" || hashKey == "": + return netutil.ExtractClientIP(ctx) + case hashKey == "uri": + return string(ctx.RequestURI()) + case strings.HasPrefix(hashKey, "header:"): + headerName := strings.TrimPrefix(hashKey, "header:") + value := ctx.Request.Header.Peek(headerName) + if len(value) > 0 { + return string(value) + } + return netutil.ExtractClientIP(ctx) // fallback to IP + default: + return netutil.ExtractClientIP(ctx) + } +} + +// getClient 返回指定目标 URL 对应的 HostClient 连接池实例。 +// 如果目标 URL 不存在于连接池中,返回 nil。 +func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient { + key := targetURL + if p.config.ProxyBind != "" { + key = targetURL + "|" + p.config.ProxyBind + } + p.mu.RLock() + client := p.clients[key] + p.mu.RUnlock() + return client +} diff --git a/internal/server/init.go b/internal/server/init.go index 2643fea..854aab9 100644 --- a/internal/server/init.go +++ b/internal/server/init.go @@ -142,10 +142,10 @@ func initLuaEngine(luaCfg *config.LuaMiddlewareConfig) (*lua.LuaEngine, error) { engine, err := lua.NewEngine(engineCfg) if err != nil { - return nil, fmt.Errorf("初始化 Lua 引擎失败: %w", err) + return nil, fmt.Errorf("failed to initialize Lua engine: %w", err) } - logging.Info().Msg("Lua 引擎已启动") + logging.Info().Msg("Lua engine started") return engine, nil } @@ -169,14 +169,14 @@ func initErrorPageManager(errorPageCfg *config.ErrorPageConfig) (*handler.ErrorP if err != nil { // 检查是否是部分加载失败 if _, ok := err.(*handler.PartialLoadError); ok { - logging.Warn().Msg("部分错误页面加载失败: " + err.Error()) + logging.Warn().Msg("Some error pages failed to load: " + err.Error()) // 返回部分加载的管理器 return manager, nil } // 全部加载失败 - return nil, fmt.Errorf("加载错误页面失败: %w", err) + return nil, fmt.Errorf("failed to load error pages: %w", err) } - logging.Info().Msg("错误页面管理器已启动") + logging.Info().Msg("Error page manager started") return manager, nil } diff --git a/internal/server/lifecycle.go b/internal/server/lifecycle.go new file mode 100644 index 0000000..6eeb41f --- /dev/null +++ b/internal/server/lifecycle.go @@ -0,0 +1,222 @@ +package server + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/logging" +) + +// cleanupResources 清理服务器资源。 +// +// 停止 Goroutine 池、健康检查器,关闭访问日志、TLS 管理器、 +// AccessControl 和 Lua 引擎。由 StopWithTimeout 和 GracefulStop 共用。 +func (s *Server) cleanupResources() { + // 停止 Goroutine 池 + if s.pool != nil { + s.pool.Stop() + } + + // 停止健康检查器 + for _, hc := range s.healthCheckers { + hc.Stop() + } + + // 关闭访问日志 + if s.accessLogMiddleware != nil { + _ = s.accessLogMiddleware.Close() + } + + // 关闭 TLS 管理器 + if s.tlsManager != nil { + s.tlsManager.Close() + } + + // 关闭 AccessControl (释放 GeoIP 资源) + if s.accessControl != nil { + if err := s.accessControl.Close(); err != nil { + logging.Warn().Err(err).Msg("Failed to close AccessControl") + } + } + + // 关闭 Lua 引擎 + if s.luaEngine != nil { + s.luaEngine.Close() + logging.Info().Msg("Lua engine closed") + } +} + +// shutdownServers 并行关闭多个 fasthttp.Server 实例。 +// +// 使用 goroutine 并行关闭所有服务器,收集所有错误并返回聚合错误。 +// 部分服务器关闭失败不会影响其他服务器的关闭。 +// +// 参数: +// - ctx: 关闭上下文,用于控制超时和取消 +// - servers: 要关闭的 fasthttp.Server 实例列表 +// +// 返回值: +// - error: 聚合错误,无错误或全部成功时返回 nil +func shutdownServers(ctx context.Context, servers []*fasthttp.Server) error { + // 防御性检查:nil ctx 使用默认背景 + if ctx == nil { + ctx = context.Background() + } + if len(servers) == 0 { + return nil + } + + var ( + mu sync.Mutex + errs []error + wg sync.WaitGroup + ) + + for _, srv := range servers { + if srv == nil { + continue + } + wg.Add(1) + go func(s *fasthttp.Server) { + defer wg.Done() + if err := s.Shutdown(); err != nil { + mu.Lock() + errs = append(errs, err) + mu.Unlock() + } + }(srv) + } + + // 等待所有关闭完成或上下文取消 + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + if len(errs) == 0 { + return nil + } + if len(errs) == 1 { + return errs[0] + } + msgs := make([]string, len(errs)) + for i, e := range errs { + msgs[i] = e.Error() + } + return fmt.Errorf("failed to close servers: %d errors: %s", len(errs), strings.Join(msgs, "; ")) + case <-ctx.Done(): + return ctx.Err() + } +} + +// StopWithTimeout 快速停止服务器(支持自定义超时)。 +// +// 立即停止服务器,不等待正在处理的请求完成。 +// 停止所有健康检查器和访问日志中间件。 +// +// 参数: +// - timeout: 快速关闭的最大等待时间 +// +// 返回值: +// - error: 停止过程中遇到的错误 +// +// 注意事项: +// - 对于生产环境,建议使用 GracefulStop 实现优雅关闭 +// - timeout <= 0 时会使用默认 5s 超时 +func (s *Server) StopWithTimeout(timeout time.Duration) error { + // 防御性检查:如果 timeout <= 0,使用默认值 + if timeout <= 0 { + timeout = 5 * time.Second + } + + s.running = false + s.cleanupResources() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // 多服务器模式:并行关闭所有 fasthttp.Server + if len(s.fastServers) > 0 { + return shutdownServers(ctx, s.fastServers) + } + + // 单服务器模式:关闭单个 fasthttp.Server + if s.fastServer != nil { + done := make(chan struct{}) + go func() { + _ = s.fastServer.Shutdown() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } + } + return nil +} + +// GracefulStop 优雅停止服务器。 +// +// 等待正在处理的请求完成后再停止服务器,确保连接正常关闭。 +// 如果超时时间到达仍有请求未完成,将返回超时错误。 +// +// 参数: +// - timeout: 优雅关闭的最大等待时间 +// +// 返回值: +// - error: 停止过程中遇到的错误,超时返回 context.DeadlineExceeded +// +// 注意事项: +// - 推荐在生产环境使用此方法关闭服务器 +// - 超时后会强制关闭,可能导致部分请求中断 +func (s *Server) GracefulStop(timeout time.Duration) error { + s.running = false + s.cleanupResources() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + // 多服务器模式:并行关闭所有 fasthttp.Server + if len(s.fastServers) > 0 { + return shutdownServers(ctx, s.fastServers) + } + + // 单服务器模式:关闭单个 fasthttp.Server + if s.fastServer != nil { + done := make(chan struct{}) + go func() { + _ = s.fastServer.Shutdown() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } + } + return nil +} + +// getProxyCacheStats 收集所有代理缓存的统计信息。 +func (s *Server) getProxyCacheStats() ProxyCacheStats { + var total ProxyCacheStats + for _, p := range s.proxies { + if stats := p.GetCacheStats(); stats != nil { + total.Entries += stats.Entries + total.Pending += stats.Pending + } + } + return total +} diff --git a/internal/server/lua_integration_test.go b/internal/server/lua_integration_test.go index 7027f3c..4061eba 100644 --- a/internal/server/lua_integration_test.go +++ b/internal/server/lua_integration_test.go @@ -75,7 +75,7 @@ func TestBuildLuaMiddlewares_InvalidPhase(t *testing.T) { _, err = s.buildLuaMiddlewares(luaCfg) assert.Error(t, err) - assert.Contains(t, err.Error(), "无效的阶段") + assert.Contains(t, err.Error(), "invalid phase") } // TestBuildLuaMiddlewares_WithTimeout 测试超时配置 diff --git a/internal/server/middleware_builder.go b/internal/server/middleware_builder.go new file mode 100644 index 0000000..0d4974f --- /dev/null +++ b/internal/server/middleware_builder.go @@ -0,0 +1,246 @@ +package server + +import ( + "fmt" + "time" + + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/lua" + "rua.plus/lolly/internal/middleware" + "rua.plus/lolly/internal/middleware/accesslog" + "rua.plus/lolly/internal/middleware/bodylimit" + "rua.plus/lolly/internal/middleware/compression" + "rua.plus/lolly/internal/middleware/errorintercept" + "rua.plus/lolly/internal/middleware/rewrite" + "rua.plus/lolly/internal/middleware/security" +) + +// buildMiddlewareChain 构建中间件链。 +// +// 根据服务器配置按顺序构建中间件链,顺序为: +// +// AccessLog -> AccessControl -> RateLimiter -> BasicAuth -> Rewrite -> Compression -> SecurityHeaders +// +// 参数: +// - serverCfg: 单个服务器的配置对象 +// +// 返回值: +// - *middleware.Chain: 构建完成的中间件链 +// - error: 构建过程中遇到的错误,如中间件创建失败 +// +// 注意事项: +// - 各中间件按顺序依次包装请求处理器 +// - 未配置的中间件不会添加到链中 +func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middleware.Chain, error) { + var middlewares []middleware.Middleware + + // 1. AccessLog (已集成) + s.accessLogMiddleware = accesslog.New(&s.config.Logging) + middlewares = append(middlewares, s.accessLogMiddleware) + + // 2. Security: AccessControl (IP 访问控制) + if len(serverCfg.Security.Access.Allow) > 0 || len(serverCfg.Security.Access.Deny) > 0 { + ac, err := security.NewAccessControl(&serverCfg.Security.Access) + if err != nil { + return nil, fmt.Errorf("failed to create access control middleware: %w", err) + } + middlewares = append(middlewares, ac) + s.accessControl = ac + } + + // 3. Security: RateLimiter (速率限制) + if serverCfg.Security.RateLimit.RequestRate > 0 { + rl, err := security.NewRateLimiter(&serverCfg.Security.RateLimit) + if err != nil { + return nil, fmt.Errorf("failed to create rate limiter middleware: %w", err) + } + middlewares = append(middlewares, rl) + } + + // 3.5 Security: ConnLimiter (连接数限制) + if serverCfg.Security.RateLimit.ConnLimit > 0 { + cl, err := security.NewConnLimiter(serverCfg.Security.RateLimit.ConnLimit, true, serverCfg.Security.RateLimit.Key) + if err != nil { + return nil, fmt.Errorf("failed to create connection limiter middleware: %w", err) + } + middlewares = append(middlewares, cl.Middleware()) + } + + // 4. Security: BasicAuth (认证) + if len(serverCfg.Security.Auth.Users) > 0 { + auth, err := security.NewBasicAuth(&serverCfg.Security.Auth) + if err != nil { + return nil, fmt.Errorf("failed to create auth middleware: %w", err) + } + middlewares = append(middlewares, auth) + } + + // 4.3 Security: AuthRequest (外部认证子请求) + if serverCfg.Security.AuthRequest.Enabled && serverCfg.Security.AuthRequest.URI != "" { + authReq, err := security.NewAuthRequest(serverCfg.Security.AuthRequest) + if err != nil { + return nil, fmt.Errorf("failed to create auth request middleware: %w", err) + } + middlewares = append(middlewares, authReq) + } + + // 4.5 BodyLimit (请求体大小限制) + // 创建 bodylimit 中间件,使用全局配置或默认值 + bodyLimitMiddleware := bodylimit.NewWithDefault() + if serverCfg.ClientMaxBodySize != "" { + bl, err := bodylimit.New(serverCfg.ClientMaxBodySize) + if err != nil { + return nil, fmt.Errorf("failed to create body limit middleware: %w", err) + } + bodyLimitMiddleware = bl + } + // 添加路径级别的限制配置 + for i := range serverCfg.Proxy { + if serverCfg.Proxy[i].ClientMaxBodySize != "" { + if err := bodyLimitMiddleware.AddPathLimit( + serverCfg.Proxy[i].Path, + serverCfg.Proxy[i].ClientMaxBodySize, + ); err != nil { + return nil, fmt.Errorf("failed to add path body limit: %w", err) + } + } + } + middlewares = append(middlewares, bodyLimitMiddleware) + + // 5. Rewrite (URL 重写) + if len(serverCfg.Rewrite) > 0 { + rw, err := rewrite.New(serverCfg.Rewrite) + if err != nil { + return nil, fmt.Errorf("failed to create rewrite middleware: %w", err) + } + middlewares = append(middlewares, rw) + } + + // 6. Compression (响应压缩) + if serverCfg.Compression.Type != "" { + comp, err := compression.New(&serverCfg.Compression) + if err != nil { + return nil, fmt.Errorf("failed to create compression middleware: %w", err) + } + middlewares = append(middlewares, comp) + } + + // 7. SecurityHeaders (安全头部) + // 如果有任何安全头部配置,则启用 + if serverCfg.Security.Headers.XFrameOptions != "" || + serverCfg.Security.Headers.XContentTypeOptions != "" || + serverCfg.Security.Headers.ContentSecurityPolicy != "" || + serverCfg.Security.Headers.ReferrerPolicy != "" || + serverCfg.Security.Headers.PermissionsPolicy != "" { + headers := security.NewHeadersWithHSTS(&serverCfg.Security.Headers, &serverCfg.SSL.HSTS) + middlewares = append(middlewares, headers) + } + + // 8. ErrorIntercept (错误页面拦截) + // 如果配置了错误页面,添加错误拦截中间件 + if s.errorPageManager != nil && s.errorPageManager.IsConfigured() { + ei := errorintercept.New(s.errorPageManager) + middlewares = append(middlewares, ei) + } + + // Lua 中间件(可选) + if s.luaEngine != nil && serverCfg.Lua != nil && serverCfg.Lua.Enabled { + luaMiddlewares, err := s.buildLuaMiddlewares(serverCfg.Lua) + if err != nil { + return nil, fmt.Errorf("failed to create Lua middleware: %w", err) + } + middlewares = append(middlewares, luaMiddlewares...) + } + + return middleware.NewChain(middlewares...), nil +} + +// buildLuaMiddlewares 根据 Lua 配置创建中间件。 +// +// 根据 Scripts 配置创建 LuaMiddleware 或 MultiPhaseLuaMiddleware。 +// 支持单脚本和多阶段脚本配置。 +// +// 参数: +// - luaCfg: Lua 配置对象 +// +// 返回值: +// - []middleware.Middleware: 创建的中间件列表 +// - error: 创建过程中遇到的错误 +func (s *Server) buildLuaMiddlewares(luaCfg *config.LuaMiddlewareConfig) ([]middleware.Middleware, error) { + if s.luaEngine == nil { + return nil, nil + } + + // 按阶段分组脚本 + phaseScripts := make(map[string][]config.LuaScriptConfig) + for _, script := range luaCfg.Scripts { + // 默认启用 + enabled := script.Enabled + if !enabled && script.Timeout == 0 && script.Path != "" { + enabled = true // 零值时默认启用 + } + if enabled { + phaseScripts[script.Phase] = append(phaseScripts[script.Phase], script) + } + } + + var middlewares []middleware.Middleware + + // 为每个阶段创建中间件 + for phase, scripts := range phaseScripts { + if len(scripts) == 0 { + continue + } + + // 单脚本:直接创建 LuaMiddleware + if len(scripts) == 1 { + script := scripts[0] + luaPhase, err := lua.ParsePhase(phase) + if err != nil { + return nil, fmt.Errorf("invalid phase '%s': %w", phase, err) + } + + timeout := script.Timeout + if timeout == 0 { + timeout = 30 * time.Second + } + + cfg := lua.LuaMiddlewareConfig{ + ScriptPath: script.Path, + Phase: luaPhase, + Timeout: timeout, + Name: fmt.Sprintf("lua-%s", phase), + } + + mw, err := lua.NewLuaMiddleware(s.luaEngine, cfg) + if err != nil { + return nil, fmt.Errorf("failed to create Lua middleware (phase=%s): %w", phase, err) + } + + middlewares = append(middlewares, mw) + } else { + // 多脚本:创建 MultiPhaseLuaMiddleware + multi := lua.NewMultiPhaseLuaMiddleware(s.luaEngine, fmt.Sprintf("lua-multi-%s", phase)) + for _, script := range scripts { + luaPhase, err := lua.ParsePhase(phase) + if err != nil { + return nil, fmt.Errorf("invalid phase '%s': %w", phase, err) + } + + timeout := script.Timeout + if timeout == 0 { + timeout = 30 * time.Second + } + + err = multi.AddPhase(luaPhase, script.Path, timeout) + if err != nil { + return nil, fmt.Errorf("failed to add Lua phase (phase=%s): %w", phase, err) + } + } + + middlewares = append(middlewares, multi) + } + } + + return middlewares, nil +} diff --git a/internal/server/pprof.go b/internal/server/pprof.go index 34b083f..bd5ff1e 100644 --- a/internal/server/pprof.go +++ b/internal/server/pprof.go @@ -76,7 +76,7 @@ func NewPprofHandler(cfg *config.PprofConfig) (*PprofHandler, error) { // 尝试解析 CIDR _, net, err := net.ParseCIDR(ipStr) if err != nil { - return nil, fmt.Errorf("解析 IP/CIDR 失败: %s: %w", ipStr, err) + return nil, fmt.Errorf("failed to parse IP/CIDR: %s: %w", ipStr, err) } h.allowedNets = append(h.allowedNets, net) } diff --git a/internal/server/purge.go b/internal/server/purge.go index 2ff0c75..d6895f9 100644 --- a/internal/server/purge.go +++ b/internal/server/purge.go @@ -7,12 +7,11 @@ import ( "encoding/json" "net" "net/netip" - "strings" "github.com/valyala/fasthttp" "rua.plus/lolly/internal/cache" "rua.plus/lolly/internal/config" - "rua.plus/lolly/internal/netutil" + "rua.plus/lolly/internal/utils" ) // PurgeHandler 缓存清理 API 处理器。 @@ -104,26 +103,26 @@ func (h *PurgeHandler) Path() string { func (h *PurgeHandler) ServeHTTP(ctx *fasthttp.RequestCtx) { // 仅允许 POST 方法 if string(ctx.Method()) != "POST" { - h.sendError(ctx, fasthttp.StatusMethodNotAllowed, "method not allowed") + utils.SendJSONError(ctx, fasthttp.StatusMethodNotAllowed, "method not allowed") return } // 检查 IP 访问权限 - if !h.checkAccess(ctx) { - h.sendError(ctx, fasthttp.StatusForbidden, "forbidden") + if !utils.CheckIPAccess(ctx, h.allowed) { + utils.SendJSONError(ctx, fasthttp.StatusForbidden, "forbidden") return } // 检查认证 - if !h.checkAuth(ctx) { - h.sendError(ctx, fasthttp.StatusUnauthorized, "unauthorized") + if !utils.CheckTokenAuth(ctx, h.auth) { + utils.SendJSONError(ctx, fasthttp.StatusUnauthorized, "unauthorized") return } // 解析请求体 var req cache.PurgeRequest if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { - h.sendError(ctx, fasthttp.StatusBadRequest, "invalid request body") + utils.SendJSONError(ctx, fasthttp.StatusBadRequest, "invalid request body") return } @@ -134,7 +133,7 @@ func (h *PurgeHandler) ServeHTTP(ctx *fasthttp.RequestCtx) { } else if req.Pattern != "" { deleted = h.purgeByPattern(req.Pattern, req.Method) } else { - h.sendError(ctx, fasthttp.StatusBadRequest, "missing path or pattern") + utils.SendJSONError(ctx, fasthttp.StatusBadRequest, "missing path or pattern") return } @@ -144,55 +143,6 @@ func (h *PurgeHandler) ServeHTTP(ctx *fasthttp.RequestCtx) { _ = json.NewEncoder(ctx).Encode(cache.PurgeResponse{Deleted: deleted}) } -// checkAccess 检查客户端 IP 是否在允许列表中。 -func (h *PurgeHandler) checkAccess(ctx *fasthttp.RequestCtx) bool { - // 如果没有配置允许列表,允许所有访问 - if len(h.allowed) == 0 { - return true - } - - clientIP := netutil.ExtractClientIPNet(ctx) - if clientIP == nil { - return false - } - - // 检查是否在允许列表中 - for _, network := range h.allowed { - if network.Contains(clientIP) { - return true - } - } - - return false -} - -// checkAuth 检查认证。 -func (h *PurgeHandler) checkAuth(ctx *fasthttp.RequestCtx) bool { - // 无需认证 - if h.auth.Type == "" || h.auth.Type == "none" { - return true - } - - // Token 认证 - if h.auth.Type == "token" { - authHeader := ctx.Request.Header.Peek("Authorization") - if len(authHeader) == 0 { - return false - } - - authStr := string(authHeader) - // 支持 Bearer token 格式 - if token, ok := strings.CutPrefix(authStr, "Bearer "); ok { - return token == h.auth.Token - } - - // 也支持直接传递 token - return authStr == h.auth.Token - } - - return false -} - // purgeByPath 按精确路径清理缓存。 func (h *PurgeHandler) purgeByPath(path string, method string) int { if h.server == nil { @@ -229,13 +179,6 @@ func (h *PurgeHandler) purgeByPattern(pattern string, method string) int { return deleted } -// sendError 发送错误响应。 -func (h *PurgeHandler) sendError(ctx *fasthttp.RequestCtx, status int, errMsg string) { - ctx.SetContentType("application/json; charset=utf-8") - ctx.SetStatusCode(status) - _ = json.NewEncoder(ctx).Encode(cache.PurgeErrorResponse{Error: errMsg}) -} - // PurgeByPathForTest 测试用的导出方法。 func (h *PurgeHandler) PurgeByPathForTest(path string, method string) int { return h.purgeByPath(path, method) diff --git a/internal/server/purge_test.go b/internal/server/purge_test.go index 886c0a4..398c9e2 100644 --- a/internal/server/purge_test.go +++ b/internal/server/purge_test.go @@ -24,6 +24,7 @@ import ( "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/proxy" + "rua.plus/lolly/internal/utils" ) func TestPurgeHandler_Path(t *testing.T) { @@ -238,7 +239,7 @@ func TestPurgeHandler_checkAccess(t *testing.T) { if len(h.allowed) == 0 { // 无白名单时应允许所有访问 - if !h.checkAccess(nil) { + if !utils.CheckIPAccess(nil, h.allowed) { t.Error("expected access to be true when no allow list configured") } return @@ -281,7 +282,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) { } ctx := &fasthttp.RequestCtx{} - if !h.checkAuth(ctx) { + if !utils.CheckTokenAuth(ctx, h.auth) { t.Error("expected auth to pass when no auth configured") } }) @@ -301,7 +302,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) { } ctx := &fasthttp.RequestCtx{} - if !h.checkAuth(ctx) { + if !utils.CheckTokenAuth(ctx, h.auth) { t.Error("expected auth to pass when type is none") } }) @@ -323,7 +324,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.Set("Authorization", "Bearer secret-token") - if !h.checkAuth(ctx) { + if !utils.CheckTokenAuth(ctx, h.auth) { t.Error("expected auth to pass with correct Bearer token") } }) @@ -345,7 +346,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.Set("Authorization", "secret-token") - if !h.checkAuth(ctx) { + if !utils.CheckTokenAuth(ctx, h.auth) { t.Error("expected auth to pass with correct direct token") } }) @@ -367,7 +368,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.Set("Authorization", "Bearer wrong-token") - if h.checkAuth(ctx) { + if utils.CheckTokenAuth(ctx, h.auth) { t.Error("expected auth to fail with wrong token") } }) @@ -388,7 +389,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) { ctx := &fasthttp.RequestCtx{} - if h.checkAuth(ctx) { + if utils.CheckTokenAuth(ctx, h.auth) { t.Error("expected auth to fail when Authorization header is missing") } }) @@ -410,7 +411,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.Set("Authorization", "Bearer secret-token") - if h.checkAuth(ctx) { + if utils.CheckTokenAuth(ctx, h.auth) { t.Error("expected auth to fail for unknown auth type") } }) @@ -570,7 +571,7 @@ func TestPurgeHandler_SendError(t *testing.T) { Allow: []string{}, } - h, err := NewPurgeHandler(nil, cfg) + _, err := NewPurgeHandler(nil, cfg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -578,7 +579,7 @@ func TestPurgeHandler_SendError(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Init(&fasthttp.Request{}, nil, nil) - h.sendError(ctx, tt.status, tt.errMsg) + utils.SendJSONError(ctx, tt.status, tt.errMsg) if ctx.Response.StatusCode() != tt.status { t.Errorf("expected status %d, got %d", tt.status, ctx.Response.StatusCode()) @@ -690,7 +691,7 @@ func TestPurgeHandler_checkAccess_NilContext(t *testing.T) { } // Empty allow list should allow access (returns true even with nil context) - if !h.checkAccess(nil) { + if !utils.CheckIPAccess(nil, h.allowed) { t.Error("expected checkAccess to return true with empty allow list") } }) @@ -779,7 +780,7 @@ func TestPurgeHandler_checkAccess_WithAllowedIP(t *testing.T) { ctx.Init(&fasthttp.Request{}, nil, nil) // context with nil remote address - should return false (no client IP) - if h.checkAccess(ctx) { + if utils.CheckIPAccess(ctx, h.allowed) { t.Error("expected checkAccess to return false with no client IP") } }) diff --git a/internal/server/router.go b/internal/server/router.go new file mode 100644 index 0000000..d39e091 --- /dev/null +++ b/internal/server/router.go @@ -0,0 +1,281 @@ +package server + +import ( + "strings" + "time" + + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/handler" + "rua.plus/lolly/internal/loadbalance" + "rua.plus/lolly/internal/logging" + "rua.plus/lolly/internal/matcher" + "rua.plus/lolly/internal/proxy" +) + +// registerProxyRoutesWithLocationEngine 使用 LocationEngine 注册代理路由。 +// +// 根据配置为 LocationEngine 注册代理路径,创建代理处理器和健康检查器。 +// 支持通过 LocationType 配置不同的匹配方式。 +func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) { + for i := range serverCfg.Proxy { + proxyCfg := &serverCfg.Proxy[i] + + // 转换目标 + targets := make([]*loadbalance.Target, len(proxyCfg.Targets)) + for j, t := range proxyCfg.Targets { + failTimeout := t.FailTimeout + if t.MaxFails > 0 && failTimeout == 0 { + failTimeout = 10 * time.Second + } + targets[j] = loadbalance.NewTargetFromConfig( + t.URL, t.Weight, + int64(t.MaxConns), int64(t.MaxFails), failTimeout, + t.Backup, t.Down, t.ProxyURI, + ) + } + + // 传递 Transport 配置和 Lua 引擎 + p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport, s.luaEngine) + if err != nil { + logging.Error().Msg("Failed to create proxy: " + err.Error()) + continue + } + + // 设置 DNS 解析器(如果已配置) + if s.resolver != nil { + p.SetResolver(s.resolver) + if err := p.Start(); err != nil { + logging.Error().Err(err).Msg("Failed to start proxy") + } + } + + // 启动健康检查 + if proxyCfg.HealthCheck.Interval > 0 { + hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck) + hc.Start() + s.healthCheckers = append(s.healthCheckers, hc) + // 设置被动健康检查 + p.SetHealthChecker(hc) + } + + // 保存代理实例用于缓存统计 + s.proxies = append(s.proxies, p) + + // 根据 LocationType 注册路由 + locType := proxyCfg.LocationType + if locType == "" { + locType = matcher.LocationTypePrefix + } + + switch locType { + case matcher.LocationTypeExact: + _ = s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) + case matcher.LocationTypePrefixPriority: + _ = s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) + case matcher.LocationTypeRegex, matcher.LocationTypeRegexCaseless: + caseInsensitive := locType == matcher.LocationTypeRegexCaseless + _ = s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal) + case matcher.LocationTypeNamed: + if proxyCfg.LocationName != "" { + _ = s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP) + } + case matcher.LocationTypePrefix: + _ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) + default: + _ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) + } + } +} + +// registerStaticHandlersWithLocationEngine 使用 LocationEngine 注册静态文件处理器。 +func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) { + for _, static := range cfg.Static { + path := static.Path + if path == "" { + path = "/" + } + + staticHandler := handler.NewStaticHandler( + static.Root, + path, + static.Index, + true, // useSendfile + ) + // 设置 alias(与 root 互斥) + if static.Alias != "" { + staticHandler.SetAlias(static.Alias) + } + if s.fileCache != nil { + staticHandler.SetFileCache(s.fileCache) + // 设置默认缓存 TTL (5s) + staticHandler.SetCacheTTL(5 * time.Second) + } + if cfg.Compression.GzipStatic { + // extensions: 源文件类型,为空使用默认值 + // GzipStaticExtensions: 预压缩文件扩展名(如 .br, .gz) + staticHandler.SetGzipStatic(true, nil, cfg.Compression.GzipStaticExtensions) + } + + // 设置符号链接安全检查 + staticHandler.SetSymlinkCheck(static.SymlinkCheck) + + // 设置 internal 限制 + staticHandler.SetInternal(static.Internal) + + // 设置缓存过期时间 + if static.Expires != "" { + staticHandler.SetExpires(static.Expires) + } + + // 根据 LocationType 注册路由 + locType := static.LocationType + if locType == "" { + locType = matcher.LocationTypePrefix + } + + switch locType { + case matcher.LocationTypeExact: + _ = s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal) + case matcher.LocationTypePrefixPriority: + _ = s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal) + case matcher.LocationTypePrefix: + _ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal) + default: + _ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal) + } + } +} + +// registerProxyRoutes 注册代理路由。 +// +// 根据配置为路由器注册代理路径,创建代理处理器和健康检查器。 +// 支持 GET、POST、PUT、DELETE、HEAD 等 HTTP 方法。 +// +// 参数: +// - router: 路由器实例,用于注册路由规则 +// - serverCfg: 服务器配置,包含代理目标、负载均衡、健康检查等设置 +// +// 注意事项: +// - 代理目标初始状态默认为健康 +// - 健康检查根据配置自动启动 +func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.ServerConfig) { + for i := range serverCfg.Proxy { + proxyCfg := &serverCfg.Proxy[i] + + // 转换目标 + targets := make([]*loadbalance.Target, len(proxyCfg.Targets)) + for j, t := range proxyCfg.Targets { + failTimeout := t.FailTimeout + if t.MaxFails > 0 && failTimeout == 0 { + failTimeout = 10 * time.Second + } + targets[j] = loadbalance.NewTargetFromConfig( + t.URL, t.Weight, + int64(t.MaxConns), int64(t.MaxFails), failTimeout, + t.Backup, t.Down, t.ProxyURI, + ) + } + + // 传递 Transport 配置和 Lua 引擎 + p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport, s.luaEngine) + if err != nil { + logging.Error().Msg("Failed to create proxy: " + err.Error()) + continue + } + + // 设置 DNS 解析器(如果已配置) + if s.resolver != nil { + p.SetResolver(s.resolver) + if err := p.Start(); err != nil { + logging.Error().Err(err).Msg("Failed to start proxy") + } + } + + // 启动健康检查 + if proxyCfg.HealthCheck.Interval > 0 { + hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck) + hc.Start() + s.healthCheckers = append(s.healthCheckers, hc) + // 设置被动健康检查 + p.SetHealthChecker(hc) + } + + // 保存代理实例用于缓存统计 + s.proxies = append(s.proxies, p) + + // 使用前缀匹配(通配符)注册代理路由 + // path: / 匹配所有子路径如 /sorry/index + // path: /api/ 匹配 /api/* 所有子路径 + routePath := proxyCfg.Path + // 确保通配符路由格式正确 + if !strings.HasSuffix(routePath, "/") && routePath != "/" { + routePath += "/" + } + wildcardPath := routePath + "{path:*}" + router.GET(wildcardPath, p.ServeHTTP) + router.POST(wildcardPath, p.ServeHTTP) + router.PUT(wildcardPath, p.ServeHTTP) + router.DELETE(wildcardPath, p.ServeHTTP) + router.HEAD(wildcardPath, p.ServeHTTP) + } +} + +// registerStaticHandlers 注册静态文件处理器。 +// +// 为路由器注册静态文件服务,支持多个静态目录、文件缓存和预压缩文件。 +// +// 参数: +// - router: 路由器实例,用于注册路由规则 +// - cfg: 服务器配置,包含静态文件和压缩设置 +func (s *Server) registerStaticHandlers(router *handler.Router, cfg *config.ServerConfig) { + for _, static := range cfg.Static { + path := static.Path + if path == "" { + path = "/" + } + + staticHandler := handler.NewStaticHandler( + static.Root, + path, + static.Index, + true, // useSendfile + ) + // 设置 alias(与 root 互斥) + if static.Alias != "" { + staticHandler.SetAlias(static.Alias) + } + if s.fileCache != nil { + staticHandler.SetFileCache(s.fileCache) + // 设置默认缓存 TTL (5s) + staticHandler.SetCacheTTL(5 * time.Second) + } + if cfg.Compression.GzipStatic { + // extensions: 源文件类型,为空使用默认值 + // GzipStaticExtensions: 预压缩文件扩展名(如 .br, .gz) + staticHandler.SetGzipStatic(true, nil, cfg.Compression.GzipStaticExtensions) + } + + // 设置符号链接安全检查 + staticHandler.SetSymlinkCheck(static.SymlinkCheck) + + // 设置缓存过期时间 + if static.Expires != "" { + staticHandler.SetExpires(static.Expires) + } + + // 设置 try_files 配置 + if len(static.TryFiles) > 0 { + // 注意:tryFilesPass 需要路由器支持,当前实现传入 nil + // 如果 tryFilesPass 为 true,需要额外处理 + staticHandler.SetTryFiles(static.TryFiles, static.TryFilesPass, router) + } + + // 注册路由:确保路径以 / 结尾 + routePath := path + if !strings.HasSuffix(routePath, "/") { + routePath += "/" + } + router.GET(routePath+"{filepath:*}", staticHandler.Handle) + router.HEAD(routePath+"{filepath:*}", staticHandler.Handle) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index a782547..4174e90 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -20,7 +20,6 @@ package server import ( - "context" "crypto/tls" "fmt" "net" @@ -34,16 +33,10 @@ import ( "rua.plus/lolly/internal/cache" "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/handler" - "rua.plus/lolly/internal/loadbalance" "rua.plus/lolly/internal/logging" "rua.plus/lolly/internal/lua" "rua.plus/lolly/internal/matcher" - "rua.plus/lolly/internal/middleware" "rua.plus/lolly/internal/middleware/accesslog" - "rua.plus/lolly/internal/middleware/bodylimit" - "rua.plus/lolly/internal/middleware/compression" - "rua.plus/lolly/internal/middleware/errorintercept" - "rua.plus/lolly/internal/middleware/rewrite" "rua.plus/lolly/internal/middleware/security" "rua.plus/lolly/internal/mimeutil" "rua.plus/lolly/internal/proxy" @@ -211,236 +204,6 @@ func (s *Server) GetHandler() fasthttp.RequestHandler { return s.handler } -// buildMiddlewareChain 构建中间件链。 -// -// 根据服务器配置按顺序构建中间件链,顺序为: -// -// AccessLog -> AccessControl -> RateLimiter -> BasicAuth -> Rewrite -> Compression -> SecurityHeaders -// -// 参数: -// - serverCfg: 单个服务器的配置对象 -// -// 返回值: -// - *middleware.Chain: 构建完成的中间件链 -// - error: 构建过程中遇到的错误,如中间件创建失败 -// -// 注意事项: -// - 各中间件按顺序依次包装请求处理器 -// - 未配置的中间件不会添加到链中 -func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middleware.Chain, error) { - var middlewares []middleware.Middleware - - // 1. AccessLog (已集成) - s.accessLogMiddleware = accesslog.New(&s.config.Logging) - middlewares = append(middlewares, s.accessLogMiddleware) - - // 2. Security: AccessControl (IP 访问控制) - if len(serverCfg.Security.Access.Allow) > 0 || len(serverCfg.Security.Access.Deny) > 0 { - ac, err := security.NewAccessControl(&serverCfg.Security.Access) - if err != nil { - return nil, fmt.Errorf("创建访问控制中间件失败: %w", err) - } - middlewares = append(middlewares, ac) - s.accessControl = ac - } - - // 3. Security: RateLimiter (速率限制) - if serverCfg.Security.RateLimit.RequestRate > 0 { - rl, err := security.NewRateLimiter(&serverCfg.Security.RateLimit) - if err != nil { - return nil, fmt.Errorf("创建限流中间件失败: %w", err) - } - middlewares = append(middlewares, rl) - } - - // 3.5 Security: ConnLimiter (连接数限制) - if serverCfg.Security.RateLimit.ConnLimit > 0 { - cl, err := security.NewConnLimiter(serverCfg.Security.RateLimit.ConnLimit, true, serverCfg.Security.RateLimit.Key) - if err != nil { - return nil, fmt.Errorf("创建连接限制中间件失败: %w", err) - } - middlewares = append(middlewares, cl.Middleware()) - } - - // 4. Security: BasicAuth (认证) - if len(serverCfg.Security.Auth.Users) > 0 { - auth, err := security.NewBasicAuth(&serverCfg.Security.Auth) - if err != nil { - return nil, fmt.Errorf("创建认证中间件失败: %w", err) - } - middlewares = append(middlewares, auth) - } - - // 4.3 Security: AuthRequest (外部认证子请求) - if serverCfg.Security.AuthRequest.Enabled && serverCfg.Security.AuthRequest.URI != "" { - authReq, err := security.NewAuthRequest(serverCfg.Security.AuthRequest) - if err != nil { - return nil, fmt.Errorf("创建外部认证中间件失败: %w", err) - } - middlewares = append(middlewares, authReq) - } - - // 4.5 BodyLimit (请求体大小限制) - // 创建 bodylimit 中间件,使用全局配置或默认值 - bodyLimitMiddleware := bodylimit.NewWithDefault() - if serverCfg.ClientMaxBodySize != "" { - bl, err := bodylimit.New(serverCfg.ClientMaxBodySize) - if err != nil { - return nil, fmt.Errorf("创建请求体限制中间件失败: %w", err) - } - bodyLimitMiddleware = bl - } - // 添加路径级别的限制配置 - for i := range serverCfg.Proxy { - if serverCfg.Proxy[i].ClientMaxBodySize != "" { - if err := bodyLimitMiddleware.AddPathLimit( - serverCfg.Proxy[i].Path, - serverCfg.Proxy[i].ClientMaxBodySize, - ); err != nil { - return nil, fmt.Errorf("添加路径请求体限制失败: %w", err) - } - } - } - middlewares = append(middlewares, bodyLimitMiddleware) - - // 5. Rewrite (URL 重写) - if len(serverCfg.Rewrite) > 0 { - rw, err := rewrite.New(serverCfg.Rewrite) - if err != nil { - return nil, fmt.Errorf("创建重写中间件失败: %w", err) - } - middlewares = append(middlewares, rw) - } - - // 6. Compression (响应压缩) - if serverCfg.Compression.Type != "" { - comp, err := compression.New(&serverCfg.Compression) - if err != nil { - return nil, fmt.Errorf("创建压缩中间件失败: %w", err) - } - middlewares = append(middlewares, comp) - } - - // 7. SecurityHeaders (安全头部) - // 如果有任何安全头部配置,则启用 - if serverCfg.Security.Headers.XFrameOptions != "" || - serverCfg.Security.Headers.XContentTypeOptions != "" || - serverCfg.Security.Headers.ContentSecurityPolicy != "" || - serverCfg.Security.Headers.ReferrerPolicy != "" || - serverCfg.Security.Headers.PermissionsPolicy != "" { - headers := security.NewHeadersWithHSTS(&serverCfg.Security.Headers, &serverCfg.SSL.HSTS) - middlewares = append(middlewares, headers) - } - - // 8. ErrorIntercept (错误页面拦截) - // 如果配置了错误页面,添加错误拦截中间件 - if s.errorPageManager != nil && s.errorPageManager.IsConfigured() { - ei := errorintercept.New(s.errorPageManager) - middlewares = append(middlewares, ei) - } - - // Lua 中间件(可选) - if s.luaEngine != nil && serverCfg.Lua != nil && serverCfg.Lua.Enabled { - luaMiddlewares, err := s.buildLuaMiddlewares(serverCfg.Lua) - if err != nil { - return nil, fmt.Errorf("创建 Lua 中间件失败: %w", err) - } - middlewares = append(middlewares, luaMiddlewares...) - } - - return middleware.NewChain(middlewares...), nil -} - -// buildLuaMiddlewares 根据 Lua 配置创建中间件。 -// -// 根据 Scripts 配置创建 LuaMiddleware 或 MultiPhaseLuaMiddleware。 -// 支持单脚本和多阶段脚本配置。 -// -// 参数: -// - luaCfg: Lua 配置对象 -// -// 返回值: -// - []middleware.Middleware: 创建的中间件列表 -// - error: 创建过程中遇到的错误 -func (s *Server) buildLuaMiddlewares(luaCfg *config.LuaMiddlewareConfig) ([]middleware.Middleware, error) { - if s.luaEngine == nil { - return nil, nil - } - - // 按阶段分组脚本 - phaseScripts := make(map[string][]config.LuaScriptConfig) - for _, script := range luaCfg.Scripts { - // 默认启用 - enabled := script.Enabled - if !enabled && script.Timeout == 0 && script.Path != "" { - enabled = true // 零值时默认启用 - } - if enabled { - phaseScripts[script.Phase] = append(phaseScripts[script.Phase], script) - } - } - - var middlewares []middleware.Middleware - - // 为每个阶段创建中间件 - for phase, scripts := range phaseScripts { - if len(scripts) == 0 { - continue - } - - // 单脚本:直接创建 LuaMiddleware - if len(scripts) == 1 { - script := scripts[0] - luaPhase, err := lua.ParsePhase(phase) - if err != nil { - return nil, fmt.Errorf("无效的阶段 '%s': %w", phase, err) - } - - timeout := script.Timeout - if timeout == 0 { - timeout = 30 * time.Second - } - - cfg := lua.LuaMiddlewareConfig{ - ScriptPath: script.Path, - Phase: luaPhase, - Timeout: timeout, - Name: fmt.Sprintf("lua-%s", phase), - } - - mw, err := lua.NewLuaMiddleware(s.luaEngine, cfg) - if err != nil { - return nil, fmt.Errorf("创建 Lua 中间件失败 (phase=%s): %w", phase, err) - } - - middlewares = append(middlewares, mw) - } else { - // 多脚本:创建 MultiPhaseLuaMiddleware - multi := lua.NewMultiPhaseLuaMiddleware(s.luaEngine, fmt.Sprintf("lua-multi-%s", phase)) - for _, script := range scripts { - luaPhase, err := lua.ParsePhase(phase) - if err != nil { - return nil, fmt.Errorf("无效的阶段 '%s': %w", phase, err) - } - - timeout := script.Timeout - if timeout == 0 { - timeout = 30 * time.Second - } - - err = multi.AddPhase(luaPhase, script.Path, timeout) - if err != nil { - return nil, fmt.Errorf("添加 Lua 阶段失败 (phase=%s): %w", phase, err) - } - } - - middlewares = append(middlewares, multi) - } - } - - return middlewares, nil -} - // Start 启动 HTTP 服务器。 // // 初始化日志系统、性能优化组件(Goroutine池、文件缓存), @@ -549,13 +312,13 @@ func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error) mode = cfg.UnixSocket.Mode } if err := os.Chmod(socketPath, os.FileMode(mode)); err != nil { - logging.Warn().Err(err).Msg("设置 socket 文件权限失败") + logging.Warn().Err(err).Msg("Failed to set socket file permissions") } // 5. 设置文件所有权(需要 root 权限) if cfg.UnixSocket.User != "" || cfg.UnixSocket.Group != "" { // 简化处理:仅记录警告,实际实现需要 syscall.Chown - logging.Warn().Msg("Unix socket 用户/组配置需要 root 权限,已跳过") + logging.Warn().Msg("Unix socket user/group config requires root privileges, skipped") } return listener, nil @@ -590,7 +353,7 @@ func (s *Server) startSingleMode() error { if s.config.Monitoring.Status.Path != "" || len(s.config.Monitoring.Status.Allow) > 0 { statusHandler, err := NewStatusHandler(s, &s.config.Monitoring.Status) if err != nil { - logging.Error().Msg("创建状态处理器失败: " + err.Error()) + logging.Error().Msg("Failed to create status handler: " + err.Error()) } else { _ = s.locationEngine.AddExact(statusHandler.Path(), statusHandler.ServeHTTP, false) } @@ -600,7 +363,7 @@ func (s *Server) startSingleMode() error { if s.config.Monitoring.Pprof.Enabled { pprofHandler, err := NewPprofHandler(&s.config.Monitoring.Pprof) if err != nil { - logging.Error().Msg("创建 pprof 处理器失败: " + err.Error()) + logging.Error().Msg("Failed to create pprof handler: " + err.Error()) } else { _ = s.locationEngine.AddExact(pprofHandler.Path(), pprofHandler.ServeHTTP, false) _ = s.locationEngine.AddPrefixPriority(pprofHandler.Path()+"/", pprofHandler.ServeHTTP, false) @@ -611,7 +374,7 @@ func (s *Server) startSingleMode() error { if serverCfg.CacheAPI != nil && serverCfg.CacheAPI.Enabled { purgeHandler, err := NewPurgeHandler(s, serverCfg.CacheAPI) if err != nil { - logging.Error().Msg("创建缓存清理处理器失败: " + err.Error()) + logging.Error().Msg("Failed to create cache purge handler: " + err.Error()) } else { _ = s.locationEngine.AddExact(purgeHandler.Path(), purgeHandler.ServeHTTP, false) } @@ -685,7 +448,7 @@ func (s *Server) startSingleMode() error { var err error s.tlsManager, err = ssl.NewTLSManager(&serverCfg.SSL) if err != nil { - return fmt.Errorf("创建 TLS 管理器失败: %w", err) + return fmt.Errorf("failed to create TLS manager: %w", err) } s.fastServer.TLSConfig = s.tlsManager.GetTLSConfig() return s.fastServer.ServeTLS(ln, "", "") @@ -747,7 +510,7 @@ func (s *Server) startVHostMode() error { if s.config.Monitoring.Status.Enabled { statusHandler, err := NewStatusHandler(s, &s.config.Monitoring.Status) if err != nil { - logging.Error().Msg("创建状态处理器失败: " + err.Error()) + logging.Error().Msg("Failed to create status handler: " + err.Error()) } else { router.GET(statusHandler.Path(), statusHandler.ServeHTTP) } @@ -757,7 +520,7 @@ func (s *Server) startVHostMode() error { if s.config.Monitoring.Pprof.Enabled { pprofHandler, err := NewPprofHandler(&s.config.Monitoring.Pprof) if err != nil { - logging.Error().Msg("创建 pprof 处理器失败: " + err.Error()) + logging.Error().Msg("Failed to create pprof handler: " + err.Error()) } else { router.GET(pprofHandler.Path(), pprofHandler.ServeHTTP) router.GET(pprofHandler.Path()+"/{profile:*}", pprofHandler.ServeHTTP) @@ -769,7 +532,7 @@ func (s *Server) startVHostMode() error { if defaultSrv != nil && defaultSrv.CacheAPI != nil && defaultSrv.CacheAPI.Enabled { purgeHandler, err := NewPurgeHandler(s, defaultSrv.CacheAPI) if err != nil { - logging.Error().Msg("创建缓存清理处理器失败: " + err.Error()) + logging.Error().Msg("Failed to create cache purge handler: " + err.Error()) } else { router.POST(purgeHandler.Path(), purgeHandler.ServeHTTP) } @@ -829,7 +592,7 @@ func (s *Server) startVHostMode() error { var err error s.tlsManager, err = ssl.NewTLSManager(&serverCfg.SSL) if err != nil { - return fmt.Errorf("创建 TLS 管理器失败: %w", err) + return fmt.Errorf("failed to create TLS manager: %w", err) } s.fastServer.TLSConfig = s.tlsManager.GetTLSConfig() return s.fastServer.ServeTLS(ln, "", "") @@ -853,7 +616,7 @@ func (s *Server) startVHostMode() error { func (s *Server) startMultiServerMode() error { // 热升级检测:multi_server 热升级未实现,回退到 vhost 模式 if os.Getenv("GRACEFUL_UPGRADE") == "1" { - logging.Warn().Msg("热升级模式下 multi_server 模式未实现,回退到虚拟主机模式") + logging.Warn().Msg("multi_server mode not implemented for graceful upgrade, falling back to vhost mode") return s.startVHostMode() } @@ -874,7 +637,7 @@ func (s *Server) startMultiServerMode() error { // 创建监听器 ln, err := s.createListener(serverCfg) if err != nil { - errCh <- fmt.Errorf("监听地址 %s 失败: %w", serverCfg.Listen, err) + errCh <- fmt.Errorf("failed to listen on %s: %w", serverCfg.Listen, err) return } s.listeners[idx] = ln @@ -886,7 +649,7 @@ func (s *Server) startMultiServerMode() error { if idx == 0 && serverCfg.CacheAPI != nil && serverCfg.CacheAPI.Enabled { purgeHandler, purgeErr := NewPurgeHandler(s, serverCfg.CacheAPI) if purgeErr != nil { - errCh <- fmt.Errorf("创建缓存清理处理器失败 (server[%d]): %w", idx, purgeErr) + errCh <- fmt.Errorf("failed to create cache purge handler (server[%d]): %w", idx, purgeErr) return } router.POST(purgeHandler.Path(), purgeHandler.ServeHTTP) @@ -900,7 +663,7 @@ func (s *Server) startMultiServerMode() error { // 构建独立的中间件链 chain, err := s.buildMiddlewareChain(serverCfg) if err != nil { - errCh <- fmt.Errorf("构建中间件链失败 (server[%d]): %w", idx, err) + errCh <- fmt.Errorf("failed to build middleware chain (server[%d]): %w", idx, err) return } @@ -927,7 +690,7 @@ func (s *Server) startMultiServerMode() error { if serverCfg.SSL.Cert != "" && serverCfg.SSL.Key != "" { tlsManager, err := ssl.NewTLSManager(&serverCfg.SSL) if err != nil { - errCh <- fmt.Errorf("创建 TLS 管理器失败 (server[%d]): %w", idx, err) + errCh <- fmt.Errorf("failed to create TLS manager (server[%d]): %w", idx, err) return } fastSrv.TLSConfig = tlsManager.GetTLSConfig() @@ -978,7 +741,7 @@ func (s *Server) startMultiServerMode() error { serveErr = f.Serve(l) } if serveErr != nil { - logging.Error().Err(serveErr).Msgf("服务器 [%d] 监听 %s 时发生错误", i, l.Addr()) + logging.Error().Err(serveErr).Msgf("Server [%d] error while listening on %s", i, l.Addr()) } }(fastSrv, ln, idx) } @@ -988,509 +751,6 @@ func (s *Server) startMultiServerMode() error { return nil } -// registerProxyRoutesWithLocationEngine 使用 LocationEngine 注册代理路由。 -// -// 根据配置为 LocationEngine 注册代理路径,创建代理处理器和健康检查器。 -// 支持通过 LocationType 配置不同的匹配方式。 -func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) { - for i := range serverCfg.Proxy { - proxyCfg := &serverCfg.Proxy[i] - - // 转换目标 - targets := make([]*loadbalance.Target, len(proxyCfg.Targets)) - for j, t := range proxyCfg.Targets { - failTimeout := t.FailTimeout - if t.MaxFails > 0 && failTimeout == 0 { - failTimeout = 10 * time.Second - } - targets[j] = loadbalance.NewTargetFromConfig( - t.URL, t.Weight, - int64(t.MaxConns), int64(t.MaxFails), failTimeout, - t.Backup, t.Down, t.ProxyURI, - ) - } - - // 传递 Transport 配置和 Lua 引擎 - p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport, s.luaEngine) - if err != nil { - logging.Error().Msg("创建代理失败: " + err.Error()) - continue - } - - // 设置 DNS 解析器(如果已配置) - if s.resolver != nil { - p.SetResolver(s.resolver) - if err := p.Start(); err != nil { - logging.Error().Err(err).Msg("启动代理失败") - } - } - - // 启动健康检查 - if proxyCfg.HealthCheck.Interval > 0 { - hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck) - hc.Start() - s.healthCheckers = append(s.healthCheckers, hc) - // 设置被动健康检查 - p.SetHealthChecker(hc) - } - - // 保存代理实例用于缓存统计 - s.proxies = append(s.proxies, p) - - // 根据 LocationType 注册路由 - locType := proxyCfg.LocationType - if locType == "" { - locType = matcher.LocationTypePrefix - } - - switch locType { - case matcher.LocationTypeExact: - _ = s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) - case matcher.LocationTypePrefixPriority: - _ = s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) - case matcher.LocationTypeRegex, matcher.LocationTypeRegexCaseless: - caseInsensitive := locType == matcher.LocationTypeRegexCaseless - _ = s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal) - case matcher.LocationTypeNamed: - if proxyCfg.LocationName != "" { - _ = s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP) - } - case matcher.LocationTypePrefix: - _ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) - default: - _ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal) - } - } -} - -// registerStaticHandlersWithLocationEngine 使用 LocationEngine 注册静态文件处理器。 -func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) { - for _, static := range cfg.Static { - path := static.Path - if path == "" { - path = "/" - } - - staticHandler := handler.NewStaticHandler( - static.Root, - path, - static.Index, - true, // useSendfile - ) - // 设置 alias(与 root 互斥) - if static.Alias != "" { - staticHandler.SetAlias(static.Alias) - } - if s.fileCache != nil { - staticHandler.SetFileCache(s.fileCache) - // 设置默认缓存 TTL (5s) - staticHandler.SetCacheTTL(5 * time.Second) - } - if cfg.Compression.GzipStatic { - // extensions: 源文件类型,为空使用默认值 - // GzipStaticExtensions: 预压缩文件扩展名(如 .br, .gz) - staticHandler.SetGzipStatic(true, nil, cfg.Compression.GzipStaticExtensions) - } - - // 设置符号链接安全检查 - staticHandler.SetSymlinkCheck(static.SymlinkCheck) - - // 设置 internal 限制 - staticHandler.SetInternal(static.Internal) - - // 设置缓存过期时间 - if static.Expires != "" { - staticHandler.SetExpires(static.Expires) - } - - // 根据 LocationType 注册路由 - locType := static.LocationType - if locType == "" { - locType = matcher.LocationTypePrefix - } - - switch locType { - case matcher.LocationTypeExact: - _ = s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal) - case matcher.LocationTypePrefixPriority: - _ = s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal) - case matcher.LocationTypePrefix: - _ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal) - default: - _ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal) - } - } -} - -// registerProxyRoutes 注册代理路由。 -// -// 根据配置为路由器注册代理路径,创建代理处理器和健康检查器。 -// 支持 GET、POST、PUT、DELETE、HEAD 等 HTTP 方法。 -// -// 参数: -// - router: 路由器实例,用于注册路由规则 -// - serverCfg: 服务器配置,包含代理目标、负载均衡、健康检查等设置 -// -// 注意事项: -// - 代理目标初始状态默认为健康 -// - 健康检查根据配置自动启动 -func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.ServerConfig) { - for i := range serverCfg.Proxy { - proxyCfg := &serverCfg.Proxy[i] - - // 转换目标 - targets := make([]*loadbalance.Target, len(proxyCfg.Targets)) - for j, t := range proxyCfg.Targets { - failTimeout := t.FailTimeout - if t.MaxFails > 0 && failTimeout == 0 { - failTimeout = 10 * time.Second - } - targets[j] = loadbalance.NewTargetFromConfig( - t.URL, t.Weight, - int64(t.MaxConns), int64(t.MaxFails), failTimeout, - t.Backup, t.Down, t.ProxyURI, - ) - } - - // 传递 Transport 配置和 Lua 引擎 - p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport, s.luaEngine) - if err != nil { - logging.Error().Msg("创建代理失败: " + err.Error()) - continue - } - - // 设置 DNS 解析器(如果已配置) - if s.resolver != nil { - p.SetResolver(s.resolver) - if err := p.Start(); err != nil { - logging.Error().Err(err).Msg("启动代理失败") - } - } - - // 启动健康检查 - if proxyCfg.HealthCheck.Interval > 0 { - hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck) - hc.Start() - s.healthCheckers = append(s.healthCheckers, hc) - // 设置被动健康检查 - p.SetHealthChecker(hc) - } - - // 保存代理实例用于缓存统计 - s.proxies = append(s.proxies, p) - - // 使用前缀匹配(通配符)注册代理路由 - // path: / 匹配所有子路径如 /sorry/index - // path: /api/ 匹配 /api/* 所有子路径 - routePath := proxyCfg.Path - // 确保通配符路由格式正确 - if !strings.HasSuffix(routePath, "/") && routePath != "/" { - routePath += "/" - } - wildcardPath := routePath + "{path:*}" - router.GET(wildcardPath, p.ServeHTTP) - router.POST(wildcardPath, p.ServeHTTP) - router.PUT(wildcardPath, p.ServeHTTP) - router.DELETE(wildcardPath, p.ServeHTTP) - router.HEAD(wildcardPath, p.ServeHTTP) - } -} - -// shutdownServers 并行关闭多个 fasthttp.Server 实例。 -// -// 使用 goroutine 并行关闭所有服务器,收集所有错误并返回聚合错误。 -// 部分服务器关闭失败不会影响其他服务器的关闭。 -// -// 参数: -// - ctx: 关闭上下文,用于控制超时和取消 -// - servers: 要关闭的 fasthttp.Server 实例列表 -// -// 返回值: -// - error: 聚合错误,无错误或全部成功时返回 nil -func shutdownServers(ctx context.Context, servers []*fasthttp.Server) error { - // 防御性检查:nil ctx 使用默认背景 - if ctx == nil { - ctx = context.Background() - } - if len(servers) == 0 { - return nil - } - - var ( - mu sync.Mutex - errs []error - wg sync.WaitGroup - ) - - for _, srv := range servers { - if srv == nil { - continue - } - wg.Add(1) - go func(s *fasthttp.Server) { - defer wg.Done() - if err := s.Shutdown(); err != nil { - mu.Lock() - errs = append(errs, err) - mu.Unlock() - } - }(srv) - } - - // 等待所有关闭完成或上下文取消 - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - - select { - case <-done: - if len(errs) == 0 { - return nil - } - if len(errs) == 1 { - return errs[0] - } - msgs := make([]string, len(errs)) - for i, e := range errs { - msgs[i] = e.Error() - } - return fmt.Errorf("关闭服务器时发生 %d 个错误: %s", len(errs), strings.Join(msgs, "; ")) - case <-ctx.Done(): - return ctx.Err() - } -} - -// StopWithTimeout 快速停止服务器(支持自定义超时)。 -// -// 立即停止服务器,不等待正在处理的请求完成。 -// 停止所有健康检查器和访问日志中间件。 -// -// 参数: -// - timeout: 快速关闭的最大等待时间 -// -// 返回值: -// - error: 停止过程中遇到的错误 -// -// 注意事项: -// - 对于生产环境,建议使用 GracefulStop 实现优雅关闭 -// - timeout <= 0 时会使用默认 5s 超时 -func (s *Server) StopWithTimeout(timeout time.Duration) error { - // 防御性检查:如果 timeout <= 0,使用默认值 - if timeout <= 0 { - timeout = 5 * time.Second - } - - s.running = false - - // 停止 Goroutine 池 - if s.pool != nil { - s.pool.Stop() - } - - // 停止健康检查器 - for _, hc := range s.healthCheckers { - hc.Stop() - } - - // 关闭访问日志 - if s.accessLogMiddleware != nil { - _ = s.accessLogMiddleware.Close() - } - - // 关闭 TLS 管理器 - if s.tlsManager != nil { - s.tlsManager.Close() - } - - // 关闭 AccessControl (释放 GeoIP 资源) - if s.accessControl != nil { - if err := s.accessControl.Close(); err != nil { - logging.Warn().Err(err).Msg("关闭 AccessControl 失败") - } - } - - // 关闭 Lua 引擎 - if s.luaEngine != nil { - s.luaEngine.Close() - logging.Info().Msg("Lua 引擎已关闭") - } - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - // 多服务器模式:并行关闭所有 fasthttp.Server - if len(s.fastServers) > 0 { - return shutdownServers(ctx, s.fastServers) - } - - // 单服务器模式:关闭单个 fasthttp.Server - if s.fastServer != nil { - done := make(chan struct{}) - go func() { - _ = s.fastServer.Shutdown() - close(done) - }() - - select { - case <-done: - return nil - case <-ctx.Done(): - return ctx.Err() - } - } - return nil -} - -// GracefulStop 优雅停止服务器。 -// -// 等待正在处理的请求完成后再停止服务器,确保连接正常关闭。 -// 如果超时时间到达仍有请求未完成,将返回超时错误。 -// -// 参数: -// - timeout: 优雅关闭的最大等待时间 -// -// 返回值: -// - error: 停止过程中遇到的错误,超时返回 context.DeadlineExceeded -// -// 注意事项: -// - 推荐在生产环境使用此方法关闭服务器 -// - 超时后会强制关闭,可能导致部分请求中断 -func (s *Server) GracefulStop(timeout time.Duration) error { - s.running = false - - // 停止 Goroutine 池 - if s.pool != nil { - s.pool.Stop() - } - - // 停止健康检查器 - for _, hc := range s.healthCheckers { - hc.Stop() - } - - // 关闭访问日志 - if s.accessLogMiddleware != nil { - _ = s.accessLogMiddleware.Close() - } - - // 关闭 TLS 管理器 - if s.tlsManager != nil { - s.tlsManager.Close() - } - - // 关闭 AccessControl (释放 GeoIP 资源) - if s.accessControl != nil { - if err := s.accessControl.Close(); err != nil { - logging.Warn().Err(err).Msg("关闭 AccessControl 失败") - } - } - - // 关闭 Lua 引擎 - if s.luaEngine != nil { - s.luaEngine.Close() - logging.Info().Msg("Lua 引擎已关闭") - } - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - // 多服务器模式:并行关闭所有 fasthttp.Server - if len(s.fastServers) > 0 { - return shutdownServers(ctx, s.fastServers) - } - - // 单服务器模式:关闭单个 fasthttp.Server - if s.fastServer != nil { - done := make(chan struct{}) - go func() { - _ = s.fastServer.Shutdown() - close(done) - }() - - select { - case <-done: - return nil - case <-ctx.Done(): - return ctx.Err() - } - } - return nil -} - -// getProxyCacheStats 收集所有代理缓存的统计信息。 -func (s *Server) getProxyCacheStats() ProxyCacheStats { - var total ProxyCacheStats - for _, p := range s.proxies { - if stats := p.GetCacheStats(); stats != nil { - total.Entries += stats.Entries - total.Pending += stats.Pending - } - } - return total -} - -// registerStaticHandlers 注册静态文件处理器。 -// -// 为路由器注册静态文件服务,支持多个静态目录、文件缓存和预压缩文件。 -// -// 参数: -// - router: 路由器实例,用于注册路由规则 -// - cfg: 服务器配置,包含静态文件和压缩设置 -func (s *Server) registerStaticHandlers(router *handler.Router, cfg *config.ServerConfig) { - for _, static := range cfg.Static { - path := static.Path - if path == "" { - path = "/" - } - - staticHandler := handler.NewStaticHandler( - static.Root, - path, - static.Index, - true, // useSendfile - ) - // 设置 alias(与 root 互斥) - if static.Alias != "" { - staticHandler.SetAlias(static.Alias) - } - if s.fileCache != nil { - staticHandler.SetFileCache(s.fileCache) - // 设置默认缓存 TTL (5s) - staticHandler.SetCacheTTL(5 * time.Second) - } - if cfg.Compression.GzipStatic { - // extensions: 源文件类型,为空使用默认值 - // GzipStaticExtensions: 预压缩文件扩展名(如 .br, .gz) - staticHandler.SetGzipStatic(true, nil, cfg.Compression.GzipStaticExtensions) - } - - // 设置符号链接安全检查 - staticHandler.SetSymlinkCheck(static.SymlinkCheck) - - // 设置缓存过期时间 - if static.Expires != "" { - staticHandler.SetExpires(static.Expires) - } - - // 设置 try_files 配置 - if len(static.TryFiles) > 0 { - // 注意:tryFilesPass 需要路由器支持,当前实现传入 nil - // 如果 tryFilesPass 为 true,需要额外处理 - staticHandler.SetTryFiles(static.TryFiles, static.TryFilesPass, router) - } - - // 注册路由:确保路径以 / 结尾 - routePath := path - if !strings.HasSuffix(routePath, "/") { - routePath += "/" - } - router.GET(routePath+"{filepath:*}", staticHandler.Handle) - router.HEAD(routePath+"{filepath:*}", staticHandler.Handle) - } -} - // SetResolver 设置 DNS 解析器。 func (s *Server) SetResolver(r resolver.Resolver) { s.resolver = r diff --git a/internal/server/status.go b/internal/server/status.go index 4a4002c..ef4d51c 100644 --- a/internal/server/status.go +++ b/internal/server/status.go @@ -21,7 +21,6 @@ import ( "github.com/valyala/fasthttp" "rua.plus/lolly/internal/config" - "rua.plus/lolly/internal/netutil" "rua.plus/lolly/internal/utils" ) @@ -194,7 +193,7 @@ func (h *StatusHandler) Path() string { // - ctx: FastHTTP 请求上下文 func (h *StatusHandler) ServeHTTP(ctx *fasthttp.RequestCtx) { // 步骤1: 检查 IP 访问权限 - if !h.checkAccess(ctx) { + if !utils.CheckIPAccess(ctx, h.allowed) { utils.SendErrorWithDetail(ctx, utils.ErrForbidden, "Access denied") return } @@ -509,34 +508,6 @@ func (h *StatusHandler) serveHTML(ctx *fasthttp.RequestCtx, status *Status) { } } -// checkAccess 检查客户端 IP 是否在允许列表中。 -// -// 如果未配置允许列表,则允许所有访问。 -// 检查时支持代理头部(X-Forwarded-For、X-Real-IP)。 -// -// 参数: -// - ctx: FastHTTP 请求上下文 -// -// 返回值: -// - bool: true 表示允许访问,false 表示拒绝 -func (h *StatusHandler) checkAccess(ctx *fasthttp.RequestCtx) bool { - // 如果没有配置允许列表,允许所有访问 - if len(h.allowed) == 0 { - return true - } - - clientIP := netutil.ExtractClientIPNet(ctx) - - // 检查是否在允许列表中 - for _, network := range h.allowed { - if network.Contains(clientIP) { - return true - } - } - - return false -} - // collectStatus 收集服务器状态数据。 // // 从服务器实例读取各项统计指标,构建状态响应对象。 diff --git a/internal/server/status_test.go b/internal/server/status_test.go index 48d6fdc..4c7fb50 100644 --- a/internal/server/status_test.go +++ b/internal/server/status_test.go @@ -24,6 +24,7 @@ import ( "rua.plus/lolly/internal/cache" "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/netutil" + "rua.plus/lolly/internal/utils" ) func TestNewStatusHandler_CIDR(t *testing.T) { @@ -896,7 +897,7 @@ func TestStatusHandler_checkAccess_AllowList(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.SetRemoteAddr(&net.TCPAddr{IP: tt.remoteIP, Port: 12345}) - got := h.checkAccess(ctx) + got := utils.CheckIPAccess(ctx, h.allowed) if got != tt.wantAccess { t.Errorf("expected access %v, got %v", tt.wantAccess, got) } diff --git a/internal/server/vhost.go b/internal/server/vhost.go index ea04f7e..661a354 100644 --- a/internal/server/vhost.go +++ b/internal/server/vhost.go @@ -23,6 +23,7 @@ import ( "github.com/valyala/fasthttp" "rua.plus/lolly/internal/netutil" + "rua.plus/lolly/internal/utils" ) // VHostManager 虚拟主机管理器。 @@ -212,7 +213,7 @@ func (v *VHostManager) Handler() fasthttp.RequestHandler { if vhost := v.FindHost(host); vhost != nil { vhost.handler(ctx) } else { - ctx.Error("Host not found", fasthttp.StatusNotFound) + utils.SendError(ctx, utils.ErrNotFound) } } } diff --git a/internal/ssl/session_tickets.go b/internal/ssl/session_tickets.go index f004aaf..72df623 100644 --- a/internal/ssl/session_tickets.go +++ b/internal/ssl/session_tickets.go @@ -22,6 +22,7 @@ import ( "time" "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/logging" ) const ( @@ -192,9 +193,7 @@ func (m *SessionTicketManager) RotateKey() error { // 如果有密钥文件,保存所有密钥 if m.config.KeyFile != "" { if err := m.saveKeys(); err != nil { - // 保存失败不影响运行,记录错误即可 - // 这里可以考虑添加日志 - _ = err + logging.Warn().Err(err).Msg("Session Ticket 密钥保存失败") } } diff --git a/internal/ssl/ssl.go b/internal/ssl/ssl.go index 2c4550e..79b82bd 100644 --- a/internal/ssl/ssl.go +++ b/internal/ssl/ssl.go @@ -47,6 +47,7 @@ import ( "sync" "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/logging" "rua.plus/lolly/internal/netutil" ) @@ -145,9 +146,7 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) { if cfg.SessionTickets.Enabled { sessionTicketMgr, err := NewSessionTicketManager(cfg.SessionTickets) if err != nil { - // Session Tickets 初始化失败不阻止 TLS 工作 - // 可以记录日志 - _ = err + logging.Warn().Err(err).Msg("Session Ticket 初始化失败,TLS 性能可能降级") } else { manager.sessionTicketMgr = sessionTicketMgr // 应用 Session Tickets 到 TLS 配置 @@ -174,9 +173,9 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) { issuerCert, err := x509.ParseCertificate(cert.Certificate[1]) if err == nil { manager.issuers[serial] = issuerCert - // 注册证书用于 OCSP Stapling - // 错误会记录日志但不会阻止 TLS 工作 - _ = ocspMgr.RegisterCertificate(parsedCert, issuerCert) + if err := ocspMgr.RegisterCertificate(parsedCert, issuerCert); err != nil { + logging.Warn().Err(err).Msg("OCSP Stapling 注册失败") + } } } @@ -192,9 +191,7 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) { if cfg.ClientVerify.Enabled { clientVerifier, err := NewClientVerifier(cfg.ClientVerify) if err != nil { - // 客户端验证配置失败不阻止 TLS 工作 - // 可以记录日志 - _ = err + logging.Warn().Err(err).Msg("客户端证书验证配置失败") } else { manager.clientVerifier = clientVerifier clientVerifier.ConfigureTLS(tlsCfg) diff --git a/internal/stream/stream.go b/internal/stream/stream.go index c0f5d5e..32a1262 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -571,9 +571,22 @@ func (s *Server) handleConnection(clientConn net.Conn, _ string) { } defer func() { _ = targetConn.Close() }() - // 双向数据转发 - go func() { _, _ = io.Copy(targetConn, clientConn) }() - _, _ = io.Copy(clientConn, targetConn) + // 双向数据转发:任一方向完成/出错时立即关闭双端连接,迫使另一方向退出 + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + _, _ = io.Copy(targetConn, clientConn) + _ = clientConn.Close() + _ = targetConn.Close() + }() + go func() { + defer wg.Done() + _, _ = io.Copy(clientConn, targetConn) + _ = targetConn.Close() + _ = clientConn.Close() + }() + wg.Wait() } // Select 选择健康的上游目标。 diff --git a/internal/utils/httperror.go b/internal/utils/httperror.go index 6e0ee16..d2c3aab 100644 --- a/internal/utils/httperror.go +++ b/internal/utils/httperror.go @@ -4,7 +4,16 @@ // the scattered pattern of ctx.Error throughout the codebase. package utils -import "github.com/valyala/fasthttp" +import ( + "crypto/subtle" + "encoding/json" + "net" + "strings" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/netutil" +) // HTTPError represents an HTTP error with a message and status code. type HTTPError struct { @@ -37,3 +46,57 @@ func SendErrorWithDetail(ctx *fasthttp.RequestCtx, err HTTPError, detail string) SendError(ctx, err) } } + +// SendJSONError sends a JSON error response. +func SendJSONError(ctx *fasthttp.RequestCtx, status int, errMsg string) { + ctx.SetContentType("application/json; charset=utf-8") + ctx.SetStatusCode(status) + _ = json.NewEncoder(ctx).Encode(struct { + Error string `json:"error"` + }{Error: errMsg}) +} + +// CheckIPAccess checks whether the client IP is in the allowed list. +// If allowed is empty, all access is permitted. +func CheckIPAccess(ctx *fasthttp.RequestCtx, allowed []net.IPNet) bool { + if len(allowed) == 0 { + return true + } + + clientIP := netutil.ExtractClientIPNet(ctx) + if clientIP == nil { + return false + } + + for _, network := range allowed { + if network.Contains(clientIP) { + return true + } + } + + return false +} + +// CheckTokenAuth checks token-based authentication. +// Returns true if auth is disabled or the token matches. +func CheckTokenAuth(ctx *fasthttp.RequestCtx, auth config.CacheAPIAuthConfig) bool { + if auth.Type == "" || auth.Type == "none" { + return true + } + + if auth.Type == "token" { + authHeader := ctx.Request.Header.Peek("Authorization") + if len(authHeader) == 0 { + return false + } + + authStr := string(authHeader) + if token, ok := strings.CutPrefix(authStr, "Bearer "); ok { + return subtle.ConstantTimeCompare([]byte(token), []byte(auth.Token)) == 1 + } + + return subtle.ConstantTimeCompare([]byte(authStr), []byte(auth.Token)) == 1 + } + + return false +}