refactor(utils): enhance ParseCIDR to support single IP
Enhance parseCIDR in utils/ipallowlist.go to support single IP addresses (without CIDR prefix) and ensure IP is in canonical form. This matches the functionality previously in access.go. - Add ParseCIDR as public function supporting CIDR and single IP - Update access.go to use utils.ParseCIDR instead of local implementation - Remove duplicate parseCIDR function from access.go - Update tests to use utils.ParseCIDR Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
70d6488fc6
commit
edc135ae5f
@ -36,6 +36,7 @@ import (
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/middleware"
|
||||
"rua.plus/lolly/internal/netutil"
|
||||
"rua.plus/lolly/internal/utils"
|
||||
)
|
||||
|
||||
// Action 表示对 IP 的操作类型。
|
||||
@ -87,7 +88,7 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) {
|
||||
|
||||
// 解析允许列表
|
||||
for _, cidr := range cfg.Allow {
|
||||
network, err := parseCIDR(cidr)
|
||||
network, err := utils.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid allow CIDR %s: %w", cidr, err)
|
||||
}
|
||||
@ -96,7 +97,7 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) {
|
||||
|
||||
// 解析拒绝列表
|
||||
for _, cidr := range cfg.Deny {
|
||||
network, err := parseCIDR(cidr)
|
||||
network, err := utils.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid deny CIDR %s: %w", cidr, err)
|
||||
}
|
||||
@ -105,7 +106,7 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) {
|
||||
|
||||
// 解析可信代理列表
|
||||
for _, cidr := range cfg.TrustedProxies {
|
||||
network, err := parseCIDR(cidr)
|
||||
network, err := utils.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid trusted_proxy CIDR %s: %w", cidr, err)
|
||||
}
|
||||
@ -314,7 +315,7 @@ func (ac *AccessControl) UpdateDenyList(cidrs []string) error {
|
||||
func parseCIDRList(cidrs []string) ([]net.IPNet, error) {
|
||||
newList := make([]net.IPNet, 0, len(cidrs))
|
||||
for _, cidr := range cidrs {
|
||||
network, err := parseCIDR(cidr)
|
||||
network, err := utils.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid CIDR %s: %w", cidr, err)
|
||||
}
|
||||
@ -346,45 +347,6 @@ func (ac *AccessControl) SetDefault(action string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseCIDR 解析 CIDR 字符串,支持 IPv4 和 IPv6。
|
||||
//
|
||||
// 支持完整的 CIDR 表示法(如 192.168.1.0/24)和单个 IP(如 192.168.1.1)。
|
||||
// 单个 IP 会自动转换为 /32(IPv4)或 /128(IPv6)的 CIDR。
|
||||
//
|
||||
// 参数:
|
||||
// - cidr: CIDR 字符串或单个 IP 地址
|
||||
//
|
||||
// 返回值:
|
||||
// - *net.IPNet: 解析后的 IP 网络对象
|
||||
// - error: 解析失败时返回错误
|
||||
func parseCIDR(cidr string) (*net.IPNet, error) {
|
||||
// 处理单个 IP(没有 /前缀)
|
||||
if !strings.Contains(cidr, "/") {
|
||||
ip := net.ParseIP(cidr)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid IP address: %s", cidr)
|
||||
}
|
||||
|
||||
// 转换为完整掩码的 CIDR
|
||||
if ip.To4() != nil {
|
||||
cidr = cidr + "/32"
|
||||
} else {
|
||||
cidr = cidr + "/128"
|
||||
}
|
||||
}
|
||||
|
||||
// 解析 CIDR
|
||||
ip, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 确保 IP 为规范形式
|
||||
network.IP = ip
|
||||
|
||||
return network, nil
|
||||
}
|
||||
|
||||
// getClientIP 从请求上下文安全提取客户端 IP。
|
||||
//
|
||||
// 仅当请求来自可信代理时,才解析 X-Forwarded-For 头部。
|
||||
|
||||
@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/utils"
|
||||
)
|
||||
|
||||
func TestNewAccessControl(t *testing.T) {
|
||||
@ -244,9 +245,9 @@ func TestParseCIDR(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
network, err := parseCIDR(tt.cidr)
|
||||
network, err := utils.ParseCIDR(tt.cidr)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseCIDR() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("utils.ParseCIDR() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if !tt.wantErr && network == nil {
|
||||
t.Error("Expected non-nil network")
|
||||
|
||||
@ -2,7 +2,9 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseIPAllowList 解析 IP/CIDR 白名单列表。
|
||||
@ -66,10 +68,49 @@ func ParseIPAllowList(allow []string) ([]net.IPNet, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseCIDR 是 net.ParseCIDR 的包装,返回 *net.IPNet 而不返回 net.IP
|
||||
// ParseCIDR 解析 CIDR 字符串或单个 IP 地址。
|
||||
//
|
||||
// 支持格式:
|
||||
// - CIDR 格式:192.168.1.0/24、::1/128
|
||||
// - 单个 IP:192.168.1.1(自动转换为 /32 或 /128)
|
||||
//
|
||||
// 参数:
|
||||
// - cidr: CIDR 字符串或单个 IP 地址
|
||||
//
|
||||
// 返回值:
|
||||
// - *net.IPNet: 解析后的 IP 网络对象
|
||||
// - error: 解析失败时返回错误
|
||||
func ParseCIDR(cidr string) (*net.IPNet, error) {
|
||||
// 处理单个 IP(没有 /前缀)
|
||||
if !strings.Contains(cidr, "/") {
|
||||
ip := net.ParseIP(cidr)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid IP address: %s", cidr)
|
||||
}
|
||||
|
||||
// 转换为完整掩码的 CIDR
|
||||
if ip.To4() != nil {
|
||||
cidr = cidr + "/32"
|
||||
} else {
|
||||
cidr = cidr + "/128"
|
||||
}
|
||||
}
|
||||
|
||||
// 解析 CIDR
|
||||
ip, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 确保 IP 为规范形式
|
||||
network.IP = ip
|
||||
|
||||
return network, nil
|
||||
}
|
||||
|
||||
// parseCIDR 是 ParseCIDR 的内部别名,保持向后兼容
|
||||
func parseCIDR(cidr string) (*net.IPNet, error) {
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
return network, err
|
||||
return ParseCIDR(cidr)
|
||||
}
|
||||
|
||||
// IPInAllowList 检查 IP 是否在白名单中。
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user