diff --git a/internal/proxy/health.go b/internal/proxy/health.go index df06fa9..07fd51d 100644 --- a/internal/proxy/health.go +++ b/internal/proxy/health.go @@ -210,33 +210,21 @@ func (h *HealthChecker) checkTarget(target *loadbalance.Target) { // 此方法用于被动健康检查,代理根据请求处理过程中 // 观察到的失败将目标标记为不健康。 // -// 在代理错误处理中的使用示例: -// -// if err := forwardRequest(target, req, resp); err != nil { -// healthChecker.MarkUnhealthy(target) -// // 尝试其他目标或返回错误 -// } -// -// 注意:要再次将目标标记为健康,主动健康检查 -// 必须成功。没有 MarkHealthy 方法 - 健康状态只能通过 -// 成功的健康检查积极恢复。 +// 同时调用 RecordFailure 记录软失败状态,配合 MaxFails/FailTimeout +// 实现失败计数和冷却机制。 func (h *HealthChecker) MarkUnhealthy(target *loadbalance.Target) { target.Healthy.Store(false) + target.RecordFailure() } // MarkHealthy 将目标标记为健康。 // 此方法用于故障转移成功后,将之前失败的目标恢复为健康状态。 // -// 在故障转移成功后的使用示例: -// -// if err := retryRequest(target, req, resp); err == nil { -// healthChecker.MarkHealthy(target) -// } -// -// 注意:此方法与主动健康检查独立运作,用于快速恢复 -// 故障转移场景中已恢复的目标。 +// 同时调用 RecordSuccess 重置软失败状态(failCount/failedUntil), +// 但不修改 Healthy 标志——健康检查器对 Healthy 拥有权威。 func (h *HealthChecker) MarkHealthy(target *loadbalance.Target) { target.Healthy.Store(true) + target.RecordSuccess() } // IsRunning 如果健康检查器当前正在运行,则返回 true。 diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index aa1341b..befc5bd 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -36,6 +36,7 @@ import ( "errors" "fmt" "hash/fnv" + "net" "slices" "strings" "sync" @@ -66,6 +67,7 @@ const ( lbLeastConn = "least_conn" // 最少连接 lbIPHash = "ip_hash" // IP 哈希 lbConsistentHash = "consistent_hash" // 一致性哈希 + lbRandom = "random" // 随机(Power of Two Choices) ) // headersPool 复用缓存 headers map,减少分配。 @@ -155,8 +157,12 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC continue } - client := createHostClient(target.URL, cfg.Timeout, transportCfg, cfg.ProxySSL) - p.clients[target.URL] = client + client := createHostClient(target.URL, cfg.Timeout, transportCfg, cfg.ProxySSL, cfg.ProxyBind, cfg.Buffering) + clientKey := target.URL + if cfg.ProxyBind != "" { + clientKey = target.URL + "|" + cfg.ProxyBind + } + p.clients[clientKey] = client } // 初始化代理缓存(如果启用) @@ -213,6 +219,8 @@ func createBalancerByName(name string, cfg *config.ProxyConfig) (loadbalance.Bal virtualNodes = 150 } return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil + case lbRandom: + return loadbalance.NewRandom(), nil default: return nil, errors.New("unsupported load balance algorithm: " + name) } @@ -255,7 +263,7 @@ func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) { // // 返回值: // - *fasthttp.HostClient: 配置完成的 HostClient 实例 -func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig, sslCfg *config.ProxySSLConfig) *fasthttp.HostClient { +func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig, sslCfg *config.ProxySSLConfig, proxyBind string, buffering *config.ProxyBufferingConfig) *fasthttp.HostClient { // 从目标 URL 解析主机和协议 // addDefaultPort=true 确保 HostClient.Addr 包含端口(host:port 格式) addr, isTLS := netutil.ParseTargetURL(targetURL, true) @@ -287,6 +295,27 @@ func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCf SecureErrorLogMessage: false, } + // ProxyBind:使用指定本地地址作为出站连接源 + if proxyBind != "" { + localAddr := proxyBind + client.Dial = func(addr string) (net.Conn, error) { + dialer := &net.Dialer{ + LocalAddr: &net.TCPAddr{IP: net.ParseIP(localAddr)}, + Timeout: client.MaxConnWaitTimeout, + } + return dialer.Dial("tcp", addr) + } + } + + // Buffering 控制 + if buffering != nil && buffering.Mode == "off" { + client.StreamResponseBody = true + } + if buffering != nil && buffering.BufferSize > 0 { + client.ReadBufferSize = buffering.BufferSize + client.WriteBufferSize = buffering.BufferSize + } + // 上游 SSL 配置(使用原生 TLSConfig) if sslCfg != nil && sslCfg.Enabled && isTLS { tlsCfg, err := CreateTLSConfig(sslCfg, extractHostFromURL(targetURL)) @@ -534,6 +563,14 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { // SAFETY: lifetime=ephemeral - consumed immediately by SetRequestURIBytes path := ctx.URI().Path() query := ctx.URI().QueryString() + + // ProxyURI 语义:当 target.ProxyURI 设置时,替换请求路径 + // 这实现了 nginx proxy_pass URI 传递语义: + // proxy_pass http://backend/v2/ → 请求路径替换为 /v2/ + if target.ProxyURI != "" { + path = []byte(target.ProxyURI) + } + targetURI := make([]byte, 0, len(target.URL)+len(path)+len(query)+1) targetURI = append(targetURI, target.URL...) targetURI = append(targetURI, path...) @@ -642,6 +679,9 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { // 请求成功,减少连接计数 loadbalance.DecrementConnections(target) + // 记录成功,重置软失败状态 + target.RecordSuccess() + // 检测 X-Accel-Redirect 头,支持内部重定向 if redirectPath := ctx.Response.Header.Peek("X-Accel-Redirect"); len(redirectPath) > 0 { utils.SetInternalRedirect(ctx, string(redirectPath)) @@ -949,15 +989,13 @@ func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string // getClient 返回指定目标 URL 对应的 HostClient 连接池实例。 // 如果目标 URL 不存在于连接池中,返回 nil。 -// -// 参数: -// - targetURL: 后端目标 URL -// -// 返回值: -// - *fasthttp.HostClient: 对应的连接池实例 func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient { + key := targetURL + if p.config.ProxyBind != "" { + key = targetURL + "|" + p.config.ProxyBind + } p.mu.RLock() - client := p.clients[targetURL] + client := p.clients[key] p.mu.RUnlock() return client } @@ -1012,17 +1050,103 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan // 参数: // - ctx: FastHTTP 请求上下文 func (p *Proxy) modifyResponseHeaders(ctx *fasthttp.RequestCtx) { + respHeaders := &ctx.Response.Header + + // 构建 PassResponse 集合(多处使用) + passSet := make(map[string]bool, len(p.config.Headers.PassResponse)) + for _, h := range p.config.Headers.PassResponse { + passSet[h] = true + } + + // PassResponse 白名单模式:仅传递列出的头部 + if len(passSet) > 0 { + var toDelete []string + for key := range respHeaders.All() { + if !passSet[string(key)] { + toDelete = append(toDelete, string(key)) + } + } + for _, k := range toDelete { + respHeaders.Del(k) + } + } + + // HideResponse:移除指定的响应头(PassResponse 优先,跳过已传递的头部) + for _, key := range p.config.Headers.HideResponse { + if !passSet[key] { + respHeaders.Del(key) + } + } + + // IgnoreHeaders:从请求和响应中移除(PassResponse 优先) + for _, key := range p.config.Headers.IgnoreHeaders { + ctx.Request.Header.Del(key) + if !passSet[key] { + respHeaders.Del(key) + } + } + + // Cookie 域/路径重写 + if p.config.Headers.CookieDomain != "" || p.config.Headers.CookiePath != "" { + p.rewriteCookies(respHeaders) + } + // 从配置设置自定义响应头(支持变量展开) if p.config.Headers.SetResponse != nil { vc := variable.NewContext(ctx) defer variable.ReleaseContext(vc) for key, value := range p.config.Headers.SetResponse { expanded := vc.Expand(value) - ctx.Response.Header.Set(key, expanded) + respHeaders.Set(key, expanded) } } } +// rewriteCookies 重写响应中 Set-Cookie 头的 domain 和 path。 +func (p *Proxy) rewriteCookies(respHeaders *fasthttp.ResponseHeader) { + cookieDomain := p.config.Headers.CookieDomain + cookiePath := p.config.Headers.CookiePath + if cookieDomain == "" && cookiePath == "" { + return + } + + var cookies []string + for _, value := range respHeaders.Cookies() { + cookie := string(value) + if cookieDomain != "" { + cookie = rewriteCookieAttr(cookie, "Domain", cookieDomain) + } + if cookiePath != "" { + cookie = rewriteCookieAttr(cookie, "Path", cookiePath) + } + cookies = append(cookies, cookie) + } + + if len(cookies) > 0 { + respHeaders.Del("Set-Cookie") + for _, c := range cookies { + respHeaders.Add("Set-Cookie", c) + } + } +} + +// rewriteCookieAttr 替换 Cookie 字符串中指定属性的值。 +func rewriteCookieAttr(cookie, attr, newValue string) string { + prefix := attr + "=" + idx := strings.Index(cookie, prefix) + if idx == -1 { + return cookie + } + + start := idx + len(prefix) + end := start + for end < len(cookie) && cookie[end] != ';' && cookie[end] != ' ' { + end++ + } + + return cookie[:start] + newValue + cookie[end:] +} + // isWebSocketRequest 检查请求是否为 WebSocket 升级请求。 // // 通过检查 Connection 和 Upgrade 请求头判断: @@ -1076,8 +1200,12 @@ func (p *Proxy) UpdateTargets(targets []*loadbalance.Target) error { continue } - client := createHostClient(target.URL, p.config.Timeout, nil, p.config.ProxySSL) - p.clients[target.URL] = client + client := createHostClient(target.URL, p.config.Timeout, nil, p.config.ProxySSL, p.config.ProxyBind, p.config.Buffering) + clientKey := target.URL + if p.config.ProxyBind != "" { + clientKey = target.URL + "|" + p.config.ProxyBind + } + p.clients[clientKey] = client } p.targets = targets diff --git a/internal/server/server.go b/internal/server/server.go index 4d004cd..d3439a3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -999,7 +999,15 @@ func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerC // 转换目标 targets := make([]*loadbalance.Target, len(proxyCfg.Targets)) for j, t := range proxyCfg.Targets { - targets[j] = loadbalance.NewTargetFromConfig(t.URL, t.Weight) + failTimeout := t.FailTimeout + if t.MaxFails > 0 && failTimeout == 0 { + failTimeout = 10 * time.Second + } + targets[j] = loadbalance.NewTargetFromConfig( + t.URL, t.Weight, + int64(t.MaxConns), int64(t.MaxFails), failTimeout, + t.Backup, t.Down, t.ProxyURI, + ) } // 传递 Transport 配置和 Lua 引擎 @@ -1122,7 +1130,15 @@ func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.S // 转换目标 targets := make([]*loadbalance.Target, len(proxyCfg.Targets)) for j, t := range proxyCfg.Targets { - targets[j] = loadbalance.NewTargetFromConfig(t.URL, t.Weight) + failTimeout := t.FailTimeout + if t.MaxFails > 0 && failTimeout == 0 { + failTimeout = 10 * time.Second + } + targets[j] = loadbalance.NewTargetFromConfig( + t.URL, t.Weight, + int64(t.MaxConns), int64(t.MaxFails), failTimeout, + t.Backup, t.Down, t.ProxyURI, + ) } // 传递 Transport 配置和 Lua 引擎