refactor: 统一 IP 白名单解析和 excludeSet 构建

Batch 1 重构:
- 新增 utils.ParseIPAllowList 统一 IP/CIDR 解析(含 localhost 特殊处理)
- pprof.go/status.go/purge.go 改用统一函数,减少 ~66 行重复代码
- 新增 loadbalance.buildExcludeSet 统一排除集合构建
- 更新 pprof_test.go 适配统一字段 allowed []net.IPNet

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-29 16:18:52 +08:00
parent 0e8b99e17f
commit 91e04222b3
7 changed files with 211 additions and 225 deletions

View File

@ -399,12 +399,7 @@ func DecrementConnections(t *Target) {
//
// 排除判断基于目标的 URL 进行匹配。
func filterHealthyAndExclude(targets []*Target, excluded []*Target) []*Target {
excludeSet := make(map[string]bool, len(excluded))
for _, t := range excluded {
if t != nil {
excludeSet[t.URL] = true
}
}
excludeSet := buildExcludeSet(excluded)
available := make([]*Target, 0, len(targets))
backups := make([]*Target, 0, len(targets))
@ -485,12 +480,7 @@ func (w *WeightedRoundRobin) SelectExcluding(targets []*Target, excluded []*Targ
// SelectExcluding 选择连接数最少的目标,排除指定的目标列表。
// 优先选择非备份目标,仅当无可用非备份目标时选择备份目标。
func (l *LeastConnections) SelectExcluding(targets []*Target, excluded []*Target) *Target {
excludeSet := make(map[string]bool, len(excluded))
for _, t := range excluded {
if t != nil {
excludeSet[t.URL] = true
}
}
excludeSet := buildExcludeSet(excluded)
var selected *Target
var selectedBackup *Target
@ -638,3 +628,25 @@ func (t *Target) LastResolved() time.Time {
}
return time.Unix(0, nano)
}
// buildExcludeSet 从排除列表构建 URL 集合。
//
// 用于负载均衡算法中快速检查目标是否应被排除。
//
// 参数:
// - excluded: 需要排除的目标列表
//
// 返回值:
// - map[string]bool: 目标 URL 到 true 的映射
func buildExcludeSet(excluded []*Target) map[string]bool {
if len(excluded) == 0 {
return nil
}
excludeSet := make(map[string]bool, len(excluded))
for _, t := range excluded {
if t != nil {
excludeSet[t.URL] = true
}
}
return excludeSet
}

View File

@ -238,12 +238,7 @@ func (c *ConsistentHash) SelectExcludingByKey(targets []*Target, excluded []*Tar
defer c.mu.RUnlock()
// 构建排除集合
excludeSet := make(map[string]bool, len(excluded))
for _, t := range excluded {
if t != nil {
excludeSet[t.URL] = true
}
}
excludeSet := buildExcludeSet(excluded)
// 如果没有排除的目标,使用正常选择
if len(excludeSet) == 0 {

View File

@ -29,6 +29,7 @@ import (
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/utils"
)
// PprofHandler pprof 性能分析处理器。
@ -38,11 +39,8 @@ type PprofHandler struct {
// path 端点路径前缀
path string
// allowedIPs 允许访问的 IP 列表
allowedIPs []net.IP
// allowedNets 允许访问的 CIDR 网络
allowedNets []*net.IPNet
// allowed 允许访问的 IP 网络列表
allowed []net.IPNet
}
// NewPprofHandler 创建 pprof 处理器。
@ -68,22 +66,15 @@ func NewPprofHandler(cfg *config.PprofConfig) (*PprofHandler, error) {
h := &PprofHandler{path: path}
// 解析允许的 IP 列表
for _, ipStr := range cfg.Allow {
if ip := net.ParseIP(ipStr); ip != nil {
h.allowedIPs = append(h.allowedIPs, ip)
continue
}
// 尝试解析 CIDR
_, net, err := net.ParseCIDR(ipStr)
if err != nil {
return nil, fmt.Errorf("failed to parse IP/CIDR: %s: %w", ipStr, err)
}
h.allowedNets = append(h.allowedNets, net)
allowed, err := utils.ParseIPAllowList(cfg.Allow)
if err != nil {
return nil, fmt.Errorf("failed to parse IP/CIDR: %w", err)
}
h.allowed = allowed
// 默认只允许 localhost
if len(h.allowedIPs) == 0 && len(h.allowedNets) == 0 {
h.allowedIPs = []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}
if len(h.allowed) == 0 {
h.allowed, _ = utils.ParseIPAllowList([]string{"localhost"})
}
return h, nil
@ -157,7 +148,7 @@ func (h *PprofHandler) ServeHTTP(ctx *fasthttp.RequestCtx) {
// 返回值:
// - bool: true 表示允许访问false 表示禁止访问
func (h *PprofHandler) isAllowed(ctx *fasthttp.RequestCtx) bool {
if len(h.allowedIPs) == 0 && len(h.allowedNets) == 0 {
if len(h.allowed) == 0 {
return true // 无限制
}
@ -167,21 +158,7 @@ func (h *PprofHandler) isAllowed(ctx *fasthttp.RequestCtx) bool {
return false
}
// 检查精确 IP
for _, ip := range h.allowedIPs {
if ip.Equal(clientIP) {
return true
}
}
// 检查 CIDR 网络
for _, net := range h.allowedNets {
if net.Contains(clientIP) {
return true
}
}
return false
return utils.IPInAllowList(clientIP, h.allowed)
}
// handleIndex 处理索引页面。

View File

@ -124,8 +124,8 @@ func TestNewPprofHandler_SingleIP(t *testing.T) {
}
// 空列表时应该默认允许 localhost
if len(tt.allow) == 0 {
if len(h.allowedIPs) != 2 {
t.Errorf("expected 2 default allowed IPs (127.0.0.1 and ::1), got %d", len(h.allowedIPs))
if len(h.allowed) != 2 {
t.Errorf("expected 2 default allowed IPs (127.0.0.1 and ::1), got %d", len(h.allowed))
}
}
}
@ -274,56 +274,48 @@ func TestPprofHandler_isAllowed(t *testing.T) {
tests := []struct {
name string
clientIP string
allowedIPs []string
allowedNets []string
wantAllowed bool
}{
{
name: "empty allow list - allow all",
allowedIPs: []string{},
allowedNets: []string{},
clientIP: "192.168.1.100",
wantAllowed: true,
},
{
name: "IP exact match",
allowedIPs: []string{"127.0.0.1"},
allowedNets: []string{},
name: "IP exact match (as /32 CIDR)",
allowedNets: []string{"127.0.0.1/32"},
clientIP: "127.0.0.1",
wantAllowed: true,
},
{
name: "IP no match",
allowedIPs: []string{"127.0.0.1"},
allowedNets: []string{},
allowedNets: []string{"127.0.0.1/32"},
clientIP: "127.0.0.2",
wantAllowed: false,
},
{
name: "CIDR match",
allowedIPs: []string{},
allowedNets: []string{"192.168.0.0/16"},
clientIP: "192.168.1.100",
wantAllowed: true,
},
{
name: "CIDR no match",
allowedIPs: []string{},
allowedNets: []string{"10.0.0.0/8"},
clientIP: "192.168.1.100",
wantAllowed: false,
},
{
name: "IPv6 CIDR match",
allowedIPs: []string{},
allowedNets: []string{"2001:db8::/32"},
clientIP: "2001:db8::1",
wantAllowed: true,
},
{
name: "IPv6 exact match",
allowedIPs: []string{"::1"},
allowedNets: []string{},
name: "IPv6 exact match (as /128 CIDR)",
allowedNets: []string{"::1/128"},
clientIP: "::1",
wantAllowed: true,
},
@ -332,8 +324,7 @@ func TestPprofHandler_isAllowed(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &PprofHandler{
allowedIPs: parseIPs(tt.allowedIPs),
allowedNets: parseNets(tt.allowedNets),
allowed: parseNets(tt.allowedNets),
}
// 创建请求上下文,模拟客户端 IP
@ -348,16 +339,10 @@ func TestPprofHandler_isAllowed(t *testing.T) {
// 复制 isAllowed 的逻辑进行测试
allowed := false
if len(h.allowedIPs) == 0 && len(h.allowedNets) == 0 {
if len(h.allowed) == 0 {
allowed = true
} else {
for _, ip := range h.allowedIPs {
if ip.Equal(clientIP) {
allowed = true
break
}
}
for _, n := range h.allowedNets {
for _, n := range h.allowed {
if n.Contains(clientIP) {
allowed = true
break
@ -372,24 +357,13 @@ func TestPprofHandler_isAllowed(t *testing.T) {
}
}
// parseIPs 辅助函数,解析 IP 字符串列表
func parseIPs(ips []string) []net.IP {
result := make([]net.IP, 0, len(ips))
for _, ip := range ips {
if parsed := net.ParseIP(ip); parsed != nil {
result = append(result, parsed)
}
}
return result
}
// parseNets 辅助函数,解析 CIDR 字符串列表
func parseNets(cidrs []string) []*net.IPNet {
result := make([]*net.IPNet, 0, len(cidrs))
func parseNets(cidrs []string) []net.IPNet {
result := make([]net.IPNet, 0, len(cidrs))
for _, cidr := range cidrs {
_, net, err := net.ParseCIDR(cidr)
if err == nil {
result = append(result, net)
if err == nil && net != nil {
result = append(result, *net)
}
}
return result
@ -398,9 +372,8 @@ func parseNets(cidrs []string) []*net.IPNet {
func TestPprofHandler_ServeHTTP_WithAllowListEmpty(t *testing.T) {
// 测试空 allow 列表时允许所有访问
h := &PprofHandler{
path: "/debug/pprof",
allowedIPs: []net.IP{},
allowedNets: []*net.IPNet{},
path: "/debug/pprof",
allowed: []net.IPNet{},
}
ctx := &fasthttp.RequestCtx{}
@ -417,9 +390,8 @@ func TestPprofHandler_ServeHTTP_WithAllowListEmpty(t *testing.T) {
func TestPprofHandler_ServeHTTP_ProfileEndpoints(t *testing.T) {
// 使用空 allow 列表允许所有访问
h := &PprofHandler{
path: "/debug/pprof",
allowedIPs: []net.IP{},
allowedNets: []*net.IPNet{},
path: "/debug/pprof",
allowed: []net.IPNet{},
}
tests := []struct {
@ -465,9 +437,8 @@ func TestPprofHandler_ServeHTTP_ProfileEndpoints(t *testing.T) {
func TestPprofHandler_handleIndex(t *testing.T) {
h := &PprofHandler{
path: "/debug/pprof",
allowedIPs: []net.IP{},
allowedNets: []*net.IPNet{},
path: "/debug/pprof",
allowed: []net.IPNet{},
}
ctx := &fasthttp.RequestCtx{}
@ -510,9 +481,8 @@ func TestPprofHandler_handleIndex(t *testing.T) {
func TestPprofHandler_ServeHTTP_PathRouting(t *testing.T) {
h := &PprofHandler{
path: "/debug/pprof",
allowedIPs: []net.IP{},
allowedNets: []*net.IPNet{},
path: "/debug/pprof",
allowed: []net.IPNet{},
}
tests := []struct {
@ -564,30 +534,30 @@ func TestPprofHandler_ServeHTTP_PathRouting(t *testing.T) {
func TestPprofHandler_ServeHTTP_Forbidden(t *testing.T) {
// 创建只允许特定 IP 的 handler
allowedIP := net.ParseIP("10.0.0.1")
_, ipNet, _ := net.ParseCIDR("10.0.0.1/32")
h := &PprofHandler{
allowedIPs: []net.IP{allowedIP},
allowed: []net.IPNet{*ipNet},
}
// 由于无法轻松设置 RemoteIP我们直接测试 isAllowed 返回 false 的情况
// 通过构造一个 allowedIPs 非空的情况来触发检查
// 通过构造一个 allowed 非空的情况来触发检查
// 验证 handler 配置正确
if len(h.allowedIPs) != 1 {
t.Errorf("expected 1 allowed IP, got %d", len(h.allowedIPs))
if len(h.allowed) != 1 {
t.Errorf("expected 1 allowed IPNet, got %d", len(h.allowed))
}
// 验证 allowed IPs 包含配置的 IP
if !h.allowedIPs[0].Equal(allowedIP) {
t.Error("expected allowedIPs to contain configured IP")
// 验证 allowed 包含配置的 IP
expectedIP := net.ParseIP("10.0.0.1")
if !h.allowed[0].Contains(expectedIP) {
t.Error("expected allowed to contain configured IP")
}
}
func TestPprofHandler_handleCPU_Params(t *testing.T) {
h := &PprofHandler{
path: "/debug/pprof",
allowedIPs: []net.IP{},
allowedNets: []*net.IPNet{},
path: "/debug/pprof",
allowed: []net.IPNet{},
}
tests := []struct {
@ -687,9 +657,8 @@ func TestPprofHandler_handleCPU_Execute(t *testing.T) {
stopCPUProfile()
h := &PprofHandler{
path: "/debug/pprof",
allowedIPs: []net.IP{},
allowedNets: []*net.IPNet{},
path: "/debug/pprof",
allowed: []net.IPNet{},
}
ctx := &fasthttp.RequestCtx{}
@ -781,35 +750,37 @@ func TestPprofHandler_ConfigWithCIDRAndIP(t *testing.T) {
t.Fatal("expected non-nil handler")
}
// 验证 IP 和 CIDR 都被正确解析
if len(h.allowedIPs) != 2 {
t.Errorf("expected 2 allowed IPs, got %d", len(h.allowedIPs))
}
if len(h.allowedNets) != 1 {
t.Errorf("expected 1 allowed net, got %d", len(h.allowedNets))
// 验证 IP 和 CIDR 都被正确解析(现在统一存储在 allowed 中)
// 127.0.0.1 -> 127.0.0.1/32
// ::1 -> ::1/128
// 192.168.0.0/24 保持不变
if len(h.allowed) != 3 {
t.Errorf("expected 3 allowed entries (2 IPs converted to /32 and /128, 1 CIDR), got %d", len(h.allowed))
}
// 验证具体内容
// 验证具体内容 - 使用 Contains 检查
foundV4 := false
foundV6 := false
for _, ip := range h.allowedIPs {
if ip.Equal(net.ParseIP("127.0.0.1")) {
foundCIDR := false
for _, ipNet := range h.allowed {
if ipNet.Contains(net.ParseIP("127.0.0.1")) && ipNet.String() == "127.0.0.1/32" {
foundV4 = true
}
if ip.Equal(net.ParseIP("::1")) {
if ipNet.Contains(net.ParseIP("::1")) && ipNet.String() == "::1/128" {
foundV6 = true
}
if ipNet.String() == "192.168.0.0/24" {
foundCIDR = true
}
}
if !foundV4 {
t.Error("expected to find 127.0.0.1 in allowedIPs")
t.Error("expected to find 127.0.0.1/32 in allowed")
}
if !foundV6 {
t.Error("expected to find ::1 in allowedIPs")
t.Error("expected to find ::1/128 in allowed")
}
// 验证 CIDR
if h.allowedNets[0].String() != "192.168.0.0/24" {
t.Errorf("expected CIDR 192.168.0.0/24, got %s", h.allowedNets[0].String())
if !foundCIDR {
t.Error("expected to find 192.168.0.0/24 in allowed")
}
}
@ -829,35 +800,34 @@ func TestPprofHandler_DefaultLocalhostBehavior(t *testing.T) {
t.Fatal("expected non-nil handler")
}
// 验证默认允许 localhost
if len(h.allowedIPs) != 2 {
t.Errorf("expected 2 default allowed IPs, got %d", len(h.allowedIPs))
// 验证默认允许 localhost (解析为 127.0.0.1/32 和 ::1/128)
if len(h.allowed) != 2 {
t.Errorf("expected 2 default allowed entries (127.0.0.1/32 and ::1/128), got %d", len(h.allowed))
}
// 验证包含 IPv4 和 IPv6 localhost
hasV4 := false
hasV6 := false
for _, ip := range h.allowedIPs {
if ip.Equal(net.ParseIP("127.0.0.1")) {
for _, n := range h.allowed {
if n.Contains(net.ParseIP("127.0.0.1")) && n.String() == "127.0.0.1/32" {
hasV4 = true
}
if ip.Equal(net.ParseIP("::1")) {
if n.Contains(net.ParseIP("::1")) && n.String() == "::1/128" {
hasV6 = true
}
}
if !hasV4 {
t.Error("expected default to include 127.0.0.1")
t.Error("expected default to include 127.0.0.1/32")
}
if !hasV6 {
t.Error("expected default to include ::1")
t.Error("expected default to include ::1/128")
}
}
func TestPprofHandler_handleHeap(t *testing.T) {
h := &PprofHandler{
path: "/debug/pprof",
allowedIPs: []net.IP{},
allowedNets: []*net.IPNet{},
path: "/debug/pprof",
allowed: []net.IPNet{},
}
ctx := &fasthttp.RequestCtx{}
@ -879,9 +849,8 @@ func TestPprofHandler_handleHeap(t *testing.T) {
func TestPprofHandler_handleGoroutine(t *testing.T) {
h := &PprofHandler{
path: "/debug/pprof",
allowedIPs: []net.IP{},
allowedNets: []*net.IPNet{},
path: "/debug/pprof",
allowed: []net.IPNet{},
}
ctx := &fasthttp.RequestCtx{}
@ -903,9 +872,8 @@ func TestPprofHandler_handleGoroutine(t *testing.T) {
func TestPprofHandler_handleBlock(t *testing.T) {
h := &PprofHandler{
path: "/debug/pprof",
allowedIPs: []net.IP{},
allowedNets: []*net.IPNet{},
path: "/debug/pprof",
allowed: []net.IPNet{},
}
ctx := &fasthttp.RequestCtx{}
@ -927,9 +895,8 @@ func TestPprofHandler_handleBlock(t *testing.T) {
func TestPprofHandler_handleMutex(t *testing.T) {
h := &PprofHandler{
path: "/debug/pprof",
allowedIPs: []net.IP{},
allowedNets: []*net.IPNet{},
path: "/debug/pprof",
allowed: []net.IPNet{},
}
ctx := &fasthttp.RequestCtx{}
@ -953,9 +920,8 @@ func TestPprofHandler_handleMutex(t *testing.T) {
func TestPprofHandler_isAllowed_RemoteIP(t *testing.T) {
t.Run("empty allow lists - allow all", func(t *testing.T) {
h := &PprofHandler{
path: "/debug/pprof",
allowedIPs: []net.IP{},
allowedNets: []*net.IPNet{},
path: "/debug/pprof",
allowed: []net.IPNet{},
}
ctx := &fasthttp.RequestCtx{}
@ -966,11 +932,10 @@ func TestPprofHandler_isAllowed_RemoteIP(t *testing.T) {
})
t.Run("with allow list but cannot parse IP", func(t *testing.T) {
allowedIP := net.ParseIP("192.168.1.1")
_, ipNet, _ := net.ParseCIDR("192.168.1.1/32")
h := &PprofHandler{
path: "/debug/pprof",
allowedIPs: []net.IP{allowedIP},
allowedNets: []*net.IPNet{},
path: "/debug/pprof",
allowed: []net.IPNet{*ipNet},
}
ctx := &fasthttp.RequestCtx{}

View File

@ -6,7 +6,6 @@ package server
import (
"encoding/json"
"net"
"net/netip"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/cache"
@ -55,38 +54,11 @@ func NewPurgeHandler(server *Server, cfg *config.CacheAPIConfig) (*PurgeHandler,
}
// 解析允许的 IP 列表
for _, cidr := range cfg.Allow {
// 处理 localhost 特殊情况
if cidr == "localhost" {
_, v4Network, _ := net.ParseCIDR("127.0.0.1/32")
_, v6Network, _ := net.ParseCIDR("::1/128")
if v4Network != nil {
h.allowed = append(h.allowed, *v4Network)
}
if v6Network != nil {
h.allowed = append(h.allowed, *v6Network)
}
continue
}
_, network, err := net.ParseCIDR(cidr)
if err != nil {
// 尝试作为单个 IP 解析
ip, err := netip.ParseAddr(cidr)
if err != nil {
return nil, err
}
// 转换为 CIDR 格式
if ip.Is4() {
_, network, _ = net.ParseCIDR(cidr + "/32")
} else {
_, network, _ = net.ParseCIDR(cidr + "/128")
}
}
if network != nil {
h.allowed = append(h.allowed, *network)
}
allowed, err := utils.ParseIPAllowList(cfg.Allow)
if err != nil {
return nil, err
}
h.allowed = allowed
return h, nil
}

View File

@ -134,39 +134,11 @@ func NewStatusHandler(server *Server, cfg *config.StatusConfig) (*StatusHandler,
}
// 解析允许的 IP 列表
for _, cidr := range cfg.Allow {
// 处理 localhost 特殊情况
if cidr == "localhost" {
// localhost 解析为 127.0.0.1 和 ::1
_, v4Network, _ := net.ParseCIDR("127.0.0.1/32")
_, v6Network, _ := net.ParseCIDR("::1/128")
if v4Network != nil {
h.allowed = append(h.allowed, *v4Network)
}
if v6Network != nil {
h.allowed = append(h.allowed, *v6Network)
}
continue
}
_, network, err := net.ParseCIDR(cidr)
if err != nil {
// 尝试作为单个 IP 解析
ip := net.ParseIP(cidr)
if ip == nil {
return nil, err
}
// 转换为 CIDR 格式
if ip.To4() != nil {
_, network, _ = net.ParseCIDR(cidr + "/32")
} else {
_, network, _ = net.ParseCIDR(cidr + "/128")
}
}
if network != nil {
h.allowed = append(h.allowed, *network)
}
allowed, err := utils.ParseIPAllowList(cfg.Allow)
if err != nil {
return nil, err
}
h.allowed = allowed
return h, nil
}

View File

@ -0,0 +1,93 @@
// Package utils 提供通用工具函数
package utils
import (
"net"
)
// ParseIPAllowList 解析 IP/CIDR 白名单列表。
//
// 支持格式:
// - CIDR 格式192.168.1.0/24、::1/128
// - 单个 IP192.168.1.1(自动转换为 /32 或 /128
// - 特殊值 "localhost":映射为 127.0.0.1/32 和 ::1/128
//
// 参数:
// - allow: IP/CIDR 字符串列表
//
// 返回值:
// - []net.IPNet: 解析后的网络列表
// - error: 解析失败时返回错误
func ParseIPAllowList(allow []string) ([]net.IPNet, error) {
if len(allow) == 0 {
return nil, nil
}
result := make([]net.IPNet, 0, len(allow)+2) // +2 for localhost expansion
for _, cidr := range allow {
// 处理 localhost 特殊情况
if cidr == "localhost" {
// localhost 解析为 127.0.0.1 和 ::1
if v4Net, err := parseCIDR("127.0.0.1/32"); err == nil {
result = append(result, *v4Net)
}
if v6Net, err := parseCIDR("::1/128"); err == nil {
result = append(result, *v6Net)
}
continue
}
// 尝试 CIDR 解析
_, network, err := net.ParseCIDR(cidr)
if err == nil && network != nil {
result = append(result, *network)
continue
}
// fallback: 尝试作为单个 IP 解析
ip := net.ParseIP(cidr)
if ip == nil {
return nil, err // 返回原始 CIDR 解析错误
}
// 转换为 CIDR 格式
var ipNet *net.IPNet
if ip.To4() != nil {
ipNet, _ = parseCIDR(cidr + "/32")
} else {
ipNet, _ = parseCIDR(cidr + "/128")
}
if ipNet != nil {
result = append(result, *ipNet)
}
}
return result, nil
}
// parseCIDR 是 net.ParseCIDR 的包装,返回 *net.IPNet 而不返回 net.IP
func parseCIDR(cidr string) (*net.IPNet, error) {
_, network, err := net.ParseCIDR(cidr)
return network, err
}
// IPInAllowList 检查 IP 是否在白名单中。
//
// 参数:
// - ip: 要检查的 IP 地址
// - allowList: 白名单网络列表
//
// 返回值:
// - bool: IP 在白名单中返回 true
func IPInAllowList(ip net.IP, allowList []net.IPNet) bool {
if len(allowList) == 0 {
return false
}
for _, network := range allowList {
if network.Contains(ip) {
return true
}
}
return false
}