lolly/internal/server/middleware_builder.go
xfy dff88449d5 feat(lua): skip routed scripts in middleware builder
Scripts with Route config are handled by LocationEngine, so skip them
in buildLuaMiddlewares to avoid duplicate processing.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-09 11:38:40 +08:00

252 lines
7.7 KiB
Go

package server
import (
"fmt"
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/lua"
"rua.plus/lolly/internal/middleware"
"rua.plus/lolly/internal/middleware/accesslog"
"rua.plus/lolly/internal/middleware/bodylimit"
"rua.plus/lolly/internal/middleware/compression"
"rua.plus/lolly/internal/middleware/errorintercept"
"rua.plus/lolly/internal/middleware/rewrite"
"rua.plus/lolly/internal/middleware/security"
)
// buildMiddlewareChain 构建中间件链。
//
// 根据服务器配置按顺序构建中间件链,顺序为:
//
// AccessLog -> AccessControl -> RateLimiter -> BasicAuth -> Rewrite -> Compression -> SecurityHeaders
//
// 参数:
// - serverCfg: 单个服务器的配置对象
//
// 返回值:
// - *middleware.Chain: 构建完成的中间件链
// - error: 构建过程中遇到的错误,如中间件创建失败
//
// 注意事项:
// - 各中间件按顺序依次包装请求处理器
// - 未配置的中间件不会添加到链中
func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middleware.Chain, error) {
var middlewares []middleware.Middleware
// 1. AccessLog (已集成)
s.accessLogMiddleware = accesslog.New(&s.config.Logging)
middlewares = append(middlewares, s.accessLogMiddleware)
// 2. Security: AccessControl (IP 访问控制)
if len(serverCfg.Security.Access.Allow) > 0 || len(serverCfg.Security.Access.Deny) > 0 {
ac, err := security.NewAccessControl(&serverCfg.Security.Access)
if err != nil {
return nil, fmt.Errorf("failed to create access control middleware: %w", err)
}
middlewares = append(middlewares, ac)
s.accessControl = ac
}
// 3. Security: RateLimiter (速率限制)
if serverCfg.Security.RateLimit.RequestRate > 0 {
rl, err := security.NewRateLimiter(&serverCfg.Security.RateLimit)
if err != nil {
return nil, fmt.Errorf("failed to create rate limiter middleware: %w", err)
}
middlewares = append(middlewares, rl)
}
// 3.5 Security: ConnLimiter (连接数限制)
if serverCfg.Security.RateLimit.ConnLimit > 0 {
cl, err := security.NewConnLimiter(serverCfg.Security.RateLimit.ConnLimit, true, serverCfg.Security.RateLimit.Key)
if err != nil {
return nil, fmt.Errorf("failed to create connection limiter middleware: %w", err)
}
middlewares = append(middlewares, cl.Middleware())
}
// 4. Security: BasicAuth (认证)
if len(serverCfg.Security.Auth.Users) > 0 {
auth, err := security.NewBasicAuth(&serverCfg.Security.Auth)
if err != nil {
return nil, fmt.Errorf("failed to create auth middleware: %w", err)
}
middlewares = append(middlewares, auth)
}
// 4.3 Security: AuthRequest (外部认证子请求)
if serverCfg.Security.AuthRequest.Enabled && serverCfg.Security.AuthRequest.URI != "" {
authReq, err := security.NewAuthRequest(serverCfg.Security.AuthRequest)
if err != nil {
return nil, fmt.Errorf("failed to create auth request middleware: %w", err)
}
middlewares = append(middlewares, authReq)
}
// 4.5 BodyLimit (请求体大小限制)
// 创建 bodylimit 中间件,使用全局配置或默认值
bodyLimitMiddleware := bodylimit.NewWithDefault()
if serverCfg.ClientMaxBodySize != "" {
bl, err := bodylimit.New(serverCfg.ClientMaxBodySize)
if err != nil {
return nil, fmt.Errorf("failed to create body limit middleware: %w", err)
}
bodyLimitMiddleware = bl
}
// 添加路径级别的限制配置
for i := range serverCfg.Proxy {
if serverCfg.Proxy[i].ClientMaxBodySize != "" {
if err := bodyLimitMiddleware.AddPathLimit(
serverCfg.Proxy[i].Path,
serverCfg.Proxy[i].ClientMaxBodySize,
); err != nil {
return nil, fmt.Errorf("failed to add path body limit: %w", err)
}
}
}
middlewares = append(middlewares, bodyLimitMiddleware)
// 5. Rewrite (URL 重写)
if len(serverCfg.Rewrite) > 0 {
rw, err := rewrite.New(serverCfg.Rewrite)
if err != nil {
return nil, fmt.Errorf("failed to create rewrite middleware: %w", err)
}
middlewares = append(middlewares, rw)
}
// 6. Compression (响应压缩)
if serverCfg.Compression.Type != "" {
comp, err := compression.New(&serverCfg.Compression)
if err != nil {
return nil, fmt.Errorf("failed to create compression middleware: %w", err)
}
middlewares = append(middlewares, comp)
}
// 7. SecurityHeaders (安全头部)
// 如果有任何安全头部配置,则启用
if serverCfg.Security.Headers.XFrameOptions != "" ||
serverCfg.Security.Headers.XContentTypeOptions != "" ||
serverCfg.Security.Headers.ContentSecurityPolicy != "" ||
serverCfg.Security.Headers.ReferrerPolicy != "" ||
serverCfg.Security.Headers.PermissionsPolicy != "" {
headers := security.NewHeadersWithHSTS(&serverCfg.Security.Headers, &serverCfg.SSL.HSTS)
middlewares = append(middlewares, headers)
}
// 8. ErrorIntercept (错误页面拦截)
// 如果配置了错误页面,添加错误拦截中间件
if s.errorPageManager != nil && s.errorPageManager.IsConfigured() {
ei := errorintercept.New(s.errorPageManager)
middlewares = append(middlewares, ei)
}
// Lua 中间件(可选)
if s.luaEngine != nil && serverCfg.Lua != nil && serverCfg.Lua.Enabled {
luaMiddlewares, err := s.buildLuaMiddlewares(serverCfg.Lua)
if err != nil {
return nil, fmt.Errorf("failed to create Lua middleware: %w", err)
}
middlewares = append(middlewares, luaMiddlewares...)
}
return middleware.NewChain(middlewares...), nil
}
// buildLuaMiddlewares 根据 Lua 配置创建中间件。
//
// 根据 Scripts 配置创建 LuaMiddleware 或 MultiPhaseLuaMiddleware。
// 支持单脚本和多阶段脚本配置。
//
// 参数:
// - luaCfg: Lua 配置对象
//
// 返回值:
// - []middleware.Middleware: 创建的中间件列表
// - error: 创建过程中遇到的错误
func (s *Server) buildLuaMiddlewares(luaCfg *config.LuaMiddlewareConfig) ([]middleware.Middleware, error) {
if s.luaEngine == nil {
return nil, nil
}
// 按阶段分组脚本
phaseScripts := make(map[string][]config.LuaScriptConfig)
for _, script := range luaCfg.Scripts {
// 路由脚本由 LocationEngine 处理,跳过
if script.Route != "" {
continue
}
// 默认启用
enabled := script.Enabled
if !enabled && script.Timeout == 0 && script.Path != "" {
enabled = true // 零值时默认启用
}
if enabled {
phaseScripts[script.Phase] = append(phaseScripts[script.Phase], script)
}
}
var middlewares []middleware.Middleware
// 为每个阶段创建中间件
for phase, scripts := range phaseScripts {
if len(scripts) == 0 {
continue
}
// 单脚本:直接创建 LuaMiddleware
if len(scripts) == 1 {
script := scripts[0]
luaPhase, err := lua.ParsePhase(phase)
if err != nil {
return nil, fmt.Errorf("invalid phase '%s': %w", phase, err)
}
timeout := script.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
cfg := lua.LuaMiddlewareConfig{
ScriptPath: script.Path,
Phase: luaPhase,
Timeout: timeout,
Name: fmt.Sprintf("lua-%s", phase),
}
mw, err := lua.NewLuaMiddleware(s.luaEngine, cfg)
if err != nil {
return nil, fmt.Errorf("failed to create Lua middleware (phase=%s): %w", phase, err)
}
middlewares = append(middlewares, mw)
} else {
// 多脚本:创建 MultiPhaseLuaMiddleware
multi := lua.NewMultiPhaseLuaMiddleware(s.luaEngine, fmt.Sprintf("lua-multi-%s", phase))
for _, script := range scripts {
luaPhase, err := lua.ParsePhase(phase)
if err != nil {
return nil, fmt.Errorf("invalid phase '%s': %w", phase, err)
}
timeout := script.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
err = multi.AddPhase(luaPhase, script.Path, timeout)
if err != nil {
return nil, fmt.Errorf("failed to add Lua phase (phase=%s): %w", phase, err)
}
}
middlewares = append(middlewares, multi)
}
}
return middlewares, nil
}