feat(geoip): 添加基于国家代码的 GeoIP 访问控制功能

- 新增 GeoIPConfig 配置结构,支持 MaxMind MMDB 数据库
- 实现 GeoIPLookup 查询器,带 LRU 缓存和 TTL 支持
- AccessControl 集成 GeoIP 检查,按国家代码过滤请求
- 支持私有 IP 特殊处理策略 (allow/deny)
- 添加完整的单元测试和配置验证测试
- 新增 stream-udp.conf 示例配置文档

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-13 16:14:44 +08:00
parent 4f6a7be44c
commit 103e8ff0cf
7 changed files with 995 additions and 25 deletions

View File

@ -0,0 +1,110 @@
# ============================================================
# Lolly UDP Stream 代理配置示例
# ============================================================
#
# 功能说明:
# - UDP 四层代理DNS、游戏服务器、VoIP 等)
# - 会话管理和超时控制
# - 负载均衡支持
#
# Lolly 对应配置:
# stream:
# - listen: ":53"
# protocol: "udp"
# timeout: 60s # 默认值
# upstream:
# targets:
# - addr: "dns1:53"
# weight: 1
# - addr: "dns2:53"
# weight: 1
# load_balance: "round_robin"
# ============================================================
# ------------------------------------------------------------
# DNS UDP 代理配置示例
# ------------------------------------------------------------
#
# YAML 格式:
stream:
# DNS 服务器代理
- listen: ":53"
protocol: "udp"
timeout: 60s # 会话超时,默认 60 秒
# DNS 请求通常很快,可设置较短如 30s
upstream:
targets:
- addr: "8.8.8.8:53"
weight: 1
- addr: "8.8.4.4:53"
weight: 1
load_balance: "round_robin"
# ------------------------------------------------------------
# 游戏服务器 UDP 代理配置示例
# ------------------------------------------------------------
#
# 游戏服务器需要长会话,建议使用 ip_hash 保持会话一致性
#
stream:
- listen: ":27015"
protocol: "udp"
timeout: 300s # 游戏会话较长,设置 5 分钟超时
upstream:
targets:
- addr: "game1:27015"
weight: 3
- addr: "game2:27015"
weight: 1
load_balance: "ip_hash" # IP Hash 保持会话一致性
# ------------------------------------------------------------
# VoIP SIP UDP 代理配置示例
# ------------------------------------------------------------
#
# VoIP 服务需要稳定连接
#
stream:
- listen: ":5060"
protocol: "udp"
timeout: 180s # VoIP 会话超时 3 分钟
upstream:
targets:
- addr: "sip1:5060"
weight: 1
- addr: "sip2:5060"
weight: 1
load_balance: "least_conn" # 最少连接,均匀分配
# ============================================================
# UDP Stream 配置参数详解
# ============================================================
#
# 1. UDP vs TCP:
# - UDP: 无连接,数据报协议,适合实时应用
# - TCP: 有连接,流协议,适合可靠传输
#
# 2. 会话管理:
# - Lolly 自动管理 UDP 会话
# - 同一客户端 IP 映射到同一后端
# - 会话空闲超时后自动清理
#
# 3. timeout 参数:
# - 默认值: 60 秒 (未配置时使用)
# - DNS: 建议 30-60 秒
# - 游戏服务器: 建议 300 秒 (5 分钟)
# - VoIP: 建议 180 秒 (3 分钟)
# - Syslog: 建议 60 秒
#
# 4. 负载均衡算法:
# - round_robin: 轮询(默认)
# - weighted_round_robin: 加权轮询
# - least_conn: 最少连接
# - ip_hash: IP 哈希推荐游戏服务器、VoIP
#
# 5. 适用场景:
# - DNS 服务器代理
# - 游戏服务器代理 (CS2, Minecraft 等)
# - VoIP 服务代理 (SIP, RTP)
# - 日志收集服务 (Syslog)
# - NTP 时间服务器

View File

@ -275,17 +275,50 @@ type StaticConfig struct {
// interval: 10s
// path: "/health"
type ProxyConfig struct {
Path string `yaml:"path"`
LoadBalance string `yaml:"load_balance"`
HashKey string `yaml:"hash_key"`
ClientMaxBodySize string `yaml:"client_max_body_size"`
Headers ProxyHeaders `yaml:"headers"`
Targets []ProxyTarget `yaml:"targets"`
HealthCheck HealthCheckConfig `yaml:"health_check"`
NextUpstream NextUpstreamConfig `yaml:"next_upstream"`
Cache ProxyCacheConfig `yaml:"cache"`
Timeout ProxyTimeout `yaml:"timeout"`
VirtualNodes int `yaml:"virtual_nodes"`
Path string `yaml:"path"`
LoadBalance string `yaml:"load_balance"`
HashKey string `yaml:"hash_key"`
ClientMaxBodySize string `yaml:"client_max_body_size"`
Headers ProxyHeaders `yaml:"headers"`
Targets []ProxyTarget `yaml:"targets"`
BalancerByLua BalancerByLuaConfig `yaml:"balancer_by_lua"`
HealthCheck HealthCheckConfig `yaml:"health_check"`
NextUpstream NextUpstreamConfig `yaml:"next_upstream"`
Cache ProxyCacheConfig `yaml:"cache"`
Timeout ProxyTimeout `yaml:"timeout"`
VirtualNodes int `yaml:"virtual_nodes"`
}
// BalancerByLuaConfig Lua 负载均衡配置
//
// 使用 Lua 脚本动态选择后端目标,支持自定义负载均衡逻辑。
//
// 注意事项:
// - Script 为 Lua 脚本文件路径
// - Timeout 控制脚本执行超时
// - Fallback 指定 Lua 失败时的备用算法
//
// 使用示例:
//
// balancer_by_lua:
// enabled: true
// script: "/etc/lolly/scripts/balancer.lua"
// timeout: 100ms
// fallback: "round_robin"
type BalancerByLuaConfig struct {
// Script Lua 脚本路径
Script string `yaml:"script"`
// Fallback 失败时使用的默认负载均衡算法
// 默认值: "round_robin"
Fallback string `yaml:"fallback"`
// Timeout 执行超时
// 默认值: 100ms
Timeout time.Duration `yaml:"timeout"`
// Enabled 是否启用
Enabled bool `yaml:"enabled"`
}
// ProxyTarget 后端目标配置。
@ -597,12 +630,13 @@ type SecurityConfig struct {
// AccessConfig IP 访问控制配置。
//
// 通过 IP 地址或 CIDR 范围控制访问权限
// 通过 IP 地址或 CIDR 范围控制访问权限,支持基于 GeoIP 的国家代码访问控制
//
// 注意事项:
// - Allow 和 Deny 列表按配置顺序匹配
// - Default 指定未匹配时的默认动作
// - TrustedProxies 用于正确获取客户端真实 IP
// - GeoIP 配置启用后,会基于国家代码进行二次检查
// - 支持 IPv4 和 IPv6 地址格式
//
// 使用示例:
@ -612,6 +646,15 @@ type SecurityConfig struct {
// deny: ["192.168.1.100"]
// default: "deny"
// trusted_proxies: ["172.16.0.0/16"]
// geoip:
// enabled: true
// database: "/var/lib/geoip/GeoIP2-Country.mmdb"
// allow_countries: ["US", "JP", "GB"]
// deny_countries: ["CN", "RU"]
// default: "deny"
// cache_size: 10000
// cache_ttl: 1h
// private_ip_behavior: "allow"
type AccessConfig struct {
// Allow 允许的 IP/CIDR 列表
// 配置允许访问的 IP 地址或网段
@ -621,13 +664,49 @@ type AccessConfig struct {
// 配置拒绝访问的 IP 地址或网段
Deny []string `yaml:"deny"`
// TrustedProxies 可信代理 CIDR 列表
// 用于正确解析 X-Forwarded-For 头部获取真实客户端 IP
TrustedProxies []string `yaml:"trusted_proxies"`
// Default 默认动作
// 未匹配任何规则时的处理方式allow 或 deny
Default string `yaml:"default"`
// TrustedProxies 可信代理 CIDR 列表
// 用于正确解析 X-Forwarded-For 头部获取真实客户端 IP
TrustedProxies []string `yaml:"trusted_proxies"`
// GeoIP GeoIP 国家代码访问控制配置
GeoIP GeoIPConfig `yaml:"geoip"`
}
// GeoIPConfig GeoIP 访问控制配置。
//
// 通过 MaxMind GeoIP2 数据库查询 IP 所属国家,实现基于国家代码的访问控制。
//
// 注意事项:
// - Database 为 GeoIP2 数据库文件路径(.mmdb 格式)
// - AllowCountries 和 DenyCountries 使用 ISO 3166-1 alpha-2 国家代码
// - CacheSize 设置 LRU 缓存最大条目数0 表示使用默认值 10000
// - CacheTTL 设置缓存有效期0 表示使用默认值 1 小时
// - PrivateIPBehavior 控制私有 IP 的处理策略
//
// 使用示例:
//
// geoip:
// enabled: true
// database: "/var/lib/geoip/GeoIP2-Country.mmdb"
// allow_countries: ["US", "JP", "GB"]
// deny_countries: ["CN", "RU"]
// default: "deny"
// cache_size: 10000
// cache_ttl: 1h
// private_ip_behavior: "allow"
type GeoIPConfig struct {
Database string `yaml:"database"`
Default string `yaml:"default"`
PrivateIPBehavior string `yaml:"private_ip_behavior"`
AllowCountries []string `yaml:"allow_countries"`
DenyCountries []string `yaml:"deny_countries"`
CacheSize int `yaml:"cache_size"`
CacheTTL time.Duration `yaml:"cache_ttl"`
Enabled bool `yaml:"enabled"`
}
// RateLimitConfig 速率限制配置。
@ -1536,7 +1615,7 @@ func Save(cfg *Config, path string) error {
return fmt.Errorf("序列化配置失败: %w", err)
}
if err := os.WriteFile(path, data, 0644); err != nil {
if err := os.WriteFile(path, data, 0o644); err != nil {
return fmt.Errorf("写入配置文件失败: %w", err)
}

View File

@ -543,9 +543,89 @@ func validateAccess(a *AccessConfig) error {
return fmt.Errorf("无效的 default 动作: %s仅允许 allow 或 deny", a.Default)
}
// 验证 GeoIP 配置
if err := validateGeoIP(&a.GeoIP); err != nil {
return fmt.Errorf("geoip: %w", err)
}
return nil
}
// validateGeoIP 验证 GeoIP 配置。
//
// 检查 GeoIP 数据库路径、国家代码格式、缓存设置等。
//
// 参数:
// - g: GeoIP 配置对象
//
// 返回值:
// - error: 验证失败时返回错误信息,成功返回 nil
func validateGeoIP(g *GeoIPConfig) error {
// 未启用时跳过验证
if !g.Enabled {
return nil
}
// 验证数据库路径
if g.Database == "" {
return errors.New("database 是必填项(启用 GeoIP 时)")
}
// 验证国家代码格式 (ISO 3166-1 alpha-2)
for _, c := range g.AllowCountries {
if !isValidCountryCode(c) {
return fmt.Errorf("无效的 allow_countries 国家代码: %s应为 2 位大写字母)", c)
}
}
for _, c := range g.DenyCountries {
if !isValidCountryCode(c) {
return fmt.Errorf("无效的 deny_countries 国家代码: %s应为 2 位大写字母)", c)
}
}
// 验证 PrivateIPBehavior
validBehaviors := []string{"", "allow", "deny", "bypass"}
if !slices.Contains(validBehaviors, g.PrivateIPBehavior) {
return fmt.Errorf("无效的 private_ip_behavior: %s仅支持 allow, deny, bypass", g.PrivateIPBehavior)
}
// 验证缓存大小
if g.CacheSize < 0 {
return errors.New("cache_size 不能为负数")
}
// 验证缓存 TTL
if g.CacheTTL < 0 {
return errors.New("cache_ttl 不能为负数")
}
// 验证默认动作
if g.Default != "" && g.Default != "allow" && g.Default != "deny" {
return fmt.Errorf("无效的 default 动作: %s仅允许 allow 或 deny", g.Default)
}
return nil
}
// isValidCountryCode 验证 ISO 3166-1 alpha-2 国家代码。
//
// 参数:
// - code: 国家代码字符串
//
// 返回值:
// - bool: true 表示有效的国家代码
func isValidCountryCode(code string) bool {
if len(code) != 2 {
return false
}
for _, c := range code {
if c < 'A' || c > 'Z' {
return false
}
}
return true
}
// validateAuth 验证认证配置。
//
// 检查认证类型、哈希算法和用户列表的有效性。

View File

@ -0,0 +1,171 @@
// Package config 提供 YAML 配置文件的解析、验证和默认配置生成功能测试。
//
// 该文件包含 GeoIP 配置验证相关的测试。
//
// 作者xfy
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestValidateGeoIP 测试 GeoIP 配置验证。
func TestValidateGeoIP(t *testing.T) {
tests := []struct {
name string
errMsg string
config GeoIPConfig
wantErr bool
}{
{
name: "未启用时跳过验证",
config: GeoIPConfig{
Enabled: false,
},
wantErr: false,
},
{
name: "启用但缺少数据库路径",
config: GeoIPConfig{
Enabled: true,
Database: "",
},
wantErr: true,
errMsg: "database 是必填项",
},
{
name: "有效的 GeoIP 配置",
config: GeoIPConfig{
Enabled: true,
Database: "/var/lib/geoip/GeoIP2-Country.mmdb",
AllowCountries: []string{"US", "JP"},
DenyCountries: []string{"CN"},
Default: "deny",
CacheSize: 10000,
PrivateIPBehavior: "allow",
},
wantErr: false,
},
{
name: "无效的国家代码(小写)",
config: GeoIPConfig{
Enabled: true,
Database: "/var/lib/geoip/GeoIP2-Country.mmdb",
AllowCountries: []string{"us"},
},
wantErr: true,
errMsg: "无效的 allow_countries",
},
{
name: "无效的国家代码3位",
config: GeoIPConfig{
Enabled: true,
Database: "/var/lib/geoip/GeoIP2-Country.mmdb",
DenyCountries: []string{"USA"},
},
wantErr: true,
errMsg: "无效的 deny_countries",
},
{
name: "无效的 private_ip_behavior",
config: GeoIPConfig{
Enabled: true,
Database: "/var/lib/geoip/GeoIP2-Country.mmdb",
PrivateIPBehavior: "invalid",
},
wantErr: true,
errMsg: "无效的 private_ip_behavior",
},
{
name: "负的 cache_size",
config: GeoIPConfig{
Enabled: true,
Database: "/var/lib/geoip/GeoIP2-Country.mmdb",
CacheSize: -1,
},
wantErr: true,
errMsg: "cache_size 不能为负数",
},
{
name: "负的 cache_ttl",
config: GeoIPConfig{
Enabled: true,
Database: "/var/lib/geoip/GeoIP2-Country.mmdb",
CacheTTL: -1,
},
wantErr: true,
errMsg: "cache_ttl 不能为负数",
},
{
name: "无效的 default 动作",
config: GeoIPConfig{
Enabled: true,
Database: "/var/lib/geoip/GeoIP2-Country.mmdb",
Default: "invalid",
},
wantErr: true,
errMsg: "无效的 default",
},
{
name: "有效的 private_ip_behavior: deny",
config: GeoIPConfig{
Enabled: true,
Database: "/var/lib/geoip/GeoIP2-Country.mmdb",
PrivateIPBehavior: "deny",
},
wantErr: false,
},
{
name: "有效的 private_ip_behavior: bypass",
config: GeoIPConfig{
Enabled: true,
Database: "/var/lib/geoip/GeoIP2-Country.mmdb",
PrivateIPBehavior: "bypass",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateGeoIP(&tt.config)
if tt.wantErr {
assert.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
// TestIsValidCountryCode 测试国家代码验证。
func TestIsValidCountryCode(t *testing.T) {
tests := []struct {
code string
expected bool
}{
{"US", true},
{"JP", true},
{"GB", true},
{"CN", true},
{"us", false}, // 小写
{"Us", false}, // 混合大小写
{"USA", false}, // 3位
{"U", false}, // 1位
{"U1", false}, // 包含数字
{"U-", false}, // 包含连字符
{"", false}, // 空字符串
}
for _, tt := range tests {
t.Run(tt.code, func(t *testing.T) {
result := isValidCountryCode(tt.code)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@ -29,6 +29,7 @@ import (
"net"
"strings"
"sync"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
@ -44,18 +45,22 @@ const (
// ActionDeny 拒绝请求(返回 403 Forbidden
ActionDeny
accessAllow = "allow"
accessDeny = "deny"
accessAllow = "allow"
accessDeny = "deny"
geoPrivateAllow = "PRIVATE_ALLOW"
geoPrivateDeny = "PRIVATE_DENY"
)
// AccessControl 实现 IP 访问控制中间件。
//
// 根据配置的允许/拒绝 CIDR 列表检查入站请求。
// 支持动态更新访问控制列表
// 根据配置的允许/拒绝 CIDR 列表和 GeoIP 国家代码检查入站请求。
// 支持动态更新访问控制列表和 GeoIP 配置
type AccessControl struct {
geoip *GeoIPLookup
allowList []net.IPNet
denyList []net.IPNet
trustedProxies []net.IPNet
geoipConfig config.GeoIPConfig
defaultAction Action
mu sync.RWMutex
}
@ -114,6 +119,31 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) {
return nil, fmt.Errorf("invalid default action: %s", cfg.Default)
}
// 初始化 GeoIP如果启用
if cfg.GeoIP.Enabled && cfg.GeoIP.Database != "" {
// 设置默认值
cacheSize := cfg.GeoIP.CacheSize
if cacheSize <= 0 {
cacheSize = 10000 // 默认 10000 条
}
ttl := cfg.GeoIP.CacheTTL
if ttl <= 0 {
ttl = time.Hour // 默认 1 小时
}
geoip, err := NewGeoIPLookup(
cfg.GeoIP.Database,
cacheSize,
ttl,
cfg.GeoIP.PrivateIPBehavior,
)
if err != nil {
return nil, fmt.Errorf("init geoip: %w", err)
}
ac.geoip = geoip
ac.geoipConfig = cfg.GeoIP
}
return ac, nil
}
@ -150,7 +180,12 @@ func (ac *AccessControl) Process(next fasthttp.RequestHandler) fasthttp.RequestH
// Check 检查 IP 地址是否允许访问。
//
// 评估顺序:先检查拒绝列表,再检查允许列表,最后使用默认操作。
// 评估顺序:
// 1. 检查 CIDR 拒绝列表(显式拒绝优先)
// 2. 检查 GeoIP 国家拒绝(如果启用)
// 3. 检查 CIDR 允许列表
// 4. 检查 GeoIP 国家允许(如果启用)
// 5. 返回默认操作
//
// 参数:
// - ip: 待检查的 IP 地址
@ -161,21 +196,55 @@ func (ac *AccessControl) Check(ip net.IP) bool {
ac.mu.RLock()
defer ac.mu.RUnlock()
// 先检查拒绝列表(显式拒绝优先)
// 1. 先检查 CIDR 拒绝列表(显式拒绝优先)
for _, network := range ac.denyList {
if network.Contains(ip) {
return false
}
}
// 检查允许列表
// 2. 检查 GeoIP 国家拒绝(如果启用)
if ac.geoip != nil && ac.geoipConfig.Enabled {
country, err := ac.geoip.LookupCountry(ip)
if err == nil {
// 处理私有 IP 特殊标记
if country == geoPrivateAllow {
// 私有 IP 自动允许,跳过国家检查
goto checkAllow
}
if country == geoPrivateDeny {
return false
}
for _, c := range ac.geoipConfig.DenyCountries {
if country == c {
return false
}
}
}
}
checkAllow:
// 3. 检查 CIDR 允许列表
for _, network := range ac.allowList {
if network.Contains(ip) {
return true
}
}
// 返回默认操作
// 4. 检查 GeoIP 国家允许(如果启用)
if ac.geoip != nil && ac.geoipConfig.Enabled {
country, err := ac.geoip.LookupCountry(ip)
if err == nil && country != geoPrivateDeny {
for _, c := range ac.geoipConfig.AllowCountries {
if country == c {
return true
}
}
}
}
// 5. 返回默认操作
return ac.defaultAction == ActionAllow
}
@ -428,7 +497,7 @@ func (ac *AccessControl) GetStats() AccessStats {
func actionToString(action Action) string {
switch action {
case ActionAllow:
return "allow"
return accessAllow
case ActionDeny:
return accessDeny
default:
@ -436,5 +505,21 @@ func actionToString(action Action) string {
}
}
// Close 释放资源。
//
// 必须在服务器停止时调用,释放 GeoIP 数据库连接。
//
// 返回值:
// - error: 关闭失败时返回错误
func (ac *AccessControl) Close() error {
ac.mu.Lock()
defer ac.mu.Unlock()
if ac.geoip != nil {
return ac.geoip.Close()
}
return nil
}
// 验证接口实现
var _ middleware.Middleware = (*AccessControl)(nil)

View File

@ -0,0 +1,241 @@
// Package security 提供安全相关的 HTTP 中间件。
//
// 该文件实现 GeoIP 查询功能,支持基于国家代码的访问控制,
// 使用 LRU 缓存提高查询性能。
//
// 使用示例:
//
// geoip, err := security.NewGeoIPLookup("/var/lib/geoip/GeoIP2-Country.mmdb", 10000, time.Hour, "allow")
// if err != nil {
// log.Fatal(err)
// }
// defer geoip.Close()
//
// country, err := geoip.LookupCountry(ip)
// if err != nil {
// log.Printf("GeoIP lookup failed: %v", err)
// }
//
// 作者xfy
package security
import (
"errors"
"fmt"
"net"
"sync"
"time"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/oschwald/geoip2-golang"
)
// GeoIPLookup GeoIP 查询器(带 LRU 缓存)。
//
// 使用 MaxMind GeoIP2 数据库查询 IP 地址所属国家,
// 通过 LRU 缓存减少数据库查询次数,提高性能。
type GeoIPLookup struct {
db *geoip2.Reader
cache *lru.Cache[string, *cachedCountry]
privateIPBehavior string
ttl time.Duration
mu sync.RWMutex
}
// cachedCountry 缓存的国家代码条目。
type cachedCountry struct {
expires time.Time
country string
}
// GeoIPStats GeoIP 缓存统计信息。
type GeoIPStats struct {
CacheSize int
CacheMaxSize int
}
// NewGeoIPLookup 创建 GeoIP 查询器。
//
// 打开 GeoIP2 数据库文件并初始化 LRU 缓存。
//
// 参数:
// - dbPath: GeoIP2 数据库文件路径(.mmdb 格式)
// - cacheSize: LRU 缓存最大条目数(硬限制)
// - ttl: 缓存条目有效期
// - privateIPBehavior: 私有 IP 处理策略("allow", "deny", "bypass"
//
// 返回值:
// - *GeoIPLookup: 查询器实例
// - error: 数据库打开失败或缓存创建失败时返回错误
func NewGeoIPLookup(dbPath string, cacheSize int, ttl time.Duration, privateIPBehavior string) (*GeoIPLookup, error) {
if dbPath == "" {
return nil, errors.New("geoip database path is required")
}
// 打开 GeoIP2 数据库
db, err := geoip2.Open(dbPath)
if err != nil {
return nil, fmt.Errorf("open geoip database: %w", err)
}
// 创建 LRU 缓存
cache, err := lru.New[string, *cachedCountry](cacheSize)
if err != nil {
db.Close()
return nil, fmt.Errorf("create lru cache: %w", err)
}
// 默认私有 IP 行为
if privateIPBehavior == "" {
privateIPBehavior = "allow"
}
return &GeoIPLookup{
db: db,
cache: cache,
ttl: ttl,
privateIPBehavior: privateIPBehavior,
}, nil
}
// LookupCountry 查询 IP 所属国家。
//
// 返回 ISO 3166-1 alpha-2 国家代码(如 "CN", "US")。
// 查询结果会被缓存,减少数据库访问。
//
// 参数:
// - ip: 待查询的 IP 地址
//
// 返回值:
// - string: ISO 3166-1 alpha-2 国家代码
// - error: 查询失败时返回错误
func (g *GeoIPLookup) LookupCountry(ip net.IP) (string, error) {
// 检查私有 IP
if isPrivateIP(ip) {
switch g.privateIPBehavior {
case "allow":
return "PRIVATE_ALLOW", nil // 特殊标记,表示允许
case accessDeny:
return "PRIVATE_DENY", nil // 特殊标记,表示拒绝
case "bypass":
return "", errors.New("private IP bypassed")
}
}
ipStr := ip.String()
// 检查缓存(读锁)
g.mu.RLock()
if cached, ok := g.cache.Get(ipStr); ok {
if time.Now().Before(cached.expires) {
g.mu.RUnlock()
return cached.country, nil
}
}
g.mu.RUnlock()
// 查询数据库(写锁)
g.mu.Lock()
defer g.mu.Unlock()
// 双重检查(可能已被其他 goroutine 更新)
if cached, ok := g.cache.Get(ipStr); ok {
if time.Now().Before(cached.expires) {
return cached.country, nil
}
}
// 查询数据库
country, err := g.db.Country(ip)
if err != nil {
return "", fmt.Errorf("geoip lookup: %w", err)
}
isoCode := country.Country.IsoCode
if isoCode == "" {
isoCode = "UNKNOWN"
}
// 存入缓存
g.cache.Add(ipStr, &cachedCountry{
country: isoCode,
expires: time.Now().Add(g.ttl),
})
return isoCode, nil
}
// Close 关闭数据库连接。
//
// 必须在服务器停止时调用,释放 GeoIP2 数据库资源。
//
// 返回值:
// - error: 关闭失败时返回错误
func (g *GeoIPLookup) Close() error {
g.mu.Lock()
defer g.mu.Unlock()
if g.db != nil {
return g.db.Close()
}
return nil
}
// GetStats 返回缓存统计信息。
//
// 返回值:
// - GeoIPStats: 包含当前缓存大小和最大缓存大小的统计对象
func (g *GeoIPLookup) GetStats() GeoIPStats {
g.mu.RLock()
defer g.mu.RUnlock()
return GeoIPStats{
CacheSize: g.cache.Len(),
CacheMaxSize: g.cache.Len(), // LRU 缓存容量与当前大小相同(已淘汰的已被移除)
}
}
// isPrivateIP 检查是否为私有 IP 地址。
//
// 支持的私有地址范围:
// - 10.0.0.0/8
// - 172.16.0.0/12
// - 192.168.0.0/16
// - 127.0.0.0/8回环
// - IPv6 本地地址
//
// 参数:
// - ip: 待检查的 IP 地址
//
// 返回值:
// - bool: true 表示是私有 IP
func isPrivateIP(ip net.IP) bool {
// IPv4 私有地址范围
privateBlocks := []string{
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"127.0.0.0/8",
}
for _, cidr := range privateBlocks {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
continue
}
if network.Contains(ip) {
return true
}
}
// IPv6 私有地址
if ip.To4() == nil {
// IPv6 本地地址
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
}
return false
}

View File

@ -0,0 +1,204 @@
// Package security 提供安全相关的 HTTP 中间件测试。
//
// 该文件包含 GeoIP 查询功能的单元测试。
//
// 作者xfy
package security
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestIsPrivateIP 测试私有 IP 检测功能。
func TestIsPrivateIP(t *testing.T) {
tests := []struct {
name string
ip string
expected bool
}{
{"IPv4 私有地址 10.x.x.x", "10.0.0.1", true},
{"IPv4 私有地址 172.16.x.x", "172.16.0.1", true},
{"IPv4 私有地址 172.31.x.x", "172.31.255.1", true},
{"IPv4 私有地址 192.168.x.x", "192.168.1.1", true},
{"IPv4 回环地址", "127.0.0.1", true},
{"IPv4 公网地址", "8.8.8.8", false},
{"IPv4 公网地址", "1.1.1.1", false},
{"IPv6 回环地址", "::1", true},
{"IPv6 本地链路地址", "fe80::1", true},
{"IPv6 公网地址", "2001:4860:4860::8888", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
require.NotNil(t, ip, "failed to parse IP: %s", tt.ip)
result := isPrivateIP(ip)
assert.Equal(t, tt.expected, result, "isPrivateIP(%s)", tt.ip)
})
}
}
// TestNewGeoIPLookup_InvalidPath 测试无效数据库路径。
func TestNewGeoIPLookup_InvalidPath(t *testing.T) {
_, err := NewGeoIPLookup("", 1000, time.Hour, "allow")
assert.Error(t, err)
assert.Contains(t, err.Error(), "database path is required")
}
// TestNewGeoIPLookup_NonExistentDB 测试不存在的数据库文件。
func TestNewGeoIPLookup_NonExistentDB(t *testing.T) {
_, err := NewGeoIPLookup("/nonexistent/path/to/geoip.mmdb", 1000, time.Hour, "allow")
assert.Error(t, err)
assert.Contains(t, err.Error(), "open geoip database")
}
// TestGeoIPLookup_PrivateIPBehavior 测试私有 IP 处理策略。
func TestGeoIPLookup_PrivateIPBehavior(t *testing.T) {
// 注意:这个测试需要有效的 GeoIP2 数据库文件
// 如果没有数据库文件,测试会被跳过
dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb"
// 尝试创建 GeoIPLookup
geoip, err := NewGeoIPLookup(dbPath, 1000, time.Hour, "allow")
if err != nil {
t.Skipf("Skipping test: GeoIP database not available: %v", err)
}
defer geoip.Close()
privateIP := net.ParseIP("192.168.1.1")
// 测试 allow 策略
country, err := geoip.LookupCountry(privateIP)
require.NoError(t, err)
assert.Equal(t, "PRIVATE_ALLOW", country)
}
// TestGeoIPLookup_PrivateIPBehavior_Deny 测试私有 IP deny 策略。
func TestGeoIPLookup_PrivateIPBehavior_Deny(t *testing.T) {
dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb"
geoip, err := NewGeoIPLookup(dbPath, 1000, time.Hour, "deny")
if err != nil {
t.Skipf("Skipping test: GeoIP database not available: %v", err)
}
defer geoip.Close()
privateIP := net.ParseIP("10.0.0.1")
country, err := geoip.LookupCountry(privateIP)
require.NoError(t, err)
assert.Equal(t, "PRIVATE_DENY", country)
}
// TestGeoIPLookup_PrivateIPBehavior_Bypass 测试私有 IP bypass 策略。
func TestGeoIPLookup_PrivateIPBehavior_Bypass(t *testing.T) {
dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb"
geoip, err := NewGeoIPLookup(dbPath, 1000, time.Hour, "bypass")
if err != nil {
t.Skipf("Skipping test: GeoIP database not available: %v", err)
}
defer geoip.Close()
privateIP := net.ParseIP("172.16.0.1")
_, err = geoip.LookupCountry(privateIP)
assert.Error(t, err)
assert.Contains(t, err.Error(), "private IP bypassed")
}
// TestGeoIPLookup_DefaultPrivateIPBehavior 测试默认私有 IP 行为。
func TestGeoIPLookup_DefaultPrivateIPBehavior(t *testing.T) {
dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb"
// 空字符串应该使用默认的 "allow"
geoip, err := NewGeoIPLookup(dbPath, 1000, time.Hour, "")
if err != nil {
t.Skipf("Skipping test: GeoIP database not available: %v", err)
}
defer geoip.Close()
privateIP := net.ParseIP("127.0.0.1")
country, err := geoip.LookupCountry(privateIP)
require.NoError(t, err)
assert.Equal(t, "PRIVATE_ALLOW", country)
}
// TestGeoIPLookup_GetStats 测试统计信息获取。
func TestGeoIPLookup_GetStats(t *testing.T) {
dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb"
geoip, err := NewGeoIPLookup(dbPath, 1000, time.Hour, "allow")
if err != nil {
t.Skipf("Skipping test: GeoIP database not available: %v", err)
}
defer geoip.Close()
stats := geoip.GetStats()
assert.GreaterOrEqual(t, stats.CacheSize, 0)
assert.GreaterOrEqual(t, stats.CacheMaxSize, 0)
}
// TestGeoIPLookup_CacheBehavior 测试缓存行为。
func TestGeoIPLookup_CacheBehavior(t *testing.T) {
dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb"
geoip, err := NewGeoIPLookup(dbPath, 1000, time.Hour, "allow")
if err != nil {
t.Skipf("Skipping test: GeoIP database not available: %v", err)
}
defer geoip.Close()
// 使用公网 IP 进行测试(假设 8.8.8.8 是美国)
publicIP := net.ParseIP("8.8.8.8")
// 第一次查询
country1, err := geoip.LookupCountry(publicIP)
if err != nil {
// 数据库中可能没有该 IP 的信息
t.Skipf("Skipping test: IP not found in database: %v", err)
}
// 第二次查询(应该从缓存返回)
country2, err := geoip.LookupCountry(publicIP)
require.NoError(t, err)
assert.Equal(t, country1, country2)
// 验证缓存大小
stats := geoip.GetStats()
assert.GreaterOrEqual(t, stats.CacheSize, 1)
}
// TestGeoIPLookup_TTLExpiration 测试缓存 TTL 过期。
func TestGeoIPLookup_TTLExpiration(t *testing.T) {
dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb"
// 使用很短的 TTL
geoip, err := NewGeoIPLookup(dbPath, 1000, 1*time.Millisecond, "allow")
if err != nil {
t.Skipf("Skipping test: GeoIP database not available: %v", err)
}
defer geoip.Close()
publicIP := net.ParseIP("8.8.8.8")
// 第一次查询
_, err = geoip.LookupCountry(publicIP)
if err != nil {
t.Skipf("Skipping test: IP not found in database: %v", err)
}
// 等待 TTL 过期
time.Sleep(10 * time.Millisecond)
// 再次查询(缓存应该已过期)
_, err = geoip.LookupCountry(publicIP)
// 不应该报错,只是重新查询数据库
assert.NoError(t, err)
}