lolly/internal/server/middleware_builder.go
xfy 6c538a1a56 feat(server,proxy): integrate Request-ID into middleware chain and proxy forwarding
- Register requestid.New() as first middleware in buildMiddlewareChain
  (before AccessLog) so the ID is available for logging
- Add SetRequestIDHeader() in proxy/headers.go to propagate X-Request-ID
  to upstream via proxy forwarding
- Call SetRequestIDHeader in header_modifier.go after SetForwardedHeaders
- Import requestid package in middleware_builder and proxy/headers
2026-06-11 23:41:30 +08:00

276 lines
8.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package server 提供 HTTP 服务器的核心实现,支持单服务器、虚拟主机和多服务器三种运行模式。
//
// 包含中间件构建器相关的逻辑,用于构建和管理 HTTP 中间件链。
//
// 作者xfy
package server
import (
"fmt"
"time"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/lua"
"rua.plus/lolly/internal/middleware"
"rua.plus/lolly/internal/middleware/bodylimit"
"rua.plus/lolly/internal/middleware/compression"
"rua.plus/lolly/internal/middleware/cors"
"rua.plus/lolly/internal/middleware/errorintercept"
"rua.plus/lolly/internal/middleware/requestid"
"rua.plus/lolly/internal/middleware/rewrite"
"rua.plus/lolly/internal/middleware/security"
)
// buildMiddlewareChain 构建中间件链。
//
// 根据服务器配置按顺序构建中间件链,顺序为:
//
// AccessLog -> AccessControl -> RateLimiter -> BasicAuth -> Rewrite -> Compression -> SecurityHeaders
//
// 参数:
// - serverCfg: 单个服务器的配置对象
//
// 返回值:
// - *middleware.Chain: 构建完成的中间件链
// - error: 构建过程中遇到的错误,如中间件创建失败
//
// 注意事项:
// - 各中间件按顺序依次包装请求处理器
// - 未配置的中间件不会添加到链中
func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middleware.Chain, error) {
var middlewares []middleware.Middleware
// 0. Request-ID (最先执行,确保后续中间件和日志可使用 $request_id)
middlewares = append(middlewares, requestid.New())
// 1. AccessLog (已集成)
middlewares = append(middlewares, s.accessLogMiddleware)
// 2. Security: AccessControl (IP 访问控制)
if len(serverCfg.Security.Access.Allow) > 0 || len(serverCfg.Security.Access.Deny) > 0 {
ac, err := security.NewAccessControl(&serverCfg.Security.Access)
if err != nil {
return nil, fmt.Errorf("failed to create access control middleware: %w", err)
}
middlewares = append(middlewares, ac)
s.accessControlsMu.Lock()
s.accessControl = ac
s.accessControls = append(s.accessControls, ac)
s.accessControlsMu.Unlock()
}
// 3. Security: RateLimiter (速率限制)
if serverCfg.Security.RateLimit.RequestRate > 0 {
rl, err := security.NewRateLimiter(&serverCfg.Security.RateLimit)
if err != nil {
return nil, fmt.Errorf("failed to create rate limiter middleware: %w", err)
}
middlewares = append(middlewares, rl)
}
// 3.5 Security: ConnLimiter (连接数限制)
if serverCfg.Security.RateLimit.ConnLimit > 0 {
cl, err := security.NewConnLimiter(serverCfg.Security.RateLimit.ConnLimit, true, serverCfg.Security.RateLimit.Key)
if err != nil {
return nil, fmt.Errorf("failed to create connection limiter middleware: %w", err)
}
middlewares = append(middlewares, cl.Middleware())
}
// 4. Security: BasicAuth (认证)
if len(serverCfg.Security.Auth.Users) > 0 {
auth, err := security.NewBasicAuth(&serverCfg.Security.Auth)
if err != nil {
return nil, fmt.Errorf("failed to create auth middleware: %w", err)
}
middlewares = append(middlewares, auth)
}
// 4.3 Security: AuthRequest (外部认证子请求)
if serverCfg.Security.AuthRequest.Enabled && serverCfg.Security.AuthRequest.URI != "" {
authReq, err := security.NewAuthRequest(serverCfg.Security.AuthRequest)
if err != nil {
return nil, fmt.Errorf("failed to create auth request middleware: %w", err)
}
middlewares = append(middlewares, authReq)
}
// 4.5 BodyLimit (请求体大小限制)
// 创建 bodylimit 中间件,使用全局配置或默认值
bodyLimitMiddleware := bodylimit.NewWithDefault()
if serverCfg.ClientMaxBodySize != "" {
bl, err := bodylimit.New(serverCfg.ClientMaxBodySize)
if err != nil {
return nil, fmt.Errorf("failed to create body limit middleware: %w", err)
}
bodyLimitMiddleware = bl
}
// 添加路径级别的限制配置
for i := range serverCfg.Proxy {
if serverCfg.Proxy[i].ClientMaxBodySize != "" {
if err := bodyLimitMiddleware.AddPathLimit(
serverCfg.Proxy[i].Path,
serverCfg.Proxy[i].ClientMaxBodySize,
); err != nil {
return nil, fmt.Errorf("failed to add path body limit: %w", err)
}
}
}
middlewares = append(middlewares, bodyLimitMiddleware)
// 5. Rewrite (URL 重写)
if len(serverCfg.Rewrite) > 0 {
rw, err := rewrite.New(serverCfg.Rewrite)
if err != nil {
return nil, fmt.Errorf("failed to create rewrite middleware: %w", err)
}
middlewares = append(middlewares, rw)
}
// 6. Compression (响应压缩)
if serverCfg.Compression.Type != "" {
comp, err := compression.New(&serverCfg.Compression)
if err != nil {
return nil, fmt.Errorf("failed to create compression middleware: %w", err)
}
middlewares = append(middlewares, comp)
}
// 7. SecurityHeaders (安全头部)
// 如果有任何安全头部配置,则启用
if serverCfg.Security.Headers.XFrameOptions != "" ||
serverCfg.Security.Headers.XContentTypeOptions != "" ||
serverCfg.Security.Headers.ContentSecurityPolicy != "" ||
serverCfg.Security.Headers.ReferrerPolicy != "" ||
serverCfg.Security.Headers.PermissionsPolicy != "" {
headers := security.NewHeadersWithHSTS(&serverCfg.Security.Headers, &serverCfg.SSL.HSTS)
middlewares = append(middlewares, headers)
}
// 7.5 CORS (跨域资源共享)
if serverCfg.Security.CORS.Enabled {
middlewares = append(middlewares, cors.New(&cors.CORSConfig{
Enabled: serverCfg.Security.CORS.Enabled,
AllowedOrigins: serverCfg.Security.CORS.AllowedOrigins,
AllowedMethods: serverCfg.Security.CORS.AllowedMethods,
AllowedHeaders: serverCfg.Security.CORS.AllowedHeaders,
ExposeHeaders: serverCfg.Security.CORS.ExposeHeaders,
AllowCredentials: serverCfg.Security.CORS.AllowCredentials,
MaxAge: serverCfg.Security.CORS.MaxAge,
}))
}
// 8. ErrorIntercept (错误页面拦截)
// 如果配置了错误页面,添加错误拦截中间件
if s.errorPageManager != nil && s.errorPageManager.IsConfigured() {
ei := errorintercept.New(s.errorPageManager)
middlewares = append(middlewares, ei)
}
// Lua 中间件(可选)
if s.luaEngine != nil && serverCfg.Lua != nil && serverCfg.Lua.Enabled {
luaMiddlewares, err := s.buildLuaMiddlewares(serverCfg.Lua)
if err != nil {
return nil, fmt.Errorf("failed to create Lua middleware: %w", err)
}
middlewares = append(middlewares, luaMiddlewares...)
}
return middleware.NewChain(middlewares...), nil
}
// buildLuaMiddlewares 根据 Lua 配置创建中间件。
//
// 根据 Scripts 配置创建 LuaMiddleware 或 MultiPhaseLuaMiddleware。
// 支持单脚本和多阶段脚本配置。
//
// 参数:
// - luaCfg: Lua 配置对象
//
// 返回值:
// - []middleware.Middleware: 创建的中间件列表
// - error: 创建过程中遇到的错误
func (s *Server) buildLuaMiddlewares(luaCfg *config.LuaMiddlewareConfig) ([]middleware.Middleware, error) {
if s.luaEngine == nil {
return nil, nil
}
// 按阶段分组脚本
phaseScripts := make(map[string][]config.LuaScriptConfig)
for _, script := range luaCfg.Scripts {
// 路由脚本由 LocationEngine 处理,跳过
if script.Route != "" {
continue
}
// 默认启用
enabled := script.Enabled
if !enabled && script.Timeout == 0 && script.Path != "" {
enabled = true // 零值时默认启用
}
if enabled {
phaseScripts[script.Phase] = append(phaseScripts[script.Phase], script)
}
}
var middlewares []middleware.Middleware
// 为每个阶段创建中间件
for phase, scripts := range phaseScripts {
if len(scripts) == 0 {
continue
}
// 单脚本:直接创建 LuaMiddleware
if len(scripts) == 1 {
script := scripts[0]
luaPhase, err := lua.ParsePhase(phase)
if err != nil {
return nil, fmt.Errorf("invalid phase '%s': %w", phase, err)
}
timeout := script.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
cfg := lua.LuaMiddlewareConfig{
ScriptPath: script.Path,
Phase: luaPhase,
Timeout: timeout,
Name: fmt.Sprintf("lua-%s", phase),
}
mw, err := lua.NewLuaMiddleware(s.luaEngine, cfg)
if err != nil {
return nil, fmt.Errorf("failed to create Lua middleware (phase=%s): %w", phase, err)
}
middlewares = append(middlewares, mw)
} else {
// 多脚本:创建 MultiPhaseLuaMiddleware
multi := lua.NewMultiPhaseLuaMiddleware(s.luaEngine, fmt.Sprintf("lua-multi-%s", phase))
for _, script := range scripts {
luaPhase, err := lua.ParsePhase(phase)
if err != nil {
return nil, fmt.Errorf("invalid phase '%s': %w", phase, err)
}
timeout := script.Timeout
if timeout == 0 {
timeout = 30 * time.Second
}
err = multi.AddPhase(luaPhase, script.Path, timeout)
if err != nil {
return nil, fmt.Errorf("failed to add Lua phase (phase=%s): %w", phase, err)
}
}
middlewares = append(middlewares, multi)
}
}
return middlewares, nil
}