diff --git a/internal/middleware/security/access.go b/internal/middleware/security/access.go index abd3dba..919b616 100644 --- a/internal/middleware/security/access.go +++ b/internal/middleware/security/access.go @@ -36,6 +36,7 @@ import ( "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/middleware" "rua.plus/lolly/internal/netutil" + "rua.plus/lolly/internal/utils" ) // Action 表示对 IP 的操作类型。 @@ -87,7 +88,7 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) { // 解析允许列表 for _, cidr := range cfg.Allow { - network, err := parseCIDR(cidr) + network, err := utils.ParseCIDR(cidr) if err != nil { return nil, fmt.Errorf("invalid allow CIDR %s: %w", cidr, err) } @@ -96,7 +97,7 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) { // 解析拒绝列表 for _, cidr := range cfg.Deny { - network, err := parseCIDR(cidr) + network, err := utils.ParseCIDR(cidr) if err != nil { return nil, fmt.Errorf("invalid deny CIDR %s: %w", cidr, err) } @@ -105,7 +106,7 @@ func NewAccessControl(cfg *config.AccessConfig) (*AccessControl, error) { // 解析可信代理列表 for _, cidr := range cfg.TrustedProxies { - network, err := parseCIDR(cidr) + network, err := utils.ParseCIDR(cidr) if err != nil { return nil, fmt.Errorf("invalid trusted_proxy CIDR %s: %w", cidr, err) } @@ -314,7 +315,7 @@ func (ac *AccessControl) UpdateDenyList(cidrs []string) error { func parseCIDRList(cidrs []string) ([]net.IPNet, error) { newList := make([]net.IPNet, 0, len(cidrs)) for _, cidr := range cidrs { - network, err := parseCIDR(cidr) + network, err := utils.ParseCIDR(cidr) if err != nil { return nil, fmt.Errorf("invalid CIDR %s: %w", cidr, err) } @@ -346,45 +347,6 @@ func (ac *AccessControl) SetDefault(action string) error { return nil } -// parseCIDR 解析 CIDR 字符串,支持 IPv4 和 IPv6。 -// -// 支持完整的 CIDR 表示法(如 192.168.1.0/24)和单个 IP(如 192.168.1.1)。 -// 单个 IP 会自动转换为 /32(IPv4)或 /128(IPv6)的 CIDR。 -// -// 参数: -// - cidr: CIDR 字符串或单个 IP 地址 -// -// 返回值: -// - *net.IPNet: 解析后的 IP 网络对象 -// - error: 解析失败时返回错误 -func parseCIDR(cidr string) (*net.IPNet, error) { - // 处理单个 IP(没有 /前缀) - if !strings.Contains(cidr, "/") { - ip := net.ParseIP(cidr) - if ip == nil { - return nil, fmt.Errorf("invalid IP address: %s", cidr) - } - - // 转换为完整掩码的 CIDR - if ip.To4() != nil { - cidr = cidr + "/32" - } else { - cidr = cidr + "/128" - } - } - - // 解析 CIDR - ip, network, err := net.ParseCIDR(cidr) - if err != nil { - return nil, err - } - - // 确保 IP 为规范形式 - network.IP = ip - - return network, nil -} - // getClientIP 从请求上下文安全提取客户端 IP。 // // 仅当请求来自可信代理时,才解析 X-Forwarded-For 头部。 diff --git a/internal/middleware/security/access_test.go b/internal/middleware/security/access_test.go index 824a540..fef68ba 100644 --- a/internal/middleware/security/access_test.go +++ b/internal/middleware/security/access_test.go @@ -16,6 +16,7 @@ import ( "github.com/valyala/fasthttp" "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/utils" ) func TestNewAccessControl(t *testing.T) { @@ -244,9 +245,9 @@ func TestParseCIDR(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - network, err := parseCIDR(tt.cidr) + network, err := utils.ParseCIDR(tt.cidr) if (err != nil) != tt.wantErr { - t.Errorf("parseCIDR() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("utils.ParseCIDR() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr && network == nil { t.Error("Expected non-nil network") diff --git a/internal/utils/ipallowlist.go b/internal/utils/ipallowlist.go index a846220..4bec467 100644 --- a/internal/utils/ipallowlist.go +++ b/internal/utils/ipallowlist.go @@ -2,7 +2,9 @@ package utils import ( + "fmt" "net" + "strings" ) // ParseIPAllowList 解析 IP/CIDR 白名单列表。 @@ -66,10 +68,49 @@ func ParseIPAllowList(allow []string) ([]net.IPNet, error) { return result, nil } -// parseCIDR 是 net.ParseCIDR 的包装,返回 *net.IPNet 而不返回 net.IP +// ParseCIDR 解析 CIDR 字符串或单个 IP 地址。 +// +// 支持格式: +// - CIDR 格式:192.168.1.0/24、::1/128 +// - 单个 IP:192.168.1.1(自动转换为 /32 或 /128) +// +// 参数: +// - cidr: CIDR 字符串或单个 IP 地址 +// +// 返回值: +// - *net.IPNet: 解析后的 IP 网络对象 +// - error: 解析失败时返回错误 +func ParseCIDR(cidr string) (*net.IPNet, error) { + // 处理单个 IP(没有 /前缀) + if !strings.Contains(cidr, "/") { + ip := net.ParseIP(cidr) + if ip == nil { + return nil, fmt.Errorf("invalid IP address: %s", cidr) + } + + // 转换为完整掩码的 CIDR + if ip.To4() != nil { + cidr = cidr + "/32" + } else { + cidr = cidr + "/128" + } + } + + // 解析 CIDR + ip, network, err := net.ParseCIDR(cidr) + if err != nil { + return nil, err + } + + // 确保 IP 为规范形式 + network.IP = ip + + return network, nil +} + +// parseCIDR 是 ParseCIDR 的内部别名,保持向后兼容 func parseCIDR(cidr string) (*net.IPNet, error) { - _, network, err := net.ParseCIDR(cidr) - return network, err + return ParseCIDR(cidr) } // IPInAllowList 检查 IP 是否在白名单中。