refactor: 统一 IP 白名单解析和 excludeSet 构建

Batch 1 重构:
- 新增 utils.ParseIPAllowList 统一 IP/CIDR 解析(含 localhost 特殊处理)
- pprof.go/status.go/purge.go 改用统一函数,减少 ~66 行重复代码
- 新增 loadbalance.buildExcludeSet 统一排除集合构建
- 更新 pprof_test.go 适配统一字段 allowed []net.IPNet

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-29 16:18:52 +08:00
parent 0e8b99e17f
commit 91e04222b3
7 changed files with 211 additions and 225 deletions

View File

@ -399,12 +399,7 @@ func DecrementConnections(t *Target) {
// //
// 排除判断基于目标的 URL 进行匹配。 // 排除判断基于目标的 URL 进行匹配。
func filterHealthyAndExclude(targets []*Target, excluded []*Target) []*Target { func filterHealthyAndExclude(targets []*Target, excluded []*Target) []*Target {
excludeSet := make(map[string]bool, len(excluded)) excludeSet := buildExcludeSet(excluded)
for _, t := range excluded {
if t != nil {
excludeSet[t.URL] = true
}
}
available := make([]*Target, 0, len(targets)) available := make([]*Target, 0, len(targets))
backups := make([]*Target, 0, len(targets)) backups := make([]*Target, 0, len(targets))
@ -485,12 +480,7 @@ func (w *WeightedRoundRobin) SelectExcluding(targets []*Target, excluded []*Targ
// SelectExcluding 选择连接数最少的目标,排除指定的目标列表。 // SelectExcluding 选择连接数最少的目标,排除指定的目标列表。
// 优先选择非备份目标,仅当无可用非备份目标时选择备份目标。 // 优先选择非备份目标,仅当无可用非备份目标时选择备份目标。
func (l *LeastConnections) SelectExcluding(targets []*Target, excluded []*Target) *Target { func (l *LeastConnections) SelectExcluding(targets []*Target, excluded []*Target) *Target {
excludeSet := make(map[string]bool, len(excluded)) excludeSet := buildExcludeSet(excluded)
for _, t := range excluded {
if t != nil {
excludeSet[t.URL] = true
}
}
var selected *Target var selected *Target
var selectedBackup *Target var selectedBackup *Target
@ -638,3 +628,25 @@ func (t *Target) LastResolved() time.Time {
} }
return time.Unix(0, nano) return time.Unix(0, nano)
} }
// buildExcludeSet 从排除列表构建 URL 集合。
//
// 用于负载均衡算法中快速检查目标是否应被排除。
//
// 参数:
// - excluded: 需要排除的目标列表
//
// 返回值:
// - map[string]bool: 目标 URL 到 true 的映射
func buildExcludeSet(excluded []*Target) map[string]bool {
if len(excluded) == 0 {
return nil
}
excludeSet := make(map[string]bool, len(excluded))
for _, t := range excluded {
if t != nil {
excludeSet[t.URL] = true
}
}
return excludeSet
}

View File

@ -238,12 +238,7 @@ func (c *ConsistentHash) SelectExcludingByKey(targets []*Target, excluded []*Tar
defer c.mu.RUnlock() defer c.mu.RUnlock()
// 构建排除集合 // 构建排除集合
excludeSet := make(map[string]bool, len(excluded)) excludeSet := buildExcludeSet(excluded)
for _, t := range excluded {
if t != nil {
excludeSet[t.URL] = true
}
}
// 如果没有排除的目标,使用正常选择 // 如果没有排除的目标,使用正常选择
if len(excludeSet) == 0 { if len(excludeSet) == 0 {

View File

@ -29,6 +29,7 @@ import (
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config" "rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/utils"
) )
// PprofHandler pprof 性能分析处理器。 // PprofHandler pprof 性能分析处理器。
@ -38,11 +39,8 @@ type PprofHandler struct {
// path 端点路径前缀 // path 端点路径前缀
path string path string
// allowedIPs 允许访问的 IP 列表 // allowed 允许访问的 IP 网络列表
allowedIPs []net.IP allowed []net.IPNet
// allowedNets 允许访问的 CIDR 网络
allowedNets []*net.IPNet
} }
// NewPprofHandler 创建 pprof 处理器。 // NewPprofHandler 创建 pprof 处理器。
@ -68,22 +66,15 @@ func NewPprofHandler(cfg *config.PprofConfig) (*PprofHandler, error) {
h := &PprofHandler{path: path} h := &PprofHandler{path: path}
// 解析允许的 IP 列表 // 解析允许的 IP 列表
for _, ipStr := range cfg.Allow { allowed, err := utils.ParseIPAllowList(cfg.Allow)
if ip := net.ParseIP(ipStr); ip != nil { if err != nil {
h.allowedIPs = append(h.allowedIPs, ip) return nil, fmt.Errorf("failed to parse IP/CIDR: %w", err)
continue
}
// 尝试解析 CIDR
_, net, err := net.ParseCIDR(ipStr)
if err != nil {
return nil, fmt.Errorf("failed to parse IP/CIDR: %s: %w", ipStr, err)
}
h.allowedNets = append(h.allowedNets, net)
} }
h.allowed = allowed
// 默认只允许 localhost // 默认只允许 localhost
if len(h.allowedIPs) == 0 && len(h.allowedNets) == 0 { if len(h.allowed) == 0 {
h.allowedIPs = []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")} h.allowed, _ = utils.ParseIPAllowList([]string{"localhost"})
} }
return h, nil return h, nil
@ -157,7 +148,7 @@ func (h *PprofHandler) ServeHTTP(ctx *fasthttp.RequestCtx) {
// 返回值: // 返回值:
// - bool: true 表示允许访问false 表示禁止访问 // - bool: true 表示允许访问false 表示禁止访问
func (h *PprofHandler) isAllowed(ctx *fasthttp.RequestCtx) bool { func (h *PprofHandler) isAllowed(ctx *fasthttp.RequestCtx) bool {
if len(h.allowedIPs) == 0 && len(h.allowedNets) == 0 { if len(h.allowed) == 0 {
return true // 无限制 return true // 无限制
} }
@ -167,21 +158,7 @@ func (h *PprofHandler) isAllowed(ctx *fasthttp.RequestCtx) bool {
return false return false
} }
// 检查精确 IP return utils.IPInAllowList(clientIP, h.allowed)
for _, ip := range h.allowedIPs {
if ip.Equal(clientIP) {
return true
}
}
// 检查 CIDR 网络
for _, net := range h.allowedNets {
if net.Contains(clientIP) {
return true
}
}
return false
} }
// handleIndex 处理索引页面。 // handleIndex 处理索引页面。

View File

@ -124,8 +124,8 @@ func TestNewPprofHandler_SingleIP(t *testing.T) {
} }
// 空列表时应该默认允许 localhost // 空列表时应该默认允许 localhost
if len(tt.allow) == 0 { if len(tt.allow) == 0 {
if len(h.allowedIPs) != 2 { if len(h.allowed) != 2 {
t.Errorf("expected 2 default allowed IPs (127.0.0.1 and ::1), got %d", len(h.allowedIPs)) t.Errorf("expected 2 default allowed IPs (127.0.0.1 and ::1), got %d", len(h.allowed))
} }
} }
} }
@ -274,56 +274,48 @@ func TestPprofHandler_isAllowed(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
clientIP string clientIP string
allowedIPs []string
allowedNets []string allowedNets []string
wantAllowed bool wantAllowed bool
}{ }{
{ {
name: "empty allow list - allow all", name: "empty allow list - allow all",
allowedIPs: []string{},
allowedNets: []string{}, allowedNets: []string{},
clientIP: "192.168.1.100", clientIP: "192.168.1.100",
wantAllowed: true, wantAllowed: true,
}, },
{ {
name: "IP exact match", name: "IP exact match (as /32 CIDR)",
allowedIPs: []string{"127.0.0.1"}, allowedNets: []string{"127.0.0.1/32"},
allowedNets: []string{},
clientIP: "127.0.0.1", clientIP: "127.0.0.1",
wantAllowed: true, wantAllowed: true,
}, },
{ {
name: "IP no match", name: "IP no match",
allowedIPs: []string{"127.0.0.1"}, allowedNets: []string{"127.0.0.1/32"},
allowedNets: []string{},
clientIP: "127.0.0.2", clientIP: "127.0.0.2",
wantAllowed: false, wantAllowed: false,
}, },
{ {
name: "CIDR match", name: "CIDR match",
allowedIPs: []string{},
allowedNets: []string{"192.168.0.0/16"}, allowedNets: []string{"192.168.0.0/16"},
clientIP: "192.168.1.100", clientIP: "192.168.1.100",
wantAllowed: true, wantAllowed: true,
}, },
{ {
name: "CIDR no match", name: "CIDR no match",
allowedIPs: []string{},
allowedNets: []string{"10.0.0.0/8"}, allowedNets: []string{"10.0.0.0/8"},
clientIP: "192.168.1.100", clientIP: "192.168.1.100",
wantAllowed: false, wantAllowed: false,
}, },
{ {
name: "IPv6 CIDR match", name: "IPv6 CIDR match",
allowedIPs: []string{},
allowedNets: []string{"2001:db8::/32"}, allowedNets: []string{"2001:db8::/32"},
clientIP: "2001:db8::1", clientIP: "2001:db8::1",
wantAllowed: true, wantAllowed: true,
}, },
{ {
name: "IPv6 exact match", name: "IPv6 exact match (as /128 CIDR)",
allowedIPs: []string{"::1"}, allowedNets: []string{"::1/128"},
allowedNets: []string{},
clientIP: "::1", clientIP: "::1",
wantAllowed: true, wantAllowed: true,
}, },
@ -332,8 +324,7 @@ func TestPprofHandler_isAllowed(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := &PprofHandler{ h := &PprofHandler{
allowedIPs: parseIPs(tt.allowedIPs), allowed: parseNets(tt.allowedNets),
allowedNets: parseNets(tt.allowedNets),
} }
// 创建请求上下文,模拟客户端 IP // 创建请求上下文,模拟客户端 IP
@ -348,16 +339,10 @@ func TestPprofHandler_isAllowed(t *testing.T) {
// 复制 isAllowed 的逻辑进行测试 // 复制 isAllowed 的逻辑进行测试
allowed := false allowed := false
if len(h.allowedIPs) == 0 && len(h.allowedNets) == 0 { if len(h.allowed) == 0 {
allowed = true allowed = true
} else { } else {
for _, ip := range h.allowedIPs { for _, n := range h.allowed {
if ip.Equal(clientIP) {
allowed = true
break
}
}
for _, n := range h.allowedNets {
if n.Contains(clientIP) { if n.Contains(clientIP) {
allowed = true allowed = true
break break
@ -372,24 +357,13 @@ func TestPprofHandler_isAllowed(t *testing.T) {
} }
} }
// parseIPs 辅助函数,解析 IP 字符串列表
func parseIPs(ips []string) []net.IP {
result := make([]net.IP, 0, len(ips))
for _, ip := range ips {
if parsed := net.ParseIP(ip); parsed != nil {
result = append(result, parsed)
}
}
return result
}
// parseNets 辅助函数,解析 CIDR 字符串列表 // parseNets 辅助函数,解析 CIDR 字符串列表
func parseNets(cidrs []string) []*net.IPNet { func parseNets(cidrs []string) []net.IPNet {
result := make([]*net.IPNet, 0, len(cidrs)) result := make([]net.IPNet, 0, len(cidrs))
for _, cidr := range cidrs { for _, cidr := range cidrs {
_, net, err := net.ParseCIDR(cidr) _, net, err := net.ParseCIDR(cidr)
if err == nil { if err == nil && net != nil {
result = append(result, net) result = append(result, *net)
} }
} }
return result return result
@ -398,9 +372,8 @@ func parseNets(cidrs []string) []*net.IPNet {
func TestPprofHandler_ServeHTTP_WithAllowListEmpty(t *testing.T) { func TestPprofHandler_ServeHTTP_WithAllowListEmpty(t *testing.T) {
// 测试空 allow 列表时允许所有访问 // 测试空 allow 列表时允许所有访问
h := &PprofHandler{ h := &PprofHandler{
path: "/debug/pprof", path: "/debug/pprof",
allowedIPs: []net.IP{}, allowed: []net.IPNet{},
allowedNets: []*net.IPNet{},
} }
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
@ -417,9 +390,8 @@ func TestPprofHandler_ServeHTTP_WithAllowListEmpty(t *testing.T) {
func TestPprofHandler_ServeHTTP_ProfileEndpoints(t *testing.T) { func TestPprofHandler_ServeHTTP_ProfileEndpoints(t *testing.T) {
// 使用空 allow 列表允许所有访问 // 使用空 allow 列表允许所有访问
h := &PprofHandler{ h := &PprofHandler{
path: "/debug/pprof", path: "/debug/pprof",
allowedIPs: []net.IP{}, allowed: []net.IPNet{},
allowedNets: []*net.IPNet{},
} }
tests := []struct { tests := []struct {
@ -465,9 +437,8 @@ func TestPprofHandler_ServeHTTP_ProfileEndpoints(t *testing.T) {
func TestPprofHandler_handleIndex(t *testing.T) { func TestPprofHandler_handleIndex(t *testing.T) {
h := &PprofHandler{ h := &PprofHandler{
path: "/debug/pprof", path: "/debug/pprof",
allowedIPs: []net.IP{}, allowed: []net.IPNet{},
allowedNets: []*net.IPNet{},
} }
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
@ -510,9 +481,8 @@ func TestPprofHandler_handleIndex(t *testing.T) {
func TestPprofHandler_ServeHTTP_PathRouting(t *testing.T) { func TestPprofHandler_ServeHTTP_PathRouting(t *testing.T) {
h := &PprofHandler{ h := &PprofHandler{
path: "/debug/pprof", path: "/debug/pprof",
allowedIPs: []net.IP{}, allowed: []net.IPNet{},
allowedNets: []*net.IPNet{},
} }
tests := []struct { tests := []struct {
@ -564,30 +534,30 @@ func TestPprofHandler_ServeHTTP_PathRouting(t *testing.T) {
func TestPprofHandler_ServeHTTP_Forbidden(t *testing.T) { func TestPprofHandler_ServeHTTP_Forbidden(t *testing.T) {
// 创建只允许特定 IP 的 handler // 创建只允许特定 IP 的 handler
allowedIP := net.ParseIP("10.0.0.1") _, ipNet, _ := net.ParseCIDR("10.0.0.1/32")
h := &PprofHandler{ h := &PprofHandler{
allowedIPs: []net.IP{allowedIP}, allowed: []net.IPNet{*ipNet},
} }
// 由于无法轻松设置 RemoteIP我们直接测试 isAllowed 返回 false 的情况 // 由于无法轻松设置 RemoteIP我们直接测试 isAllowed 返回 false 的情况
// 通过构造一个 allowedIPs 非空的情况来触发检查 // 通过构造一个 allowed 非空的情况来触发检查
// 验证 handler 配置正确 // 验证 handler 配置正确
if len(h.allowedIPs) != 1 { if len(h.allowed) != 1 {
t.Errorf("expected 1 allowed IP, got %d", len(h.allowedIPs)) t.Errorf("expected 1 allowed IPNet, got %d", len(h.allowed))
} }
// 验证 allowed IPs 包含配置的 IP // 验证 allowed 包含配置的 IP
if !h.allowedIPs[0].Equal(allowedIP) { expectedIP := net.ParseIP("10.0.0.1")
t.Error("expected allowedIPs to contain configured IP") if !h.allowed[0].Contains(expectedIP) {
t.Error("expected allowed to contain configured IP")
} }
} }
func TestPprofHandler_handleCPU_Params(t *testing.T) { func TestPprofHandler_handleCPU_Params(t *testing.T) {
h := &PprofHandler{ h := &PprofHandler{
path: "/debug/pprof", path: "/debug/pprof",
allowedIPs: []net.IP{}, allowed: []net.IPNet{},
allowedNets: []*net.IPNet{},
} }
tests := []struct { tests := []struct {
@ -687,9 +657,8 @@ func TestPprofHandler_handleCPU_Execute(t *testing.T) {
stopCPUProfile() stopCPUProfile()
h := &PprofHandler{ h := &PprofHandler{
path: "/debug/pprof", path: "/debug/pprof",
allowedIPs: []net.IP{}, allowed: []net.IPNet{},
allowedNets: []*net.IPNet{},
} }
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
@ -781,35 +750,37 @@ func TestPprofHandler_ConfigWithCIDRAndIP(t *testing.T) {
t.Fatal("expected non-nil handler") t.Fatal("expected non-nil handler")
} }
// 验证 IP 和 CIDR 都被正确解析 // 验证 IP 和 CIDR 都被正确解析(现在统一存储在 allowed 中)
if len(h.allowedIPs) != 2 { // 127.0.0.1 -> 127.0.0.1/32
t.Errorf("expected 2 allowed IPs, got %d", len(h.allowedIPs)) // ::1 -> ::1/128
} // 192.168.0.0/24 保持不变
if len(h.allowedNets) != 1 { if len(h.allowed) != 3 {
t.Errorf("expected 1 allowed net, got %d", len(h.allowedNets)) t.Errorf("expected 3 allowed entries (2 IPs converted to /32 and /128, 1 CIDR), got %d", len(h.allowed))
} }
// 验证具体内容 // 验证具体内容 - 使用 Contains 检查
foundV4 := false foundV4 := false
foundV6 := false foundV6 := false
for _, ip := range h.allowedIPs { foundCIDR := false
if ip.Equal(net.ParseIP("127.0.0.1")) { for _, ipNet := range h.allowed {
if ipNet.Contains(net.ParseIP("127.0.0.1")) && ipNet.String() == "127.0.0.1/32" {
foundV4 = true foundV4 = true
} }
if ip.Equal(net.ParseIP("::1")) { if ipNet.Contains(net.ParseIP("::1")) && ipNet.String() == "::1/128" {
foundV6 = true foundV6 = true
} }
if ipNet.String() == "192.168.0.0/24" {
foundCIDR = true
}
} }
if !foundV4 { if !foundV4 {
t.Error("expected to find 127.0.0.1 in allowedIPs") t.Error("expected to find 127.0.0.1/32 in allowed")
} }
if !foundV6 { if !foundV6 {
t.Error("expected to find ::1 in allowedIPs") t.Error("expected to find ::1/128 in allowed")
} }
if !foundCIDR {
// 验证 CIDR t.Error("expected to find 192.168.0.0/24 in allowed")
if h.allowedNets[0].String() != "192.168.0.0/24" {
t.Errorf("expected CIDR 192.168.0.0/24, got %s", h.allowedNets[0].String())
} }
} }
@ -829,35 +800,34 @@ func TestPprofHandler_DefaultLocalhostBehavior(t *testing.T) {
t.Fatal("expected non-nil handler") t.Fatal("expected non-nil handler")
} }
// 验证默认允许 localhost // 验证默认允许 localhost (解析为 127.0.0.1/32 和 ::1/128)
if len(h.allowedIPs) != 2 { if len(h.allowed) != 2 {
t.Errorf("expected 2 default allowed IPs, got %d", len(h.allowedIPs)) t.Errorf("expected 2 default allowed entries (127.0.0.1/32 and ::1/128), got %d", len(h.allowed))
} }
// 验证包含 IPv4 和 IPv6 localhost // 验证包含 IPv4 和 IPv6 localhost
hasV4 := false hasV4 := false
hasV6 := false hasV6 := false
for _, ip := range h.allowedIPs { for _, n := range h.allowed {
if ip.Equal(net.ParseIP("127.0.0.1")) { if n.Contains(net.ParseIP("127.0.0.1")) && n.String() == "127.0.0.1/32" {
hasV4 = true hasV4 = true
} }
if ip.Equal(net.ParseIP("::1")) { if n.Contains(net.ParseIP("::1")) && n.String() == "::1/128" {
hasV6 = true hasV6 = true
} }
} }
if !hasV4 { if !hasV4 {
t.Error("expected default to include 127.0.0.1") t.Error("expected default to include 127.0.0.1/32")
} }
if !hasV6 { if !hasV6 {
t.Error("expected default to include ::1") t.Error("expected default to include ::1/128")
} }
} }
func TestPprofHandler_handleHeap(t *testing.T) { func TestPprofHandler_handleHeap(t *testing.T) {
h := &PprofHandler{ h := &PprofHandler{
path: "/debug/pprof", path: "/debug/pprof",
allowedIPs: []net.IP{}, allowed: []net.IPNet{},
allowedNets: []*net.IPNet{},
} }
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
@ -879,9 +849,8 @@ func TestPprofHandler_handleHeap(t *testing.T) {
func TestPprofHandler_handleGoroutine(t *testing.T) { func TestPprofHandler_handleGoroutine(t *testing.T) {
h := &PprofHandler{ h := &PprofHandler{
path: "/debug/pprof", path: "/debug/pprof",
allowedIPs: []net.IP{}, allowed: []net.IPNet{},
allowedNets: []*net.IPNet{},
} }
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
@ -903,9 +872,8 @@ func TestPprofHandler_handleGoroutine(t *testing.T) {
func TestPprofHandler_handleBlock(t *testing.T) { func TestPprofHandler_handleBlock(t *testing.T) {
h := &PprofHandler{ h := &PprofHandler{
path: "/debug/pprof", path: "/debug/pprof",
allowedIPs: []net.IP{}, allowed: []net.IPNet{},
allowedNets: []*net.IPNet{},
} }
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
@ -927,9 +895,8 @@ func TestPprofHandler_handleBlock(t *testing.T) {
func TestPprofHandler_handleMutex(t *testing.T) { func TestPprofHandler_handleMutex(t *testing.T) {
h := &PprofHandler{ h := &PprofHandler{
path: "/debug/pprof", path: "/debug/pprof",
allowedIPs: []net.IP{}, allowed: []net.IPNet{},
allowedNets: []*net.IPNet{},
} }
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
@ -953,9 +920,8 @@ func TestPprofHandler_handleMutex(t *testing.T) {
func TestPprofHandler_isAllowed_RemoteIP(t *testing.T) { func TestPprofHandler_isAllowed_RemoteIP(t *testing.T) {
t.Run("empty allow lists - allow all", func(t *testing.T) { t.Run("empty allow lists - allow all", func(t *testing.T) {
h := &PprofHandler{ h := &PprofHandler{
path: "/debug/pprof", path: "/debug/pprof",
allowedIPs: []net.IP{}, allowed: []net.IPNet{},
allowedNets: []*net.IPNet{},
} }
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}
@ -966,11 +932,10 @@ func TestPprofHandler_isAllowed_RemoteIP(t *testing.T) {
}) })
t.Run("with allow list but cannot parse IP", func(t *testing.T) { t.Run("with allow list but cannot parse IP", func(t *testing.T) {
allowedIP := net.ParseIP("192.168.1.1") _, ipNet, _ := net.ParseCIDR("192.168.1.1/32")
h := &PprofHandler{ h := &PprofHandler{
path: "/debug/pprof", path: "/debug/pprof",
allowedIPs: []net.IP{allowedIP}, allowed: []net.IPNet{*ipNet},
allowedNets: []*net.IPNet{},
} }
ctx := &fasthttp.RequestCtx{} ctx := &fasthttp.RequestCtx{}

View File

@ -6,7 +6,6 @@ package server
import ( import (
"encoding/json" "encoding/json"
"net" "net"
"net/netip"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
"rua.plus/lolly/internal/cache" "rua.plus/lolly/internal/cache"
@ -55,38 +54,11 @@ func NewPurgeHandler(server *Server, cfg *config.CacheAPIConfig) (*PurgeHandler,
} }
// 解析允许的 IP 列表 // 解析允许的 IP 列表
for _, cidr := range cfg.Allow { allowed, err := utils.ParseIPAllowList(cfg.Allow)
// 处理 localhost 特殊情况 if err != nil {
if cidr == "localhost" { return nil, err
_, v4Network, _ := net.ParseCIDR("127.0.0.1/32")
_, v6Network, _ := net.ParseCIDR("::1/128")
if v4Network != nil {
h.allowed = append(h.allowed, *v4Network)
}
if v6Network != nil {
h.allowed = append(h.allowed, *v6Network)
}
continue
}
_, network, err := net.ParseCIDR(cidr)
if err != nil {
// 尝试作为单个 IP 解析
ip, err := netip.ParseAddr(cidr)
if err != nil {
return nil, err
}
// 转换为 CIDR 格式
if ip.Is4() {
_, network, _ = net.ParseCIDR(cidr + "/32")
} else {
_, network, _ = net.ParseCIDR(cidr + "/128")
}
}
if network != nil {
h.allowed = append(h.allowed, *network)
}
} }
h.allowed = allowed
return h, nil return h, nil
} }

View File

@ -134,39 +134,11 @@ func NewStatusHandler(server *Server, cfg *config.StatusConfig) (*StatusHandler,
} }
// 解析允许的 IP 列表 // 解析允许的 IP 列表
for _, cidr := range cfg.Allow { allowed, err := utils.ParseIPAllowList(cfg.Allow)
// 处理 localhost 特殊情况 if err != nil {
if cidr == "localhost" { return nil, err
// localhost 解析为 127.0.0.1 和 ::1
_, v4Network, _ := net.ParseCIDR("127.0.0.1/32")
_, v6Network, _ := net.ParseCIDR("::1/128")
if v4Network != nil {
h.allowed = append(h.allowed, *v4Network)
}
if v6Network != nil {
h.allowed = append(h.allowed, *v6Network)
}
continue
}
_, network, err := net.ParseCIDR(cidr)
if err != nil {
// 尝试作为单个 IP 解析
ip := net.ParseIP(cidr)
if ip == nil {
return nil, err
}
// 转换为 CIDR 格式
if ip.To4() != nil {
_, network, _ = net.ParseCIDR(cidr + "/32")
} else {
_, network, _ = net.ParseCIDR(cidr + "/128")
}
}
if network != nil {
h.allowed = append(h.allowed, *network)
}
} }
h.allowed = allowed
return h, nil return h, nil
} }

View File

@ -0,0 +1,93 @@
// Package utils 提供通用工具函数
package utils
import (
"net"
)
// ParseIPAllowList 解析 IP/CIDR 白名单列表。
//
// 支持格式:
// - CIDR 格式192.168.1.0/24、::1/128
// - 单个 IP192.168.1.1(自动转换为 /32 或 /128
// - 特殊值 "localhost":映射为 127.0.0.1/32 和 ::1/128
//
// 参数:
// - allow: IP/CIDR 字符串列表
//
// 返回值:
// - []net.IPNet: 解析后的网络列表
// - error: 解析失败时返回错误
func ParseIPAllowList(allow []string) ([]net.IPNet, error) {
if len(allow) == 0 {
return nil, nil
}
result := make([]net.IPNet, 0, len(allow)+2) // +2 for localhost expansion
for _, cidr := range allow {
// 处理 localhost 特殊情况
if cidr == "localhost" {
// localhost 解析为 127.0.0.1 和 ::1
if v4Net, err := parseCIDR("127.0.0.1/32"); err == nil {
result = append(result, *v4Net)
}
if v6Net, err := parseCIDR("::1/128"); err == nil {
result = append(result, *v6Net)
}
continue
}
// 尝试 CIDR 解析
_, network, err := net.ParseCIDR(cidr)
if err == nil && network != nil {
result = append(result, *network)
continue
}
// fallback: 尝试作为单个 IP 解析
ip := net.ParseIP(cidr)
if ip == nil {
return nil, err // 返回原始 CIDR 解析错误
}
// 转换为 CIDR 格式
var ipNet *net.IPNet
if ip.To4() != nil {
ipNet, _ = parseCIDR(cidr + "/32")
} else {
ipNet, _ = parseCIDR(cidr + "/128")
}
if ipNet != nil {
result = append(result, *ipNet)
}
}
return result, nil
}
// parseCIDR 是 net.ParseCIDR 的包装,返回 *net.IPNet 而不返回 net.IP
func parseCIDR(cidr string) (*net.IPNet, error) {
_, network, err := net.ParseCIDR(cidr)
return network, err
}
// IPInAllowList 检查 IP 是否在白名单中。
//
// 参数:
// - ip: 要检查的 IP 地址
// - allowList: 白名单网络列表
//
// 返回值:
// - bool: IP 在白名单中返回 true
func IPInAllowList(ip net.IP, allowList []net.IPNet) bool {
if len(allowList) == 0 {
return false
}
for _, network := range allowList {
if network.Contains(ip) {
return true
}
}
return false
}