feat(proxy): 实现代理缓存功能与一致性哈希支持

- 新增缓存命中/过期/刷新逻辑
- 实现缓存锁防止缓存击穿 (stale-while-revalidate)
- 支持一致性哈希按 uri/ip/header 选择目标
- 新增 getProxyCacheStats 收集缓存统计
- 集成连接数限制中间件 (ConnLimiter)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-03 16:57:48 +08:00
parent 351f477822
commit 03b0df2c69
3 changed files with 180 additions and 0 deletions

View File

@ -234,6 +234,34 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
// 修改请求头
p.modifyRequestHeaders(ctx, target)
// 尝试从缓存获取(如果启用)
if p.cache != nil {
cacheKey := p.buildCacheKey(ctx)
if entry, ok, stale := p.cache.Get(cacheKey); ok {
// 缓存命中
if !stale {
// 新鲜缓存,直接返回
p.writeCachedResponse(ctx, entry)
return
}
// 过期缓存,尝试后台刷新,同时返回旧数据
go p.backgroundRefresh(ctx, target, cacheKey)
p.writeCachedResponse(ctx, entry)
return
}
// 检查是否需要缓存锁(防止缓存击穿)
if done := p.cache.AcquireLock(cacheKey); done != nil {
// 有其他请求正在生成缓存,等待
<-done
// 重新尝试获取缓存
if entry, ok, _ := p.cache.Get(cacheKey); ok {
p.writeCachedResponse(ctx, entry)
return
}
}
}
// 执行代理请求
err := client.Do(req, &ctx.Response)
if err != nil {
@ -242,6 +270,11 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
p.healthChecker.MarkUnhealthy(target)
}
// 释放缓存锁
if p.cache != nil {
p.cache.ReleaseLock(p.buildCacheKey(ctx), err)
}
// 处理不同类型的错误
if errors.Is(err, fasthttp.ErrTimeout) {
ctx.Error("Gateway Timeout", fasthttp.StatusGatewayTimeout)
@ -253,12 +286,28 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
return
}
// 存入缓存(如果启用且响应可缓存)
if p.cache != nil {
cacheKey := p.buildCacheKey(ctx)
status := ctx.Response.StatusCode()
if status >= 200 && status < 300 {
// 提取响应头
headers := make(map[string]string)
ctx.Response.Header.VisitAll(func(key, value []byte) {
headers[string(key)] = string(value)
})
p.cache.Set(cacheKey, ctx.Response.Body(), headers, status, p.config.Cache.MaxAge)
}
p.cache.ReleaseLock(cacheKey, nil)
}
// 修改响应头
p.modifyResponseHeaders(ctx)
}
// selectTarget 使用配置的负载均衡器选择后端目标。
// 对于 IP 哈希负载均衡,从请求中提取客户端 IP。
// 对于一致性哈希,根据配置的 hash_key 选择目标。
// 如果没有可用的健康目标则返回 nil。
func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
p.mu.RLock()
@ -276,9 +325,35 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
return ipHash.SelectByIP(targets, clientIP)
}
// 对于一致性哈希,根据 hash_key 配置选择
if ch, ok := balancer.(*loadbalance.ConsistentHash); ok {
hashKey := ch.GetHashKey()
key := p.extractHashKey(ctx, hashKey)
return ch.SelectByKey(targets, key)
}
return balancer.Select(targets)
}
// extractHashKey 根据配置提取哈希键值。
func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string {
switch {
case hashKey == "ip" || hashKey == "":
return getClientIP(ctx)
case hashKey == "uri":
return string(ctx.RequestURI())
case strings.HasPrefix(hashKey, "header:"):
headerName := strings.TrimPrefix(hashKey, "header:")
value := ctx.Request.Header.Peek(headerName)
if len(value) > 0 {
return string(value)
}
return getClientIP(ctx) // fallback to IP
default:
return getClientIP(ctx)
}
}
// getClientIP 从请求上下文中提取客户端 IP 地址。
func getClientIP(ctx *fasthttp.RequestCtx) string {
// 首先检查 X-Forwarded-For 请求头
@ -437,3 +512,63 @@ func (p *Proxy) GetConfig() *config.ProxyConfig {
defer p.mu.RUnlock()
return p.config
}
// buildCacheKey 构建缓存键。
func (p *Proxy) buildCacheKey(ctx *fasthttp.RequestCtx) string {
// 使用请求方法和路径作为缓存键
return string(ctx.Request.Header.Method()) + ":" + string(ctx.Request.URI().RequestURI())
}
// writeCachedResponse 写入缓存的响应。
func (p *Proxy) writeCachedResponse(ctx *fasthttp.RequestCtx, entry *cache.ProxyCacheEntry) {
ctx.Response.SetBody(entry.Data)
ctx.Response.SetStatusCode(entry.Status)
for key, value := range entry.Headers {
ctx.Response.Header.Set(key, value)
}
ctx.Response.Header.Set("X-Cache", "HIT")
}
// backgroundRefresh 后台刷新缓存。
func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance.Target, cacheKey string) {
// 创建新的请求上下文副本
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// 复制原始请求
ctx.Request.CopyTo(req)
// 获取客户端
client := p.getClient(target.URL)
if client == nil {
return
}
// 执行请求
err := client.Do(req, resp)
if err != nil {
p.cache.ReleaseLock(cacheKey, err)
return
}
// 提取响应头
headers := make(map[string]string)
resp.Header.VisitAll(func(key, value []byte) {
headers[string(key)] = string(value)
})
// 更新缓存
p.cache.Set(cacheKey, resp.Body(), headers, resp.StatusCode(), p.config.Cache.MaxAge)
}
// GetCacheStats 返回代理缓存的统计信息。
// 如果缓存未启用,返回 nil。
func (p *Proxy) GetCacheStats() *cache.ProxyCacheStats {
if p.cache == nil {
return nil
}
stats := p.cache.Stats()
return &stats
}

View File

@ -69,6 +69,9 @@ type Server struct {
// healthCheckers 健康检查器列表,用于检查代理目标健康状态
healthCheckers []*proxy.HealthChecker
// proxies 代理实例列表,用于收集缓存统计
proxies []*proxy.Proxy
// accessLogMiddleware 访问日志中间件,记录请求详细信息
accessLogMiddleware *accesslog.AccessLog
@ -217,6 +220,15 @@ func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middlewa
middlewares = append(middlewares, rl)
}
// 3.5 Security: ConnLimiter (连接数限制)
if serverCfg.Security.RateLimit.ConnLimit > 0 {
cl, err := security.NewConnLimiter(serverCfg.Security.RateLimit.ConnLimit, true, serverCfg.Security.RateLimit.Key)
if err != nil {
return nil, fmt.Errorf("创建连接限制中间件失败: %w", err)
}
middlewares = append(middlewares, cl.Middleware())
}
// 4. Security: BasicAuth (认证)
if len(serverCfg.Security.Auth.Users) > 0 {
auth, err := security.NewBasicAuth(&serverCfg.Security.Auth)
@ -339,6 +351,10 @@ func (s *Server) startSingleMode() error {
if s.fileCache != nil {
staticHandler.SetFileCache(s.fileCache)
}
// 设置预压缩文件支持
if s.config.Server.Compression.GzipStatic {
staticHandler.SetGzipStatic(true, s.config.Server.Compression.GzipStaticExtensions)
}
router.GET("/{filepath:*}", staticHandler.Handle)
router.HEAD("/{filepath:*}", staticHandler.Handle)
@ -417,6 +433,10 @@ func (s *Server) startVHostMode() error {
if s.fileCache != nil {
staticHandler.SetFileCache(s.fileCache)
}
// 设置预压缩文件支持
if s.config.Servers[i].Compression.GzipStatic {
staticHandler.SetGzipStatic(true, s.config.Servers[i].Compression.GzipStaticExtensions)
}
router.GET("/{filepath:*}", staticHandler.Handle)
router.HEAD("/{filepath:*}", staticHandler.Handle)
@ -457,6 +477,10 @@ func (s *Server) startVHostMode() error {
if s.fileCache != nil {
staticHandler.SetFileCache(s.fileCache)
}
// 设置预压缩文件支持
if s.config.Server.Compression.GzipStatic {
staticHandler.SetGzipStatic(true, s.config.Server.Compression.GzipStaticExtensions)
}
router.GET("/{filepath:*}", staticHandler.Handle)
chain, err := s.buildMiddlewareChain(&s.config.Server)
@ -550,6 +574,9 @@ func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.S
p.SetHealthChecker(hc)
}
// 保存代理实例用于缓存统计
s.proxies = append(s.proxies, p)
router.GET(proxyCfg.Path, p.ServeHTTP)
router.POST(proxyCfg.Path, p.ServeHTTP)
router.PUT(proxyCfg.Path, p.ServeHTTP)
@ -653,3 +680,15 @@ func (s *Server) GracefulStop(timeout time.Duration) error {
}
return nil
}
// getProxyCacheStats 收集所有代理缓存的统计信息。
func (s *Server) getProxyCacheStats() ProxyCacheStats {
var total ProxyCacheStats
for _, p := range s.proxies {
if stats := p.GetCacheStats(); stats != nil {
total.Entries += stats.Entries
total.Pending += stats.Pending
}
}
return total
}

View File

@ -250,6 +250,12 @@ func (h *StatusHandler) collectStatus() *Status {
MaxSize: stats.MaxSize,
},
}
// 收集代理缓存统计
proxyCacheStats := h.server.getProxyCacheStats()
if proxyCacheStats.Entries > 0 || proxyCacheStats.Pending > 0 {
status.Cache.ProxyCache = proxyCacheStats
}
}
// 收集 Goroutine 池统计