diff --git a/internal/middleware/security/access.go b/internal/middleware/security/access.go index 82867cd..88b8feb 100644 --- a/internal/middleware/security/access.go +++ b/internal/middleware/security/access.go @@ -34,6 +34,7 @@ import ( "github.com/valyala/fasthttp" "rua.plus/lolly/internal/config" "rua.plus/lolly/internal/middleware" + "rua.plus/lolly/internal/netutil" ) // Action 表示对 IP 的操作类型。 @@ -383,7 +384,7 @@ func parseCIDR(cidr string) (*net.IPNet, error) { // 返回值: // - net.IP: 客户端 IP 地址,无法获取时返回 nil func (ac *AccessControl) getClientIP(ctx *fasthttp.RequestCtx) net.IP { - remoteIP := getRemoteAddrIP(ctx) + remoteIP := netutil.GetRemoteAddrIP(ctx) // 仅当配置了可信代理且请求来自可信代理时,才解析 X-Forwarded-For if len(ac.trustedProxies) > 0 && remoteIP != nil { @@ -441,30 +442,6 @@ func (ac *AccessControl) getClientIP(ctx *fasthttp.RequestCtx) net.IP { return remoteIP } -// getRemoteAddrIP 从 RemoteAddr 提取 IP。 -// -// 参数: -// - ctx: FastHTTP 请求上下文 -// -// 返回值: -// - net.IP: 客户端 IP 地址,无法获取时返回 nil -func getRemoteAddrIP(ctx *fasthttp.RequestCtx) net.IP { - if addr := ctx.RemoteAddr(); addr != nil { - if tcpAddr, ok := addr.(*net.TCPAddr); ok { - return tcpAddr.IP - } - // 从字符串表示解析 - ipStr := addr.String() - if idx := strings.LastIndex(ipStr, ":"); idx != -1 { - ipStr = ipStr[:idx] - } - // 移除 IPv6 的方括号 - ipStr = strings.TrimPrefix(strings.TrimSuffix(ipStr, "]"), "[") - return net.ParseIP(ipStr) - } - return nil -} - // AccessStats 访问控制统计信息结构。 type AccessStats struct { Default string diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index b96de67..f68567f 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -193,24 +193,7 @@ func (p *Proxy) SetHealthChecker(hc *HealthChecker) { // createBalancer 根据配置的算法创建负载均衡器。 func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) { - switch cfg.LoadBalance { - case lbRoundRobin, "": - return loadbalance.NewRoundRobin(), nil - case lbWeightedRoundRobin: - return loadbalance.NewWeightedRoundRobin(), nil - case lbLeastConn: - return loadbalance.NewLeastConnections(), nil - case lbIPHash: - return loadbalance.NewIPHash(), nil - case lbConsistentHash: - virtualNodes := cfg.VirtualNodes - if virtualNodes <= 0 { - virtualNodes = 150 - } - return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil - default: - return nil, errors.New("unsupported load balance algorithm: " + cfg.LoadBalance) - } + return createBalancerByName(cfg.LoadBalance, cfg) } // createHostClient 为后台目标 URL 创建 fasthttp.HostClient。