refactor: 提取公共逻辑、消除重复代码、加强错误处理
- 提取 App 公共逻辑到 app_common.go,消除 app.go/app_windows.go 重复定义
- 提取 Server 生命周期/中间件/路由逻辑到独立文件(lifecycle.go/middleware_builder.go/router.go)
- 提取 Proxy 缓存处理/头部修改/目标选择到独立模块
- 提取 CheckIPAccess/CheckTokenAuth 到 utils/httperror.go,消除 status/purge 重复实现
- 修复 stream 双向转发:任一方向完成立即关闭双端,避免连接泄漏
- 修复 SSL/TLS 中静默忽略错误的问题,添加日志记录
- 统一日志消息为英文
💘 Generated with Crush
Assisted-by: GLM 5.1 via Crush <crush@charm.land>
This commit is contained in:
parent
6f6a8f0455
commit
cf2fcca7e8
@ -4,196 +4,33 @@ package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/http2"
|
||||
"rua.plus/lolly/internal/http3"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
"rua.plus/lolly/internal/resolver"
|
||||
"rua.plus/lolly/internal/server"
|
||||
"rua.plus/lolly/internal/stream"
|
||||
"rua.plus/lolly/internal/variable"
|
||||
)
|
||||
|
||||
// App manages the server lifecycle, including HTTP, HTTP/3, Stream servers and graceful upgrades.
|
||||
type App struct {
|
||||
resv resolver.Resolver
|
||||
cfg *config.Config
|
||||
srv *server.Server
|
||||
http3Srv *http3.Server
|
||||
http2Srv *http2.Server
|
||||
streamSrv *stream.Server
|
||||
upgradeMgr *server.UpgradeManager
|
||||
logger *logging.AppLogger
|
||||
cfgPath string
|
||||
pidFile string
|
||||
logFile string
|
||||
listeners []net.Listener
|
||||
}
|
||||
|
||||
// NewApp creates a new App instance with the given config path.
|
||||
func NewApp(cfgPath string) *App {
|
||||
return &App{
|
||||
cfgPath: cfgPath,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPidFile sets the path to the PID file for the app.
|
||||
func (a *App) SetPidFile(path string) {
|
||||
a.pidFile = path
|
||||
}
|
||||
|
||||
// SetLogFile sets the path to the log file for the app.
|
||||
func (a *App) SetLogFile(path string) {
|
||||
a.logFile = path
|
||||
}
|
||||
|
||||
// Run starts the application: loads config, creates servers, and handles signals.
|
||||
func (a *App) Run() int {
|
||||
cfg, err := config.Load(a.cfgPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "加载配置失败: %v\n", err)
|
||||
if err := a.loadAndValidateConfig(); err != nil {
|
||||
return 1
|
||||
}
|
||||
a.cfg = cfg
|
||||
a.logger = logging.NewAppLogger(&cfg.Logging)
|
||||
|
||||
variable.SetGlobalVariables(cfg.Variables.Set)
|
||||
if len(cfg.Variables.Set) > 0 {
|
||||
a.logger.LogStartup("全局变量已加载", map[string]string{
|
||||
"count": fmt.Sprintf("%d", len(cfg.Variables.Set)),
|
||||
})
|
||||
}
|
||||
a.initVariables()
|
||||
|
||||
// Inherit parent listeners when running as a graceful upgrade child.
|
||||
if os.Getenv("GRACEFUL_UPGRADE") == "1" {
|
||||
a.logger.LogStartup("检测到热升级模式,继承父进程监听器", nil)
|
||||
a.upgradeMgr = server.NewUpgradeManager(nil)
|
||||
listeners, err := a.upgradeMgr.GetInheritedListeners()
|
||||
if err == nil && len(listeners) > 0 {
|
||||
a.listeners = listeners
|
||||
}
|
||||
}
|
||||
a.inheritListeners()
|
||||
|
||||
a.logger.LogStartup("配置加载成功", map[string]string{"config_path": a.cfgPath})
|
||||
|
||||
mode := a.cfg.GetMode()
|
||||
if mode == config.ServerModeMultiServer {
|
||||
for i, srv := range a.cfg.Servers {
|
||||
a.logger.LogStartup("监听地址", map[string]string{
|
||||
"index": fmt.Sprintf("[%d]", i),
|
||||
"listen": srv.Listen,
|
||||
"name": srv.Name,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
a.logger.LogStartup("监听地址", map[string]string{"listen": a.cfg.Servers[0].Listen})
|
||||
}
|
||||
|
||||
if a.cfg.Resolver.Enabled {
|
||||
a.resv = resolver.New(&a.cfg.Resolver)
|
||||
a.logger.LogStartup("DNS 解析器已启用", map[string]string{
|
||||
"addresses": fmt.Sprintf("%v", a.cfg.Resolver.Addresses),
|
||||
"ttl": a.cfg.Resolver.TTL().String(),
|
||||
})
|
||||
}
|
||||
|
||||
a.srv = server.New(a.cfg)
|
||||
|
||||
if a.resv != nil {
|
||||
a.srv.SetResolver(a.resv)
|
||||
}
|
||||
|
||||
if len(a.listeners) > 0 {
|
||||
a.srv.SetListeners(a.listeners)
|
||||
}
|
||||
|
||||
if len(a.cfg.Stream) > 0 {
|
||||
a.streamSrv = stream.NewServer()
|
||||
for _, sc := range a.cfg.Stream {
|
||||
targets := make([]stream.TargetSpec, len(sc.Upstream.Targets))
|
||||
for i, t := range sc.Upstream.Targets {
|
||||
targets[i] = stream.TargetSpec{
|
||||
Addr: t.Addr,
|
||||
Weight: t.Weight,
|
||||
}
|
||||
}
|
||||
|
||||
if err := a.streamSrv.AddUpstream(sc.Listen, targets, sc.Upstream.LoadBalance, stream.HealthCheckSpec{}); err != nil {
|
||||
a.logger.Error().Err(err).Msg("添加 Stream 上游失败")
|
||||
}
|
||||
|
||||
if sc.Protocol == "udp" {
|
||||
if err := a.streamSrv.ListenUDP(sc.Listen, sc.Listen, 60*time.Second); err != nil {
|
||||
a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("监听 UDP 失败")
|
||||
}
|
||||
} else {
|
||||
if err := a.streamSrv.ListenTCP(sc.Listen); err != nil {
|
||||
a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("监听 TCP 失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
a.logger.LogStartup("Stream 服务器启动中", nil)
|
||||
if err := a.streamSrv.Start(); err != nil {
|
||||
a.logger.Error().Err(err).Msg("Stream 服务器启动失败")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if a.cfg.HTTP3.Enabled && a.cfg.Servers[0].SSL.Cert != "" {
|
||||
tlsConfig, err := a.srv.GetTLSConfig()
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/3")
|
||||
} else {
|
||||
a.http3Srv, err = http3.NewServer(&a.cfg.HTTP3, a.srv.GetHandler(), tlsConfig)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("创建 HTTP/3 服务器失败")
|
||||
} else {
|
||||
go func() {
|
||||
a.logger.LogStartup("HTTP/3 服务器启动中", map[string]string{"listen": a.cfg.HTTP3.Listen})
|
||||
if err := a.http3Srv.Start(); err != nil {
|
||||
a.logger.Error().Err(err).Msg("HTTP/3 服务器启动失败")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if a.cfg.Servers[0].SSL.HTTP2.Enabled && a.cfg.Servers[0].SSL.Cert != "" {
|
||||
tlsConfig, err := a.srv.GetTLSConfig()
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/2")
|
||||
} else {
|
||||
a.http2Srv, err = http2.NewServer(&a.cfg.Servers[0].SSL.HTTP2, a.srv.GetHandler(), tlsConfig)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("创建 HTTP/2 服务器失败")
|
||||
} else {
|
||||
go func() {
|
||||
a.logger.LogStartup("HTTP/2 服务器启动中", map[string]string{
|
||||
"listen": a.cfg.Servers[0].Listen,
|
||||
"max_concurrent_streams": fmt.Sprintf("%d", a.cfg.Servers[0].SSL.HTTP2.MaxConcurrentStreams),
|
||||
"push_enabled": fmt.Sprintf("%t", a.cfg.Servers[0].SSL.HTTP2.PushEnabled),
|
||||
})
|
||||
// HTTP/2 shares the main server's listener; ALPN negotiates protocol selection.
|
||||
listeners := a.srv.GetListeners()
|
||||
if len(listeners) > 0 {
|
||||
if err := a.http2Srv.Serve(listeners[0]); err != nil {
|
||||
a.logger.Error().Err(err).Msg("HTTP/2 服务器启动失败")
|
||||
}
|
||||
} else {
|
||||
a.logger.Error().Msg("HTTP/2 服务器启动失败: 无可用监听器")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
a.logServerAddresses()
|
||||
a.initResolver()
|
||||
a.initServer()
|
||||
a.initStreamServers()
|
||||
a.initHTTP3()
|
||||
a.initHTTP2()
|
||||
|
||||
a.upgradeMgr = server.NewUpgradeManager(a.srv)
|
||||
a.srv.SetUpgradeManager(a.upgradeMgr)
|
||||
@ -207,7 +44,7 @@ func (a *App) Run() int {
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
a.logger.LogStartup("HTTP 服务器启动中", nil)
|
||||
a.logger.LogStartup("Starting HTTP server", nil)
|
||||
if err := a.srv.Start(); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
@ -218,24 +55,36 @@ func (a *App) Run() int {
|
||||
for {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
a.logger.Error().Err(err).Msg("服务器启动失败")
|
||||
a.logger.Error().Err(err).Msg("Server failed to start")
|
||||
return 1
|
||||
case sig := <-sigChan:
|
||||
if sig == syscall.SIGINT {
|
||||
sigintCount++
|
||||
if sigintCount >= 3 {
|
||||
a.logger.LogShutdown("收到 3 次 SIGINT,强制退出")
|
||||
a.logger.LogShutdown("Received 3 SIGINT, forcing exit")
|
||||
return 1
|
||||
}
|
||||
}
|
||||
if !a.handleSignal(sig) {
|
||||
a.logger.LogShutdown("服务器已停止")
|
||||
a.logger.LogShutdown("Server stopped")
|
||||
return 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inheritListeners inherits parent listeners during graceful upgrade.
|
||||
func (a *App) inheritListeners() {
|
||||
if os.Getenv("GRACEFUL_UPGRADE") == "1" {
|
||||
a.logger.LogStartup("Graceful upgrade mode detected, inheriting parent listeners", nil)
|
||||
a.upgradeMgr = server.NewUpgradeManager(nil)
|
||||
listeners, err := a.upgradeMgr.GetInheritedListeners()
|
||||
if err == nil && len(listeners) > 0 {
|
||||
a.listeners = listeners
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) setupSignalHandlers(sigChan chan<- os.Signal) {
|
||||
signal.Notify(sigChan,
|
||||
syscall.SIGTERM,
|
||||
@ -250,7 +99,7 @@ func (a *App) setupSignalHandlers(sigChan chan<- os.Signal) {
|
||||
// handleSignal returns false to indicate the app should exit.
|
||||
func (a *App) handleSignal(sig os.Signal) bool {
|
||||
if a.cfg == nil {
|
||||
a.logger.Error().Msg("信号处理失败: 配置为 nil,使用默认超时")
|
||||
a.logger.Error().Msg("Signal handling failed: config is nil, using default timeout")
|
||||
a.cfg = &config.Config{
|
||||
Shutdown: config.ShutdownConfig{
|
||||
GracefulTimeout: 30 * time.Second,
|
||||
@ -265,7 +114,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
||||
if timeout <= 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("优雅停止(等待 %v)", timeout))
|
||||
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("Graceful stop (waiting %v)", timeout))
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.GracefulStop(timeout)
|
||||
@ -278,9 +127,9 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
||||
}
|
||||
sigTyped, ok := sig.(syscall.Signal)
|
||||
if !ok {
|
||||
a.logger.LogSignal("unknown", "停止服务器")
|
||||
a.logger.LogSignal("unknown", "Stopping server")
|
||||
} else {
|
||||
a.logger.LogSignal(sigName(sigTyped), "停止服务器")
|
||||
a.logger.LogSignal(sigName(sigTyped), "Stopping server")
|
||||
}
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
@ -288,89 +137,65 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
||||
return false
|
||||
|
||||
case syscall.SIGHUP:
|
||||
a.logger.LogSignal("SIGHUP", "重载配置")
|
||||
a.logger.LogSignal("SIGHUP", "Reloading config")
|
||||
a.reloadConfig()
|
||||
return true
|
||||
|
||||
case syscall.SIGUSR1:
|
||||
a.logger.LogSignal("SIGUSR1", "重新打开日志")
|
||||
a.logger.LogSignal("SIGUSR1", "Reopening logs")
|
||||
a.reopenLogs()
|
||||
return true
|
||||
|
||||
case syscall.SIGUSR2:
|
||||
a.logger.LogSignal("SIGUSR2", "执行热升级")
|
||||
a.logger.LogSignal("SIGUSR2", "Performing graceful upgrade")
|
||||
a.gracefulUpgrade()
|
||||
return true
|
||||
|
||||
default:
|
||||
a.logger.Info().Str("signal", sig.String()).Msg("收到未知信号")
|
||||
a.logger.Info().Str("signal", sig.String()).Msg("Received unknown signal")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) shutdownHTTP3() {
|
||||
if a.http3Srv != nil {
|
||||
if err := a.http3Srv.Stop(); err != nil {
|
||||
a.logger.Error().Err(err).Msg("HTTP/3 服务器关闭失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) shutdownHTTP2() {
|
||||
if a.http2Srv != nil {
|
||||
if err := a.http2Srv.Stop(); err != nil {
|
||||
a.logger.Error().Err(err).Msg("HTTP/2 服务器关闭失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) reloadConfig() {
|
||||
newCfg, err := config.Load(a.cfgPath)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("重载配置失败")
|
||||
a.logger.Error().Err(err).Msg("Failed to reload config")
|
||||
return
|
||||
}
|
||||
|
||||
a.cfg = newCfg
|
||||
a.logger = logging.NewAppLogger(&newCfg.Logging)
|
||||
a.logger.LogStartup("配置重载成功", nil)
|
||||
}
|
||||
|
||||
func (a *App) reopenLogs() {
|
||||
if a.cfg != nil {
|
||||
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)
|
||||
a.logger = logging.NewAppLogger(&a.cfg.Logging)
|
||||
}
|
||||
a.logger.LogStartup("日志已重新打开", nil)
|
||||
a.logger.LogStartup("Config reloaded successfully", nil)
|
||||
}
|
||||
|
||||
func (a *App) gracefulUpgrade() {
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("获取可执行文件路径失败")
|
||||
a.logger.Error().Err(err).Msg("Failed to get executable path")
|
||||
return
|
||||
}
|
||||
|
||||
if a.srv == nil {
|
||||
a.logger.Error().Msg("热升级失败: 服务器实例为 nil")
|
||||
a.logger.Error().Msg("Graceful upgrade failed: server instance is nil")
|
||||
return
|
||||
}
|
||||
|
||||
listeners := a.srv.GetListeners()
|
||||
if len(listeners) == 0 {
|
||||
a.logger.Error().Msg("热升级失败: 服务器未保存监听器(热升级当前未完全实现)")
|
||||
a.logger.Info().Msg("提示: 热升级需要服务器使用手动监听器管理模式")
|
||||
a.logger.Error().Msg("Graceful upgrade failed: server has no saved listeners (graceful upgrade not fully implemented)")
|
||||
a.logger.Info().Msg("Hint: graceful upgrade requires the server to use manual listener management mode")
|
||||
return
|
||||
}
|
||||
|
||||
a.upgradeMgr.SetListeners(listeners)
|
||||
|
||||
if err := a.upgradeMgr.GracefulUpgrade(execPath); err != nil {
|
||||
a.logger.Error().Err(err).Msg("热升级失败")
|
||||
a.logger.Error().Err(err).Msg("Graceful upgrade failed")
|
||||
return
|
||||
}
|
||||
|
||||
a.logger.LogStartup("热升级已启动,新进程正在接管", nil)
|
||||
a.logger.LogStartup("Graceful upgrade started, new process is taking over", nil)
|
||||
|
||||
timeout := a.cfg.Shutdown.GracefulTimeout
|
||||
if timeout <= 0 {
|
||||
|
||||
242
internal/app/app_common.go
Normal file
242
internal/app/app_common.go
Normal file
@ -0,0 +1,242 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/http2"
|
||||
"rua.plus/lolly/internal/http3"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
"rua.plus/lolly/internal/resolver"
|
||||
"rua.plus/lolly/internal/server"
|
||||
"rua.plus/lolly/internal/stream"
|
||||
"rua.plus/lolly/internal/variable"
|
||||
)
|
||||
|
||||
// App manages the server lifecycle, including HTTP, HTTP/3, Stream servers and graceful upgrades.
|
||||
type App struct {
|
||||
resv resolver.Resolver
|
||||
cfg *config.Config
|
||||
srv *server.Server
|
||||
http3Srv *http3.Server
|
||||
http2Srv *http2.Server
|
||||
streamSrv *stream.Server
|
||||
upgradeMgr *server.UpgradeManager
|
||||
logger *logging.AppLogger
|
||||
cfgPath string
|
||||
pidFile string
|
||||
logFile string
|
||||
listeners []net.Listener
|
||||
}
|
||||
|
||||
// NewApp creates a new App instance with the given config path.
|
||||
func NewApp(cfgPath string) *App {
|
||||
return &App{
|
||||
cfgPath: cfgPath,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPidFile sets the path to the PID file for the app.
|
||||
func (a *App) SetPidFile(path string) {
|
||||
a.pidFile = path
|
||||
}
|
||||
|
||||
// SetLogFile sets the path to the log file for the app.
|
||||
func (a *App) SetLogFile(path string) {
|
||||
a.logFile = path
|
||||
}
|
||||
|
||||
// loadAndValidateConfig loads configuration and initializes the logger.
|
||||
func (a *App) loadAndValidateConfig() error {
|
||||
cfg, err := config.Load(a.cfgPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to load config: %v\n", err)
|
||||
return err
|
||||
}
|
||||
a.cfg = cfg
|
||||
a.logger = logging.NewAppLogger(&cfg.Logging)
|
||||
return nil
|
||||
}
|
||||
|
||||
// initVariables loads global variables from configuration.
|
||||
func (a *App) initVariables() {
|
||||
variable.SetGlobalVariables(a.cfg.Variables.Set)
|
||||
if len(a.cfg.Variables.Set) > 0 {
|
||||
a.logger.LogStartup("Global variables loaded", map[string]string{
|
||||
"count": fmt.Sprintf("%d", len(a.cfg.Variables.Set)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// logServerAddresses logs the listening addresses based on server mode.
|
||||
func (a *App) logServerAddresses() {
|
||||
a.logger.LogStartup("Config loaded successfully", map[string]string{"config_path": a.cfgPath})
|
||||
|
||||
mode := a.cfg.GetMode()
|
||||
if mode == config.ServerModeMultiServer {
|
||||
for i, srv := range a.cfg.Servers {
|
||||
a.logger.LogStartup("Listening address", map[string]string{
|
||||
"index": fmt.Sprintf("[%d]", i),
|
||||
"listen": srv.Listen,
|
||||
"name": srv.Name,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
a.logger.LogStartup("Listening address", map[string]string{"listen": a.cfg.Servers[0].Listen})
|
||||
}
|
||||
}
|
||||
|
||||
// initResolver initializes the DNS resolver if enabled.
|
||||
func (a *App) initResolver() {
|
||||
if a.cfg.Resolver.Enabled {
|
||||
a.resv = resolver.New(&a.cfg.Resolver)
|
||||
a.logger.LogStartup("DNS resolver enabled", map[string]string{
|
||||
"addresses": fmt.Sprintf("%v", a.cfg.Resolver.Addresses),
|
||||
"ttl": a.cfg.Resolver.TTL().String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// initServer creates the main server and sets the resolver.
|
||||
func (a *App) initServer() {
|
||||
a.srv = server.New(a.cfg)
|
||||
|
||||
if a.resv != nil {
|
||||
a.srv.SetResolver(a.resv)
|
||||
}
|
||||
|
||||
if len(a.listeners) > 0 {
|
||||
a.srv.SetListeners(a.listeners)
|
||||
}
|
||||
}
|
||||
|
||||
// initStreamServers configures and starts stream servers.
|
||||
func (a *App) initStreamServers() {
|
||||
if len(a.cfg.Stream) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
a.streamSrv = stream.NewServer()
|
||||
for _, sc := range a.cfg.Stream {
|
||||
targets := make([]stream.TargetSpec, len(sc.Upstream.Targets))
|
||||
for i, t := range sc.Upstream.Targets {
|
||||
targets[i] = stream.TargetSpec{
|
||||
Addr: t.Addr,
|
||||
Weight: t.Weight,
|
||||
}
|
||||
}
|
||||
|
||||
if err := a.streamSrv.AddUpstream(sc.Listen, targets, sc.Upstream.LoadBalance, stream.HealthCheckSpec{}); err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to add Stream upstream")
|
||||
}
|
||||
|
||||
if sc.Protocol == "udp" {
|
||||
if err := a.streamSrv.ListenUDP(sc.Listen, sc.Listen, 60*time.Second); err != nil {
|
||||
a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("Failed to listen on UDP")
|
||||
}
|
||||
} else {
|
||||
if err := a.streamSrv.ListenTCP(sc.Listen); err != nil {
|
||||
a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("Failed to listen on TCP")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
a.logger.LogStartup("Starting Stream server", nil)
|
||||
if err := a.streamSrv.Start(); err != nil {
|
||||
a.logger.Error().Err(err).Msg("Stream server failed to start")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// initHTTP3 starts the HTTP/3 server if enabled.
|
||||
func (a *App) initHTTP3() {
|
||||
if !a.cfg.HTTP3.Enabled || a.cfg.Servers[0].SSL.Cert == "" {
|
||||
return
|
||||
}
|
||||
|
||||
tlsConfig, err := a.srv.GetTLSConfig()
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to get TLS config, skipping HTTP/3")
|
||||
return
|
||||
}
|
||||
|
||||
a.http3Srv, err = http3.NewServer(&a.cfg.HTTP3, a.srv.GetHandler(), tlsConfig)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to create HTTP/3 server")
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
a.logger.LogStartup("Starting HTTP/3 server", map[string]string{"listen": a.cfg.HTTP3.Listen})
|
||||
if err := a.http3Srv.Start(); err != nil {
|
||||
a.logger.Error().Err(err).Msg("HTTP/3 server failed to start")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// initHTTP2 starts the HTTP/2 server if enabled.
|
||||
func (a *App) initHTTP2() {
|
||||
if !a.cfg.Servers[0].SSL.HTTP2.Enabled || a.cfg.Servers[0].SSL.Cert == "" {
|
||||
return
|
||||
}
|
||||
|
||||
tlsConfig, err := a.srv.GetTLSConfig()
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to get TLS config, skipping HTTP/2")
|
||||
return
|
||||
}
|
||||
|
||||
a.http2Srv, err = http2.NewServer(&a.cfg.Servers[0].SSL.HTTP2, a.srv.GetHandler(), tlsConfig)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to create HTTP/2 server")
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
a.logger.LogStartup("Starting HTTP/2 server", map[string]string{
|
||||
"listen": a.cfg.Servers[0].Listen,
|
||||
"max_concurrent_streams": fmt.Sprintf("%d", a.cfg.Servers[0].SSL.HTTP2.MaxConcurrentStreams),
|
||||
"push_enabled": fmt.Sprintf("%t", a.cfg.Servers[0].SSL.HTTP2.PushEnabled),
|
||||
})
|
||||
// HTTP/2 shares the main server's listener; ALPN negotiates protocol selection.
|
||||
listeners := a.srv.GetListeners()
|
||||
if len(listeners) > 0 {
|
||||
if err := a.http2Srv.Serve(listeners[0]); err != nil {
|
||||
a.logger.Error().Err(err).Msg("HTTP/2 server failed to start")
|
||||
}
|
||||
} else {
|
||||
a.logger.Error().Msg("HTTP/2 server failed to start: no available listeners")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// shutdownHTTP3 gracefully stops the HTTP/3 server.
|
||||
func (a *App) shutdownHTTP3() {
|
||||
if a.http3Srv != nil {
|
||||
if err := a.http3Srv.Stop(); err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to shutdown HTTP/3 server")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shutdownHTTP2 gracefully stops the HTTP/2 server.
|
||||
func (a *App) shutdownHTTP2() {
|
||||
if a.http2Srv != nil {
|
||||
if err := a.http2Srv.Stop(); err != nil {
|
||||
a.logger.Error().Err(err).Msg("Failed to shutdown HTTP/2 server")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reopenLogs reinitializes the logger from current config.
|
||||
func (a *App) reopenLogs() {
|
||||
if a.cfg != nil {
|
||||
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)
|
||||
a.logger = logging.NewAppLogger(&a.cfg.Logging)
|
||||
}
|
||||
a.logger.LogStartup("Logs reopened", nil)
|
||||
}
|
||||
@ -225,7 +225,7 @@ func TestRun(t *testing.T) {
|
||||
genConfig: true,
|
||||
outputPath: filepath.Join(t.TempDir(), "config.yaml"),
|
||||
wantExitCode: 0,
|
||||
wantContains: "配置已写入:",
|
||||
wantContains: "Config written to:",
|
||||
},
|
||||
{
|
||||
name: "配置文件不存在",
|
||||
@ -233,7 +233,7 @@ func TestRun(t *testing.T) {
|
||||
genConfig: false,
|
||||
showVersion: false,
|
||||
wantExitCode: 1,
|
||||
wantErrContains: "加载配置失败",
|
||||
wantErrContains: "Failed to load config",
|
||||
},
|
||||
{
|
||||
name: "generate 与 import 互斥",
|
||||
@ -252,7 +252,7 @@ func TestRun(t *testing.T) {
|
||||
name: "导入 nginx 配置文件不存在",
|
||||
importPath: "/tmp/nginx.conf",
|
||||
wantExitCode: 1,
|
||||
wantErrContains: "解析 nginx 配置失败",
|
||||
wantErrContains: "failed to parse nginx config",
|
||||
},
|
||||
}
|
||||
|
||||
@ -371,8 +371,8 @@ func TestGenerateConfig(t *testing.T) {
|
||||
t.Errorf("exit code = %d, want 1", exitCode)
|
||||
}
|
||||
|
||||
if !strings.Contains(stderr, "写入文件失败") {
|
||||
t.Errorf("stderr 应包含 '写入文件失败', 实际输出: %q", stderr)
|
||||
if !strings.Contains(stderr, "Failed to write file") {
|
||||
t.Errorf("stderr should contain 'Failed to write file', actual: %q", stderr)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -1247,8 +1247,8 @@ func TestGenerateConfig_ErrorCase(t *testing.T) {
|
||||
t.Errorf("exit code = %d, want 1", exitCode)
|
||||
}
|
||||
|
||||
if !strings.Contains(stderr, "写入文件失败") {
|
||||
t.Errorf("stderr 应包含 '写入文件失败', 实际输出: %q", stderr)
|
||||
if !strings.Contains(stderr, "Failed to write file") {
|
||||
t.Errorf("stderr should contain 'Failed to write file', actual: %q", stderr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -6,178 +6,27 @@ package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/http2"
|
||||
"rua.plus/lolly/internal/http3"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
"rua.plus/lolly/internal/resolver"
|
||||
"rua.plus/lolly/internal/server"
|
||||
"rua.plus/lolly/internal/stream"
|
||||
"rua.plus/lolly/internal/variable"
|
||||
)
|
||||
|
||||
// App manages the server lifecycle (Windows version).
|
||||
type App struct {
|
||||
resv resolver.Resolver
|
||||
cfg *config.Config
|
||||
srv *server.Server
|
||||
http3Srv *http3.Server
|
||||
http2Srv *http2.Server
|
||||
streamSrv *stream.Server
|
||||
logger *logging.AppLogger
|
||||
cfgPath string
|
||||
pidFile string
|
||||
logFile string
|
||||
listeners []net.Listener
|
||||
upgradeMgr *server.UpgradeManager
|
||||
}
|
||||
|
||||
func NewApp(cfgPath string) *App {
|
||||
return &App{
|
||||
cfgPath: cfgPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) SetPidFile(path string) {
|
||||
a.pidFile = path
|
||||
}
|
||||
|
||||
func (a *App) SetLogFile(path string) {
|
||||
a.logFile = path
|
||||
}
|
||||
|
||||
// Run starts the application: loads config, creates servers, and handles signals (Windows version).
|
||||
func (a *App) Run() int {
|
||||
cfg, err := config.Load(a.cfgPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "加载配置失败: %v\n", err)
|
||||
if err := a.loadAndValidateConfig(); err != nil {
|
||||
return 1
|
||||
}
|
||||
a.cfg = cfg
|
||||
a.logger = logging.NewAppLogger(&cfg.Logging)
|
||||
|
||||
variable.SetGlobalVariables(cfg.Variables.Set)
|
||||
if len(cfg.Variables.Set) > 0 {
|
||||
a.logger.LogStartup("全局变量已加载", map[string]string{
|
||||
"count": fmt.Sprintf("%d", len(cfg.Variables.Set)),
|
||||
})
|
||||
}
|
||||
|
||||
a.logger.LogStartup("配置加载成功", map[string]string{"config_path": a.cfgPath})
|
||||
|
||||
mode := a.cfg.GetMode()
|
||||
if mode == config.ServerModeMultiServer {
|
||||
for i, srv := range a.cfg.Servers {
|
||||
a.logger.LogStartup("监听地址", map[string]string{
|
||||
"index": fmt.Sprintf("[%d]", i),
|
||||
"listen": srv.Listen,
|
||||
"name": srv.Name,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
a.logger.LogStartup("监听地址", map[string]string{"listen": a.cfg.Servers[0].Listen})
|
||||
}
|
||||
|
||||
if a.cfg.Resolver.Enabled {
|
||||
a.resv = resolver.New(&a.cfg.Resolver)
|
||||
a.logger.LogStartup("DNS 解析器已启用", map[string]string{
|
||||
"addresses": fmt.Sprintf("%v", a.cfg.Resolver.Addresses),
|
||||
"ttl": a.cfg.Resolver.TTL().String(),
|
||||
})
|
||||
}
|
||||
|
||||
a.srv = server.New(a.cfg)
|
||||
|
||||
if a.resv != nil {
|
||||
a.srv.SetResolver(a.resv)
|
||||
}
|
||||
|
||||
if len(a.cfg.Stream) > 0 {
|
||||
a.streamSrv = stream.NewServer()
|
||||
for _, sc := range a.cfg.Stream {
|
||||
targets := make([]stream.TargetSpec, len(sc.Upstream.Targets))
|
||||
for i, t := range sc.Upstream.Targets {
|
||||
targets[i] = stream.TargetSpec{
|
||||
Addr: t.Addr,
|
||||
Weight: t.Weight,
|
||||
}
|
||||
}
|
||||
|
||||
if err := a.streamSrv.AddUpstream(sc.Listen, targets, sc.Upstream.LoadBalance, stream.HealthCheckSpec{}); err != nil {
|
||||
a.logger.Error().Err(err).Msg("添加 Stream 上游失败")
|
||||
}
|
||||
|
||||
if sc.Protocol == "udp" {
|
||||
if err := a.streamSrv.ListenUDP(sc.Listen, sc.Listen, 60*time.Second); err != nil {
|
||||
a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("监听 UDP 失败")
|
||||
}
|
||||
} else {
|
||||
if err := a.streamSrv.ListenTCP(sc.Listen); err != nil {
|
||||
a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("监听 TCP 失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
a.logger.LogStartup("Stream 服务器启动中", nil)
|
||||
if err := a.streamSrv.Start(); err != nil {
|
||||
a.logger.Error().Err(err).Msg("Stream 服务器启动失败")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if a.cfg.HTTP3.Enabled && a.cfg.Servers[0].SSL.Cert != "" {
|
||||
tlsConfig, err := a.srv.GetTLSConfig()
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/3")
|
||||
} else {
|
||||
a.http3Srv, err = http3.NewServer(&a.cfg.HTTP3, a.srv.GetHandler(), tlsConfig)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("创建 HTTP/3 服务器失败")
|
||||
} else {
|
||||
go func() {
|
||||
a.logger.LogStartup("HTTP/3 服务器启动中", map[string]string{"listen": a.cfg.HTTP3.Listen})
|
||||
if err := a.http3Srv.Start(); err != nil {
|
||||
a.logger.Error().Err(err).Msg("HTTP/3 服务器启动失败")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if a.cfg.Servers[0].SSL.HTTP2.Enabled && a.cfg.Servers[0].SSL.Cert != "" {
|
||||
tlsConfig, err := a.srv.GetTLSConfig()
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/2")
|
||||
} else {
|
||||
a.http2Srv, err = http2.NewServer(&a.cfg.Servers[0].SSL.HTTP2, a.srv.GetHandler(), tlsConfig)
|
||||
if err != nil {
|
||||
a.logger.Error().Err(err).Msg("创建 HTTP/2 服务器失败")
|
||||
} else {
|
||||
go func() {
|
||||
a.logger.LogStartup("HTTP/2 服务器启动中", map[string]string{
|
||||
"listen": a.cfg.Servers[0].Listen,
|
||||
"max_concurrent_streams": fmt.Sprintf("%d", a.cfg.Servers[0].SSL.HTTP2.MaxConcurrentStreams),
|
||||
"push_enabled": fmt.Sprintf("%t", a.cfg.Servers[0].SSL.HTTP2.PushEnabled),
|
||||
})
|
||||
listeners := a.srv.GetListeners()
|
||||
if len(listeners) > 0 {
|
||||
if err := a.http2Srv.Serve(listeners[0]); err != nil {
|
||||
a.logger.Error().Err(err).Msg("HTTP/2 服务器启动失败")
|
||||
}
|
||||
} else {
|
||||
a.logger.Error().Msg("HTTP/2 服务器启动失败: 无可用监听器")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
a.initVariables()
|
||||
a.logServerAddresses()
|
||||
a.initResolver()
|
||||
a.initServer()
|
||||
a.initStreamServers()
|
||||
a.initHTTP3()
|
||||
a.initHTTP2()
|
||||
|
||||
a.upgradeMgr = server.NewUpgradeManager(a.srv)
|
||||
if a.pidFile != "" {
|
||||
@ -190,7 +39,7 @@ func (a *App) Run() int {
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
a.logger.LogStartup("HTTP 服务器启动中", nil)
|
||||
a.logger.LogStartup("Starting HTTP server", nil)
|
||||
if err := a.srv.Start(); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
@ -201,18 +50,18 @@ func (a *App) Run() int {
|
||||
for {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
a.logger.Error().Err(err).Msg("服务器启动失败")
|
||||
a.logger.Error().Err(err).Msg("Server failed to start")
|
||||
return 1
|
||||
case sig := <-sigChan:
|
||||
if sig == syscall.SIGINT {
|
||||
sigintCount++
|
||||
if sigintCount >= 3 {
|
||||
a.logger.LogShutdown("收到 3 次 SIGINT,强制退出")
|
||||
a.logger.LogShutdown("Received 3 SIGINT, forcing exit")
|
||||
return 1
|
||||
}
|
||||
}
|
||||
if !a.handleSignal(sig) {
|
||||
a.logger.LogShutdown("服务器已停止")
|
||||
a.logger.LogShutdown("Server stopped")
|
||||
return 0
|
||||
}
|
||||
}
|
||||
@ -234,47 +83,23 @@ func (a *App) handleSignal(sig os.Signal) bool {
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
a.logger.LogSignal(sigName(sig.(syscall.Signal)), "停止服务器")
|
||||
a.logger.LogSignal(sigName(sig.(syscall.Signal)), "Stopping server")
|
||||
a.shutdownHTTP2()
|
||||
a.shutdownHTTP3()
|
||||
_ = a.srv.StopWithTimeout(timeout)
|
||||
return false
|
||||
default:
|
||||
a.logger.Info().Str("signal", sig.String()).Msg("收到信号(Windows 忽略)")
|
||||
a.logger.Info().Str("signal", sig.String()).Msg("Received signal (ignored on Windows)")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) shutdownHTTP3() {
|
||||
if a.http3Srv != nil {
|
||||
if err := a.http3Srv.Stop(); err != nil {
|
||||
a.logger.Error().Err(err).Msg("HTTP/3 服务器关闭失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) shutdownHTTP2() {
|
||||
if a.http2Srv != nil {
|
||||
if err := a.http2Srv.Stop(); err != nil {
|
||||
a.logger.Error().Err(err).Msg("HTTP/2 服务器关闭失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *App) reloadConfig() {
|
||||
// Windows stub - functionality limited
|
||||
}
|
||||
|
||||
func (a *App) reopenLogs() {
|
||||
if a.cfg != nil {
|
||||
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)
|
||||
a.logger = logging.NewAppLogger(&a.cfg.Logging)
|
||||
}
|
||||
a.logger.LogStartup("日志已重新打开", nil)
|
||||
}
|
||||
|
||||
func (a *App) gracefulUpgrade() {
|
||||
a.logger.Info().Msg("Windows 不支持热升级")
|
||||
a.logger.Info().Msg("Graceful upgrade is not supported on Windows")
|
||||
}
|
||||
|
||||
func sigName(sig syscall.Signal) string {
|
||||
|
||||
@ -48,7 +48,7 @@ func generateConfig(outputPath string) int {
|
||||
cfg := config.DefaultConfig()
|
||||
yamlData, err := config.GenerateConfigYAML(cfg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "生成配置失败: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Failed to generate config: %v\n", err)
|
||||
return 1
|
||||
}
|
||||
|
||||
@ -56,10 +56,10 @@ func generateConfig(outputPath string) int {
|
||||
fmt.Print(string(yamlData))
|
||||
} else {
|
||||
if err := os.WriteFile(outputPath, yamlData, 0o644); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "写入文件失败: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Failed to write file: %v\n", err)
|
||||
return 1
|
||||
}
|
||||
fmt.Printf("配置已写入: %s\n", outputPath)
|
||||
fmt.Printf("Config written to: %s\n", outputPath)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@ -67,12 +67,12 @@ func generateConfig(outputPath string) int {
|
||||
func importNginxConfig(path, outputPath string) error {
|
||||
nginxCfg, err := nginx.ParseFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析 nginx 配置失败: %w", err)
|
||||
return fmt.Errorf("failed to parse nginx config: %w", err)
|
||||
}
|
||||
|
||||
result, err := nginx.Convert(nginxCfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("转换配置失败: %w", err)
|
||||
return fmt.Errorf("failed to convert config: %w", err)
|
||||
}
|
||||
|
||||
for _, w := range result.Warnings {
|
||||
@ -80,26 +80,26 @@ func importNginxConfig(path, outputPath string) error {
|
||||
}
|
||||
|
||||
if validateErr := config.Validate(result.Config); validateErr != nil {
|
||||
return fmt.Errorf("转换后配置验证失败: %w", validateErr)
|
||||
return fmt.Errorf("converted config validation failed: %w", validateErr)
|
||||
}
|
||||
|
||||
yamlData, err := yaml.Marshal(result.Config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化 YAML 失败: %w", err)
|
||||
return fmt.Errorf("failed to marshal YAML: %w", err)
|
||||
}
|
||||
|
||||
if outputPath == "" {
|
||||
if _, err := os.Stdout.Write(yamlData); err != nil {
|
||||
return fmt.Errorf("写入标准输出失败: %w", err)
|
||||
return fmt.Errorf("failed to write to stdout: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := os.MkdirAll(filepath.Dir(outputPath), 0o755); err != nil {
|
||||
return fmt.Errorf("创建输出目录失败: %w", err)
|
||||
return fmt.Errorf("failed to create output directory: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(outputPath, yamlData, 0o644); err != nil {
|
||||
return fmt.Errorf("写入文件失败: %w", err)
|
||||
return fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
fmt.Printf("配置已写入: %s\n", outputPath)
|
||||
fmt.Printf("Config written to: %s\n", outputPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
78
internal/cache/purge.go
vendored
78
internal/cache/purge.go
vendored
@ -12,7 +12,6 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"hash/fnv"
|
||||
"net"
|
||||
@ -20,7 +19,7 @@ import (
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
"rua.plus/lolly/internal/utils"
|
||||
)
|
||||
|
||||
// PurgeAPI 缓存清理 API 处理器。
|
||||
@ -117,26 +116,26 @@ func (p *PurgeAPI) Path() string {
|
||||
func (p *PurgeAPI) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
// 仅允许 POST 方法
|
||||
if string(ctx.Method()) != "POST" {
|
||||
p.sendError(ctx, fasthttp.StatusMethodNotAllowed, "method not allowed")
|
||||
utils.SendJSONError(ctx, fasthttp.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 IP 访问权限
|
||||
if !p.checkAccess(ctx) {
|
||||
p.sendError(ctx, fasthttp.StatusForbidden, "forbidden")
|
||||
if !utils.CheckIPAccess(ctx, p.allowed) {
|
||||
utils.SendJSONError(ctx, fasthttp.StatusForbidden, "forbidden")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查认证
|
||||
if !p.checkAuth(ctx) {
|
||||
p.sendError(ctx, fasthttp.StatusUnauthorized, "unauthorized")
|
||||
if !utils.CheckTokenAuth(ctx, p.auth) {
|
||||
utils.SendJSONError(ctx, fasthttp.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var req PurgeRequest
|
||||
if err := json.Unmarshal(ctx.PostBody(), &req); err != nil {
|
||||
p.sendError(ctx, fasthttp.StatusBadRequest, "invalid request body")
|
||||
utils.SendJSONError(ctx, fasthttp.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
@ -147,7 +146,7 @@ func (p *PurgeAPI) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
} else if req.Pattern != "" {
|
||||
deleted = p.purgeByPattern(req.Pattern)
|
||||
} else {
|
||||
p.sendError(ctx, fasthttp.StatusBadRequest, "missing path or pattern")
|
||||
utils.SendJSONError(ctx, fasthttp.StatusBadRequest, "missing path or pattern")
|
||||
return
|
||||
}
|
||||
|
||||
@ -157,61 +156,6 @@ func (p *PurgeAPI) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
_ = json.NewEncoder(ctx).Encode(PurgeResponse{Deleted: deleted})
|
||||
}
|
||||
|
||||
// checkAccess 检查客户端 IP 是否在允许列表中。
|
||||
func (p *PurgeAPI) checkAccess(ctx *fasthttp.RequestCtx) bool {
|
||||
// 如果没有配置允许列表,允许所有访问
|
||||
if len(p.allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
clientIP := p.getClientIP(ctx)
|
||||
if clientIP == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否在允许列表中
|
||||
for _, network := range p.allowed {
|
||||
if network.Contains(clientIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// checkAuth 检查认证。
|
||||
func (p *PurgeAPI) checkAuth(ctx *fasthttp.RequestCtx) bool {
|
||||
// 无需认证
|
||||
if p.auth.Type == "" || p.auth.Type == "none" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Token 认证
|
||||
if p.auth.Type == "token" {
|
||||
// 从 Authorization header 获取 token
|
||||
authHeader := ctx.Request.Header.Peek("Authorization")
|
||||
if len(authHeader) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 支持 Bearer token 格式
|
||||
authStr := string(authHeader)
|
||||
if token, ok := strings.CutPrefix(authStr, "Bearer "); ok {
|
||||
return subtle.ConstantTimeCompare([]byte(token), []byte(p.auth.Token)) == 1
|
||||
}
|
||||
|
||||
// 也支持直接传递 token
|
||||
return subtle.ConstantTimeCompare([]byte(authStr), []byte(p.auth.Token)) == 1
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getClientIP 从请求上下文提取客户端 IP。
|
||||
func (p *PurgeAPI) getClientIP(ctx *fasthttp.RequestCtx) net.IP {
|
||||
return netutil.ExtractClientIPNet(ctx)
|
||||
}
|
||||
|
||||
// purgeByPath 按精确路径清理缓存。
|
||||
func (p *PurgeAPI) purgeByPath(path string) int {
|
||||
if p.cache == nil {
|
||||
@ -323,9 +267,3 @@ func matchPattern(pattern, path string) bool {
|
||||
return MatchPattern(pattern, path)
|
||||
}
|
||||
|
||||
// sendError 发送错误响应。
|
||||
func (p *PurgeAPI) sendError(ctx *fasthttp.RequestCtx, status int, errMsg string) {
|
||||
ctx.SetContentType("application/json; charset=utf-8")
|
||||
ctx.SetStatusCode(status)
|
||||
_ = json.NewEncoder(ctx).Encode(PurgeErrorResponse{Error: errMsg})
|
||||
}
|
||||
|
||||
@ -525,7 +525,7 @@ func validateProxy(p *ProxyConfig) error {
|
||||
}
|
||||
// 使用 regexp.Compile 验证正则语法有效性
|
||||
if _, err := regexp.Compile(p.Path); err != nil {
|
||||
return fmt.Errorf("location_type 为 '%s' 时,path '%s' 不是有效正则: %v", p.LocationType, p.Path, err)
|
||||
return fmt.Errorf("location_type 为 '%s' 时,path '%s' 不是有效正则: %w", p.LocationType, p.Path, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1273,7 +1273,7 @@ func validateRedirectRewrite(cfg *RedirectRewriteConfig) error {
|
||||
patternStr = rule.Pattern[2:] // 去掉 ~* 前缀(大小写不敏感)
|
||||
}
|
||||
if _, err := regexp.Compile(patternStr); err != nil {
|
||||
return fmt.Errorf("redirect_rewrite.rules[%d].pattern invalid regex: %v", i, err)
|
||||
return fmt.Errorf("redirect_rewrite.rules[%d].pattern invalid regex: %w", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
@ -106,7 +107,11 @@ func NewErrorPageManager(cfg *config.ErrorPageConfig) (*ErrorPageManager, error)
|
||||
|
||||
if len(loadErrors) == totalPages {
|
||||
// 全部加载失败,返回错误
|
||||
return nil, fmt.Errorf("所有错误页面加载失败: %v", loadErrors)
|
||||
errs := make([]error, 0, len(loadErrors))
|
||||
for _, e := range loadErrors {
|
||||
errs = append(errs, e)
|
||||
}
|
||||
return nil, fmt.Errorf("所有错误页面加载失败: %w", errors.Join(errs...))
|
||||
}
|
||||
// 部分失败,记录警告(由调用者处理)
|
||||
return manager, &PartialLoadError{Errors: loadErrors}
|
||||
|
||||
@ -201,15 +201,9 @@ func (s *Server) Stop() error {
|
||||
s.running = false
|
||||
|
||||
if s.http3Server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.http3Server.Close(); err != nil {
|
||||
logging.Error().Err(err).Msg("HTTP/3 server close error")
|
||||
}
|
||||
|
||||
// 等待服务完全停止
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
logging.Info().Msg("HTTP/3 server stopped")
|
||||
|
||||
@ -25,6 +25,7 @@ import (
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/utils"
|
||||
"rua.plus/lolly/internal/variable"
|
||||
)
|
||||
|
||||
@ -179,7 +180,7 @@ func (m *Middleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandl
|
||||
for ruleIndex < len(m.rules) {
|
||||
// 步骤1: 检查迭代次数是否超过限制(防止无限循环)
|
||||
if iterationCount >= MaxRewriteIterations {
|
||||
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
|
||||
utils.SendError(ctx, utils.ErrInternalError)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -41,6 +41,7 @@ import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/middleware"
|
||||
"rua.plus/lolly/internal/utils"
|
||||
)
|
||||
|
||||
// HashAlgorithm 表示密码哈希算法类型。
|
||||
@ -381,7 +382,7 @@ func (ba *BasicAuth) extractCredentials(ctx *fasthttp.RequestCtx) (string, strin
|
||||
func (ba *BasicAuth) sendAuthChallenge(ctx *fasthttp.RequestCtx) {
|
||||
ctx.Response.Header.Set("WWW-Authenticate",
|
||||
fmt.Sprintf("Basic realm=\"%s\", charset=\"UTF-8\"", ba.realm))
|
||||
ctx.Error("Unauthorized", fasthttp.StatusUnauthorized)
|
||||
utils.SendError(ctx, utils.ErrUnauthorized)
|
||||
}
|
||||
|
||||
// AddUser 动态添加新用户。
|
||||
|
||||
197
internal/proxy/cache_handler.go
Normal file
197
internal/proxy/cache_handler.go
Normal file
@ -0,0 +1,197 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/cache"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
)
|
||||
|
||||
// buildCacheKey 构建缓存键字符串。
|
||||
//
|
||||
// 使用请求方法和完整请求 URI 作为缓存键。
|
||||
// 该函数保留用于日志记录和调试场景。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
//
|
||||
// 返回值:
|
||||
// - string: 缓存键(格式 "METHOD:URI")
|
||||
func (p *Proxy) buildCacheKey(ctx *fasthttp.RequestCtx) string {
|
||||
// 使用请求方法和路径作为缓存键
|
||||
return string(ctx.Request.Header.Method()) + ":" + string(ctx.Request.URI().RequestURI())
|
||||
}
|
||||
|
||||
// buildCacheKeyHash 使用 FNV-64a 计算缓存键的 uint64 哈希值。
|
||||
// 返回哈希值和原始字符串键。
|
||||
// 注意:此函数会先构建字符串键再哈希,存在双重分配。
|
||||
// 对于只需要哈希值的场景,使用 buildCacheKeyHashValue 代替。
|
||||
func (p *Proxy) buildCacheKeyHash(ctx *fasthttp.RequestCtx) (uint64, string) {
|
||||
// 构建原始 key
|
||||
origKey := p.buildCacheKey(ctx)
|
||||
|
||||
// 使用 FNV-64a 计算哈希
|
||||
h := fnv.New64a()
|
||||
h.Write([]byte(origKey))
|
||||
return h.Sum64(), origKey
|
||||
}
|
||||
|
||||
// buildCacheKeyHashValue 直接计算缓存键的哈希值,零字符串分配。
|
||||
// 用于只需要哈希值而不需要原始键的场景。
|
||||
func (p *Proxy) buildCacheKeyHashValue(ctx *fasthttp.RequestCtx) uint64 {
|
||||
h := fnv.New64a()
|
||||
h.Write(ctx.Request.Header.Method())
|
||||
h.Write([]byte(":"))
|
||||
h.Write(ctx.Request.URI().RequestURI())
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// writeCachedResponse 将缓存的响应写入 FastHTTP 响应上下文。
|
||||
//
|
||||
// 设置响应体、状态码、响应头,并添加 X-Cache: HIT 头标记缓存命中。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
// - entry: 缓存条目,包含响应数据和元数据
|
||||
func (p *Proxy) writeCachedResponse(ctx *fasthttp.RequestCtx, entry *cache.ProxyCacheEntry) {
|
||||
ctx.Response.SetBody(entry.Data)
|
||||
ctx.Response.SetStatusCode(entry.Status)
|
||||
for key, value := range entry.Headers {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
ctx.Response.Header.Set("X-Cache", "HIT")
|
||||
}
|
||||
|
||||
// backgroundRefresh 在后台异步刷新缓存条目。
|
||||
//
|
||||
// 向对应的上游目标发送请求,获取最新响应并更新缓存。
|
||||
// 该方法在独立 goroutine 中运行,不阻塞主请求流程。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 原始 FastHTTP 请求上下文(仅用于复制请求信息)
|
||||
// - target: 要刷新的后端目标
|
||||
// - hashKey: 缓存哈希键
|
||||
// - origKey: 缓存原始键
|
||||
func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance.Target, hashKey uint64, origKey string) {
|
||||
// 创建新的请求上下文副本
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// 复制原始请求
|
||||
ctx.Request.CopyTo(req)
|
||||
|
||||
// 如果启用 Revalidate,添加条件请求头
|
||||
if p.config.Cache.Revalidate {
|
||||
if entry, ok, _ := p.cache.Get(hashKey, origKey); ok {
|
||||
if entry.LastModified != "" {
|
||||
req.Header.Set("If-Modified-Since", entry.LastModified)
|
||||
}
|
||||
if entry.ETag != "" {
|
||||
req.Header.Set("If-None-Match", entry.ETag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取客户端
|
||||
client := p.getClient(target.URL)
|
||||
if client == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
err := client.Do(req, resp)
|
||||
if err != nil {
|
||||
p.cache.ReleaseLock(hashKey, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理 304 Not Modified 响应
|
||||
if resp.StatusCode() == 304 {
|
||||
newHeaders := make(map[string]string)
|
||||
if lm := resp.Header.Peek("Last-Modified"); len(lm) > 0 {
|
||||
newHeaders["Last-Modified"] = string(lm)
|
||||
}
|
||||
if et := resp.Header.Peek("ETag"); len(et) > 0 {
|
||||
newHeaders["ETag"] = string(et)
|
||||
}
|
||||
p.cache.RefreshTTL(hashKey, origKey, newHeaders)
|
||||
return
|
||||
}
|
||||
|
||||
// 提取响应头(使用 pool 复用 map)
|
||||
headers, ok := headersPool.Get().(map[string]string)
|
||||
if !ok {
|
||||
headers = make(map[string]string, 20)
|
||||
}
|
||||
for k := range headers {
|
||||
delete(headers, k)
|
||||
}
|
||||
for key, value := range resp.Header.All() {
|
||||
headers[string(key)] = string(value)
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
p.cache.Set(hashKey, origKey, resp.Body(), headers, resp.StatusCode(), p.getCacheDuration(resp.StatusCode()))
|
||||
}
|
||||
|
||||
// GetCache 返回代理的 ProxyCache 实例(用于 purge handler)。
|
||||
// 如果缓存未启用,返回 nil。
|
||||
func (p *Proxy) GetCache() *cache.ProxyCache {
|
||||
return p.cache
|
||||
}
|
||||
|
||||
// GetCacheStats 返回代理缓存的统计信息。
|
||||
// 如果缓存未启用,返回 nil。
|
||||
func (p *Proxy) GetCacheStats() *cache.ProxyCacheStats {
|
||||
if p.cache == nil {
|
||||
return nil
|
||||
}
|
||||
stats := p.cache.Stats()
|
||||
return &stats
|
||||
}
|
||||
|
||||
// getCacheDuration 根据状态码获取缓存时间。
|
||||
// 优先级:CacheValid 配置 > MaxAge
|
||||
//
|
||||
// 映射规则:
|
||||
// - 200-299: CacheValid.OK(0 时继承 MaxAge)
|
||||
// - 301/302: CacheValid.Redirect
|
||||
// - 404: CacheValid.NotFound
|
||||
// - 400-499(除 404): CacheValid.ClientError
|
||||
// - 500-599: CacheValid.ServerError
|
||||
// - 其他: 不缓存(返回 0)
|
||||
func (p *Proxy) getCacheDuration(statusCode int) time.Duration {
|
||||
// 无 CacheValid 配置,使用 MaxAge
|
||||
if p.config.CacheValid == nil {
|
||||
return p.config.Cache.MaxAge
|
||||
}
|
||||
|
||||
cv := p.config.CacheValid
|
||||
|
||||
switch {
|
||||
case statusCode >= 200 && statusCode < 300:
|
||||
if cv.OK > 0 {
|
||||
return cv.OK
|
||||
}
|
||||
return p.config.Cache.MaxAge // 0 表示继承 MaxAge
|
||||
|
||||
case statusCode == 301 || statusCode == 302:
|
||||
return cv.Redirect // 0 表示不缓存
|
||||
|
||||
case statusCode == 404:
|
||||
return cv.NotFound
|
||||
|
||||
case statusCode >= 400 && statusCode < 500:
|
||||
return cv.ClientError
|
||||
|
||||
case statusCode >= 500:
|
||||
return cv.ServerError
|
||||
|
||||
default:
|
||||
return 0 // 不缓存
|
||||
}
|
||||
}
|
||||
166
internal/proxy/header_modifier.go
Normal file
166
internal/proxy/header_modifier.go
Normal file
@ -0,0 +1,166 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
"rua.plus/lolly/internal/variable"
|
||||
)
|
||||
|
||||
// modifyRequestHeaders 在转发请求到后端之前修改请求头。
|
||||
//
|
||||
// 执行以下操作:
|
||||
// 1. 设置 Host header 为目标主机地址
|
||||
// 2. 提取并设置 X-Forwarded-For、X-Real-IP、X-Forwarded-Host、X-Forwarded-Proto
|
||||
// 3. 应用自定义请求头配置(支持变量展开)
|
||||
// 4. 移除配置的请求头
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
// - target: 选中的后端目标
|
||||
func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalance.Target) {
|
||||
headers := &ctx.Request.Header
|
||||
|
||||
// 设置 Host header 为目标主机
|
||||
// 从 target.URL 提取 host:port(HostClient 连接需要此格式)
|
||||
targetHost := extractHostFromURL(target.URL)
|
||||
if targetHost != "" {
|
||||
headers.Set("Host", targetHost)
|
||||
}
|
||||
|
||||
// 提取并设置 X-Forwarded 系列头
|
||||
fh := ExtractForwardedHeaders(ctx)
|
||||
SetForwardedHeaders(headers, fh, true)
|
||||
|
||||
// 从配置设置自定义请求头(支持变量展开)
|
||||
if p.config.Headers.SetRequest != nil {
|
||||
vc := variable.NewContext(ctx)
|
||||
defer variable.ReleaseContext(vc)
|
||||
for key, value := range p.config.Headers.SetRequest {
|
||||
expanded := vc.Expand(value)
|
||||
if containsCRLF(expanded) {
|
||||
logging.Warn().Msgf("rejected CRLF in header value: %s", key)
|
||||
continue
|
||||
}
|
||||
headers.Set(key, expanded)
|
||||
}
|
||||
}
|
||||
|
||||
// 移除配置的请求头
|
||||
if len(p.config.Headers.Remove) > 0 {
|
||||
for _, key := range p.config.Headers.Remove {
|
||||
headers.Del(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// modifyResponseHeaders 在发送给客户端之前修改响应头。
|
||||
//
|
||||
// 应用自定义响应头配置,支持变量展开(如 $upstream_addr、$status 等)。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
func (p *Proxy) modifyResponseHeaders(ctx *fasthttp.RequestCtx) {
|
||||
respHeaders := &ctx.Response.Header
|
||||
|
||||
// 构建 PassResponse 集合(多处使用)
|
||||
passSet := make(map[string]bool, len(p.config.Headers.PassResponse))
|
||||
for _, h := range p.config.Headers.PassResponse {
|
||||
passSet[h] = true
|
||||
}
|
||||
|
||||
// PassResponse 白名单模式:仅传递列出的头部
|
||||
if len(passSet) > 0 {
|
||||
var toDelete []string
|
||||
for key := range respHeaders.All() {
|
||||
if !passSet[string(key)] {
|
||||
toDelete = append(toDelete, string(key))
|
||||
}
|
||||
}
|
||||
for _, k := range toDelete {
|
||||
respHeaders.Del(k)
|
||||
}
|
||||
}
|
||||
|
||||
// HideResponse:移除指定的响应头(PassResponse 优先,跳过已传递的头部)
|
||||
for _, key := range p.config.Headers.HideResponse {
|
||||
if !passSet[key] {
|
||||
respHeaders.Del(key)
|
||||
}
|
||||
}
|
||||
|
||||
// IgnoreHeaders:从请求和响应中移除(PassResponse 优先)
|
||||
for _, key := range p.config.Headers.IgnoreHeaders {
|
||||
ctx.Request.Header.Del(key)
|
||||
if !passSet[key] {
|
||||
respHeaders.Del(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Cookie 域/路径重写
|
||||
if p.config.Headers.CookieDomain != "" || p.config.Headers.CookiePath != "" {
|
||||
p.rewriteCookies(respHeaders)
|
||||
}
|
||||
|
||||
// 从配置设置自定义响应头(支持变量展开)
|
||||
if p.config.Headers.SetResponse != nil {
|
||||
vc := variable.NewContext(ctx)
|
||||
defer variable.ReleaseContext(vc)
|
||||
for key, value := range p.config.Headers.SetResponse {
|
||||
expanded := vc.Expand(value)
|
||||
if containsCRLF(expanded) {
|
||||
logging.Warn().Msgf("rejected CRLF in header value: %s", key)
|
||||
continue
|
||||
}
|
||||
respHeaders.Set(key, expanded)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteCookies 重写响应中 Set-Cookie 头的 domain 和 path。
|
||||
func (p *Proxy) rewriteCookies(respHeaders *fasthttp.ResponseHeader) {
|
||||
cookieDomain := p.config.Headers.CookieDomain
|
||||
cookiePath := p.config.Headers.CookiePath
|
||||
if cookieDomain == "" && cookiePath == "" {
|
||||
return
|
||||
}
|
||||
|
||||
cookies := make([]string, 0, respHeaders.Len())
|
||||
for _, value := range respHeaders.Cookies() {
|
||||
cookie := string(value)
|
||||
if cookieDomain != "" {
|
||||
cookie = rewriteCookieAttr(cookie, "Domain", cookieDomain)
|
||||
}
|
||||
if cookiePath != "" {
|
||||
cookie = rewriteCookieAttr(cookie, "Path", cookiePath)
|
||||
}
|
||||
cookies = append(cookies, cookie)
|
||||
}
|
||||
|
||||
if len(cookies) > 0 {
|
||||
respHeaders.Del("Set-Cookie")
|
||||
for _, c := range cookies {
|
||||
respHeaders.Add("Set-Cookie", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteCookieAttr 替换 Cookie 字符串中指定属性的值(大小写不敏感)。
|
||||
func rewriteCookieAttr(cookie, attr, newValue string) string {
|
||||
prefix := attr + "="
|
||||
lower := strings.ToLower(cookie)
|
||||
idx := strings.Index(lower, strings.ToLower(prefix))
|
||||
if idx == -1 {
|
||||
return cookie
|
||||
}
|
||||
|
||||
start := idx + len(prefix)
|
||||
end := start
|
||||
for end < len(cookie) && cookie[end] != ';' && cookie[end] != ' ' {
|
||||
end++
|
||||
}
|
||||
|
||||
return cookie[:start] + newValue + cookie[end:]
|
||||
}
|
||||
@ -33,10 +33,8 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"net"
|
||||
urlpath "path"
|
||||
"slices"
|
||||
@ -46,7 +44,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
"rua.plus/lolly/internal/cache"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
@ -902,352 +899,6 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
|
||||
return p.selectByBalancer(ctx, targets)
|
||||
}
|
||||
|
||||
// selectByLua 使用 Lua 脚本选择后端目标。
|
||||
//
|
||||
// 执行配置的 Lua 脚本,脚本可通过 ngx.balancer.set_current_peer() 选择目标。
|
||||
// 如果 Lua 脚本执行失败或未调用 set_current_peer,返回 nil 表示需要使用 fallback 算法。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
// - targets: 候选目标列表
|
||||
//
|
||||
// 返回值:
|
||||
// - *loadbalance.Target: Lua 脚本选中的目标,nil 表示未选择
|
||||
// - error: Lua 执行失败时返回错误
|
||||
func (p *Proxy) selectByLua(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) (*loadbalance.Target, error) {
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
|
||||
bctx := &lua.BalancerContext{
|
||||
Targets: targets,
|
||||
ClientIP: clientIP,
|
||||
Retries: p.config.NextUpstream.Tries,
|
||||
}
|
||||
|
||||
// 创建 Lua 协程
|
||||
coro, err := p.luaEngine.NewCoroutine(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create lua coroutine: %w", err)
|
||||
}
|
||||
defer coro.Close()
|
||||
|
||||
// 注册 balancer API
|
||||
L := coro.Co
|
||||
ngx, ok := L.GetGlobal("ngx").(*glua.LTable)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("global 'ngx' is not an LTable")
|
||||
}
|
||||
lua.RegisterBalancerAPI(L, bctx, ngx)
|
||||
|
||||
// 设置超时
|
||||
timeout := p.config.BalancerByLua.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = 100 * time.Millisecond
|
||||
}
|
||||
|
||||
// 执行脚本(带超时)
|
||||
execCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
coro.ExecutionContext = execCtx
|
||||
|
||||
err = coro.ExecuteFile(p.config.BalancerByLua.Script)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("execute lua script: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否调用了 set_current_peer
|
||||
if !bctx.IsSelected() {
|
||||
return nil, nil // 未选择,返回 nil 表示需使用 fallback
|
||||
}
|
||||
|
||||
return bctx.Selected, nil
|
||||
}
|
||||
|
||||
// selectByFallback 使用 fallback 负载均衡算法选择目标。
|
||||
//
|
||||
// 当 Lua balancer 执行失败或未选择目标时使用。
|
||||
// 对于 IPHash 算法,会自动提取客户端 IP 进行哈希选择。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
// - targets: 候选目标列表
|
||||
//
|
||||
// 返回值:
|
||||
// - *loadbalance.Target: fallback 算法选中的目标
|
||||
func (p *Proxy) selectByFallback(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target {
|
||||
p.mu.RLock()
|
||||
balancer := p.fallbackBalancer
|
||||
p.mu.RUnlock()
|
||||
|
||||
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
return ipHash.SelectByIP(targets, clientIP)
|
||||
}
|
||||
|
||||
return balancer.Select(targets)
|
||||
}
|
||||
|
||||
// selectByBalancer 使用主负载均衡器选择目标。
|
||||
//
|
||||
// 对于特殊算法(IPHash、ConsistentHash),会从请求上下文中提取
|
||||
// 相应的哈希键(客户端 IP、URI、自定义 Header 等)。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
// - targets: 候选目标列表
|
||||
//
|
||||
// 返回值:
|
||||
// - *loadbalance.Target: 主负载均衡器选中的目标
|
||||
func (p *Proxy) selectByBalancer(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target {
|
||||
p.mu.RLock()
|
||||
balancer := p.balancer
|
||||
p.mu.RUnlock()
|
||||
|
||||
// 对于 IPHash 负载均衡器,提取客户端 IP
|
||||
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
return ipHash.SelectByIP(targets, clientIP)
|
||||
}
|
||||
|
||||
// 对于一致性哈希,根据 hash_key 配置选择
|
||||
if ch, ok := balancer.(*loadbalance.ConsistentHash); ok {
|
||||
hashKey := ch.GetHashKey()
|
||||
key := p.extractHashKey(ctx, hashKey)
|
||||
return ch.SelectByKey(targets, key)
|
||||
}
|
||||
|
||||
return balancer.Select(targets)
|
||||
}
|
||||
|
||||
// selectTargetExcluding 选择后端目标,排除已尝试失败的目标。
|
||||
// 用于故障转移场景,避免重复选择已失败的目标。
|
||||
// 如果没有可用的健康目标则返回 nil。
|
||||
func (p *Proxy) selectTargetExcluding(ctx *fasthttp.RequestCtx, excluded []*loadbalance.Target) *loadbalance.Target {
|
||||
p.mu.RLock()
|
||||
balancer := p.balancer
|
||||
targets := p.targets
|
||||
p.mu.RUnlock()
|
||||
|
||||
if len(targets) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 对于 IPHash 负载均衡器,提取客户端 IP
|
||||
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
return ipHash.SelectExcludingByIP(targets, excluded, clientIP)
|
||||
}
|
||||
|
||||
// 对于一致性哈希,根据 hash_key 配置选择
|
||||
if ch, ok := balancer.(*loadbalance.ConsistentHash); ok {
|
||||
hashKey := ch.GetHashKey()
|
||||
key := p.extractHashKey(ctx, hashKey)
|
||||
return ch.SelectExcludingByKey(targets, excluded, key)
|
||||
}
|
||||
|
||||
return balancer.SelectExcluding(targets, excluded)
|
||||
}
|
||||
|
||||
// extractHashKey 根据一致性哈希配置提取哈希键值。
|
||||
//
|
||||
// 支持的 hash_key 配置:
|
||||
// - "ip" 或 "": 使用客户端 IP 地址
|
||||
// - "uri": 使用完整请求 URI
|
||||
// - "header:NAME": 使用指定请求头的值,缺失时回退到客户端 IP
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
// - hashKey: 哈希键配置
|
||||
//
|
||||
// 返回值:
|
||||
// - string: 提取的哈希键值
|
||||
func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string {
|
||||
switch {
|
||||
case hashKey == "ip" || hashKey == "":
|
||||
return netutil.ExtractClientIP(ctx)
|
||||
case hashKey == "uri":
|
||||
return string(ctx.RequestURI())
|
||||
case strings.HasPrefix(hashKey, "header:"):
|
||||
headerName := strings.TrimPrefix(hashKey, "header:")
|
||||
value := ctx.Request.Header.Peek(headerName)
|
||||
if len(value) > 0 {
|
||||
return string(value)
|
||||
}
|
||||
return netutil.ExtractClientIP(ctx) // fallback to IP
|
||||
default:
|
||||
return netutil.ExtractClientIP(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// getClient 返回指定目标 URL 对应的 HostClient 连接池实例。
|
||||
// 如果目标 URL 不存在于连接池中,返回 nil。
|
||||
func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient {
|
||||
key := targetURL
|
||||
if p.config.ProxyBind != "" {
|
||||
key = targetURL + "|" + p.config.ProxyBind
|
||||
}
|
||||
p.mu.RLock()
|
||||
client := p.clients[key]
|
||||
p.mu.RUnlock()
|
||||
return client
|
||||
}
|
||||
|
||||
// modifyRequestHeaders 在转发请求到后端之前修改请求头。
|
||||
//
|
||||
// 执行以下操作:
|
||||
// 1. 设置 Host header 为目标主机地址
|
||||
// 2. 提取并设置 X-Forwarded-For、X-Real-IP、X-Forwarded-Host、X-Forwarded-Proto
|
||||
// 3. 应用自定义请求头配置(支持变量展开)
|
||||
// 4. 移除配置的请求头
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
// - target: 选中的后端目标
|
||||
func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalance.Target) {
|
||||
headers := &ctx.Request.Header
|
||||
|
||||
// 设置 Host header 为目标主机
|
||||
// 从 target.URL 提取 host:port(HostClient 连接需要此格式)
|
||||
targetHost := extractHostFromURL(target.URL)
|
||||
if targetHost != "" {
|
||||
headers.Set("Host", targetHost)
|
||||
}
|
||||
|
||||
// 提取并设置 X-Forwarded 系列头
|
||||
fh := ExtractForwardedHeaders(ctx)
|
||||
SetForwardedHeaders(headers, fh, true)
|
||||
|
||||
// 从配置设置自定义请求头(支持变量展开)
|
||||
if p.config.Headers.SetRequest != nil {
|
||||
vc := variable.NewContext(ctx)
|
||||
defer variable.ReleaseContext(vc)
|
||||
for key, value := range p.config.Headers.SetRequest {
|
||||
expanded := vc.Expand(value)
|
||||
if containsCRLF(expanded) {
|
||||
logging.Warn().Msgf("rejected CRLF in header value: %s", key)
|
||||
continue
|
||||
}
|
||||
headers.Set(key, expanded)
|
||||
}
|
||||
}
|
||||
|
||||
// 移除配置的请求头
|
||||
if len(p.config.Headers.Remove) > 0 {
|
||||
for _, key := range p.config.Headers.Remove {
|
||||
headers.Del(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// modifyResponseHeaders 在发送给客户端之前修改响应头。
|
||||
//
|
||||
// 应用自定义响应头配置,支持变量展开(如 $upstream_addr、$status 等)。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
func (p *Proxy) modifyResponseHeaders(ctx *fasthttp.RequestCtx) {
|
||||
respHeaders := &ctx.Response.Header
|
||||
|
||||
// 构建 PassResponse 集合(多处使用)
|
||||
passSet := make(map[string]bool, len(p.config.Headers.PassResponse))
|
||||
for _, h := range p.config.Headers.PassResponse {
|
||||
passSet[h] = true
|
||||
}
|
||||
|
||||
// PassResponse 白名单模式:仅传递列出的头部
|
||||
if len(passSet) > 0 {
|
||||
var toDelete []string
|
||||
for key := range respHeaders.All() {
|
||||
if !passSet[string(key)] {
|
||||
toDelete = append(toDelete, string(key))
|
||||
}
|
||||
}
|
||||
for _, k := range toDelete {
|
||||
respHeaders.Del(k)
|
||||
}
|
||||
}
|
||||
|
||||
// HideResponse:移除指定的响应头(PassResponse 优先,跳过已传递的头部)
|
||||
for _, key := range p.config.Headers.HideResponse {
|
||||
if !passSet[key] {
|
||||
respHeaders.Del(key)
|
||||
}
|
||||
}
|
||||
|
||||
// IgnoreHeaders:从请求和响应中移除(PassResponse 优先)
|
||||
for _, key := range p.config.Headers.IgnoreHeaders {
|
||||
ctx.Request.Header.Del(key)
|
||||
if !passSet[key] {
|
||||
respHeaders.Del(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Cookie 域/路径重写
|
||||
if p.config.Headers.CookieDomain != "" || p.config.Headers.CookiePath != "" {
|
||||
p.rewriteCookies(respHeaders)
|
||||
}
|
||||
|
||||
// 从配置设置自定义响应头(支持变量展开)
|
||||
if p.config.Headers.SetResponse != nil {
|
||||
vc := variable.NewContext(ctx)
|
||||
defer variable.ReleaseContext(vc)
|
||||
for key, value := range p.config.Headers.SetResponse {
|
||||
expanded := vc.Expand(value)
|
||||
if containsCRLF(expanded) {
|
||||
logging.Warn().Msgf("rejected CRLF in header value: %s", key)
|
||||
continue
|
||||
}
|
||||
respHeaders.Set(key, expanded)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteCookies 重写响应中 Set-Cookie 头的 domain 和 path。
|
||||
func (p *Proxy) rewriteCookies(respHeaders *fasthttp.ResponseHeader) {
|
||||
cookieDomain := p.config.Headers.CookieDomain
|
||||
cookiePath := p.config.Headers.CookiePath
|
||||
if cookieDomain == "" && cookiePath == "" {
|
||||
return
|
||||
}
|
||||
|
||||
cookies := make([]string, 0, respHeaders.Len())
|
||||
for _, value := range respHeaders.Cookies() {
|
||||
cookie := string(value)
|
||||
if cookieDomain != "" {
|
||||
cookie = rewriteCookieAttr(cookie, "Domain", cookieDomain)
|
||||
}
|
||||
if cookiePath != "" {
|
||||
cookie = rewriteCookieAttr(cookie, "Path", cookiePath)
|
||||
}
|
||||
cookies = append(cookies, cookie)
|
||||
}
|
||||
|
||||
if len(cookies) > 0 {
|
||||
respHeaders.Del("Set-Cookie")
|
||||
for _, c := range cookies {
|
||||
respHeaders.Add("Set-Cookie", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteCookieAttr 替换 Cookie 字符串中指定属性的值(大小写不敏感)。
|
||||
func rewriteCookieAttr(cookie, attr, newValue string) string {
|
||||
prefix := attr + "="
|
||||
lower := strings.ToLower(cookie)
|
||||
idx := strings.Index(lower, strings.ToLower(prefix))
|
||||
if idx == -1 {
|
||||
return cookie
|
||||
}
|
||||
|
||||
start := idx + len(prefix)
|
||||
end := start
|
||||
for end < len(cookie) && cookie[end] != ';' && cookie[end] != ' ' {
|
||||
end++
|
||||
}
|
||||
|
||||
return cookie[:start] + newValue + cookie[end:]
|
||||
}
|
||||
|
||||
// isWebSocketRequest 检查请求是否为 WebSocket 升级请求。
|
||||
//
|
||||
// 通过检查 Connection 和 Upgrade 请求头判断:
|
||||
@ -1333,151 +984,6 @@ func (p *Proxy) GetConfig() *config.ProxyConfig {
|
||||
return p.config
|
||||
}
|
||||
|
||||
// buildCacheKey 构建缓存键字符串。
|
||||
//
|
||||
// 使用请求方法和完整请求 URI 作为缓存键。
|
||||
// 该函数保留用于日志记录和调试场景。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
//
|
||||
// 返回值:
|
||||
// - string: 缓存键(格式 "METHOD:URI")
|
||||
func (p *Proxy) buildCacheKey(ctx *fasthttp.RequestCtx) string {
|
||||
// 使用请求方法和路径作为缓存键
|
||||
return string(ctx.Request.Header.Method()) + ":" + string(ctx.Request.URI().RequestURI())
|
||||
}
|
||||
|
||||
// buildCacheKeyHash 使用 FNV-64a 计算缓存键的 uint64 哈希值。
|
||||
// 返回哈希值和原始字符串键。
|
||||
// 注意:此函数会先构建字符串键再哈希,存在双重分配。
|
||||
// 对于只需要哈希值的场景,使用 buildCacheKeyHashValue 代替。
|
||||
func (p *Proxy) buildCacheKeyHash(ctx *fasthttp.RequestCtx) (uint64, string) {
|
||||
// 构建原始 key
|
||||
origKey := p.buildCacheKey(ctx)
|
||||
|
||||
// 使用 FNV-64a 计算哈希
|
||||
h := fnv.New64a()
|
||||
h.Write([]byte(origKey))
|
||||
return h.Sum64(), origKey
|
||||
}
|
||||
|
||||
// buildCacheKeyHashValue 直接计算缓存键的哈希值,零字符串分配。
|
||||
// 用于只需要哈希值而不需要原始键的场景。
|
||||
func (p *Proxy) buildCacheKeyHashValue(ctx *fasthttp.RequestCtx) uint64 {
|
||||
h := fnv.New64a()
|
||||
h.Write(ctx.Request.Header.Method())
|
||||
h.Write([]byte(":"))
|
||||
h.Write(ctx.Request.URI().RequestURI())
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// writeCachedResponse 将缓存的响应写入 FastHTTP 响应上下文。
|
||||
//
|
||||
// 设置响应体、状态码、响应头,并添加 X-Cache: HIT 头标记缓存命中。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
// - entry: 缓存条目,包含响应数据和元数据
|
||||
func (p *Proxy) writeCachedResponse(ctx *fasthttp.RequestCtx, entry *cache.ProxyCacheEntry) {
|
||||
ctx.Response.SetBody(entry.Data)
|
||||
ctx.Response.SetStatusCode(entry.Status)
|
||||
for key, value := range entry.Headers {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
ctx.Response.Header.Set("X-Cache", "HIT")
|
||||
}
|
||||
|
||||
// backgroundRefresh 在后台异步刷新缓存条目。
|
||||
//
|
||||
// 向对应的上游目标发送请求,获取最新响应并更新缓存。
|
||||
// 该方法在独立 goroutine 中运行,不阻塞主请求流程。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 原始 FastHTTP 请求上下文(仅用于复制请求信息)
|
||||
// - target: 要刷新的后端目标
|
||||
// - hashKey: 缓存哈希键
|
||||
// - origKey: 缓存原始键
|
||||
func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance.Target, hashKey uint64, origKey string) {
|
||||
// 创建新的请求上下文副本
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// 复制原始请求
|
||||
ctx.Request.CopyTo(req)
|
||||
|
||||
// 如果启用 Revalidate,添加条件请求头
|
||||
if p.config.Cache.Revalidate {
|
||||
if entry, ok, _ := p.cache.Get(hashKey, origKey); ok {
|
||||
if entry.LastModified != "" {
|
||||
req.Header.Set("If-Modified-Since", entry.LastModified)
|
||||
}
|
||||
if entry.ETag != "" {
|
||||
req.Header.Set("If-None-Match", entry.ETag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取客户端
|
||||
client := p.getClient(target.URL)
|
||||
if client == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
err := client.Do(req, resp)
|
||||
if err != nil {
|
||||
p.cache.ReleaseLock(hashKey, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理 304 Not Modified 响应
|
||||
if resp.StatusCode() == 304 {
|
||||
newHeaders := make(map[string]string)
|
||||
if lm := resp.Header.Peek("Last-Modified"); len(lm) > 0 {
|
||||
newHeaders["Last-Modified"] = string(lm)
|
||||
}
|
||||
if et := resp.Header.Peek("ETag"); len(et) > 0 {
|
||||
newHeaders["ETag"] = string(et)
|
||||
}
|
||||
p.cache.RefreshTTL(hashKey, origKey, newHeaders)
|
||||
return
|
||||
}
|
||||
|
||||
// 提取响应头(使用 pool 复用 map)
|
||||
headers, ok := headersPool.Get().(map[string]string)
|
||||
if !ok {
|
||||
headers = make(map[string]string, 20)
|
||||
}
|
||||
for k := range headers {
|
||||
delete(headers, k)
|
||||
}
|
||||
for key, value := range resp.Header.All() {
|
||||
headers[string(key)] = string(value)
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
p.cache.Set(hashKey, origKey, resp.Body(), headers, resp.StatusCode(), p.getCacheDuration(resp.StatusCode()))
|
||||
}
|
||||
|
||||
// GetCache 返回代理的 ProxyCache 实例(用于 purge handler)。
|
||||
// 如果缓存未启用,返回 nil。
|
||||
func (p *Proxy) GetCache() *cache.ProxyCache {
|
||||
return p.cache
|
||||
}
|
||||
|
||||
// GetCacheStats 返回代理缓存的统计信息。
|
||||
// 如果缓存未启用,返回 nil。
|
||||
func (p *Proxy) GetCacheStats() *cache.ProxyCacheStats {
|
||||
if p.cache == nil {
|
||||
return nil
|
||||
}
|
||||
stats := p.cache.Stats()
|
||||
return &stats
|
||||
}
|
||||
|
||||
// extractHostFromURL 从 URL 字符串中提取 host:port 部分。
|
||||
//
|
||||
// 移除 http:// 或 https:// 协议前缀,以及路径部分,
|
||||
@ -1505,44 +1011,3 @@ func extractHostFromURL(urlStr string) string {
|
||||
return host
|
||||
}
|
||||
|
||||
// getCacheDuration 根据状态码获取缓存时间。
|
||||
// 优先级:CacheValid 配置 > MaxAge
|
||||
//
|
||||
// 映射规则:
|
||||
// - 200-299: CacheValid.OK(0 时继承 MaxAge)
|
||||
// - 301/302: CacheValid.Redirect
|
||||
// - 404: CacheValid.NotFound
|
||||
// - 400-499(除 404): CacheValid.ClientError
|
||||
// - 500-599: CacheValid.ServerError
|
||||
// - 其他: 不缓存(返回 0)
|
||||
func (p *Proxy) getCacheDuration(statusCode int) time.Duration {
|
||||
// 无 CacheValid 配置,使用 MaxAge
|
||||
if p.config.CacheValid == nil {
|
||||
return p.config.Cache.MaxAge
|
||||
}
|
||||
|
||||
cv := p.config.CacheValid
|
||||
|
||||
switch {
|
||||
case statusCode >= 200 && statusCode < 300:
|
||||
if cv.OK > 0 {
|
||||
return cv.OK
|
||||
}
|
||||
return p.config.Cache.MaxAge // 0 表示继承 MaxAge
|
||||
|
||||
case statusCode == 301 || statusCode == 302:
|
||||
return cv.Redirect // 0 表示不缓存
|
||||
|
||||
case statusCode == 404:
|
||||
return cv.NotFound
|
||||
|
||||
case statusCode >= 400 && statusCode < 500:
|
||||
return cv.ClientError
|
||||
|
||||
case statusCode >= 500:
|
||||
return cv.ServerError
|
||||
|
||||
default:
|
||||
return 0 // 不缓存
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,6 +30,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
)
|
||||
|
||||
// tlsVersionMap TLS 版本字符串到 tls 常量的映射表。
|
||||
@ -100,7 +101,9 @@ func CreateTLSConfig(cfg *config.ProxySSLConfig, defaultServerName string) (*tls
|
||||
tlsCfg.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
// TLS 版本配置
|
||||
// TLS 版本配置:默认 MinVersion = TLS 1.2
|
||||
tlsCfg.MinVersion = tls.VersionTLS12
|
||||
|
||||
if cfg.MinVersion != "" {
|
||||
version, ok := tlsVersionMap[strings.ToUpper(cfg.MinVersion)]
|
||||
if !ok {
|
||||
@ -109,6 +112,11 @@ func CreateTLSConfig(cfg *config.ProxySSLConfig, defaultServerName string) (*tls
|
||||
tlsCfg.MinVersion = version
|
||||
}
|
||||
|
||||
// 警告:TLS 1.0/1.1 已不安全,不应在生产环境使用
|
||||
if tlsCfg.MinVersion < tls.VersionTLS12 {
|
||||
logging.Warn().Msgf("上游 TLS MinVersion 设置为 %s(低于 TLS 1.2),存在安全风险", cfg.MinVersion)
|
||||
}
|
||||
|
||||
if cfg.MaxVersion != "" {
|
||||
version, ok := tlsVersionMap[strings.ToUpper(cfg.MaxVersion)]
|
||||
if !ok {
|
||||
|
||||
204
internal/proxy/target_selector.go
Normal file
204
internal/proxy/target_selector.go
Normal file
@ -0,0 +1,204 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/lua"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
)
|
||||
|
||||
// selectByLua 使用 Lua 脚本选择后端目标。
|
||||
//
|
||||
// 执行配置的 Lua 脚本,脚本可通过 ngx.balancer.set_current_peer() 选择目标。
|
||||
// 如果 Lua 脚本执行失败或未调用 set_current_peer,返回 nil 表示需要使用 fallback 算法。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
// - targets: 候选目标列表
|
||||
//
|
||||
// 返回值:
|
||||
// - *loadbalance.Target: Lua 脚本选中的目标,nil 表示未选择
|
||||
// - error: Lua 执行失败时返回错误
|
||||
func (p *Proxy) selectByLua(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) (*loadbalance.Target, error) {
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
|
||||
bctx := &lua.BalancerContext{
|
||||
Targets: targets,
|
||||
ClientIP: clientIP,
|
||||
Retries: p.config.NextUpstream.Tries,
|
||||
}
|
||||
|
||||
// 创建 Lua 协程
|
||||
coro, err := p.luaEngine.NewCoroutine(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create lua coroutine: %w", err)
|
||||
}
|
||||
defer coro.Close()
|
||||
|
||||
// 注册 balancer API
|
||||
L := coro.Co
|
||||
ngx, ok := L.GetGlobal("ngx").(*glua.LTable)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("global 'ngx' is not an LTable")
|
||||
}
|
||||
lua.RegisterBalancerAPI(L, bctx, ngx)
|
||||
|
||||
// 设置超时
|
||||
timeout := p.config.BalancerByLua.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = 100 * time.Millisecond
|
||||
}
|
||||
|
||||
// 执行脚本(带超时)
|
||||
execCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
coro.ExecutionContext = execCtx
|
||||
|
||||
err = coro.ExecuteFile(p.config.BalancerByLua.Script)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("execute lua script: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否调用了 set_current_peer
|
||||
if !bctx.IsSelected() {
|
||||
return nil, nil // 未选择,返回 nil 表示需使用 fallback
|
||||
}
|
||||
|
||||
return bctx.Selected, nil
|
||||
}
|
||||
|
||||
// selectByFallback 使用 fallback 负载均衡算法选择目标。
|
||||
//
|
||||
// 当 Lua balancer 执行失败或未选择目标时使用。
|
||||
// 对于 IPHash 算法,会自动提取客户端 IP 进行哈希选择。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
// - targets: 候选目标列表
|
||||
//
|
||||
// 返回值:
|
||||
// - *loadbalance.Target: fallback 算法选中的目标
|
||||
func (p *Proxy) selectByFallback(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target {
|
||||
p.mu.RLock()
|
||||
balancer := p.fallbackBalancer
|
||||
p.mu.RUnlock()
|
||||
|
||||
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
return ipHash.SelectByIP(targets, clientIP)
|
||||
}
|
||||
|
||||
return balancer.Select(targets)
|
||||
}
|
||||
|
||||
// selectByBalancer 使用主负载均衡器选择目标。
|
||||
//
|
||||
// 对于特殊算法(IPHash、ConsistentHash),会从请求上下文中提取
|
||||
// 相应的哈希键(客户端 IP、URI、自定义 Header 等)。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
// - targets: 候选目标列表
|
||||
//
|
||||
// 返回值:
|
||||
// - *loadbalance.Target: 主负载均衡器选中的目标
|
||||
func (p *Proxy) selectByBalancer(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target {
|
||||
p.mu.RLock()
|
||||
balancer := p.balancer
|
||||
p.mu.RUnlock()
|
||||
|
||||
// 对于 IPHash 负载均衡器,提取客户端 IP
|
||||
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
return ipHash.SelectByIP(targets, clientIP)
|
||||
}
|
||||
|
||||
// 对于一致性哈希,根据 hash_key 配置选择
|
||||
if ch, ok := balancer.(*loadbalance.ConsistentHash); ok {
|
||||
hashKey := ch.GetHashKey()
|
||||
key := p.extractHashKey(ctx, hashKey)
|
||||
return ch.SelectByKey(targets, key)
|
||||
}
|
||||
|
||||
return balancer.Select(targets)
|
||||
}
|
||||
|
||||
// selectTargetExcluding 选择后端目标,排除已尝试失败的目标。
|
||||
// 用于故障转移场景,避免重复选择已失败的目标。
|
||||
// 如果没有可用的健康目标则返回 nil。
|
||||
func (p *Proxy) selectTargetExcluding(ctx *fasthttp.RequestCtx, excluded []*loadbalance.Target) *loadbalance.Target {
|
||||
p.mu.RLock()
|
||||
balancer := p.balancer
|
||||
targets := p.targets
|
||||
p.mu.RUnlock()
|
||||
|
||||
if len(targets) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 对于 IPHash 负载均衡器,提取客户端 IP
|
||||
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
|
||||
clientIP := netutil.ExtractClientIP(ctx)
|
||||
return ipHash.SelectExcludingByIP(targets, excluded, clientIP)
|
||||
}
|
||||
|
||||
// 对于一致性哈希,根据 hash_key 配置选择
|
||||
if ch, ok := balancer.(*loadbalance.ConsistentHash); ok {
|
||||
hashKey := ch.GetHashKey()
|
||||
key := p.extractHashKey(ctx, hashKey)
|
||||
return ch.SelectExcludingByKey(targets, excluded, key)
|
||||
}
|
||||
|
||||
return balancer.SelectExcluding(targets, excluded)
|
||||
}
|
||||
|
||||
// extractHashKey 根据一致性哈希配置提取哈希键值。
|
||||
//
|
||||
// 支持的 hash_key 配置:
|
||||
// - "ip" 或 "": 使用客户端 IP 地址
|
||||
// - "uri": 使用完整请求 URI
|
||||
// - "header:NAME": 使用指定请求头的值,缺失时回退到客户端 IP
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
// - hashKey: 哈希键配置
|
||||
//
|
||||
// 返回值:
|
||||
// - string: 提取的哈希键值
|
||||
func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string {
|
||||
switch {
|
||||
case hashKey == "ip" || hashKey == "":
|
||||
return netutil.ExtractClientIP(ctx)
|
||||
case hashKey == "uri":
|
||||
return string(ctx.RequestURI())
|
||||
case strings.HasPrefix(hashKey, "header:"):
|
||||
headerName := strings.TrimPrefix(hashKey, "header:")
|
||||
value := ctx.Request.Header.Peek(headerName)
|
||||
if len(value) > 0 {
|
||||
return string(value)
|
||||
}
|
||||
return netutil.ExtractClientIP(ctx) // fallback to IP
|
||||
default:
|
||||
return netutil.ExtractClientIP(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// getClient 返回指定目标 URL 对应的 HostClient 连接池实例。
|
||||
// 如果目标 URL 不存在于连接池中,返回 nil。
|
||||
func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient {
|
||||
key := targetURL
|
||||
if p.config.ProxyBind != "" {
|
||||
key = targetURL + "|" + p.config.ProxyBind
|
||||
}
|
||||
p.mu.RLock()
|
||||
client := p.clients[key]
|
||||
p.mu.RUnlock()
|
||||
return client
|
||||
}
|
||||
@ -142,10 +142,10 @@ func initLuaEngine(luaCfg *config.LuaMiddlewareConfig) (*lua.LuaEngine, error) {
|
||||
|
||||
engine, err := lua.NewEngine(engineCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("初始化 Lua 引擎失败: %w", err)
|
||||
return nil, fmt.Errorf("failed to initialize Lua engine: %w", err)
|
||||
}
|
||||
|
||||
logging.Info().Msg("Lua 引擎已启动")
|
||||
logging.Info().Msg("Lua engine started")
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
@ -169,14 +169,14 @@ func initErrorPageManager(errorPageCfg *config.ErrorPageConfig) (*handler.ErrorP
|
||||
if err != nil {
|
||||
// 检查是否是部分加载失败
|
||||
if _, ok := err.(*handler.PartialLoadError); ok {
|
||||
logging.Warn().Msg("部分错误页面加载失败: " + err.Error())
|
||||
logging.Warn().Msg("Some error pages failed to load: " + err.Error())
|
||||
// 返回部分加载的管理器
|
||||
return manager, nil
|
||||
}
|
||||
// 全部加载失败
|
||||
return nil, fmt.Errorf("加载错误页面失败: %w", err)
|
||||
return nil, fmt.Errorf("failed to load error pages: %w", err)
|
||||
}
|
||||
|
||||
logging.Info().Msg("错误页面管理器已启动")
|
||||
logging.Info().Msg("Error page manager started")
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
222
internal/server/lifecycle.go
Normal file
222
internal/server/lifecycle.go
Normal file
@ -0,0 +1,222 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
)
|
||||
|
||||
// cleanupResources 清理服务器资源。
|
||||
//
|
||||
// 停止 Goroutine 池、健康检查器,关闭访问日志、TLS 管理器、
|
||||
// AccessControl 和 Lua 引擎。由 StopWithTimeout 和 GracefulStop 共用。
|
||||
func (s *Server) cleanupResources() {
|
||||
// 停止 Goroutine 池
|
||||
if s.pool != nil {
|
||||
s.pool.Stop()
|
||||
}
|
||||
|
||||
// 停止健康检查器
|
||||
for _, hc := range s.healthCheckers {
|
||||
hc.Stop()
|
||||
}
|
||||
|
||||
// 关闭访问日志
|
||||
if s.accessLogMiddleware != nil {
|
||||
_ = s.accessLogMiddleware.Close()
|
||||
}
|
||||
|
||||
// 关闭 TLS 管理器
|
||||
if s.tlsManager != nil {
|
||||
s.tlsManager.Close()
|
||||
}
|
||||
|
||||
// 关闭 AccessControl (释放 GeoIP 资源)
|
||||
if s.accessControl != nil {
|
||||
if err := s.accessControl.Close(); err != nil {
|
||||
logging.Warn().Err(err).Msg("Failed to close AccessControl")
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭 Lua 引擎
|
||||
if s.luaEngine != nil {
|
||||
s.luaEngine.Close()
|
||||
logging.Info().Msg("Lua engine closed")
|
||||
}
|
||||
}
|
||||
|
||||
// shutdownServers 并行关闭多个 fasthttp.Server 实例。
|
||||
//
|
||||
// 使用 goroutine 并行关闭所有服务器,收集所有错误并返回聚合错误。
|
||||
// 部分服务器关闭失败不会影响其他服务器的关闭。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 关闭上下文,用于控制超时和取消
|
||||
// - servers: 要关闭的 fasthttp.Server 实例列表
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 聚合错误,无错误或全部成功时返回 nil
|
||||
func shutdownServers(ctx context.Context, servers []*fasthttp.Server) error {
|
||||
// 防御性检查:nil ctx 使用默认背景
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if len(servers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
errs []error
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
|
||||
for _, srv := range servers {
|
||||
if srv == nil {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(s *fasthttp.Server) {
|
||||
defer wg.Done()
|
||||
if err := s.Shutdown(); err != nil {
|
||||
mu.Lock()
|
||||
errs = append(errs, err)
|
||||
mu.Unlock()
|
||||
}
|
||||
}(srv)
|
||||
}
|
||||
|
||||
// 等待所有关闭完成或上下文取消
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
if len(errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(errs) == 1 {
|
||||
return errs[0]
|
||||
}
|
||||
msgs := make([]string, len(errs))
|
||||
for i, e := range errs {
|
||||
msgs[i] = e.Error()
|
||||
}
|
||||
return fmt.Errorf("failed to close servers: %d errors: %s", len(errs), strings.Join(msgs, "; "))
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// StopWithTimeout 快速停止服务器(支持自定义超时)。
|
||||
//
|
||||
// 立即停止服务器,不等待正在处理的请求完成。
|
||||
// 停止所有健康检查器和访问日志中间件。
|
||||
//
|
||||
// 参数:
|
||||
// - timeout: 快速关闭的最大等待时间
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 停止过程中遇到的错误
|
||||
//
|
||||
// 注意事项:
|
||||
// - 对于生产环境,建议使用 GracefulStop 实现优雅关闭
|
||||
// - timeout <= 0 时会使用默认 5s 超时
|
||||
func (s *Server) StopWithTimeout(timeout time.Duration) error {
|
||||
// 防御性检查:如果 timeout <= 0,使用默认值
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
|
||||
s.running = false
|
||||
s.cleanupResources()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// 多服务器模式:并行关闭所有 fasthttp.Server
|
||||
if len(s.fastServers) > 0 {
|
||||
return shutdownServers(ctx, s.fastServers)
|
||||
}
|
||||
|
||||
// 单服务器模式:关闭单个 fasthttp.Server
|
||||
if s.fastServer != nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = s.fastServer.Shutdown()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GracefulStop 优雅停止服务器。
|
||||
//
|
||||
// 等待正在处理的请求完成后再停止服务器,确保连接正常关闭。
|
||||
// 如果超时时间到达仍有请求未完成,将返回超时错误。
|
||||
//
|
||||
// 参数:
|
||||
// - timeout: 优雅关闭的最大等待时间
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 停止过程中遇到的错误,超时返回 context.DeadlineExceeded
|
||||
//
|
||||
// 注意事项:
|
||||
// - 推荐在生产环境使用此方法关闭服务器
|
||||
// - 超时后会强制关闭,可能导致部分请求中断
|
||||
func (s *Server) GracefulStop(timeout time.Duration) error {
|
||||
s.running = false
|
||||
s.cleanupResources()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// 多服务器模式:并行关闭所有 fasthttp.Server
|
||||
if len(s.fastServers) > 0 {
|
||||
return shutdownServers(ctx, s.fastServers)
|
||||
}
|
||||
|
||||
// 单服务器模式:关闭单个 fasthttp.Server
|
||||
if s.fastServer != nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = s.fastServer.Shutdown()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getProxyCacheStats 收集所有代理缓存的统计信息。
|
||||
func (s *Server) getProxyCacheStats() ProxyCacheStats {
|
||||
var total ProxyCacheStats
|
||||
for _, p := range s.proxies {
|
||||
if stats := p.GetCacheStats(); stats != nil {
|
||||
total.Entries += stats.Entries
|
||||
total.Pending += stats.Pending
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
@ -75,7 +75,7 @@ func TestBuildLuaMiddlewares_InvalidPhase(t *testing.T) {
|
||||
|
||||
_, err = s.buildLuaMiddlewares(luaCfg)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "无效的阶段")
|
||||
assert.Contains(t, err.Error(), "invalid phase")
|
||||
}
|
||||
|
||||
// TestBuildLuaMiddlewares_WithTimeout 测试超时配置
|
||||
|
||||
246
internal/server/middleware_builder.go
Normal file
246
internal/server/middleware_builder.go
Normal file
@ -0,0 +1,246 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/lua"
|
||||
"rua.plus/lolly/internal/middleware"
|
||||
"rua.plus/lolly/internal/middleware/accesslog"
|
||||
"rua.plus/lolly/internal/middleware/bodylimit"
|
||||
"rua.plus/lolly/internal/middleware/compression"
|
||||
"rua.plus/lolly/internal/middleware/errorintercept"
|
||||
"rua.plus/lolly/internal/middleware/rewrite"
|
||||
"rua.plus/lolly/internal/middleware/security"
|
||||
)
|
||||
|
||||
// buildMiddlewareChain 构建中间件链。
|
||||
//
|
||||
// 根据服务器配置按顺序构建中间件链,顺序为:
|
||||
//
|
||||
// AccessLog -> AccessControl -> RateLimiter -> BasicAuth -> Rewrite -> Compression -> SecurityHeaders
|
||||
//
|
||||
// 参数:
|
||||
// - serverCfg: 单个服务器的配置对象
|
||||
//
|
||||
// 返回值:
|
||||
// - *middleware.Chain: 构建完成的中间件链
|
||||
// - error: 构建过程中遇到的错误,如中间件创建失败
|
||||
//
|
||||
// 注意事项:
|
||||
// - 各中间件按顺序依次包装请求处理器
|
||||
// - 未配置的中间件不会添加到链中
|
||||
func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middleware.Chain, error) {
|
||||
var middlewares []middleware.Middleware
|
||||
|
||||
// 1. AccessLog (已集成)
|
||||
s.accessLogMiddleware = accesslog.New(&s.config.Logging)
|
||||
middlewares = append(middlewares, s.accessLogMiddleware)
|
||||
|
||||
// 2. Security: AccessControl (IP 访问控制)
|
||||
if len(serverCfg.Security.Access.Allow) > 0 || len(serverCfg.Security.Access.Deny) > 0 {
|
||||
ac, err := security.NewAccessControl(&serverCfg.Security.Access)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create access control middleware: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, ac)
|
||||
s.accessControl = ac
|
||||
}
|
||||
|
||||
// 3. Security: RateLimiter (速率限制)
|
||||
if serverCfg.Security.RateLimit.RequestRate > 0 {
|
||||
rl, err := security.NewRateLimiter(&serverCfg.Security.RateLimit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create rate limiter middleware: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, rl)
|
||||
}
|
||||
|
||||
// 3.5 Security: ConnLimiter (连接数限制)
|
||||
if serverCfg.Security.RateLimit.ConnLimit > 0 {
|
||||
cl, err := security.NewConnLimiter(serverCfg.Security.RateLimit.ConnLimit, true, serverCfg.Security.RateLimit.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection limiter middleware: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, cl.Middleware())
|
||||
}
|
||||
|
||||
// 4. Security: BasicAuth (认证)
|
||||
if len(serverCfg.Security.Auth.Users) > 0 {
|
||||
auth, err := security.NewBasicAuth(&serverCfg.Security.Auth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth middleware: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, auth)
|
||||
}
|
||||
|
||||
// 4.3 Security: AuthRequest (外部认证子请求)
|
||||
if serverCfg.Security.AuthRequest.Enabled && serverCfg.Security.AuthRequest.URI != "" {
|
||||
authReq, err := security.NewAuthRequest(serverCfg.Security.AuthRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create auth request middleware: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, authReq)
|
||||
}
|
||||
|
||||
// 4.5 BodyLimit (请求体大小限制)
|
||||
// 创建 bodylimit 中间件,使用全局配置或默认值
|
||||
bodyLimitMiddleware := bodylimit.NewWithDefault()
|
||||
if serverCfg.ClientMaxBodySize != "" {
|
||||
bl, err := bodylimit.New(serverCfg.ClientMaxBodySize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create body limit middleware: %w", err)
|
||||
}
|
||||
bodyLimitMiddleware = bl
|
||||
}
|
||||
// 添加路径级别的限制配置
|
||||
for i := range serverCfg.Proxy {
|
||||
if serverCfg.Proxy[i].ClientMaxBodySize != "" {
|
||||
if err := bodyLimitMiddleware.AddPathLimit(
|
||||
serverCfg.Proxy[i].Path,
|
||||
serverCfg.Proxy[i].ClientMaxBodySize,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to add path body limit: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
middlewares = append(middlewares, bodyLimitMiddleware)
|
||||
|
||||
// 5. Rewrite (URL 重写)
|
||||
if len(serverCfg.Rewrite) > 0 {
|
||||
rw, err := rewrite.New(serverCfg.Rewrite)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create rewrite middleware: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, rw)
|
||||
}
|
||||
|
||||
// 6. Compression (响应压缩)
|
||||
if serverCfg.Compression.Type != "" {
|
||||
comp, err := compression.New(&serverCfg.Compression)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create compression middleware: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, comp)
|
||||
}
|
||||
|
||||
// 7. SecurityHeaders (安全头部)
|
||||
// 如果有任何安全头部配置,则启用
|
||||
if serverCfg.Security.Headers.XFrameOptions != "" ||
|
||||
serverCfg.Security.Headers.XContentTypeOptions != "" ||
|
||||
serverCfg.Security.Headers.ContentSecurityPolicy != "" ||
|
||||
serverCfg.Security.Headers.ReferrerPolicy != "" ||
|
||||
serverCfg.Security.Headers.PermissionsPolicy != "" {
|
||||
headers := security.NewHeadersWithHSTS(&serverCfg.Security.Headers, &serverCfg.SSL.HSTS)
|
||||
middlewares = append(middlewares, headers)
|
||||
}
|
||||
|
||||
// 8. ErrorIntercept (错误页面拦截)
|
||||
// 如果配置了错误页面,添加错误拦截中间件
|
||||
if s.errorPageManager != nil && s.errorPageManager.IsConfigured() {
|
||||
ei := errorintercept.New(s.errorPageManager)
|
||||
middlewares = append(middlewares, ei)
|
||||
}
|
||||
|
||||
// Lua 中间件(可选)
|
||||
if s.luaEngine != nil && serverCfg.Lua != nil && serverCfg.Lua.Enabled {
|
||||
luaMiddlewares, err := s.buildLuaMiddlewares(serverCfg.Lua)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Lua middleware: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, luaMiddlewares...)
|
||||
}
|
||||
|
||||
return middleware.NewChain(middlewares...), nil
|
||||
}
|
||||
|
||||
// buildLuaMiddlewares 根据 Lua 配置创建中间件。
|
||||
//
|
||||
// 根据 Scripts 配置创建 LuaMiddleware 或 MultiPhaseLuaMiddleware。
|
||||
// 支持单脚本和多阶段脚本配置。
|
||||
//
|
||||
// 参数:
|
||||
// - luaCfg: Lua 配置对象
|
||||
//
|
||||
// 返回值:
|
||||
// - []middleware.Middleware: 创建的中间件列表
|
||||
// - error: 创建过程中遇到的错误
|
||||
func (s *Server) buildLuaMiddlewares(luaCfg *config.LuaMiddlewareConfig) ([]middleware.Middleware, error) {
|
||||
if s.luaEngine == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 按阶段分组脚本
|
||||
phaseScripts := make(map[string][]config.LuaScriptConfig)
|
||||
for _, script := range luaCfg.Scripts {
|
||||
// 默认启用
|
||||
enabled := script.Enabled
|
||||
if !enabled && script.Timeout == 0 && script.Path != "" {
|
||||
enabled = true // 零值时默认启用
|
||||
}
|
||||
if enabled {
|
||||
phaseScripts[script.Phase] = append(phaseScripts[script.Phase], script)
|
||||
}
|
||||
}
|
||||
|
||||
var middlewares []middleware.Middleware
|
||||
|
||||
// 为每个阶段创建中间件
|
||||
for phase, scripts := range phaseScripts {
|
||||
if len(scripts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 单脚本:直接创建 LuaMiddleware
|
||||
if len(scripts) == 1 {
|
||||
script := scripts[0]
|
||||
luaPhase, err := lua.ParsePhase(phase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid phase '%s': %w", phase, err)
|
||||
}
|
||||
|
||||
timeout := script.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
cfg := lua.LuaMiddlewareConfig{
|
||||
ScriptPath: script.Path,
|
||||
Phase: luaPhase,
|
||||
Timeout: timeout,
|
||||
Name: fmt.Sprintf("lua-%s", phase),
|
||||
}
|
||||
|
||||
mw, err := lua.NewLuaMiddleware(s.luaEngine, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Lua middleware (phase=%s): %w", phase, err)
|
||||
}
|
||||
|
||||
middlewares = append(middlewares, mw)
|
||||
} else {
|
||||
// 多脚本:创建 MultiPhaseLuaMiddleware
|
||||
multi := lua.NewMultiPhaseLuaMiddleware(s.luaEngine, fmt.Sprintf("lua-multi-%s", phase))
|
||||
for _, script := range scripts {
|
||||
luaPhase, err := lua.ParsePhase(phase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid phase '%s': %w", phase, err)
|
||||
}
|
||||
|
||||
timeout := script.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
err = multi.AddPhase(luaPhase, script.Path, timeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add Lua phase (phase=%s): %w", phase, err)
|
||||
}
|
||||
}
|
||||
|
||||
middlewares = append(middlewares, multi)
|
||||
}
|
||||
}
|
||||
|
||||
return middlewares, nil
|
||||
}
|
||||
@ -76,7 +76,7 @@ func NewPprofHandler(cfg *config.PprofConfig) (*PprofHandler, error) {
|
||||
// 尝试解析 CIDR
|
||||
_, net, err := net.ParseCIDR(ipStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析 IP/CIDR 失败: %s: %w", ipStr, err)
|
||||
return nil, fmt.Errorf("failed to parse IP/CIDR: %s: %w", ipStr, err)
|
||||
}
|
||||
h.allowedNets = append(h.allowedNets, net)
|
||||
}
|
||||
|
||||
@ -7,12 +7,11 @@ import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/cache"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
"rua.plus/lolly/internal/utils"
|
||||
)
|
||||
|
||||
// PurgeHandler 缓存清理 API 处理器。
|
||||
@ -104,26 +103,26 @@ func (h *PurgeHandler) Path() string {
|
||||
func (h *PurgeHandler) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
// 仅允许 POST 方法
|
||||
if string(ctx.Method()) != "POST" {
|
||||
h.sendError(ctx, fasthttp.StatusMethodNotAllowed, "method not allowed")
|
||||
utils.SendJSONError(ctx, fasthttp.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 IP 访问权限
|
||||
if !h.checkAccess(ctx) {
|
||||
h.sendError(ctx, fasthttp.StatusForbidden, "forbidden")
|
||||
if !utils.CheckIPAccess(ctx, h.allowed) {
|
||||
utils.SendJSONError(ctx, fasthttp.StatusForbidden, "forbidden")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查认证
|
||||
if !h.checkAuth(ctx) {
|
||||
h.sendError(ctx, fasthttp.StatusUnauthorized, "unauthorized")
|
||||
if !utils.CheckTokenAuth(ctx, h.auth) {
|
||||
utils.SendJSONError(ctx, fasthttp.StatusUnauthorized, "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var req cache.PurgeRequest
|
||||
if err := json.Unmarshal(ctx.PostBody(), &req); err != nil {
|
||||
h.sendError(ctx, fasthttp.StatusBadRequest, "invalid request body")
|
||||
utils.SendJSONError(ctx, fasthttp.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
@ -134,7 +133,7 @@ func (h *PurgeHandler) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
} else if req.Pattern != "" {
|
||||
deleted = h.purgeByPattern(req.Pattern, req.Method)
|
||||
} else {
|
||||
h.sendError(ctx, fasthttp.StatusBadRequest, "missing path or pattern")
|
||||
utils.SendJSONError(ctx, fasthttp.StatusBadRequest, "missing path or pattern")
|
||||
return
|
||||
}
|
||||
|
||||
@ -144,55 +143,6 @@ func (h *PurgeHandler) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
_ = json.NewEncoder(ctx).Encode(cache.PurgeResponse{Deleted: deleted})
|
||||
}
|
||||
|
||||
// checkAccess 检查客户端 IP 是否在允许列表中。
|
||||
func (h *PurgeHandler) checkAccess(ctx *fasthttp.RequestCtx) bool {
|
||||
// 如果没有配置允许列表,允许所有访问
|
||||
if len(h.allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
clientIP := netutil.ExtractClientIPNet(ctx)
|
||||
if clientIP == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否在允许列表中
|
||||
for _, network := range h.allowed {
|
||||
if network.Contains(clientIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// checkAuth 检查认证。
|
||||
func (h *PurgeHandler) checkAuth(ctx *fasthttp.RequestCtx) bool {
|
||||
// 无需认证
|
||||
if h.auth.Type == "" || h.auth.Type == "none" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Token 认证
|
||||
if h.auth.Type == "token" {
|
||||
authHeader := ctx.Request.Header.Peek("Authorization")
|
||||
if len(authHeader) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
authStr := string(authHeader)
|
||||
// 支持 Bearer token 格式
|
||||
if token, ok := strings.CutPrefix(authStr, "Bearer "); ok {
|
||||
return token == h.auth.Token
|
||||
}
|
||||
|
||||
// 也支持直接传递 token
|
||||
return authStr == h.auth.Token
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// purgeByPath 按精确路径清理缓存。
|
||||
func (h *PurgeHandler) purgeByPath(path string, method string) int {
|
||||
if h.server == nil {
|
||||
@ -229,13 +179,6 @@ func (h *PurgeHandler) purgeByPattern(pattern string, method string) int {
|
||||
return deleted
|
||||
}
|
||||
|
||||
// sendError 发送错误响应。
|
||||
func (h *PurgeHandler) sendError(ctx *fasthttp.RequestCtx, status int, errMsg string) {
|
||||
ctx.SetContentType("application/json; charset=utf-8")
|
||||
ctx.SetStatusCode(status)
|
||||
_ = json.NewEncoder(ctx).Encode(cache.PurgeErrorResponse{Error: errMsg})
|
||||
}
|
||||
|
||||
// PurgeByPathForTest 测试用的导出方法。
|
||||
func (h *PurgeHandler) PurgeByPathForTest(path string, method string) int {
|
||||
return h.purgeByPath(path, method)
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/proxy"
|
||||
"rua.plus/lolly/internal/utils"
|
||||
)
|
||||
|
||||
func TestPurgeHandler_Path(t *testing.T) {
|
||||
@ -238,7 +239,7 @@ func TestPurgeHandler_checkAccess(t *testing.T) {
|
||||
|
||||
if len(h.allowed) == 0 {
|
||||
// 无白名单时应允许所有访问
|
||||
if !h.checkAccess(nil) {
|
||||
if !utils.CheckIPAccess(nil, h.allowed) {
|
||||
t.Error("expected access to be true when no allow list configured")
|
||||
}
|
||||
return
|
||||
@ -281,7 +282,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
if !h.checkAuth(ctx) {
|
||||
if !utils.CheckTokenAuth(ctx, h.auth) {
|
||||
t.Error("expected auth to pass when no auth configured")
|
||||
}
|
||||
})
|
||||
@ -301,7 +302,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
if !h.checkAuth(ctx) {
|
||||
if !utils.CheckTokenAuth(ctx, h.auth) {
|
||||
t.Error("expected auth to pass when type is none")
|
||||
}
|
||||
})
|
||||
@ -323,7 +324,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("Authorization", "Bearer secret-token")
|
||||
|
||||
if !h.checkAuth(ctx) {
|
||||
if !utils.CheckTokenAuth(ctx, h.auth) {
|
||||
t.Error("expected auth to pass with correct Bearer token")
|
||||
}
|
||||
})
|
||||
@ -345,7 +346,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("Authorization", "secret-token")
|
||||
|
||||
if !h.checkAuth(ctx) {
|
||||
if !utils.CheckTokenAuth(ctx, h.auth) {
|
||||
t.Error("expected auth to pass with correct direct token")
|
||||
}
|
||||
})
|
||||
@ -367,7 +368,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("Authorization", "Bearer wrong-token")
|
||||
|
||||
if h.checkAuth(ctx) {
|
||||
if utils.CheckTokenAuth(ctx, h.auth) {
|
||||
t.Error("expected auth to fail with wrong token")
|
||||
}
|
||||
})
|
||||
@ -388,7 +389,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
if h.checkAuth(ctx) {
|
||||
if utils.CheckTokenAuth(ctx, h.auth) {
|
||||
t.Error("expected auth to fail when Authorization header is missing")
|
||||
}
|
||||
})
|
||||
@ -410,7 +411,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("Authorization", "Bearer secret-token")
|
||||
|
||||
if h.checkAuth(ctx) {
|
||||
if utils.CheckTokenAuth(ctx, h.auth) {
|
||||
t.Error("expected auth to fail for unknown auth type")
|
||||
}
|
||||
})
|
||||
@ -570,7 +571,7 @@ func TestPurgeHandler_SendError(t *testing.T) {
|
||||
Allow: []string{},
|
||||
}
|
||||
|
||||
h, err := NewPurgeHandler(nil, cfg)
|
||||
_, err := NewPurgeHandler(nil, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@ -578,7 +579,7 @@ func TestPurgeHandler_SendError(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
h.sendError(ctx, tt.status, tt.errMsg)
|
||||
utils.SendJSONError(ctx, tt.status, tt.errMsg)
|
||||
|
||||
if ctx.Response.StatusCode() != tt.status {
|
||||
t.Errorf("expected status %d, got %d", tt.status, ctx.Response.StatusCode())
|
||||
@ -690,7 +691,7 @@ func TestPurgeHandler_checkAccess_NilContext(t *testing.T) {
|
||||
}
|
||||
|
||||
// Empty allow list should allow access (returns true even with nil context)
|
||||
if !h.checkAccess(nil) {
|
||||
if !utils.CheckIPAccess(nil, h.allowed) {
|
||||
t.Error("expected checkAccess to return true with empty allow list")
|
||||
}
|
||||
})
|
||||
@ -779,7 +780,7 @@ func TestPurgeHandler_checkAccess_WithAllowedIP(t *testing.T) {
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
// context with nil remote address - should return false (no client IP)
|
||||
if h.checkAccess(ctx) {
|
||||
if utils.CheckIPAccess(ctx, h.allowed) {
|
||||
t.Error("expected checkAccess to return false with no client IP")
|
||||
}
|
||||
})
|
||||
|
||||
281
internal/server/router.go
Normal file
281
internal/server/router.go
Normal file
@ -0,0 +1,281 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/handler"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
"rua.plus/lolly/internal/matcher"
|
||||
"rua.plus/lolly/internal/proxy"
|
||||
)
|
||||
|
||||
// registerProxyRoutesWithLocationEngine 使用 LocationEngine 注册代理路由。
|
||||
//
|
||||
// 根据配置为 LocationEngine 注册代理路径,创建代理处理器和健康检查器。
|
||||
// 支持通过 LocationType 配置不同的匹配方式。
|
||||
func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) {
|
||||
for i := range serverCfg.Proxy {
|
||||
proxyCfg := &serverCfg.Proxy[i]
|
||||
|
||||
// 转换目标
|
||||
targets := make([]*loadbalance.Target, len(proxyCfg.Targets))
|
||||
for j, t := range proxyCfg.Targets {
|
||||
failTimeout := t.FailTimeout
|
||||
if t.MaxFails > 0 && failTimeout == 0 {
|
||||
failTimeout = 10 * time.Second
|
||||
}
|
||||
targets[j] = loadbalance.NewTargetFromConfig(
|
||||
t.URL, t.Weight,
|
||||
int64(t.MaxConns), int64(t.MaxFails), failTimeout,
|
||||
t.Backup, t.Down, t.ProxyURI,
|
||||
)
|
||||
}
|
||||
|
||||
// 传递 Transport 配置和 Lua 引擎
|
||||
p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport, s.luaEngine)
|
||||
if err != nil {
|
||||
logging.Error().Msg("Failed to create proxy: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// 设置 DNS 解析器(如果已配置)
|
||||
if s.resolver != nil {
|
||||
p.SetResolver(s.resolver)
|
||||
if err := p.Start(); err != nil {
|
||||
logging.Error().Err(err).Msg("Failed to start proxy")
|
||||
}
|
||||
}
|
||||
|
||||
// 启动健康检查
|
||||
if proxyCfg.HealthCheck.Interval > 0 {
|
||||
hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck)
|
||||
hc.Start()
|
||||
s.healthCheckers = append(s.healthCheckers, hc)
|
||||
// 设置被动健康检查
|
||||
p.SetHealthChecker(hc)
|
||||
}
|
||||
|
||||
// 保存代理实例用于缓存统计
|
||||
s.proxies = append(s.proxies, p)
|
||||
|
||||
// 根据 LocationType 注册路由
|
||||
locType := proxyCfg.LocationType
|
||||
if locType == "" {
|
||||
locType = matcher.LocationTypePrefix
|
||||
}
|
||||
|
||||
switch locType {
|
||||
case matcher.LocationTypeExact:
|
||||
_ = s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
|
||||
case matcher.LocationTypePrefixPriority:
|
||||
_ = s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
|
||||
case matcher.LocationTypeRegex, matcher.LocationTypeRegexCaseless:
|
||||
caseInsensitive := locType == matcher.LocationTypeRegexCaseless
|
||||
_ = s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal)
|
||||
case matcher.LocationTypeNamed:
|
||||
if proxyCfg.LocationName != "" {
|
||||
_ = s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP)
|
||||
}
|
||||
case matcher.LocationTypePrefix:
|
||||
_ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
|
||||
default:
|
||||
_ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registerStaticHandlersWithLocationEngine 使用 LocationEngine 注册静态文件处理器。
|
||||
func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) {
|
||||
for _, static := range cfg.Static {
|
||||
path := static.Path
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
staticHandler := handler.NewStaticHandler(
|
||||
static.Root,
|
||||
path,
|
||||
static.Index,
|
||||
true, // useSendfile
|
||||
)
|
||||
// 设置 alias(与 root 互斥)
|
||||
if static.Alias != "" {
|
||||
staticHandler.SetAlias(static.Alias)
|
||||
}
|
||||
if s.fileCache != nil {
|
||||
staticHandler.SetFileCache(s.fileCache)
|
||||
// 设置默认缓存 TTL (5s)
|
||||
staticHandler.SetCacheTTL(5 * time.Second)
|
||||
}
|
||||
if cfg.Compression.GzipStatic {
|
||||
// extensions: 源文件类型,为空使用默认值
|
||||
// GzipStaticExtensions: 预压缩文件扩展名(如 .br, .gz)
|
||||
staticHandler.SetGzipStatic(true, nil, cfg.Compression.GzipStaticExtensions)
|
||||
}
|
||||
|
||||
// 设置符号链接安全检查
|
||||
staticHandler.SetSymlinkCheck(static.SymlinkCheck)
|
||||
|
||||
// 设置 internal 限制
|
||||
staticHandler.SetInternal(static.Internal)
|
||||
|
||||
// 设置缓存过期时间
|
||||
if static.Expires != "" {
|
||||
staticHandler.SetExpires(static.Expires)
|
||||
}
|
||||
|
||||
// 根据 LocationType 注册路由
|
||||
locType := static.LocationType
|
||||
if locType == "" {
|
||||
locType = matcher.LocationTypePrefix
|
||||
}
|
||||
|
||||
switch locType {
|
||||
case matcher.LocationTypeExact:
|
||||
_ = s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal)
|
||||
case matcher.LocationTypePrefixPriority:
|
||||
_ = s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal)
|
||||
case matcher.LocationTypePrefix:
|
||||
_ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal)
|
||||
default:
|
||||
_ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registerProxyRoutes 注册代理路由。
|
||||
//
|
||||
// 根据配置为路由器注册代理路径,创建代理处理器和健康检查器。
|
||||
// 支持 GET、POST、PUT、DELETE、HEAD 等 HTTP 方法。
|
||||
//
|
||||
// 参数:
|
||||
// - router: 路由器实例,用于注册路由规则
|
||||
// - serverCfg: 服务器配置,包含代理目标、负载均衡、健康检查等设置
|
||||
//
|
||||
// 注意事项:
|
||||
// - 代理目标初始状态默认为健康
|
||||
// - 健康检查根据配置自动启动
|
||||
func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.ServerConfig) {
|
||||
for i := range serverCfg.Proxy {
|
||||
proxyCfg := &serverCfg.Proxy[i]
|
||||
|
||||
// 转换目标
|
||||
targets := make([]*loadbalance.Target, len(proxyCfg.Targets))
|
||||
for j, t := range proxyCfg.Targets {
|
||||
failTimeout := t.FailTimeout
|
||||
if t.MaxFails > 0 && failTimeout == 0 {
|
||||
failTimeout = 10 * time.Second
|
||||
}
|
||||
targets[j] = loadbalance.NewTargetFromConfig(
|
||||
t.URL, t.Weight,
|
||||
int64(t.MaxConns), int64(t.MaxFails), failTimeout,
|
||||
t.Backup, t.Down, t.ProxyURI,
|
||||
)
|
||||
}
|
||||
|
||||
// 传递 Transport 配置和 Lua 引擎
|
||||
p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport, s.luaEngine)
|
||||
if err != nil {
|
||||
logging.Error().Msg("Failed to create proxy: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// 设置 DNS 解析器(如果已配置)
|
||||
if s.resolver != nil {
|
||||
p.SetResolver(s.resolver)
|
||||
if err := p.Start(); err != nil {
|
||||
logging.Error().Err(err).Msg("Failed to start proxy")
|
||||
}
|
||||
}
|
||||
|
||||
// 启动健康检查
|
||||
if proxyCfg.HealthCheck.Interval > 0 {
|
||||
hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck)
|
||||
hc.Start()
|
||||
s.healthCheckers = append(s.healthCheckers, hc)
|
||||
// 设置被动健康检查
|
||||
p.SetHealthChecker(hc)
|
||||
}
|
||||
|
||||
// 保存代理实例用于缓存统计
|
||||
s.proxies = append(s.proxies, p)
|
||||
|
||||
// 使用前缀匹配(通配符)注册代理路由
|
||||
// path: / 匹配所有子路径如 /sorry/index
|
||||
// path: /api/ 匹配 /api/* 所有子路径
|
||||
routePath := proxyCfg.Path
|
||||
// 确保通配符路由格式正确
|
||||
if !strings.HasSuffix(routePath, "/") && routePath != "/" {
|
||||
routePath += "/"
|
||||
}
|
||||
wildcardPath := routePath + "{path:*}"
|
||||
router.GET(wildcardPath, p.ServeHTTP)
|
||||
router.POST(wildcardPath, p.ServeHTTP)
|
||||
router.PUT(wildcardPath, p.ServeHTTP)
|
||||
router.DELETE(wildcardPath, p.ServeHTTP)
|
||||
router.HEAD(wildcardPath, p.ServeHTTP)
|
||||
}
|
||||
}
|
||||
|
||||
// registerStaticHandlers 注册静态文件处理器。
|
||||
//
|
||||
// 为路由器注册静态文件服务,支持多个静态目录、文件缓存和预压缩文件。
|
||||
//
|
||||
// 参数:
|
||||
// - router: 路由器实例,用于注册路由规则
|
||||
// - cfg: 服务器配置,包含静态文件和压缩设置
|
||||
func (s *Server) registerStaticHandlers(router *handler.Router, cfg *config.ServerConfig) {
|
||||
for _, static := range cfg.Static {
|
||||
path := static.Path
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
staticHandler := handler.NewStaticHandler(
|
||||
static.Root,
|
||||
path,
|
||||
static.Index,
|
||||
true, // useSendfile
|
||||
)
|
||||
// 设置 alias(与 root 互斥)
|
||||
if static.Alias != "" {
|
||||
staticHandler.SetAlias(static.Alias)
|
||||
}
|
||||
if s.fileCache != nil {
|
||||
staticHandler.SetFileCache(s.fileCache)
|
||||
// 设置默认缓存 TTL (5s)
|
||||
staticHandler.SetCacheTTL(5 * time.Second)
|
||||
}
|
||||
if cfg.Compression.GzipStatic {
|
||||
// extensions: 源文件类型,为空使用默认值
|
||||
// GzipStaticExtensions: 预压缩文件扩展名(如 .br, .gz)
|
||||
staticHandler.SetGzipStatic(true, nil, cfg.Compression.GzipStaticExtensions)
|
||||
}
|
||||
|
||||
// 设置符号链接安全检查
|
||||
staticHandler.SetSymlinkCheck(static.SymlinkCheck)
|
||||
|
||||
// 设置缓存过期时间
|
||||
if static.Expires != "" {
|
||||
staticHandler.SetExpires(static.Expires)
|
||||
}
|
||||
|
||||
// 设置 try_files 配置
|
||||
if len(static.TryFiles) > 0 {
|
||||
// 注意:tryFilesPass 需要路由器支持,当前实现传入 nil
|
||||
// 如果 tryFilesPass 为 true,需要额外处理
|
||||
staticHandler.SetTryFiles(static.TryFiles, static.TryFilesPass, router)
|
||||
}
|
||||
|
||||
// 注册路由:确保路径以 / 结尾
|
||||
routePath := path
|
||||
if !strings.HasSuffix(routePath, "/") {
|
||||
routePath += "/"
|
||||
}
|
||||
router.GET(routePath+"{filepath:*}", staticHandler.Handle)
|
||||
router.HEAD(routePath+"{filepath:*}", staticHandler.Handle)
|
||||
}
|
||||
}
|
||||
@ -20,7 +20,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
@ -34,16 +33,10 @@ import (
|
||||
"rua.plus/lolly/internal/cache"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/handler"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
"rua.plus/lolly/internal/lua"
|
||||
"rua.plus/lolly/internal/matcher"
|
||||
"rua.plus/lolly/internal/middleware"
|
||||
"rua.plus/lolly/internal/middleware/accesslog"
|
||||
"rua.plus/lolly/internal/middleware/bodylimit"
|
||||
"rua.plus/lolly/internal/middleware/compression"
|
||||
"rua.plus/lolly/internal/middleware/errorintercept"
|
||||
"rua.plus/lolly/internal/middleware/rewrite"
|
||||
"rua.plus/lolly/internal/middleware/security"
|
||||
"rua.plus/lolly/internal/mimeutil"
|
||||
"rua.plus/lolly/internal/proxy"
|
||||
@ -211,236 +204,6 @@ func (s *Server) GetHandler() fasthttp.RequestHandler {
|
||||
return s.handler
|
||||
}
|
||||
|
||||
// buildMiddlewareChain 构建中间件链。
|
||||
//
|
||||
// 根据服务器配置按顺序构建中间件链,顺序为:
|
||||
//
|
||||
// AccessLog -> AccessControl -> RateLimiter -> BasicAuth -> Rewrite -> Compression -> SecurityHeaders
|
||||
//
|
||||
// 参数:
|
||||
// - serverCfg: 单个服务器的配置对象
|
||||
//
|
||||
// 返回值:
|
||||
// - *middleware.Chain: 构建完成的中间件链
|
||||
// - error: 构建过程中遇到的错误,如中间件创建失败
|
||||
//
|
||||
// 注意事项:
|
||||
// - 各中间件按顺序依次包装请求处理器
|
||||
// - 未配置的中间件不会添加到链中
|
||||
func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middleware.Chain, error) {
|
||||
var middlewares []middleware.Middleware
|
||||
|
||||
// 1. AccessLog (已集成)
|
||||
s.accessLogMiddleware = accesslog.New(&s.config.Logging)
|
||||
middlewares = append(middlewares, s.accessLogMiddleware)
|
||||
|
||||
// 2. Security: AccessControl (IP 访问控制)
|
||||
if len(serverCfg.Security.Access.Allow) > 0 || len(serverCfg.Security.Access.Deny) > 0 {
|
||||
ac, err := security.NewAccessControl(&serverCfg.Security.Access)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建访问控制中间件失败: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, ac)
|
||||
s.accessControl = ac
|
||||
}
|
||||
|
||||
// 3. Security: RateLimiter (速率限制)
|
||||
if serverCfg.Security.RateLimit.RequestRate > 0 {
|
||||
rl, err := security.NewRateLimiter(&serverCfg.Security.RateLimit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建限流中间件失败: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, rl)
|
||||
}
|
||||
|
||||
// 3.5 Security: ConnLimiter (连接数限制)
|
||||
if serverCfg.Security.RateLimit.ConnLimit > 0 {
|
||||
cl, err := security.NewConnLimiter(serverCfg.Security.RateLimit.ConnLimit, true, serverCfg.Security.RateLimit.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建连接限制中间件失败: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, cl.Middleware())
|
||||
}
|
||||
|
||||
// 4. Security: BasicAuth (认证)
|
||||
if len(serverCfg.Security.Auth.Users) > 0 {
|
||||
auth, err := security.NewBasicAuth(&serverCfg.Security.Auth)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建认证中间件失败: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, auth)
|
||||
}
|
||||
|
||||
// 4.3 Security: AuthRequest (外部认证子请求)
|
||||
if serverCfg.Security.AuthRequest.Enabled && serverCfg.Security.AuthRequest.URI != "" {
|
||||
authReq, err := security.NewAuthRequest(serverCfg.Security.AuthRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建外部认证中间件失败: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, authReq)
|
||||
}
|
||||
|
||||
// 4.5 BodyLimit (请求体大小限制)
|
||||
// 创建 bodylimit 中间件,使用全局配置或默认值
|
||||
bodyLimitMiddleware := bodylimit.NewWithDefault()
|
||||
if serverCfg.ClientMaxBodySize != "" {
|
||||
bl, err := bodylimit.New(serverCfg.ClientMaxBodySize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求体限制中间件失败: %w", err)
|
||||
}
|
||||
bodyLimitMiddleware = bl
|
||||
}
|
||||
// 添加路径级别的限制配置
|
||||
for i := range serverCfg.Proxy {
|
||||
if serverCfg.Proxy[i].ClientMaxBodySize != "" {
|
||||
if err := bodyLimitMiddleware.AddPathLimit(
|
||||
serverCfg.Proxy[i].Path,
|
||||
serverCfg.Proxy[i].ClientMaxBodySize,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("添加路径请求体限制失败: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
middlewares = append(middlewares, bodyLimitMiddleware)
|
||||
|
||||
// 5. Rewrite (URL 重写)
|
||||
if len(serverCfg.Rewrite) > 0 {
|
||||
rw, err := rewrite.New(serverCfg.Rewrite)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建重写中间件失败: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, rw)
|
||||
}
|
||||
|
||||
// 6. Compression (响应压缩)
|
||||
if serverCfg.Compression.Type != "" {
|
||||
comp, err := compression.New(&serverCfg.Compression)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建压缩中间件失败: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, comp)
|
||||
}
|
||||
|
||||
// 7. SecurityHeaders (安全头部)
|
||||
// 如果有任何安全头部配置,则启用
|
||||
if serverCfg.Security.Headers.XFrameOptions != "" ||
|
||||
serverCfg.Security.Headers.XContentTypeOptions != "" ||
|
||||
serverCfg.Security.Headers.ContentSecurityPolicy != "" ||
|
||||
serverCfg.Security.Headers.ReferrerPolicy != "" ||
|
||||
serverCfg.Security.Headers.PermissionsPolicy != "" {
|
||||
headers := security.NewHeadersWithHSTS(&serverCfg.Security.Headers, &serverCfg.SSL.HSTS)
|
||||
middlewares = append(middlewares, headers)
|
||||
}
|
||||
|
||||
// 8. ErrorIntercept (错误页面拦截)
|
||||
// 如果配置了错误页面,添加错误拦截中间件
|
||||
if s.errorPageManager != nil && s.errorPageManager.IsConfigured() {
|
||||
ei := errorintercept.New(s.errorPageManager)
|
||||
middlewares = append(middlewares, ei)
|
||||
}
|
||||
|
||||
// Lua 中间件(可选)
|
||||
if s.luaEngine != nil && serverCfg.Lua != nil && serverCfg.Lua.Enabled {
|
||||
luaMiddlewares, err := s.buildLuaMiddlewares(serverCfg.Lua)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 Lua 中间件失败: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, luaMiddlewares...)
|
||||
}
|
||||
|
||||
return middleware.NewChain(middlewares...), nil
|
||||
}
|
||||
|
||||
// buildLuaMiddlewares 根据 Lua 配置创建中间件。
|
||||
//
|
||||
// 根据 Scripts 配置创建 LuaMiddleware 或 MultiPhaseLuaMiddleware。
|
||||
// 支持单脚本和多阶段脚本配置。
|
||||
//
|
||||
// 参数:
|
||||
// - luaCfg: Lua 配置对象
|
||||
//
|
||||
// 返回值:
|
||||
// - []middleware.Middleware: 创建的中间件列表
|
||||
// - error: 创建过程中遇到的错误
|
||||
func (s *Server) buildLuaMiddlewares(luaCfg *config.LuaMiddlewareConfig) ([]middleware.Middleware, error) {
|
||||
if s.luaEngine == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 按阶段分组脚本
|
||||
phaseScripts := make(map[string][]config.LuaScriptConfig)
|
||||
for _, script := range luaCfg.Scripts {
|
||||
// 默认启用
|
||||
enabled := script.Enabled
|
||||
if !enabled && script.Timeout == 0 && script.Path != "" {
|
||||
enabled = true // 零值时默认启用
|
||||
}
|
||||
if enabled {
|
||||
phaseScripts[script.Phase] = append(phaseScripts[script.Phase], script)
|
||||
}
|
||||
}
|
||||
|
||||
var middlewares []middleware.Middleware
|
||||
|
||||
// 为每个阶段创建中间件
|
||||
for phase, scripts := range phaseScripts {
|
||||
if len(scripts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 单脚本:直接创建 LuaMiddleware
|
||||
if len(scripts) == 1 {
|
||||
script := scripts[0]
|
||||
luaPhase, err := lua.ParsePhase(phase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无效的阶段 '%s': %w", phase, err)
|
||||
}
|
||||
|
||||
timeout := script.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
cfg := lua.LuaMiddlewareConfig{
|
||||
ScriptPath: script.Path,
|
||||
Phase: luaPhase,
|
||||
Timeout: timeout,
|
||||
Name: fmt.Sprintf("lua-%s", phase),
|
||||
}
|
||||
|
||||
mw, err := lua.NewLuaMiddleware(s.luaEngine, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建 Lua 中间件失败 (phase=%s): %w", phase, err)
|
||||
}
|
||||
|
||||
middlewares = append(middlewares, mw)
|
||||
} else {
|
||||
// 多脚本:创建 MultiPhaseLuaMiddleware
|
||||
multi := lua.NewMultiPhaseLuaMiddleware(s.luaEngine, fmt.Sprintf("lua-multi-%s", phase))
|
||||
for _, script := range scripts {
|
||||
luaPhase, err := lua.ParsePhase(phase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无效的阶段 '%s': %w", phase, err)
|
||||
}
|
||||
|
||||
timeout := script.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 30 * time.Second
|
||||
}
|
||||
|
||||
err = multi.AddPhase(luaPhase, script.Path, timeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("添加 Lua 阶段失败 (phase=%s): %w", phase, err)
|
||||
}
|
||||
}
|
||||
|
||||
middlewares = append(middlewares, multi)
|
||||
}
|
||||
}
|
||||
|
||||
return middlewares, nil
|
||||
}
|
||||
|
||||
// Start 启动 HTTP 服务器。
|
||||
//
|
||||
// 初始化日志系统、性能优化组件(Goroutine池、文件缓存),
|
||||
@ -549,13 +312,13 @@ func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error)
|
||||
mode = cfg.UnixSocket.Mode
|
||||
}
|
||||
if err := os.Chmod(socketPath, os.FileMode(mode)); err != nil {
|
||||
logging.Warn().Err(err).Msg("设置 socket 文件权限失败")
|
||||
logging.Warn().Err(err).Msg("Failed to set socket file permissions")
|
||||
}
|
||||
|
||||
// 5. 设置文件所有权(需要 root 权限)
|
||||
if cfg.UnixSocket.User != "" || cfg.UnixSocket.Group != "" {
|
||||
// 简化处理:仅记录警告,实际实现需要 syscall.Chown
|
||||
logging.Warn().Msg("Unix socket 用户/组配置需要 root 权限,已跳过")
|
||||
logging.Warn().Msg("Unix socket user/group config requires root privileges, skipped")
|
||||
}
|
||||
|
||||
return listener, nil
|
||||
@ -590,7 +353,7 @@ func (s *Server) startSingleMode() error {
|
||||
if s.config.Monitoring.Status.Path != "" || len(s.config.Monitoring.Status.Allow) > 0 {
|
||||
statusHandler, err := NewStatusHandler(s, &s.config.Monitoring.Status)
|
||||
if err != nil {
|
||||
logging.Error().Msg("创建状态处理器失败: " + err.Error())
|
||||
logging.Error().Msg("Failed to create status handler: " + err.Error())
|
||||
} else {
|
||||
_ = s.locationEngine.AddExact(statusHandler.Path(), statusHandler.ServeHTTP, false)
|
||||
}
|
||||
@ -600,7 +363,7 @@ func (s *Server) startSingleMode() error {
|
||||
if s.config.Monitoring.Pprof.Enabled {
|
||||
pprofHandler, err := NewPprofHandler(&s.config.Monitoring.Pprof)
|
||||
if err != nil {
|
||||
logging.Error().Msg("创建 pprof 处理器失败: " + err.Error())
|
||||
logging.Error().Msg("Failed to create pprof handler: " + err.Error())
|
||||
} else {
|
||||
_ = s.locationEngine.AddExact(pprofHandler.Path(), pprofHandler.ServeHTTP, false)
|
||||
_ = s.locationEngine.AddPrefixPriority(pprofHandler.Path()+"/", pprofHandler.ServeHTTP, false)
|
||||
@ -611,7 +374,7 @@ func (s *Server) startSingleMode() error {
|
||||
if serverCfg.CacheAPI != nil && serverCfg.CacheAPI.Enabled {
|
||||
purgeHandler, err := NewPurgeHandler(s, serverCfg.CacheAPI)
|
||||
if err != nil {
|
||||
logging.Error().Msg("创建缓存清理处理器失败: " + err.Error())
|
||||
logging.Error().Msg("Failed to create cache purge handler: " + err.Error())
|
||||
} else {
|
||||
_ = s.locationEngine.AddExact(purgeHandler.Path(), purgeHandler.ServeHTTP, false)
|
||||
}
|
||||
@ -685,7 +448,7 @@ func (s *Server) startSingleMode() error {
|
||||
var err error
|
||||
s.tlsManager, err = ssl.NewTLSManager(&serverCfg.SSL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 TLS 管理器失败: %w", err)
|
||||
return fmt.Errorf("failed to create TLS manager: %w", err)
|
||||
}
|
||||
s.fastServer.TLSConfig = s.tlsManager.GetTLSConfig()
|
||||
return s.fastServer.ServeTLS(ln, "", "")
|
||||
@ -747,7 +510,7 @@ func (s *Server) startVHostMode() error {
|
||||
if s.config.Monitoring.Status.Enabled {
|
||||
statusHandler, err := NewStatusHandler(s, &s.config.Monitoring.Status)
|
||||
if err != nil {
|
||||
logging.Error().Msg("创建状态处理器失败: " + err.Error())
|
||||
logging.Error().Msg("Failed to create status handler: " + err.Error())
|
||||
} else {
|
||||
router.GET(statusHandler.Path(), statusHandler.ServeHTTP)
|
||||
}
|
||||
@ -757,7 +520,7 @@ func (s *Server) startVHostMode() error {
|
||||
if s.config.Monitoring.Pprof.Enabled {
|
||||
pprofHandler, err := NewPprofHandler(&s.config.Monitoring.Pprof)
|
||||
if err != nil {
|
||||
logging.Error().Msg("创建 pprof 处理器失败: " + err.Error())
|
||||
logging.Error().Msg("Failed to create pprof handler: " + err.Error())
|
||||
} else {
|
||||
router.GET(pprofHandler.Path(), pprofHandler.ServeHTTP)
|
||||
router.GET(pprofHandler.Path()+"/{profile:*}", pprofHandler.ServeHTTP)
|
||||
@ -769,7 +532,7 @@ func (s *Server) startVHostMode() error {
|
||||
if defaultSrv != nil && defaultSrv.CacheAPI != nil && defaultSrv.CacheAPI.Enabled {
|
||||
purgeHandler, err := NewPurgeHandler(s, defaultSrv.CacheAPI)
|
||||
if err != nil {
|
||||
logging.Error().Msg("创建缓存清理处理器失败: " + err.Error())
|
||||
logging.Error().Msg("Failed to create cache purge handler: " + err.Error())
|
||||
} else {
|
||||
router.POST(purgeHandler.Path(), purgeHandler.ServeHTTP)
|
||||
}
|
||||
@ -829,7 +592,7 @@ func (s *Server) startVHostMode() error {
|
||||
var err error
|
||||
s.tlsManager, err = ssl.NewTLSManager(&serverCfg.SSL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建 TLS 管理器失败: %w", err)
|
||||
return fmt.Errorf("failed to create TLS manager: %w", err)
|
||||
}
|
||||
s.fastServer.TLSConfig = s.tlsManager.GetTLSConfig()
|
||||
return s.fastServer.ServeTLS(ln, "", "")
|
||||
@ -853,7 +616,7 @@ func (s *Server) startVHostMode() error {
|
||||
func (s *Server) startMultiServerMode() error {
|
||||
// 热升级检测:multi_server 热升级未实现,回退到 vhost 模式
|
||||
if os.Getenv("GRACEFUL_UPGRADE") == "1" {
|
||||
logging.Warn().Msg("热升级模式下 multi_server 模式未实现,回退到虚拟主机模式")
|
||||
logging.Warn().Msg("multi_server mode not implemented for graceful upgrade, falling back to vhost mode")
|
||||
return s.startVHostMode()
|
||||
}
|
||||
|
||||
@ -874,7 +637,7 @@ func (s *Server) startMultiServerMode() error {
|
||||
// 创建监听器
|
||||
ln, err := s.createListener(serverCfg)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("监听地址 %s 失败: %w", serverCfg.Listen, err)
|
||||
errCh <- fmt.Errorf("failed to listen on %s: %w", serverCfg.Listen, err)
|
||||
return
|
||||
}
|
||||
s.listeners[idx] = ln
|
||||
@ -886,7 +649,7 @@ func (s *Server) startMultiServerMode() error {
|
||||
if idx == 0 && serverCfg.CacheAPI != nil && serverCfg.CacheAPI.Enabled {
|
||||
purgeHandler, purgeErr := NewPurgeHandler(s, serverCfg.CacheAPI)
|
||||
if purgeErr != nil {
|
||||
errCh <- fmt.Errorf("创建缓存清理处理器失败 (server[%d]): %w", idx, purgeErr)
|
||||
errCh <- fmt.Errorf("failed to create cache purge handler (server[%d]): %w", idx, purgeErr)
|
||||
return
|
||||
}
|
||||
router.POST(purgeHandler.Path(), purgeHandler.ServeHTTP)
|
||||
@ -900,7 +663,7 @@ func (s *Server) startMultiServerMode() error {
|
||||
// 构建独立的中间件链
|
||||
chain, err := s.buildMiddlewareChain(serverCfg)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("构建中间件链失败 (server[%d]): %w", idx, err)
|
||||
errCh <- fmt.Errorf("failed to build middleware chain (server[%d]): %w", idx, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -927,7 +690,7 @@ func (s *Server) startMultiServerMode() error {
|
||||
if serverCfg.SSL.Cert != "" && serverCfg.SSL.Key != "" {
|
||||
tlsManager, err := ssl.NewTLSManager(&serverCfg.SSL)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("创建 TLS 管理器失败 (server[%d]): %w", idx, err)
|
||||
errCh <- fmt.Errorf("failed to create TLS manager (server[%d]): %w", idx, err)
|
||||
return
|
||||
}
|
||||
fastSrv.TLSConfig = tlsManager.GetTLSConfig()
|
||||
@ -978,7 +741,7 @@ func (s *Server) startMultiServerMode() error {
|
||||
serveErr = f.Serve(l)
|
||||
}
|
||||
if serveErr != nil {
|
||||
logging.Error().Err(serveErr).Msgf("服务器 [%d] 监听 %s 时发生错误", i, l.Addr())
|
||||
logging.Error().Err(serveErr).Msgf("Server [%d] error while listening on %s", i, l.Addr())
|
||||
}
|
||||
}(fastSrv, ln, idx)
|
||||
}
|
||||
@ -988,509 +751,6 @@ func (s *Server) startMultiServerMode() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerProxyRoutesWithLocationEngine 使用 LocationEngine 注册代理路由。
|
||||
//
|
||||
// 根据配置为 LocationEngine 注册代理路径,创建代理处理器和健康检查器。
|
||||
// 支持通过 LocationType 配置不同的匹配方式。
|
||||
func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) {
|
||||
for i := range serverCfg.Proxy {
|
||||
proxyCfg := &serverCfg.Proxy[i]
|
||||
|
||||
// 转换目标
|
||||
targets := make([]*loadbalance.Target, len(proxyCfg.Targets))
|
||||
for j, t := range proxyCfg.Targets {
|
||||
failTimeout := t.FailTimeout
|
||||
if t.MaxFails > 0 && failTimeout == 0 {
|
||||
failTimeout = 10 * time.Second
|
||||
}
|
||||
targets[j] = loadbalance.NewTargetFromConfig(
|
||||
t.URL, t.Weight,
|
||||
int64(t.MaxConns), int64(t.MaxFails), failTimeout,
|
||||
t.Backup, t.Down, t.ProxyURI,
|
||||
)
|
||||
}
|
||||
|
||||
// 传递 Transport 配置和 Lua 引擎
|
||||
p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport, s.luaEngine)
|
||||
if err != nil {
|
||||
logging.Error().Msg("创建代理失败: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// 设置 DNS 解析器(如果已配置)
|
||||
if s.resolver != nil {
|
||||
p.SetResolver(s.resolver)
|
||||
if err := p.Start(); err != nil {
|
||||
logging.Error().Err(err).Msg("启动代理失败")
|
||||
}
|
||||
}
|
||||
|
||||
// 启动健康检查
|
||||
if proxyCfg.HealthCheck.Interval > 0 {
|
||||
hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck)
|
||||
hc.Start()
|
||||
s.healthCheckers = append(s.healthCheckers, hc)
|
||||
// 设置被动健康检查
|
||||
p.SetHealthChecker(hc)
|
||||
}
|
||||
|
||||
// 保存代理实例用于缓存统计
|
||||
s.proxies = append(s.proxies, p)
|
||||
|
||||
// 根据 LocationType 注册路由
|
||||
locType := proxyCfg.LocationType
|
||||
if locType == "" {
|
||||
locType = matcher.LocationTypePrefix
|
||||
}
|
||||
|
||||
switch locType {
|
||||
case matcher.LocationTypeExact:
|
||||
_ = s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
|
||||
case matcher.LocationTypePrefixPriority:
|
||||
_ = s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
|
||||
case matcher.LocationTypeRegex, matcher.LocationTypeRegexCaseless:
|
||||
caseInsensitive := locType == matcher.LocationTypeRegexCaseless
|
||||
_ = s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal)
|
||||
case matcher.LocationTypeNamed:
|
||||
if proxyCfg.LocationName != "" {
|
||||
_ = s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP)
|
||||
}
|
||||
case matcher.LocationTypePrefix:
|
||||
_ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
|
||||
default:
|
||||
_ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registerStaticHandlersWithLocationEngine 使用 LocationEngine 注册静态文件处理器。
|
||||
func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) {
|
||||
for _, static := range cfg.Static {
|
||||
path := static.Path
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
staticHandler := handler.NewStaticHandler(
|
||||
static.Root,
|
||||
path,
|
||||
static.Index,
|
||||
true, // useSendfile
|
||||
)
|
||||
// 设置 alias(与 root 互斥)
|
||||
if static.Alias != "" {
|
||||
staticHandler.SetAlias(static.Alias)
|
||||
}
|
||||
if s.fileCache != nil {
|
||||
staticHandler.SetFileCache(s.fileCache)
|
||||
// 设置默认缓存 TTL (5s)
|
||||
staticHandler.SetCacheTTL(5 * time.Second)
|
||||
}
|
||||
if cfg.Compression.GzipStatic {
|
||||
// extensions: 源文件类型,为空使用默认值
|
||||
// GzipStaticExtensions: 预压缩文件扩展名(如 .br, .gz)
|
||||
staticHandler.SetGzipStatic(true, nil, cfg.Compression.GzipStaticExtensions)
|
||||
}
|
||||
|
||||
// 设置符号链接安全检查
|
||||
staticHandler.SetSymlinkCheck(static.SymlinkCheck)
|
||||
|
||||
// 设置 internal 限制
|
||||
staticHandler.SetInternal(static.Internal)
|
||||
|
||||
// 设置缓存过期时间
|
||||
if static.Expires != "" {
|
||||
staticHandler.SetExpires(static.Expires)
|
||||
}
|
||||
|
||||
// 根据 LocationType 注册路由
|
||||
locType := static.LocationType
|
||||
if locType == "" {
|
||||
locType = matcher.LocationTypePrefix
|
||||
}
|
||||
|
||||
switch locType {
|
||||
case matcher.LocationTypeExact:
|
||||
_ = s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal)
|
||||
case matcher.LocationTypePrefixPriority:
|
||||
_ = s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal)
|
||||
case matcher.LocationTypePrefix:
|
||||
_ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal)
|
||||
default:
|
||||
_ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// registerProxyRoutes 注册代理路由。
|
||||
//
|
||||
// 根据配置为路由器注册代理路径,创建代理处理器和健康检查器。
|
||||
// 支持 GET、POST、PUT、DELETE、HEAD 等 HTTP 方法。
|
||||
//
|
||||
// 参数:
|
||||
// - router: 路由器实例,用于注册路由规则
|
||||
// - serverCfg: 服务器配置,包含代理目标、负载均衡、健康检查等设置
|
||||
//
|
||||
// 注意事项:
|
||||
// - 代理目标初始状态默认为健康
|
||||
// - 健康检查根据配置自动启动
|
||||
func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.ServerConfig) {
|
||||
for i := range serverCfg.Proxy {
|
||||
proxyCfg := &serverCfg.Proxy[i]
|
||||
|
||||
// 转换目标
|
||||
targets := make([]*loadbalance.Target, len(proxyCfg.Targets))
|
||||
for j, t := range proxyCfg.Targets {
|
||||
failTimeout := t.FailTimeout
|
||||
if t.MaxFails > 0 && failTimeout == 0 {
|
||||
failTimeout = 10 * time.Second
|
||||
}
|
||||
targets[j] = loadbalance.NewTargetFromConfig(
|
||||
t.URL, t.Weight,
|
||||
int64(t.MaxConns), int64(t.MaxFails), failTimeout,
|
||||
t.Backup, t.Down, t.ProxyURI,
|
||||
)
|
||||
}
|
||||
|
||||
// 传递 Transport 配置和 Lua 引擎
|
||||
p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport, s.luaEngine)
|
||||
if err != nil {
|
||||
logging.Error().Msg("创建代理失败: " + err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// 设置 DNS 解析器(如果已配置)
|
||||
if s.resolver != nil {
|
||||
p.SetResolver(s.resolver)
|
||||
if err := p.Start(); err != nil {
|
||||
logging.Error().Err(err).Msg("启动代理失败")
|
||||
}
|
||||
}
|
||||
|
||||
// 启动健康检查
|
||||
if proxyCfg.HealthCheck.Interval > 0 {
|
||||
hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck)
|
||||
hc.Start()
|
||||
s.healthCheckers = append(s.healthCheckers, hc)
|
||||
// 设置被动健康检查
|
||||
p.SetHealthChecker(hc)
|
||||
}
|
||||
|
||||
// 保存代理实例用于缓存统计
|
||||
s.proxies = append(s.proxies, p)
|
||||
|
||||
// 使用前缀匹配(通配符)注册代理路由
|
||||
// path: / 匹配所有子路径如 /sorry/index
|
||||
// path: /api/ 匹配 /api/* 所有子路径
|
||||
routePath := proxyCfg.Path
|
||||
// 确保通配符路由格式正确
|
||||
if !strings.HasSuffix(routePath, "/") && routePath != "/" {
|
||||
routePath += "/"
|
||||
}
|
||||
wildcardPath := routePath + "{path:*}"
|
||||
router.GET(wildcardPath, p.ServeHTTP)
|
||||
router.POST(wildcardPath, p.ServeHTTP)
|
||||
router.PUT(wildcardPath, p.ServeHTTP)
|
||||
router.DELETE(wildcardPath, p.ServeHTTP)
|
||||
router.HEAD(wildcardPath, p.ServeHTTP)
|
||||
}
|
||||
}
|
||||
|
||||
// shutdownServers 并行关闭多个 fasthttp.Server 实例。
|
||||
//
|
||||
// 使用 goroutine 并行关闭所有服务器,收集所有错误并返回聚合错误。
|
||||
// 部分服务器关闭失败不会影响其他服务器的关闭。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 关闭上下文,用于控制超时和取消
|
||||
// - servers: 要关闭的 fasthttp.Server 实例列表
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 聚合错误,无错误或全部成功时返回 nil
|
||||
func shutdownServers(ctx context.Context, servers []*fasthttp.Server) error {
|
||||
// 防御性检查:nil ctx 使用默认背景
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if len(servers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
errs []error
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
|
||||
for _, srv := range servers {
|
||||
if srv == nil {
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(s *fasthttp.Server) {
|
||||
defer wg.Done()
|
||||
if err := s.Shutdown(); err != nil {
|
||||
mu.Lock()
|
||||
errs = append(errs, err)
|
||||
mu.Unlock()
|
||||
}
|
||||
}(srv)
|
||||
}
|
||||
|
||||
// 等待所有关闭完成或上下文取消
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
if len(errs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(errs) == 1 {
|
||||
return errs[0]
|
||||
}
|
||||
msgs := make([]string, len(errs))
|
||||
for i, e := range errs {
|
||||
msgs[i] = e.Error()
|
||||
}
|
||||
return fmt.Errorf("关闭服务器时发生 %d 个错误: %s", len(errs), strings.Join(msgs, "; "))
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// StopWithTimeout 快速停止服务器(支持自定义超时)。
|
||||
//
|
||||
// 立即停止服务器,不等待正在处理的请求完成。
|
||||
// 停止所有健康检查器和访问日志中间件。
|
||||
//
|
||||
// 参数:
|
||||
// - timeout: 快速关闭的最大等待时间
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 停止过程中遇到的错误
|
||||
//
|
||||
// 注意事项:
|
||||
// - 对于生产环境,建议使用 GracefulStop 实现优雅关闭
|
||||
// - timeout <= 0 时会使用默认 5s 超时
|
||||
func (s *Server) StopWithTimeout(timeout time.Duration) error {
|
||||
// 防御性检查:如果 timeout <= 0,使用默认值
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
|
||||
s.running = false
|
||||
|
||||
// 停止 Goroutine 池
|
||||
if s.pool != nil {
|
||||
s.pool.Stop()
|
||||
}
|
||||
|
||||
// 停止健康检查器
|
||||
for _, hc := range s.healthCheckers {
|
||||
hc.Stop()
|
||||
}
|
||||
|
||||
// 关闭访问日志
|
||||
if s.accessLogMiddleware != nil {
|
||||
_ = s.accessLogMiddleware.Close()
|
||||
}
|
||||
|
||||
// 关闭 TLS 管理器
|
||||
if s.tlsManager != nil {
|
||||
s.tlsManager.Close()
|
||||
}
|
||||
|
||||
// 关闭 AccessControl (释放 GeoIP 资源)
|
||||
if s.accessControl != nil {
|
||||
if err := s.accessControl.Close(); err != nil {
|
||||
logging.Warn().Err(err).Msg("关闭 AccessControl 失败")
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭 Lua 引擎
|
||||
if s.luaEngine != nil {
|
||||
s.luaEngine.Close()
|
||||
logging.Info().Msg("Lua 引擎已关闭")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// 多服务器模式:并行关闭所有 fasthttp.Server
|
||||
if len(s.fastServers) > 0 {
|
||||
return shutdownServers(ctx, s.fastServers)
|
||||
}
|
||||
|
||||
// 单服务器模式:关闭单个 fasthttp.Server
|
||||
if s.fastServer != nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = s.fastServer.Shutdown()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GracefulStop 优雅停止服务器。
|
||||
//
|
||||
// 等待正在处理的请求完成后再停止服务器,确保连接正常关闭。
|
||||
// 如果超时时间到达仍有请求未完成,将返回超时错误。
|
||||
//
|
||||
// 参数:
|
||||
// - timeout: 优雅关闭的最大等待时间
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 停止过程中遇到的错误,超时返回 context.DeadlineExceeded
|
||||
//
|
||||
// 注意事项:
|
||||
// - 推荐在生产环境使用此方法关闭服务器
|
||||
// - 超时后会强制关闭,可能导致部分请求中断
|
||||
func (s *Server) GracefulStop(timeout time.Duration) error {
|
||||
s.running = false
|
||||
|
||||
// 停止 Goroutine 池
|
||||
if s.pool != nil {
|
||||
s.pool.Stop()
|
||||
}
|
||||
|
||||
// 停止健康检查器
|
||||
for _, hc := range s.healthCheckers {
|
||||
hc.Stop()
|
||||
}
|
||||
|
||||
// 关闭访问日志
|
||||
if s.accessLogMiddleware != nil {
|
||||
_ = s.accessLogMiddleware.Close()
|
||||
}
|
||||
|
||||
// 关闭 TLS 管理器
|
||||
if s.tlsManager != nil {
|
||||
s.tlsManager.Close()
|
||||
}
|
||||
|
||||
// 关闭 AccessControl (释放 GeoIP 资源)
|
||||
if s.accessControl != nil {
|
||||
if err := s.accessControl.Close(); err != nil {
|
||||
logging.Warn().Err(err).Msg("关闭 AccessControl 失败")
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭 Lua 引擎
|
||||
if s.luaEngine != nil {
|
||||
s.luaEngine.Close()
|
||||
logging.Info().Msg("Lua 引擎已关闭")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// 多服务器模式:并行关闭所有 fasthttp.Server
|
||||
if len(s.fastServers) > 0 {
|
||||
return shutdownServers(ctx, s.fastServers)
|
||||
}
|
||||
|
||||
// 单服务器模式:关闭单个 fasthttp.Server
|
||||
if s.fastServer != nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = s.fastServer.Shutdown()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getProxyCacheStats 收集所有代理缓存的统计信息。
|
||||
func (s *Server) getProxyCacheStats() ProxyCacheStats {
|
||||
var total ProxyCacheStats
|
||||
for _, p := range s.proxies {
|
||||
if stats := p.GetCacheStats(); stats != nil {
|
||||
total.Entries += stats.Entries
|
||||
total.Pending += stats.Pending
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// registerStaticHandlers 注册静态文件处理器。
|
||||
//
|
||||
// 为路由器注册静态文件服务,支持多个静态目录、文件缓存和预压缩文件。
|
||||
//
|
||||
// 参数:
|
||||
// - router: 路由器实例,用于注册路由规则
|
||||
// - cfg: 服务器配置,包含静态文件和压缩设置
|
||||
func (s *Server) registerStaticHandlers(router *handler.Router, cfg *config.ServerConfig) {
|
||||
for _, static := range cfg.Static {
|
||||
path := static.Path
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
staticHandler := handler.NewStaticHandler(
|
||||
static.Root,
|
||||
path,
|
||||
static.Index,
|
||||
true, // useSendfile
|
||||
)
|
||||
// 设置 alias(与 root 互斥)
|
||||
if static.Alias != "" {
|
||||
staticHandler.SetAlias(static.Alias)
|
||||
}
|
||||
if s.fileCache != nil {
|
||||
staticHandler.SetFileCache(s.fileCache)
|
||||
// 设置默认缓存 TTL (5s)
|
||||
staticHandler.SetCacheTTL(5 * time.Second)
|
||||
}
|
||||
if cfg.Compression.GzipStatic {
|
||||
// extensions: 源文件类型,为空使用默认值
|
||||
// GzipStaticExtensions: 预压缩文件扩展名(如 .br, .gz)
|
||||
staticHandler.SetGzipStatic(true, nil, cfg.Compression.GzipStaticExtensions)
|
||||
}
|
||||
|
||||
// 设置符号链接安全检查
|
||||
staticHandler.SetSymlinkCheck(static.SymlinkCheck)
|
||||
|
||||
// 设置缓存过期时间
|
||||
if static.Expires != "" {
|
||||
staticHandler.SetExpires(static.Expires)
|
||||
}
|
||||
|
||||
// 设置 try_files 配置
|
||||
if len(static.TryFiles) > 0 {
|
||||
// 注意:tryFilesPass 需要路由器支持,当前实现传入 nil
|
||||
// 如果 tryFilesPass 为 true,需要额外处理
|
||||
staticHandler.SetTryFiles(static.TryFiles, static.TryFilesPass, router)
|
||||
}
|
||||
|
||||
// 注册路由:确保路径以 / 结尾
|
||||
routePath := path
|
||||
if !strings.HasSuffix(routePath, "/") {
|
||||
routePath += "/"
|
||||
}
|
||||
router.GET(routePath+"{filepath:*}", staticHandler.Handle)
|
||||
router.HEAD(routePath+"{filepath:*}", staticHandler.Handle)
|
||||
}
|
||||
}
|
||||
|
||||
// SetResolver 设置 DNS 解析器。
|
||||
func (s *Server) SetResolver(r resolver.Resolver) {
|
||||
s.resolver = r
|
||||
|
||||
@ -21,7 +21,6 @@ import (
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
"rua.plus/lolly/internal/utils"
|
||||
)
|
||||
|
||||
@ -194,7 +193,7 @@ func (h *StatusHandler) Path() string {
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
func (h *StatusHandler) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
// 步骤1: 检查 IP 访问权限
|
||||
if !h.checkAccess(ctx) {
|
||||
if !utils.CheckIPAccess(ctx, h.allowed) {
|
||||
utils.SendErrorWithDetail(ctx, utils.ErrForbidden, "Access denied")
|
||||
return
|
||||
}
|
||||
@ -509,34 +508,6 @@ func (h *StatusHandler) serveHTML(ctx *fasthttp.RequestCtx, status *Status) {
|
||||
}
|
||||
}
|
||||
|
||||
// checkAccess 检查客户端 IP 是否在允许列表中。
|
||||
//
|
||||
// 如果未配置允许列表,则允许所有访问。
|
||||
// 检查时支持代理头部(X-Forwarded-For、X-Real-IP)。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
//
|
||||
// 返回值:
|
||||
// - bool: true 表示允许访问,false 表示拒绝
|
||||
func (h *StatusHandler) checkAccess(ctx *fasthttp.RequestCtx) bool {
|
||||
// 如果没有配置允许列表,允许所有访问
|
||||
if len(h.allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
clientIP := netutil.ExtractClientIPNet(ctx)
|
||||
|
||||
// 检查是否在允许列表中
|
||||
for _, network := range h.allowed {
|
||||
if network.Contains(clientIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// collectStatus 收集服务器状态数据。
|
||||
//
|
||||
// 从服务器实例读取各项统计指标,构建状态响应对象。
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"rua.plus/lolly/internal/cache"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
"rua.plus/lolly/internal/utils"
|
||||
)
|
||||
|
||||
func TestNewStatusHandler_CIDR(t *testing.T) {
|
||||
@ -896,7 +897,7 @@ func TestStatusHandler_checkAccess_AllowList(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetRemoteAddr(&net.TCPAddr{IP: tt.remoteIP, Port: 12345})
|
||||
|
||||
got := h.checkAccess(ctx)
|
||||
got := utils.CheckIPAccess(ctx, h.allowed)
|
||||
if got != tt.wantAccess {
|
||||
t.Errorf("expected access %v, got %v", tt.wantAccess, got)
|
||||
}
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
"rua.plus/lolly/internal/utils"
|
||||
)
|
||||
|
||||
// VHostManager 虚拟主机管理器。
|
||||
@ -212,7 +213,7 @@ func (v *VHostManager) Handler() fasthttp.RequestHandler {
|
||||
if vhost := v.FindHost(host); vhost != nil {
|
||||
vhost.handler(ctx)
|
||||
} else {
|
||||
ctx.Error("Host not found", fasthttp.StatusNotFound)
|
||||
utils.SendError(ctx, utils.ErrNotFound)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -192,9 +193,7 @@ func (m *SessionTicketManager) RotateKey() error {
|
||||
// 如果有密钥文件,保存所有密钥
|
||||
if m.config.KeyFile != "" {
|
||||
if err := m.saveKeys(); err != nil {
|
||||
// 保存失败不影响运行,记录错误即可
|
||||
// 这里可以考虑添加日志
|
||||
_ = err
|
||||
logging.Warn().Err(err).Msg("Session Ticket 密钥保存失败")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -47,6 +47,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/logging"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
)
|
||||
|
||||
@ -145,9 +146,7 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
|
||||
if cfg.SessionTickets.Enabled {
|
||||
sessionTicketMgr, err := NewSessionTicketManager(cfg.SessionTickets)
|
||||
if err != nil {
|
||||
// Session Tickets 初始化失败不阻止 TLS 工作
|
||||
// 可以记录日志
|
||||
_ = err
|
||||
logging.Warn().Err(err).Msg("Session Ticket 初始化失败,TLS 性能可能降级")
|
||||
} else {
|
||||
manager.sessionTicketMgr = sessionTicketMgr
|
||||
// 应用 Session Tickets 到 TLS 配置
|
||||
@ -174,9 +173,9 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
|
||||
issuerCert, err := x509.ParseCertificate(cert.Certificate[1])
|
||||
if err == nil {
|
||||
manager.issuers[serial] = issuerCert
|
||||
// 注册证书用于 OCSP Stapling
|
||||
// 错误会记录日志但不会阻止 TLS 工作
|
||||
_ = ocspMgr.RegisterCertificate(parsedCert, issuerCert)
|
||||
if err := ocspMgr.RegisterCertificate(parsedCert, issuerCert); err != nil {
|
||||
logging.Warn().Err(err).Msg("OCSP Stapling 注册失败")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -192,9 +191,7 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
|
||||
if cfg.ClientVerify.Enabled {
|
||||
clientVerifier, err := NewClientVerifier(cfg.ClientVerify)
|
||||
if err != nil {
|
||||
// 客户端验证配置失败不阻止 TLS 工作
|
||||
// 可以记录日志
|
||||
_ = err
|
||||
logging.Warn().Err(err).Msg("客户端证书验证配置失败")
|
||||
} else {
|
||||
manager.clientVerifier = clientVerifier
|
||||
clientVerifier.ConfigureTLS(tlsCfg)
|
||||
|
||||
@ -571,9 +571,22 @@ func (s *Server) handleConnection(clientConn net.Conn, _ string) {
|
||||
}
|
||||
defer func() { _ = targetConn.Close() }()
|
||||
|
||||
// 双向数据转发
|
||||
go func() { _, _ = io.Copy(targetConn, clientConn) }()
|
||||
// 双向数据转发:任一方向完成/出错时立即关闭双端连接,迫使另一方向退出
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(targetConn, clientConn)
|
||||
_ = clientConn.Close()
|
||||
_ = targetConn.Close()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(clientConn, targetConn)
|
||||
_ = targetConn.Close()
|
||||
_ = clientConn.Close()
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Select 选择健康的上游目标。
|
||||
|
||||
@ -4,7 +4,16 @@
|
||||
// the scattered pattern of ctx.Error throughout the codebase.
|
||||
package utils
|
||||
|
||||
import "github.com/valyala/fasthttp"
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
)
|
||||
|
||||
// HTTPError represents an HTTP error with a message and status code.
|
||||
type HTTPError struct {
|
||||
@ -37,3 +46,57 @@ func SendErrorWithDetail(ctx *fasthttp.RequestCtx, err HTTPError, detail string)
|
||||
SendError(ctx, err)
|
||||
}
|
||||
}
|
||||
|
||||
// SendJSONError sends a JSON error response.
|
||||
func SendJSONError(ctx *fasthttp.RequestCtx, status int, errMsg string) {
|
||||
ctx.SetContentType("application/json; charset=utf-8")
|
||||
ctx.SetStatusCode(status)
|
||||
_ = json.NewEncoder(ctx).Encode(struct {
|
||||
Error string `json:"error"`
|
||||
}{Error: errMsg})
|
||||
}
|
||||
|
||||
// CheckIPAccess checks whether the client IP is in the allowed list.
|
||||
// If allowed is empty, all access is permitted.
|
||||
func CheckIPAccess(ctx *fasthttp.RequestCtx, allowed []net.IPNet) bool {
|
||||
if len(allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
clientIP := netutil.ExtractClientIPNet(ctx)
|
||||
if clientIP == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, network := range allowed {
|
||||
if network.Contains(clientIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// CheckTokenAuth checks token-based authentication.
|
||||
// Returns true if auth is disabled or the token matches.
|
||||
func CheckTokenAuth(ctx *fasthttp.RequestCtx, auth config.CacheAPIAuthConfig) bool {
|
||||
if auth.Type == "" || auth.Type == "none" {
|
||||
return true
|
||||
}
|
||||
|
||||
if auth.Type == "token" {
|
||||
authHeader := ctx.Request.Header.Peek("Authorization")
|
||||
if len(authHeader) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
authStr := string(authHeader)
|
||||
if token, ok := strings.CutPrefix(authStr, "Bearer "); ok {
|
||||
return subtle.ConstantTimeCompare([]byte(token), []byte(auth.Token)) == 1
|
||||
}
|
||||
|
||||
return subtle.ConstantTimeCompare([]byte(authStr), []byte(auth.Token)) == 1
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user