From bd5a2c02021dfe534a519b35d49bb43a0f0d3cd2 Mon Sep 17 00:00:00 2001 From: xfy Date: Fri, 17 Apr 2026 11:31:34 +0800 Subject: [PATCH] =?UTF-8?q?feat(resolver):=20=E4=B8=BA=20DNS=20=E7=BC=93?= =?UTF-8?q?=E5=AD=98=E6=B7=BB=E5=8A=A0=20LRU=20=E6=B7=98=E6=B1=B0=E6=9C=BA?= =?UTF-8?q?=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将 sync.Map 替换为 map + RWMutex,实现基于 cache_size 的 LRU 淘汰: - 添加 lruOrder 链表追踪访问顺序 - 新增 storeCache 方法处理缓存存储和淘汰 - 添加 evictLRULocked、moveToFrontLocked 辅助方法 - 新增 TestCacheSizeLimit、TestLRUEvictionOrder 等测试 Co-Authored-By: Claude Opus 4.7 --- internal/resolver/cache.go | 56 +++---- internal/resolver/resolver.go | 88 ++++++++--- internal/resolver/resolver_bench_test.go | 8 +- internal/resolver/resolver_test.go | 192 +++++++++++++++++++++-- 4 files changed, 279 insertions(+), 65 deletions(-) diff --git a/internal/resolver/cache.go b/internal/resolver/cache.go index 7e9b028..196e83d 100644 --- a/internal/resolver/cache.go +++ b/internal/resolver/cache.go @@ -6,7 +6,6 @@ package resolver import ( - "sync" "time" ) @@ -27,19 +26,16 @@ func (r *DNSResolver) GetCacheStats() CacheStats { // 统计缓存条目 var entries, expired int now := time.Now() - r.cache.Range(func(_ interface{}, value interface{}) bool { + r.mu.RLock() + for _, entry := range r.cache { entries++ - entry, ok := value.(*DNSCacheEntry) - if !ok { - return true - } entry.mu.RLock() if now.After(entry.ExpiresAt) { expired++ } entry.mu.RUnlock() - return true - }) + } + r.mu.RUnlock() return CacheStats{ Hits: hits, @@ -51,28 +47,35 @@ func (r *DNSResolver) GetCacheStats() CacheStats { // GetCacheEntry 获取指定主机的缓存条目(用于测试)。 func (r *DNSResolver) GetCacheEntry(host string) (*DNSCacheEntry, bool) { - if entry, ok := r.cache.Load(host); ok { - cacheEntry, ok := entry.(*DNSCacheEntry) - if !ok { - return nil, false - } - return cacheEntry, true + r.mu.RLock() + entry, ok := r.cache[host] + r.mu.RUnlock() + if !ok { + return nil, false } - return nil, false + return entry, true } // DeleteCacheEntry 删除指定主机的缓存条目。 func (r *DNSResolver) DeleteCacheEntry(host string) { - r.cache.Delete(host) r.mu.Lock() + delete(r.cache, host) + // 从 LRU 链表中移除 + for i, h := range r.lruOrder { + if h == host { + r.lruOrder = append(r.lruOrder[:i], r.lruOrder[i+1:]...) + break + } + } delete(r.refreshHosts, host) r.mu.Unlock() } // ClearCache 清空所有缓存。 func (r *DNSResolver) ClearCache() { - r.cache = sync.Map{} r.mu.Lock() + r.cache = make(map[string]*DNSCacheEntry) + r.lruOrder = make([]string, 0, r.config.CacheSize) r.refreshHosts = make(map[string]struct{}) r.mu.Unlock() } @@ -90,15 +93,14 @@ func (r *DNSResolver) GetHitRate() float64 { // IsCached 检查指定主机是否在缓存中且未过期。 func (r *DNSResolver) IsCached(host string) bool { - if entry, ok := r.cache.Load(host); ok { - cacheEntry, ok := entry.(*DNSCacheEntry) - if !ok { - return false - } - cacheEntry.mu.RLock() - expiresAt := cacheEntry.ExpiresAt - cacheEntry.mu.RUnlock() - return time.Now().Before(expiresAt) + r.mu.RLock() + entry, ok := r.cache[host] + r.mu.RUnlock() + if !ok { + return false } - return false + entry.mu.RLock() + expiresAt := entry.ExpiresAt + entry.mu.RUnlock() + return time.Now().Before(expiresAt) } diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go index 219580c..10ba990 100644 --- a/internal/resolver/resolver.go +++ b/internal/resolver/resolver.go @@ -56,12 +56,55 @@ type Stats struct { AverageLatency time.Duration // 平均解析延迟 } +// storeCache 存入缓存(带 LRU 淘汰)。 +func (r *DNSResolver) storeCache(host string, entry *DNSCacheEntry) { + r.mu.Lock() + defer r.mu.Unlock() + + // 已存在则更新并移到头部 + if _, ok := r.cache[host]; ok { + r.cache[host] = entry + r.moveToFrontLocked(host) + return + } + + // 检查是否需要淘汰 + if r.config.CacheSize > 0 && len(r.cache) >= r.config.CacheSize { + r.evictLRULocked() + } + + r.cache[host] = entry + r.lruOrder = append(r.lruOrder, host) +} + +// evictLRULocked 淘汰最久未使用的条目(需持有锁)。 +func (r *DNSResolver) evictLRULocked() { + if len(r.lruOrder) == 0 { + return + } + oldest := r.lruOrder[0] + delete(r.cache, oldest) + r.lruOrder = r.lruOrder[1:] +} + +// moveToFrontLocked 将条目移到 LRU 链表尾部(最新)(需持有锁)。 +func (r *DNSResolver) moveToFrontLocked(host string) { + for i, h := range r.lruOrder { + if h == host { + r.lruOrder = append(r.lruOrder[:i], r.lruOrder[i+1:]...) + r.lruOrder = append(r.lruOrder, host) + return + } + } +} + // DNSResolver 实现 Resolver 接口的 DNS 解析器。 type DNSResolver struct { config *config.ResolverConfig stopCh chan struct{} refreshHosts map[string]struct{} - cache sync.Map + cache map[string]*DNSCacheEntry // DNS 缓存 + lruOrder []string // LRU 访问顺序(最旧在前) hits atomic.Int64 misses atomic.Int64 errors atomic.Int64 @@ -109,6 +152,8 @@ func New(cfg *config.ResolverConfig) Resolver { config: &configCopy, stopCh: make(chan struct{}), refreshHosts: make(map[string]struct{}), + cache: make(map[string]*DNSCacheEntry), + lruOrder: make([]string, 0, cfg.CacheSize), } } @@ -131,23 +176,24 @@ func (r *DNSResolver) lookup(ctx context.Context, host string, useCache bool) ([ // 尝试从缓存获取 if useCache { - if entry, ok := r.cache.Load(host); ok { - cacheEntry, ok := entry.(*DNSCacheEntry) - if ok { - cacheEntry.mu.RLock() - ips := cacheEntry.IPs - expiresAt := cacheEntry.ExpiresAt - cacheErr := cacheEntry.Error - cacheEntry.mu.RUnlock() + r.mu.RLock() + entry, ok := r.cache[host] + r.mu.RUnlock() - // 缓存未过期,返回缓存结果 - if time.Now().Before(expiresAt) { - r.hits.Add(1) - if cacheErr != nil { - return nil, cacheErr - } - return ips, nil + if ok { + entry.mu.RLock() + ips := entry.IPs + expiresAt := entry.ExpiresAt + cacheErr := entry.Error + entry.mu.RUnlock() + + // 缓存未过期,返回缓存结果 + if time.Now().Before(expiresAt) { + r.hits.Add(1) + if cacheErr != nil { + return nil, cacheErr } + return ips, nil } } } @@ -173,7 +219,7 @@ func (r *DNSResolver) lookup(ctx context.Context, host string, useCache bool) ([ LastLookup: time.Now(), Error: err, } - r.cache.Store(host, entry) + r.storeCache(host, entry) // 添加到刷新列表 r.mu.Lock() @@ -365,11 +411,9 @@ func (r *DNSResolver) Stats() Stats { misses := r.misses.Load() // 统计缓存条目数 - var entries int - r.cache.Range(func(_, _ interface{}) bool { - entries++ - return true - }) + r.mu.RLock() + entries := len(r.cache) + r.mu.RUnlock() // 计算平均延迟 var avgLatency time.Duration diff --git a/internal/resolver/resolver_bench_test.go b/internal/resolver/resolver_bench_test.go index 694bb4b..66be109 100644 --- a/internal/resolver/resolver_bench_test.go +++ b/internal/resolver/resolver_bench_test.go @@ -29,7 +29,7 @@ func createTestResolver() *DNSResolver { // 预填充缓存条目,模拟真实的解析场景 for i := 0; i < 100; i++ { host := fmt.Sprintf("host%d.example.com", i) - r.cache.Store(host, &DNSCacheEntry{ + r.storeCache(host, &DNSCacheEntry{ IPs: []string{fmt.Sprintf("192.168.1.%d", i%256), fmt.Sprintf("192.168.2.%d", i%256)}, ExpiresAt: time.Now().Add(30 * time.Second), LastLookup: time.Now(), @@ -106,7 +106,7 @@ func BenchmarkDNSResolverConcurrent(b *testing.B) { // 只添加一个缓存条目,所有 goroutine 都访问同一个条目 targetHost := "concurrent.example.com" - r.cache.Store(targetHost, &DNSCacheEntry{ + r.storeCache(targetHost, &DNSCacheEntry{ IPs: []string{"10.0.0.1", "10.0.0.2", "10.0.0.3"}, ExpiresAt: time.Now().Add(30 * time.Second), LastLookup: time.Now(), @@ -158,7 +158,7 @@ func BenchmarkDNSResolverCacheExpiry(b *testing.B) { host := "127.0.0.1" // 预存储一个已过期的条目 - r.cache.Store(host, &DNSCacheEntry{ + r.storeCache(host, &DNSCacheEntry{ IPs: []string{"192.168.1.1"}, ExpiresAt: time.Now().Add(-1 * time.Second), // 已过期 LastLookup: time.Now().Add(-2 * time.Second), @@ -216,7 +216,7 @@ func BenchmarkDNSResolverMixedWorkload(b *testing.B) { // 预填充一些缓存 for i := 0; i < 50; i++ { host := fmt.Sprintf("cached%d.example.com", i) - r.cache.Store(host, &DNSCacheEntry{ + r.storeCache(host, &DNSCacheEntry{ IPs: []string{fmt.Sprintf("192.168.1.%d", i%256)}, ExpiresAt: time.Now().Add(30 * time.Second), }) diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go index a7f5778..bf374ca 100644 --- a/internal/resolver/resolver_test.go +++ b/internal/resolver/resolver_test.go @@ -108,7 +108,7 @@ func TestCache(t *testing.T) { r := New(cfg).(*DNSResolver) // 模拟缓存条目 - r.cache.Store("test.example.com", &DNSCacheEntry{ + r.storeCache("test.example.com", &DNSCacheEntry{ IPs: []string{"192.168.1.1", "192.168.1.2"}, ExpiresAt: time.Now().Add(1 * time.Minute), }) @@ -133,7 +133,7 @@ func TestCache(t *testing.T) { // 测试缓存过期 // 更新缓存条目为过期 - r.cache.Store("test.example.com", &DNSCacheEntry{ + r.storeCache("test.example.com", &DNSCacheEntry{ IPs: []string{"192.168.1.1"}, ExpiresAt: time.Now().Add(-1 * time.Second), // 已过期 }) @@ -157,13 +157,13 @@ func TestIsCached(t *testing.T) { r := New(cfg).(*DNSResolver) // 添加未过期的缓存 - r.cache.Store("active.example.com", &DNSCacheEntry{ + r.storeCache("active.example.com", &DNSCacheEntry{ IPs: []string{"192.168.1.1"}, ExpiresAt: time.Now().Add(1 * time.Minute), }) // 添加已过期的缓存 - r.cache.Store("expired.example.com", &DNSCacheEntry{ + r.storeCache("expired.example.com", &DNSCacheEntry{ IPs: []string{"192.168.1.2"}, ExpiresAt: time.Now().Add(-1 * time.Second), }) @@ -194,7 +194,7 @@ func TestCacheHitRate(t *testing.T) { } // 模拟缓存命中 - r.cache.Store("test.example.com", &DNSCacheEntry{ + r.storeCache("test.example.com", &DNSCacheEntry{ IPs: []string{"192.168.1.1"}, ExpiresAt: time.Now().Add(1 * time.Minute), }) @@ -227,11 +227,11 @@ func TestStats(t *testing.T) { r := New(cfg).(*DNSResolver) // 添加缓存条目 - r.cache.Store("test1.example.com", &DNSCacheEntry{ + r.storeCache("test1.example.com", &DNSCacheEntry{ IPs: []string{"192.168.1.1"}, ExpiresAt: time.Now().Add(1 * time.Minute), }) - r.cache.Store("test2.example.com", &DNSCacheEntry{ + r.storeCache("test2.example.com", &DNSCacheEntry{ IPs: []string{"192.168.1.2"}, ExpiresAt: time.Now().Add(1 * time.Minute), }) @@ -329,7 +329,7 @@ func TestDeleteCacheEntry(t *testing.T) { r := New(cfg).(*DNSResolver) // 添加缓存 - r.cache.Store("test.example.com", &DNSCacheEntry{ + r.storeCache("test.example.com", &DNSCacheEntry{ IPs: []string{"192.168.1.1"}, ExpiresAt: time.Now().Add(1 * time.Minute), }) @@ -358,7 +358,7 @@ func TestClearCache(t *testing.T) { // 添加多个缓存 for i := 0; i < 5; i++ { host := fmt.Sprintf("test%d.example.com", i) - r.cache.Store(host, &DNSCacheEntry{ + r.storeCache(host, &DNSCacheEntry{ IPs: []string{fmt.Sprintf("192.168.1.%d", i)}, ExpiresAt: time.Now().Add(1 * time.Minute), }) @@ -387,7 +387,7 @@ func TestConcurrentAccess(t *testing.T) { r := New(cfg).(*DNSResolver) // 添加测试缓存 - r.cache.Store("test.example.com", &DNSCacheEntry{ + r.storeCache("test.example.com", &DNSCacheEntry{ IPs: []string{"192.168.1.1"}, ExpiresAt: time.Now().Add(1 * time.Minute), }) @@ -533,13 +533,13 @@ func TestCacheStats(t *testing.T) { r := New(cfg).(*DNSResolver) // 添加活跃缓存 - r.cache.Store("active.example.com", &DNSCacheEntry{ + r.storeCache("active.example.com", &DNSCacheEntry{ IPs: []string{"192.168.1.1"}, ExpiresAt: time.Now().Add(1 * time.Minute), }) // 添加过期缓存 - r.cache.Store("expired.example.com", &DNSCacheEntry{ + r.storeCache("expired.example.com", &DNSCacheEntry{ IPs: []string{"192.168.1.2"}, ExpiresAt: time.Now().Add(-1 * time.Second), }) @@ -566,3 +566,171 @@ func TestCacheStats(t *testing.T) { t.Errorf("expected 1 expired, got %d", stats.Expired) } } + +// TestCacheSizeLimit 测试缓存大小限制。 +func TestCacheSizeLimit(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Valid: 30 * time.Second, + CacheSize: 3, // 限制 3 个条目 + } + + r := New(cfg).(*DNSResolver) + + // 添加 5 个缓存条目,应淘汰 2 个 + for i := 0; i < 5; i++ { + host := fmt.Sprintf("host%d.example.com", i) + r.storeCache(host, &DNSCacheEntry{ + IPs: []string{fmt.Sprintf("192.168.1.%d", i)}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + } + + // 验证缓存条目数不超过限制 + stats := r.GetCacheStats() + if stats.Entries > 3 { + t.Errorf("expected at most 3 entries with CacheSize=3, got %d", stats.Entries) + } + + // 验证最早添加的条目被淘汰(LRU) + if r.IsCached("host0.example.com") { + t.Error("host0.example.com should be evicted (oldest)") + } + if r.IsCached("host1.example.com") { + t.Error("host1.example.com should be evicted (second oldest)") + } + + // 验证最新添加的条目存在 + if !r.IsCached("host2.example.com") { + t.Error("host2.example.com should be cached") + } + if !r.IsCached("host3.example.com") { + t.Error("host3.example.com should be cached") + } + if !r.IsCached("host4.example.com") { + t.Error("host4.example.com should be cached") + } +} + +// TestCacheSizeZero 测试 cache_size=0 时无限制。 +func TestCacheSizeZero(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Valid: 30 * time.Second, + CacheSize: 0, // 无限制 + } + + r := New(cfg).(*DNSResolver) + + // 添加大量缓存条目 + for i := 0; i < 100; i++ { + host := fmt.Sprintf("host%d.example.com", i) + r.storeCache(host, &DNSCacheEntry{ + IPs: []string{fmt.Sprintf("192.168.1.%d", i%256)}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + } + + // 验证所有条目都存在 + stats := r.GetCacheStats() + if stats.Entries != 100 { + t.Errorf("expected 100 entries with CacheSize=0, got %d", stats.Entries) + } +} + +// TestLRUEvictionOrder 测试 LRU 淘汰顺序。 +func TestLRUEvictionOrder(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Valid: 30 * time.Second, + CacheSize: 3, + } + + r := New(cfg).(*DNSResolver) + + // 添加 3 个条目填满缓存 + r.storeCache("a.example.com", &DNSCacheEntry{ + IPs: []string{"192.168.1.1"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + r.storeCache("b.example.com", &DNSCacheEntry{ + IPs: []string{"192.168.1.2"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + r.storeCache("c.example.com", &DNSCacheEntry{ + IPs: []string{"192.168.1.3"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + + // 访问 a.example.com 使其变为最新 + r.storeCache("a.example.com", &DNSCacheEntry{ + IPs: []string{"192.168.1.1"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + + // 添加新条目,应淘汰 b.example.com(最久未使用) + r.storeCache("d.example.com", &DNSCacheEntry{ + IPs: []string{"192.168.1.4"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + + // 验证淘汰顺序 + if r.IsCached("b.example.com") { + t.Error("b.example.com should be evicted (least recently used)") + } + if !r.IsCached("a.example.com") { + t.Error("a.example.com should be cached (recently accessed)") + } + if !r.IsCached("c.example.com") { + t.Error("c.example.com should be cached") + } + if !r.IsCached("d.example.com") { + t.Error("d.example.com should be cached (newly added)") + } +} + +// TestCacheUpdatePreservesOrder 测试更新已存在条目不触发淘汰。 +func TestCacheUpdatePreservesOrder(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Valid: 30 * time.Second, + CacheSize: 3, + } + + r := New(cfg).(*DNSResolver) + + // 添加 3 个条目填满缓存 + r.storeCache("a.example.com", &DNSCacheEntry{ + IPs: []string{"192.168.1.1"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + r.storeCache("b.example.com", &DNSCacheEntry{ + IPs: []string{"192.168.1.2"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + r.storeCache("c.example.com", &DNSCacheEntry{ + IPs: []string{"192.168.1.3"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + + // 更新已存在的条目(不应触发淘汰) + r.storeCache("b.example.com", &DNSCacheEntry{ + IPs: []string{"192.168.1.20"}, // 新 IP + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + + // 验证所有条目仍然存在 + stats := r.GetCacheStats() + if stats.Entries != 3 { + t.Errorf("expected 3 entries after update, got %d", stats.Entries) + } + + // 验证更新生效 + entry, ok := r.GetCacheEntry("b.example.com") + if !ok { + t.Fatal("b.example.com should exist") + } + if len(entry.IPs) != 1 || entry.IPs[0] != "192.168.1.20" { + t.Errorf("expected IP 192.168.1.20, got %v", entry.IPs) + } +}