diff --git a/internal/middleware/security/access.go b/internal/middleware/security/access.go index b8ec439..d5d5d70 100644 --- a/internal/middleware/security/access.go +++ b/internal/middleware/security/access.go @@ -200,51 +200,41 @@ func (ac *AccessControl) Check(ip net.IP) bool { ac.mu.RLock() defer ac.mu.RUnlock() - // 1. 先检查 CIDR 拒绝列表(显式拒绝优先) for _, network := range ac.denyList { if network.Contains(ip) { return false } } - // 2. 检查 GeoIP 国家拒绝(如果启用) + var country string + var hasCountry bool if ac.geoip != nil && ac.geoipConfig.Enabled { - country, err := ac.geoip.LookupCountry(ip) - if err == nil { - // 处理私有 IP 特殊标记 - if country == geoPrivateAllow { - // 私有 IP 自动允许,跳过国家检查 - goto checkAllow - } - if country == geoPrivateDeny { + if c, err := ac.geoip.LookupCountry(ip); err == nil { + if c == geoPrivateDeny { return false } - - if slices.Contains(ac.geoipConfig.DenyCountries, country) { - return false + if c != geoPrivateAllow { + country = c + hasCountry = true + if slices.Contains(ac.geoipConfig.DenyCountries, country) { + return false + } } } } -checkAllow: - // 3. 检查 CIDR 允许列表 for _, network := range ac.allowList { if network.Contains(ip) { return true } } - // 4. 检查 GeoIP 国家允许(如果启用) - if ac.geoip != nil && ac.geoipConfig.Enabled { - country, err := ac.geoip.LookupCountry(ip) - if err == nil && country != geoPrivateDeny { - if slices.Contains(ac.geoipConfig.AllowCountries, country) { - return true - } + if hasCountry { + if slices.Contains(ac.geoipConfig.AllowCountries, country) { + return true } } - // 5. 返回默认操作 return ac.defaultAction == ActionAllow } @@ -360,56 +350,43 @@ func (ac *AccessControl) SetDefault(action string) error { func (ac *AccessControl) getClientIP(ctx *fasthttp.RequestCtx) net.IP { remoteIP := netutil.GetRemoteAddrIP(ctx) - // 仅当配置了可信代理且请求来自可信代理时,才解析 X-Forwarded-For - if len(ac.trustedProxies) > 0 && remoteIP != nil { - isTrusted := false - for _, network := range ac.trustedProxies { - if network.Contains(remoteIP) { - isTrusted = true - break - } - } + if len(ac.trustedProxies) == 0 || remoteIP == nil { + return remoteIP + } - if isTrusted { - // 使用右侧(最接近客户端)的非可信 IP - if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 { - ips := strings.Split(string(xff), ",") - for _, v := range slices.Backward(ips) { - ipStr := strings.TrimSpace(v) - if ip := net.ParseIP(ipStr); ip != nil { - // 检查该 IP 是否在可信代理列表中 - trusted := false - for _, network := range ac.trustedProxies { - if network.Contains(ip) { - trusted = true - break - } - } - if !trusted { - return ip - } + isTrusted := false + for _, network := range ac.trustedProxies { + if network.Contains(remoteIP) { + isTrusted = true + break + } + } + if !isTrusted { + return remoteIP + } + + if xff := ctx.Request.Header.Peek("X-Forwarded-For"); len(xff) > 0 { + ips := strings.Split(string(xff), ",") + for _, v := range slices.Backward(ips) { + ipStr := strings.TrimSpace(v) + if ip := net.ParseIP(ipStr); ip != nil { + trusted := false + for _, network := range ac.trustedProxies { + if network.Contains(ip) { + trusted = true + break } } + if !trusted { + return ip + } } } } - // 检查 X-Real-IP 头部(仅来自可信代理时) - if len(ac.trustedProxies) > 0 && remoteIP != nil { - isTrusted := false - for _, network := range ac.trustedProxies { - if network.Contains(remoteIP) { - isTrusted = true - break - } - } - - if isTrusted { - if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 { - if ip := net.ParseIP(string(xri)); ip != nil { - return ip - } - } + if xri := ctx.Request.Header.Peek("X-Real-IP"); len(xri) > 0 { + if ip := net.ParseIP(string(xri)); ip != nil { + return ip } }