diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 3fe2fde..a95cc25 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -234,6 +234,34 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { // 修改请求头 p.modifyRequestHeaders(ctx, target) + // 尝试从缓存获取(如果启用) + if p.cache != nil { + cacheKey := p.buildCacheKey(ctx) + if entry, ok, stale := p.cache.Get(cacheKey); ok { + // 缓存命中 + if !stale { + // 新鲜缓存,直接返回 + p.writeCachedResponse(ctx, entry) + return + } + // 过期缓存,尝试后台刷新,同时返回旧数据 + go p.backgroundRefresh(ctx, target, cacheKey) + p.writeCachedResponse(ctx, entry) + return + } + + // 检查是否需要缓存锁(防止缓存击穿) + if done := p.cache.AcquireLock(cacheKey); done != nil { + // 有其他请求正在生成缓存,等待 + <-done + // 重新尝试获取缓存 + if entry, ok, _ := p.cache.Get(cacheKey); ok { + p.writeCachedResponse(ctx, entry) + return + } + } + } + // 执行代理请求 err := client.Do(req, &ctx.Response) if err != nil { @@ -242,6 +270,11 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { p.healthChecker.MarkUnhealthy(target) } + // 释放缓存锁 + if p.cache != nil { + p.cache.ReleaseLock(p.buildCacheKey(ctx), err) + } + // 处理不同类型的错误 if errors.Is(err, fasthttp.ErrTimeout) { ctx.Error("Gateway Timeout", fasthttp.StatusGatewayTimeout) @@ -253,12 +286,28 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { return } + // 存入缓存(如果启用且响应可缓存) + if p.cache != nil { + cacheKey := p.buildCacheKey(ctx) + status := ctx.Response.StatusCode() + if status >= 200 && status < 300 { + // 提取响应头 + headers := make(map[string]string) + ctx.Response.Header.VisitAll(func(key, value []byte) { + headers[string(key)] = string(value) + }) + p.cache.Set(cacheKey, ctx.Response.Body(), headers, status, p.config.Cache.MaxAge) + } + p.cache.ReleaseLock(cacheKey, nil) + } + // 修改响应头 p.modifyResponseHeaders(ctx) } // selectTarget 使用配置的负载均衡器选择后端目标。 // 对于 IP 哈希负载均衡,从请求中提取客户端 IP。 +// 对于一致性哈希,根据配置的 hash_key 选择目标。 // 如果没有可用的健康目标则返回 nil。 func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target { p.mu.RLock() @@ -276,9 +325,35 @@ func (p *Proxy) selectTarget(ctx *fasthttp.RequestCtx) *loadbalance.Target { return ipHash.SelectByIP(targets, clientIP) } + // 对于一致性哈希,根据 hash_key 配置选择 + if ch, ok := balancer.(*loadbalance.ConsistentHash); ok { + hashKey := ch.GetHashKey() + key := p.extractHashKey(ctx, hashKey) + return ch.SelectByKey(targets, key) + } + return balancer.Select(targets) } +// extractHashKey 根据配置提取哈希键值。 +func (p *Proxy) extractHashKey(ctx *fasthttp.RequestCtx, hashKey string) string { + switch { + case hashKey == "ip" || hashKey == "": + return getClientIP(ctx) + case hashKey == "uri": + return string(ctx.RequestURI()) + case strings.HasPrefix(hashKey, "header:"): + headerName := strings.TrimPrefix(hashKey, "header:") + value := ctx.Request.Header.Peek(headerName) + if len(value) > 0 { + return string(value) + } + return getClientIP(ctx) // fallback to IP + default: + return getClientIP(ctx) + } +} + // getClientIP 从请求上下文中提取客户端 IP 地址。 func getClientIP(ctx *fasthttp.RequestCtx) string { // 首先检查 X-Forwarded-For 请求头 @@ -437,3 +512,63 @@ func (p *Proxy) GetConfig() *config.ProxyConfig { defer p.mu.RUnlock() return p.config } + +// buildCacheKey 构建缓存键。 +func (p *Proxy) buildCacheKey(ctx *fasthttp.RequestCtx) string { + // 使用请求方法和路径作为缓存键 + return string(ctx.Request.Header.Method()) + ":" + string(ctx.Request.URI().RequestURI()) +} + +// writeCachedResponse 写入缓存的响应。 +func (p *Proxy) writeCachedResponse(ctx *fasthttp.RequestCtx, entry *cache.ProxyCacheEntry) { + ctx.Response.SetBody(entry.Data) + ctx.Response.SetStatusCode(entry.Status) + for key, value := range entry.Headers { + ctx.Response.Header.Set(key, value) + } + ctx.Response.Header.Set("X-Cache", "HIT") +} + +// backgroundRefresh 后台刷新缓存。 +func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance.Target, cacheKey string) { + // 创建新的请求上下文副本 + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // 复制原始请求 + ctx.Request.CopyTo(req) + + // 获取客户端 + client := p.getClient(target.URL) + if client == nil { + return + } + + // 执行请求 + err := client.Do(req, resp) + if err != nil { + p.cache.ReleaseLock(cacheKey, err) + return + } + + // 提取响应头 + headers := make(map[string]string) + resp.Header.VisitAll(func(key, value []byte) { + headers[string(key)] = string(value) + }) + + // 更新缓存 + p.cache.Set(cacheKey, resp.Body(), headers, resp.StatusCode(), p.config.Cache.MaxAge) +} + +// GetCacheStats 返回代理缓存的统计信息。 +// 如果缓存未启用,返回 nil。 +func (p *Proxy) GetCacheStats() *cache.ProxyCacheStats { + if p.cache == nil { + return nil + } + stats := p.cache.Stats() + return &stats +} diff --git a/internal/server/server.go b/internal/server/server.go index 80fb9b5..fe0e9a3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -69,6 +69,9 @@ type Server struct { // healthCheckers 健康检查器列表,用于检查代理目标健康状态 healthCheckers []*proxy.HealthChecker + // proxies 代理实例列表,用于收集缓存统计 + proxies []*proxy.Proxy + // accessLogMiddleware 访问日志中间件,记录请求详细信息 accessLogMiddleware *accesslog.AccessLog @@ -217,6 +220,15 @@ func (s *Server) buildMiddlewareChain(serverCfg *config.ServerConfig) (*middlewa middlewares = append(middlewares, rl) } + // 3.5 Security: ConnLimiter (连接数限制) + if serverCfg.Security.RateLimit.ConnLimit > 0 { + cl, err := security.NewConnLimiter(serverCfg.Security.RateLimit.ConnLimit, true, serverCfg.Security.RateLimit.Key) + if err != nil { + return nil, fmt.Errorf("创建连接限制中间件失败: %w", err) + } + middlewares = append(middlewares, cl.Middleware()) + } + // 4. Security: BasicAuth (认证) if len(serverCfg.Security.Auth.Users) > 0 { auth, err := security.NewBasicAuth(&serverCfg.Security.Auth) @@ -339,6 +351,10 @@ func (s *Server) startSingleMode() error { if s.fileCache != nil { staticHandler.SetFileCache(s.fileCache) } + // 设置预压缩文件支持 + if s.config.Server.Compression.GzipStatic { + staticHandler.SetGzipStatic(true, s.config.Server.Compression.GzipStaticExtensions) + } router.GET("/{filepath:*}", staticHandler.Handle) router.HEAD("/{filepath:*}", staticHandler.Handle) @@ -417,6 +433,10 @@ func (s *Server) startVHostMode() error { if s.fileCache != nil { staticHandler.SetFileCache(s.fileCache) } + // 设置预压缩文件支持 + if s.config.Servers[i].Compression.GzipStatic { + staticHandler.SetGzipStatic(true, s.config.Servers[i].Compression.GzipStaticExtensions) + } router.GET("/{filepath:*}", staticHandler.Handle) router.HEAD("/{filepath:*}", staticHandler.Handle) @@ -457,6 +477,10 @@ func (s *Server) startVHostMode() error { if s.fileCache != nil { staticHandler.SetFileCache(s.fileCache) } + // 设置预压缩文件支持 + if s.config.Server.Compression.GzipStatic { + staticHandler.SetGzipStatic(true, s.config.Server.Compression.GzipStaticExtensions) + } router.GET("/{filepath:*}", staticHandler.Handle) chain, err := s.buildMiddlewareChain(&s.config.Server) @@ -550,6 +574,9 @@ func (s *Server) registerProxyRoutes(router *handler.Router, serverCfg *config.S p.SetHealthChecker(hc) } + // 保存代理实例用于缓存统计 + s.proxies = append(s.proxies, p) + router.GET(proxyCfg.Path, p.ServeHTTP) router.POST(proxyCfg.Path, p.ServeHTTP) router.PUT(proxyCfg.Path, p.ServeHTTP) @@ -653,3 +680,15 @@ func (s *Server) GracefulStop(timeout time.Duration) error { } return nil } + +// getProxyCacheStats 收集所有代理缓存的统计信息。 +func (s *Server) getProxyCacheStats() ProxyCacheStats { + var total ProxyCacheStats + for _, p := range s.proxies { + if stats := p.GetCacheStats(); stats != nil { + total.Entries += stats.Entries + total.Pending += stats.Pending + } + } + return total +} diff --git a/internal/server/status.go b/internal/server/status.go index f3f6ab7..bd7166e 100644 --- a/internal/server/status.go +++ b/internal/server/status.go @@ -250,6 +250,12 @@ func (h *StatusHandler) collectStatus() *Status { MaxSize: stats.MaxSize, }, } + + // 收集代理缓存统计 + proxyCacheStats := h.server.getProxyCacheStats() + if proxyCacheStats.Entries > 0 || proxyCacheStats.Pending > 0 { + status.Cache.ProxyCache = proxyCacheStats + } } // 收集 Goroutine 池统计