refactor: 统一 IP 白名单解析和 excludeSet 构建
Batch 1 重构: - 新增 utils.ParseIPAllowList 统一 IP/CIDR 解析(含 localhost 特殊处理) - pprof.go/status.go/purge.go 改用统一函数,减少 ~66 行重复代码 - 新增 loadbalance.buildExcludeSet 统一排除集合构建 - 更新 pprof_test.go 适配统一字段 allowed []net.IPNet Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
0e8b99e17f
commit
91e04222b3
@ -399,12 +399,7 @@ func DecrementConnections(t *Target) {
|
||||
//
|
||||
// 排除判断基于目标的 URL 进行匹配。
|
||||
func filterHealthyAndExclude(targets []*Target, excluded []*Target) []*Target {
|
||||
excludeSet := make(map[string]bool, len(excluded))
|
||||
for _, t := range excluded {
|
||||
if t != nil {
|
||||
excludeSet[t.URL] = true
|
||||
}
|
||||
}
|
||||
excludeSet := buildExcludeSet(excluded)
|
||||
|
||||
available := make([]*Target, 0, len(targets))
|
||||
backups := make([]*Target, 0, len(targets))
|
||||
@ -485,12 +480,7 @@ func (w *WeightedRoundRobin) SelectExcluding(targets []*Target, excluded []*Targ
|
||||
// SelectExcluding 选择连接数最少的目标,排除指定的目标列表。
|
||||
// 优先选择非备份目标,仅当无可用非备份目标时选择备份目标。
|
||||
func (l *LeastConnections) SelectExcluding(targets []*Target, excluded []*Target) *Target {
|
||||
excludeSet := make(map[string]bool, len(excluded))
|
||||
for _, t := range excluded {
|
||||
if t != nil {
|
||||
excludeSet[t.URL] = true
|
||||
}
|
||||
}
|
||||
excludeSet := buildExcludeSet(excluded)
|
||||
|
||||
var selected *Target
|
||||
var selectedBackup *Target
|
||||
@ -638,3 +628,25 @@ func (t *Target) LastResolved() time.Time {
|
||||
}
|
||||
return time.Unix(0, nano)
|
||||
}
|
||||
|
||||
// buildExcludeSet 从排除列表构建 URL 集合。
|
||||
//
|
||||
// 用于负载均衡算法中快速检查目标是否应被排除。
|
||||
//
|
||||
// 参数:
|
||||
// - excluded: 需要排除的目标列表
|
||||
//
|
||||
// 返回值:
|
||||
// - map[string]bool: 目标 URL 到 true 的映射
|
||||
func buildExcludeSet(excluded []*Target) map[string]bool {
|
||||
if len(excluded) == 0 {
|
||||
return nil
|
||||
}
|
||||
excludeSet := make(map[string]bool, len(excluded))
|
||||
for _, t := range excluded {
|
||||
if t != nil {
|
||||
excludeSet[t.URL] = true
|
||||
}
|
||||
}
|
||||
return excludeSet
|
||||
}
|
||||
|
||||
@ -238,12 +238,7 @@ func (c *ConsistentHash) SelectExcludingByKey(targets []*Target, excluded []*Tar
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
// 构建排除集合
|
||||
excludeSet := make(map[string]bool, len(excluded))
|
||||
for _, t := range excluded {
|
||||
if t != nil {
|
||||
excludeSet[t.URL] = true
|
||||
}
|
||||
}
|
||||
excludeSet := buildExcludeSet(excluded)
|
||||
|
||||
// 如果没有排除的目标,使用正常选择
|
||||
if len(excludeSet) == 0 {
|
||||
|
||||
@ -29,6 +29,7 @@ import (
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/utils"
|
||||
)
|
||||
|
||||
// PprofHandler pprof 性能分析处理器。
|
||||
@ -38,11 +39,8 @@ type PprofHandler struct {
|
||||
// path 端点路径前缀
|
||||
path string
|
||||
|
||||
// allowedIPs 允许访问的 IP 列表
|
||||
allowedIPs []net.IP
|
||||
|
||||
// allowedNets 允许访问的 CIDR 网络
|
||||
allowedNets []*net.IPNet
|
||||
// allowed 允许访问的 IP 网络列表
|
||||
allowed []net.IPNet
|
||||
}
|
||||
|
||||
// NewPprofHandler 创建 pprof 处理器。
|
||||
@ -68,22 +66,15 @@ func NewPprofHandler(cfg *config.PprofConfig) (*PprofHandler, error) {
|
||||
h := &PprofHandler{path: path}
|
||||
|
||||
// 解析允许的 IP 列表
|
||||
for _, ipStr := range cfg.Allow {
|
||||
if ip := net.ParseIP(ipStr); ip != nil {
|
||||
h.allowedIPs = append(h.allowedIPs, ip)
|
||||
continue
|
||||
}
|
||||
// 尝试解析 CIDR
|
||||
_, net, err := net.ParseCIDR(ipStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse IP/CIDR: %s: %w", ipStr, err)
|
||||
}
|
||||
h.allowedNets = append(h.allowedNets, net)
|
||||
allowed, err := utils.ParseIPAllowList(cfg.Allow)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse IP/CIDR: %w", err)
|
||||
}
|
||||
h.allowed = allowed
|
||||
|
||||
// 默认只允许 localhost
|
||||
if len(h.allowedIPs) == 0 && len(h.allowedNets) == 0 {
|
||||
h.allowedIPs = []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")}
|
||||
if len(h.allowed) == 0 {
|
||||
h.allowed, _ = utils.ParseIPAllowList([]string{"localhost"})
|
||||
}
|
||||
|
||||
return h, nil
|
||||
@ -157,7 +148,7 @@ func (h *PprofHandler) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
// 返回值:
|
||||
// - bool: true 表示允许访问,false 表示禁止访问
|
||||
func (h *PprofHandler) isAllowed(ctx *fasthttp.RequestCtx) bool {
|
||||
if len(h.allowedIPs) == 0 && len(h.allowedNets) == 0 {
|
||||
if len(h.allowed) == 0 {
|
||||
return true // 无限制
|
||||
}
|
||||
|
||||
@ -167,21 +158,7 @@ func (h *PprofHandler) isAllowed(ctx *fasthttp.RequestCtx) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查精确 IP
|
||||
for _, ip := range h.allowedIPs {
|
||||
if ip.Equal(clientIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 CIDR 网络
|
||||
for _, net := range h.allowedNets {
|
||||
if net.Contains(clientIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return utils.IPInAllowList(clientIP, h.allowed)
|
||||
}
|
||||
|
||||
// handleIndex 处理索引页面。
|
||||
|
||||
@ -124,8 +124,8 @@ func TestNewPprofHandler_SingleIP(t *testing.T) {
|
||||
}
|
||||
// 空列表时应该默认允许 localhost
|
||||
if len(tt.allow) == 0 {
|
||||
if len(h.allowedIPs) != 2 {
|
||||
t.Errorf("expected 2 default allowed IPs (127.0.0.1 and ::1), got %d", len(h.allowedIPs))
|
||||
if len(h.allowed) != 2 {
|
||||
t.Errorf("expected 2 default allowed IPs (127.0.0.1 and ::1), got %d", len(h.allowed))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -274,56 +274,48 @@ func TestPprofHandler_isAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientIP string
|
||||
allowedIPs []string
|
||||
allowedNets []string
|
||||
wantAllowed bool
|
||||
}{
|
||||
{
|
||||
name: "empty allow list - allow all",
|
||||
allowedIPs: []string{},
|
||||
allowedNets: []string{},
|
||||
clientIP: "192.168.1.100",
|
||||
wantAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "IP exact match",
|
||||
allowedIPs: []string{"127.0.0.1"},
|
||||
allowedNets: []string{},
|
||||
name: "IP exact match (as /32 CIDR)",
|
||||
allowedNets: []string{"127.0.0.1/32"},
|
||||
clientIP: "127.0.0.1",
|
||||
wantAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "IP no match",
|
||||
allowedIPs: []string{"127.0.0.1"},
|
||||
allowedNets: []string{},
|
||||
allowedNets: []string{"127.0.0.1/32"},
|
||||
clientIP: "127.0.0.2",
|
||||
wantAllowed: false,
|
||||
},
|
||||
{
|
||||
name: "CIDR match",
|
||||
allowedIPs: []string{},
|
||||
allowedNets: []string{"192.168.0.0/16"},
|
||||
clientIP: "192.168.1.100",
|
||||
wantAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "CIDR no match",
|
||||
allowedIPs: []string{},
|
||||
allowedNets: []string{"10.0.0.0/8"},
|
||||
clientIP: "192.168.1.100",
|
||||
wantAllowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 CIDR match",
|
||||
allowedIPs: []string{},
|
||||
allowedNets: []string{"2001:db8::/32"},
|
||||
clientIP: "2001:db8::1",
|
||||
wantAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 exact match",
|
||||
allowedIPs: []string{"::1"},
|
||||
allowedNets: []string{},
|
||||
name: "IPv6 exact match (as /128 CIDR)",
|
||||
allowedNets: []string{"::1/128"},
|
||||
clientIP: "::1",
|
||||
wantAllowed: true,
|
||||
},
|
||||
@ -332,8 +324,7 @@ func TestPprofHandler_isAllowed(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
allowedIPs: parseIPs(tt.allowedIPs),
|
||||
allowedNets: parseNets(tt.allowedNets),
|
||||
allowed: parseNets(tt.allowedNets),
|
||||
}
|
||||
|
||||
// 创建请求上下文,模拟客户端 IP
|
||||
@ -348,16 +339,10 @@ func TestPprofHandler_isAllowed(t *testing.T) {
|
||||
|
||||
// 复制 isAllowed 的逻辑进行测试
|
||||
allowed := false
|
||||
if len(h.allowedIPs) == 0 && len(h.allowedNets) == 0 {
|
||||
if len(h.allowed) == 0 {
|
||||
allowed = true
|
||||
} else {
|
||||
for _, ip := range h.allowedIPs {
|
||||
if ip.Equal(clientIP) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, n := range h.allowedNets {
|
||||
for _, n := range h.allowed {
|
||||
if n.Contains(clientIP) {
|
||||
allowed = true
|
||||
break
|
||||
@ -372,24 +357,13 @@ func TestPprofHandler_isAllowed(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// parseIPs 辅助函数,解析 IP 字符串列表
|
||||
func parseIPs(ips []string) []net.IP {
|
||||
result := make([]net.IP, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
if parsed := net.ParseIP(ip); parsed != nil {
|
||||
result = append(result, parsed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// parseNets 辅助函数,解析 CIDR 字符串列表
|
||||
func parseNets(cidrs []string) []*net.IPNet {
|
||||
result := make([]*net.IPNet, 0, len(cidrs))
|
||||
func parseNets(cidrs []string) []net.IPNet {
|
||||
result := make([]net.IPNet, 0, len(cidrs))
|
||||
for _, cidr := range cidrs {
|
||||
_, net, err := net.ParseCIDR(cidr)
|
||||
if err == nil {
|
||||
result = append(result, net)
|
||||
if err == nil && net != nil {
|
||||
result = append(result, *net)
|
||||
}
|
||||
}
|
||||
return result
|
||||
@ -398,9 +372,8 @@ func parseNets(cidrs []string) []*net.IPNet {
|
||||
func TestPprofHandler_ServeHTTP_WithAllowListEmpty(t *testing.T) {
|
||||
// 测试空 allow 列表时允许所有访问
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
path: "/debug/pprof",
|
||||
allowed: []net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
@ -417,9 +390,8 @@ func TestPprofHandler_ServeHTTP_WithAllowListEmpty(t *testing.T) {
|
||||
func TestPprofHandler_ServeHTTP_ProfileEndpoints(t *testing.T) {
|
||||
// 使用空 allow 列表允许所有访问
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
path: "/debug/pprof",
|
||||
allowed: []net.IPNet{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
@ -465,9 +437,8 @@ func TestPprofHandler_ServeHTTP_ProfileEndpoints(t *testing.T) {
|
||||
|
||||
func TestPprofHandler_handleIndex(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
path: "/debug/pprof",
|
||||
allowed: []net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
@ -510,9 +481,8 @@ func TestPprofHandler_handleIndex(t *testing.T) {
|
||||
|
||||
func TestPprofHandler_ServeHTTP_PathRouting(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
path: "/debug/pprof",
|
||||
allowed: []net.IPNet{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
@ -564,30 +534,30 @@ func TestPprofHandler_ServeHTTP_PathRouting(t *testing.T) {
|
||||
|
||||
func TestPprofHandler_ServeHTTP_Forbidden(t *testing.T) {
|
||||
// 创建只允许特定 IP 的 handler
|
||||
allowedIP := net.ParseIP("10.0.0.1")
|
||||
_, ipNet, _ := net.ParseCIDR("10.0.0.1/32")
|
||||
h := &PprofHandler{
|
||||
allowedIPs: []net.IP{allowedIP},
|
||||
allowed: []net.IPNet{*ipNet},
|
||||
}
|
||||
|
||||
// 由于无法轻松设置 RemoteIP,我们直接测试 isAllowed 返回 false 的情况
|
||||
// 通过构造一个 allowedIPs 非空的情况来触发检查
|
||||
// 通过构造一个 allowed 非空的情况来触发检查
|
||||
|
||||
// 验证 handler 配置正确
|
||||
if len(h.allowedIPs) != 1 {
|
||||
t.Errorf("expected 1 allowed IP, got %d", len(h.allowedIPs))
|
||||
if len(h.allowed) != 1 {
|
||||
t.Errorf("expected 1 allowed IPNet, got %d", len(h.allowed))
|
||||
}
|
||||
|
||||
// 验证 allowed IPs 包含配置的 IP
|
||||
if !h.allowedIPs[0].Equal(allowedIP) {
|
||||
t.Error("expected allowedIPs to contain configured IP")
|
||||
// 验证 allowed 包含配置的 IP
|
||||
expectedIP := net.ParseIP("10.0.0.1")
|
||||
if !h.allowed[0].Contains(expectedIP) {
|
||||
t.Error("expected allowed to contain configured IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_handleCPU_Params(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
path: "/debug/pprof",
|
||||
allowed: []net.IPNet{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
@ -687,9 +657,8 @@ func TestPprofHandler_handleCPU_Execute(t *testing.T) {
|
||||
stopCPUProfile()
|
||||
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
path: "/debug/pprof",
|
||||
allowed: []net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
@ -781,35 +750,37 @@ func TestPprofHandler_ConfigWithCIDRAndIP(t *testing.T) {
|
||||
t.Fatal("expected non-nil handler")
|
||||
}
|
||||
|
||||
// 验证 IP 和 CIDR 都被正确解析
|
||||
if len(h.allowedIPs) != 2 {
|
||||
t.Errorf("expected 2 allowed IPs, got %d", len(h.allowedIPs))
|
||||
}
|
||||
if len(h.allowedNets) != 1 {
|
||||
t.Errorf("expected 1 allowed net, got %d", len(h.allowedNets))
|
||||
// 验证 IP 和 CIDR 都被正确解析(现在统一存储在 allowed 中)
|
||||
// 127.0.0.1 -> 127.0.0.1/32
|
||||
// ::1 -> ::1/128
|
||||
// 192.168.0.0/24 保持不变
|
||||
if len(h.allowed) != 3 {
|
||||
t.Errorf("expected 3 allowed entries (2 IPs converted to /32 and /128, 1 CIDR), got %d", len(h.allowed))
|
||||
}
|
||||
|
||||
// 验证具体内容
|
||||
// 验证具体内容 - 使用 Contains 检查
|
||||
foundV4 := false
|
||||
foundV6 := false
|
||||
for _, ip := range h.allowedIPs {
|
||||
if ip.Equal(net.ParseIP("127.0.0.1")) {
|
||||
foundCIDR := false
|
||||
for _, ipNet := range h.allowed {
|
||||
if ipNet.Contains(net.ParseIP("127.0.0.1")) && ipNet.String() == "127.0.0.1/32" {
|
||||
foundV4 = true
|
||||
}
|
||||
if ip.Equal(net.ParseIP("::1")) {
|
||||
if ipNet.Contains(net.ParseIP("::1")) && ipNet.String() == "::1/128" {
|
||||
foundV6 = true
|
||||
}
|
||||
if ipNet.String() == "192.168.0.0/24" {
|
||||
foundCIDR = true
|
||||
}
|
||||
}
|
||||
if !foundV4 {
|
||||
t.Error("expected to find 127.0.0.1 in allowedIPs")
|
||||
t.Error("expected to find 127.0.0.1/32 in allowed")
|
||||
}
|
||||
if !foundV6 {
|
||||
t.Error("expected to find ::1 in allowedIPs")
|
||||
t.Error("expected to find ::1/128 in allowed")
|
||||
}
|
||||
|
||||
// 验证 CIDR
|
||||
if h.allowedNets[0].String() != "192.168.0.0/24" {
|
||||
t.Errorf("expected CIDR 192.168.0.0/24, got %s", h.allowedNets[0].String())
|
||||
if !foundCIDR {
|
||||
t.Error("expected to find 192.168.0.0/24 in allowed")
|
||||
}
|
||||
}
|
||||
|
||||
@ -829,35 +800,34 @@ func TestPprofHandler_DefaultLocalhostBehavior(t *testing.T) {
|
||||
t.Fatal("expected non-nil handler")
|
||||
}
|
||||
|
||||
// 验证默认允许 localhost
|
||||
if len(h.allowedIPs) != 2 {
|
||||
t.Errorf("expected 2 default allowed IPs, got %d", len(h.allowedIPs))
|
||||
// 验证默认允许 localhost (解析为 127.0.0.1/32 和 ::1/128)
|
||||
if len(h.allowed) != 2 {
|
||||
t.Errorf("expected 2 default allowed entries (127.0.0.1/32 and ::1/128), got %d", len(h.allowed))
|
||||
}
|
||||
|
||||
// 验证包含 IPv4 和 IPv6 localhost
|
||||
hasV4 := false
|
||||
hasV6 := false
|
||||
for _, ip := range h.allowedIPs {
|
||||
if ip.Equal(net.ParseIP("127.0.0.1")) {
|
||||
for _, n := range h.allowed {
|
||||
if n.Contains(net.ParseIP("127.0.0.1")) && n.String() == "127.0.0.1/32" {
|
||||
hasV4 = true
|
||||
}
|
||||
if ip.Equal(net.ParseIP("::1")) {
|
||||
if n.Contains(net.ParseIP("::1")) && n.String() == "::1/128" {
|
||||
hasV6 = true
|
||||
}
|
||||
}
|
||||
if !hasV4 {
|
||||
t.Error("expected default to include 127.0.0.1")
|
||||
t.Error("expected default to include 127.0.0.1/32")
|
||||
}
|
||||
if !hasV6 {
|
||||
t.Error("expected default to include ::1")
|
||||
t.Error("expected default to include ::1/128")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_handleHeap(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
path: "/debug/pprof",
|
||||
allowed: []net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
@ -879,9 +849,8 @@ func TestPprofHandler_handleHeap(t *testing.T) {
|
||||
|
||||
func TestPprofHandler_handleGoroutine(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
path: "/debug/pprof",
|
||||
allowed: []net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
@ -903,9 +872,8 @@ func TestPprofHandler_handleGoroutine(t *testing.T) {
|
||||
|
||||
func TestPprofHandler_handleBlock(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
path: "/debug/pprof",
|
||||
allowed: []net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
@ -927,9 +895,8 @@ func TestPprofHandler_handleBlock(t *testing.T) {
|
||||
|
||||
func TestPprofHandler_handleMutex(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
path: "/debug/pprof",
|
||||
allowed: []net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
@ -953,9 +920,8 @@ func TestPprofHandler_handleMutex(t *testing.T) {
|
||||
func TestPprofHandler_isAllowed_RemoteIP(t *testing.T) {
|
||||
t.Run("empty allow lists - allow all", func(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
path: "/debug/pprof",
|
||||
allowed: []net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
@ -966,11 +932,10 @@ func TestPprofHandler_isAllowed_RemoteIP(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("with allow list but cannot parse IP", func(t *testing.T) {
|
||||
allowedIP := net.ParseIP("192.168.1.1")
|
||||
_, ipNet, _ := net.ParseCIDR("192.168.1.1/32")
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{allowedIP},
|
||||
allowedNets: []*net.IPNet{},
|
||||
path: "/debug/pprof",
|
||||
allowed: []net.IPNet{*ipNet},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
@ -6,7 +6,6 @@ package server
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/cache"
|
||||
@ -55,38 +54,11 @@ func NewPurgeHandler(server *Server, cfg *config.CacheAPIConfig) (*PurgeHandler,
|
||||
}
|
||||
|
||||
// 解析允许的 IP 列表
|
||||
for _, cidr := range cfg.Allow {
|
||||
// 处理 localhost 特殊情况
|
||||
if cidr == "localhost" {
|
||||
_, v4Network, _ := net.ParseCIDR("127.0.0.1/32")
|
||||
_, v6Network, _ := net.ParseCIDR("::1/128")
|
||||
if v4Network != nil {
|
||||
h.allowed = append(h.allowed, *v4Network)
|
||||
}
|
||||
if v6Network != nil {
|
||||
h.allowed = append(h.allowed, *v6Network)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
// 尝试作为单个 IP 解析
|
||||
ip, err := netip.ParseAddr(cidr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 转换为 CIDR 格式
|
||||
if ip.Is4() {
|
||||
_, network, _ = net.ParseCIDR(cidr + "/32")
|
||||
} else {
|
||||
_, network, _ = net.ParseCIDR(cidr + "/128")
|
||||
}
|
||||
}
|
||||
if network != nil {
|
||||
h.allowed = append(h.allowed, *network)
|
||||
}
|
||||
allowed, err := utils.ParseIPAllowList(cfg.Allow)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.allowed = allowed
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
@ -134,39 +134,11 @@ func NewStatusHandler(server *Server, cfg *config.StatusConfig) (*StatusHandler,
|
||||
}
|
||||
|
||||
// 解析允许的 IP 列表
|
||||
for _, cidr := range cfg.Allow {
|
||||
// 处理 localhost 特殊情况
|
||||
if cidr == "localhost" {
|
||||
// localhost 解析为 127.0.0.1 和 ::1
|
||||
_, v4Network, _ := net.ParseCIDR("127.0.0.1/32")
|
||||
_, v6Network, _ := net.ParseCIDR("::1/128")
|
||||
if v4Network != nil {
|
||||
h.allowed = append(h.allowed, *v4Network)
|
||||
}
|
||||
if v6Network != nil {
|
||||
h.allowed = append(h.allowed, *v6Network)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
// 尝试作为单个 IP 解析
|
||||
ip := net.ParseIP(cidr)
|
||||
if ip == nil {
|
||||
return nil, err
|
||||
}
|
||||
// 转换为 CIDR 格式
|
||||
if ip.To4() != nil {
|
||||
_, network, _ = net.ParseCIDR(cidr + "/32")
|
||||
} else {
|
||||
_, network, _ = net.ParseCIDR(cidr + "/128")
|
||||
}
|
||||
}
|
||||
if network != nil {
|
||||
h.allowed = append(h.allowed, *network)
|
||||
}
|
||||
allowed, err := utils.ParseIPAllowList(cfg.Allow)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h.allowed = allowed
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
93
internal/utils/ipallowlist.go
Normal file
93
internal/utils/ipallowlist.go
Normal file
@ -0,0 +1,93 @@
|
||||
// Package utils 提供通用工具函数
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// ParseIPAllowList 解析 IP/CIDR 白名单列表。
|
||||
//
|
||||
// 支持格式:
|
||||
// - CIDR 格式:192.168.1.0/24、::1/128
|
||||
// - 单个 IP:192.168.1.1(自动转换为 /32 或 /128)
|
||||
// - 特殊值 "localhost":映射为 127.0.0.1/32 和 ::1/128
|
||||
//
|
||||
// 参数:
|
||||
// - allow: IP/CIDR 字符串列表
|
||||
//
|
||||
// 返回值:
|
||||
// - []net.IPNet: 解析后的网络列表
|
||||
// - error: 解析失败时返回错误
|
||||
func ParseIPAllowList(allow []string) ([]net.IPNet, error) {
|
||||
if len(allow) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make([]net.IPNet, 0, len(allow)+2) // +2 for localhost expansion
|
||||
|
||||
for _, cidr := range allow {
|
||||
// 处理 localhost 特殊情况
|
||||
if cidr == "localhost" {
|
||||
// localhost 解析为 127.0.0.1 和 ::1
|
||||
if v4Net, err := parseCIDR("127.0.0.1/32"); err == nil {
|
||||
result = append(result, *v4Net)
|
||||
}
|
||||
if v6Net, err := parseCIDR("::1/128"); err == nil {
|
||||
result = append(result, *v6Net)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 尝试 CIDR 解析
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err == nil && network != nil {
|
||||
result = append(result, *network)
|
||||
continue
|
||||
}
|
||||
|
||||
// fallback: 尝试作为单个 IP 解析
|
||||
ip := net.ParseIP(cidr)
|
||||
if ip == nil {
|
||||
return nil, err // 返回原始 CIDR 解析错误
|
||||
}
|
||||
|
||||
// 转换为 CIDR 格式
|
||||
var ipNet *net.IPNet
|
||||
if ip.To4() != nil {
|
||||
ipNet, _ = parseCIDR(cidr + "/32")
|
||||
} else {
|
||||
ipNet, _ = parseCIDR(cidr + "/128")
|
||||
}
|
||||
if ipNet != nil {
|
||||
result = append(result, *ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseCIDR 是 net.ParseCIDR 的包装,返回 *net.IPNet 而不返回 net.IP
|
||||
func parseCIDR(cidr string) (*net.IPNet, error) {
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
return network, err
|
||||
}
|
||||
|
||||
// IPInAllowList 检查 IP 是否在白名单中。
|
||||
//
|
||||
// 参数:
|
||||
// - ip: 要检查的 IP 地址
|
||||
// - allowList: 白名单网络列表
|
||||
//
|
||||
// 返回值:
|
||||
// - bool: IP 在白名单中返回 true
|
||||
func IPInAllowList(ip net.IP, allowList []net.IPNet) bool {
|
||||
if len(allowList) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, network := range allowList {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user