diff --git a/internal/middleware/security/access.go b/internal/middleware/security/access.go index 4dd3a37..d193990 100644 --- a/internal/middleware/security/access.go +++ b/internal/middleware/security/access.go @@ -43,6 +43,9 @@ const ( ActionAllow Action = iota // ActionDeny 拒绝请求(返回 403 Forbidden) ActionDeny + + accessAllow = "allow" + accessDeny = "deny" ) // AccessControl 实现 IP 访问控制中间件。 @@ -112,9 +115,9 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) { // 设置默认操作 switch strings.ToLower(cfg.Default) { - case "allow", "": + case accessAllow, "": ac.defaultAction = ActionAllow - case "deny": + case accessDeny: ac.defaultAction = ActionDeny default: return nil, fmt.Errorf("invalid default action: %s", cfg.Default) @@ -259,9 +262,9 @@ func (ac *AccessControl) SetDefault(action string) error { defer ac.mu.Unlock() switch strings.ToLower(action) { - case "allow": + case accessAllow: ac.defaultAction = ActionAllow - case "deny": + case accessDeny: ac.defaultAction = ActionDeny default: return fmt.Errorf("invalid action: %s", action) @@ -436,7 +439,7 @@ func actionToString(action Action) string { case ActionAllow: return "allow" case ActionDeny: - return "deny" + return accessDeny default: return "unknown" } diff --git a/internal/middleware/security/auth_request.go b/internal/middleware/security/auth_request.go index 9f059fb..d6d4bc8 100644 --- a/internal/middleware/security/auth_request.go +++ b/internal/middleware/security/auth_request.go @@ -313,8 +313,8 @@ func (a *AuthRequest) expandVars(ctx *fasthttp.RequestCtx, template string) stri } // 创建变量上下文 - vc := variable.NewVariableContext(ctx) - defer variable.ReleaseVariableContext(vc) + vc := variable.NewContext(ctx) + defer variable.ReleaseContext(vc) return vc.Expand(template) } diff --git a/internal/middleware/security/headers.go b/internal/middleware/security/headers.go index 5791f07..679d81c 100644 --- a/internal/middleware/security/headers.go +++ b/internal/middleware/security/headers.go @@ -35,7 +35,7 @@ import ( "rua.plus/lolly/internal/middleware" ) -// SecurityHeadersMiddleware 安全响应头中间件。 +// HeadersMiddleware 安全响应头中间件。 // // 为 HTTP 响应添加安全相关的头部字段,防止常见的 Web 安全漏洞。 // 支持配置各种安全头的值,并提供安全的默认配置。 @@ -43,13 +43,13 @@ import ( // 注意事项: // - 所有方法均为并发安全 // - HSTS 头仅在 TLS 连接时添加 -type SecurityHeadersMiddleware struct { +type HeadersMiddleware struct { config *config.SecurityHeaders // 安全头配置 hsts string // 预格式化的 HSTS 头值 mu sync.RWMutex // 读写锁,保护并发访问 } -// NewSecurityHeaders 创建新的安全响应头中间件。 +// NewHeaders 创建新的安全响应头中间件。 // // 根据配置创建中间件实例,如果配置为 nil 则使用安全的默认值。 // @@ -57,12 +57,12 @@ type SecurityHeadersMiddleware struct { // - cfg: 安全头配置,可以为 nil 使用默认配置 // // 返回值: -// - *SecurityHeadersMiddleware: 配置好的中间件实例 -func NewSecurityHeaders(cfg *config.SecurityHeaders) *SecurityHeadersMiddleware { - return NewSecurityHeadersWithHSTS(cfg, nil) +// - *HeadersMiddleware: 配置好的中间件实例 +func NewHeaders(cfg *config.SecurityHeaders) *HeadersMiddleware { + return NewHeadersWithHSTS(cfg, nil) } -// NewSecurityHeadersWithHSTS 创建新的安全响应头中间件,支持 HSTS 配置。 +// NewHeadersWithHSTS 创建新的安全响应头中间件,支持 HSTS 配置。 // // 根据配置创建中间件实例,如果配置为 nil 则使用安全的默认值。 // @@ -71,9 +71,9 @@ func NewSecurityHeaders(cfg *config.SecurityHeaders) *SecurityHeadersMiddleware // - hstsCfg: HSTS 配置,可以为 nil 使用默认值 // // 返回值: -// - *SecurityHeadersMiddleware: 配置好的中间件实例 -func NewSecurityHeadersWithHSTS(cfg *config.SecurityHeaders, hstsCfg *config.HSTSConfig) *SecurityHeadersMiddleware { - sh := &SecurityHeadersMiddleware{} +// - *HeadersMiddleware: 配置好的中间件实例 +func NewHeadersWithHSTS(cfg *config.SecurityHeaders, hstsCfg *config.HSTSConfig) *HeadersMiddleware { + sh := &HeadersMiddleware{} if cfg != nil { sh.config = cfg @@ -93,7 +93,7 @@ func NewSecurityHeadersWithHSTS(cfg *config.SecurityHeaders, hstsCfg *config.HST } // formatHSTSFromConfig 根据配置格式化 HSTS 头值。 -func (sh *SecurityHeadersMiddleware) formatHSTSFromConfig(hstsCfg *config.HSTSConfig) { +func (sh *HeadersMiddleware) formatHSTSFromConfig(hstsCfg *config.HSTSConfig) { if hstsCfg != nil { maxAge := hstsCfg.MaxAge if maxAge <= 0 { @@ -109,7 +109,7 @@ func (sh *SecurityHeadersMiddleware) formatHSTSFromConfig(hstsCfg *config.HSTSCo // // 返回值: // - string: 中间件标识名 "security_headers" -func (sh *SecurityHeadersMiddleware) Name() string { +func (sh *HeadersMiddleware) Name() string { return "security_headers" } @@ -122,7 +122,7 @@ func (sh *SecurityHeadersMiddleware) Name() string { // // 返回值: // - fasthttp.RequestHandler: 包装后的处理器 -func (sh *SecurityHeadersMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { +func (sh *HeadersMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { // 先调用下一个处理器 next(ctx) @@ -138,7 +138,7 @@ func (sh *SecurityHeadersMiddleware) Process(next fasthttp.RequestHandler) fasth // // 参数: // - ctx: FastHTTP 请求上下文 -func (sh *SecurityHeadersMiddleware) addHeaders(ctx *fasthttp.RequestCtx) { +func (sh *HeadersMiddleware) addHeaders(ctx *fasthttp.RequestCtx) { headers := &ctx.Response.Header sh.mu.RLock() @@ -183,7 +183,7 @@ func (sh *SecurityHeadersMiddleware) addHeaders(ctx *fasthttp.RequestCtx) { // // HSTS(HTTP Strict Transport Security)用于强制浏览器使用 HTTPS 连接。 // 默认配置为 1 年有效期,包含子域名。 -func (sh *SecurityHeadersMiddleware) formatHSTS() { +func (sh *HeadersMiddleware) formatHSTS() { // 默认 HSTS 值 maxAge := 31536000 // 1 年有效期(秒) includeSubDomains := true // 包含所有子域名 @@ -223,7 +223,7 @@ func formatHSTSValue(maxAge int, includeSubDomains bool, preload bool) string { // // 参数: // - cfg: 新的安全头配置 -func (sh *SecurityHeadersMiddleware) UpdateConfig(cfg *config.SecurityHeaders) { +func (sh *HeadersMiddleware) UpdateConfig(cfg *config.SecurityHeaders) { sh.mu.Lock() sh.config = cfg sh.formatHSTS() @@ -234,7 +234,7 @@ func (sh *SecurityHeadersMiddleware) UpdateConfig(cfg *config.SecurityHeaders) { // // 参数: // - value: 新的 X-Frame-Options 值(如 "DENY"、"SAMEORIGIN") -func (sh *SecurityHeadersMiddleware) SetXFrameOptions(value string) { +func (sh *HeadersMiddleware) SetXFrameOptions(value string) { sh.mu.Lock() if sh.config != nil { sh.config.XFrameOptions = value @@ -246,7 +246,7 @@ func (sh *SecurityHeadersMiddleware) SetXFrameOptions(value string) { // // 参数: // - value: 新的 Content-Security-Policy 值 -func (sh *SecurityHeadersMiddleware) SetContentSecurityPolicy(value string) { +func (sh *HeadersMiddleware) SetContentSecurityPolicy(value string) { sh.mu.Lock() if sh.config != nil { sh.config.ContentSecurityPolicy = value @@ -258,7 +258,7 @@ func (sh *SecurityHeadersMiddleware) SetContentSecurityPolicy(value string) { // // 参数: // - value: 新的 Referrer-Policy 值(如 "no-referrer"、"strict-origin") -func (sh *SecurityHeadersMiddleware) SetReferrerPolicy(value string) { +func (sh *HeadersMiddleware) SetReferrerPolicy(value string) { sh.mu.Lock() if sh.config != nil { sh.config.ReferrerPolicy = value @@ -270,7 +270,7 @@ func (sh *SecurityHeadersMiddleware) SetReferrerPolicy(value string) { // // 参数: // - value: 新的 Permissions-Policy 值 -func (sh *SecurityHeadersMiddleware) SetPermissionsPolicy(value string) { +func (sh *HeadersMiddleware) SetPermissionsPolicy(value string) { sh.mu.Lock() if sh.config != nil { sh.config.PermissionsPolicy = value @@ -282,7 +282,7 @@ func (sh *SecurityHeadersMiddleware) SetPermissionsPolicy(value string) { // // 返回值: // - *config.SecurityHeaders: 当前配置的副本 -func (sh *SecurityHeadersMiddleware) GetConfig() *config.SecurityHeaders { +func (sh *HeadersMiddleware) GetConfig() *config.SecurityHeaders { sh.mu.RLock() defer sh.mu.RUnlock() return sh.config @@ -331,4 +331,4 @@ func DevelopmentSecurityHeaders() *config.SecurityHeaders { } // 验证接口实现 -var _ middleware.Middleware = (*SecurityHeadersMiddleware)(nil) +var _ middleware.Middleware = (*HeadersMiddleware)(nil) diff --git a/internal/middleware/security/headers_test.go b/internal/middleware/security/headers_test.go index 1c15d76..75a67bb 100644 --- a/internal/middleware/security/headers_test.go +++ b/internal/middleware/security/headers_test.go @@ -17,7 +17,7 @@ import ( "rua.plus/lolly/internal/config" ) -func TestNewSecurityHeaders(t *testing.T) { +func TestNewHeaders(t *testing.T) { tests := []struct { name string cfg *config.SecurityHeaders @@ -38,22 +38,22 @@ func TestNewSecurityHeaders(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sh := NewSecurityHeaders(tt.cfg) + sh := NewHeaders(tt.cfg) if sh == nil { - t.Error("Expected non-nil SecurityHeadersMiddleware") + t.Error("Expected non-nil HeadersMiddleware") } }) } } -func TestSecurityHeadersName(t *testing.T) { - sh := NewSecurityHeaders(nil) +func TestHeadersName(t *testing.T) { + sh := NewHeaders(nil) if sh.Name() != "security_headers" { t.Errorf("Expected name 'security_headers', got %s", sh.Name()) } } -func TestSecurityHeadersProcess(t *testing.T) { +func TestHeadersProcess(t *testing.T) { cfg := &config.SecurityHeaders{ XFrameOptions: "DENY", XContentTypeOptions: "nosniff", @@ -62,7 +62,7 @@ func TestSecurityHeadersProcess(t *testing.T) { PermissionsPolicy: "geolocation=()", } - sh := NewSecurityHeaders(cfg) + sh := NewHeaders(cfg) handlerCalled := false nextHandler := func(ctx *fasthttp.RequestCtx) { @@ -107,15 +107,14 @@ func TestSecurityHeadersProcess(t *testing.T) { } } -func TestSecurityHeadersHSTS(t *testing.T) { +func TestHeadersHSTS(_ *testing.T) { cfg := &config.SecurityHeaders{ XFrameOptions: "DENY", } - sh := NewSecurityHeaders(cfg) + sh := NewHeaders(cfg) - nextHandler := func(ctx *fasthttp.RequestCtx) { - _, _ = ctx.WriteString("OK") + nextHandler := func(_ *fasthttp.RequestCtx) { } handler := sh.Process(nextHandler) @@ -131,8 +130,8 @@ func TestSecurityHeadersHSTS(t *testing.T) { // In this test we verify the handler doesn't panic } -func TestSecurityHeadersUpdate(t *testing.T) { - sh := NewSecurityHeaders(nil) +func TestHeadersUpdate(t *testing.T) { + sh := NewHeaders(nil) // Update X-Frame-Options sh.SetXFrameOptions("SAMEORIGIN") @@ -164,7 +163,7 @@ func TestSecurityHeadersUpdate(t *testing.T) { } func TestUpdateConfig(t *testing.T) { - sh := NewSecurityHeaders(nil) + sh := NewHeaders(nil) newCfg := &config.SecurityHeaders{ XFrameOptions: "DENY", diff --git a/internal/middleware/security/ratelimit.go b/internal/middleware/security/ratelimit.go index bcc4fd3..673e3e9 100644 --- a/internal/middleware/security/ratelimit.go +++ b/internal/middleware/security/ratelimit.go @@ -42,6 +42,8 @@ import ( "rua.plus/lolly/internal/netutil" ) +const rateLimitHeader = "header" + // RateLimiter 基于令牌桶算法的请求速率限制器。 // // 实现请求限流功能,支持按 IP 或自定义键值进行限流。 @@ -133,7 +135,7 @@ func newTokenBucketLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) { switch cfg.Key { case "ip", "": rl.keyFunc = keyByIP - case "header": + case rateLimitHeader: rl.keyFunc = keyByHeader default: return nil, fmt.Errorf("unknown key type: %s", cfg.Key) @@ -157,7 +159,7 @@ func NewSlidingWindowLimiterWrapper(cfg *config.RateLimitConfig, window time.Dur switch cfg.Key { case "ip", "": keyFunc = keyByIP - case "header": + case rateLimitHeader: keyFunc = keyByHeader default: return nil, fmt.Errorf("unknown key type: %s", cfg.Key) @@ -485,13 +487,13 @@ type ConnLimiter struct { // 返回值: // - *ConnLimiter: 配置好的连接限制器 // - error: 配置无效时返回错误 -func NewConnLimiter(max int, perKey bool, keyType string) (*ConnLimiter, error) { - if max <= 0 { +func NewConnLimiter(maxConns int, perKey bool, keyType string) (*ConnLimiter, error) { + if maxConns <= 0 { return nil, errors.New("max connections must be positive") } cl := &ConnLimiter{ - max: max, + max: maxConns, perKey: perKey, counts: make(map[string]int64), } @@ -500,7 +502,7 @@ func NewConnLimiter(max int, perKey bool, keyType string) (*ConnLimiter, error) switch keyType { case "ip", "": cl.keyFunc = keyByIP - case "header": + case rateLimitHeader: cl.keyFunc = keyByHeader default: return nil, fmt.Errorf("unknown key type: %s", keyType) diff --git a/internal/middleware/security/sliding_window.go b/internal/middleware/security/sliding_window.go index 24d939f..660df16 100644 --- a/internal/middleware/security/sliding_window.go +++ b/internal/middleware/security/sliding_window.go @@ -206,7 +206,7 @@ func (s *SlidingWindowLimiter) Cleanup(maxAge time.Duration) { } } -// Stats 返回限流器统计信息。 +// SlidingWindowStats 返回限流器统计信息。 type SlidingWindowStats struct { Window time.Duration // 窗口大小 Limit int // 窗口内最大请求数