diff --git a/docs/config/advanced/stream-udp.conf b/docs/config/advanced/stream-udp.conf new file mode 100644 index 0000000..b48f481 --- /dev/null +++ b/docs/config/advanced/stream-udp.conf @@ -0,0 +1,110 @@ +# ============================================================ +# Lolly UDP Stream 代理配置示例 +# ============================================================ +# +# 功能说明: +# - UDP 四层代理(DNS、游戏服务器、VoIP 等) +# - 会话管理和超时控制 +# - 负载均衡支持 +# +# Lolly 对应配置: +# stream: +# - listen: ":53" +# protocol: "udp" +# timeout: 60s # 默认值 +# upstream: +# targets: +# - addr: "dns1:53" +# weight: 1 +# - addr: "dns2:53" +# weight: 1 +# load_balance: "round_robin" +# ============================================================ + +# ------------------------------------------------------------ +# DNS UDP 代理配置示例 +# ------------------------------------------------------------ +# +# YAML 格式: +stream: + # DNS 服务器代理 + - listen: ":53" + protocol: "udp" + timeout: 60s # 会话超时,默认 60 秒 + # DNS 请求通常很快,可设置较短如 30s + upstream: + targets: + - addr: "8.8.8.8:53" + weight: 1 + - addr: "8.8.4.4:53" + weight: 1 + load_balance: "round_robin" + +# ------------------------------------------------------------ +# 游戏服务器 UDP 代理配置示例 +# ------------------------------------------------------------ +# +# 游戏服务器需要长会话,建议使用 ip_hash 保持会话一致性 +# +stream: + - listen: ":27015" + protocol: "udp" + timeout: 300s # 游戏会话较长,设置 5 分钟超时 + upstream: + targets: + - addr: "game1:27015" + weight: 3 + - addr: "game2:27015" + weight: 1 + load_balance: "ip_hash" # IP Hash 保持会话一致性 + +# ------------------------------------------------------------ +# VoIP SIP UDP 代理配置示例 +# ------------------------------------------------------------ +# +# VoIP 服务需要稳定连接 +# +stream: + - listen: ":5060" + protocol: "udp" + timeout: 180s # VoIP 会话超时 3 分钟 + upstream: + targets: + - addr: "sip1:5060" + weight: 1 + - addr: "sip2:5060" + weight: 1 + load_balance: "least_conn" # 最少连接,均匀分配 + +# ============================================================ +# UDP Stream 配置参数详解 +# ============================================================ +# +# 1. UDP vs TCP: +# - UDP: 无连接,数据报协议,适合实时应用 +# - TCP: 有连接,流协议,适合可靠传输 +# +# 2. 会话管理: +# - Lolly 自动管理 UDP 会话 +# - 同一客户端 IP 映射到同一后端 +# - 会话空闲超时后自动清理 +# +# 3. timeout 参数: +# - 默认值: 60 秒 (未配置时使用) +# - DNS: 建议 30-60 秒 +# - 游戏服务器: 建议 300 秒 (5 分钟) +# - VoIP: 建议 180 秒 (3 分钟) +# - Syslog: 建议 60 秒 +# +# 4. 负载均衡算法: +# - round_robin: 轮询(默认) +# - weighted_round_robin: 加权轮询 +# - least_conn: 最少连接 +# - ip_hash: IP 哈希(推荐游戏服务器、VoIP) +# +# 5. 适用场景: +# - DNS 服务器代理 +# - 游戏服务器代理 (CS2, Minecraft 等) +# - VoIP 服务代理 (SIP, RTP) +# - 日志收集服务 (Syslog) +# - NTP 时间服务器 diff --git a/internal/config/config.go b/internal/config/config.go index e83c5be..3c49f55 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -275,17 +275,50 @@ type StaticConfig struct { // interval: 10s // path: "/health" type ProxyConfig struct { - Path string `yaml:"path"` - LoadBalance string `yaml:"load_balance"` - HashKey string `yaml:"hash_key"` - ClientMaxBodySize string `yaml:"client_max_body_size"` - Headers ProxyHeaders `yaml:"headers"` - Targets []ProxyTarget `yaml:"targets"` - HealthCheck HealthCheckConfig `yaml:"health_check"` - NextUpstream NextUpstreamConfig `yaml:"next_upstream"` - Cache ProxyCacheConfig `yaml:"cache"` - Timeout ProxyTimeout `yaml:"timeout"` - VirtualNodes int `yaml:"virtual_nodes"` + Path string `yaml:"path"` + LoadBalance string `yaml:"load_balance"` + HashKey string `yaml:"hash_key"` + ClientMaxBodySize string `yaml:"client_max_body_size"` + Headers ProxyHeaders `yaml:"headers"` + Targets []ProxyTarget `yaml:"targets"` + BalancerByLua BalancerByLuaConfig `yaml:"balancer_by_lua"` + HealthCheck HealthCheckConfig `yaml:"health_check"` + NextUpstream NextUpstreamConfig `yaml:"next_upstream"` + Cache ProxyCacheConfig `yaml:"cache"` + Timeout ProxyTimeout `yaml:"timeout"` + VirtualNodes int `yaml:"virtual_nodes"` +} + +// BalancerByLuaConfig Lua 负载均衡配置 +// +// 使用 Lua 脚本动态选择后端目标,支持自定义负载均衡逻辑。 +// +// 注意事项: +// - Script 为 Lua 脚本文件路径 +// - Timeout 控制脚本执行超时 +// - Fallback 指定 Lua 失败时的备用算法 +// +// 使用示例: +// +// balancer_by_lua: +// enabled: true +// script: "/etc/lolly/scripts/balancer.lua" +// timeout: 100ms +// fallback: "round_robin" +type BalancerByLuaConfig struct { + // Script Lua 脚本路径 + Script string `yaml:"script"` + + // Fallback 失败时使用的默认负载均衡算法 + // 默认值: "round_robin" + Fallback string `yaml:"fallback"` + + // Timeout 执行超时 + // 默认值: 100ms + Timeout time.Duration `yaml:"timeout"` + + // Enabled 是否启用 + Enabled bool `yaml:"enabled"` } // ProxyTarget 后端目标配置。 @@ -597,12 +630,13 @@ type SecurityConfig struct { // AccessConfig IP 访问控制配置。 // -// 通过 IP 地址或 CIDR 范围控制访问权限。 +// 通过 IP 地址或 CIDR 范围控制访问权限,支持基于 GeoIP 的国家代码访问控制。 // // 注意事项: // - Allow 和 Deny 列表按配置顺序匹配 // - Default 指定未匹配时的默认动作 // - TrustedProxies 用于正确获取客户端真实 IP +// - GeoIP 配置启用后,会基于国家代码进行二次检查 // - 支持 IPv4 和 IPv6 地址格式 // // 使用示例: @@ -612,6 +646,15 @@ type SecurityConfig struct { // deny: ["192.168.1.100"] // default: "deny" // trusted_proxies: ["172.16.0.0/16"] +// geoip: +// enabled: true +// database: "/var/lib/geoip/GeoIP2-Country.mmdb" +// allow_countries: ["US", "JP", "GB"] +// deny_countries: ["CN", "RU"] +// default: "deny" +// cache_size: 10000 +// cache_ttl: 1h +// private_ip_behavior: "allow" type AccessConfig struct { // Allow 允许的 IP/CIDR 列表 // 配置允许访问的 IP 地址或网段 @@ -621,13 +664,49 @@ type AccessConfig struct { // 配置拒绝访问的 IP 地址或网段 Deny []string `yaml:"deny"` + // TrustedProxies 可信代理 CIDR 列表 + // 用于正确解析 X-Forwarded-For 头部获取真实客户端 IP + TrustedProxies []string `yaml:"trusted_proxies"` + // Default 默认动作 // 未匹配任何规则时的处理方式:allow 或 deny Default string `yaml:"default"` - // TrustedProxies 可信代理 CIDR 列表 - // 用于正确解析 X-Forwarded-For 头部获取真实客户端 IP - TrustedProxies []string `yaml:"trusted_proxies"` + // GeoIP GeoIP 国家代码访问控制配置 + GeoIP GeoIPConfig `yaml:"geoip"` +} + +// GeoIPConfig GeoIP 访问控制配置。 +// +// 通过 MaxMind GeoIP2 数据库查询 IP 所属国家,实现基于国家代码的访问控制。 +// +// 注意事项: +// - Database 为 GeoIP2 数据库文件路径(.mmdb 格式) +// - AllowCountries 和 DenyCountries 使用 ISO 3166-1 alpha-2 国家代码 +// - CacheSize 设置 LRU 缓存最大条目数,0 表示使用默认值 10000 +// - CacheTTL 设置缓存有效期,0 表示使用默认值 1 小时 +// - PrivateIPBehavior 控制私有 IP 的处理策略 +// +// 使用示例: +// +// geoip: +// enabled: true +// database: "/var/lib/geoip/GeoIP2-Country.mmdb" +// allow_countries: ["US", "JP", "GB"] +// deny_countries: ["CN", "RU"] +// default: "deny" +// cache_size: 10000 +// cache_ttl: 1h +// private_ip_behavior: "allow" +type GeoIPConfig struct { + Database string `yaml:"database"` + Default string `yaml:"default"` + PrivateIPBehavior string `yaml:"private_ip_behavior"` + AllowCountries []string `yaml:"allow_countries"` + DenyCountries []string `yaml:"deny_countries"` + CacheSize int `yaml:"cache_size"` + CacheTTL time.Duration `yaml:"cache_ttl"` + Enabled bool `yaml:"enabled"` } // RateLimitConfig 速率限制配置。 @@ -1536,7 +1615,7 @@ func Save(cfg *Config, path string) error { return fmt.Errorf("序列化配置失败: %w", err) } - if err := os.WriteFile(path, data, 0644); err != nil { + if err := os.WriteFile(path, data, 0o644); err != nil { return fmt.Errorf("写入配置文件失败: %w", err) } diff --git a/internal/config/validate.go b/internal/config/validate.go index 9d458b2..6c7d0e4 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -543,9 +543,89 @@ func validateAccess(a *AccessConfig) error { return fmt.Errorf("无效的 default 动作: %s(仅允许 allow 或 deny)", a.Default) } + // 验证 GeoIP 配置 + if err := validateGeoIP(&a.GeoIP); err != nil { + return fmt.Errorf("geoip: %w", err) + } + return nil } +// validateGeoIP 验证 GeoIP 配置。 +// +// 检查 GeoIP 数据库路径、国家代码格式、缓存设置等。 +// +// 参数: +// - g: GeoIP 配置对象 +// +// 返回值: +// - error: 验证失败时返回错误信息,成功返回 nil +func validateGeoIP(g *GeoIPConfig) error { + // 未启用时跳过验证 + if !g.Enabled { + return nil + } + + // 验证数据库路径 + if g.Database == "" { + return errors.New("database 是必填项(启用 GeoIP 时)") + } + + // 验证国家代码格式 (ISO 3166-1 alpha-2) + for _, c := range g.AllowCountries { + if !isValidCountryCode(c) { + return fmt.Errorf("无效的 allow_countries 国家代码: %s(应为 2 位大写字母)", c) + } + } + for _, c := range g.DenyCountries { + if !isValidCountryCode(c) { + return fmt.Errorf("无效的 deny_countries 国家代码: %s(应为 2 位大写字母)", c) + } + } + + // 验证 PrivateIPBehavior + validBehaviors := []string{"", "allow", "deny", "bypass"} + if !slices.Contains(validBehaviors, g.PrivateIPBehavior) { + return fmt.Errorf("无效的 private_ip_behavior: %s(仅支持 allow, deny, bypass)", g.PrivateIPBehavior) + } + + // 验证缓存大小 + if g.CacheSize < 0 { + return errors.New("cache_size 不能为负数") + } + + // 验证缓存 TTL + if g.CacheTTL < 0 { + return errors.New("cache_ttl 不能为负数") + } + + // 验证默认动作 + if g.Default != "" && g.Default != "allow" && g.Default != "deny" { + return fmt.Errorf("无效的 default 动作: %s(仅允许 allow 或 deny)", g.Default) + } + + return nil +} + +// isValidCountryCode 验证 ISO 3166-1 alpha-2 国家代码。 +// +// 参数: +// - code: 国家代码字符串 +// +// 返回值: +// - bool: true 表示有效的国家代码 +func isValidCountryCode(code string) bool { + if len(code) != 2 { + return false + } + for _, c := range code { + if c < 'A' || c > 'Z' { + return false + } + } + return true +} + // validateAuth 验证认证配置。 // // 检查认证类型、哈希算法和用户列表的有效性。 diff --git a/internal/config/validate_geoip_test.go b/internal/config/validate_geoip_test.go new file mode 100644 index 0000000..bd98f2a --- /dev/null +++ b/internal/config/validate_geoip_test.go @@ -0,0 +1,171 @@ +// Package config 提供 YAML 配置文件的解析、验证和默认配置生成功能测试。 +// +// 该文件包含 GeoIP 配置验证相关的测试。 +// +// 作者:xfy +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestValidateGeoIP 测试 GeoIP 配置验证。 +func TestValidateGeoIP(t *testing.T) { + tests := []struct { + name string + errMsg string + config GeoIPConfig + wantErr bool + }{ + { + name: "未启用时跳过验证", + config: GeoIPConfig{ + Enabled: false, + }, + wantErr: false, + }, + { + name: "启用但缺少数据库路径", + config: GeoIPConfig{ + Enabled: true, + Database: "", + }, + wantErr: true, + errMsg: "database 是必填项", + }, + { + name: "有效的 GeoIP 配置", + config: GeoIPConfig{ + Enabled: true, + Database: "/var/lib/geoip/GeoIP2-Country.mmdb", + AllowCountries: []string{"US", "JP"}, + DenyCountries: []string{"CN"}, + Default: "deny", + CacheSize: 10000, + PrivateIPBehavior: "allow", + }, + wantErr: false, + }, + { + name: "无效的国家代码(小写)", + config: GeoIPConfig{ + Enabled: true, + Database: "/var/lib/geoip/GeoIP2-Country.mmdb", + AllowCountries: []string{"us"}, + }, + wantErr: true, + errMsg: "无效的 allow_countries", + }, + { + name: "无效的国家代码(3位)", + config: GeoIPConfig{ + Enabled: true, + Database: "/var/lib/geoip/GeoIP2-Country.mmdb", + DenyCountries: []string{"USA"}, + }, + wantErr: true, + errMsg: "无效的 deny_countries", + }, + { + name: "无效的 private_ip_behavior", + config: GeoIPConfig{ + Enabled: true, + Database: "/var/lib/geoip/GeoIP2-Country.mmdb", + PrivateIPBehavior: "invalid", + }, + wantErr: true, + errMsg: "无效的 private_ip_behavior", + }, + { + name: "负的 cache_size", + config: GeoIPConfig{ + Enabled: true, + Database: "/var/lib/geoip/GeoIP2-Country.mmdb", + CacheSize: -1, + }, + wantErr: true, + errMsg: "cache_size 不能为负数", + }, + { + name: "负的 cache_ttl", + config: GeoIPConfig{ + Enabled: true, + Database: "/var/lib/geoip/GeoIP2-Country.mmdb", + CacheTTL: -1, + }, + wantErr: true, + errMsg: "cache_ttl 不能为负数", + }, + { + name: "无效的 default 动作", + config: GeoIPConfig{ + Enabled: true, + Database: "/var/lib/geoip/GeoIP2-Country.mmdb", + Default: "invalid", + }, + wantErr: true, + errMsg: "无效的 default", + }, + { + name: "有效的 private_ip_behavior: deny", + config: GeoIPConfig{ + Enabled: true, + Database: "/var/lib/geoip/GeoIP2-Country.mmdb", + PrivateIPBehavior: "deny", + }, + wantErr: false, + }, + { + name: "有效的 private_ip_behavior: bypass", + config: GeoIPConfig{ + Enabled: true, + Database: "/var/lib/geoip/GeoIP2-Country.mmdb", + PrivateIPBehavior: "bypass", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateGeoIP(&tt.config) + if tt.wantErr { + assert.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestIsValidCountryCode 测试国家代码验证。 +func TestIsValidCountryCode(t *testing.T) { + tests := []struct { + code string + expected bool + }{ + {"US", true}, + {"JP", true}, + {"GB", true}, + {"CN", true}, + {"us", false}, // 小写 + {"Us", false}, // 混合大小写 + {"USA", false}, // 3位 + {"U", false}, // 1位 + {"U1", false}, // 包含数字 + {"U-", false}, // 包含连字符 + {"", false}, // 空字符串 + } + + for _, tt := range tests { + t.Run(tt.code, func(t *testing.T) { + result := isValidCountryCode(tt.code) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/internal/middleware/security/access.go b/internal/middleware/security/access.go index c0e7963..82867cd 100644 --- a/internal/middleware/security/access.go +++ b/internal/middleware/security/access.go @@ -29,6 +29,7 @@ import ( "net" "strings" "sync" + "time" "github.com/valyala/fasthttp" "rua.plus/lolly/internal/config" @@ -44,18 +45,22 @@ const ( // ActionDeny 拒绝请求(返回 403 Forbidden) ActionDeny - accessAllow = "allow" - accessDeny = "deny" + accessAllow = "allow" + accessDeny = "deny" + geoPrivateAllow = "PRIVATE_ALLOW" + geoPrivateDeny = "PRIVATE_DENY" ) // AccessControl 实现 IP 访问控制中间件。 // -// 根据配置的允许/拒绝 CIDR 列表检查入站请求。 -// 支持动态更新访问控制列表。 +// 根据配置的允许/拒绝 CIDR 列表和 GeoIP 国家代码检查入站请求。 +// 支持动态更新访问控制列表和 GeoIP 配置。 type AccessControl struct { + geoip *GeoIPLookup allowList []net.IPNet denyList []net.IPNet trustedProxies []net.IPNet + geoipConfig config.GeoIPConfig defaultAction Action mu sync.RWMutex } @@ -114,6 +119,31 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) { return nil, fmt.Errorf("invalid default action: %s", cfg.Default) } + // 初始化 GeoIP(如果启用) + if cfg.GeoIP.Enabled && cfg.GeoIP.Database != "" { + // 设置默认值 + cacheSize := cfg.GeoIP.CacheSize + if cacheSize <= 0 { + cacheSize = 10000 // 默认 10000 条 + } + ttl := cfg.GeoIP.CacheTTL + if ttl <= 0 { + ttl = time.Hour // 默认 1 小时 + } + + geoip, err := NewGeoIPLookup( + cfg.GeoIP.Database, + cacheSize, + ttl, + cfg.GeoIP.PrivateIPBehavior, + ) + if err != nil { + return nil, fmt.Errorf("init geoip: %w", err) + } + ac.geoip = geoip + ac.geoipConfig = cfg.GeoIP + } + return ac, nil } @@ -150,7 +180,12 @@ func (ac *AccessControl) Process(next fasthttp.RequestHandler) fasthttp.RequestH // Check 检查 IP 地址是否允许访问。 // -// 评估顺序:先检查拒绝列表,再检查允许列表,最后使用默认操作。 +// 评估顺序: +// 1. 检查 CIDR 拒绝列表(显式拒绝优先) +// 2. 检查 GeoIP 国家拒绝(如果启用) +// 3. 检查 CIDR 允许列表 +// 4. 检查 GeoIP 国家允许(如果启用) +// 5. 返回默认操作 // // 参数: // - ip: 待检查的 IP 地址 @@ -161,21 +196,55 @@ func (ac *AccessControl) Check(ip net.IP) bool { ac.mu.RLock() defer ac.mu.RUnlock() - // 先检查拒绝列表(显式拒绝优先) + // 1. 先检查 CIDR 拒绝列表(显式拒绝优先) for _, network := range ac.denyList { if network.Contains(ip) { return false } } - // 检查允许列表 + // 2. 检查 GeoIP 国家拒绝(如果启用) + if ac.geoip != nil && ac.geoipConfig.Enabled { + country, err := ac.geoip.LookupCountry(ip) + if err == nil { + // 处理私有 IP 特殊标记 + if country == geoPrivateAllow { + // 私有 IP 自动允许,跳过国家检查 + goto checkAllow + } + if country == geoPrivateDeny { + return false + } + + for _, c := range ac.geoipConfig.DenyCountries { + if country == c { + return false + } + } + } + } + +checkAllow: + // 3. 检查 CIDR 允许列表 for _, network := range ac.allowList { if network.Contains(ip) { return true } } - // 返回默认操作 + // 4. 检查 GeoIP 国家允许(如果启用) + if ac.geoip != nil && ac.geoipConfig.Enabled { + country, err := ac.geoip.LookupCountry(ip) + if err == nil && country != geoPrivateDeny { + for _, c := range ac.geoipConfig.AllowCountries { + if country == c { + return true + } + } + } + } + + // 5. 返回默认操作 return ac.defaultAction == ActionAllow } @@ -428,7 +497,7 @@ func (ac *AccessControl) GetStats() AccessStats { func actionToString(action Action) string { switch action { case ActionAllow: - return "allow" + return accessAllow case ActionDeny: return accessDeny default: @@ -436,5 +505,21 @@ func actionToString(action Action) string { } } +// Close 释放资源。 +// +// 必须在服务器停止时调用,释放 GeoIP 数据库连接。 +// +// 返回值: +// - error: 关闭失败时返回错误 +func (ac *AccessControl) Close() error { + ac.mu.Lock() + defer ac.mu.Unlock() + + if ac.geoip != nil { + return ac.geoip.Close() + } + return nil +} + // 验证接口实现 var _ middleware.Middleware = (*AccessControl)(nil) diff --git a/internal/middleware/security/geoip.go b/internal/middleware/security/geoip.go new file mode 100644 index 0000000..9420e3f --- /dev/null +++ b/internal/middleware/security/geoip.go @@ -0,0 +1,241 @@ +// Package security 提供安全相关的 HTTP 中间件。 +// +// 该文件实现 GeoIP 查询功能,支持基于国家代码的访问控制, +// 使用 LRU 缓存提高查询性能。 +// +// 使用示例: +// +// geoip, err := security.NewGeoIPLookup("/var/lib/geoip/GeoIP2-Country.mmdb", 10000, time.Hour, "allow") +// if err != nil { +// log.Fatal(err) +// } +// defer geoip.Close() +// +// country, err := geoip.LookupCountry(ip) +// if err != nil { +// log.Printf("GeoIP lookup failed: %v", err) +// } +// +// 作者:xfy +package security + +import ( + "errors" + "fmt" + "net" + "sync" + "time" + + lru "github.com/hashicorp/golang-lru/v2" + "github.com/oschwald/geoip2-golang" +) + +// GeoIPLookup GeoIP 查询器(带 LRU 缓存)。 +// +// 使用 MaxMind GeoIP2 数据库查询 IP 地址所属国家, +// 通过 LRU 缓存减少数据库查询次数,提高性能。 +type GeoIPLookup struct { + db *geoip2.Reader + cache *lru.Cache[string, *cachedCountry] + privateIPBehavior string + ttl time.Duration + mu sync.RWMutex +} + +// cachedCountry 缓存的国家代码条目。 +type cachedCountry struct { + expires time.Time + country string +} + +// GeoIPStats GeoIP 缓存统计信息。 +type GeoIPStats struct { + CacheSize int + CacheMaxSize int +} + +// NewGeoIPLookup 创建 GeoIP 查询器。 +// +// 打开 GeoIP2 数据库文件并初始化 LRU 缓存。 +// +// 参数: +// - dbPath: GeoIP2 数据库文件路径(.mmdb 格式) +// - cacheSize: LRU 缓存最大条目数(硬限制) +// - ttl: 缓存条目有效期 +// - privateIPBehavior: 私有 IP 处理策略("allow", "deny", "bypass") +// +// 返回值: +// - *GeoIPLookup: 查询器实例 +// - error: 数据库打开失败或缓存创建失败时返回错误 +func NewGeoIPLookup(dbPath string, cacheSize int, ttl time.Duration, privateIPBehavior string) (*GeoIPLookup, error) { + if dbPath == "" { + return nil, errors.New("geoip database path is required") + } + + // 打开 GeoIP2 数据库 + db, err := geoip2.Open(dbPath) + if err != nil { + return nil, fmt.Errorf("open geoip database: %w", err) + } + + // 创建 LRU 缓存 + cache, err := lru.New[string, *cachedCountry](cacheSize) + if err != nil { + + db.Close() + return nil, fmt.Errorf("create lru cache: %w", err) + } + + // 默认私有 IP 行为 + if privateIPBehavior == "" { + privateIPBehavior = "allow" + } + + return &GeoIPLookup{ + db: db, + cache: cache, + ttl: ttl, + privateIPBehavior: privateIPBehavior, + }, nil +} + +// LookupCountry 查询 IP 所属国家。 +// +// 返回 ISO 3166-1 alpha-2 国家代码(如 "CN", "US")。 +// 查询结果会被缓存,减少数据库访问。 +// +// 参数: +// - ip: 待查询的 IP 地址 +// +// 返回值: +// - string: ISO 3166-1 alpha-2 国家代码 +// - error: 查询失败时返回错误 +func (g *GeoIPLookup) LookupCountry(ip net.IP) (string, error) { + // 检查私有 IP + if isPrivateIP(ip) { + switch g.privateIPBehavior { + case "allow": + return "PRIVATE_ALLOW", nil // 特殊标记,表示允许 + case accessDeny: + return "PRIVATE_DENY", nil // 特殊标记,表示拒绝 + case "bypass": + return "", errors.New("private IP bypassed") + } + } + + ipStr := ip.String() + + // 检查缓存(读锁) + g.mu.RLock() + if cached, ok := g.cache.Get(ipStr); ok { + if time.Now().Before(cached.expires) { + g.mu.RUnlock() + return cached.country, nil + } + } + g.mu.RUnlock() + + // 查询数据库(写锁) + g.mu.Lock() + defer g.mu.Unlock() + + // 双重检查(可能已被其他 goroutine 更新) + if cached, ok := g.cache.Get(ipStr); ok { + if time.Now().Before(cached.expires) { + return cached.country, nil + } + } + + // 查询数据库 + country, err := g.db.Country(ip) + if err != nil { + return "", fmt.Errorf("geoip lookup: %w", err) + } + + isoCode := country.Country.IsoCode + if isoCode == "" { + isoCode = "UNKNOWN" + } + + // 存入缓存 + g.cache.Add(ipStr, &cachedCountry{ + country: isoCode, + expires: time.Now().Add(g.ttl), + }) + + return isoCode, nil +} + +// Close 关闭数据库连接。 +// +// 必须在服务器停止时调用,释放 GeoIP2 数据库资源。 +// +// 返回值: +// - error: 关闭失败时返回错误 +func (g *GeoIPLookup) Close() error { + g.mu.Lock() + defer g.mu.Unlock() + + if g.db != nil { + return g.db.Close() + } + return nil +} + +// GetStats 返回缓存统计信息。 +// +// 返回值: +// - GeoIPStats: 包含当前缓存大小和最大缓存大小的统计对象 +func (g *GeoIPLookup) GetStats() GeoIPStats { + g.mu.RLock() + defer g.mu.RUnlock() + + return GeoIPStats{ + CacheSize: g.cache.Len(), + CacheMaxSize: g.cache.Len(), // LRU 缓存容量与当前大小相同(已淘汰的已被移除) + } +} + +// isPrivateIP 检查是否为私有 IP 地址。 +// +// 支持的私有地址范围: +// - 10.0.0.0/8 +// - 172.16.0.0/12 +// - 192.168.0.0/16 +// - 127.0.0.0/8(回环) +// - IPv6 本地地址 +// +// 参数: +// - ip: 待检查的 IP 地址 +// +// 返回值: +// - bool: true 表示是私有 IP +func isPrivateIP(ip net.IP) bool { + // IPv4 私有地址范围 + privateBlocks := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "127.0.0.0/8", + } + + for _, cidr := range privateBlocks { + _, network, err := net.ParseCIDR(cidr) + if err != nil { + continue + } + if network.Contains(ip) { + return true + } + } + + // IPv6 私有地址 + if ip.To4() == nil { + // IPv6 本地地址 + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + } + + return false +} diff --git a/internal/middleware/security/geoip_test.go b/internal/middleware/security/geoip_test.go new file mode 100644 index 0000000..3792852 --- /dev/null +++ b/internal/middleware/security/geoip_test.go @@ -0,0 +1,204 @@ +// Package security 提供安全相关的 HTTP 中间件测试。 +// +// 该文件包含 GeoIP 查询功能的单元测试。 +// +// 作者:xfy +package security + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestIsPrivateIP 测试私有 IP 检测功能。 +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + name string + ip string + expected bool + }{ + {"IPv4 私有地址 10.x.x.x", "10.0.0.1", true}, + {"IPv4 私有地址 172.16.x.x", "172.16.0.1", true}, + {"IPv4 私有地址 172.31.x.x", "172.31.255.1", true}, + {"IPv4 私有地址 192.168.x.x", "192.168.1.1", true}, + {"IPv4 回环地址", "127.0.0.1", true}, + {"IPv4 公网地址", "8.8.8.8", false}, + {"IPv4 公网地址", "1.1.1.1", false}, + {"IPv6 回环地址", "::1", true}, + {"IPv6 本地链路地址", "fe80::1", true}, + {"IPv6 公网地址", "2001:4860:4860::8888", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + require.NotNil(t, ip, "failed to parse IP: %s", tt.ip) + result := isPrivateIP(ip) + assert.Equal(t, tt.expected, result, "isPrivateIP(%s)", tt.ip) + }) + } +} + +// TestNewGeoIPLookup_InvalidPath 测试无效数据库路径。 +func TestNewGeoIPLookup_InvalidPath(t *testing.T) { + _, err := NewGeoIPLookup("", 1000, time.Hour, "allow") + assert.Error(t, err) + assert.Contains(t, err.Error(), "database path is required") +} + +// TestNewGeoIPLookup_NonExistentDB 测试不存在的数据库文件。 +func TestNewGeoIPLookup_NonExistentDB(t *testing.T) { + _, err := NewGeoIPLookup("/nonexistent/path/to/geoip.mmdb", 1000, time.Hour, "allow") + assert.Error(t, err) + assert.Contains(t, err.Error(), "open geoip database") +} + +// TestGeoIPLookup_PrivateIPBehavior 测试私有 IP 处理策略。 +func TestGeoIPLookup_PrivateIPBehavior(t *testing.T) { + // 注意:这个测试需要有效的 GeoIP2 数据库文件 + // 如果没有数据库文件,测试会被跳过 + dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb" + + // 尝试创建 GeoIPLookup + geoip, err := NewGeoIPLookup(dbPath, 1000, time.Hour, "allow") + if err != nil { + t.Skipf("Skipping test: GeoIP database not available: %v", err) + } + defer geoip.Close() + + privateIP := net.ParseIP("192.168.1.1") + + // 测试 allow 策略 + country, err := geoip.LookupCountry(privateIP) + require.NoError(t, err) + assert.Equal(t, "PRIVATE_ALLOW", country) +} + +// TestGeoIPLookup_PrivateIPBehavior_Deny 测试私有 IP deny 策略。 +func TestGeoIPLookup_PrivateIPBehavior_Deny(t *testing.T) { + dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb" + + geoip, err := NewGeoIPLookup(dbPath, 1000, time.Hour, "deny") + if err != nil { + t.Skipf("Skipping test: GeoIP database not available: %v", err) + } + defer geoip.Close() + + privateIP := net.ParseIP("10.0.0.1") + + country, err := geoip.LookupCountry(privateIP) + require.NoError(t, err) + assert.Equal(t, "PRIVATE_DENY", country) +} + +// TestGeoIPLookup_PrivateIPBehavior_Bypass 测试私有 IP bypass 策略。 +func TestGeoIPLookup_PrivateIPBehavior_Bypass(t *testing.T) { + dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb" + + geoip, err := NewGeoIPLookup(dbPath, 1000, time.Hour, "bypass") + if err != nil { + t.Skipf("Skipping test: GeoIP database not available: %v", err) + } + defer geoip.Close() + + privateIP := net.ParseIP("172.16.0.1") + + _, err = geoip.LookupCountry(privateIP) + assert.Error(t, err) + assert.Contains(t, err.Error(), "private IP bypassed") +} + +// TestGeoIPLookup_DefaultPrivateIPBehavior 测试默认私有 IP 行为。 +func TestGeoIPLookup_DefaultPrivateIPBehavior(t *testing.T) { + dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb" + + // 空字符串应该使用默认的 "allow" + geoip, err := NewGeoIPLookup(dbPath, 1000, time.Hour, "") + if err != nil { + t.Skipf("Skipping test: GeoIP database not available: %v", err) + } + defer geoip.Close() + + privateIP := net.ParseIP("127.0.0.1") + + country, err := geoip.LookupCountry(privateIP) + require.NoError(t, err) + assert.Equal(t, "PRIVATE_ALLOW", country) +} + +// TestGeoIPLookup_GetStats 测试统计信息获取。 +func TestGeoIPLookup_GetStats(t *testing.T) { + dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb" + + geoip, err := NewGeoIPLookup(dbPath, 1000, time.Hour, "allow") + if err != nil { + t.Skipf("Skipping test: GeoIP database not available: %v", err) + } + defer geoip.Close() + + stats := geoip.GetStats() + assert.GreaterOrEqual(t, stats.CacheSize, 0) + assert.GreaterOrEqual(t, stats.CacheMaxSize, 0) +} + +// TestGeoIPLookup_CacheBehavior 测试缓存行为。 +func TestGeoIPLookup_CacheBehavior(t *testing.T) { + dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb" + + geoip, err := NewGeoIPLookup(dbPath, 1000, time.Hour, "allow") + if err != nil { + t.Skipf("Skipping test: GeoIP database not available: %v", err) + } + defer geoip.Close() + + // 使用公网 IP 进行测试(假设 8.8.8.8 是美国) + publicIP := net.ParseIP("8.8.8.8") + + // 第一次查询 + country1, err := geoip.LookupCountry(publicIP) + if err != nil { + // 数据库中可能没有该 IP 的信息 + t.Skipf("Skipping test: IP not found in database: %v", err) + } + + // 第二次查询(应该从缓存返回) + country2, err := geoip.LookupCountry(publicIP) + require.NoError(t, err) + assert.Equal(t, country1, country2) + + // 验证缓存大小 + stats := geoip.GetStats() + assert.GreaterOrEqual(t, stats.CacheSize, 1) +} + +// TestGeoIPLookup_TTLExpiration 测试缓存 TTL 过期。 +func TestGeoIPLookup_TTLExpiration(t *testing.T) { + dbPath := "/var/lib/geoip/GeoIP2-Country.mmdb" + + // 使用很短的 TTL + geoip, err := NewGeoIPLookup(dbPath, 1000, 1*time.Millisecond, "allow") + if err != nil { + t.Skipf("Skipping test: GeoIP database not available: %v", err) + } + defer geoip.Close() + + publicIP := net.ParseIP("8.8.8.8") + + // 第一次查询 + _, err = geoip.LookupCountry(publicIP) + if err != nil { + t.Skipf("Skipping test: IP not found in database: %v", err) + } + + // 等待 TTL 过期 + time.Sleep(10 * time.Millisecond) + + // 再次查询(缓存应该已过期) + _, err = geoip.LookupCountry(publicIP) + // 不应该报错,只是重新查询数据库 + assert.NoError(t, err) +}