feat(proxy): 实现代理缓存功能与一致性哈希支持
- 新增缓存命中/过期/刷新逻辑 - 实现缓存锁防止缓存击穿 (stale-while-revalidate) - 支持一致性哈希按 uri/ip/header 选择目标 - 新增 getProxyCacheStats 收集缓存统计 - 集成连接数限制中间件 (ConnLimiter) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
351f477822
commit
03b0df2c69
@ -234,6 +234,34 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
// 修改请求头
|
||||
p.modifyRequestHeaders(ctx, target)
|
||||
|
||||
// 尝试从缓存获取(如果启用)
|
||||
if p.cache != nil {
|
||||
cacheKey := p.buildCacheKey(ctx)
|
||||
if entry, ok, stale := p.cache.Get(cacheKey); ok {
|
||||
// 缓存命中
|
||||
if !stale {
|
||||
// 新鲜缓存,直接返回
|
||||
p.writeCachedResponse(ctx, entry)
|
||||
return
|
||||
}
|
||||
// 过期缓存,尝试后台刷新,同时返回旧数据
|
||||
go p.backgroundRefresh(ctx, target, cacheKey)
|
||||
p.writeCachedResponse(ctx, entry)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否需要缓存锁(防止缓存击穿)
|
||||
if done := p.cache.AcquireLock(cacheKey); done != nil {
|
||||
// 有其他请求正在生成缓存,等待
|
||||
<-done
|
||||
// 重新尝试获取缓存
|
||||
if entry, ok, _ := p.cache.Get(cacheKey); ok {
|
||||
p.writeCachedResponse(ctx, entry)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 执行代理请求
|
||||
err := client.Do(req, &ctx.Response)
|
||||
if err != nil {
|
||||
@ -242,6 +270,11 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
p.healthChecker.MarkUnhealthy(target)
|
||||
}
|
||||
|
||||
// 释放缓存锁
|
||||
if p.cache != nil {
|
||||
p.cache.ReleaseLock(p.buildCacheKey(ctx), err)
|
||||
}
|
||||
|
||||
// 处理不同类型的错误
|
||||
if errors.Is(err, fasthttp.ErrTimeout) {
|
||||
ctx.Error("Gateway Timeout", fasthttp.StatusGatewayTimeout)
|
||||
@ -253,12 +286,28 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) {
|
||||
return
|
||||
}
|
||||
|
||||
// 存入缓存(如果启用且响应可缓存)
|
||||
if p.cache != nil {
|
||||
cacheKey := p.buildCacheKey(ctx)
|
||||
status := ctx.Response.StatusCode()
|
||||
if status >= 200 && status < 300 {
|
||||
// 提取响应头
|
||||
headers := make(map[string]string)
|
||||
ctx.Response.Header.VisitAll(func(key, value []byte) {
|
||||
headers[string(key)] = string(value)
|
||||
})
|
||||
p.cache.Set(cacheKey, ctx.Response.Body(), headers, status, p.config.Cache.MaxAge)
|
||||
}
|
||||
p.cache.ReleaseLock(cacheKey, nil)
|
||||
}
|
||||
|
||||
// 修改响应头
|
||||
p.modifyResponseHeaders(ctx)
|
||||
}
|
||||
|
||||
// selectTarget 使用配置的负载均衡器选择后端目标。
|
||||
// 对于 IP 哈希负载均衡,从请求中提取客户端 IP。
|
||||
// 对于一致性哈希,根据配置的 hash_key 选择目标。
|
||||
// 如果没有可用的健康目标则返回 nil。
|
||||
func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
|
||||
p.mu.RLock()
|
||||
@ -276,9 +325,35 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target {
|
||||
return ipHash.SelectByIP(targets, clientIP)
|
||||
}
|
||||
|
||||
// 对于一致性哈希,根据 hash_key 配置选择
|
||||
if ch, ok := balancer.(*loadbalance.ConsistentHash); ok {
|
||||
hashKey := ch.GetHashKey()
|
||||
key := p.extractHashKey(ctx, hashKey)
|
||||
return ch.SelectByKey(targets, key)
|
||||
}
|
||||
|
||||
return balancer.Select(targets)
|
||||
}
|
||||
|
||||
// extractHashKey 根据配置提取哈希键值。
|
||||
func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string {
|
||||
switch {
|
||||
case hashKey == "ip" || hashKey == "":
|
||||
return getClientIP(ctx)
|
||||
case hashKey == "uri":
|
||||
return string(ctx.RequestURI())
|
||||
case strings.HasPrefix(hashKey, "header:"):
|
||||
headerName := strings.TrimPrefix(hashKey, "header:")
|
||||
value := ctx.Request.Header.Peek(headerName)
|
||||
if len(value) > 0 {
|
||||
return string(value)
|
||||
}
|
||||
return getClientIP(ctx) // fallback to IP
|
||||
default:
|
||||
return getClientIP(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// getClientIP 从请求上下文中提取客户端 IP 地址。
|
||||
func getClientIP(ctx *fasthttp.RequestCtx) string {
|
||||
// 首先检查 X-Forwarded-For 请求头
|
||||
@ -437,3 +512,63 @@ func (p *Proxy) GetConfig() *config.ProxyConfig {
|
||||
defer p.mu.RUnlock()
|
||||
return p.config
|
||||
}
|
||||
|
||||
// buildCacheKey 构建缓存键。
|
||||
func (p *Proxy) buildCacheKey(ctx *fasthttp.RequestCtx) string {
|
||||
// 使用请求方法和路径作为缓存键
|
||||
return string(ctx.Request.Header.Method()) + ":" + string(ctx.Request.URI().RequestURI())
|
||||
}
|
||||
|
||||
// writeCachedResponse 写入缓存的响应。
|
||||
func (p *Proxy) writeCachedResponse(ctx *fasthttp.RequestCtx, entry *cache.ProxyCacheEntry) {
|
||||
ctx.Response.SetBody(entry.Data)
|
||||
ctx.Response.SetStatusCode(entry.Status)
|
||||
for key, value := range entry.Headers {
|
||||
ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
ctx.Response.Header.Set("X-Cache", "HIT")
|
||||
}
|
||||
|
||||
// backgroundRefresh 后台刷新缓存。
|
||||
func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance.Target, cacheKey string) {
|
||||
// 创建新的请求上下文副本
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// 复制原始请求
|
||||
ctx.Request.CopyTo(req)
|
||||
|
||||
// 获取客户端
|
||||
client := p.getClient(target.URL)
|
||||
if client == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
err := client.Do(req, resp)
|
||||
if err != nil {
|
||||
p.cache.ReleaseLock(cacheKey, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 提取响应头
|
||||
headers := make(map[string]string)
|
||||
resp.Header.VisitAll(func(key, value []byte) {
|
||||
headers[string(key)] = string(value)
|
||||
})
|
||||
|
||||
// 更新缓存
|
||||
p.cache.Set(cacheKey, resp.Body(), headers, resp.StatusCode(), p.config.Cache.MaxAge)
|
||||
}
|
||||
|
||||
// GetCacheStats 返回代理缓存的统计信息。
|
||||
// 如果缓存未启用,返回 nil。
|
||||
func (p *Proxy) GetCacheStats() *cache.ProxyCacheStats {
|
||||
if p.cache == nil {
|
||||
return nil
|
||||
}
|
||||
stats := p.cache.Stats()
|
||||
return &stats
|
||||
}
|
||||
|
||||
@ -69,6 +69,9 @@ type Server struct {
|
||||
// healthCheckers 健康检查器列表,用于检查代理目标健康状态
|
||||
healthCheckers []*proxy.HealthChecker
|
||||
|
||||
// proxies 代理实例列表,用于收集缓存统计
|
||||
proxies []*proxy.Proxy
|
||||
|
||||
// accessLogMiddleware 访问日志中间件,记录请求详细信息
|
||||
accessLogMiddleware *accesslog.AccessLog
|
||||
|
||||
@ -217,6 +220,15 @@ func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middlewa
|
||||
middlewares = append(middlewares, rl)
|
||||
}
|
||||
|
||||
// 3.5 Security: ConnLimiter (连接数限制)
|
||||
if serverCfg.Security.RateLimit.ConnLimit > 0 {
|
||||
cl, err := security.NewConnLimiter(serverCfg.Security.RateLimit.ConnLimit, true, serverCfg.Security.RateLimit.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建连接限制中间件失败: %w", err)
|
||||
}
|
||||
middlewares = append(middlewares, cl.Middleware())
|
||||
}
|
||||
|
||||
// 4. Security: BasicAuth (认证)
|
||||
if len(serverCfg.Security.Auth.Users) > 0 {
|
||||
auth, err := security.NewBasicAuth(&serverCfg.Security.Auth)
|
||||
@ -339,6 +351,10 @@ func (s *Server) startSingleMode() error {
|
||||
if s.fileCache != nil {
|
||||
staticHandler.SetFileCache(s.fileCache)
|
||||
}
|
||||
// 设置预压缩文件支持
|
||||
if s.config.Server.Compression.GzipStatic {
|
||||
staticHandler.SetGzipStatic(true, s.config.Server.Compression.GzipStaticExtensions)
|
||||
}
|
||||
router.GET("/{filepath:*}", staticHandler.Handle)
|
||||
router.HEAD("/{filepath:*}", staticHandler.Handle)
|
||||
|
||||
@ -417,6 +433,10 @@ func (s *Server) startVHostMode() error {
|
||||
if s.fileCache != nil {
|
||||
staticHandler.SetFileCache(s.fileCache)
|
||||
}
|
||||
// 设置预压缩文件支持
|
||||
if s.config.Servers[i].Compression.GzipStatic {
|
||||
staticHandler.SetGzipStatic(true, s.config.Servers[i].Compression.GzipStaticExtensions)
|
||||
}
|
||||
router.GET("/{filepath:*}", staticHandler.Handle)
|
||||
router.HEAD("/{filepath:*}", staticHandler.Handle)
|
||||
|
||||
@ -457,6 +477,10 @@ func (s *Server) startVHostMode() error {
|
||||
if s.fileCache != nil {
|
||||
staticHandler.SetFileCache(s.fileCache)
|
||||
}
|
||||
// 设置预压缩文件支持
|
||||
if s.config.Server.Compression.GzipStatic {
|
||||
staticHandler.SetGzipStatic(true, s.config.Server.Compression.GzipStaticExtensions)
|
||||
}
|
||||
router.GET("/{filepath:*}", staticHandler.Handle)
|
||||
|
||||
chain, err := s.buildMiddlewareChain(&s.config.Server)
|
||||
@ -550,6 +574,9 @@ func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.S
|
||||
p.SetHealthChecker(hc)
|
||||
}
|
||||
|
||||
// 保存代理实例用于缓存统计
|
||||
s.proxies = append(s.proxies, p)
|
||||
|
||||
router.GET(proxyCfg.Path, p.ServeHTTP)
|
||||
router.POST(proxyCfg.Path, p.ServeHTTP)
|
||||
router.PUT(proxyCfg.Path, p.ServeHTTP)
|
||||
@ -653,3 +680,15 @@ func (s *Server) GracefulStop(timeout time.Duration) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getProxyCacheStats 收集所有代理缓存的统计信息。
|
||||
func (s *Server) getProxyCacheStats() ProxyCacheStats {
|
||||
var total ProxyCacheStats
|
||||
for _, p := range s.proxies {
|
||||
if stats := p.GetCacheStats(); stats != nil {
|
||||
total.Entries += stats.Entries
|
||||
total.Pending += stats.Pending
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
@ -250,6 +250,12 @@ func (h *StatusHandler) collectStatus() *Status {
|
||||
MaxSize: stats.MaxSize,
|
||||
},
|
||||
}
|
||||
|
||||
// 收集代理缓存统计
|
||||
proxyCacheStats := h.server.getProxyCacheStats()
|
||||
if proxyCacheStats.Entries > 0 || proxyCacheStats.Pending > 0 {
|
||||
status.Cache.ProxyCache = proxyCacheStats
|
||||
}
|
||||
}
|
||||
|
||||
// 收集 Goroutine 池统计
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user