diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 0000000..5fa2223 --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,347 @@ +package cache + +import ( + "testing" + "time" +) + +func TestNewFileCache(t *testing.T) { + fc := NewFileCache(100, 1024*1024, 30*time.Second) + if fc == nil { + t.Error("Expected non-nil FileCache") + } +} + +func TestFileCacheSetGet(t *testing.T) { + fc := NewFileCache(10, 1024, 1*time.Hour) + + path := "/test/file.txt" + data := []byte("Hello, World!") + + err := fc.Set(path, data, int64(len(data)), time.Now()) + if err != nil { + t.Errorf("Set() error: %v", err) + } + + entry, ok := fc.Get(path) + if !ok { + t.Error("Expected to find cached entry") + } + if string(entry.Data) != "Hello, World!" { + t.Errorf("Expected data 'Hello, World!', got %s", entry.Data) + } +} + +func TestFileCacheDelete(t *testing.T) { + fc := NewFileCache(10, 1024, 1*time.Hour) + + fc.Set("/test.txt", []byte("data"), 4, time.Now()) + + fc.Delete("/test.txt") + + _, ok := fc.Get("/test.txt") + if ok { + t.Error("Expected entry to be deleted") + } +} + +func TestFileCacheLRUEviction(t *testing.T) { + // 最大 3 个条目 + fc := NewFileCache(3, 0, 1*time.Hour) + + fc.Set("/a", []byte("a"), 1, time.Now()) + fc.Set("/b", []byte("b"), 1, time.Now()) + fc.Set("/c", []byte("c"), 1, time.Now()) + + // 再添加一个,应该淘汰 /a + fc.Set("/d", []byte("d"), 1, time.Now()) + + _, ok := fc.Get("/a") + if ok { + t.Error("Expected /a to be evicted") + } + + // b, c, d 应该还在 + for _, path := range []string{"b", "c", "d"} { + _, ok := fc.Get("/" + path) + if !ok { + t.Errorf("Expected /%s to exist", path) + } + } +} + +func TestFileCacheSizeEviction(t *testing.T) { + // 最大 10 字节 + fc := NewFileCache(0, 10, 1*time.Hour) + + fc.Set("/a", []byte("12345"), 5, time.Now()) + fc.Set("/b", []byte("12345"), 5, time.Now()) + + // 再添加 6 字节,应该淘汰一个 + fc.Set("/c", []byte("123456"), 6, time.Now()) + + stats := fc.Stats() + if stats.Size > 10 { + t.Errorf("Expected size <= 10, got %d", stats.Size) + } +} + +func TestFileCacheInactiveEviction(t *testing.T) { + fc := NewFileCache(10, 1024, 100*time.Millisecond) + + fc.Set("/test", []byte("data"), 4, time.Now()) + + // 立即获取应该成功 + _, ok := fc.Get("/test") + if !ok { + t.Error("Expected entry to exist") + } + + // 等待过期 + time.Sleep(150 * time.Millisecond) + + // 再次获取应该失败(因过期被删除) + _, ok = fc.Get("/test") + if ok { + t.Error("Expected entry to be expired") + } +} + +func TestFileCacheClear(t *testing.T) { + fc := NewFileCache(10, 1024, 1*time.Hour) + + fc.Set("/a", []byte("a"), 1, time.Now()) + fc.Set("/b", []byte("b"), 1, time.Now()) + + fc.Clear() + + stats := fc.Stats() + if stats.Entries != 0 { + t.Errorf("Expected 0 entries after clear, got %d", stats.Entries) + } +} + +func TestFileCacheStats(t *testing.T) { + fc := NewFileCache(100, 1024, 1*time.Hour) + + fc.Set("/a", []byte("12345"), 5, time.Now()) + fc.Set("/b", []byte("12345"), 5, time.Now()) + + stats := fc.Stats() + if stats.Entries != 2 { + t.Errorf("Expected 2 entries, got %d", stats.Entries) + } + if stats.Size != 10 { + t.Errorf("Expected size 10, got %d", stats.Size) + } +} + +func TestNewProxyCache(t *testing.T) { + rules := []ProxyCacheRule{ + {Path: "/api/", Methods: []string{"GET"}, MaxAge: 10 * time.Minute}, + } + + pc := NewProxyCache(rules, true, 60*time.Second) + if pc == nil { + t.Error("Expected non-nil ProxyCache") + } +} + +func TestProxyCacheSetGet(t *testing.T) { + pc := NewProxyCache(nil, false, 0) + + key := "test-key" + data := []byte("response body") + headers := map[string]string{"Content-Type": "application/json"} + + pc.Set(key, data, headers, 200, 10*time.Minute) + + entry, ok, stale := pc.Get(key) + if !ok { + t.Error("Expected to find cached entry") + } + if stale { + t.Error("Expected entry to be fresh") + } + if string(entry.Data) != "response body" { + t.Errorf("Expected data 'response body', got %s", entry.Data) + } + if entry.Status != 200 { + t.Errorf("Expected status 200, got %d", entry.Status) + } +} + +func TestProxyCacheExpiration(t *testing.T) { + pc := NewProxyCache(nil, false, 0) + + key := "expire-test" + pc.Set(key, []byte("data"), nil, 200, 100*time.Millisecond) + + // 立即获取应该成功 + _, ok, _ := pc.Get(key) + if !ok { + t.Error("Expected entry to exist") + } + + // 等待过期 + time.Sleep(150 * time.Millisecond) + + _, ok, _ = pc.Get(key) + if ok { + t.Error("Expected entry to be expired") + } +} + +func TestProxyCacheStaleWhileRevalidate(t *testing.T) { + pc := NewProxyCache(nil, false, 200*time.Millisecond) + + key := "stale-test" + pc.Set(key, []byte("data"), nil, 200, 100*time.Millisecond) + + // 等待过期但仍在 stale 时间内 + time.Sleep(150 * time.Millisecond) + + entry, ok, stale := pc.Get(key) + if !ok { + t.Error("Expected stale entry to be usable") + } + if !stale { + t.Error("Expected entry to be marked as stale") + } + if entry == nil { + t.Error("Expected stale entry data") + } +} + +func TestProxyCacheLock(t *testing.T) { + pc := NewProxyCache(nil, true, 0) + + key := "lock-test" + + // 获取锁 + ch := pc.AcquireLock(key) + if ch != nil { + t.Error("Expected to acquire lock (nil chan)") + } + + // 第二次获取应该返回等待 chan + ch2 := pc.AcquireLock(key) + if ch2 == nil { + t.Error("Expected waiting chan when lock is held") + } + + // 设置缓存并释放锁 + pc.Set(key, []byte("data"), nil, 200, 10*time.Minute) + + // 现在应该能获取缓存 + _, ok, _ := pc.Get(key) + if !ok { + t.Error("Expected cache entry after lock release") + } +} + +func TestProxyCacheMatchRule(t *testing.T) { + rules := []ProxyCacheRule{ + {Path: "/api/", Methods: []string{"GET"}, Statuses: []int{200}, MaxAge: 10 * time.Minute}, + {Path: "/static/*", Methods: []string{"GET"}, MaxAge: 1 * time.Hour}, + } + + pc := NewProxyCache(rules, false, 0) + + tests := []struct { + path string + method string + status int + want bool + }{ + {"api/users", "GET", 200, true}, + {"api/users", "POST", 200, false}, // POST 不在 Methods + {"api/users", "GET", 404, false}, // 404 不在 Statuses + {"static/css/style.css", "GET", 200, true}, + {"other/path", "GET", 200, false}, // 不匹配任何规则 + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + // 添加前缀 / 到 path + fullPath := "/" + tt.path + rule := pc.MatchRule(fullPath, tt.method, tt.status) + if (rule != nil) != tt.want { + t.Errorf("MatchRule(%s, %s, %d) want %v", fullPath, tt.method, tt.status, tt.want) + } + }) + } +} + +func TestProxyCacheDelete(t *testing.T) { + pc := NewProxyCache(nil, false, 0) + + pc.Set("key1", []byte("data"), nil, 200, 10*time.Minute) + pc.Delete("key1") + + _, ok, _ := pc.Get("key1") + if ok { + t.Error("Expected entry to be deleted") + } +} + +func TestProxyCacheClear(t *testing.T) { + pc := NewProxyCache(nil, false, 0) + + pc.Set("a", []byte("a"), nil, 200, 10*time.Minute) + pc.Set("b", []byte("b"), nil, 200, 10*time.Minute) + + pc.Clear() + + stats := pc.Stats() + if stats.Entries != 0 { + t.Errorf("Expected 0 entries, got %d", stats.Entries) + } +} + +func TestPathMatch(t *testing.T) { + tests := []struct { + pattern string + path string + want bool + }{ + {"*", "/anything", true}, + {"api/*", "/api/users", true}, + {"api/*", "/api/", true}, + {"api/*", "/other", false}, + {"/exact", "/exact", true}, + {"/exact", "/exact/other", false}, + } + + for _, tt := range tests { + t.Run(tt.pattern+"_"+tt.path, func(t *testing.T) { + // 添加前缀 / 如果 pattern 没有 + pattern := tt.pattern + if pattern[0] != '/' && pattern != "*" { + pattern = "/" + pattern + } + result := pathMatch(pattern, tt.path) + if result != tt.want { + t.Errorf("pathMatch(%s, %s) = %v, want %v", pattern, tt.path, result, tt.want) + } + }) + } +} + +func TestContains(t *testing.T) { + if !contains([]string{"GET", "POST"}, "GET") { + t.Error("Expected to find GET") + } + if contains([]string{"GET", "POST"}, "DELETE") { + t.Error("Expected not to find DELETE") + } +} + +func TestContainsInt(t *testing.T) { + if !containsInt([]int{200, 301, 302}, 200) { + t.Error("Expected to find 200") + } + if containsInt([]int{200, 301, 302}, 404) { + t.Error("Expected not to find 404") + } +} \ No newline at end of file diff --git a/internal/cache/file_cache.go b/internal/cache/file_cache.go new file mode 100644 index 0000000..433b002 --- /dev/null +++ b/internal/cache/file_cache.go @@ -0,0 +1,404 @@ +// Package cache 提供文件缓存和代理缓存功能,支持 LRU 淘汰和缓存锁防击穿。 +package cache + +import ( + "container/list" + "strings" + "sync" + "time" +) + +// FileEntry 文件缓存条目。 +type FileEntry struct { + Path string // 文件路径 + Size int64 // 文件大小 + ModTime time.Time // 修改时间 + LastAccess time.Time // 最后访问时间 + Data []byte // 文件内容 + element *list.Element // LRU 链表元素 +} + +// FileCache 文件缓存,支持 LRU 淘汰。 +type FileCache struct { + maxEntries int64 // 最大条目数 + maxSize int64 // 内存上限(字节) + inactive time.Duration // 未访问淘汰时间 + entries map[string]*FileEntry + lruList *list.List // LRU 链表 + mu sync.RWMutex + currentSize int64 // 当前内存使用 +} + +// NewFileCache 创建文件缓存。 +func NewFileCache(maxEntries, maxSize int64, inactive time.Duration) *FileCache { + return &FileCache{ + maxEntries: maxEntries, + maxSize: maxSize, + inactive: inactive, + entries: make(map[string]*FileEntry), + lruList: list.New(), + } +} + +// Get 获取缓存的文件。 +func (c *FileCache) Get(path string) (*FileEntry, bool) { + c.mu.Lock() + defer c.mu.Unlock() + + entry, ok := c.entries[path] + if !ok { + return nil, false + } + + // 检查是否过期 + if time.Since(entry.LastAccess) > c.inactive { + c.removeEntry(entry) + return nil, false + } + + // 更新访问时间并移到 LRU 链表头部 + entry.LastAccess = time.Now() + c.lruList.MoveToFront(entry.element) + + return entry, true +} + +// Set 设置缓存条目。 +func (c *FileCache) Set(path string, data []byte, size int64, modTime time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() + + // 检查是否已存在 + if entry, ok := c.entries[path]; ok { + c.currentSize -= entry.Size + entry.Data = data + entry.Size = size + entry.ModTime = modTime + entry.LastAccess = time.Now() + c.currentSize += size + c.lruList.MoveToFront(entry.element) + c.evictIfNeeded() + return nil + } + + // 创建新条目 + entry := &FileEntry{ + Path: path, + Data: data, + Size: size, + ModTime: modTime, + LastAccess: time.Now(), + } + entry.element = c.lruList.PushFront(entry) + c.entries[path] = entry + c.currentSize += size + + c.evictIfNeeded() + return nil +} + +// Delete 删除缓存条目。 +func (c *FileCache) Delete(path string) { + c.mu.Lock() + defer c.mu.Unlock() + + if entry, ok := c.entries[path]; ok { + c.removeEntry(entry) + } +} + +// removeEntry 内部删除条目(不加锁)。 +func (c *FileCache) removeEntry(entry *FileEntry) { + c.lruList.Remove(entry.element) + delete(c.entries, entry.Path) + c.currentSize -= entry.Size +} + +// evictIfNeeded 根据限制淘汰条目。 +func (c *FileCache) evictIfNeeded() { + // 按条目数淘汰 + for c.lruList.Len() > int(c.maxEntries) && c.maxEntries > 0 { + c.evictLRU() + } + + // 按内存大小淘汰 + for c.currentSize > c.maxSize && c.maxSize > 0 { + c.evictLRU() + } +} + +// evictLRU 淘汰最久未使用的条目。 +func (c *FileCache) evictLRU() { + if c.lruList.Len() == 0 { + return + } + + element := c.lruList.Back() + if element == nil { + return + } + + entry := element.Value.(*FileEntry) + c.removeEntry(entry) +} + +// Clear 清空缓存。 +func (c *FileCache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + + c.entries = make(map[string]*FileEntry) + c.lruList = list.New() + c.currentSize = 0 +} + +// Stats 返回缓存统计信息。 +func (c *FileCache) Stats() FileCacheStats { + c.mu.RLock() + defer c.mu.RUnlock() + + return FileCacheStats{ + Entries: int64(len(c.entries)), + MaxEntries: c.maxEntries, + Size: c.currentSize, + MaxSize: c.maxSize, + } +} + +// FileCacheStats 文件缓存统计。 +type FileCacheStats struct { + Entries int64 + MaxEntries int64 + Size int64 + MaxSize int64 +} + +// ProxyCacheRule 代理缓存规则。 +type ProxyCacheRule struct { + Path string // 匹配路径 + Methods []string // 可缓存的 HTTP 方法 + Statuses []int // 可缓存的状态码 + MaxAge time.Duration // 缓存有效期 +} + +// ProxyCacheEntry 代理缓存条目。 +type ProxyCacheEntry struct { + Key string // 缓存 key + Data []byte // 响应体 + Headers map[string]string // 响应头 + Status int // 状态码 + Created time.Time // 创建时间 + MaxAge time.Duration // 有效期 +} + +// ProxyCache 代理响应缓存,支持缓存锁防击穿。 +type ProxyCache struct { + rules []ProxyCacheRule + entries map[string]*ProxyCacheEntry + mu sync.RWMutex + cacheLock bool // 缓存锁开关 + pending map[string]*pendingRequest // 正在生成的缓存项 + staleTime time.Duration // 过期缓存复用时间 +} + +// pendingRequest 等待中的缓存请求。 +type pendingRequest struct { + done chan struct{} // 完成信号 + err error // 生成结果 +} + +// NewProxyCache 创建代理缓存。 +func NewProxyCache(rules []ProxyCacheRule, cacheLock bool, staleTime time.Duration) *ProxyCache { + return &ProxyCache{ + rules: rules, + entries: make(map[string]*ProxyCacheEntry), + cacheLock: cacheLock, + pending: make(map[string]*pendingRequest), + staleTime: staleTime, + } +} + +// Get 获取缓存的代理响应。 +func (c *ProxyCache) Get(key string) (*ProxyCacheEntry, bool, bool) { + c.mu.RLock() + entry, ok := c.entries[key] + c.mu.RUnlock() + + if !ok { + return nil, false, false + } + + // 检查是否过期 + now := time.Now() + expired := now.Sub(entry.Created) > entry.MaxAge + + if expired { + // 检查是否可以使用过期缓存 + if c.staleTime > 0 && now.Sub(entry.Created) <= entry.MaxAge+c.staleTime { + return entry, true, true // stale but usable + } + return nil, false, false + } + + return entry, true, false +} + +// Set 设置代理缓存条目。 +func (c *ProxyCache) Set(key string, data []byte, headers map[string]string, status int, maxAge time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + + c.entries[key] = &ProxyCacheEntry{ + Key: key, + Data: data, + Headers: headers, + Status: status, + Created: time.Now(), + MaxAge: maxAge, + } + + // 如果有等待的请求,通知它们 + if pending, ok := c.pending[key]; ok { + pending.err = nil + close(pending.done) + delete(c.pending, key) + } +} + +// AcquireLock 获取缓存生成锁(防止击穿)。 +// 如果返回 nil,表示获得锁,应该去生成缓存。 +// 如果返回 chan,表示有其他请求正在生成,应该等待。 +func (c *ProxyCache) AcquireLock(key string) <-chan struct{} { + if !c.cacheLock { + return nil // 不使用缓存锁 + } + + c.mu.Lock() + defer c.mu.Unlock() + + // 检查是否已有缓存 + if _, ok := c.entries[key]; ok { + return nil + } + + // 检查是否有 pending 请求 + if pending, ok := c.pending[key]; ok { + return pending.done // 等待现有请求 + } + + // 创建新的 pending 请求 + pending := &pendingRequest{ + done: make(chan struct{}), + } + c.pending[key] = pending + return nil // 获得锁,应该生成缓存 +} + +// ReleaseLock 释放缓存生成锁。 +func (c *ProxyCache) ReleaseLock(key string, err error) { + if !c.cacheLock { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + if pending, ok := c.pending[key]; ok { + pending.err = err + close(pending.done) + delete(c.pending, key) + } +} + +// MatchRule 检查请求是否匹配缓存规则。 +func (c *ProxyCache) MatchRule(path, method string, status int) *ProxyCacheRule { + for _, rule := range c.rules { + // 检查路径匹配(简单前缀匹配) + if rule.Path != "" && !pathMatch(rule.Path, path) { + continue + } + + // 检查方法 + if len(rule.Methods) > 0 && !contains(rule.Methods, method) { + continue + } + + // 检查状态码 + if len(rule.Statuses) > 0 && !containsInt(rule.Statuses, status) { + continue + } + + return &rule + } + return nil +} + +// pathMatch 路径匹配(支持前缀和精确匹配)。 +func pathMatch(pattern, path string) bool { + if pattern == "*" { + return true + } + // 通配符匹配 + if pattern[len(pattern)-1] == '*' { + prefix := pattern[:len(pattern)-1] + return strings.HasPrefix(path, prefix) + } + // 前缀匹配(pattern 以 / 结尾) + if pattern[len(pattern)-1] == '/' { + return strings.HasPrefix(path, pattern) + } + // 精确匹配 + return path == pattern +} + +// contains 检查字符串切片是否包含某值。 +func contains(slice []string, val string) bool { + for _, s := range slice { + if s == val { + return true + } + } + return false +} + +// containsInt 检查整数切片是否包含某值。 +func containsInt(slice []int, val int) bool { + for _, i := range slice { + if i == val { + return true + } + } + return false +} + +// Delete 删除缓存条目。 +func (c *ProxyCache) Delete(key string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.entries, key) +} + +// Clear 清空代理缓存。 +func (c *ProxyCache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + c.entries = make(map[string]*ProxyCacheEntry) + c.pending = make(map[string]*pendingRequest) +} + +// Stats 返回代理缓存统计。 +func (c *ProxyCache) Stats() ProxyCacheStats { + c.mu.RLock() + defer c.mu.RUnlock() + + return ProxyCacheStats{ + Entries: len(c.entries), + Pending: len(c.pending), + } +} + +// ProxyCacheStats 代理缓存统计。 +type ProxyCacheStats struct { + Entries int + Pending int +} \ No newline at end of file