feat(proxy,server): 实现 ProxyBind/Buffering/ProxyURI 和响应头控制

ProxyBind 支持指定本地地址出站连接,Buffering 控制响应缓冲模式,
ProxyURI 实现 nginx proxy_pass URI 替换语义,
响应头新增 HideResponse/PassResponse/IgnoreHeaders/Cookie 域路径重写,
健康检查集成 RecordFailure/RecordSuccess 软失败状态。

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-21 11:28:32 +08:00
parent c0b7e30bf0
commit 2b63666ea5
3 changed files with 165 additions and 33 deletions

View File

@ -210,33 +210,21 @@ func (h *HealthChecker) checkTarget(target *loadbalance.Target) {
// 此方法用于被动健康检查,代理根据请求处理过程中
// 观察到的失败将目标标记为不健康。
//
// 在代理错误处理中的使用示例:
//
// if err := forwardRequest(target, req, resp); err != nil {
// healthChecker.MarkUnhealthy(target)
// // 尝试其他目标或返回错误
// }
//
// 注意:要再次将目标标记为健康,主动健康检查
// 必须成功。没有 MarkHealthy 方法 - 健康状态只能通过
// 成功的健康检查积极恢复。
// 同时调用 RecordFailure 记录软失败状态,配合 MaxFails/FailTimeout
// 实现失败计数和冷却机制。
func (h *HealthChecker) MarkUnhealthy(target *loadbalance.Target) {
target.Healthy.Store(false)
target.RecordFailure()
}
// MarkHealthy 将目标标记为健康。
// 此方法用于故障转移成功后,将之前失败的目标恢复为健康状态。
//
// 在故障转移成功后的使用示例:
//
// if err := retryRequest(target, req, resp); err == nil {
// healthChecker.MarkHealthy(target)
// }
//
// 注意:此方法与主动健康检查独立运作,用于快速恢复
// 故障转移场景中已恢复的目标。
// 同时调用 RecordSuccess 重置软失败状态failCount/failedUntil
// 但不修改 Healthy 标志——健康检查器对 Healthy 拥有权威。
func (h *HealthChecker) MarkHealthy(target *loadbalance.Target) {
target.Healthy.Store(true)
target.RecordSuccess()
}
// IsRunning 如果健康检查器当前正在运行,则返回 true。

View File

@ -36,6 +36,7 @@ import (
"errors"
"fmt"
"hash/fnv"
"net"
"slices"
"strings"
"sync"
@ -66,6 +67,7 @@ const (
lbLeastConn = "least_conn" // 最少连接
lbIPHash = "ip_hash" // IP 哈希
lbConsistentHash = "consistent_hash" // 一致性哈希
lbRandom = "random" // 随机Power of Two Choices
)
// headersPool 复用缓存 headers map减少分配。
@ -155,8 +157,12 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC
continue
}
client := createHostClient(target.URL, cfg.Timeout, transportCfg, cfg.ProxySSL)
p.clients[target.URL] = client
client := createHostClient(target.URL, cfg.Timeout, transportCfg, cfg.ProxySSL, cfg.ProxyBind, cfg.Buffering)
clientKey := target.URL
if cfg.ProxyBind != "" {
clientKey = target.URL + "|" + cfg.ProxyBind
}
p.clients[clientKey] = client
}
// 初始化代理缓存(如果启用)
@ -213,6 +219,8 @@ func createBalancerByName(name string, cfg *config.ProxyConfig) (loadbalance.Bal
virtualNodes = 150
}
return loadbalance.NewConsistentHash(virtualNodes, cfg.HashKey), nil
case lbRandom:
return loadbalance.NewRandom(), nil
default:
return nil, errors.New("unsupported load balance algorithm: " + name)
}
@ -255,7 +263,7 @@ func createBalancer(cfg *config.ProxyConfig) (loadbalance.Balancer, error) {
//
// 返回值:
// - *fasthttp.HostClient: 配置完成的 HostClient 实例
func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig, sslCfg *config.ProxySSLConfig) *fasthttp.HostClient {
func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCfg *config.TransportConfig, sslCfg *config.ProxySSLConfig, proxyBind string, buffering *config.ProxyBufferingConfig) *fasthttp.HostClient {
// 从目标 URL 解析主机和协议
// addDefaultPort=true 确保 HostClient.Addr 包含端口host:port 格式)
addr, isTLS := netutil.ParseTargetURL(targetURL, true)
@ -287,6 +295,27 @@ func createHostClient(targetURL string, timeout config.ProxyTimeout, transportCf
SecureErrorLogMessage: false,
}
// ProxyBind使用指定本地地址作为出站连接源
if proxyBind != "" {
localAddr := proxyBind
client.Dial = func(addr string) (net.Conn, error) {
dialer := &net.Dialer{
LocalAddr: &net.TCPAddr{IP: net.ParseIP(localAddr)},
Timeout: client.MaxConnWaitTimeout,
}
return dialer.Dial("tcp", addr)
}
}
// Buffering 控制
if buffering != nil && buffering.Mode == "off" {
client.StreamResponseBody = true
}
if buffering != nil && buffering.BufferSize > 0 {
client.ReadBufferSize = buffering.BufferSize
client.WriteBufferSize = buffering.BufferSize
}
// 上游 SSL 配置(使用原生 TLSConfig
if sslCfg != nil && sslCfg.Enabled && isTLS {
tlsCfg, err := CreateTLSConfig(sslCfg, extractHostFromURL(targetURL))
@ -534,6 +563,14 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
// SAFETY: lifetime=ephemeral - consumed immediately by SetRequestURIBytes
path := ctx.URI().Path()
query := ctx.URI().QueryString()
// ProxyURI 语义:当 target.ProxyURI 设置时,替换请求路径
// 这实现了 nginx proxy_pass URI 传递语义:
// proxy_pass http://backend/v2/ → 请求路径替换为 /v2/
if target.ProxyURI != "" {
path = []byte(target.ProxyURI)
}
targetURI := make([]byte, 0, len(target.URL)+len(path)+len(query)+1)
targetURI = append(targetURI, target.URL...)
targetURI = append(targetURI, path...)
@ -642,6 +679,9 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
// 请求成功,减少连接计数
loadbalance.DecrementConnections(target)
// 记录成功,重置软失败状态
target.RecordSuccess()
// 检测 X-Accel-Redirect 头,支持内部重定向
if redirectPath := ctx.Response.Header.Peek("X-Accel-Redirect"); len(redirectPath) > 0 {
utils.SetInternalRedirect(ctx, string(redirectPath))
@ -949,15 +989,13 @@ func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string
// getClient 返回指定目标 URL 对应的 HostClient 连接池实例。
// 如果目标 URL 不存在于连接池中,返回 nil。
//
// 参数:
// - targetURL: 后端目标 URL
//
// 返回值:
// - *fasthttp.HostClient: 对应的连接池实例
func (p *Proxy) getClient(targetURL string) *fasthttp.HostClient {
key := targetURL
if p.config.ProxyBind != "" {
key = targetURL + "|" + p.config.ProxyBind
}
p.mu.RLock()
client := p.clients[targetURL]
client := p.clients[key]
p.mu.RUnlock()
return client
}
@ -1012,17 +1050,103 @@ func (p *Proxy) modifyRequestHeaders(ctx *fasthttp.RequestCtx, target *loadbalan
// 参数:
// - ctx: FastHTTP 请求上下文
func (p *Proxy) modifyResponseHeaders(ctx *fasthttp.RequestCtx) {
respHeaders := &ctx.Response.Header
// 构建 PassResponse 集合(多处使用)
passSet := make(map[string]bool, len(p.config.Headers.PassResponse))
for _, h := range p.config.Headers.PassResponse {
passSet[h] = true
}
// PassResponse 白名单模式:仅传递列出的头部
if len(passSet) > 0 {
var toDelete []string
for key := range respHeaders.All() {
if !passSet[string(key)] {
toDelete = append(toDelete, string(key))
}
}
for _, k := range toDelete {
respHeaders.Del(k)
}
}
// HideResponse移除指定的响应头PassResponse 优先,跳过已传递的头部)
for _, key := range p.config.Headers.HideResponse {
if !passSet[key] {
respHeaders.Del(key)
}
}
// IgnoreHeaders从请求和响应中移除PassResponse 优先)
for _, key := range p.config.Headers.IgnoreHeaders {
ctx.Request.Header.Del(key)
if !passSet[key] {
respHeaders.Del(key)
}
}
// Cookie 域/路径重写
if p.config.Headers.CookieDomain != "" || p.config.Headers.CookiePath != "" {
p.rewriteCookies(respHeaders)
}
// 从配置设置自定义响应头(支持变量展开)
if p.config.Headers.SetResponse != nil {
vc := variable.NewContext(ctx)
defer variable.ReleaseContext(vc)
for key, value := range p.config.Headers.SetResponse {
expanded := vc.Expand(value)
ctx.Response.Header.Set(key, expanded)
respHeaders.Set(key, expanded)
}
}
}
// rewriteCookies 重写响应中 Set-Cookie 头的 domain 和 path。
func (p *Proxy) rewriteCookies(respHeaders *fasthttp.ResponseHeader) {
cookieDomain := p.config.Headers.CookieDomain
cookiePath := p.config.Headers.CookiePath
if cookieDomain == "" && cookiePath == "" {
return
}
var cookies []string
for _, value := range respHeaders.Cookies() {
cookie := string(value)
if cookieDomain != "" {
cookie = rewriteCookieAttr(cookie, "Domain", cookieDomain)
}
if cookiePath != "" {
cookie = rewriteCookieAttr(cookie, "Path", cookiePath)
}
cookies = append(cookies, cookie)
}
if len(cookies) > 0 {
respHeaders.Del("Set-Cookie")
for _, c := range cookies {
respHeaders.Add("Set-Cookie", c)
}
}
}
// rewriteCookieAttr 替换 Cookie 字符串中指定属性的值。
func rewriteCookieAttr(cookie, attr, newValue string) string {
prefix := attr + "="
idx := strings.Index(cookie, prefix)
if idx == -1 {
return cookie
}
start := idx + len(prefix)
end := start
for end < len(cookie) && cookie[end] != ';' && cookie[end] != ' ' {
end++
}
return cookie[:start] + newValue + cookie[end:]
}
// isWebSocketRequest 检查请求是否为 WebSocket 升级请求。
//
// 通过检查 Connection 和 Upgrade 请求头判断:
@ -1076,8 +1200,12 @@ func (p *Proxy) UpdateTargets(targets []*loadbalance.Target) error {
continue
}
client := createHostClient(target.URL, p.config.Timeout, nil, p.config.ProxySSL)
p.clients[target.URL] = client
client := createHostClient(target.URL, p.config.Timeout, nil, p.config.ProxySSL, p.config.ProxyBind, p.config.Buffering)
clientKey := target.URL
if p.config.ProxyBind != "" {
clientKey = target.URL + "|" + p.config.ProxyBind
}
p.clients[clientKey] = client
}
p.targets = targets

View File

@ -999,7 +999,15 @@ func (s *Server) registerProxyRoutesWithLocationEngine(serverCfg *config.ServerC
// 转换目标
targets := make([]*loadbalance.Target, len(proxyCfg.Targets))
for j, t := range proxyCfg.Targets {
targets[j] = loadbalance.NewTargetFromConfig(t.URL, t.Weight)
failTimeout := t.FailTimeout
if t.MaxFails > 0 && failTimeout == 0 {
failTimeout = 10 * time.Second
}
targets[j] = loadbalance.NewTargetFromConfig(
t.URL, t.Weight,
int64(t.MaxConns), int64(t.MaxFails), failTimeout,
t.Backup, t.Down, t.ProxyURI,
)
}
// 传递 Transport 配置和 Lua 引擎
@ -1122,7 +1130,15 @@ func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.S
// 转换目标
targets := make([]*loadbalance.Target, len(proxyCfg.Targets))
for j, t := range proxyCfg.Targets {
targets[j] = loadbalance.NewTargetFromConfig(t.URL, t.Weight)
failTimeout := t.FailTimeout
if t.MaxFails > 0 && failTimeout == 0 {
failTimeout = 10 * time.Second
}
targets[j] = loadbalance.NewTargetFromConfig(
t.URL, t.Weight,
int64(t.MaxConns), int64(t.MaxFails), failTimeout,
t.Backup, t.Down, t.ProxyURI,
)
}
// 传递 Transport 配置和 Lua 引擎