refactor(utils): enhance ParseCIDR to support single IP

Enhance parseCIDR in utils/ipallowlist.go to support single IP addresses
(without CIDR prefix) and ensure IP is in canonical form. This matches
the functionality previously in access.go.

- Add ParseCIDR as public function supporting CIDR and single IP
- Update access.go to use utils.ParseCIDR instead of local implementation
- Remove duplicate parseCIDR function from access.go
- Update tests to use utils.ParseCIDR

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-05-08 18:20:09 +08:00
parent 70d6488fc6
commit edc135ae5f
3 changed files with 52 additions and 48 deletions

View File

@ -36,6 +36,7 @@ import (
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/middleware"
"rua.plus/lolly/internal/netutil"
"rua.plus/lolly/internal/utils"
)
// Action 表示对 IP 的操作类型。
@ -87,7 +88,7 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) {
// 解析允许列表
for _, cidr := range cfg.Allow {
network, err := parseCIDR(cidr)
network, err := utils.ParseCIDR(cidr)
if err != nil {
return nil, fmt.Errorf("invalid allow CIDR %s: %w", cidr, err)
}
@ -96,7 +97,7 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) {
// 解析拒绝列表
for _, cidr := range cfg.Deny {
network, err := parseCIDR(cidr)
network, err := utils.ParseCIDR(cidr)
if err != nil {
return nil, fmt.Errorf("invalid deny CIDR %s: %w", cidr, err)
}
@ -105,7 +106,7 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) {
// 解析可信代理列表
for _, cidr := range cfg.TrustedProxies {
network, err := parseCIDR(cidr)
network, err := utils.ParseCIDR(cidr)
if err != nil {
return nil, fmt.Errorf("invalid trusted_proxy CIDR %s: %w", cidr, err)
}
@ -314,7 +315,7 @@ func (ac *AccessControl) UpdateDenyList(cidrs []string) error {
func parseCIDRList(cidrs []string) ([]net.IPNet, error) {
newList := make([]net.IPNet, 0, len(cidrs))
for _, cidr := range cidrs {
network, err := parseCIDR(cidr)
network, err := utils.ParseCIDR(cidr)
if err != nil {
return nil, fmt.Errorf("invalid CIDR %s: %w", cidr, err)
}
@ -346,45 +347,6 @@ func (ac *AccessControl) SetDefault(action string) error {
return nil
}
// parseCIDR 解析 CIDR 字符串,支持 IPv4 和 IPv6。
//
// 支持完整的 CIDR 表示法(如 192.168.1.0/24和单个 IP如 192.168.1.1)。
// 单个 IP 会自动转换为 /32IPv4或 /128IPv6的 CIDR。
//
// 参数:
// - cidr: CIDR 字符串或单个 IP 地址
//
// 返回值:
// - *net.IPNet: 解析后的 IP 网络对象
// - error: 解析失败时返回错误
func parseCIDR(cidr string) (*net.IPNet, error) {
// 处理单个 IP没有 /前缀)
if !strings.Contains(cidr, "/") {
ip := net.ParseIP(cidr)
if ip == nil {
return nil, fmt.Errorf("invalid IP address: %s", cidr)
}
// 转换为完整掩码的 CIDR
if ip.To4() != nil {
cidr = cidr + "/32"
} else {
cidr = cidr + "/128"
}
}
// 解析 CIDR
ip, network, err := net.ParseCIDR(cidr)
if err != nil {
return nil, err
}
// 确保 IP 为规范形式
network.IP = ip
return network, nil
}
// getClientIP 从请求上下文安全提取客户端 IP。
//
// 仅当请求来自可信代理时,才解析 X-Forwarded-For 头部。

View File

@ -16,6 +16,7 @@ import (
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/utils"
)
func TestNewAccessControl(t *testing.T) {
@ -244,9 +245,9 @@ func TestParseCIDR(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
network, err := parseCIDR(tt.cidr)
network, err := utils.ParseCIDR(tt.cidr)
if (err != nil) != tt.wantErr {
t.Errorf("parseCIDR() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("utils.ParseCIDR() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr && network == nil {
t.Error("Expected non-nil network")

View File

@ -2,7 +2,9 @@
package utils
import (
"fmt"
"net"
"strings"
)
// ParseIPAllowList 解析 IP/CIDR 白名单列表。
@ -66,10 +68,49 @@ func ParseIPAllowList(allow []string) ([]net.IPNet, error) {
return result, nil
}
// parseCIDR 是 net.ParseCIDR 的包装,返回 *net.IPNet 而不返回 net.IP
// ParseCIDR 解析 CIDR 字符串或单个 IP 地址。
//
// 支持格式:
// - CIDR 格式192.168.1.0/24、::1/128
// - 单个 IP192.168.1.1(自动转换为 /32 或 /128
//
// 参数:
// - cidr: CIDR 字符串或单个 IP 地址
//
// 返回值:
// - *net.IPNet: 解析后的 IP 网络对象
// - error: 解析失败时返回错误
func ParseCIDR(cidr string) (*net.IPNet, error) {
// 处理单个 IP没有 /前缀)
if !strings.Contains(cidr, "/") {
ip := net.ParseIP(cidr)
if ip == nil {
return nil, fmt.Errorf("invalid IP address: %s", cidr)
}
// 转换为完整掩码的 CIDR
if ip.To4() != nil {
cidr = cidr + "/32"
} else {
cidr = cidr + "/128"
}
}
// 解析 CIDR
ip, network, err := net.ParseCIDR(cidr)
if err != nil {
return nil, err
}
// 确保 IP 为规范形式
network.IP = ip
return network, nil
}
// parseCIDR 是 ParseCIDR 的内部别名,保持向后兼容
func parseCIDR(cidr string) (*net.IPNet, error) {
_, network, err := net.ParseCIDR(cidr)
return network, err
return ParseCIDR(cidr)
}
// IPInAllowList 检查 IP 是否在白名单中。