refactor(security): 重命名 HeadersMiddleware 移除冗余前缀
SecurityHeadersMiddleware → HeadersMiddleware NewSecurityHeaders → NewHeaders NewSecurityHeadersWithHSTS → NewHeadersWithHSTS Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
3e153f5fe1
commit
649a6ed23f
@ -43,6 +43,9 @@ const (
|
||||
ActionAllow Action = iota
|
||||
// ActionDeny 拒绝请求(返回 403 Forbidden)
|
||||
ActionDeny
|
||||
|
||||
accessAllow = "allow"
|
||||
accessDeny = "deny"
|
||||
)
|
||||
|
||||
// AccessControl 实现 IP 访问控制中间件。
|
||||
@ -112,9 +115,9 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) {
|
||||
|
||||
// 设置默认操作
|
||||
switch strings.ToLower(cfg.Default) {
|
||||
case "allow", "":
|
||||
case accessAllow, "":
|
||||
ac.defaultAction = ActionAllow
|
||||
case "deny":
|
||||
case accessDeny:
|
||||
ac.defaultAction = ActionDeny
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid default action: %s", cfg.Default)
|
||||
@ -259,9 +262,9 @@ func (ac *AccessControl) SetDefault(action string) error {
|
||||
defer ac.mu.Unlock()
|
||||
|
||||
switch strings.ToLower(action) {
|
||||
case "allow":
|
||||
case accessAllow:
|
||||
ac.defaultAction = ActionAllow
|
||||
case "deny":
|
||||
case accessDeny:
|
||||
ac.defaultAction = ActionDeny
|
||||
default:
|
||||
return fmt.Errorf("invalid action: %s", action)
|
||||
@ -436,7 +439,7 @@ func actionToString(action Action) string {
|
||||
case ActionAllow:
|
||||
return "allow"
|
||||
case ActionDeny:
|
||||
return "deny"
|
||||
return accessDeny
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
@ -313,8 +313,8 @@ func (a *AuthRequest) expandVars(ctx *fasthttp.RequestCtx, template string) stri
|
||||
}
|
||||
|
||||
// 创建变量上下文
|
||||
vc := variable.NewVariableContext(ctx)
|
||||
defer variable.ReleaseVariableContext(vc)
|
||||
vc := variable.NewContext(ctx)
|
||||
defer variable.ReleaseContext(vc)
|
||||
|
||||
return vc.Expand(template)
|
||||
}
|
||||
|
||||
@ -35,7 +35,7 @@ import (
|
||||
"rua.plus/lolly/internal/middleware"
|
||||
)
|
||||
|
||||
// SecurityHeadersMiddleware 安全响应头中间件。
|
||||
// HeadersMiddleware 安全响应头中间件。
|
||||
//
|
||||
// 为 HTTP 响应添加安全相关的头部字段,防止常见的 Web 安全漏洞。
|
||||
// 支持配置各种安全头的值,并提供安全的默认配置。
|
||||
@ -43,13 +43,13 @@ import (
|
||||
// 注意事项:
|
||||
// - 所有方法均为并发安全
|
||||
// - HSTS 头仅在 TLS 连接时添加
|
||||
type SecurityHeadersMiddleware struct {
|
||||
type HeadersMiddleware struct {
|
||||
config *config.SecurityHeaders // 安全头配置
|
||||
hsts string // 预格式化的 HSTS 头值
|
||||
mu sync.RWMutex // 读写锁,保护并发访问
|
||||
}
|
||||
|
||||
// NewSecurityHeaders 创建新的安全响应头中间件。
|
||||
// NewHeaders 创建新的安全响应头中间件。
|
||||
//
|
||||
// 根据配置创建中间件实例,如果配置为 nil 则使用安全的默认值。
|
||||
//
|
||||
@ -57,12 +57,12 @@ type SecurityHeadersMiddleware struct {
|
||||
// - cfg: 安全头配置,可以为 nil 使用默认配置
|
||||
//
|
||||
// 返回值:
|
||||
// - *SecurityHeadersMiddleware: 配置好的中间件实例
|
||||
func NewSecurityHeaders(cfg *config.SecurityHeaders) *SecurityHeadersMiddleware {
|
||||
return NewSecurityHeadersWithHSTS(cfg, nil)
|
||||
// - *HeadersMiddleware: 配置好的中间件实例
|
||||
func NewHeaders(cfg *config.SecurityHeaders) *HeadersMiddleware {
|
||||
return NewHeadersWithHSTS(cfg, nil)
|
||||
}
|
||||
|
||||
// NewSecurityHeadersWithHSTS 创建新的安全响应头中间件,支持 HSTS 配置。
|
||||
// NewHeadersWithHSTS 创建新的安全响应头中间件,支持 HSTS 配置。
|
||||
//
|
||||
// 根据配置创建中间件实例,如果配置为 nil 则使用安全的默认值。
|
||||
//
|
||||
@ -71,9 +71,9 @@ func NewSecurityHeaders(cfg *config.SecurityHeaders) *SecurityHeadersMiddleware
|
||||
// - hstsCfg: HSTS 配置,可以为 nil 使用默认值
|
||||
//
|
||||
// 返回值:
|
||||
// - *SecurityHeadersMiddleware: 配置好的中间件实例
|
||||
func NewSecurityHeadersWithHSTS(cfg *config.SecurityHeaders, hstsCfg *config.HSTSConfig) *SecurityHeadersMiddleware {
|
||||
sh := &SecurityHeadersMiddleware{}
|
||||
// - *HeadersMiddleware: 配置好的中间件实例
|
||||
func NewHeadersWithHSTS(cfg *config.SecurityHeaders, hstsCfg *config.HSTSConfig) *HeadersMiddleware {
|
||||
sh := &HeadersMiddleware{}
|
||||
|
||||
if cfg != nil {
|
||||
sh.config = cfg
|
||||
@ -93,7 +93,7 @@ func NewSecurityHeadersWithHSTS(cfg *config.SecurityHeaders, hstsCfg *config.HST
|
||||
}
|
||||
|
||||
// formatHSTSFromConfig 根据配置格式化 HSTS 头值。
|
||||
func (sh *SecurityHeadersMiddleware) formatHSTSFromConfig(hstsCfg *config.HSTSConfig) {
|
||||
func (sh *HeadersMiddleware) formatHSTSFromConfig(hstsCfg *config.HSTSConfig) {
|
||||
if hstsCfg != nil {
|
||||
maxAge := hstsCfg.MaxAge
|
||||
if maxAge <= 0 {
|
||||
@ -109,7 +109,7 @@ func (sh *SecurityHeadersMiddleware) formatHSTSFromConfig(hstsCfg *config.HSTSCo
|
||||
//
|
||||
// 返回值:
|
||||
// - string: 中间件标识名 "security_headers"
|
||||
func (sh *SecurityHeadersMiddleware) Name() string {
|
||||
func (sh *HeadersMiddleware) Name() string {
|
||||
return "security_headers"
|
||||
}
|
||||
|
||||
@ -122,7 +122,7 @@ func (sh *SecurityHeadersMiddleware) Name() string {
|
||||
//
|
||||
// 返回值:
|
||||
// - fasthttp.RequestHandler: 包装后的处理器
|
||||
func (sh *SecurityHeadersMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||
func (sh *HeadersMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
// 先调用下一个处理器
|
||||
next(ctx)
|
||||
@ -138,7 +138,7 @@ func (sh *SecurityHeadersMiddleware) Process(next fasthttp.RequestHandler) fasth
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: FastHTTP 请求上下文
|
||||
func (sh *SecurityHeadersMiddleware) addHeaders(ctx *fasthttp.RequestCtx) {
|
||||
func (sh *HeadersMiddleware) addHeaders(ctx *fasthttp.RequestCtx) {
|
||||
headers := &ctx.Response.Header
|
||||
|
||||
sh.mu.RLock()
|
||||
@ -183,7 +183,7 @@ func (sh *SecurityHeadersMiddleware) addHeaders(ctx *fasthttp.RequestCtx) {
|
||||
//
|
||||
// HSTS(HTTP Strict Transport Security)用于强制浏览器使用 HTTPS 连接。
|
||||
// 默认配置为 1 年有效期,包含子域名。
|
||||
func (sh *SecurityHeadersMiddleware) formatHSTS() {
|
||||
func (sh *HeadersMiddleware) formatHSTS() {
|
||||
// 默认 HSTS 值
|
||||
maxAge := 31536000 // 1 年有效期(秒)
|
||||
includeSubDomains := true // 包含所有子域名
|
||||
@ -223,7 +223,7 @@ func formatHSTSValue(maxAge int, includeSubDomains bool, preload bool) string {
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 新的安全头配置
|
||||
func (sh *SecurityHeadersMiddleware) UpdateConfig(cfg *config.SecurityHeaders) {
|
||||
func (sh *HeadersMiddleware) UpdateConfig(cfg *config.SecurityHeaders) {
|
||||
sh.mu.Lock()
|
||||
sh.config = cfg
|
||||
sh.formatHSTS()
|
||||
@ -234,7 +234,7 @@ func (sh *SecurityHeadersMiddleware) UpdateConfig(cfg *config.SecurityHeaders) {
|
||||
//
|
||||
// 参数:
|
||||
// - value: 新的 X-Frame-Options 值(如 "DENY"、"SAMEORIGIN")
|
||||
func (sh *SecurityHeadersMiddleware) SetXFrameOptions(value string) {
|
||||
func (sh *HeadersMiddleware) SetXFrameOptions(value string) {
|
||||
sh.mu.Lock()
|
||||
if sh.config != nil {
|
||||
sh.config.XFrameOptions = value
|
||||
@ -246,7 +246,7 @@ func (sh *SecurityHeadersMiddleware) SetXFrameOptions(value string) {
|
||||
//
|
||||
// 参数:
|
||||
// - value: 新的 Content-Security-Policy 值
|
||||
func (sh *SecurityHeadersMiddleware) SetContentSecurityPolicy(value string) {
|
||||
func (sh *HeadersMiddleware) SetContentSecurityPolicy(value string) {
|
||||
sh.mu.Lock()
|
||||
if sh.config != nil {
|
||||
sh.config.ContentSecurityPolicy = value
|
||||
@ -258,7 +258,7 @@ func (sh *SecurityHeadersMiddleware) SetContentSecurityPolicy(value string) {
|
||||
//
|
||||
// 参数:
|
||||
// - value: 新的 Referrer-Policy 值(如 "no-referrer"、"strict-origin")
|
||||
func (sh *SecurityHeadersMiddleware) SetReferrerPolicy(value string) {
|
||||
func (sh *HeadersMiddleware) SetReferrerPolicy(value string) {
|
||||
sh.mu.Lock()
|
||||
if sh.config != nil {
|
||||
sh.config.ReferrerPolicy = value
|
||||
@ -270,7 +270,7 @@ func (sh *SecurityHeadersMiddleware) SetReferrerPolicy(value string) {
|
||||
//
|
||||
// 参数:
|
||||
// - value: 新的 Permissions-Policy 值
|
||||
func (sh *SecurityHeadersMiddleware) SetPermissionsPolicy(value string) {
|
||||
func (sh *HeadersMiddleware) SetPermissionsPolicy(value string) {
|
||||
sh.mu.Lock()
|
||||
if sh.config != nil {
|
||||
sh.config.PermissionsPolicy = value
|
||||
@ -282,7 +282,7 @@ func (sh *SecurityHeadersMiddleware) SetPermissionsPolicy(value string) {
|
||||
//
|
||||
// 返回值:
|
||||
// - *config.SecurityHeaders: 当前配置的副本
|
||||
func (sh *SecurityHeadersMiddleware) GetConfig() *config.SecurityHeaders {
|
||||
func (sh *HeadersMiddleware) GetConfig() *config.SecurityHeaders {
|
||||
sh.mu.RLock()
|
||||
defer sh.mu.RUnlock()
|
||||
return sh.config
|
||||
@ -331,4 +331,4 @@ func DevelopmentSecurityHeaders() *config.SecurityHeaders {
|
||||
}
|
||||
|
||||
// 验证接口实现
|
||||
var _ middleware.Middleware = (*SecurityHeadersMiddleware)(nil)
|
||||
var _ middleware.Middleware = (*HeadersMiddleware)(nil)
|
||||
|
||||
@ -17,7 +17,7 @@ import (
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
func TestNewSecurityHeaders(t *testing.T) {
|
||||
func TestNewHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *config.SecurityHeaders
|
||||
@ -38,22 +38,22 @@ func TestNewSecurityHeaders(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sh := NewSecurityHeaders(tt.cfg)
|
||||
sh := NewHeaders(tt.cfg)
|
||||
if sh == nil {
|
||||
t.Error("Expected non-nil SecurityHeadersMiddleware")
|
||||
t.Error("Expected non-nil HeadersMiddleware")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeadersName(t *testing.T) {
|
||||
sh := NewSecurityHeaders(nil)
|
||||
func TestHeadersName(t *testing.T) {
|
||||
sh := NewHeaders(nil)
|
||||
if sh.Name() != "security_headers" {
|
||||
t.Errorf("Expected name 'security_headers', got %s", sh.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeadersProcess(t *testing.T) {
|
||||
func TestHeadersProcess(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
XContentTypeOptions: "nosniff",
|
||||
@ -62,7 +62,7 @@ func TestSecurityHeadersProcess(t *testing.T) {
|
||||
PermissionsPolicy: "geolocation=()",
|
||||
}
|
||||
|
||||
sh := NewSecurityHeaders(cfg)
|
||||
sh := NewHeaders(cfg)
|
||||
|
||||
handlerCalled := false
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
@ -107,15 +107,14 @@ func TestSecurityHeadersProcess(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeadersHSTS(t *testing.T) {
|
||||
func TestHeadersHSTS(_ *testing.T) {
|
||||
cfg := &config.SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
}
|
||||
|
||||
sh := NewSecurityHeaders(cfg)
|
||||
sh := NewHeaders(cfg)
|
||||
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
_, _ = ctx.WriteString("OK")
|
||||
nextHandler := func(_ *fasthttp.RequestCtx) {
|
||||
}
|
||||
|
||||
handler := sh.Process(nextHandler)
|
||||
@ -131,8 +130,8 @@ func TestSecurityHeadersHSTS(t *testing.T) {
|
||||
// In this test we verify the handler doesn't panic
|
||||
}
|
||||
|
||||
func TestSecurityHeadersUpdate(t *testing.T) {
|
||||
sh := NewSecurityHeaders(nil)
|
||||
func TestHeadersUpdate(t *testing.T) {
|
||||
sh := NewHeaders(nil)
|
||||
|
||||
// Update X-Frame-Options
|
||||
sh.SetXFrameOptions("SAMEORIGIN")
|
||||
@ -164,7 +163,7 @@ func TestSecurityHeadersUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdateConfig(t *testing.T) {
|
||||
sh := NewSecurityHeaders(nil)
|
||||
sh := NewHeaders(nil)
|
||||
|
||||
newCfg := &config.SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
|
||||
@ -42,6 +42,8 @@ import (
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
)
|
||||
|
||||
const rateLimitHeader = "header"
|
||||
|
||||
// RateLimiter 基于令牌桶算法的请求速率限制器。
|
||||
//
|
||||
// 实现请求限流功能,支持按 IP 或自定义键值进行限流。
|
||||
@ -133,7 +135,7 @@ func newTokenBucketLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) {
|
||||
switch cfg.Key {
|
||||
case "ip", "":
|
||||
rl.keyFunc = keyByIP
|
||||
case "header":
|
||||
case rateLimitHeader:
|
||||
rl.keyFunc = keyByHeader
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown key type: %s", cfg.Key)
|
||||
@ -157,7 +159,7 @@ func NewSlidingWindowLimiterWrapper(cfg *config.RateLimitConfig, window time.Dur
|
||||
switch cfg.Key {
|
||||
case "ip", "":
|
||||
keyFunc = keyByIP
|
||||
case "header":
|
||||
case rateLimitHeader:
|
||||
keyFunc = keyByHeader
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown key type: %s", cfg.Key)
|
||||
@ -485,13 +487,13 @@ type ConnLimiter struct {
|
||||
// 返回值:
|
||||
// - *ConnLimiter: 配置好的连接限制器
|
||||
// - error: 配置无效时返回错误
|
||||
func NewConnLimiter(max int, perKey bool, keyType string) (*ConnLimiter, error) {
|
||||
if max <= 0 {
|
||||
func NewConnLimiter(maxConns int, perKey bool, keyType string) (*ConnLimiter, error) {
|
||||
if maxConns <= 0 {
|
||||
return nil, errors.New("max connections must be positive")
|
||||
}
|
||||
|
||||
cl := &ConnLimiter{
|
||||
max: max,
|
||||
max: maxConns,
|
||||
perKey: perKey,
|
||||
counts: make(map[string]int64),
|
||||
}
|
||||
@ -500,7 +502,7 @@ func NewConnLimiter(max int, perKey bool, keyType string) (*ConnLimiter, error)
|
||||
switch keyType {
|
||||
case "ip", "":
|
||||
cl.keyFunc = keyByIP
|
||||
case "header":
|
||||
case rateLimitHeader:
|
||||
cl.keyFunc = keyByHeader
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown key type: %s", keyType)
|
||||
|
||||
@ -206,7 +206,7 @@ func (s *SlidingWindowLimiter) Cleanup(maxAge time.Duration) {
|
||||
}
|
||||
}
|
||||
|
||||
// Stats 返回限流器统计信息。
|
||||
// SlidingWindowStats 返回限流器统计信息。
|
||||
type SlidingWindowStats struct {
|
||||
Window time.Duration // 窗口大小
|
||||
Limit int // 窗口内最大请求数
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user