refactor(security): 重命名 HeadersMiddleware 移除冗余前缀
SecurityHeadersMiddleware → HeadersMiddleware NewSecurityHeaders → NewHeaders NewSecurityHeadersWithHSTS → NewHeadersWithHSTS Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
3e153f5fe1
commit
649a6ed23f
@ -43,6 +43,9 @@ const (
|
|||||||
ActionAllow Action = iota
|
ActionAllow Action = iota
|
||||||
// ActionDeny 拒绝请求(返回 403 Forbidden)
|
// ActionDeny 拒绝请求(返回 403 Forbidden)
|
||||||
ActionDeny
|
ActionDeny
|
||||||
|
|
||||||
|
accessAllow = "allow"
|
||||||
|
accessDeny = "deny"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccessControl 实现 IP 访问控制中间件。
|
// AccessControl 实现 IP 访问控制中间件。
|
||||||
@ -112,9 +115,9 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) {
|
|||||||
|
|
||||||
// 设置默认操作
|
// 设置默认操作
|
||||||
switch strings.ToLower(cfg.Default) {
|
switch strings.ToLower(cfg.Default) {
|
||||||
case "allow", "":
|
case accessAllow, "":
|
||||||
ac.defaultAction = ActionAllow
|
ac.defaultAction = ActionAllow
|
||||||
case "deny":
|
case accessDeny:
|
||||||
ac.defaultAction = ActionDeny
|
ac.defaultAction = ActionDeny
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid default action: %s", cfg.Default)
|
return nil, fmt.Errorf("invalid default action: %s", cfg.Default)
|
||||||
@ -259,9 +262,9 @@ func (ac *AccessControl) SetDefault(action string) error {
|
|||||||
defer ac.mu.Unlock()
|
defer ac.mu.Unlock()
|
||||||
|
|
||||||
switch strings.ToLower(action) {
|
switch strings.ToLower(action) {
|
||||||
case "allow":
|
case accessAllow:
|
||||||
ac.defaultAction = ActionAllow
|
ac.defaultAction = ActionAllow
|
||||||
case "deny":
|
case accessDeny:
|
||||||
ac.defaultAction = ActionDeny
|
ac.defaultAction = ActionDeny
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("invalid action: %s", action)
|
return fmt.Errorf("invalid action: %s", action)
|
||||||
@ -436,7 +439,7 @@ func actionToString(action Action) string {
|
|||||||
case ActionAllow:
|
case ActionAllow:
|
||||||
return "allow"
|
return "allow"
|
||||||
case ActionDeny:
|
case ActionDeny:
|
||||||
return "deny"
|
return accessDeny
|
||||||
default:
|
default:
|
||||||
return "unknown"
|
return "unknown"
|
||||||
}
|
}
|
||||||
|
|||||||
@ -313,8 +313,8 @@ func (a *AuthRequest) expandVars(ctx *fasthttp.RequestCtx, template string) stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 创建变量上下文
|
// 创建变量上下文
|
||||||
vc := variable.NewVariableContext(ctx)
|
vc := variable.NewContext(ctx)
|
||||||
defer variable.ReleaseVariableContext(vc)
|
defer variable.ReleaseContext(vc)
|
||||||
|
|
||||||
return vc.Expand(template)
|
return vc.Expand(template)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -35,7 +35,7 @@ import (
|
|||||||
"rua.plus/lolly/internal/middleware"
|
"rua.plus/lolly/internal/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SecurityHeadersMiddleware 安全响应头中间件。
|
// HeadersMiddleware 安全响应头中间件。
|
||||||
//
|
//
|
||||||
// 为 HTTP 响应添加安全相关的头部字段,防止常见的 Web 安全漏洞。
|
// 为 HTTP 响应添加安全相关的头部字段,防止常见的 Web 安全漏洞。
|
||||||
// 支持配置各种安全头的值,并提供安全的默认配置。
|
// 支持配置各种安全头的值,并提供安全的默认配置。
|
||||||
@ -43,13 +43,13 @@ import (
|
|||||||
// 注意事项:
|
// 注意事项:
|
||||||
// - 所有方法均为并发安全
|
// - 所有方法均为并发安全
|
||||||
// - HSTS 头仅在 TLS 连接时添加
|
// - HSTS 头仅在 TLS 连接时添加
|
||||||
type SecurityHeadersMiddleware struct {
|
type HeadersMiddleware struct {
|
||||||
config *config.SecurityHeaders // 安全头配置
|
config *config.SecurityHeaders // 安全头配置
|
||||||
hsts string // 预格式化的 HSTS 头值
|
hsts string // 预格式化的 HSTS 头值
|
||||||
mu sync.RWMutex // 读写锁,保护并发访问
|
mu sync.RWMutex // 读写锁,保护并发访问
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSecurityHeaders 创建新的安全响应头中间件。
|
// NewHeaders 创建新的安全响应头中间件。
|
||||||
//
|
//
|
||||||
// 根据配置创建中间件实例,如果配置为 nil 则使用安全的默认值。
|
// 根据配置创建中间件实例,如果配置为 nil 则使用安全的默认值。
|
||||||
//
|
//
|
||||||
@ -57,12 +57,12 @@ type SecurityHeadersMiddleware struct {
|
|||||||
// - cfg: 安全头配置,可以为 nil 使用默认配置
|
// - cfg: 安全头配置,可以为 nil 使用默认配置
|
||||||
//
|
//
|
||||||
// 返回值:
|
// 返回值:
|
||||||
// - *SecurityHeadersMiddleware: 配置好的中间件实例
|
// - *HeadersMiddleware: 配置好的中间件实例
|
||||||
func NewSecurityHeaders(cfg *config.SecurityHeaders) *SecurityHeadersMiddleware {
|
func NewHeaders(cfg *config.SecurityHeaders) *HeadersMiddleware {
|
||||||
return NewSecurityHeadersWithHSTS(cfg, nil)
|
return NewHeadersWithHSTS(cfg, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSecurityHeadersWithHSTS 创建新的安全响应头中间件,支持 HSTS 配置。
|
// NewHeadersWithHSTS 创建新的安全响应头中间件,支持 HSTS 配置。
|
||||||
//
|
//
|
||||||
// 根据配置创建中间件实例,如果配置为 nil 则使用安全的默认值。
|
// 根据配置创建中间件实例,如果配置为 nil 则使用安全的默认值。
|
||||||
//
|
//
|
||||||
@ -71,9 +71,9 @@ func NewSecurityHeaders(cfg *config.SecurityHeaders) *SecurityHeadersMiddleware
|
|||||||
// - hstsCfg: HSTS 配置,可以为 nil 使用默认值
|
// - hstsCfg: HSTS 配置,可以为 nil 使用默认值
|
||||||
//
|
//
|
||||||
// 返回值:
|
// 返回值:
|
||||||
// - *SecurityHeadersMiddleware: 配置好的中间件实例
|
// - *HeadersMiddleware: 配置好的中间件实例
|
||||||
func NewSecurityHeadersWithHSTS(cfg *config.SecurityHeaders, hstsCfg *config.HSTSConfig) *SecurityHeadersMiddleware {
|
func NewHeadersWithHSTS(cfg *config.SecurityHeaders, hstsCfg *config.HSTSConfig) *HeadersMiddleware {
|
||||||
sh := &SecurityHeadersMiddleware{}
|
sh := &HeadersMiddleware{}
|
||||||
|
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
sh.config = cfg
|
sh.config = cfg
|
||||||
@ -93,7 +93,7 @@ func NewSecurityHeadersWithHSTS(cfg *config.SecurityHeaders, hstsCfg *config.HST
|
|||||||
}
|
}
|
||||||
|
|
||||||
// formatHSTSFromConfig 根据配置格式化 HSTS 头值。
|
// formatHSTSFromConfig 根据配置格式化 HSTS 头值。
|
||||||
func (sh *SecurityHeadersMiddleware) formatHSTSFromConfig(hstsCfg *config.HSTSConfig) {
|
func (sh *HeadersMiddleware) formatHSTSFromConfig(hstsCfg *config.HSTSConfig) {
|
||||||
if hstsCfg != nil {
|
if hstsCfg != nil {
|
||||||
maxAge := hstsCfg.MaxAge
|
maxAge := hstsCfg.MaxAge
|
||||||
if maxAge <= 0 {
|
if maxAge <= 0 {
|
||||||
@ -109,7 +109,7 @@ func (sh *SecurityHeadersMiddleware) formatHSTSFromConfig(hstsCfg *config.HSTSCo
|
|||||||
//
|
//
|
||||||
// 返回值:
|
// 返回值:
|
||||||
// - string: 中间件标识名 "security_headers"
|
// - string: 中间件标识名 "security_headers"
|
||||||
func (sh *SecurityHeadersMiddleware) Name() string {
|
func (sh *HeadersMiddleware) Name() string {
|
||||||
return "security_headers"
|
return "security_headers"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ func (sh *SecurityHeadersMiddleware) Name() string {
|
|||||||
//
|
//
|
||||||
// 返回值:
|
// 返回值:
|
||||||
// - fasthttp.RequestHandler: 包装后的处理器
|
// - fasthttp.RequestHandler: 包装后的处理器
|
||||||
func (sh *SecurityHeadersMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
func (sh *HeadersMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||||
return func(ctx *fasthttp.RequestCtx) {
|
return func(ctx *fasthttp.RequestCtx) {
|
||||||
// 先调用下一个处理器
|
// 先调用下一个处理器
|
||||||
next(ctx)
|
next(ctx)
|
||||||
@ -138,7 +138,7 @@ func (sh *SecurityHeadersMiddleware) Process(next fasthttp.RequestHandler) fasth
|
|||||||
//
|
//
|
||||||
// 参数:
|
// 参数:
|
||||||
// - ctx: FastHTTP 请求上下文
|
// - ctx: FastHTTP 请求上下文
|
||||||
func (sh *SecurityHeadersMiddleware) addHeaders(ctx *fasthttp.RequestCtx) {
|
func (sh *HeadersMiddleware) addHeaders(ctx *fasthttp.RequestCtx) {
|
||||||
headers := &ctx.Response.Header
|
headers := &ctx.Response.Header
|
||||||
|
|
||||||
sh.mu.RLock()
|
sh.mu.RLock()
|
||||||
@ -183,7 +183,7 @@ func (sh *SecurityHeadersMiddleware) addHeaders(ctx *fasthttp.RequestCtx) {
|
|||||||
//
|
//
|
||||||
// HSTS(HTTP Strict Transport Security)用于强制浏览器使用 HTTPS 连接。
|
// HSTS(HTTP Strict Transport Security)用于强制浏览器使用 HTTPS 连接。
|
||||||
// 默认配置为 1 年有效期,包含子域名。
|
// 默认配置为 1 年有效期,包含子域名。
|
||||||
func (sh *SecurityHeadersMiddleware) formatHSTS() {
|
func (sh *HeadersMiddleware) formatHSTS() {
|
||||||
// 默认 HSTS 值
|
// 默认 HSTS 值
|
||||||
maxAge := 31536000 // 1 年有效期(秒)
|
maxAge := 31536000 // 1 年有效期(秒)
|
||||||
includeSubDomains := true // 包含所有子域名
|
includeSubDomains := true // 包含所有子域名
|
||||||
@ -223,7 +223,7 @@ func formatHSTSValue(maxAge int, includeSubDomains bool, preload bool) string {
|
|||||||
//
|
//
|
||||||
// 参数:
|
// 参数:
|
||||||
// - cfg: 新的安全头配置
|
// - cfg: 新的安全头配置
|
||||||
func (sh *SecurityHeadersMiddleware) UpdateConfig(cfg *config.SecurityHeaders) {
|
func (sh *HeadersMiddleware) UpdateConfig(cfg *config.SecurityHeaders) {
|
||||||
sh.mu.Lock()
|
sh.mu.Lock()
|
||||||
sh.config = cfg
|
sh.config = cfg
|
||||||
sh.formatHSTS()
|
sh.formatHSTS()
|
||||||
@ -234,7 +234,7 @@ func (sh *SecurityHeadersMiddleware) UpdateConfig(cfg *config.SecurityHeaders) {
|
|||||||
//
|
//
|
||||||
// 参数:
|
// 参数:
|
||||||
// - value: 新的 X-Frame-Options 值(如 "DENY"、"SAMEORIGIN")
|
// - value: 新的 X-Frame-Options 值(如 "DENY"、"SAMEORIGIN")
|
||||||
func (sh *SecurityHeadersMiddleware) SetXFrameOptions(value string) {
|
func (sh *HeadersMiddleware) SetXFrameOptions(value string) {
|
||||||
sh.mu.Lock()
|
sh.mu.Lock()
|
||||||
if sh.config != nil {
|
if sh.config != nil {
|
||||||
sh.config.XFrameOptions = value
|
sh.config.XFrameOptions = value
|
||||||
@ -246,7 +246,7 @@ func (sh *SecurityHeadersMiddleware) SetXFrameOptions(value string) {
|
|||||||
//
|
//
|
||||||
// 参数:
|
// 参数:
|
||||||
// - value: 新的 Content-Security-Policy 值
|
// - value: 新的 Content-Security-Policy 值
|
||||||
func (sh *SecurityHeadersMiddleware) SetContentSecurityPolicy(value string) {
|
func (sh *HeadersMiddleware) SetContentSecurityPolicy(value string) {
|
||||||
sh.mu.Lock()
|
sh.mu.Lock()
|
||||||
if sh.config != nil {
|
if sh.config != nil {
|
||||||
sh.config.ContentSecurityPolicy = value
|
sh.config.ContentSecurityPolicy = value
|
||||||
@ -258,7 +258,7 @@ func (sh *SecurityHeadersMiddleware) SetContentSecurityPolicy(value string) {
|
|||||||
//
|
//
|
||||||
// 参数:
|
// 参数:
|
||||||
// - value: 新的 Referrer-Policy 值(如 "no-referrer"、"strict-origin")
|
// - value: 新的 Referrer-Policy 值(如 "no-referrer"、"strict-origin")
|
||||||
func (sh *SecurityHeadersMiddleware) SetReferrerPolicy(value string) {
|
func (sh *HeadersMiddleware) SetReferrerPolicy(value string) {
|
||||||
sh.mu.Lock()
|
sh.mu.Lock()
|
||||||
if sh.config != nil {
|
if sh.config != nil {
|
||||||
sh.config.ReferrerPolicy = value
|
sh.config.ReferrerPolicy = value
|
||||||
@ -270,7 +270,7 @@ func (sh *SecurityHeadersMiddleware) SetReferrerPolicy(value string) {
|
|||||||
//
|
//
|
||||||
// 参数:
|
// 参数:
|
||||||
// - value: 新的 Permissions-Policy 值
|
// - value: 新的 Permissions-Policy 值
|
||||||
func (sh *SecurityHeadersMiddleware) SetPermissionsPolicy(value string) {
|
func (sh *HeadersMiddleware) SetPermissionsPolicy(value string) {
|
||||||
sh.mu.Lock()
|
sh.mu.Lock()
|
||||||
if sh.config != nil {
|
if sh.config != nil {
|
||||||
sh.config.PermissionsPolicy = value
|
sh.config.PermissionsPolicy = value
|
||||||
@ -282,7 +282,7 @@ func (sh *SecurityHeadersMiddleware) SetPermissionsPolicy(value string) {
|
|||||||
//
|
//
|
||||||
// 返回值:
|
// 返回值:
|
||||||
// - *config.SecurityHeaders: 当前配置的副本
|
// - *config.SecurityHeaders: 当前配置的副本
|
||||||
func (sh *SecurityHeadersMiddleware) GetConfig() *config.SecurityHeaders {
|
func (sh *HeadersMiddleware) GetConfig() *config.SecurityHeaders {
|
||||||
sh.mu.RLock()
|
sh.mu.RLock()
|
||||||
defer sh.mu.RUnlock()
|
defer sh.mu.RUnlock()
|
||||||
return sh.config
|
return sh.config
|
||||||
@ -331,4 +331,4 @@ func DevelopmentSecurityHeaders() *config.SecurityHeaders {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证接口实现
|
// 验证接口实现
|
||||||
var _ middleware.Middleware = (*SecurityHeadersMiddleware)(nil)
|
var _ middleware.Middleware = (*HeadersMiddleware)(nil)
|
||||||
|
|||||||
@ -17,7 +17,7 @@ import (
|
|||||||
"rua.plus/lolly/internal/config"
|
"rua.plus/lolly/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewSecurityHeaders(t *testing.T) {
|
func TestNewHeaders(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
cfg *config.SecurityHeaders
|
cfg *config.SecurityHeaders
|
||||||
@ -38,22 +38,22 @@ func TestNewSecurityHeaders(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
sh := NewSecurityHeaders(tt.cfg)
|
sh := NewHeaders(tt.cfg)
|
||||||
if sh == nil {
|
if sh == nil {
|
||||||
t.Error("Expected non-nil SecurityHeadersMiddleware")
|
t.Error("Expected non-nil HeadersMiddleware")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSecurityHeadersName(t *testing.T) {
|
func TestHeadersName(t *testing.T) {
|
||||||
sh := NewSecurityHeaders(nil)
|
sh := NewHeaders(nil)
|
||||||
if sh.Name() != "security_headers" {
|
if sh.Name() != "security_headers" {
|
||||||
t.Errorf("Expected name 'security_headers', got %s", sh.Name())
|
t.Errorf("Expected name 'security_headers', got %s", sh.Name())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSecurityHeadersProcess(t *testing.T) {
|
func TestHeadersProcess(t *testing.T) {
|
||||||
cfg := &config.SecurityHeaders{
|
cfg := &config.SecurityHeaders{
|
||||||
XFrameOptions: "DENY",
|
XFrameOptions: "DENY",
|
||||||
XContentTypeOptions: "nosniff",
|
XContentTypeOptions: "nosniff",
|
||||||
@ -62,7 +62,7 @@ func TestSecurityHeadersProcess(t *testing.T) {
|
|||||||
PermissionsPolicy: "geolocation=()",
|
PermissionsPolicy: "geolocation=()",
|
||||||
}
|
}
|
||||||
|
|
||||||
sh := NewSecurityHeaders(cfg)
|
sh := NewHeaders(cfg)
|
||||||
|
|
||||||
handlerCalled := false
|
handlerCalled := false
|
||||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||||
@ -107,15 +107,14 @@ func TestSecurityHeadersProcess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSecurityHeadersHSTS(t *testing.T) {
|
func TestHeadersHSTS(_ *testing.T) {
|
||||||
cfg := &config.SecurityHeaders{
|
cfg := &config.SecurityHeaders{
|
||||||
XFrameOptions: "DENY",
|
XFrameOptions: "DENY",
|
||||||
}
|
}
|
||||||
|
|
||||||
sh := NewSecurityHeaders(cfg)
|
sh := NewHeaders(cfg)
|
||||||
|
|
||||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
nextHandler := func(_ *fasthttp.RequestCtx) {
|
||||||
_, _ = ctx.WriteString("OK")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := sh.Process(nextHandler)
|
handler := sh.Process(nextHandler)
|
||||||
@ -131,8 +130,8 @@ func TestSecurityHeadersHSTS(t *testing.T) {
|
|||||||
// In this test we verify the handler doesn't panic
|
// In this test we verify the handler doesn't panic
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSecurityHeadersUpdate(t *testing.T) {
|
func TestHeadersUpdate(t *testing.T) {
|
||||||
sh := NewSecurityHeaders(nil)
|
sh := NewHeaders(nil)
|
||||||
|
|
||||||
// Update X-Frame-Options
|
// Update X-Frame-Options
|
||||||
sh.SetXFrameOptions("SAMEORIGIN")
|
sh.SetXFrameOptions("SAMEORIGIN")
|
||||||
@ -164,7 +163,7 @@ func TestSecurityHeadersUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateConfig(t *testing.T) {
|
func TestUpdateConfig(t *testing.T) {
|
||||||
sh := NewSecurityHeaders(nil)
|
sh := NewHeaders(nil)
|
||||||
|
|
||||||
newCfg := &config.SecurityHeaders{
|
newCfg := &config.SecurityHeaders{
|
||||||
XFrameOptions: "DENY",
|
XFrameOptions: "DENY",
|
||||||
|
|||||||
@ -42,6 +42,8 @@ import (
|
|||||||
"rua.plus/lolly/internal/netutil"
|
"rua.plus/lolly/internal/netutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const rateLimitHeader = "header"
|
||||||
|
|
||||||
// RateLimiter 基于令牌桶算法的请求速率限制器。
|
// RateLimiter 基于令牌桶算法的请求速率限制器。
|
||||||
//
|
//
|
||||||
// 实现请求限流功能,支持按 IP 或自定义键值进行限流。
|
// 实现请求限流功能,支持按 IP 或自定义键值进行限流。
|
||||||
@ -133,7 +135,7 @@ func newTokenBucketLimiter(cfg *config.RateLimitConfig) (*RateLimiter, error) {
|
|||||||
switch cfg.Key {
|
switch cfg.Key {
|
||||||
case "ip", "":
|
case "ip", "":
|
||||||
rl.keyFunc = keyByIP
|
rl.keyFunc = keyByIP
|
||||||
case "header":
|
case rateLimitHeader:
|
||||||
rl.keyFunc = keyByHeader
|
rl.keyFunc = keyByHeader
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown key type: %s", cfg.Key)
|
return nil, fmt.Errorf("unknown key type: %s", cfg.Key)
|
||||||
@ -157,7 +159,7 @@ func NewSlidingWindowLimiterWrapper(cfg *config.RateLimitConfig, window time.Dur
|
|||||||
switch cfg.Key {
|
switch cfg.Key {
|
||||||
case "ip", "":
|
case "ip", "":
|
||||||
keyFunc = keyByIP
|
keyFunc = keyByIP
|
||||||
case "header":
|
case rateLimitHeader:
|
||||||
keyFunc = keyByHeader
|
keyFunc = keyByHeader
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown key type: %s", cfg.Key)
|
return nil, fmt.Errorf("unknown key type: %s", cfg.Key)
|
||||||
@ -485,13 +487,13 @@ type ConnLimiter struct {
|
|||||||
// 返回值:
|
// 返回值:
|
||||||
// - *ConnLimiter: 配置好的连接限制器
|
// - *ConnLimiter: 配置好的连接限制器
|
||||||
// - error: 配置无效时返回错误
|
// - error: 配置无效时返回错误
|
||||||
func NewConnLimiter(max int, perKey bool, keyType string) (*ConnLimiter, error) {
|
func NewConnLimiter(maxConns int, perKey bool, keyType string) (*ConnLimiter, error) {
|
||||||
if max <= 0 {
|
if maxConns <= 0 {
|
||||||
return nil, errors.New("max connections must be positive")
|
return nil, errors.New("max connections must be positive")
|
||||||
}
|
}
|
||||||
|
|
||||||
cl := &ConnLimiter{
|
cl := &ConnLimiter{
|
||||||
max: max,
|
max: maxConns,
|
||||||
perKey: perKey,
|
perKey: perKey,
|
||||||
counts: make(map[string]int64),
|
counts: make(map[string]int64),
|
||||||
}
|
}
|
||||||
@ -500,7 +502,7 @@ func NewConnLimiter(max int, perKey bool, keyType string) (*ConnLimiter, error)
|
|||||||
switch keyType {
|
switch keyType {
|
||||||
case "ip", "":
|
case "ip", "":
|
||||||
cl.keyFunc = keyByIP
|
cl.keyFunc = keyByIP
|
||||||
case "header":
|
case rateLimitHeader:
|
||||||
cl.keyFunc = keyByHeader
|
cl.keyFunc = keyByHeader
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown key type: %s", keyType)
|
return nil, fmt.Errorf("unknown key type: %s", keyType)
|
||||||
|
|||||||
@ -206,7 +206,7 @@ func (s *SlidingWindowLimiter) Cleanup(maxAge time.Duration) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stats 返回限流器统计信息。
|
// SlidingWindowStats 返回限流器统计信息。
|
||||||
type SlidingWindowStats struct {
|
type SlidingWindowStats struct {
|
||||||
Window time.Duration // 窗口大小
|
Window time.Duration // 窗口大小
|
||||||
Limit int // 窗口内最大请求数
|
Limit int // 窗口内最大请求数
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user