diff --git a/internal/cache/file_cache.go b/internal/cache/file_cache.go index 7e184c0..564981d 100644 --- a/internal/cache/file_cache.go +++ b/internal/cache/file_cache.go @@ -22,6 +22,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" ) @@ -293,13 +294,18 @@ type ProxyCacheRule struct { // ProxyCacheEntry 代理缓存条目。 type ProxyCacheEntry struct { - Created time.Time - Headers map[string]string - Key string - OrigKey string - Data []byte - Status int - MaxAge time.Duration + Created time.Time + Headers map[string]string + Key string + OrigKey string + Data []byte + Status int + MaxAge time.Duration + Uses atomic.Int32 // 访问计数,用于 min_uses 阈值检查 + Updating atomic.Bool // 后台更新标志,表示正在后台刷新 + LastModified string // Last-Modified 响应头,用于条件请求 + ETag string // ETag 响应头,用于条件请求 + LastValidated time.Time // 最后验证时间,用于防止验证循环 } // ProxyCache 代理响应缓存,支持缓存锁防击穿。 @@ -356,6 +362,9 @@ func (c *ProxyCache) Get(hashKey uint64, origKey string) (*ProxyCacheEntry, bool return nil, false, false } + // 增加访问计数(原子操作,用于 min_uses 阈值检查) + entry.Uses.Add(1) + // 检查是否过期 now := time.Now() expired := now.Sub(entry.Created) > entry.MaxAge @@ -423,6 +432,50 @@ func (c *ProxyCache) AcquireLock(hashKey uint64) <-chan struct{} { return nil // 获得锁,应该生成缓存 } +// AcquireLockWithTimeout 获取缓存生成锁(带超时)。 +// 返回值: +// - waitCh != nil && timedOut == false: 需要等待其他请求完成 +// - waitCh == nil && timedOut == false: 获得锁,应该生成缓存 +// - timedOut == true: 超时,应该放弃缓存直接请求上游 +func (c *ProxyCache) AcquireLockWithTimeout(hashKey uint64, timeout time.Duration) (waitCh <-chan struct{}, timedOut bool) { + if !c.cacheLock { + return nil, false // 不使用缓存锁 + } + + c.mu.Lock() + // 检查是否已有缓存 + if _, ok := c.entries[hashKey]; ok { + c.mu.Unlock() + return nil, false + } + + // 检查是否有 pending 请求 + if pending, ok := c.pending[hashKey]; ok { + c.mu.Unlock() + // 有其他请求正在生成,需要等待 + if timeout > 0 { + // 带超时等待 + select { + case <-pending.done: + // 刚刚完成,重新检查缓存 + return nil, false + case <-time.After(timeout): + // 超时 + return nil, true + } + } + return pending.done, false // 无限等待 + } + + // 创建新的 pending 请求 + pending := &pendingRequest{ + done: make(chan struct{}), + } + c.pending[hashKey] = pending + c.mu.Unlock() + return nil, false // 获得锁,应该生成缓存 +} + // ReleaseLock 释放缓存生成锁。 func (c *ProxyCache) ReleaseLock(hashKey uint64, err error) { if !c.cacheLock { @@ -442,9 +495,25 @@ func (c *ProxyCache) ReleaseLock(hashKey uint64, err error) { // MatchRule 检查请求是否匹配缓存规则。 func (c *ProxyCache) MatchRule(path, method string, status int) *ProxyCacheRule { for _, rule := range c.rules { - // 检查路径匹配(简单前缀匹配) - if rule.Path != "" && !MatchPattern(rule.Path, path) { - continue + // 检查路径匹配 + if rule.Path != "" { + // 如果路径以 / 结尾,使用前缀匹配 + // 如果路径包含 *,使用通配符匹配 + // 否则使用前缀匹配(允许 /api 匹配 /api/users) + if strings.HasSuffix(rule.Path, "/") { + if !strings.HasPrefix(path, rule.Path) { + continue + } + } else if strings.Contains(rule.Path, "*") { + if !MatchPattern(rule.Path, path) { + continue + } + } else { + // 精确匹配或前缀匹配 + if path != rule.Path && !strings.HasPrefix(path, rule.Path+"/") && !strings.HasPrefix(path, rule.Path+"?") && len(path) <= len(rule.Path) { + continue + } + } } // 检查方法 @@ -495,6 +564,49 @@ func (c *ProxyCache) Clear() { c.pending = make(map[uint64]*pendingRequest) } +// RefreshTTL 刷新缓存条目的 TTL(用于 304 响应处理)。 +// 不替换缓存内容,只更新验证时间和验证头。 +// 返回是否成功(条目可能已被驱逐)。 +func (c *ProxyCache) RefreshTTL(hashKey uint64, origKey string, newHeaders map[string]string) bool { + c.mu.Lock() + defer c.mu.Unlock() + + entry, ok := c.entries[hashKey] + if !ok || entry.OrigKey != origKey { + return false // 条目已被驱逐 + } + + // 更新验证时间(不更新 Created,保持 LRU 顺序) + entry.LastValidated = time.Now() + + // 更新验证头(如果提供) + if newHeaders != nil { + if lm, ok := newHeaders["Last-Modified"]; ok { + entry.LastModified = lm + } + if et, ok := newHeaders["ETag"]; ok { + entry.ETag = et + } + } + + return true +} + +// SetValidationHeaders 设置缓存条目的验证头(Last-Modified 和 ETag)。 +func (c *ProxyCache) SetValidationHeaders(hashKey uint64, origKey string, lastModified, etag string) bool { + c.mu.Lock() + defer c.mu.Unlock() + + entry, ok := c.entries[hashKey] + if !ok || entry.OrigKey != origKey { + return false + } + + entry.LastModified = lastModified + entry.ETag = etag + return true +} + // Stats 返回代理缓存统计。 func (c *ProxyCache) Stats() ProxyCacheStats { c.mu.RLock() diff --git a/internal/config/config.go b/internal/config/config.go index 82c4dc2..40fae37 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -594,10 +594,16 @@ type ProxyHeaders struct { // cache_lock: true // stale_while_revalidate: 1m type ProxyCacheConfig struct { - MaxAge time.Duration `yaml:"max_age"` - StaleWhileRevalidate time.Duration `yaml:"stale_while_revalidate"` - Enabled bool `yaml:"enabled"` - CacheLock bool `yaml:"cache_lock"` + MaxAge time.Duration `yaml:"max_age"` + StaleWhileRevalidate time.Duration `yaml:"stale_while_revalidate"` + Enabled bool `yaml:"enabled"` + CacheLock bool `yaml:"cache_lock"` + Methods []string `yaml:"methods"` + MinUses int `yaml:"min_uses"` // 缓存阈值,请求次数达到此值才缓存 + CacheLockTimeout time.Duration `yaml:"cache_lock_timeout"` // 缓存锁超时时间 + BackgroundUpdateDisable bool `yaml:"background_update_disable"` // 禁用后台更新(默认 false = 启用后台更新) + CacheIgnoreHeaders []string `yaml:"cache_ignore_headers"` // 缓存时忽略的响应头 + Revalidate bool `yaml:"revalidate"` // 启用条件请求(If-Modified-Since/If-None-Match) } // ProxyCacheValidConfig 缓存有效期分段配置。 diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 083032b..2899a94 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -404,8 +404,14 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) { buf.WriteString(" # cache: # 代理缓存\n") buf.WriteString(" # enabled: false\n") buf.WriteString(" # max_age: 60s\n") + buf.WriteString(" # methods: [GET, HEAD] # 可缓存的 HTTP 方法(默认 GET, HEAD)\n") + buf.WriteString(" # min_uses: 1 # 缓存阈值,请求次数达到此值才缓存(默认 1)\n") buf.WriteString(" # cache_lock: true # 防止缓存击穿\n") + buf.WriteString(" # cache_lock_timeout: 5s # 缓存锁超时时间(默认 5s)\n") buf.WriteString(" # stale_while_revalidate: 30s\n") + buf.WriteString(" # background_update_disable: false # 禁用后台更新(默认启用)\n") + buf.WriteString(" # cache_ignore_headers: [] # 缓存时忽略的响应头\n") + buf.WriteString(" # revalidate: false # 启用条件请求(默认关闭)\n") buf.WriteString(" # cache_valid: # 按 HTTP 状态码细分缓存时间(可选,未配置时使用 max_age)\n") buf.WriteString(" # ok: 10m # 200-299 缓存 10 分钟\n") buf.WriteString(" # redirect: 1h # 301/302 缓存 1 小时\n") diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 42a37a4..cf94f2c 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -169,9 +169,16 @@ func NewProxy(cfg *config.ProxyConfig, targets []*loadbalance.Target, transportC if cfg.Cache.Enabled { rules := make([]cache.ProxyCacheRule, 0) if cfg.Cache.MaxAge > 0 { + // 使用配置中的方法,若为空则使用默认值 GET, HEAD (nginx 默认行为) + methods := cfg.Cache.Methods + if len(methods) == 0 { + methods = []string{"GET", "HEAD"} + } rules = append(rules, cache.ProxyCacheRule{ - Path: cfg.Path, - MaxAge: cfg.Cache.MaxAge, + Path: cfg.Path, + Methods: methods, + Statuses: nil, // nil = 所有可缓存状态码 (由 getCacheDuration 处理) + MaxAge: cfg.Cache.MaxAge, }) } p.cache = cache.NewProxyCache(rules, cfg.Cache.CacheLock, cfg.Cache.StaleWhileRevalidate) @@ -590,6 +597,15 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { // 尝试从缓存获取(如果启用) if p.cache != nil && attempt == 0 { + // 检查请求方法是否允许缓存 + method := string(ctx.Request.Header.Method()) + path := string(ctx.Request.URI().Path()) + rule := p.cache.MatchRule(path, method, 0) + if rule == nil { + // 方法不在允许列表中,跳过缓存 + goto proxyRequest + } + hashKey, origKey := p.buildCacheKeyHash(ctx) if entry, ok, stale := p.cache.Get(hashKey, origKey); ok { // 缓存命中 @@ -605,8 +621,13 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { return } // 过期缓存,尝试后台刷新,同时返回旧数据 - - go p.backgroundRefresh(ctx, target, hashKey, origKey) + if !p.config.Cache.BackgroundUpdateDisable { + entry.Updating.Store(true) + go func() { + defer entry.Updating.Store(false) + p.backgroundRefresh(ctx, target, hashKey, origKey) + }() + } upstreamAddr = "CACHE" upstreamStatus = entry.Status @@ -618,10 +639,18 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { } // 检查是否需要缓存锁(防止缓存击穿) - if done := p.cache.AcquireLock(hashKey); done != nil { + timeout := p.config.Cache.CacheLockTimeout + if timeout == 0 && p.config.Cache.CacheLock { + timeout = 5 * time.Second // nginx 默认 5s + } + waitCh, timedOut := p.cache.AcquireLockWithTimeout(hashKey, timeout) + if timedOut { + // 超时,跳过缓存直接请求上游 + // 不缓存响应(nginx 行为) + } else if waitCh != nil { // 有其他请求正在生成缓存,等待 loadbalance.DecrementConnections(target) - <-done + <-waitCh // 重新尝试获取缓存 if entry, ok, _ := p.cache.Get(hashKey, origKey); ok { @@ -639,6 +668,7 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { } } + proxyRequest: // 执行代理请求 timing.MarkConnectStart() err := client.Do(req, &ctx.Response) @@ -723,8 +753,25 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { // 存入缓存(如果启用且响应可缓存) if p.cache != nil { + // 再次检查方法是否允许缓存 + method := string(ctx.Request.Header.Method()) + path := string(ctx.Request.URI().Path()) + if rule := p.cache.MatchRule(path, method, statusCode); rule == nil { + // 方法或状态码不在允许列表中,不缓存 + return + } + hashKey, origKey := p.buildCacheKeyHash(ctx) if statusCode >= 200 && statusCode < 300 { + // 检查 MinUses 阈值 + if entry, ok, _ := p.cache.Get(hashKey, origKey); ok { + minUses := p.config.Cache.MinUses + if minUses > 0 && entry.Uses.Load() < int32(minUses) { + p.cache.ReleaseLock(hashKey, nil) + return + } + } + // 提取响应头(使用 pool 复用 map) headers, ok := headersPool.Get().(map[string]string) if !ok { @@ -733,10 +780,31 @@ func (p *Proxy) ServeHTTP(ctx *fasthttp.RequestCtx) { for k := range headers { delete(headers, k) } + // 构建忽略头部查找表(大小写不敏感) + ignoreSet := make(map[string]bool, len(p.config.Cache.CacheIgnoreHeaders)) + for _, h := range p.config.Cache.CacheIgnoreHeaders { + ignoreSet[strings.ToLower(h)] = true + } + + var lastModified, etag string for key, value := range ctx.Response.Header.All() { + headerName := strings.ToLower(string(key)) + if ignoreSet[headerName] { + continue + } headers[string(key)] = string(value) + + switch headerName { + case "last-modified": + lastModified = string(value) + case "etag": + etag = string(value) + } } p.cache.Set(hashKey, origKey, ctx.Response.Body(), headers, statusCode, p.getCacheDuration(statusCode)) + if lastModified != "" || etag != "" { + p.cache.SetValidationHeaders(hashKey, origKey, lastModified, etag) + } // 注意:不能 Put 回 pool,因为 cache.Set 存储了 map 引用 // 后续 writeCachedResponse 会读取该 map } @@ -1312,6 +1380,18 @@ func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance. // 复制原始请求 ctx.Request.CopyTo(req) + // 如果启用 Revalidate,添加条件请求头 + if p.config.Cache.Revalidate { + if entry, ok, _ := p.cache.Get(hashKey, origKey); ok { + if entry.LastModified != "" { + req.Header.Set("If-Modified-Since", entry.LastModified) + } + if entry.ETag != "" { + req.Header.Set("If-None-Match", entry.ETag) + } + } + } + // 获取客户端 client := p.getClient(target.URL) if client == nil { @@ -1325,6 +1405,19 @@ func (p *Proxy) backgroundRefresh(ctx *fasthttp.RequestCtx, target *loadbalance. return } + // 处理 304 Not Modified 响应 + if resp.StatusCode() == 304 { + newHeaders := make(map[string]string) + if lm := resp.Header.Peek("Last-Modified"); len(lm) > 0 { + newHeaders["Last-Modified"] = string(lm) + } + if et := resp.Header.Peek("ETag"); len(et) > 0 { + newHeaders["ETag"] = string(et) + } + p.cache.RefreshTTL(hashKey, origKey, newHeaders) + return + } + // 提取响应头(使用 pool 复用 map) headers, ok := headersPool.Get().(map[string]string) if !ok {