refactor: 提取公共逻辑、消除重复代码、加强错误处理

- 提取 App 公共逻辑到 app_common.go,消除 app.go/app_windows.go 重复定义
- 提取 Server 生命周期/中间件/路由逻辑到独立文件(lifecycle.go/middleware_builder.go/router.go)
- 提取 Proxy 缓存处理/头部修改/目标选择到独立模块
- 提取 CheckIPAccess/CheckTokenAuth 到 utils/httperror.go,消除 status/purge 重复实现
- 修复 stream 双向转发:任一方向完成立即关闭双端,避免连接泄漏
- 修复 SSL/TLS 中静默忽略错误的问题,添加日志记录
- 统一日志消息为英文

💘 Generated with Crush

Assisted-by: GLM 5.1 via Crush <crush@charm.land>
This commit is contained in:
xfy 2026-04-28 18:00:48 +08:00
parent 6f6a8f0455
commit cf2fcca7e8
32 changed files with 1798 additions and 1929 deletions

View File

@ -4,196 +4,33 @@ package app
import (
"fmt"
"net"
"os"
"os/signal"
"syscall"
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/http2"
"rua.plus/lolly/internal/http3"
"rua.plus/lolly/internal/logging"
"rua.plus/lolly/internal/resolver"
"rua.plus/lolly/internal/server"
"rua.plus/lolly/internal/stream"
"rua.plus/lolly/internal/variable"
)
// App manages the server lifecycle, including HTTP, HTTP/3, Stream servers and graceful upgrades.
type App struct {
resv resolver.Resolver
cfg *config.Config
srv *server.Server
http3Srv *http3.Server
http2Srv *http2.Server
streamSrv *stream.Server
upgradeMgr *server.UpgradeManager
logger *logging.AppLogger
cfgPath string
pidFile string
logFile string
listeners []net.Listener
}
// NewApp creates a new App instance with the given config path.
func NewApp(cfgPath string) *App {
return &App{
cfgPath: cfgPath,
}
}
// SetPidFile sets the path to the PID file for the app.
func (a *App) SetPidFile(path string) {
a.pidFile = path
}
// SetLogFile sets the path to the log file for the app.
func (a *App) SetLogFile(path string) {
a.logFile = path
}
// Run starts the application: loads config, creates servers, and handles signals.
func (a *App) Run() int {
cfg, err := config.Load(a.cfgPath)
if err != nil {
fmt.Fprintf(os.Stderr, "加载配置失败: %v\n", err)
if err := a.loadAndValidateConfig(); err != nil {
return 1
}
a.cfg = cfg
a.logger = logging.NewAppLogger(&cfg.Logging)
variable.SetGlobalVariables(cfg.Variables.Set)
if len(cfg.Variables.Set) > 0 {
a.logger.LogStartup("全局变量已加载", map[string]string{
"count": fmt.Sprintf("%d", len(cfg.Variables.Set)),
})
}
a.initVariables()
// Inherit parent listeners when running as a graceful upgrade child.
if os.Getenv("GRACEFUL_UPGRADE") == "1" {
a.logger.LogStartup("检测到热升级模式,继承父进程监听器", nil)
a.upgradeMgr = server.NewUpgradeManager(nil)
listeners, err := a.upgradeMgr.GetInheritedListeners()
if err == nil && len(listeners) > 0 {
a.listeners = listeners
}
}
a.inheritListeners()
a.logger.LogStartup("配置加载成功", map[string]string{"config_path": a.cfgPath})
mode := a.cfg.GetMode()
if mode == config.ServerModeMultiServer {
for i, srv := range a.cfg.Servers {
a.logger.LogStartup("监听地址", map[string]string{
"index": fmt.Sprintf("[%d]", i),
"listen": srv.Listen,
"name": srv.Name,
})
}
} else {
a.logger.LogStartup("监听地址", map[string]string{"listen": a.cfg.Servers[0].Listen})
}
if a.cfg.Resolver.Enabled {
a.resv = resolver.New(&a.cfg.Resolver)
a.logger.LogStartup("DNS 解析器已启用", map[string]string{
"addresses": fmt.Sprintf("%v", a.cfg.Resolver.Addresses),
"ttl": a.cfg.Resolver.TTL().String(),
})
}
a.srv = server.New(a.cfg)
if a.resv != nil {
a.srv.SetResolver(a.resv)
}
if len(a.listeners) > 0 {
a.srv.SetListeners(a.listeners)
}
if len(a.cfg.Stream) > 0 {
a.streamSrv = stream.NewServer()
for _, sc := range a.cfg.Stream {
targets := make([]stream.TargetSpec, len(sc.Upstream.Targets))
for i, t := range sc.Upstream.Targets {
targets[i] = stream.TargetSpec{
Addr: t.Addr,
Weight: t.Weight,
}
}
if err := a.streamSrv.AddUpstream(sc.Listen, targets, sc.Upstream.LoadBalance, stream.HealthCheckSpec{}); err != nil {
a.logger.Error().Err(err).Msg("添加 Stream 上游失败")
}
if sc.Protocol == "udp" {
if err := a.streamSrv.ListenUDP(sc.Listen, sc.Listen, 60*time.Second); err != nil {
a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("监听 UDP 失败")
}
} else {
if err := a.streamSrv.ListenTCP(sc.Listen); err != nil {
a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("监听 TCP 失败")
}
}
}
go func() {
a.logger.LogStartup("Stream 服务器启动中", nil)
if err := a.streamSrv.Start(); err != nil {
a.logger.Error().Err(err).Msg("Stream 服务器启动失败")
}
}()
}
if a.cfg.HTTP3.Enabled && a.cfg.Servers[0].SSL.Cert != "" {
tlsConfig, err := a.srv.GetTLSConfig()
if err != nil {
a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/3")
} else {
a.http3Srv, err = http3.NewServer(&a.cfg.HTTP3, a.srv.GetHandler(), tlsConfig)
if err != nil {
a.logger.Error().Err(err).Msg("创建 HTTP/3 服务器失败")
} else {
go func() {
a.logger.LogStartup("HTTP/3 服务器启动中", map[string]string{"listen": a.cfg.HTTP3.Listen})
if err := a.http3Srv.Start(); err != nil {
a.logger.Error().Err(err).Msg("HTTP/3 服务器启动失败")
}
}()
}
}
}
if a.cfg.Servers[0].SSL.HTTP2.Enabled && a.cfg.Servers[0].SSL.Cert != "" {
tlsConfig, err := a.srv.GetTLSConfig()
if err != nil {
a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/2")
} else {
a.http2Srv, err = http2.NewServer(&a.cfg.Servers[0].SSL.HTTP2, a.srv.GetHandler(), tlsConfig)
if err != nil {
a.logger.Error().Err(err).Msg("创建 HTTP/2 服务器失败")
} else {
go func() {
a.logger.LogStartup("HTTP/2 服务器启动中", map[string]string{
"listen": a.cfg.Servers[0].Listen,
"max_concurrent_streams": fmt.Sprintf("%d", a.cfg.Servers[0].SSL.HTTP2.MaxConcurrentStreams),
"push_enabled": fmt.Sprintf("%t", a.cfg.Servers[0].SSL.HTTP2.PushEnabled),
})
// HTTP/2 shares the main server's listener; ALPN negotiates protocol selection.
listeners := a.srv.GetListeners()
if len(listeners) > 0 {
if err := a.http2Srv.Serve(listeners[0]); err != nil {
a.logger.Error().Err(err).Msg("HTTP/2 服务器启动失败")
}
} else {
a.logger.Error().Msg("HTTP/2 服务器启动失败: 无可用监听器")
}
}()
}
}
}
a.logServerAddresses()
a.initResolver()
a.initServer()
a.initStreamServers()
a.initHTTP3()
a.initHTTP2()
a.upgradeMgr = server.NewUpgradeManager(a.srv)
a.srv.SetUpgradeManager(a.upgradeMgr)
@ -207,7 +44,7 @@ func (a *App) Run() int {
errChan := make(chan error, 1)
go func() {
a.logger.LogStartup("HTTP 服务器启动中", nil)
a.logger.LogStartup("Starting HTTP server", nil)
if err := a.srv.Start(); err != nil {
errChan <- err
}
@ -218,24 +55,36 @@ func (a *App) Run() int {
for {
select {
case err := <-errChan:
a.logger.Error().Err(err).Msg("服务器启动失败")
a.logger.Error().Err(err).Msg("Server failed to start")
return 1
case sig := <-sigChan:
if sig == syscall.SIGINT {
sigintCount++
if sigintCount >= 3 {
a.logger.LogShutdown("收到 3 次 SIGINT强制退出")
a.logger.LogShutdown("Received 3 SIGINT, forcing exit")
return 1
}
}
if !a.handleSignal(sig) {
a.logger.LogShutdown("服务器已停止")
a.logger.LogShutdown("Server stopped")
return 0
}
}
}
}
// inheritListeners inherits parent listeners during graceful upgrade.
func (a *App) inheritListeners() {
if os.Getenv("GRACEFUL_UPGRADE") == "1" {
a.logger.LogStartup("Graceful upgrade mode detected, inheriting parent listeners", nil)
a.upgradeMgr = server.NewUpgradeManager(nil)
listeners, err := a.upgradeMgr.GetInheritedListeners()
if err == nil && len(listeners) > 0 {
a.listeners = listeners
}
}
}
func (a *App) setupSignalHandlers(sigChan chan<- os.Signal) {
signal.Notify(sigChan,
syscall.SIGTERM,
@ -250,7 +99,7 @@ func (a *App) setupSignalHandlers(sigChan chan<- os.Signal) {
// handleSignal returns false to indicate the app should exit.
func (a *App) handleSignal(sig os.Signal) bool {
if a.cfg == nil {
a.logger.Error().Msg("信号处理失败: 配置为 nil使用默认超时")
a.logger.Error().Msg("Signal handling failed: config is nil, using default timeout")
a.cfg = &config.Config{
Shutdown: config.ShutdownConfig{
GracefulTimeout: 30 * time.Second,
@ -265,7 +114,7 @@ func (a *App) handleSignal(sig os.Signal) bool {
if timeout <= 0 {
timeout = 30 * time.Second
}
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("优雅停止(等待 %v", timeout))
a.logger.LogSignal("SIGQUIT", fmt.Sprintf("Graceful stop (waiting %v)", timeout))
a.shutdownHTTP2()
a.shutdownHTTP3()
_ = a.srv.GracefulStop(timeout)
@ -278,9 +127,9 @@ func (a *App) handleSignal(sig os.Signal) bool {
}
sigTyped, ok := sig.(syscall.Signal)
if !ok {
a.logger.LogSignal("unknown", "停止服务器")
a.logger.LogSignal("unknown", "Stopping server")
} else {
a.logger.LogSignal(sigName(sigTyped), "停止服务器")
a.logger.LogSignal(sigName(sigTyped), "Stopping server")
}
a.shutdownHTTP2()
a.shutdownHTTP3()
@ -288,89 +137,65 @@ func (a *App) handleSignal(sig os.Signal) bool {
return false
case syscall.SIGHUP:
a.logger.LogSignal("SIGHUP", "重载配置")
a.logger.LogSignal("SIGHUP", "Reloading config")
a.reloadConfig()
return true
case syscall.SIGUSR1:
a.logger.LogSignal("SIGUSR1", "重新打开日志")
a.logger.LogSignal("SIGUSR1", "Reopening logs")
a.reopenLogs()
return true
case syscall.SIGUSR2:
a.logger.LogSignal("SIGUSR2", "执行热升级")
a.logger.LogSignal("SIGUSR2", "Performing graceful upgrade")
a.gracefulUpgrade()
return true
default:
a.logger.Info().Str("signal", sig.String()).Msg("收到未知信号")
a.logger.Info().Str("signal", sig.String()).Msg("Received unknown signal")
return true
}
}
func (a *App) shutdownHTTP3() {
if a.http3Srv != nil {
if err := a.http3Srv.Stop(); err != nil {
a.logger.Error().Err(err).Msg("HTTP/3 服务器关闭失败")
}
}
}
func (a *App) shutdownHTTP2() {
if a.http2Srv != nil {
if err := a.http2Srv.Stop(); err != nil {
a.logger.Error().Err(err).Msg("HTTP/2 服务器关闭失败")
}
}
}
func (a *App) reloadConfig() {
newCfg, err := config.Load(a.cfgPath)
if err != nil {
a.logger.Error().Err(err).Msg("重载配置失败")
a.logger.Error().Err(err).Msg("Failed to reload config")
return
}
a.cfg = newCfg
a.logger = logging.NewAppLogger(&newCfg.Logging)
a.logger.LogStartup("配置重载成功", nil)
}
func (a *App) reopenLogs() {
if a.cfg != nil {
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)
a.logger = logging.NewAppLogger(&a.cfg.Logging)
}
a.logger.LogStartup("日志已重新打开", nil)
a.logger.LogStartup("Config reloaded successfully", nil)
}
func (a *App) gracefulUpgrade() {
execPath, err := os.Executable()
if err != nil {
a.logger.Error().Err(err).Msg("获取可执行文件路径失败")
a.logger.Error().Err(err).Msg("Failed to get executable path")
return
}
if a.srv == nil {
a.logger.Error().Msg("热升级失败: 服务器实例为 nil")
a.logger.Error().Msg("Graceful upgrade failed: server instance is nil")
return
}
listeners := a.srv.GetListeners()
if len(listeners) == 0 {
a.logger.Error().Msg("热升级失败: 服务器未保存监听器(热升级当前未完全实现)")
a.logger.Info().Msg("提示: 热升级需要服务器使用手动监听器管理模式")
a.logger.Error().Msg("Graceful upgrade failed: server has no saved listeners (graceful upgrade not fully implemented)")
a.logger.Info().Msg("Hint: graceful upgrade requires the server to use manual listener management mode")
return
}
a.upgradeMgr.SetListeners(listeners)
if err := a.upgradeMgr.GracefulUpgrade(execPath); err != nil {
a.logger.Error().Err(err).Msg("热升级失败")
a.logger.Error().Err(err).Msg("Graceful upgrade failed")
return
}
a.logger.LogStartup("热升级已启动,新进程正在接管", nil)
a.logger.LogStartup("Graceful upgrade started, new process is taking over", nil)
timeout := a.cfg.Shutdown.GracefulTimeout
if timeout <= 0 {

242
internal/app/app_common.go Normal file
View File

@ -0,0 +1,242 @@
package app
import (
"fmt"
"net"
"os"
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/http2"
"rua.plus/lolly/internal/http3"
"rua.plus/lolly/internal/logging"
"rua.plus/lolly/internal/resolver"
"rua.plus/lolly/internal/server"
"rua.plus/lolly/internal/stream"
"rua.plus/lolly/internal/variable"
)
// App manages the server lifecycle, including HTTP, HTTP/3, Stream servers and graceful upgrades.
type App struct {
resv resolver.Resolver
cfg *config.Config
srv *server.Server
http3Srv *http3.Server
http2Srv *http2.Server
streamSrv *stream.Server
upgradeMgr *server.UpgradeManager
logger *logging.AppLogger
cfgPath string
pidFile string
logFile string
listeners []net.Listener
}
// NewApp creates a new App instance with the given config path.
func NewApp(cfgPath string) *App {
return &App{
cfgPath: cfgPath,
}
}
// SetPidFile sets the path to the PID file for the app.
func (a *App) SetPidFile(path string) {
a.pidFile = path
}
// SetLogFile sets the path to the log file for the app.
func (a *App) SetLogFile(path string) {
a.logFile = path
}
// loadAndValidateConfig loads configuration and initializes the logger.
func (a *App) loadAndValidateConfig() error {
cfg, err := config.Load(a.cfgPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load config: %v\n", err)
return err
}
a.cfg = cfg
a.logger = logging.NewAppLogger(&cfg.Logging)
return nil
}
// initVariables loads global variables from configuration.
func (a *App) initVariables() {
variable.SetGlobalVariables(a.cfg.Variables.Set)
if len(a.cfg.Variables.Set) > 0 {
a.logger.LogStartup("Global variables loaded", map[string]string{
"count": fmt.Sprintf("%d", len(a.cfg.Variables.Set)),
})
}
}
// logServerAddresses logs the listening addresses based on server mode.
func (a *App) logServerAddresses() {
a.logger.LogStartup("Config loaded successfully", map[string]string{"config_path": a.cfgPath})
mode := a.cfg.GetMode()
if mode == config.ServerModeMultiServer {
for i, srv := range a.cfg.Servers {
a.logger.LogStartup("Listening address", map[string]string{
"index": fmt.Sprintf("[%d]", i),
"listen": srv.Listen,
"name": srv.Name,
})
}
} else {
a.logger.LogStartup("Listening address", map[string]string{"listen": a.cfg.Servers[0].Listen})
}
}
// initResolver initializes the DNS resolver if enabled.
func (a *App) initResolver() {
if a.cfg.Resolver.Enabled {
a.resv = resolver.New(&a.cfg.Resolver)
a.logger.LogStartup("DNS resolver enabled", map[string]string{
"addresses": fmt.Sprintf("%v", a.cfg.Resolver.Addresses),
"ttl": a.cfg.Resolver.TTL().String(),
})
}
}
// initServer creates the main server and sets the resolver.
func (a *App) initServer() {
a.srv = server.New(a.cfg)
if a.resv != nil {
a.srv.SetResolver(a.resv)
}
if len(a.listeners) > 0 {
a.srv.SetListeners(a.listeners)
}
}
// initStreamServers configures and starts stream servers.
func (a *App) initStreamServers() {
if len(a.cfg.Stream) == 0 {
return
}
a.streamSrv = stream.NewServer()
for _, sc := range a.cfg.Stream {
targets := make([]stream.TargetSpec, len(sc.Upstream.Targets))
for i, t := range sc.Upstream.Targets {
targets[i] = stream.TargetSpec{
Addr: t.Addr,
Weight: t.Weight,
}
}
if err := a.streamSrv.AddUpstream(sc.Listen, targets, sc.Upstream.LoadBalance, stream.HealthCheckSpec{}); err != nil {
a.logger.Error().Err(err).Msg("Failed to add Stream upstream")
}
if sc.Protocol == "udp" {
if err := a.streamSrv.ListenUDP(sc.Listen, sc.Listen, 60*time.Second); err != nil {
a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("Failed to listen on UDP")
}
} else {
if err := a.streamSrv.ListenTCP(sc.Listen); err != nil {
a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("Failed to listen on TCP")
}
}
}
go func() {
a.logger.LogStartup("Starting Stream server", nil)
if err := a.streamSrv.Start(); err != nil {
a.logger.Error().Err(err).Msg("Stream server failed to start")
}
}()
}
// initHTTP3 starts the HTTP/3 server if enabled.
func (a *App) initHTTP3() {
if !a.cfg.HTTP3.Enabled || a.cfg.Servers[0].SSL.Cert == "" {
return
}
tlsConfig, err := a.srv.GetTLSConfig()
if err != nil {
a.logger.Error().Err(err).Msg("Failed to get TLS config, skipping HTTP/3")
return
}
a.http3Srv, err = http3.NewServer(&a.cfg.HTTP3, a.srv.GetHandler(), tlsConfig)
if err != nil {
a.logger.Error().Err(err).Msg("Failed to create HTTP/3 server")
return
}
go func() {
a.logger.LogStartup("Starting HTTP/3 server", map[string]string{"listen": a.cfg.HTTP3.Listen})
if err := a.http3Srv.Start(); err != nil {
a.logger.Error().Err(err).Msg("HTTP/3 server failed to start")
}
}()
}
// initHTTP2 starts the HTTP/2 server if enabled.
func (a *App) initHTTP2() {
if !a.cfg.Servers[0].SSL.HTTP2.Enabled || a.cfg.Servers[0].SSL.Cert == "" {
return
}
tlsConfig, err := a.srv.GetTLSConfig()
if err != nil {
a.logger.Error().Err(err).Msg("Failed to get TLS config, skipping HTTP/2")
return
}
a.http2Srv, err = http2.NewServer(&a.cfg.Servers[0].SSL.HTTP2, a.srv.GetHandler(), tlsConfig)
if err != nil {
a.logger.Error().Err(err).Msg("Failed to create HTTP/2 server")
return
}
go func() {
a.logger.LogStartup("Starting HTTP/2 server", map[string]string{
"listen": a.cfg.Servers[0].Listen,
"max_concurrent_streams": fmt.Sprintf("%d", a.cfg.Servers[0].SSL.HTTP2.MaxConcurrentStreams),
"push_enabled": fmt.Sprintf("%t", a.cfg.Servers[0].SSL.HTTP2.PushEnabled),
})
// HTTP/2 shares the main server's listener; ALPN negotiates protocol selection.
listeners := a.srv.GetListeners()
if len(listeners) > 0 {
if err := a.http2Srv.Serve(listeners[0]); err != nil {
a.logger.Error().Err(err).Msg("HTTP/2 server failed to start")
}
} else {
a.logger.Error().Msg("HTTP/2 server failed to start: no available listeners")
}
}()
}
// shutdownHTTP3 gracefully stops the HTTP/3 server.
func (a *App) shutdownHTTP3() {
if a.http3Srv != nil {
if err := a.http3Srv.Stop(); err != nil {
a.logger.Error().Err(err).Msg("Failed to shutdown HTTP/3 server")
}
}
}
// shutdownHTTP2 gracefully stops the HTTP/2 server.
func (a *App) shutdownHTTP2() {
if a.http2Srv != nil {
if err := a.http2Srv.Stop(); err != nil {
a.logger.Error().Err(err).Msg("Failed to shutdown HTTP/2 server")
}
}
}
// reopenLogs reinitializes the logger from current config.
func (a *App) reopenLogs() {
if a.cfg != nil {
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)
a.logger = logging.NewAppLogger(&a.cfg.Logging)
}
a.logger.LogStartup("Logs reopened", nil)
}

View File

@ -225,7 +225,7 @@ func TestRun(t *testing.T) {
genConfig: true,
outputPath: filepath.Join(t.TempDir(), "config.yaml"),
wantExitCode: 0,
wantContains: "配置已写入:",
wantContains: "Config written to:",
},
{
name: "配置文件不存在",
@ -233,7 +233,7 @@ func TestRun(t *testing.T) {
genConfig: false,
showVersion: false,
wantExitCode: 1,
wantErrContains: "加载配置失败",
wantErrContains: "Failed to load config",
},
{
name: "generate 与 import 互斥",
@ -252,7 +252,7 @@ func TestRun(t *testing.T) {
name: "导入 nginx 配置文件不存在",
importPath: "/tmp/nginx.conf",
wantExitCode: 1,
wantErrContains: "解析 nginx 配置失败",
wantErrContains: "failed to parse nginx config",
},
}
@ -371,8 +371,8 @@ func TestGenerateConfig(t *testing.T) {
t.Errorf("exit code = %d, want 1", exitCode)
}
if !strings.Contains(stderr, "写入文件失败") {
t.Errorf("stderr 应包含 '写入文件失败', 实际输出: %q", stderr)
if !strings.Contains(stderr, "Failed to write file") {
t.Errorf("stderr should contain 'Failed to write file', actual: %q", stderr)
}
})
}
@ -1247,8 +1247,8 @@ func TestGenerateConfig_ErrorCase(t *testing.T) {
t.Errorf("exit code = %d, want 1", exitCode)
}
if !strings.Contains(stderr, "写入文件失败") {
t.Errorf("stderr 应包含 '写入文件失败', 实际输出: %q", stderr)
if !strings.Contains(stderr, "Failed to write file") {
t.Errorf("stderr should contain 'Failed to write file', actual: %q", stderr)
}
})
}

View File

@ -6,178 +6,27 @@ package app
import (
"fmt"
"net"
"os"
"os/signal"
"syscall"
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/http2"
"rua.plus/lolly/internal/http3"
"rua.plus/lolly/internal/logging"
"rua.plus/lolly/internal/resolver"
"rua.plus/lolly/internal/server"
"rua.plus/lolly/internal/stream"
"rua.plus/lolly/internal/variable"
)
// App manages the server lifecycle (Windows version).
type App struct {
resv resolver.Resolver
cfg *config.Config
srv *server.Server
http3Srv *http3.Server
http2Srv *http2.Server
streamSrv *stream.Server
logger *logging.AppLogger
cfgPath string
pidFile string
logFile string
listeners []net.Listener
upgradeMgr *server.UpgradeManager
}
func NewApp(cfgPath string) *App {
return &App{
cfgPath: cfgPath,
}
}
func (a *App) SetPidFile(path string) {
a.pidFile = path
}
func (a *App) SetLogFile(path string) {
a.logFile = path
}
// Run starts the application: loads config, creates servers, and handles signals (Windows version).
func (a *App) Run() int {
cfg, err := config.Load(a.cfgPath)
if err != nil {
fmt.Fprintf(os.Stderr, "加载配置失败: %v\n", err)
if err := a.loadAndValidateConfig(); err != nil {
return 1
}
a.cfg = cfg
a.logger = logging.NewAppLogger(&cfg.Logging)
variable.SetGlobalVariables(cfg.Variables.Set)
if len(cfg.Variables.Set) > 0 {
a.logger.LogStartup("全局变量已加载", map[string]string{
"count": fmt.Sprintf("%d", len(cfg.Variables.Set)),
})
}
a.logger.LogStartup("配置加载成功", map[string]string{"config_path": a.cfgPath})
mode := a.cfg.GetMode()
if mode == config.ServerModeMultiServer {
for i, srv := range a.cfg.Servers {
a.logger.LogStartup("监听地址", map[string]string{
"index": fmt.Sprintf("[%d]", i),
"listen": srv.Listen,
"name": srv.Name,
})
}
} else {
a.logger.LogStartup("监听地址", map[string]string{"listen": a.cfg.Servers[0].Listen})
}
if a.cfg.Resolver.Enabled {
a.resv = resolver.New(&a.cfg.Resolver)
a.logger.LogStartup("DNS 解析器已启用", map[string]string{
"addresses": fmt.Sprintf("%v", a.cfg.Resolver.Addresses),
"ttl": a.cfg.Resolver.TTL().String(),
})
}
a.srv = server.New(a.cfg)
if a.resv != nil {
a.srv.SetResolver(a.resv)
}
if len(a.cfg.Stream) > 0 {
a.streamSrv = stream.NewServer()
for _, sc := range a.cfg.Stream {
targets := make([]stream.TargetSpec, len(sc.Upstream.Targets))
for i, t := range sc.Upstream.Targets {
targets[i] = stream.TargetSpec{
Addr: t.Addr,
Weight: t.Weight,
}
}
if err := a.streamSrv.AddUpstream(sc.Listen, targets, sc.Upstream.LoadBalance, stream.HealthCheckSpec{}); err != nil {
a.logger.Error().Err(err).Msg("添加 Stream 上游失败")
}
if sc.Protocol == "udp" {
if err := a.streamSrv.ListenUDP(sc.Listen, sc.Listen, 60*time.Second); err != nil {
a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("监听 UDP 失败")
}
} else {
if err := a.streamSrv.ListenTCP(sc.Listen); err != nil {
a.logger.Error().Err(err).Str("listen", sc.Listen).Msg("监听 TCP 失败")
}
}
}
go func() {
a.logger.LogStartup("Stream 服务器启动中", nil)
if err := a.streamSrv.Start(); err != nil {
a.logger.Error().Err(err).Msg("Stream 服务器启动失败")
}
}()
}
if a.cfg.HTTP3.Enabled && a.cfg.Servers[0].SSL.Cert != "" {
tlsConfig, err := a.srv.GetTLSConfig()
if err != nil {
a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/3")
} else {
a.http3Srv, err = http3.NewServer(&a.cfg.HTTP3, a.srv.GetHandler(), tlsConfig)
if err != nil {
a.logger.Error().Err(err).Msg("创建 HTTP/3 服务器失败")
} else {
go func() {
a.logger.LogStartup("HTTP/3 服务器启动中", map[string]string{"listen": a.cfg.HTTP3.Listen})
if err := a.http3Srv.Start(); err != nil {
a.logger.Error().Err(err).Msg("HTTP/3 服务器启动失败")
}
}()
}
}
}
if a.cfg.Servers[0].SSL.HTTP2.Enabled && a.cfg.Servers[0].SSL.Cert != "" {
tlsConfig, err := a.srv.GetTLSConfig()
if err != nil {
a.logger.Error().Err(err).Msg("获取 TLS 配置失败,跳过 HTTP/2")
} else {
a.http2Srv, err = http2.NewServer(&a.cfg.Servers[0].SSL.HTTP2, a.srv.GetHandler(), tlsConfig)
if err != nil {
a.logger.Error().Err(err).Msg("创建 HTTP/2 服务器失败")
} else {
go func() {
a.logger.LogStartup("HTTP/2 服务器启动中", map[string]string{
"listen": a.cfg.Servers[0].Listen,
"max_concurrent_streams": fmt.Sprintf("%d", a.cfg.Servers[0].SSL.HTTP2.MaxConcurrentStreams),
"push_enabled": fmt.Sprintf("%t", a.cfg.Servers[0].SSL.HTTP2.PushEnabled),
})
listeners := a.srv.GetListeners()
if len(listeners) > 0 {
if err := a.http2Srv.Serve(listeners[0]); err != nil {
a.logger.Error().Err(err).Msg("HTTP/2 服务器启动失败")
}
} else {
a.logger.Error().Msg("HTTP/2 服务器启动失败: 无可用监听器")
}
}()
}
}
}
a.initVariables()
a.logServerAddresses()
a.initResolver()
a.initServer()
a.initStreamServers()
a.initHTTP3()
a.initHTTP2()
a.upgradeMgr = server.NewUpgradeManager(a.srv)
if a.pidFile != "" {
@ -190,7 +39,7 @@ func (a *App) Run() int {
errChan := make(chan error, 1)
go func() {
a.logger.LogStartup("HTTP 服务器启动中", nil)
a.logger.LogStartup("Starting HTTP server", nil)
if err := a.srv.Start(); err != nil {
errChan <- err
}
@ -201,18 +50,18 @@ func (a *App) Run() int {
for {
select {
case err := <-errChan:
a.logger.Error().Err(err).Msg("服务器启动失败")
a.logger.Error().Err(err).Msg("Server failed to start")
return 1
case sig := <-sigChan:
if sig == syscall.SIGINT {
sigintCount++
if sigintCount >= 3 {
a.logger.LogShutdown("收到 3 次 SIGINT强制退出")
a.logger.LogShutdown("Received 3 SIGINT, forcing exit")
return 1
}
}
if !a.handleSignal(sig) {
a.logger.LogShutdown("服务器已停止")
a.logger.LogShutdown("Server stopped")
return 0
}
}
@ -234,47 +83,23 @@ func (a *App) handleSignal(sig os.Signal) bool {
if timeout <= 0 {
timeout = 5 * time.Second
}
a.logger.LogSignal(sigName(sig.(syscall.Signal)), "停止服务器")
a.logger.LogSignal(sigName(sig.(syscall.Signal)), "Stopping server")
a.shutdownHTTP2()
a.shutdownHTTP3()
_ = a.srv.StopWithTimeout(timeout)
return false
default:
a.logger.Info().Str("signal", sig.String()).Msg("收到信号Windows 忽略)")
a.logger.Info().Str("signal", sig.String()).Msg("Received signal (ignored on Windows)")
return true
}
}
func (a *App) shutdownHTTP3() {
if a.http3Srv != nil {
if err := a.http3Srv.Stop(); err != nil {
a.logger.Error().Err(err).Msg("HTTP/3 服务器关闭失败")
}
}
}
func (a *App) shutdownHTTP2() {
if a.http2Srv != nil {
if err := a.http2Srv.Stop(); err != nil {
a.logger.Error().Err(err).Msg("HTTP/2 服务器关闭失败")
}
}
}
func (a *App) reloadConfig() {
// Windows stub - functionality limited
}
func (a *App) reopenLogs() {
if a.cfg != nil {
logging.Init(a.cfg.Logging.Error.Level, a.cfg.Logging.Format)
a.logger = logging.NewAppLogger(&a.cfg.Logging)
}
a.logger.LogStartup("日志已重新打开", nil)
}
func (a *App) gracefulUpgrade() {
a.logger.Info().Msg("Windows 不支持热升级")
a.logger.Info().Msg("Graceful upgrade is not supported on Windows")
}
func sigName(sig syscall.Signal) string {

View File

@ -48,7 +48,7 @@ func generateConfig(outputPath string) int {
cfg := config.DefaultConfig()
yamlData, err := config.GenerateConfigYAML(cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "生成配置失败: %v\n", err)
fmt.Fprintf(os.Stderr, "Failed to generate config: %v\n", err)
return 1
}
@ -56,10 +56,10 @@ func generateConfig(outputPath string) int {
fmt.Print(string(yamlData))
} else {
if err := os.WriteFile(outputPath, yamlData, 0o644); err != nil {
fmt.Fprintf(os.Stderr, "写入文件失败: %v\n", err)
fmt.Fprintf(os.Stderr, "Failed to write file: %v\n", err)
return 1
}
fmt.Printf("配置已写入: %s\n", outputPath)
fmt.Printf("Config written to: %s\n", outputPath)
}
return 0
}
@ -67,12 +67,12 @@ func generateConfig(outputPath string) int {
func importNginxConfig(path, outputPath string) error {
nginxCfg, err := nginx.ParseFile(path)
if err != nil {
return fmt.Errorf("解析 nginx 配置失败: %w", err)
return fmt.Errorf("failed to parse nginx config: %w", err)
}
result, err := nginx.Convert(nginxCfg)
if err != nil {
return fmt.Errorf("转换配置失败: %w", err)
return fmt.Errorf("failed to convert config: %w", err)
}
for _, w := range result.Warnings {
@ -80,26 +80,26 @@ func importNginxConfig(path, outputPath string) error {
}
if validateErr := config.Validate(result.Config); validateErr != nil {
return fmt.Errorf("转换后配置验证失败: %w", validateErr)
return fmt.Errorf("converted config validation failed: %w", validateErr)
}
yamlData, err := yaml.Marshal(result.Config)
if err != nil {
return fmt.Errorf("序列化 YAML 失败: %w", err)
return fmt.Errorf("failed to marshal YAML: %w", err)
}
if outputPath == "" {
if _, err := os.Stdout.Write(yamlData); err != nil {
return fmt.Errorf("写入标准输出失败: %w", err)
return fmt.Errorf("failed to write to stdout: %w", err)
}
} else {
if err := os.MkdirAll(filepath.Dir(outputPath), 0o755); err != nil {
return fmt.Errorf("创建输出目录失败: %w", err)
return fmt.Errorf("failed to create output directory: %w", err)
}
if err := os.WriteFile(outputPath, yamlData, 0o644); err != nil {
return fmt.Errorf("写入文件失败: %w", err)
return fmt.Errorf("failed to write file: %w", err)
}
fmt.Printf("配置已写入: %s\n", outputPath)
fmt.Printf("Config written to: %s\n", outputPath)
}
return nil

View File

@ -12,7 +12,6 @@
package cache
import (
"crypto/subtle"
"encoding/json"
"hash/fnv"
"net"
@ -20,7 +19,7 @@ import (
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/netutil"
"rua.plus/lolly/internal/utils"
)
// PurgeAPI 缓存清理 API 处理器。
@ -117,26 +116,26 @@ func (p *PurgeAPI) Path() string {
func (p *PurgeAPI) ServeHTTP(ctx *fasthttp.RequestCtx) {
// 仅允许 POST 方法
if string(ctx.Method()) != "POST" {
p.sendError(ctx, fasthttp.StatusMethodNotAllowed, "method not allowed")
utils.SendJSONError(ctx, fasthttp.StatusMethodNotAllowed, "method not allowed")
return
}
// 检查 IP 访问权限
if !p.checkAccess(ctx) {
p.sendError(ctx, fasthttp.StatusForbidden, "forbidden")
if !utils.CheckIPAccess(ctx, p.allowed) {
utils.SendJSONError(ctx, fasthttp.StatusForbidden, "forbidden")
return
}
// 检查认证
if !p.checkAuth(ctx) {
p.sendError(ctx, fasthttp.StatusUnauthorized, "unauthorized")
if !utils.CheckTokenAuth(ctx, p.auth) {
utils.SendJSONError(ctx, fasthttp.StatusUnauthorized, "unauthorized")
return
}
// 解析请求体
var req PurgeRequest
if err := json.Unmarshal(ctx.PostBody(), &req); err != nil {
p.sendError(ctx, fasthttp.StatusBadRequest, "invalid request body")
utils.SendJSONError(ctx, fasthttp.StatusBadRequest, "invalid request body")
return
}
@ -147,7 +146,7 @@ func (p *PurgeAPI) ServeHTTP(ctx *fasthttp.RequestCtx) {
} else if req.Pattern != "" {
deleted = p.purgeByPattern(req.Pattern)
} else {
p.sendError(ctx, fasthttp.StatusBadRequest, "missing path or pattern")
utils.SendJSONError(ctx, fasthttp.StatusBadRequest, "missing path or pattern")
return
}
@ -157,61 +156,6 @@ func (p *PurgeAPI) ServeHTTP(ctx *fasthttp.RequestCtx) {
_ = json.NewEncoder(ctx).Encode(PurgeResponse{Deleted: deleted})
}
// checkAccess 检查客户端 IP 是否在允许列表中。
func (p *PurgeAPI) checkAccess(ctx *fasthttp.RequestCtx) bool {
// 如果没有配置允许列表,允许所有访问
if len(p.allowed) == 0 {
return true
}
clientIP := p.getClientIP(ctx)
if clientIP == nil {
return false
}
// 检查是否在允许列表中
for _, network := range p.allowed {
if network.Contains(clientIP) {
return true
}
}
return false
}
// checkAuth 检查认证。
func (p *PurgeAPI) checkAuth(ctx *fasthttp.RequestCtx) bool {
// 无需认证
if p.auth.Type == "" || p.auth.Type == "none" {
return true
}
// Token 认证
if p.auth.Type == "token" {
// 从 Authorization header 获取 token
authHeader := ctx.Request.Header.Peek("Authorization")
if len(authHeader) == 0 {
return false
}
// 支持 Bearer token 格式
authStr := string(authHeader)
if token, ok := strings.CutPrefix(authStr, "Bearer "); ok {
return subtle.ConstantTimeCompare([]byte(token), []byte(p.auth.Token)) == 1
}
// 也支持直接传递 token
return subtle.ConstantTimeCompare([]byte(authStr), []byte(p.auth.Token)) == 1
}
return false
}
// getClientIP 从请求上下文提取客户端 IP。
func (p *PurgeAPI) getClientIP(ctx *fasthttp.RequestCtx) net.IP {
return netutil.ExtractClientIPNet(ctx)
}
// purgeByPath 按精确路径清理缓存。
func (p *PurgeAPI) purgeByPath(path string) int {
if p.cache == nil {
@ -323,9 +267,3 @@ func matchPattern(pattern, path string) bool {
return MatchPattern(pattern, path)
}
// sendError 发送错误响应。
func (p *PurgeAPI) sendError(ctx *fasthttp.RequestCtx, status int, errMsg string) {
ctx.SetContentType("application/json; charset=utf-8")
ctx.SetStatusCode(status)
_ = json.NewEncoder(ctx).Encode(PurgeErrorResponse{Error: errMsg})
}

View File

@ -525,7 +525,7 @@ func validateProxy(p *ProxyConfig) error {
}
// 使用 regexp.Compile 验证正则语法有效性
if _, err := regexp.Compile(p.Path); err != nil {
return fmt.Errorf("location_type 为 '%s' 时path '%s' 不是有效正则: %v", p.LocationType, p.Path, err)
return fmt.Errorf("location_type 为 '%s' 时path '%s' 不是有效正则: %w", p.LocationType, p.Path, err)
}
}
@ -1273,7 +1273,7 @@ func validateRedirectRewrite(cfg *RedirectRewriteConfig) error {
patternStr = rule.Pattern[2:] // 去掉 ~* 前缀(大小写不敏感)
}
if _, err := regexp.Compile(patternStr); err != nil {
return fmt.Errorf("redirect_rewrite.rules[%d].pattern invalid regex: %v", i, err)
return fmt.Errorf("redirect_rewrite.rules[%d].pattern invalid regex: %w", i, err)
}
}
}

View File

@ -18,6 +18,7 @@
package handler
import (
"errors"
"fmt"
"os"
"sync"
@ -106,7 +107,11 @@ func NewErrorPageManager(cfg *config.ErrorPageConfig) (*ErrorPageManager, error)
if len(loadErrors) == totalPages {
// 全部加载失败,返回错误
return nil, fmt.Errorf("所有错误页面加载失败: %v", loadErrors)
errs := make([]error, 0, len(loadErrors))
for _, e := range loadErrors {
errs = append(errs, e)
}
return nil, fmt.Errorf("所有错误页面加载失败: %w", errors.Join(errs...))
}
// 部分失败,记录警告(由调用者处理)
return manager, &PartialLoadError{Errors: loadErrors}

View File

@ -201,15 +201,9 @@ func (s *Server) Stop() error {
s.running = false
if s.http3Server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := s.http3Server.Close(); err != nil {
logging.Error().Err(err).Msg("HTTP/3 server close error")
}
// 等待服务完全停止
<-ctx.Done()
}
logging.Info().Msg("HTTP/3 server stopped")

View File

@ -25,6 +25,7 @@ import (
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/utils"
"rua.plus/lolly/internal/variable"
)
@ -179,7 +180,7 @@ func (m *Middleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandl
for ruleIndex < len(m.rules) {
// 步骤1: 检查迭代次数是否超过限制(防止无限循环)
if iterationCount >= MaxRewriteIterations {
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
utils.SendError(ctx, utils.ErrInternalError)
return
}

View File

@ -41,6 +41,7 @@ import (
"golang.org/x/crypto/bcrypt"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/middleware"
"rua.plus/lolly/internal/utils"
)
// HashAlgorithm 表示密码哈希算法类型。
@ -381,7 +382,7 @@ func (ba *BasicAuth) extractCredentials(ctx *fasthttp.RequestCtx) (string, strin
func (ba *BasicAuth) sendAuthChallenge(ctx *fasthttp.RequestCtx) {
ctx.Response.Header.Set("WWW-Authenticate",
fmt.Sprintf("Basic realm=\"%s\", charset=\"UTF-8\"", ba.realm))
ctx.Error("Unauthorized", fasthttp.StatusUnauthorized)
utils.SendError(ctx, utils.ErrUnauthorized)
}
// AddUser 动态添加新用户。

View File

@ -0,0 +1,197 @@
package proxy
import (
"hash/fnv"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/cache"
"rua.plus/lolly/internal/loadbalance"
)
// buildCacheKey 构建缓存键字符串。
//
// 使用请求方法和完整请求 URI 作为缓存键。
// 该函数保留用于日志记录和调试场景。
//
// 参数:
// - ctx: FastHTTP 请求上下文
//
// 返回值:
// - string: 缓存键(格式 "METHOD:URI"
func (p *Proxy) buildCacheKey(ctx *fasthttp.RequestCtx) string {
// 使用请求方法和路径作为缓存键
return string(ctx.Request.Header.Method()) + ":" + string(ctx.Request.URI().RequestURI())
}
// buildCacheKeyHash 使用 FNV-64a 计算缓存键的 uint64 哈希值。
// 返回哈希值和原始字符串键。
// 注意:此函数会先构建字符串键再哈希,存在双重分配。
// 对于只需要哈希值的场景,使用 buildCacheKeyHashValue 代替。
func (p *Proxy) buildCacheKeyHash(ctx *fasthttp.RequestCtx) (uint64, string) {
// 构建原始 key
origKey := p.buildCacheKey(ctx)
// 使用 FNV-64a 计算哈希
h := fnv.New64a()
h.Write([]byte(origKey))
return h.Sum64(), origKey
}
// buildCacheKeyHashValue 直接计算缓存键的哈希值,零字符串分配。
// 用于只需要哈希值而不需要原始键的场景。
func (p *Proxy) buildCacheKeyHashValue(ctx *fasthttp.RequestCtx) uint64 {
h := fnv.New64a()
h.Write(ctx.Request.Header.Method())
h.Write([]byte(":"))
h.Write(ctx.Request.URI().RequestURI())
return h.Sum64()
}
// writeCachedResponse 将缓存的响应写入 FastHTTP 响应上下文。
//
// 设置响应体、状态码、响应头,并添加 X-Cache: HIT 头标记缓存命中。
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - entry: 缓存条目,包含响应数据和元数据
func (p *Proxy) writeCachedResponse(ctx *fasthttp.RequestCtx, entry *cache.ProxyCacheEntry) {
ctx.Response.SetBody(entry.Data)
ctx.Response.SetStatusCode(entry.Status)
for key, value := range entry.Headers {
ctx.Response.Header.Set(key, value)
}
ctx.Response.Header.Set("X-Cache", "HIT")
}
// backgroundRefresh 在后台异步刷新缓存条目。
//
// 向对应的上游目标发送请求,获取最新响应并更新缓存。
// 该方法在独立 goroutine 中运行,不阻塞主请求流程。
//
// 参数:
// - ctx: 原始 FastHTTP 请求上下文(仅用于复制请求信息)
// - target: 要刷新的后端目标
// - hashKey: 缓存哈希键
// - origKey: 缓存原始键
func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance.Target, hashKey uint64, origKey string) {
// 创建新的请求上下文副本
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// 复制原始请求
ctx.Request.CopyTo(req)
// 如果启用 Revalidate添加条件请求头
if p.config.Cache.Revalidate {
if entry, ok, _ := p.cache.Get(hashKey, origKey); ok {
if entry.LastModified != "" {
req.Header.Set("If-Modified-Since", entry.LastModified)
}
if entry.ETag != "" {
req.Header.Set("If-None-Match", entry.ETag)
}
}
}
// 获取客户端
client := p.getClient(target.URL)
if client == nil {
return
}
// 执行请求
err := client.Do(req, resp)
if err != nil {
p.cache.ReleaseLock(hashKey, err)
return
}
// 处理 304 Not Modified 响应
if resp.StatusCode() == 304 {
newHeaders := make(map[string]string)
if lm := resp.Header.Peek("Last-Modified"); len(lm) > 0 {
newHeaders["Last-Modified"] = string(lm)
}
if et := resp.Header.Peek("ETag"); len(et) > 0 {
newHeaders["ETag"] = string(et)
}
p.cache.RefreshTTL(hashKey, origKey, newHeaders)
return
}
// 提取响应头(使用 pool 复用 map
headers, ok := headersPool.Get().(map[string]string)
if !ok {
headers = make(map[string]string, 20)
}
for k := range headers {
delete(headers, k)
}
for key, value := range resp.Header.All() {
headers[string(key)] = string(value)
}
// 更新缓存
p.cache.Set(hashKey, origKey, resp.Body(), headers, resp.StatusCode(), p.getCacheDuration(resp.StatusCode()))
}
// GetCache 返回代理的 ProxyCache 实例(用于 purge handler
// 如果缓存未启用,返回 nil。
func (p *Proxy) GetCache() *cache.ProxyCache {
return p.cache
}
// GetCacheStats 返回代理缓存的统计信息。
// 如果缓存未启用,返回 nil。
func (p *Proxy) GetCacheStats() *cache.ProxyCacheStats {
if p.cache == nil {
return nil
}
stats := p.cache.Stats()
return &stats
}
// getCacheDuration 根据状态码获取缓存时间。
// 优先级CacheValid 配置 > MaxAge
//
// 映射规则:
// - 200-299: CacheValid.OK0 时继承 MaxAge
// - 301/302: CacheValid.Redirect
// - 404: CacheValid.NotFound
// - 400-499除 404: CacheValid.ClientError
// - 500-599: CacheValid.ServerError
// - 其他: 不缓存(返回 0
func (p *Proxy) getCacheDuration(statusCode int) time.Duration {
// 无 CacheValid 配置,使用 MaxAge
if p.config.CacheValid == nil {
return p.config.Cache.MaxAge
}
cv := p.config.CacheValid
switch {
case statusCode >= 200 && statusCode < 300:
if cv.OK > 0 {
return cv.OK
}
return p.config.Cache.MaxAge // 0 表示继承 MaxAge
case statusCode == 301 || statusCode == 302:
return cv.Redirect // 0 表示不缓存
case statusCode == 404:
return cv.NotFound
case statusCode >= 400 && statusCode < 500:
return cv.ClientError
case statusCode >= 500:
return cv.ServerError
default:
return 0 // 不缓存
}
}

View File

@ -0,0 +1,166 @@
package proxy
import (
"strings"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/logging"
"rua.plus/lolly/internal/variable"
)
// modifyRequestHeaders 在转发请求到后端之前修改请求头。
//
// 执行以下操作:
// 1. 设置 Host header 为目标主机地址
// 2. 提取并设置 X-Forwarded-For、X-Real-IP、X-Forwarded-Host、X-Forwarded-Proto
// 3. 应用自定义请求头配置(支持变量展开)
// 4. 移除配置的请求头
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - target: 选中的后端目标
func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalance.Target) {
headers := &ctx.Request.Header
// 设置 Host header 为目标主机
// 从 target.URL 提取 host:portHostClient 连接需要此格式)
targetHost := extractHostFromURL(target.URL)
if targetHost != "" {
headers.Set("Host", targetHost)
}
// 提取并设置 X-Forwarded 系列头
fh := ExtractForwardedHeaders(ctx)
SetForwardedHeaders(headers, fh, true)
// 从配置设置自定义请求头(支持变量展开)
if p.config.Headers.SetRequest != nil {
vc := variable.NewContext(ctx)
defer variable.ReleaseContext(vc)
for key, value := range p.config.Headers.SetRequest {
expanded := vc.Expand(value)
if containsCRLF(expanded) {
logging.Warn().Msgf("rejected CRLF in header value: %s", key)
continue
}
headers.Set(key, expanded)
}
}
// 移除配置的请求头
if len(p.config.Headers.Remove) > 0 {
for _, key := range p.config.Headers.Remove {
headers.Del(key)
}
}
}
// modifyResponseHeaders 在发送给客户端之前修改响应头。
//
// 应用自定义响应头配置,支持变量展开(如 $upstream_addr、$status 等)。
//
// 参数:
// - ctx: FastHTTP 请求上下文
func (p *Proxy) modifyResponseHeaders(ctx *fasthttp.RequestCtx) {
respHeaders := &ctx.Response.Header
// 构建 PassResponse 集合(多处使用)
passSet := make(map[string]bool, len(p.config.Headers.PassResponse))
for _, h := range p.config.Headers.PassResponse {
passSet[h] = true
}
// PassResponse 白名单模式:仅传递列出的头部
if len(passSet) > 0 {
var toDelete []string
for key := range respHeaders.All() {
if !passSet[string(key)] {
toDelete = append(toDelete, string(key))
}
}
for _, k := range toDelete {
respHeaders.Del(k)
}
}
// HideResponse移除指定的响应头PassResponse 优先,跳过已传递的头部)
for _, key := range p.config.Headers.HideResponse {
if !passSet[key] {
respHeaders.Del(key)
}
}
// IgnoreHeaders从请求和响应中移除PassResponse 优先)
for _, key := range p.config.Headers.IgnoreHeaders {
ctx.Request.Header.Del(key)
if !passSet[key] {
respHeaders.Del(key)
}
}
// Cookie 域/路径重写
if p.config.Headers.CookieDomain != "" || p.config.Headers.CookiePath != "" {
p.rewriteCookies(respHeaders)
}
// 从配置设置自定义响应头(支持变量展开)
if p.config.Headers.SetResponse != nil {
vc := variable.NewContext(ctx)
defer variable.ReleaseContext(vc)
for key, value := range p.config.Headers.SetResponse {
expanded := vc.Expand(value)
if containsCRLF(expanded) {
logging.Warn().Msgf("rejected CRLF in header value: %s", key)
continue
}
respHeaders.Set(key, expanded)
}
}
}
// rewriteCookies 重写响应中 Set-Cookie 头的 domain 和 path。
func (p *Proxy) rewriteCookies(respHeaders *fasthttp.ResponseHeader) {
cookieDomain := p.config.Headers.CookieDomain
cookiePath := p.config.Headers.CookiePath
if cookieDomain == "" && cookiePath == "" {
return
}
cookies := make([]string, 0, respHeaders.Len())
for _, value := range respHeaders.Cookies() {
cookie := string(value)
if cookieDomain != "" {
cookie = rewriteCookieAttr(cookie, "Domain", cookieDomain)
}
if cookiePath != "" {
cookie = rewriteCookieAttr(cookie, "Path", cookiePath)
}
cookies = append(cookies, cookie)
}
if len(cookies) > 0 {
respHeaders.Del("Set-Cookie")
for _, c := range cookies {
respHeaders.Add("Set-Cookie", c)
}
}
}
// rewriteCookieAttr 替换 Cookie 字符串中指定属性的值(大小写不敏感)。
func rewriteCookieAttr(cookie, attr, newValue string) string {
prefix := attr + "="
lower := strings.ToLower(cookie)
idx := strings.Index(lower, strings.ToLower(prefix))
if idx == -1 {
return cookie
}
start := idx + len(prefix)
end := start
for end < len(cookie) && cookie[end] != ';' && cookie[end] != ' ' {
end++
}
return cookie[:start] + newValue + cookie[end:]
}

View File

@ -33,10 +33,8 @@ package proxy
import (
"bytes"
"context"
"errors"
"fmt"
"hash/fnv"
"net"
urlpath "path"
"slices"
@ -46,7 +44,6 @@ import (
"time"
"github.com/valyala/fasthttp"
glua "github.com/yuin/gopher-lua"
"rua.plus/lolly/internal/cache"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
@ -902,352 +899,6 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
return p.selectByBalancer(ctx, targets)
}
// selectByLua 使用 Lua 脚本选择后端目标。
//
// 执行配置的 Lua 脚本,脚本可通过 ngx.balancer.set_current_peer() 选择目标。
// 如果 Lua 脚本执行失败或未调用 set_current_peer返回 nil 表示需要使用 fallback 算法。
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - targets: 候选目标列表
//
// 返回值:
// - *loadbalance.Target: Lua 脚本选中的目标nil 表示未选择
// - error: Lua 执行失败时返回错误
func (p *Proxy) selectByLua(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) (*loadbalance.Target, error) {
clientIP := netutil.ExtractClientIP(ctx)
bctx := &lua.BalancerContext{
Targets: targets,
ClientIP: clientIP,
Retries: p.config.NextUpstream.Tries,
}
// 创建 Lua 协程
coro, err := p.luaEngine.NewCoroutine(ctx)
if err != nil {
return nil, fmt.Errorf("create lua coroutine: %w", err)
}
defer coro.Close()
// 注册 balancer API
L := coro.Co
ngx, ok := L.GetGlobal("ngx").(*glua.LTable)
if !ok {
return nil, fmt.Errorf("global 'ngx' is not an LTable")
}
lua.RegisterBalancerAPI(L, bctx, ngx)
// 设置超时
timeout := p.config.BalancerByLua.Timeout
if timeout <= 0 {
timeout = 100 * time.Millisecond
}
// 执行脚本(带超时)
execCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
coro.ExecutionContext = execCtx
err = coro.ExecuteFile(p.config.BalancerByLua.Script)
if err != nil {
return nil, fmt.Errorf("execute lua script: %w", err)
}
// 检查是否调用了 set_current_peer
if !bctx.IsSelected() {
return nil, nil // 未选择,返回 nil 表示需使用 fallback
}
return bctx.Selected, nil
}
// selectByFallback 使用 fallback 负载均衡算法选择目标。
//
// 当 Lua balancer 执行失败或未选择目标时使用。
// 对于 IPHash 算法,会自动提取客户端 IP 进行哈希选择。
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - targets: 候选目标列表
//
// 返回值:
// - *loadbalance.Target: fallback 算法选中的目标
func (p *Proxy) selectByFallback(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target {
p.mu.RLock()
balancer := p.fallbackBalancer
p.mu.RUnlock()
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
clientIP := netutil.ExtractClientIP(ctx)
return ipHash.SelectByIP(targets, clientIP)
}
return balancer.Select(targets)
}
// selectByBalancer 使用主负载均衡器选择目标。
//
// 对于特殊算法IPHash、ConsistentHash会从请求上下文中提取
// 相应的哈希键(客户端 IP、URI、自定义 Header 等)。
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - targets: 候选目标列表
//
// 返回值:
// - *loadbalance.Target: 主负载均衡器选中的目标
func (p *Proxy) selectByBalancer(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target {
p.mu.RLock()
balancer := p.balancer
p.mu.RUnlock()
// 对于 IPHash 负载均衡器,提取客户端 IP
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
clientIP := netutil.ExtractClientIP(ctx)
return ipHash.SelectByIP(targets, clientIP)
}
// 对于一致性哈希,根据 hash_key 配置选择
if ch, ok := balancer.(*loadbalance.ConsistentHash); ok {
hashKey := ch.GetHashKey()
key := p.extractHashKey(ctx, hashKey)
return ch.SelectByKey(targets, key)
}
return balancer.Select(targets)
}
// selectTargetExcluding 选择后端目标,排除已尝试失败的目标。
// 用于故障转移场景,避免重复选择已失败的目标。
// 如果没有可用的健康目标则返回 nil。
func (p *Proxy) selectTargetExcluding(ctx *fasthttp.RequestCtx, excluded []*loadbalance.Target) *loadbalance.Target {
p.mu.RLock()
balancer := p.balancer
targets := p.targets
p.mu.RUnlock()
if len(targets) == 0 {
return nil
}
// 对于 IPHash 负载均衡器,提取客户端 IP
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
clientIP := netutil.ExtractClientIP(ctx)
return ipHash.SelectExcludingByIP(targets, excluded, clientIP)
}
// 对于一致性哈希,根据 hash_key 配置选择
if ch, ok := balancer.(*loadbalance.ConsistentHash); ok {
hashKey := ch.GetHashKey()
key := p.extractHashKey(ctx, hashKey)
return ch.SelectExcludingByKey(targets, excluded, key)
}
return balancer.SelectExcluding(targets, excluded)
}
// extractHashKey 根据一致性哈希配置提取哈希键值。
//
// 支持的 hash_key 配置:
// - "ip" 或 "": 使用客户端 IP 地址
// - "uri": 使用完整请求 URI
// - "header:NAME": 使用指定请求头的值,缺失时回退到客户端 IP
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - hashKey: 哈希键配置
//
// 返回值:
// - string: 提取的哈希键值
func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string {
switch {
case hashKey == "ip" || hashKey == "":
return netutil.ExtractClientIP(ctx)
case hashKey == "uri":
return string(ctx.RequestURI())
case strings.HasPrefix(hashKey, "header:"):
headerName := strings.TrimPrefix(hashKey, "header:")
value := ctx.Request.Header.Peek(headerName)
if len(value) > 0 {
return string(value)
}
return netutil.ExtractClientIP(ctx) // fallback to IP
default:
return netutil.ExtractClientIP(ctx)
}
}
// getClient 返回指定目标 URL 对应的 HostClient 连接池实例。
// 如果目标 URL 不存在于连接池中,返回 nil。
func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient {
key := targetURL
if p.config.ProxyBind != "" {
key = targetURL + "|" + p.config.ProxyBind
}
p.mu.RLock()
client := p.clients[key]
p.mu.RUnlock()
return client
}
// modifyRequestHeaders 在转发请求到后端之前修改请求头。
//
// 执行以下操作:
// 1. 设置 Host header 为目标主机地址
// 2. 提取并设置 X-Forwarded-For、X-Real-IP、X-Forwarded-Host、X-Forwarded-Proto
// 3. 应用自定义请求头配置(支持变量展开)
// 4. 移除配置的请求头
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - target: 选中的后端目标
func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalance.Target) {
headers := &ctx.Request.Header
// 设置 Host header 为目标主机
// 从 target.URL 提取 host:portHostClient 连接需要此格式)
targetHost := extractHostFromURL(target.URL)
if targetHost != "" {
headers.Set("Host", targetHost)
}
// 提取并设置 X-Forwarded 系列头
fh := ExtractForwardedHeaders(ctx)
SetForwardedHeaders(headers, fh, true)
// 从配置设置自定义请求头(支持变量展开)
if p.config.Headers.SetRequest != nil {
vc := variable.NewContext(ctx)
defer variable.ReleaseContext(vc)
for key, value := range p.config.Headers.SetRequest {
expanded := vc.Expand(value)
if containsCRLF(expanded) {
logging.Warn().Msgf("rejected CRLF in header value: %s", key)
continue
}
headers.Set(key, expanded)
}
}
// 移除配置的请求头
if len(p.config.Headers.Remove) > 0 {
for _, key := range p.config.Headers.Remove {
headers.Del(key)
}
}
}
// modifyResponseHeaders 在发送给客户端之前修改响应头。
//
// 应用自定义响应头配置,支持变量展开(如 $upstream_addr、$status 等)。
//
// 参数:
// - ctx: FastHTTP 请求上下文
func (p *Proxy) modifyResponseHeaders(ctx *fasthttp.RequestCtx) {
respHeaders := &ctx.Response.Header
// 构建 PassResponse 集合(多处使用)
passSet := make(map[string]bool, len(p.config.Headers.PassResponse))
for _, h := range p.config.Headers.PassResponse {
passSet[h] = true
}
// PassResponse 白名单模式:仅传递列出的头部
if len(passSet) > 0 {
var toDelete []string
for key := range respHeaders.All() {
if !passSet[string(key)] {
toDelete = append(toDelete, string(key))
}
}
for _, k := range toDelete {
respHeaders.Del(k)
}
}
// HideResponse移除指定的响应头PassResponse 优先,跳过已传递的头部)
for _, key := range p.config.Headers.HideResponse {
if !passSet[key] {
respHeaders.Del(key)
}
}
// IgnoreHeaders从请求和响应中移除PassResponse 优先)
for _, key := range p.config.Headers.IgnoreHeaders {
ctx.Request.Header.Del(key)
if !passSet[key] {
respHeaders.Del(key)
}
}
// Cookie 域/路径重写
if p.config.Headers.CookieDomain != "" || p.config.Headers.CookiePath != "" {
p.rewriteCookies(respHeaders)
}
// 从配置设置自定义响应头(支持变量展开)
if p.config.Headers.SetResponse != nil {
vc := variable.NewContext(ctx)
defer variable.ReleaseContext(vc)
for key, value := range p.config.Headers.SetResponse {
expanded := vc.Expand(value)
if containsCRLF(expanded) {
logging.Warn().Msgf("rejected CRLF in header value: %s", key)
continue
}
respHeaders.Set(key, expanded)
}
}
}
// rewriteCookies 重写响应中 Set-Cookie 头的 domain 和 path。
func (p *Proxy) rewriteCookies(respHeaders *fasthttp.ResponseHeader) {
cookieDomain := p.config.Headers.CookieDomain
cookiePath := p.config.Headers.CookiePath
if cookieDomain == "" && cookiePath == "" {
return
}
cookies := make([]string, 0, respHeaders.Len())
for _, value := range respHeaders.Cookies() {
cookie := string(value)
if cookieDomain != "" {
cookie = rewriteCookieAttr(cookie, "Domain", cookieDomain)
}
if cookiePath != "" {
cookie = rewriteCookieAttr(cookie, "Path", cookiePath)
}
cookies = append(cookies, cookie)
}
if len(cookies) > 0 {
respHeaders.Del("Set-Cookie")
for _, c := range cookies {
respHeaders.Add("Set-Cookie", c)
}
}
}
// rewriteCookieAttr 替换 Cookie 字符串中指定属性的值(大小写不敏感)。
func rewriteCookieAttr(cookie, attr, newValue string) string {
prefix := attr + "="
lower := strings.ToLower(cookie)
idx := strings.Index(lower, strings.ToLower(prefix))
if idx == -1 {
return cookie
}
start := idx + len(prefix)
end := start
for end < len(cookie) && cookie[end] != ';' && cookie[end] != ' ' {
end++
}
return cookie[:start] + newValue + cookie[end:]
}
// isWebSocketRequest 检查请求是否为 WebSocket 升级请求。
//
// 通过检查 Connection 和 Upgrade 请求头判断:
@ -1333,151 +984,6 @@ func (p *Proxy) GetConfig() *config.ProxyConfig {
return p.config
}
// buildCacheKey 构建缓存键字符串。
//
// 使用请求方法和完整请求 URI 作为缓存键。
// 该函数保留用于日志记录和调试场景。
//
// 参数:
// - ctx: FastHTTP 请求上下文
//
// 返回值:
// - string: 缓存键(格式 "METHOD:URI"
func (p *Proxy) buildCacheKey(ctx *fasthttp.RequestCtx) string {
// 使用请求方法和路径作为缓存键
return string(ctx.Request.Header.Method()) + ":" + string(ctx.Request.URI().RequestURI())
}
// buildCacheKeyHash 使用 FNV-64a 计算缓存键的 uint64 哈希值。
// 返回哈希值和原始字符串键。
// 注意:此函数会先构建字符串键再哈希,存在双重分配。
// 对于只需要哈希值的场景,使用 buildCacheKeyHashValue 代替。
func (p *Proxy) buildCacheKeyHash(ctx *fasthttp.RequestCtx) (uint64, string) {
// 构建原始 key
origKey := p.buildCacheKey(ctx)
// 使用 FNV-64a 计算哈希
h := fnv.New64a()
h.Write([]byte(origKey))
return h.Sum64(), origKey
}
// buildCacheKeyHashValue 直接计算缓存键的哈希值,零字符串分配。
// 用于只需要哈希值而不需要原始键的场景。
func (p *Proxy) buildCacheKeyHashValue(ctx *fasthttp.RequestCtx) uint64 {
h := fnv.New64a()
h.Write(ctx.Request.Header.Method())
h.Write([]byte(":"))
h.Write(ctx.Request.URI().RequestURI())
return h.Sum64()
}
// writeCachedResponse 将缓存的响应写入 FastHTTP 响应上下文。
//
// 设置响应体、状态码、响应头,并添加 X-Cache: HIT 头标记缓存命中。
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - entry: 缓存条目,包含响应数据和元数据
func (p *Proxy) writeCachedResponse(ctx *fasthttp.RequestCtx, entry *cache.ProxyCacheEntry) {
ctx.Response.SetBody(entry.Data)
ctx.Response.SetStatusCode(entry.Status)
for key, value := range entry.Headers {
ctx.Response.Header.Set(key, value)
}
ctx.Response.Header.Set("X-Cache", "HIT")
}
// backgroundRefresh 在后台异步刷新缓存条目。
//
// 向对应的上游目标发送请求,获取最新响应并更新缓存。
// 该方法在独立 goroutine 中运行,不阻塞主请求流程。
//
// 参数:
// - ctx: 原始 FastHTTP 请求上下文(仅用于复制请求信息)
// - target: 要刷新的后端目标
// - hashKey: 缓存哈希键
// - origKey: 缓存原始键
func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance.Target, hashKey uint64, origKey string) {
// 创建新的请求上下文副本
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// 复制原始请求
ctx.Request.CopyTo(req)
// 如果启用 Revalidate添加条件请求头
if p.config.Cache.Revalidate {
if entry, ok, _ := p.cache.Get(hashKey, origKey); ok {
if entry.LastModified != "" {
req.Header.Set("If-Modified-Since", entry.LastModified)
}
if entry.ETag != "" {
req.Header.Set("If-None-Match", entry.ETag)
}
}
}
// 获取客户端
client := p.getClient(target.URL)
if client == nil {
return
}
// 执行请求
err := client.Do(req, resp)
if err != nil {
p.cache.ReleaseLock(hashKey, err)
return
}
// 处理 304 Not Modified 响应
if resp.StatusCode() == 304 {
newHeaders := make(map[string]string)
if lm := resp.Header.Peek("Last-Modified"); len(lm) > 0 {
newHeaders["Last-Modified"] = string(lm)
}
if et := resp.Header.Peek("ETag"); len(et) > 0 {
newHeaders["ETag"] = string(et)
}
p.cache.RefreshTTL(hashKey, origKey, newHeaders)
return
}
// 提取响应头(使用 pool 复用 map
headers, ok := headersPool.Get().(map[string]string)
if !ok {
headers = make(map[string]string, 20)
}
for k := range headers {
delete(headers, k)
}
for key, value := range resp.Header.All() {
headers[string(key)] = string(value)
}
// 更新缓存
p.cache.Set(hashKey, origKey, resp.Body(), headers, resp.StatusCode(), p.getCacheDuration(resp.StatusCode()))
}
// GetCache 返回代理的 ProxyCache 实例(用于 purge handler
// 如果缓存未启用,返回 nil。
func (p *Proxy) GetCache() *cache.ProxyCache {
return p.cache
}
// GetCacheStats 返回代理缓存的统计信息。
// 如果缓存未启用,返回 nil。
func (p *Proxy) GetCacheStats() *cache.ProxyCacheStats {
if p.cache == nil {
return nil
}
stats := p.cache.Stats()
return &stats
}
// extractHostFromURL 从 URL 字符串中提取 host:port 部分。
//
// 移除 http:// 或 https:// 协议前缀,以及路径部分,
@ -1505,44 +1011,3 @@ func extractHostFromURL(urlStr string) string {
return host
}
// getCacheDuration 根据状态码获取缓存时间。
// 优先级CacheValid 配置 > MaxAge
//
// 映射规则:
// - 200-299: CacheValid.OK0 时继承 MaxAge
// - 301/302: CacheValid.Redirect
// - 404: CacheValid.NotFound
// - 400-499除 404: CacheValid.ClientError
// - 500-599: CacheValid.ServerError
// - 其他: 不缓存(返回 0
func (p *Proxy) getCacheDuration(statusCode int) time.Duration {
// 无 CacheValid 配置,使用 MaxAge
if p.config.CacheValid == nil {
return p.config.Cache.MaxAge
}
cv := p.config.CacheValid
switch {
case statusCode >= 200 && statusCode < 300:
if cv.OK > 0 {
return cv.OK
}
return p.config.Cache.MaxAge // 0 表示继承 MaxAge
case statusCode == 301 || statusCode == 302:
return cv.Redirect // 0 表示不缓存
case statusCode == 404:
return cv.NotFound
case statusCode >= 400 && statusCode < 500:
return cv.ClientError
case statusCode >= 500:
return cv.ServerError
default:
return 0 // 不缓存
}
}

View File

@ -30,6 +30,7 @@ import (
"strings"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/logging"
)
// tlsVersionMap TLS 版本字符串到 tls 常量的映射表。
@ -100,7 +101,9 @@ func CreateTLSConfig(cfg *config.ProxySSLConfig, defaultServerName string) (*tls
tlsCfg.Certificates = []tls.Certificate{cert}
}
// TLS 版本配置
// TLS 版本配置:默认 MinVersion = TLS 1.2
tlsCfg.MinVersion = tls.VersionTLS12
if cfg.MinVersion != "" {
version, ok := tlsVersionMap[strings.ToUpper(cfg.MinVersion)]
if !ok {
@ -109,6 +112,11 @@ func CreateTLSConfig(cfg *config.ProxySSLConfig, defaultServerName string) (*tls
tlsCfg.MinVersion = version
}
// 警告TLS 1.0/1.1 已不安全,不应在生产环境使用
if tlsCfg.MinVersion < tls.VersionTLS12 {
logging.Warn().Msgf("上游 TLS MinVersion 设置为 %s低于 TLS 1.2),存在安全风险", cfg.MinVersion)
}
if cfg.MaxVersion != "" {
version, ok := tlsVersionMap[strings.ToUpper(cfg.MaxVersion)]
if !ok {

View File

@ -0,0 +1,204 @@
package proxy
import (
"context"
"fmt"
"strings"
"time"
"github.com/valyala/fasthttp"
glua "github.com/yuin/gopher-lua"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/lua"
"rua.plus/lolly/internal/netutil"
)
// selectByLua 使用 Lua 脚本选择后端目标。
//
// 执行配置的 Lua 脚本,脚本可通过 ngx.balancer.set_current_peer() 选择目标。
// 如果 Lua 脚本执行失败或未调用 set_current_peer返回 nil 表示需要使用 fallback 算法。
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - targets: 候选目标列表
//
// 返回值:
// - *loadbalance.Target: Lua 脚本选中的目标nil 表示未选择
// - error: Lua 执行失败时返回错误
func (p *Proxy) selectByLua(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) (*loadbalance.Target, error) {
clientIP := netutil.ExtractClientIP(ctx)
bctx := &lua.BalancerContext{
Targets: targets,
ClientIP: clientIP,
Retries: p.config.NextUpstream.Tries,
}
// 创建 Lua 协程
coro, err := p.luaEngine.NewCoroutine(ctx)
if err != nil {
return nil, fmt.Errorf("create lua coroutine: %w", err)
}
defer coro.Close()
// 注册 balancer API
L := coro.Co
ngx, ok := L.GetGlobal("ngx").(*glua.LTable)
if !ok {
return nil, fmt.Errorf("global 'ngx' is not an LTable")
}
lua.RegisterBalancerAPI(L, bctx, ngx)
// 设置超时
timeout := p.config.BalancerByLua.Timeout
if timeout <= 0 {
timeout = 100 * time.Millisecond
}
// 执行脚本(带超时)
execCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
coro.ExecutionContext = execCtx
err = coro.ExecuteFile(p.config.BalancerByLua.Script)
if err != nil {
return nil, fmt.Errorf("execute lua script: %w", err)
}
// 检查是否调用了 set_current_peer
if !bctx.IsSelected() {
return nil, nil // 未选择,返回 nil 表示需使用 fallback
}
return bctx.Selected, nil
}
// selectByFallback 使用 fallback 负载均衡算法选择目标。
//
// 当 Lua balancer 执行失败或未选择目标时使用。
// 对于 IPHash 算法,会自动提取客户端 IP 进行哈希选择。
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - targets: 候选目标列表
//
// 返回值:
// - *loadbalance.Target: fallback 算法选中的目标
func (p *Proxy) selectByFallback(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target {
p.mu.RLock()
balancer := p.fallbackBalancer
p.mu.RUnlock()
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
clientIP := netutil.ExtractClientIP(ctx)
return ipHash.SelectByIP(targets, clientIP)
}
return balancer.Select(targets)
}
// selectByBalancer 使用主负载均衡器选择目标。
//
// 对于特殊算法IPHash、ConsistentHash会从请求上下文中提取
// 相应的哈希键(客户端 IP、URI、自定义 Header 等)。
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - targets: 候选目标列表
//
// 返回值:
// - *loadbalance.Target: 主负载均衡器选中的目标
func (p *Proxy) selectByBalancer(ctx *fasthttp.RequestCtx, targets []*loadbalance.Target) *loadbalance.Target {
p.mu.RLock()
balancer := p.balancer
p.mu.RUnlock()
// 对于 IPHash 负载均衡器,提取客户端 IP
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
clientIP := netutil.ExtractClientIP(ctx)
return ipHash.SelectByIP(targets, clientIP)
}
// 对于一致性哈希,根据 hash_key 配置选择
if ch, ok := balancer.(*loadbalance.ConsistentHash); ok {
hashKey := ch.GetHashKey()
key := p.extractHashKey(ctx, hashKey)
return ch.SelectByKey(targets, key)
}
return balancer.Select(targets)
}
// selectTargetExcluding 选择后端目标,排除已尝试失败的目标。
// 用于故障转移场景,避免重复选择已失败的目标。
// 如果没有可用的健康目标则返回 nil。
func (p *Proxy) selectTargetExcluding(ctx *fasthttp.RequestCtx, excluded []*loadbalance.Target) *loadbalance.Target {
p.mu.RLock()
balancer := p.balancer
targets := p.targets
p.mu.RUnlock()
if len(targets) == 0 {
return nil
}
// 对于 IPHash 负载均衡器,提取客户端 IP
if ipHash, ok := balancer.(*loadbalance.IPHash); ok {
clientIP := netutil.ExtractClientIP(ctx)
return ipHash.SelectExcludingByIP(targets, excluded, clientIP)
}
// 对于一致性哈希,根据 hash_key 配置选择
if ch, ok := balancer.(*loadbalance.ConsistentHash); ok {
hashKey := ch.GetHashKey()
key := p.extractHashKey(ctx, hashKey)
return ch.SelectExcludingByKey(targets, excluded, key)
}
return balancer.SelectExcluding(targets, excluded)
}
// extractHashKey 根据一致性哈希配置提取哈希键值。
//
// 支持的 hash_key 配置:
// - "ip" 或 "": 使用客户端 IP 地址
// - "uri": 使用完整请求 URI
// - "header:NAME": 使用指定请求头的值,缺失时回退到客户端 IP
//
// 参数:
// - ctx: FastHTTP 请求上下文
// - hashKey: 哈希键配置
//
// 返回值:
// - string: 提取的哈希键值
func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string {
switch {
case hashKey == "ip" || hashKey == "":
return netutil.ExtractClientIP(ctx)
case hashKey == "uri":
return string(ctx.RequestURI())
case strings.HasPrefix(hashKey, "header:"):
headerName := strings.TrimPrefix(hashKey, "header:")
value := ctx.Request.Header.Peek(headerName)
if len(value) > 0 {
return string(value)
}
return netutil.ExtractClientIP(ctx) // fallback to IP
default:
return netutil.ExtractClientIP(ctx)
}
}
// getClient 返回指定目标 URL 对应的 HostClient 连接池实例。
// 如果目标 URL 不存在于连接池中,返回 nil。
func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient {
key := targetURL
if p.config.ProxyBind != "" {
key = targetURL + "|" + p.config.ProxyBind
}
p.mu.RLock()
client := p.clients[key]
p.mu.RUnlock()
return client
}

View File

@ -142,10 +142,10 @@ func initLuaEngine(luaCfg *config.LuaMiddlewareConfig) (*lua.LuaEngine, error) {
engine, err := lua.NewEngine(engineCfg)
if err != nil {
return nil, fmt.Errorf("初始化 Lua 引擎失败: %w", err)
return nil, fmt.Errorf("failed to initialize Lua engine: %w", err)
}
logging.Info().Msg("Lua 引擎已启动")
logging.Info().Msg("Lua engine started")
return engine, nil
}
@ -169,14 +169,14 @@ func initErrorPageManager(errorPageCfg *config.ErrorPageConfig) (*handler.ErrorP
if err != nil {
// 检查是否是部分加载失败
if _, ok := err.(*handler.PartialLoadError); ok {
logging.Warn().Msg("部分错误页面加载失败: " + err.Error())
logging.Warn().Msg("Some error pages failed to load: " + err.Error())
// 返回部分加载的管理器
return manager, nil
}
// 全部加载失败
return nil, fmt.Errorf("加载错误页面失败: %w", err)
return nil, fmt.Errorf("failed to load error pages: %w", err)
}
logging.Info().Msg("错误页面管理器已启动")
logging.Info().Msg("Error page manager started")
return manager, nil
}

View File

@ -0,0 +1,222 @@
package server
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/logging"
)
// cleanupResources 清理服务器资源。
//
// 停止 Goroutine 池、健康检查器关闭访问日志、TLS 管理器、
// AccessControl 和 Lua 引擎。由 StopWithTimeout 和 GracefulStop 共用。
func (s *Server) cleanupResources() {
// 停止 Goroutine 池
if s.pool != nil {
s.pool.Stop()
}
// 停止健康检查器
for _, hc := range s.healthCheckers {
hc.Stop()
}
// 关闭访问日志
if s.accessLogMiddleware != nil {
_ = s.accessLogMiddleware.Close()
}
// 关闭 TLS 管理器
if s.tlsManager != nil {
s.tlsManager.Close()
}
// 关闭 AccessControl (释放 GeoIP 资源)
if s.accessControl != nil {
if err := s.accessControl.Close(); err != nil {
logging.Warn().Err(err).Msg("Failed to close AccessControl")
}
}
// 关闭 Lua 引擎
if s.luaEngine != nil {
s.luaEngine.Close()
logging.Info().Msg("Lua engine closed")
}
}
// shutdownServers 并行关闭多个 fasthttp.Server 实例。
//
// 使用 goroutine 并行关闭所有服务器,收集所有错误并返回聚合错误。
// 部分服务器关闭失败不会影响其他服务器的关闭。
//
// 参数:
// - ctx: 关闭上下文,用于控制超时和取消
// - servers: 要关闭的 fasthttp.Server 实例列表
//
// 返回值:
// - error: 聚合错误,无错误或全部成功时返回 nil
func shutdownServers(ctx context.Context, servers []*fasthttp.Server) error {
// 防御性检查nil ctx 使用默认背景
if ctx == nil {
ctx = context.Background()
}
if len(servers) == 0 {
return nil
}
var (
mu sync.Mutex
errs []error
wg sync.WaitGroup
)
for _, srv := range servers {
if srv == nil {
continue
}
wg.Add(1)
go func(s *fasthttp.Server) {
defer wg.Done()
if err := s.Shutdown(); err != nil {
mu.Lock()
errs = append(errs, err)
mu.Unlock()
}
}(srv)
}
// 等待所有关闭完成或上下文取消
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
if len(errs) == 0 {
return nil
}
if len(errs) == 1 {
return errs[0]
}
msgs := make([]string, len(errs))
for i, e := range errs {
msgs[i] = e.Error()
}
return fmt.Errorf("failed to close servers: %d errors: %s", len(errs), strings.Join(msgs, "; "))
case <-ctx.Done():
return ctx.Err()
}
}
// StopWithTimeout 快速停止服务器(支持自定义超时)。
//
// 立即停止服务器,不等待正在处理的请求完成。
// 停止所有健康检查器和访问日志中间件。
//
// 参数:
// - timeout: 快速关闭的最大等待时间
//
// 返回值:
// - error: 停止过程中遇到的错误
//
// 注意事项:
// - 对于生产环境,建议使用 GracefulStop 实现优雅关闭
// - timeout <= 0 时会使用默认 5s 超时
func (s *Server) StopWithTimeout(timeout time.Duration) error {
// 防御性检查:如果 timeout <= 0使用默认值
if timeout <= 0 {
timeout = 5 * time.Second
}
s.running = false
s.cleanupResources()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// 多服务器模式:并行关闭所有 fasthttp.Server
if len(s.fastServers) > 0 {
return shutdownServers(ctx, s.fastServers)
}
// 单服务器模式:关闭单个 fasthttp.Server
if s.fastServer != nil {
done := make(chan struct{})
go func() {
_ = s.fastServer.Shutdown()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
return nil
}
// GracefulStop 优雅停止服务器。
//
// 等待正在处理的请求完成后再停止服务器,确保连接正常关闭。
// 如果超时时间到达仍有请求未完成,将返回超时错误。
//
// 参数:
// - timeout: 优雅关闭的最大等待时间
//
// 返回值:
// - error: 停止过程中遇到的错误,超时返回 context.DeadlineExceeded
//
// 注意事项:
// - 推荐在生产环境使用此方法关闭服务器
// - 超时后会强制关闭,可能导致部分请求中断
func (s *Server) GracefulStop(timeout time.Duration) error {
s.running = false
s.cleanupResources()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// 多服务器模式:并行关闭所有 fasthttp.Server
if len(s.fastServers) > 0 {
return shutdownServers(ctx, s.fastServers)
}
// 单服务器模式:关闭单个 fasthttp.Server
if s.fastServer != nil {
done := make(chan struct{})
go func() {
_ = s.fastServer.Shutdown()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
return nil
}
// getProxyCacheStats 收集所有代理缓存的统计信息。
func (s *Server) getProxyCacheStats() ProxyCacheStats {
var total ProxyCacheStats
for _, p := range s.proxies {
if stats := p.GetCacheStats(); stats != nil {
total.Entries += stats.Entries
total.Pending += stats.Pending
}
}
return total
}

View File

@ -75,7 +75,7 @@ func TestBuildLuaMiddlewares_InvalidPhase(t *testing.T) {
_, err = s.buildLuaMiddlewares(luaCfg)
assert.Error(t, err)
assert.Contains(t, err.Error(), "无效的阶段")
assert.Contains(t, err.Error(), "invalid phase")
}
// TestBuildLuaMiddlewares_WithTimeout 测试超时配置

View File

@ -0,0 +1,246 @@
package server
import (
"fmt"
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/lua"
"rua.plus/lolly/internal/middleware"
"rua.plus/lolly/internal/middleware/accesslog"
"rua.plus/lolly/internal/middleware/bodylimit"
"rua.plus/lolly/internal/middleware/compression"
"rua.plus/lolly/internal/middleware/errorintercept"
"rua.plus/lolly/internal/middleware/rewrite"
"rua.plus/lolly/internal/middleware/security"
)
// buildMiddlewareChain 构建中间件链。
//
// 根据服务器配置按顺序构建中间件链,顺序为:
//
// AccessLog -> AccessControl -> RateLimiter -> BasicAuth -> Rewrite -> Compression -> SecurityHeaders
//
// 参数:
// - serverCfg: 单个服务器的配置对象
//
// 返回值:
// - *middleware.Chain: 构建完成的中间件链
// - error: 构建过程中遇到的错误,如中间件创建失败
//
// 注意事项:
// - 各中间件按顺序依次包装请求处理器
// - 未配置的中间件不会添加到链中
func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middleware.Chain, error) {
var middlewares []middleware.Middleware
// 1. AccessLog (已集成)
s.accessLogMiddleware = accesslog.New(&s.config.Logging)
middlewares = append(middlewares, s.accessLogMiddleware)
// 2. Security: AccessControl (IP 访问控制)
if len(serverCfg.Security.Access.Allow) > 0 || len(serverCfg.Security.Access.Deny) > 0 {
ac, err := security.NewAccessControl(&serverCfg.Security.Access)
if err != nil {
return nil, fmt.Errorf("failed to create access control middleware: %w", err)
}
middlewares = append(middlewares, ac)
s.accessControl = ac
}
// 3. Security: RateLimiter (速率限制)
if serverCfg.Security.RateLimit.RequestRate > 0 {
rl, err := security.NewRateLimiter(&serverCfg.Security.RateLimit)
if err != nil {
return nil, fmt.Errorf("failed to create rate limiter middleware: %w", err)
}
middlewares = append(middlewares, rl)
}
// 3.5 Security: ConnLimiter (连接数限制)
if serverCfg.Security.RateLimit.ConnLimit > 0 {
cl, err := security.NewConnLimiter(serverCfg.Security.RateLimit.ConnLimit, true, serverCfg.Security.RateLimit.Key)
if err != nil {
return nil, fmt.Errorf("failed to create connection limiter middleware: %w", err)
}
middlewares = append(middlewares, cl.Middleware())
}
// 4. Security: BasicAuth (认证)
if len(serverCfg.Security.Auth.Users) > 0 {
auth, err := security.NewBasicAuth(&serverCfg.Security.Auth)
if err != nil {
return nil, fmt.Errorf("failed to create auth middleware: %w", err)
}
middlewares = append(middlewares, auth)
}
// 4.3 Security: AuthRequest (外部认证子请求)
if serverCfg.Security.AuthRequest.Enabled && serverCfg.Security.AuthRequest.URI != "" {
authReq, err := security.NewAuthRequest(serverCfg.Security.AuthRequest)
if err != nil {
return nil, fmt.Errorf("failed to create auth request middleware: %w", err)
}
middlewares = append(middlewares, authReq)
}
// 4.5 BodyLimit (请求体大小限制)
// 创建 bodylimit 中间件,使用全局配置或默认值
bodyLimitMiddleware := bodylimit.NewWithDefault()
if serverCfg.ClientMaxBodySize != "" {
bl, err := bodylimit.New(serverCfg.ClientMaxBodySize)
if err != nil {
return nil, fmt.Errorf("failed to create body limit middleware: %w", err)
}
bodyLimitMiddleware = bl
}
// 添加路径级别的限制配置
for i := range serverCfg.Proxy {
if serverCfg.Proxy[i].ClientMaxBodySize != "" {
if err := bodyLimitMiddleware.AddPathLimit(
serverCfg.Proxy[i].Path,
serverCfg.Proxy[i].ClientMaxBodySize,
); err != nil {
return nil, fmt.Errorf("failed to add path body limit: %w", err)
}
}
}
middlewares = append(middlewares, bodyLimitMiddleware)
// 5. Rewrite (URL 重写)
if len(serverCfg.Rewrite) > 0 {
rw, err := rewrite.New(serverCfg.Rewrite)
if err != nil {
return nil, fmt.Errorf("failed to create rewrite middleware: %w", err)
}
middlewares = append(middlewares, rw)
}
// 6. Compression (响应压缩)
if serverCfg.Compression.Type != "" {
comp, err := compression.New(&serverCfg.Compression)
if err != nil {
return nil, fmt.Errorf("failed to create compression middleware: %w", err)
}
middlewares = append(middlewares, comp)
}
// 7. SecurityHeaders (安全头部)
// 如果有任何安全头部配置,则启用
if serverCfg.Security.Headers.XFrameOptions != "" ||
serverCfg.Security.Headers.XContentTypeOptions != "" ||
serverCfg.Security.Headers.ContentSecurityPolicy != "" ||
serverCfg.Security.Headers.ReferrerPolicy != "" ||
serverCfg.Security.Headers.PermissionsPolicy != "" {
headers := security.NewHeadersWithHSTS(&serverCfg.Security.Headers, &serverCfg.SSL.HSTS)
middlewares = append(middlewares, headers)
}
// 8. ErrorIntercept (错误页面拦截)
// 如果配置了错误页面,添加错误拦截中间件
if s.errorPageManager != nil && s.errorPageManager.IsConfigured() {
ei := errorintercept.New(s.errorPageManager)
middlewares = append(middlewares, ei)
}
// Lua 中间件(可选)
if s.luaEngine != nil && serverCfg.Lua != nil && serverCfg.Lua.Enabled {
luaMiddlewares, err := s.buildLuaMiddlewares(serverCfg.Lua)
if err != nil {
return nil, fmt.Errorf("failed to create Lua middleware: %w", err)
}
middlewares = append(middlewares, luaMiddlewares...)
}
return middleware.NewChain(middlewares...), nil
}
// buildLuaMiddlewares 根据 Lua 配置创建中间件。
//
// 根据 Scripts 配置创建 LuaMiddleware 或 MultiPhaseLuaMiddleware。
// 支持单脚本和多阶段脚本配置。
//
// 参数:
// - luaCfg: Lua 配置对象
//
// 返回值:
// - []middleware.Middleware: 创建的中间件列表
// - error: 创建过程中遇到的错误
func (s *Server) buildLuaMiddlewares(luaCfg *config.LuaMiddlewareConfig) ([]middleware.Middleware, error) {
if s.luaEngine == nil {
return nil, nil
}
// 按阶段分组脚本
phaseScripts := make(map[string][]config.LuaScriptConfig)
for _, script := range luaCfg.Scripts {
// 默认启用
enabled := script.Enabled
if !enabled && script.Timeout == 0 && script.Path != "" {
enabled = true // 零值时默认启用
}
if enabled {
phaseScripts[script.Phase] = append(phaseScripts[script.Phase], script)
}
}
var middlewares []middleware.Middleware
// 为每个阶段创建中间件
for phase, scripts := range phaseScripts {
if len(scripts) == 0 {
continue
}
// 单脚本:直接创建 LuaMiddleware
if len(scripts) == 1 {
script := scripts[0]
luaPhase, err := lua.ParsePhase(phase)
if err != nil {
return nil, fmt.Errorf("invalid phase '%s': %w", phase, err)
}
timeout := script.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
cfg := lua.LuaMiddlewareConfig{
ScriptPath: script.Path,
Phase: luaPhase,
Timeout: timeout,
Name: fmt.Sprintf("lua-%s", phase),
}
mw, err := lua.NewLuaMiddleware(s.luaEngine, cfg)
if err != nil {
return nil, fmt.Errorf("failed to create Lua middleware (phase=%s): %w", phase, err)
}
middlewares = append(middlewares, mw)
} else {
// 多脚本:创建 MultiPhaseLuaMiddleware
multi := lua.NewMultiPhaseLuaMiddleware(s.luaEngine, fmt.Sprintf("lua-multi-%s", phase))
for _, script := range scripts {
luaPhase, err := lua.ParsePhase(phase)
if err != nil {
return nil, fmt.Errorf("invalid phase '%s': %w", phase, err)
}
timeout := script.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
err = multi.AddPhase(luaPhase, script.Path, timeout)
if err != nil {
return nil, fmt.Errorf("failed to add Lua phase (phase=%s): %w", phase, err)
}
}
middlewares = append(middlewares, multi)
}
}
return middlewares, nil
}

View File

@ -76,7 +76,7 @@ func NewPprofHandler(cfg *config.PprofConfig) (*PprofHandler, error) {
// 尝试解析 CIDR
_, net, err := net.ParseCIDR(ipStr)
if err != nil {
return nil, fmt.Errorf("解析 IP/CIDR 失败: %s: %w", ipStr, err)
return nil, fmt.Errorf("failed to parse IP/CIDR: %s: %w", ipStr, err)
}
h.allowedNets = append(h.allowedNets, net)
}

View File

@ -7,12 +7,11 @@ import (
"encoding/json"
"net"
"net/netip"
"strings"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/cache"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/netutil"
"rua.plus/lolly/internal/utils"
)
// PurgeHandler 缓存清理 API 处理器。
@ -104,26 +103,26 @@ func (h *PurgeHandler) Path() string {
func (h *PurgeHandler) ServeHTTP(ctx *fasthttp.RequestCtx) {
// 仅允许 POST 方法
if string(ctx.Method()) != "POST" {
h.sendError(ctx, fasthttp.StatusMethodNotAllowed, "method not allowed")
utils.SendJSONError(ctx, fasthttp.StatusMethodNotAllowed, "method not allowed")
return
}
// 检查 IP 访问权限
if !h.checkAccess(ctx) {
h.sendError(ctx, fasthttp.StatusForbidden, "forbidden")
if !utils.CheckIPAccess(ctx, h.allowed) {
utils.SendJSONError(ctx, fasthttp.StatusForbidden, "forbidden")
return
}
// 检查认证
if !h.checkAuth(ctx) {
h.sendError(ctx, fasthttp.StatusUnauthorized, "unauthorized")
if !utils.CheckTokenAuth(ctx, h.auth) {
utils.SendJSONError(ctx, fasthttp.StatusUnauthorized, "unauthorized")
return
}
// 解析请求体
var req cache.PurgeRequest
if err := json.Unmarshal(ctx.PostBody(), &req); err != nil {
h.sendError(ctx, fasthttp.StatusBadRequest, "invalid request body")
utils.SendJSONError(ctx, fasthttp.StatusBadRequest, "invalid request body")
return
}
@ -134,7 +133,7 @@ func (h *PurgeHandler) ServeHTTP(ctx *fasthttp.RequestCtx) {
} else if req.Pattern != "" {
deleted = h.purgeByPattern(req.Pattern, req.Method)
} else {
h.sendError(ctx, fasthttp.StatusBadRequest, "missing path or pattern")
utils.SendJSONError(ctx, fasthttp.StatusBadRequest, "missing path or pattern")
return
}
@ -144,55 +143,6 @@ func (h *PurgeHandler) ServeHTTP(ctx *fasthttp.RequestCtx) {
_ = json.NewEncoder(ctx).Encode(cache.PurgeResponse{Deleted: deleted})
}
// checkAccess 检查客户端 IP 是否在允许列表中。
func (h *PurgeHandler) checkAccess(ctx *fasthttp.RequestCtx) bool {
// 如果没有配置允许列表,允许所有访问
if len(h.allowed) == 0 {
return true
}
clientIP := netutil.ExtractClientIPNet(ctx)
if clientIP == nil {
return false
}
// 检查是否在允许列表中
for _, network := range h.allowed {
if network.Contains(clientIP) {
return true
}
}
return false
}
// checkAuth 检查认证。
func (h *PurgeHandler) checkAuth(ctx *fasthttp.RequestCtx) bool {
// 无需认证
if h.auth.Type == "" || h.auth.Type == "none" {
return true
}
// Token 认证
if h.auth.Type == "token" {
authHeader := ctx.Request.Header.Peek("Authorization")
if len(authHeader) == 0 {
return false
}
authStr := string(authHeader)
// 支持 Bearer token 格式
if token, ok := strings.CutPrefix(authStr, "Bearer "); ok {
return token == h.auth.Token
}
// 也支持直接传递 token
return authStr == h.auth.Token
}
return false
}
// purgeByPath 按精确路径清理缓存。
func (h *PurgeHandler) purgeByPath(path string, method string) int {
if h.server == nil {
@ -229,13 +179,6 @@ func (h *PurgeHandler) purgeByPattern(pattern string, method string) int {
return deleted
}
// sendError 发送错误响应。
func (h *PurgeHandler) sendError(ctx *fasthttp.RequestCtx, status int, errMsg string) {
ctx.SetContentType("application/json; charset=utf-8")
ctx.SetStatusCode(status)
_ = json.NewEncoder(ctx).Encode(cache.PurgeErrorResponse{Error: errMsg})
}
// PurgeByPathForTest 测试用的导出方法。
func (h *PurgeHandler) PurgeByPathForTest(path string, method string) int {
return h.purgeByPath(path, method)

View File

@ -24,6 +24,7 @@ import (
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/proxy"
"rua.plus/lolly/internal/utils"
)
func TestPurgeHandler_Path(t *testing.T) {
@ -238,7 +239,7 @@ func TestPurgeHandler_checkAccess(t *testing.T) {
if len(h.allowed) == 0 {
// 无白名单时应允许所有访问
if !h.checkAccess(nil) {
if !utils.CheckIPAccess(nil, h.allowed) {
t.Error("expected access to be true when no allow list configured")
}
return
@ -281,7 +282,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
}
ctx := &fasthttp.RequestCtx{}
if !h.checkAuth(ctx) {
if !utils.CheckTokenAuth(ctx, h.auth) {
t.Error("expected auth to pass when no auth configured")
}
})
@ -301,7 +302,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
}
ctx := &fasthttp.RequestCtx{}
if !h.checkAuth(ctx) {
if !utils.CheckTokenAuth(ctx, h.auth) {
t.Error("expected auth to pass when type is none")
}
})
@ -323,7 +324,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Authorization", "Bearer secret-token")
if !h.checkAuth(ctx) {
if !utils.CheckTokenAuth(ctx, h.auth) {
t.Error("expected auth to pass with correct Bearer token")
}
})
@ -345,7 +346,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Authorization", "secret-token")
if !h.checkAuth(ctx) {
if !utils.CheckTokenAuth(ctx, h.auth) {
t.Error("expected auth to pass with correct direct token")
}
})
@ -367,7 +368,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Authorization", "Bearer wrong-token")
if h.checkAuth(ctx) {
if utils.CheckTokenAuth(ctx, h.auth) {
t.Error("expected auth to fail with wrong token")
}
})
@ -388,7 +389,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
if h.checkAuth(ctx) {
if utils.CheckTokenAuth(ctx, h.auth) {
t.Error("expected auth to fail when Authorization header is missing")
}
})
@ -410,7 +411,7 @@ func TestPurgeHandler_checkAuth(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Authorization", "Bearer secret-token")
if h.checkAuth(ctx) {
if utils.CheckTokenAuth(ctx, h.auth) {
t.Error("expected auth to fail for unknown auth type")
}
})
@ -570,7 +571,7 @@ func TestPurgeHandler_SendError(t *testing.T) {
Allow: []string{},
}
h, err := NewPurgeHandler(nil, cfg)
_, err := NewPurgeHandler(nil, cfg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@ -578,7 +579,7 @@ func TestPurgeHandler_SendError(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
h.sendError(ctx, tt.status, tt.errMsg)
utils.SendJSONError(ctx, tt.status, tt.errMsg)
if ctx.Response.StatusCode() != tt.status {
t.Errorf("expected status %d, got %d", tt.status, ctx.Response.StatusCode())
@ -690,7 +691,7 @@ func TestPurgeHandler_checkAccess_NilContext(t *testing.T) {
}
// Empty allow list should allow access (returns true even with nil context)
if !h.checkAccess(nil) {
if !utils.CheckIPAccess(nil, h.allowed) {
t.Error("expected checkAccess to return true with empty allow list")
}
})
@ -779,7 +780,7 @@ func TestPurgeHandler_checkAccess_WithAllowedIP(t *testing.T) {
ctx.Init(&fasthttp.Request{}, nil, nil)
// context with nil remote address - should return false (no client IP)
if h.checkAccess(ctx) {
if utils.CheckIPAccess(ctx, h.allowed) {
t.Error("expected checkAccess to return false with no client IP")
}
})

281
internal/server/router.go Normal file
View File

@ -0,0 +1,281 @@
package server
import (
"strings"
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/handler"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/logging"
"rua.plus/lolly/internal/matcher"
"rua.plus/lolly/internal/proxy"
)
// registerProxyRoutesWithLocationEngine 使用 LocationEngine 注册代理路由。
//
// 根据配置为 LocationEngine 注册代理路径,创建代理处理器和健康检查器。
// 支持通过 LocationType 配置不同的匹配方式。
func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) {
for i := range serverCfg.Proxy {
proxyCfg := &serverCfg.Proxy[i]
// 转换目标
targets := make([]*loadbalance.Target, len(proxyCfg.Targets))
for j, t := range proxyCfg.Targets {
failTimeout := t.FailTimeout
if t.MaxFails > 0 && failTimeout == 0 {
failTimeout = 10 * time.Second
}
targets[j] = loadbalance.NewTargetFromConfig(
t.URL, t.Weight,
int64(t.MaxConns), int64(t.MaxFails), failTimeout,
t.Backup, t.Down, t.ProxyURI,
)
}
// 传递 Transport 配置和 Lua 引擎
p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport, s.luaEngine)
if err != nil {
logging.Error().Msg("Failed to create proxy: " + err.Error())
continue
}
// 设置 DNS 解析器(如果已配置)
if s.resolver != nil {
p.SetResolver(s.resolver)
if err := p.Start(); err != nil {
logging.Error().Err(err).Msg("Failed to start proxy")
}
}
// 启动健康检查
if proxyCfg.HealthCheck.Interval > 0 {
hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck)
hc.Start()
s.healthCheckers = append(s.healthCheckers, hc)
// 设置被动健康检查
p.SetHealthChecker(hc)
}
// 保存代理实例用于缓存统计
s.proxies = append(s.proxies, p)
// 根据 LocationType 注册路由
locType := proxyCfg.LocationType
if locType == "" {
locType = matcher.LocationTypePrefix
}
switch locType {
case matcher.LocationTypeExact:
_ = s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
case matcher.LocationTypePrefixPriority:
_ = s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
case matcher.LocationTypeRegex, matcher.LocationTypeRegexCaseless:
caseInsensitive := locType == matcher.LocationTypeRegexCaseless
_ = s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal)
case matcher.LocationTypeNamed:
if proxyCfg.LocationName != "" {
_ = s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP)
}
case matcher.LocationTypePrefix:
_ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
default:
_ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
}
}
}
// registerStaticHandlersWithLocationEngine 使用 LocationEngine 注册静态文件处理器。
func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) {
for _, static := range cfg.Static {
path := static.Path
if path == "" {
path = "/"
}
staticHandler := handler.NewStaticHandler(
static.Root,
path,
static.Index,
true, // useSendfile
)
// 设置 alias与 root 互斥)
if static.Alias != "" {
staticHandler.SetAlias(static.Alias)
}
if s.fileCache != nil {
staticHandler.SetFileCache(s.fileCache)
// 设置默认缓存 TTL (5s)
staticHandler.SetCacheTTL(5 * time.Second)
}
if cfg.Compression.GzipStatic {
// extensions: 源文件类型,为空使用默认值
// GzipStaticExtensions: 预压缩文件扩展名(如 .br, .gz
staticHandler.SetGzipStatic(true, nil, cfg.Compression.GzipStaticExtensions)
}
// 设置符号链接安全检查
staticHandler.SetSymlinkCheck(static.SymlinkCheck)
// 设置 internal 限制
staticHandler.SetInternal(static.Internal)
// 设置缓存过期时间
if static.Expires != "" {
staticHandler.SetExpires(static.Expires)
}
// 根据 LocationType 注册路由
locType := static.LocationType
if locType == "" {
locType = matcher.LocationTypePrefix
}
switch locType {
case matcher.LocationTypeExact:
_ = s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal)
case matcher.LocationTypePrefixPriority:
_ = s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal)
case matcher.LocationTypePrefix:
_ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal)
default:
_ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal)
}
}
}
// registerProxyRoutes 注册代理路由。
//
// 根据配置为路由器注册代理路径,创建代理处理器和健康检查器。
// 支持 GET、POST、PUT、DELETE、HEAD 等 HTTP 方法。
//
// 参数:
// - router: 路由器实例,用于注册路由规则
// - serverCfg: 服务器配置,包含代理目标、负载均衡、健康检查等设置
//
// 注意事项:
// - 代理目标初始状态默认为健康
// - 健康检查根据配置自动启动
func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.ServerConfig) {
for i := range serverCfg.Proxy {
proxyCfg := &serverCfg.Proxy[i]
// 转换目标
targets := make([]*loadbalance.Target, len(proxyCfg.Targets))
for j, t := range proxyCfg.Targets {
failTimeout := t.FailTimeout
if t.MaxFails > 0 && failTimeout == 0 {
failTimeout = 10 * time.Second
}
targets[j] = loadbalance.NewTargetFromConfig(
t.URL, t.Weight,
int64(t.MaxConns), int64(t.MaxFails), failTimeout,
t.Backup, t.Down, t.ProxyURI,
)
}
// 传递 Transport 配置和 Lua 引擎
p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport, s.luaEngine)
if err != nil {
logging.Error().Msg("Failed to create proxy: " + err.Error())
continue
}
// 设置 DNS 解析器(如果已配置)
if s.resolver != nil {
p.SetResolver(s.resolver)
if err := p.Start(); err != nil {
logging.Error().Err(err).Msg("Failed to start proxy")
}
}
// 启动健康检查
if proxyCfg.HealthCheck.Interval > 0 {
hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck)
hc.Start()
s.healthCheckers = append(s.healthCheckers, hc)
// 设置被动健康检查
p.SetHealthChecker(hc)
}
// 保存代理实例用于缓存统计
s.proxies = append(s.proxies, p)
// 使用前缀匹配(通配符)注册代理路由
// path: / 匹配所有子路径如 /sorry/index
// path: /api/ 匹配 /api/* 所有子路径
routePath := proxyCfg.Path
// 确保通配符路由格式正确
if !strings.HasSuffix(routePath, "/") && routePath != "/" {
routePath += "/"
}
wildcardPath := routePath + "{path:*}"
router.GET(wildcardPath, p.ServeHTTP)
router.POST(wildcardPath, p.ServeHTTP)
router.PUT(wildcardPath, p.ServeHTTP)
router.DELETE(wildcardPath, p.ServeHTTP)
router.HEAD(wildcardPath, p.ServeHTTP)
}
}
// registerStaticHandlers 注册静态文件处理器。
//
// 为路由器注册静态文件服务,支持多个静态目录、文件缓存和预压缩文件。
//
// 参数:
// - router: 路由器实例,用于注册路由规则
// - cfg: 服务器配置,包含静态文件和压缩设置
func (s *Server) registerStaticHandlers(router *handler.Router, cfg *config.ServerConfig) {
for _, static := range cfg.Static {
path := static.Path
if path == "" {
path = "/"
}
staticHandler := handler.NewStaticHandler(
static.Root,
path,
static.Index,
true, // useSendfile
)
// 设置 alias与 root 互斥)
if static.Alias != "" {
staticHandler.SetAlias(static.Alias)
}
if s.fileCache != nil {
staticHandler.SetFileCache(s.fileCache)
// 设置默认缓存 TTL (5s)
staticHandler.SetCacheTTL(5 * time.Second)
}
if cfg.Compression.GzipStatic {
// extensions: 源文件类型,为空使用默认值
// GzipStaticExtensions: 预压缩文件扩展名(如 .br, .gz
staticHandler.SetGzipStatic(true, nil, cfg.Compression.GzipStaticExtensions)
}
// 设置符号链接安全检查
staticHandler.SetSymlinkCheck(static.SymlinkCheck)
// 设置缓存过期时间
if static.Expires != "" {
staticHandler.SetExpires(static.Expires)
}
// 设置 try_files 配置
if len(static.TryFiles) > 0 {
// 注意tryFilesPass 需要路由器支持,当前实现传入 nil
// 如果 tryFilesPass 为 true需要额外处理
staticHandler.SetTryFiles(static.TryFiles, static.TryFilesPass, router)
}
// 注册路由:确保路径以 / 结尾
routePath := path
if !strings.HasSuffix(routePath, "/") {
routePath += "/"
}
router.GET(routePath+"{filepath:*}", staticHandler.Handle)
router.HEAD(routePath+"{filepath:*}", staticHandler.Handle)
}
}

View File

@ -20,7 +20,6 @@
package server
import (
"context"
"crypto/tls"
"fmt"
"net"
@ -34,16 +33,10 @@ import (
"rua.plus/lolly/internal/cache"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/handler"
"rua.plus/lolly/internal/loadbalance"
"rua.plus/lolly/internal/logging"
"rua.plus/lolly/internal/lua"
"rua.plus/lolly/internal/matcher"
"rua.plus/lolly/internal/middleware"
"rua.plus/lolly/internal/middleware/accesslog"
"rua.plus/lolly/internal/middleware/bodylimit"
"rua.plus/lolly/internal/middleware/compression"
"rua.plus/lolly/internal/middleware/errorintercept"
"rua.plus/lolly/internal/middleware/rewrite"
"rua.plus/lolly/internal/middleware/security"
"rua.plus/lolly/internal/mimeutil"
"rua.plus/lolly/internal/proxy"
@ -211,236 +204,6 @@ func (s *Server) GetHandler() fasthttp.RequestHandler {
return s.handler
}
// buildMiddlewareChain 构建中间件链。
//
// 根据服务器配置按顺序构建中间件链,顺序为:
//
// AccessLog -> AccessControl -> RateLimiter -> BasicAuth -> Rewrite -> Compression -> SecurityHeaders
//
// 参数:
// - serverCfg: 单个服务器的配置对象
//
// 返回值:
// - *middleware.Chain: 构建完成的中间件链
// - error: 构建过程中遇到的错误,如中间件创建失败
//
// 注意事项:
// - 各中间件按顺序依次包装请求处理器
// - 未配置的中间件不会添加到链中
func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middleware.Chain, error) {
var middlewares []middleware.Middleware
// 1. AccessLog (已集成)
s.accessLogMiddleware = accesslog.New(&s.config.Logging)
middlewares = append(middlewares, s.accessLogMiddleware)
// 2. Security: AccessControl (IP 访问控制)
if len(serverCfg.Security.Access.Allow) > 0 || len(serverCfg.Security.Access.Deny) > 0 {
ac, err := security.NewAccessControl(&serverCfg.Security.Access)
if err != nil {
return nil, fmt.Errorf("创建访问控制中间件失败: %w", err)
}
middlewares = append(middlewares, ac)
s.accessControl = ac
}
// 3. Security: RateLimiter (速率限制)
if serverCfg.Security.RateLimit.RequestRate > 0 {
rl, err := security.NewRateLimiter(&serverCfg.Security.RateLimit)
if err != nil {
return nil, fmt.Errorf("创建限流中间件失败: %w", err)
}
middlewares = append(middlewares, rl)
}
// 3.5 Security: ConnLimiter (连接数限制)
if serverCfg.Security.RateLimit.ConnLimit > 0 {
cl, err := security.NewConnLimiter(serverCfg.Security.RateLimit.ConnLimit, true, serverCfg.Security.RateLimit.Key)
if err != nil {
return nil, fmt.Errorf("创建连接限制中间件失败: %w", err)
}
middlewares = append(middlewares, cl.Middleware())
}
// 4. Security: BasicAuth (认证)
if len(serverCfg.Security.Auth.Users) > 0 {
auth, err := security.NewBasicAuth(&serverCfg.Security.Auth)
if err != nil {
return nil, fmt.Errorf("创建认证中间件失败: %w", err)
}
middlewares = append(middlewares, auth)
}
// 4.3 Security: AuthRequest (外部认证子请求)
if serverCfg.Security.AuthRequest.Enabled && serverCfg.Security.AuthRequest.URI != "" {
authReq, err := security.NewAuthRequest(serverCfg.Security.AuthRequest)
if err != nil {
return nil, fmt.Errorf("创建外部认证中间件失败: %w", err)
}
middlewares = append(middlewares, authReq)
}
// 4.5 BodyLimit (请求体大小限制)
// 创建 bodylimit 中间件,使用全局配置或默认值
bodyLimitMiddleware := bodylimit.NewWithDefault()
if serverCfg.ClientMaxBodySize != "" {
bl, err := bodylimit.New(serverCfg.ClientMaxBodySize)
if err != nil {
return nil, fmt.Errorf("创建请求体限制中间件失败: %w", err)
}
bodyLimitMiddleware = bl
}
// 添加路径级别的限制配置
for i := range serverCfg.Proxy {
if serverCfg.Proxy[i].ClientMaxBodySize != "" {
if err := bodyLimitMiddleware.AddPathLimit(
serverCfg.Proxy[i].Path,
serverCfg.Proxy[i].ClientMaxBodySize,
); err != nil {
return nil, fmt.Errorf("添加路径请求体限制失败: %w", err)
}
}
}
middlewares = append(middlewares, bodyLimitMiddleware)
// 5. Rewrite (URL 重写)
if len(serverCfg.Rewrite) > 0 {
rw, err := rewrite.New(serverCfg.Rewrite)
if err != nil {
return nil, fmt.Errorf("创建重写中间件失败: %w", err)
}
middlewares = append(middlewares, rw)
}
// 6. Compression (响应压缩)
if serverCfg.Compression.Type != "" {
comp, err := compression.New(&serverCfg.Compression)
if err != nil {
return nil, fmt.Errorf("创建压缩中间件失败: %w", err)
}
middlewares = append(middlewares, comp)
}
// 7. SecurityHeaders (安全头部)
// 如果有任何安全头部配置,则启用
if serverCfg.Security.Headers.XFrameOptions != "" ||
serverCfg.Security.Headers.XContentTypeOptions != "" ||
serverCfg.Security.Headers.ContentSecurityPolicy != "" ||
serverCfg.Security.Headers.ReferrerPolicy != "" ||
serverCfg.Security.Headers.PermissionsPolicy != "" {
headers := security.NewHeadersWithHSTS(&serverCfg.Security.Headers, &serverCfg.SSL.HSTS)
middlewares = append(middlewares, headers)
}
// 8. ErrorIntercept (错误页面拦截)
// 如果配置了错误页面,添加错误拦截中间件
if s.errorPageManager != nil && s.errorPageManager.IsConfigured() {
ei := errorintercept.New(s.errorPageManager)
middlewares = append(middlewares, ei)
}
// Lua 中间件(可选)
if s.luaEngine != nil && serverCfg.Lua != nil && serverCfg.Lua.Enabled {
luaMiddlewares, err := s.buildLuaMiddlewares(serverCfg.Lua)
if err != nil {
return nil, fmt.Errorf("创建 Lua 中间件失败: %w", err)
}
middlewares = append(middlewares, luaMiddlewares...)
}
return middleware.NewChain(middlewares...), nil
}
// buildLuaMiddlewares 根据 Lua 配置创建中间件。
//
// 根据 Scripts 配置创建 LuaMiddleware 或 MultiPhaseLuaMiddleware。
// 支持单脚本和多阶段脚本配置。
//
// 参数:
// - luaCfg: Lua 配置对象
//
// 返回值:
// - []middleware.Middleware: 创建的中间件列表
// - error: 创建过程中遇到的错误
func (s *Server) buildLuaMiddlewares(luaCfg *config.LuaMiddlewareConfig) ([]middleware.Middleware, error) {
if s.luaEngine == nil {
return nil, nil
}
// 按阶段分组脚本
phaseScripts := make(map[string][]config.LuaScriptConfig)
for _, script := range luaCfg.Scripts {
// 默认启用
enabled := script.Enabled
if !enabled && script.Timeout == 0 && script.Path != "" {
enabled = true // 零值时默认启用
}
if enabled {
phaseScripts[script.Phase] = append(phaseScripts[script.Phase], script)
}
}
var middlewares []middleware.Middleware
// 为每个阶段创建中间件
for phase, scripts := range phaseScripts {
if len(scripts) == 0 {
continue
}
// 单脚本:直接创建 LuaMiddleware
if len(scripts) == 1 {
script := scripts[0]
luaPhase, err := lua.ParsePhase(phase)
if err != nil {
return nil, fmt.Errorf("无效的阶段 '%s': %w", phase, err)
}
timeout := script.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
cfg := lua.LuaMiddlewareConfig{
ScriptPath: script.Path,
Phase: luaPhase,
Timeout: timeout,
Name: fmt.Sprintf("lua-%s", phase),
}
mw, err := lua.NewLuaMiddleware(s.luaEngine, cfg)
if err != nil {
return nil, fmt.Errorf("创建 Lua 中间件失败 (phase=%s): %w", phase, err)
}
middlewares = append(middlewares, mw)
} else {
// 多脚本:创建 MultiPhaseLuaMiddleware
multi := lua.NewMultiPhaseLuaMiddleware(s.luaEngine, fmt.Sprintf("lua-multi-%s", phase))
for _, script := range scripts {
luaPhase, err := lua.ParsePhase(phase)
if err != nil {
return nil, fmt.Errorf("无效的阶段 '%s': %w", phase, err)
}
timeout := script.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
err = multi.AddPhase(luaPhase, script.Path, timeout)
if err != nil {
return nil, fmt.Errorf("添加 Lua 阶段失败 (phase=%s): %w", phase, err)
}
}
middlewares = append(middlewares, multi)
}
}
return middlewares, nil
}
// Start 启动 HTTP 服务器。
//
// 初始化日志系统、性能优化组件Goroutine池、文件缓存
@ -549,13 +312,13 @@ func (s *Server) createListener(cfg *config.ServerConfig) (net.Listener, error)
mode = cfg.UnixSocket.Mode
}
if err := os.Chmod(socketPath, os.FileMode(mode)); err != nil {
logging.Warn().Err(err).Msg("设置 socket 文件权限失败")
logging.Warn().Err(err).Msg("Failed to set socket file permissions")
}
// 5. 设置文件所有权(需要 root 权限)
if cfg.UnixSocket.User != "" || cfg.UnixSocket.Group != "" {
// 简化处理:仅记录警告,实际实现需要 syscall.Chown
logging.Warn().Msg("Unix socket 用户/组配置需要 root 权限,已跳过")
logging.Warn().Msg("Unix socket user/group config requires root privileges, skipped")
}
return listener, nil
@ -590,7 +353,7 @@ func (s *Server) startSingleMode() error {
if s.config.Monitoring.Status.Path != "" || len(s.config.Monitoring.Status.Allow) > 0 {
statusHandler, err := NewStatusHandler(s, &s.config.Monitoring.Status)
if err != nil {
logging.Error().Msg("创建状态处理器失败: " + err.Error())
logging.Error().Msg("Failed to create status handler: " + err.Error())
} else {
_ = s.locationEngine.AddExact(statusHandler.Path(), statusHandler.ServeHTTP, false)
}
@ -600,7 +363,7 @@ func (s *Server) startSingleMode() error {
if s.config.Monitoring.Pprof.Enabled {
pprofHandler, err := NewPprofHandler(&s.config.Monitoring.Pprof)
if err != nil {
logging.Error().Msg("创建 pprof 处理器失败: " + err.Error())
logging.Error().Msg("Failed to create pprof handler: " + err.Error())
} else {
_ = s.locationEngine.AddExact(pprofHandler.Path(), pprofHandler.ServeHTTP, false)
_ = s.locationEngine.AddPrefixPriority(pprofHandler.Path()+"/", pprofHandler.ServeHTTP, false)
@ -611,7 +374,7 @@ func (s *Server) startSingleMode() error {
if serverCfg.CacheAPI != nil && serverCfg.CacheAPI.Enabled {
purgeHandler, err := NewPurgeHandler(s, serverCfg.CacheAPI)
if err != nil {
logging.Error().Msg("创建缓存清理处理器失败: " + err.Error())
logging.Error().Msg("Failed to create cache purge handler: " + err.Error())
} else {
_ = s.locationEngine.AddExact(purgeHandler.Path(), purgeHandler.ServeHTTP, false)
}
@ -685,7 +448,7 @@ func (s *Server) startSingleMode() error {
var err error
s.tlsManager, err = ssl.NewTLSManager(&serverCfg.SSL)
if err != nil {
return fmt.Errorf("创建 TLS 管理器失败: %w", err)
return fmt.Errorf("failed to create TLS manager: %w", err)
}
s.fastServer.TLSConfig = s.tlsManager.GetTLSConfig()
return s.fastServer.ServeTLS(ln, "", "")
@ -747,7 +510,7 @@ func (s *Server) startVHostMode() error {
if s.config.Monitoring.Status.Enabled {
statusHandler, err := NewStatusHandler(s, &s.config.Monitoring.Status)
if err != nil {
logging.Error().Msg("创建状态处理器失败: " + err.Error())
logging.Error().Msg("Failed to create status handler: " + err.Error())
} else {
router.GET(statusHandler.Path(), statusHandler.ServeHTTP)
}
@ -757,7 +520,7 @@ func (s *Server) startVHostMode() error {
if s.config.Monitoring.Pprof.Enabled {
pprofHandler, err := NewPprofHandler(&s.config.Monitoring.Pprof)
if err != nil {
logging.Error().Msg("创建 pprof 处理器失败: " + err.Error())
logging.Error().Msg("Failed to create pprof handler: " + err.Error())
} else {
router.GET(pprofHandler.Path(), pprofHandler.ServeHTTP)
router.GET(pprofHandler.Path()+"/{profile:*}", pprofHandler.ServeHTTP)
@ -769,7 +532,7 @@ func (s *Server) startVHostMode() error {
if defaultSrv != nil && defaultSrv.CacheAPI != nil && defaultSrv.CacheAPI.Enabled {
purgeHandler, err := NewPurgeHandler(s, defaultSrv.CacheAPI)
if err != nil {
logging.Error().Msg("创建缓存清理处理器失败: " + err.Error())
logging.Error().Msg("Failed to create cache purge handler: " + err.Error())
} else {
router.POST(purgeHandler.Path(), purgeHandler.ServeHTTP)
}
@ -829,7 +592,7 @@ func (s *Server) startVHostMode() error {
var err error
s.tlsManager, err = ssl.NewTLSManager(&serverCfg.SSL)
if err != nil {
return fmt.Errorf("创建 TLS 管理器失败: %w", err)
return fmt.Errorf("failed to create TLS manager: %w", err)
}
s.fastServer.TLSConfig = s.tlsManager.GetTLSConfig()
return s.fastServer.ServeTLS(ln, "", "")
@ -853,7 +616,7 @@ func (s *Server) startVHostMode() error {
func (s *Server) startMultiServerMode() error {
// 热升级检测multi_server 热升级未实现,回退到 vhost 模式
if os.Getenv("GRACEFUL_UPGRADE") == "1" {
logging.Warn().Msg("热升级模式下 multi_server 模式未实现,回退到虚拟主机模式")
logging.Warn().Msg("multi_server mode not implemented for graceful upgrade, falling back to vhost mode")
return s.startVHostMode()
}
@ -874,7 +637,7 @@ func (s *Server) startMultiServerMode() error {
// 创建监听器
ln, err := s.createListener(serverCfg)
if err != nil {
errCh <- fmt.Errorf("监听地址 %s 失败: %w", serverCfg.Listen, err)
errCh <- fmt.Errorf("failed to listen on %s: %w", serverCfg.Listen, err)
return
}
s.listeners[idx] = ln
@ -886,7 +649,7 @@ func (s *Server) startMultiServerMode() error {
if idx == 0 && serverCfg.CacheAPI != nil && serverCfg.CacheAPI.Enabled {
purgeHandler, purgeErr := NewPurgeHandler(s, serverCfg.CacheAPI)
if purgeErr != nil {
errCh <- fmt.Errorf("创建缓存清理处理器失败 (server[%d]): %w", idx, purgeErr)
errCh <- fmt.Errorf("failed to create cache purge handler (server[%d]): %w", idx, purgeErr)
return
}
router.POST(purgeHandler.Path(), purgeHandler.ServeHTTP)
@ -900,7 +663,7 @@ func (s *Server) startMultiServerMode() error {
// 构建独立的中间件链
chain, err := s.buildMiddlewareChain(serverCfg)
if err != nil {
errCh <- fmt.Errorf("构建中间件链失败 (server[%d]): %w", idx, err)
errCh <- fmt.Errorf("failed to build middleware chain (server[%d]): %w", idx, err)
return
}
@ -927,7 +690,7 @@ func (s *Server) startMultiServerMode() error {
if serverCfg.SSL.Cert != "" && serverCfg.SSL.Key != "" {
tlsManager, err := ssl.NewTLSManager(&serverCfg.SSL)
if err != nil {
errCh <- fmt.Errorf("创建 TLS 管理器失败 (server[%d]): %w", idx, err)
errCh <- fmt.Errorf("failed to create TLS manager (server[%d]): %w", idx, err)
return
}
fastSrv.TLSConfig = tlsManager.GetTLSConfig()
@ -978,7 +741,7 @@ func (s *Server) startMultiServerMode() error {
serveErr = f.Serve(l)
}
if serveErr != nil {
logging.Error().Err(serveErr).Msgf("服务器 [%d] 监听 %s 时发生错误", i, l.Addr())
logging.Error().Err(serveErr).Msgf("Server [%d] error while listening on %s", i, l.Addr())
}
}(fastSrv, ln, idx)
}
@ -988,509 +751,6 @@ func (s *Server) startMultiServerMode() error {
return nil
}
// registerProxyRoutesWithLocationEngine 使用 LocationEngine 注册代理路由。
//
// 根据配置为 LocationEngine 注册代理路径,创建代理处理器和健康检查器。
// 支持通过 LocationType 配置不同的匹配方式。
func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerConfig) {
for i := range serverCfg.Proxy {
proxyCfg := &serverCfg.Proxy[i]
// 转换目标
targets := make([]*loadbalance.Target, len(proxyCfg.Targets))
for j, t := range proxyCfg.Targets {
failTimeout := t.FailTimeout
if t.MaxFails > 0 && failTimeout == 0 {
failTimeout = 10 * time.Second
}
targets[j] = loadbalance.NewTargetFromConfig(
t.URL, t.Weight,
int64(t.MaxConns), int64(t.MaxFails), failTimeout,
t.Backup, t.Down, t.ProxyURI,
)
}
// 传递 Transport 配置和 Lua 引擎
p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport, s.luaEngine)
if err != nil {
logging.Error().Msg("创建代理失败: " + err.Error())
continue
}
// 设置 DNS 解析器(如果已配置)
if s.resolver != nil {
p.SetResolver(s.resolver)
if err := p.Start(); err != nil {
logging.Error().Err(err).Msg("启动代理失败")
}
}
// 启动健康检查
if proxyCfg.HealthCheck.Interval > 0 {
hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck)
hc.Start()
s.healthCheckers = append(s.healthCheckers, hc)
// 设置被动健康检查
p.SetHealthChecker(hc)
}
// 保存代理实例用于缓存统计
s.proxies = append(s.proxies, p)
// 根据 LocationType 注册路由
locType := proxyCfg.LocationType
if locType == "" {
locType = matcher.LocationTypePrefix
}
switch locType {
case matcher.LocationTypeExact:
_ = s.locationEngine.AddExact(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
case matcher.LocationTypePrefixPriority:
_ = s.locationEngine.AddPrefixPriority(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
case matcher.LocationTypeRegex, matcher.LocationTypeRegexCaseless:
caseInsensitive := locType == matcher.LocationTypeRegexCaseless
_ = s.locationEngine.AddRegex(proxyCfg.Path, p.ServeHTTP, caseInsensitive, proxyCfg.Internal)
case matcher.LocationTypeNamed:
if proxyCfg.LocationName != "" {
_ = s.locationEngine.AddNamed(proxyCfg.LocationName, p.ServeHTTP)
}
case matcher.LocationTypePrefix:
_ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
default:
_ = s.locationEngine.AddPrefix(proxyCfg.Path, p.ServeHTTP, proxyCfg.Internal)
}
}
}
// registerStaticHandlersWithLocationEngine 使用 LocationEngine 注册静态文件处理器。
func (s *Server) registerStaticHandlersWithLocationEngine(cfg *config.ServerConfig) {
for _, static := range cfg.Static {
path := static.Path
if path == "" {
path = "/"
}
staticHandler := handler.NewStaticHandler(
static.Root,
path,
static.Index,
true, // useSendfile
)
// 设置 alias与 root 互斥)
if static.Alias != "" {
staticHandler.SetAlias(static.Alias)
}
if s.fileCache != nil {
staticHandler.SetFileCache(s.fileCache)
// 设置默认缓存 TTL (5s)
staticHandler.SetCacheTTL(5 * time.Second)
}
if cfg.Compression.GzipStatic {
// extensions: 源文件类型,为空使用默认值
// GzipStaticExtensions: 预压缩文件扩展名(如 .br, .gz
staticHandler.SetGzipStatic(true, nil, cfg.Compression.GzipStaticExtensions)
}
// 设置符号链接安全检查
staticHandler.SetSymlinkCheck(static.SymlinkCheck)
// 设置 internal 限制
staticHandler.SetInternal(static.Internal)
// 设置缓存过期时间
if static.Expires != "" {
staticHandler.SetExpires(static.Expires)
}
// 根据 LocationType 注册路由
locType := static.LocationType
if locType == "" {
locType = matcher.LocationTypePrefix
}
switch locType {
case matcher.LocationTypeExact:
_ = s.locationEngine.AddExact(path, staticHandler.Handle, static.Internal)
case matcher.LocationTypePrefixPriority:
_ = s.locationEngine.AddPrefixPriority(path, staticHandler.Handle, static.Internal)
case matcher.LocationTypePrefix:
_ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal)
default:
_ = s.locationEngine.AddPrefix(path, staticHandler.Handle, static.Internal)
}
}
}
// registerProxyRoutes 注册代理路由。
//
// 根据配置为路由器注册代理路径,创建代理处理器和健康检查器。
// 支持 GET、POST、PUT、DELETE、HEAD 等 HTTP 方法。
//
// 参数:
// - router: 路由器实例,用于注册路由规则
// - serverCfg: 服务器配置,包含代理目标、负载均衡、健康检查等设置
//
// 注意事项:
// - 代理目标初始状态默认为健康
// - 健康检查根据配置自动启动
func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.ServerConfig) {
for i := range serverCfg.Proxy {
proxyCfg := &serverCfg.Proxy[i]
// 转换目标
targets := make([]*loadbalance.Target, len(proxyCfg.Targets))
for j, t := range proxyCfg.Targets {
failTimeout := t.FailTimeout
if t.MaxFails > 0 && failTimeout == 0 {
failTimeout = 10 * time.Second
}
targets[j] = loadbalance.NewTargetFromConfig(
t.URL, t.Weight,
int64(t.MaxConns), int64(t.MaxFails), failTimeout,
t.Backup, t.Down, t.ProxyURI,
)
}
// 传递 Transport 配置和 Lua 引擎
p, err := proxy.NewProxy(proxyCfg, targets, &s.config.Performance.Transport, s.luaEngine)
if err != nil {
logging.Error().Msg("创建代理失败: " + err.Error())
continue
}
// 设置 DNS 解析器(如果已配置)
if s.resolver != nil {
p.SetResolver(s.resolver)
if err := p.Start(); err != nil {
logging.Error().Err(err).Msg("启动代理失败")
}
}
// 启动健康检查
if proxyCfg.HealthCheck.Interval > 0 {
hc := proxy.NewHealthChecker(targets, &proxyCfg.HealthCheck)
hc.Start()
s.healthCheckers = append(s.healthCheckers, hc)
// 设置被动健康检查
p.SetHealthChecker(hc)
}
// 保存代理实例用于缓存统计
s.proxies = append(s.proxies, p)
// 使用前缀匹配(通配符)注册代理路由
// path: / 匹配所有子路径如 /sorry/index
// path: /api/ 匹配 /api/* 所有子路径
routePath := proxyCfg.Path
// 确保通配符路由格式正确
if !strings.HasSuffix(routePath, "/") && routePath != "/" {
routePath += "/"
}
wildcardPath := routePath + "{path:*}"
router.GET(wildcardPath, p.ServeHTTP)
router.POST(wildcardPath, p.ServeHTTP)
router.PUT(wildcardPath, p.ServeHTTP)
router.DELETE(wildcardPath, p.ServeHTTP)
router.HEAD(wildcardPath, p.ServeHTTP)
}
}
// shutdownServers 并行关闭多个 fasthttp.Server 实例。
//
// 使用 goroutine 并行关闭所有服务器,收集所有错误并返回聚合错误。
// 部分服务器关闭失败不会影响其他服务器的关闭。
//
// 参数:
// - ctx: 关闭上下文,用于控制超时和取消
// - servers: 要关闭的 fasthttp.Server 实例列表
//
// 返回值:
// - error: 聚合错误,无错误或全部成功时返回 nil
func shutdownServers(ctx context.Context, servers []*fasthttp.Server) error {
// 防御性检查nil ctx 使用默认背景
if ctx == nil {
ctx = context.Background()
}
if len(servers) == 0 {
return nil
}
var (
mu sync.Mutex
errs []error
wg sync.WaitGroup
)
for _, srv := range servers {
if srv == nil {
continue
}
wg.Add(1)
go func(s *fasthttp.Server) {
defer wg.Done()
if err := s.Shutdown(); err != nil {
mu.Lock()
errs = append(errs, err)
mu.Unlock()
}
}(srv)
}
// 等待所有关闭完成或上下文取消
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
if len(errs) == 0 {
return nil
}
if len(errs) == 1 {
return errs[0]
}
msgs := make([]string, len(errs))
for i, e := range errs {
msgs[i] = e.Error()
}
return fmt.Errorf("关闭服务器时发生 %d 个错误: %s", len(errs), strings.Join(msgs, "; "))
case <-ctx.Done():
return ctx.Err()
}
}
// StopWithTimeout 快速停止服务器(支持自定义超时)。
//
// 立即停止服务器,不等待正在处理的请求完成。
// 停止所有健康检查器和访问日志中间件。
//
// 参数:
// - timeout: 快速关闭的最大等待时间
//
// 返回值:
// - error: 停止过程中遇到的错误
//
// 注意事项:
// - 对于生产环境,建议使用 GracefulStop 实现优雅关闭
// - timeout <= 0 时会使用默认 5s 超时
func (s *Server) StopWithTimeout(timeout time.Duration) error {
// 防御性检查:如果 timeout <= 0使用默认值
if timeout <= 0 {
timeout = 5 * time.Second
}
s.running = false
// 停止 Goroutine 池
if s.pool != nil {
s.pool.Stop()
}
// 停止健康检查器
for _, hc := range s.healthCheckers {
hc.Stop()
}
// 关闭访问日志
if s.accessLogMiddleware != nil {
_ = s.accessLogMiddleware.Close()
}
// 关闭 TLS 管理器
if s.tlsManager != nil {
s.tlsManager.Close()
}
// 关闭 AccessControl (释放 GeoIP 资源)
if s.accessControl != nil {
if err := s.accessControl.Close(); err != nil {
logging.Warn().Err(err).Msg("关闭 AccessControl 失败")
}
}
// 关闭 Lua 引擎
if s.luaEngine != nil {
s.luaEngine.Close()
logging.Info().Msg("Lua 引擎已关闭")
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// 多服务器模式:并行关闭所有 fasthttp.Server
if len(s.fastServers) > 0 {
return shutdownServers(ctx, s.fastServers)
}
// 单服务器模式:关闭单个 fasthttp.Server
if s.fastServer != nil {
done := make(chan struct{})
go func() {
_ = s.fastServer.Shutdown()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
return nil
}
// GracefulStop 优雅停止服务器。
//
// 等待正在处理的请求完成后再停止服务器,确保连接正常关闭。
// 如果超时时间到达仍有请求未完成,将返回超时错误。
//
// 参数:
// - timeout: 优雅关闭的最大等待时间
//
// 返回值:
// - error: 停止过程中遇到的错误,超时返回 context.DeadlineExceeded
//
// 注意事项:
// - 推荐在生产环境使用此方法关闭服务器
// - 超时后会强制关闭,可能导致部分请求中断
func (s *Server) GracefulStop(timeout time.Duration) error {
s.running = false
// 停止 Goroutine 池
if s.pool != nil {
s.pool.Stop()
}
// 停止健康检查器
for _, hc := range s.healthCheckers {
hc.Stop()
}
// 关闭访问日志
if s.accessLogMiddleware != nil {
_ = s.accessLogMiddleware.Close()
}
// 关闭 TLS 管理器
if s.tlsManager != nil {
s.tlsManager.Close()
}
// 关闭 AccessControl (释放 GeoIP 资源)
if s.accessControl != nil {
if err := s.accessControl.Close(); err != nil {
logging.Warn().Err(err).Msg("关闭 AccessControl 失败")
}
}
// 关闭 Lua 引擎
if s.luaEngine != nil {
s.luaEngine.Close()
logging.Info().Msg("Lua 引擎已关闭")
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// 多服务器模式:并行关闭所有 fasthttp.Server
if len(s.fastServers) > 0 {
return shutdownServers(ctx, s.fastServers)
}
// 单服务器模式:关闭单个 fasthttp.Server
if s.fastServer != nil {
done := make(chan struct{})
go func() {
_ = s.fastServer.Shutdown()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
return nil
}
// getProxyCacheStats 收集所有代理缓存的统计信息。
func (s *Server) getProxyCacheStats() ProxyCacheStats {
var total ProxyCacheStats
for _, p := range s.proxies {
if stats := p.GetCacheStats(); stats != nil {
total.Entries += stats.Entries
total.Pending += stats.Pending
}
}
return total
}
// registerStaticHandlers 注册静态文件处理器。
//
// 为路由器注册静态文件服务,支持多个静态目录、文件缓存和预压缩文件。
//
// 参数:
// - router: 路由器实例,用于注册路由规则
// - cfg: 服务器配置,包含静态文件和压缩设置
func (s *Server) registerStaticHandlers(router *handler.Router, cfg *config.ServerConfig) {
for _, static := range cfg.Static {
path := static.Path
if path == "" {
path = "/"
}
staticHandler := handler.NewStaticHandler(
static.Root,
path,
static.Index,
true, // useSendfile
)
// 设置 alias与 root 互斥)
if static.Alias != "" {
staticHandler.SetAlias(static.Alias)
}
if s.fileCache != nil {
staticHandler.SetFileCache(s.fileCache)
// 设置默认缓存 TTL (5s)
staticHandler.SetCacheTTL(5 * time.Second)
}
if cfg.Compression.GzipStatic {
// extensions: 源文件类型,为空使用默认值
// GzipStaticExtensions: 预压缩文件扩展名(如 .br, .gz
staticHandler.SetGzipStatic(true, nil, cfg.Compression.GzipStaticExtensions)
}
// 设置符号链接安全检查
staticHandler.SetSymlinkCheck(static.SymlinkCheck)
// 设置缓存过期时间
if static.Expires != "" {
staticHandler.SetExpires(static.Expires)
}
// 设置 try_files 配置
if len(static.TryFiles) > 0 {
// 注意tryFilesPass 需要路由器支持,当前实现传入 nil
// 如果 tryFilesPass 为 true需要额外处理
staticHandler.SetTryFiles(static.TryFiles, static.TryFilesPass, router)
}
// 注册路由:确保路径以 / 结尾
routePath := path
if !strings.HasSuffix(routePath, "/") {
routePath += "/"
}
router.GET(routePath+"{filepath:*}", staticHandler.Handle)
router.HEAD(routePath+"{filepath:*}", staticHandler.Handle)
}
}
// SetResolver 设置 DNS 解析器。
func (s *Server) SetResolver(r resolver.Resolver) {
s.resolver = r

View File

@ -21,7 +21,6 @@ import (
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/netutil"
"rua.plus/lolly/internal/utils"
)
@ -194,7 +193,7 @@ func (h *StatusHandler) Path() string {
// - ctx: FastHTTP 请求上下文
func (h *StatusHandler) ServeHTTP(ctx *fasthttp.RequestCtx) {
// 步骤1: 检查 IP 访问权限
if !h.checkAccess(ctx) {
if !utils.CheckIPAccess(ctx, h.allowed) {
utils.SendErrorWithDetail(ctx, utils.ErrForbidden, "Access denied")
return
}
@ -509,34 +508,6 @@ func (h *StatusHandler) serveHTML(ctx *fasthttp.RequestCtx, status *Status) {
}
}
// checkAccess 检查客户端 IP 是否在允许列表中。
//
// 如果未配置允许列表,则允许所有访问。
// 检查时支持代理头部X-Forwarded-For、X-Real-IP
//
// 参数:
// - ctx: FastHTTP 请求上下文
//
// 返回值:
// - bool: true 表示允许访问false 表示拒绝
func (h *StatusHandler) checkAccess(ctx *fasthttp.RequestCtx) bool {
// 如果没有配置允许列表,允许所有访问
if len(h.allowed) == 0 {
return true
}
clientIP := netutil.ExtractClientIPNet(ctx)
// 检查是否在允许列表中
for _, network := range h.allowed {
if network.Contains(clientIP) {
return true
}
}
return false
}
// collectStatus 收集服务器状态数据。
//
// 从服务器实例读取各项统计指标,构建状态响应对象。

View File

@ -24,6 +24,7 @@ import (
"rua.plus/lolly/internal/cache"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/netutil"
"rua.plus/lolly/internal/utils"
)
func TestNewStatusHandler_CIDR(t *testing.T) {
@ -896,7 +897,7 @@ func TestStatusHandler_checkAccess_AllowList(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.SetRemoteAddr(&net.TCPAddr{IP: tt.remoteIP, Port: 12345})
got := h.checkAccess(ctx)
got := utils.CheckIPAccess(ctx, h.allowed)
if got != tt.wantAccess {
t.Errorf("expected access %v, got %v", tt.wantAccess, got)
}

View File

@ -23,6 +23,7 @@ import (
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/netutil"
"rua.plus/lolly/internal/utils"
)
// VHostManager 虚拟主机管理器。
@ -212,7 +213,7 @@ func (v *VHostManager) Handler() fasthttp.RequestHandler {
if vhost := v.FindHost(host); vhost != nil {
vhost.handler(ctx)
} else {
ctx.Error("Host not found", fasthttp.StatusNotFound)
utils.SendError(ctx, utils.ErrNotFound)
}
}
}

View File

@ -22,6 +22,7 @@ import (
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/logging"
)
const (
@ -192,9 +193,7 @@ func (m *SessionTicketManager) RotateKey() error {
// 如果有密钥文件,保存所有密钥
if m.config.KeyFile != "" {
if err := m.saveKeys(); err != nil {
// 保存失败不影响运行,记录错误即可
// 这里可以考虑添加日志
_ = err
logging.Warn().Err(err).Msg("Session Ticket 密钥保存失败")
}
}

View File

@ -47,6 +47,7 @@ import (
"sync"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/logging"
"rua.plus/lolly/internal/netutil"
)
@ -145,9 +146,7 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
if cfg.SessionTickets.Enabled {
sessionTicketMgr, err := NewSessionTicketManager(cfg.SessionTickets)
if err != nil {
// Session Tickets 初始化失败不阻止 TLS 工作
// 可以记录日志
_ = err
logging.Warn().Err(err).Msg("Session Ticket 初始化失败TLS 性能可能降级")
} else {
manager.sessionTicketMgr = sessionTicketMgr
// 应用 Session Tickets 到 TLS 配置
@ -174,9 +173,9 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
issuerCert, err := x509.ParseCertificate(cert.Certificate[1])
if err == nil {
manager.issuers[serial] = issuerCert
// 注册证书用于 OCSP Stapling
// 错误会记录日志但不会阻止 TLS 工作
_ = ocspMgr.RegisterCertificate(parsedCert, issuerCert)
if err := ocspMgr.RegisterCertificate(parsedCert, issuerCert); err != nil {
logging.Warn().Err(err).Msg("OCSP Stapling 注册失败")
}
}
}
@ -192,9 +191,7 @@ func NewTLSManager(cfg *config.SSLConfig) (*TLSManager, error) {
if cfg.ClientVerify.Enabled {
clientVerifier, err := NewClientVerifier(cfg.ClientVerify)
if err != nil {
// 客户端验证配置失败不阻止 TLS 工作
// 可以记录日志
_ = err
logging.Warn().Err(err).Msg("客户端证书验证配置失败")
} else {
manager.clientVerifier = clientVerifier
clientVerifier.ConfigureTLS(tlsCfg)

View File

@ -571,9 +571,22 @@ func (s *Server) handleConnection(clientConn net.Conn, _ string) {
}
defer func() { _ = targetConn.Close() }()
// 双向数据转发
go func() { _, _ = io.Copy(targetConn, clientConn) }()
// 双向数据转发:任一方向完成/出错时立即关闭双端连接,迫使另一方向退出
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = io.Copy(targetConn, clientConn)
_ = clientConn.Close()
_ = targetConn.Close()
}()
go func() {
defer wg.Done()
_, _ = io.Copy(clientConn, targetConn)
_ = targetConn.Close()
_ = clientConn.Close()
}()
wg.Wait()
}
// Select 选择健康的上游目标。

View File

@ -4,7 +4,16 @@
// the scattered pattern of ctx.Error throughout the codebase.
package utils
import "github.com/valyala/fasthttp"
import (
"crypto/subtle"
"encoding/json"
"net"
"strings"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/netutil"
)
// HTTPError represents an HTTP error with a message and status code.
type HTTPError struct {
@ -37,3 +46,57 @@ func SendErrorWithDetail(ctx *fasthttp.RequestCtx, err HTTPError, detail string)
SendError(ctx, err)
}
}
// SendJSONError sends a JSON error response.
func SendJSONError(ctx *fasthttp.RequestCtx, status int, errMsg string) {
ctx.SetContentType("application/json; charset=utf-8")
ctx.SetStatusCode(status)
_ = json.NewEncoder(ctx).Encode(struct {
Error string `json:"error"`
}{Error: errMsg})
}
// CheckIPAccess checks whether the client IP is in the allowed list.
// If allowed is empty, all access is permitted.
func CheckIPAccess(ctx *fasthttp.RequestCtx, allowed []net.IPNet) bool {
if len(allowed) == 0 {
return true
}
clientIP := netutil.ExtractClientIPNet(ctx)
if clientIP == nil {
return false
}
for _, network := range allowed {
if network.Contains(clientIP) {
return true
}
}
return false
}
// CheckTokenAuth checks token-based authentication.
// Returns true if auth is disabled or the token matches.
func CheckTokenAuth(ctx *fasthttp.RequestCtx, auth config.CacheAPIAuthConfig) bool {
if auth.Type == "" || auth.Type == "none" {
return true
}
if auth.Type == "token" {
authHeader := ctx.Request.Header.Peek("Authorization")
if len(authHeader) == 0 {
return false
}
authStr := string(authHeader)
if token, ok := strings.CutPrefix(authStr, "Bearer "); ok {
return subtle.ConstantTimeCompare([]byte(token), []byte(auth.Token)) == 1
}
return subtle.ConstantTimeCompare([]byte(authStr), []byte(auth.Token)) == 1
}
return false
}