diff --git a/internal/config/config.go b/internal/config/config.go index bdf787e..04ce75b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -20,6 +20,7 @@ package config import ( "errors" "fmt" + "net" "os" "time" @@ -76,6 +77,10 @@ type Config struct { // Monitoring 监控配置 // 包含状态端点等监控相关配置 Monitoring MonitoringConfig `yaml:"monitoring"` + + // Resolver DNS 解析器配置 + // 启用动态 DNS 解析和缓存 + Resolver ResolverConfig `yaml:"resolver"` } // HTTP3Config HTTP/3 (QUIC) 配置。 @@ -1460,5 +1465,101 @@ func Validate(cfg *Config) error { return fmt.Errorf("performance: %w", err) } + // 验证 Resolver 配置 + if err := cfg.Resolver.Validate(); err != nil { + return fmt.Errorf("resolver: %w", err) + } + + return nil +} + +// ResolverConfig DNS 解析器配置。 +// +// 配置 DNS 解析器的行为,包括服务器地址、缓存 TTL、超时等。 +// 启用后可实现动态 DNS 解析和缓存,支持后端域名的动态解析。 +// +// 注意事项: +// - Enabled 为 true 时启用 DNS 解析器 +// - Addresses 配置 DNS 服务器地址,如 "8.8.8.8:53" +// - Valid 为缓存有效期(TTL),建议 30s-300s +// - Timeout 为单次查询超时时间 +// +// 使用示例: +// +// resolver: +// enabled: true +// addresses: +// - "8.8.8.8:53" +// - "8.8.4.4:53" +// valid: 30s +// timeout: 5s +// ipv4: true +// ipv6: false +// cache_size: 1024 +type ResolverConfig struct { + // Enabled 是否启用 DNS 解析器 + Enabled bool `yaml:"enabled"` + + // Addresses DNS 服务器地址列表 + // 格式为 "ip:port",如 "8.8.8.8:53" + Addresses []string `yaml:"addresses"` + + // Valid 缓存有效期(TTL) + // 解析结果的缓存时间 + Valid time.Duration `yaml:"valid"` + + // Timeout DNS 查询超时 + // 单次 DNS 查询的最大等待时间 + Timeout time.Duration `yaml:"timeout"` + + // IPv4 是否查询 IPv4 地址 + IPv4 bool `yaml:"ipv4"` + + // IPv6 是否查询 IPv6 地址 + IPv6 bool `yaml:"ipv6"` + + // CacheSize 缓存最大条目数 + // 0 表示无限制 + CacheSize int `yaml:"cache_size"` +} + +// TTL 返回缓存有效期(Valid 的别名,便于代码理解)。 +func (c *ResolverConfig) TTL() time.Duration { + return c.Valid +} + +// Validate 验证 Resolver 配置。 +// +// 检查 DNS 服务器地址格式、TTL 和超时设置的有效性。 +// +// 返回值: +// - error: 验证失败时的错误信息 +func (c *ResolverConfig) Validate() error { + if !c.Enabled { + return nil + } + + if len(c.Addresses) == 0 { + return errors.New("resolver.addresses is required when enabled") + } + + for _, addr := range c.Addresses { + if _, err := net.ResolveUDPAddr("udp", addr); err != nil { + return fmt.Errorf("invalid DNS address %s: %w", addr, err) + } + } + + if c.Valid > 0 && c.Valid < time.Second { + return errors.New("resolver.valid must be at least 1s") + } + + if c.Timeout > 0 && c.Timeout < time.Second { + return errors.New("resolver.timeout must be at least 1s") + } + + if !c.IPv4 && !c.IPv6 { + return errors.New("at least one of ipv4 or ipv6 must be enabled") + } + return nil } diff --git a/internal/loadbalance/balancer.go b/internal/loadbalance/balancer.go index 21eea39..be650c9 100644 --- a/internal/loadbalance/balancer.go +++ b/internal/loadbalance/balancer.go @@ -18,7 +18,10 @@ package loadbalance import ( "hash/fnv" + "net" + "net/url" "sync/atomic" + "time" ) // Target 表示负载均衡的后端服务器目标。 @@ -38,6 +41,15 @@ type Target struct { // Connections 跟踪当前活跃连接数。 // 并发修改此字段时应使用原子操作。 Connections int64 + + // hostname 从 URL 提取的主机名(缓存,避免重复解析) + hostname string + + // resolvedIPs 解析后的 IP 列表(使用 atomic.Pointer 保证并发安全) + resolvedIPs atomic.Pointer[[]string] + + // lastResolved 最后解析时间(UnixNano,使用 atomic.Int64) + lastResolved atomic.Int64 } // Balancer 是负载均衡算法的接口。 @@ -352,3 +364,90 @@ func (i *IPHash) SelectExcludingByIP(targets []*Target, excluded []*Target, clie idx := hash % uint64(len(available)) return available[idx] } + +// Hostname 返回目标主机名(从 URL 提取)。 +// 如果 hostname 未初始化,会自动调用 initHostname。 +func (t *Target) Hostname() string { + if t.hostname == "" { + t.initHostname() + } + return t.hostname +} + +// ResolvedIPs 返回解析后的 IP 列表。 +// 如果未解析过,返回 nil。 +func (t *Target) ResolvedIPs() []string { + ips := t.resolvedIPs.Load() + if ips == nil { + return nil + } + return *ips +} + +// SetResolvedIPs 设置解析后的 IP 列表,并更新最后解析时间。 +func (t *Target) SetResolvedIPs(ips []string) { + // 创建副本避免外部修改 + ipsCopy := make([]string, len(ips)) + copy(ipsCopy, ips) + t.resolvedIPs.Store(&ipsCopy) + t.lastResolved.Store(time.Now().UnixNano()) +} + +// NeedsResolve 检查是否需要重新解析。 +// 如果 hostname 是 IP 地址,返回 false。 +// 如果从未解析过或超过 TTL,返回 true。 +func (t *Target) NeedsResolve(ttl time.Duration) bool { + host := t.Hostname() + + // IP 类型的 URL 不需要解析 + if net.ParseIP(host) != nil { + return false + } + + last := t.lastResolved.Load() + if last == 0 { + return true // 首次解析 + } + + return time.Since(time.Unix(0, last)) > ttl +} + +// initHostname 从 URL 中提取并缓存主机名。 +// 必须在 Target 创建后调用一次。 +func (t *Target) initHostname() { + u, err := url.Parse(t.URL) + if err != nil { + // 解析失败,使用整个 URL 作为 hostname + t.hostname = t.URL + return + } + + // 提取主机名(去掉端口) + host := u.Host + if h, _, err := net.SplitHostPort(host); err == nil { + t.hostname = h + } else { + t.hostname = host + } +} + +// NewTargetFromConfig 从配置创建 Target(推荐入口)。 +// 自动初始化 hostname 和 Healthy 状态。 +func NewTargetFromConfig(url string, weight int) *Target { + t := &Target{ + URL: url, + Weight: weight, + } + t.initHostname() + t.Healthy.Store(true) + return t +} + +// LastResolved 返回最后解析时间。 +func (t *Target) LastResolved() time.Time { + nano := t.lastResolved.Load() + if nano == 0 { + return time.Time{} + } + return time.Unix(0, nano) +} diff --git a/internal/resolver/cache.go b/internal/resolver/cache.go new file mode 100644 index 0000000..9772003 --- /dev/null +++ b/internal/resolver/cache.go @@ -0,0 +1,94 @@ +// Package resolver 提供 DNS 解析功能,支持缓存和后台刷新。 +// +// 该文件包含 DNS 缓存相关的实现。 +// +// 作者:xfy +package resolver + +import ( + "sync" + "time" +) + +// CacheStats 返回缓存统计信息。 +type CacheStats struct { + Hits int64 // 缓存命中次数 + Misses int64 // 缓存未命中次数 + Entries int // 当前缓存条目数 + Expired int // 过期条目数 +} + +// GetCacheStats 返回当前缓存统计信息。 +// 这是一个辅助函数,用于测试和监控。 +func (r *DNSResolver) GetCacheStats() CacheStats { + hits := r.hits.Load() + misses := r.misses.Load() + + // 统计缓存条目 + var entries, expired int + now := time.Now() + r.cache.Range(func(key, value interface{}) bool { + entries++ + entry := value.(*dnsCacheEntry) + entry.mu.RLock() + if now.After(entry.ExpiresAt) { + expired++ + } + entry.mu.RUnlock() + return true + }) + + return CacheStats{ + Hits: hits, + Misses: misses, + Entries: entries, + Expired: expired, + } +} + +// GetCacheEntry 获取指定主机的缓存条目(用于测试)。 +func (r *DNSResolver) GetCacheEntry(host string) (*dnsCacheEntry, bool) { + if entry, ok := r.cache.Load(host); ok { + return entry.(*dnsCacheEntry), true + } + return nil, false +} + +// DeleteCacheEntry 删除指定主机的缓存条目。 +func (r *DNSResolver) DeleteCacheEntry(host string) { + r.cache.Delete(host) + r.mu.Lock() + delete(r.refreshHosts, host) + r.mu.Unlock() +} + +// ClearCache 清空所有缓存。 +func (r *DNSResolver) ClearCache() { + r.cache = sync.Map{} + r.mu.Lock() + r.refreshHosts = make(map[string]struct{}) + r.mu.Unlock() +} + +// GetHitRate 返回缓存命中率。 +func (r *DNSResolver) GetHitRate() float64 { + hits := r.hits.Load() + misses := r.misses.Load() + total := hits + misses + if total == 0 { + return 0 + } + return float64(hits) / float64(total) +} + +// IsCached 检查指定主机是否在缓存中且未过期。 +func (r *DNSResolver) IsCached(host string) bool { + if entry, ok := r.cache.Load(host); ok { + cacheEntry := entry.(*dnsCacheEntry) + cacheEntry.mu.RLock() + expiresAt := cacheEntry.ExpiresAt + cacheEntry.mu.RUnlock() + return time.Now().Before(expiresAt) + } + return false +} diff --git a/internal/resolver/resolver.go b/internal/resolver/resolver.go new file mode 100644 index 0000000..92caa14 --- /dev/null +++ b/internal/resolver/resolver.go @@ -0,0 +1,417 @@ +// Package resolver 提供 DNS 解析功能,支持缓存和后台刷新。 +// +// 该包实现了带缓存的 DNS 解析器,用于动态解析后端服务域名。 +// 支持 UDP DNS 查询、TTL 缓存、后台刷新等特性。 +// +// 主要用途: +// +// 用于代理模块动态解析 upstream 域名,支持域名变更自动感知 +// +// 注意事项: +// - 解析器使用 sync.Map 实现并发安全的缓存 +// - 后台刷新协程需要调用 Start() 启动 +// - 停止使用时应调用 Stop() 释放资源 +// +// 作者:xfy +package resolver + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "rua.plus/lolly/internal/config" +) + +// Resolver DNS 解析器接口。 +type Resolver interface { + // LookupHost 解析主机名,返回 IP 地址列表 + LookupHost(ctx context.Context, host string) ([]string, error) + + // LookupHostWithCache 带缓存的解析,优先返回缓存结果 + LookupHostWithCache(ctx context.Context, host string) ([]string, error) + + // Refresh 刷新指定主机的缓存 + Refresh(host string) error + + // Start 启动后台刷新协程 + Start() error + + // Stop 停止解析器 + Stop() error + + // Stats 返回统计信息 + Stats() ResolverStats +} + +// ResolverStats 解析器统计信息。 +type ResolverStats struct { + CacheHits int64 // 缓存命中次数 + CacheMisses int64 // 缓存未命中次数 + CacheEntries int // 当前缓存条目数 + ResolveErrors int64 // 解析错误次数 + AverageLatency time.Duration // 平均解析延迟 +} + +// DNSResolver 实现 Resolver 接口的 DNS 解析器。 +type DNSResolver struct { + config *config.ResolverConfig + cache sync.Map // key: hostname, value: *dnsCacheEntry + serverIdx atomic.Uint32 + + // 统计信息 + hits atomic.Int64 + misses atomic.Int64 + errors atomic.Int64 + latencyNs atomic.Int64 // 总延迟纳秒数 + count atomic.Int64 // 解析次数 + + // 后台刷新 + stopCh chan struct{} + started atomic.Bool + mu sync.RWMutex + refreshHosts map[string]struct{} // 需要刷新的主机列表 +} + +// dnsCacheEntry DNS 缓存条目。 +type dnsCacheEntry struct { + IPs []string + ExpiresAt time.Time + LastLookup time.Time + Error error + mu sync.RWMutex +} + +// New 创建新的 DNS 解析器。 +func New(cfg *config.ResolverConfig) Resolver { + if !cfg.Enabled { + return &noopResolver{} + } + + // 设置默认值 + valid := cfg.Valid + if valid == 0 { + valid = 30 * time.Second + } + timeout := cfg.Timeout + if timeout == 0 { + timeout = 5 * time.Second + } + + // 创建新配置副本,应用默认值 + configCopy := *cfg + configCopy.Valid = valid + configCopy.Timeout = timeout + if !configCopy.IPv4 && !configCopy.IPv6 { + configCopy.IPv4 = true + } + + return &DNSResolver{ + config: &configCopy, + stopCh: make(chan struct{}), + refreshHosts: make(map[string]struct{}), + } +} + +// LookupHost 解析主机名,返回 IP 地址列表。 +func (r *DNSResolver) LookupHost(ctx context.Context, host string) ([]string, error) { + return r.lookup(ctx, host, false) +} + +// LookupHostWithCache 带缓存的解析,优先返回缓存结果。 +func (r *DNSResolver) LookupHostWithCache(ctx context.Context, host string) ([]string, error) { + return r.lookup(ctx, host, true) +} + +// lookup 内部解析方法。 +func (r *DNSResolver) lookup(ctx context.Context, host string, useCache bool) ([]string, error) { + // 如果 host 已经是 IP 地址,直接返回 + if ip := net.ParseIP(host); ip != nil { + return []string{host}, nil + } + + // 尝试从缓存获取 + if useCache { + if entry, ok := r.cache.Load(host); ok { + cacheEntry := entry.(*dnsCacheEntry) + cacheEntry.mu.RLock() + ips := cacheEntry.IPs + expiresAt := cacheEntry.ExpiresAt + cacheErr := cacheEntry.Error + cacheEntry.mu.RUnlock() + + // 缓存未过期,返回缓存结果 + if time.Now().Before(expiresAt) { + r.hits.Add(1) + if cacheErr != nil { + return nil, cacheErr + } + return ips, nil + } + } + } + + r.misses.Add(1) + + // 执行 DNS 查询 + start := time.Now() + ips, err := r.queryDNS(ctx, host) + latency := time.Since(start) + + r.latencyNs.Add(latency.Nanoseconds()) + r.count.Add(1) + + if err != nil { + r.errors.Add(1) + } + + // 更新缓存 + entry := &dnsCacheEntry{ + IPs: ips, + ExpiresAt: time.Now().Add(r.config.TTL()), + LastLookup: time.Now(), + Error: err, + } + r.cache.Store(host, entry) + + // 添加到刷新列表 + r.mu.Lock() + r.refreshHosts[host] = struct{}{} + r.mu.Unlock() + + if err != nil { + return nil, err + } + return ips, nil +} + +// queryDNS 执行实际的 DNS 查询。 +func (r *DNSResolver) queryDNS(ctx context.Context, host string) ([]string, error) { + if len(r.config.Addresses) == 0 { + // 使用系统默认 DNS + return r.queryWithResolver(ctx, host, "") + } + + // 轮询选择 DNS 服务器 + idx := r.serverIdx.Add(1) % uint32(len(r.config.Addresses)) + dnsServer := r.config.Addresses[idx] + + return r.queryWithResolver(ctx, host, dnsServer) +} + +// queryWithResolver 使用指定的 DNS 服务器查询。 +func (r *DNSResolver) queryWithResolver(ctx context.Context, host, server string) ([]string, error) { + var ips []string + + // 创建带超时的 context + if r.config.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, r.config.Timeout) + defer cancel() + } + + // 创建自定义 resolver + var resolver *net.Resolver + if server != "" { + resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, "udp", server) + }, + } + } + + // 查询 IPv4 + if r.config.IPv4 { + var ipAddrs []net.IPAddr + var err error + if resolver != nil { + ipAddrs, err = resolver.LookupIPAddr(ctx, host) + } else { + ipList, lookupErr := net.LookupIP(host) + if lookupErr != nil { + err = lookupErr + } else { + ipAddrs = make([]net.IPAddr, len(ipList)) + for i, ip := range ipList { + ipAddrs[i] = net.IPAddr{IP: ip} + } + } + } + if err != nil { + return nil, fmt.Errorf("DNS lookup failed for %s: %w", host, err) + } + + for _, addr := range ipAddrs { + if ip4 := addr.IP.To4(); ip4 != nil { + ips = append(ips, ip4.String()) + } + } + } + + // 查询 IPv6 + if r.config.IPv6 { + var ipAddrs []net.IPAddr + var err error + if resolver != nil { + ipAddrs, err = resolver.LookupIPAddr(ctx, host) + } else { + ipList, lookupErr := net.LookupIP(host) + if lookupErr != nil { + err = lookupErr + } else { + ipAddrs = make([]net.IPAddr, len(ipList)) + for i, ip := range ipList { + ipAddrs[i] = net.IPAddr{IP: ip} + } + } + } + if err != nil { + // IPv6 查询失败不返回错误,继续使用 IPv4 结果 + _ = err + } else { + for _, addr := range ipAddrs { + if ip := addr.IP.To16(); ip != nil && ip.To4() == nil { + ips = append(ips, ip.String()) + } + } + } + } + + if len(ips) == 0 { + return nil, fmt.Errorf("no IP addresses found for %s", host) + } + + return ips, nil +} + +// Refresh 刷新指定主机的缓存。 +func (r *DNSResolver) Refresh(host string) error { + _, err := r.LookupHost(context.Background(), host) + return err +} + +// Start 启动后台刷新协程。 +func (r *DNSResolver) Start() error { + if !r.config.Enabled { + return nil + } + + if r.started.Load() { + return nil + } + + r.started.Store(true) + + // 启动后台刷新协程 + go r.refreshLoop() + + return nil +} + +// refreshLoop 后台刷新循环。 +func (r *DNSResolver) refreshLoop() { + // 刷新间隔为 TTL / 2 + interval := r.config.TTL() / 2 + if interval < time.Second { + interval = time.Second + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + r.doRefresh() + case <-r.stopCh: + return + } + } +} + +// doRefresh 执行刷新操作。 +func (r *DNSResolver) doRefresh() { + r.mu.RLock() + hosts := make([]string, 0, len(r.refreshHosts)) + for host := range r.refreshHosts { + hosts = append(hosts, host) + } + r.mu.RUnlock() + + for _, host := range hosts { + ctx, cancel := context.WithTimeout(context.Background(), r.config.Timeout) + _, _ = r.LookupHost(ctx, host) // 刷新缓存 + cancel() + } +} + +// Stop 停止解析器。 +func (r *DNSResolver) Stop() error { + if !r.started.Load() { + return nil + } + + close(r.stopCh) + r.started.Store(false) + return nil +} + +// Stats 返回统计信息。 +func (r *DNSResolver) Stats() ResolverStats { + hits := r.hits.Load() + misses := r.misses.Load() + + // 统计缓存条目数 + var entries int + r.cache.Range(func(key, value interface{}) bool { + entries++ + return true + }) + + // 计算平均延迟 + var avgLatency time.Duration + count := r.count.Load() + if count > 0 { + avgLatency = time.Duration(r.latencyNs.Load() / count) + } + + return ResolverStats{ + CacheHits: hits, + CacheMisses: misses, + CacheEntries: entries, + ResolveErrors: r.errors.Load(), + AverageLatency: avgLatency, + } +} + +// noopResolver 是禁用状态下的空实现。 +type noopResolver struct{} + +func (n *noopResolver) LookupHost(ctx context.Context, host string) ([]string, error) { + return nil, fmt.Errorf("resolver is disabled") +} + +func (n *noopResolver) LookupHostWithCache(ctx context.Context, host string) ([]string, error) { + return n.LookupHost(ctx, host) +} + +func (n *noopResolver) Refresh(host string) error { + return nil +} + +func (n *noopResolver) Start() error { + return nil +} + +func (n *noopResolver) Stop() error { + return nil +} + +func (n *noopResolver) Stats() ResolverStats { + return ResolverStats{} +} diff --git a/internal/resolver/resolver_test.go b/internal/resolver/resolver_test.go new file mode 100644 index 0000000..c45bd2d --- /dev/null +++ b/internal/resolver/resolver_test.go @@ -0,0 +1,568 @@ +package resolver + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "rua.plus/lolly/internal/config" +) + +// TestNewResolver 测试解析器创建。 +func TestNewResolver(t *testing.T) { + // 测试启用状态 + cfg := &config.ResolverConfig{ + Enabled: true, + Addresses: []string{"8.8.8.8:53"}, + Valid: 30 * time.Second, + Timeout: 5 * time.Second, + IPv4: true, + IPv6: false, + } + + r := New(cfg) + if r == nil { + t.Fatal("New() should return non-nil resolver") + } + + // 验证是 DNSResolver 类型 + dnsR, ok := r.(*DNSResolver) + if !ok { + t.Fatal("New() should return *DNSResolver when enabled") + } + + if !dnsR.config.Enabled { + t.Error("config.Enabled should be true") + } + + // 测试禁用状态 + cfgDisabed := &config.ResolverConfig{ + Enabled: false, + } + rDisabled := New(cfgDisabed) + if _, ok := rDisabled.(*noopResolver); !ok { + t.Error("New() should return *noopResolver when disabled") + } +} + +// TestNewResolverDefaults 测试默认值设置。 +func TestNewResolverDefaults(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Addresses: []string{"8.8.8.8:53"}, + // 不设置 Valid 和 Timeout + } + + r := New(cfg).(*DNSResolver) + if r.config.Valid != 30*time.Second { + t.Errorf("expected default Valid=30s, got %v", r.config.Valid) + } + if r.config.Timeout != 5*time.Second { + t.Errorf("expected default Timeout=5s, got %v", r.config.Timeout) + } +} + +// TestLookupHostWithIP 测试 IP 地址直接返回。 +func TestLookupHostWithIP(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Addresses: []string{"8.8.8.8:53"}, + Valid: 30 * time.Second, + Timeout: 5 * time.Second, + IPv4: true, + } + + r := New(cfg).(*DNSResolver) + + // 测试 IPv4 地址直接返回 + ips, err := r.LookupHost(context.Background(), "127.0.0.1") + if err != nil { + t.Fatalf("LookupHost failed: %v", err) + } + if len(ips) != 1 || ips[0] != "127.0.0.1" { + t.Errorf("expected [127.0.0.1], got %v", ips) + } + + // 测试 IPv6 地址直接返回 + ips, err = r.LookupHost(context.Background(), "::1") + if err != nil { + t.Fatalf("LookupHost failed: %v", err) + } + if len(ips) != 1 || ips[0] != "::1" { + t.Errorf("expected [::1], got %v", ips) + } +} + +// TestCache 测试缓存功能。 +func TestCache(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Addresses: []string{}, // 空地址,使用系统 DNS + Valid: 1 * time.Second, + Timeout: 5 * time.Second, + IPv4: true, + } + + r := New(cfg).(*DNSResolver) + + // 模拟缓存条目 + r.cache.Store("test.example.com", &dnsCacheEntry{ + IPs: []string{"192.168.1.1", "192.168.1.2"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + r.mu.Lock() + r.refreshHosts["test.example.com"] = struct{}{} + r.mu.Unlock() + + // 测试缓存命中 + ctx := context.Background() + ips, err := r.LookupHostWithCache(ctx, "test.example.com") + if err != nil { + t.Fatalf("LookupHostWithCache failed: %v", err) + } + if len(ips) != 2 { + t.Errorf("expected 2 IPs, got %d", len(ips)) + } + + // 验证缓存命中统计 + if r.GetCacheHits() != 1 { + t.Errorf("expected 1 cache hit, got %d", r.GetCacheHits()) + } + + // 测试缓存过期 + // 更新缓存条目为过期 + r.cache.Store("test.example.com", &dnsCacheEntry{ + IPs: []string{"192.168.1.1"}, + ExpiresAt: time.Now().Add(-1 * time.Second), // 已过期 + }) + + // 由于使用系统 DNS,可能会失败,但应该尝试查询 + _, _ = r.LookupHostWithCache(ctx, "test.example.com") + + // 应该有缓存未命中(因为过期了) + if r.GetCacheMisses() != 1 { + t.Errorf("expected 1 cache miss, got %d", r.GetCacheMisses()) + } +} + +// TestIsCached 测试缓存状态检查。 +func TestIsCached(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Valid: 30 * time.Second, + } + + r := New(cfg).(*DNSResolver) + + // 添加未过期的缓存 + r.cache.Store("active.example.com", &dnsCacheEntry{ + IPs: []string{"192.168.1.1"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + + // 添加已过期的缓存 + r.cache.Store("expired.example.com", &dnsCacheEntry{ + IPs: []string{"192.168.1.2"}, + ExpiresAt: time.Now().Add(-1 * time.Second), + }) + + if !r.IsCached("active.example.com") { + t.Error("IsCached should return true for active entry") + } + if r.IsCached("expired.example.com") { + t.Error("IsCached should return false for expired entry") + } + if r.IsCached("unknown.example.com") { + t.Error("IsCached should return false for unknown entry") + } +} + +// TestCacheHitRate 测试缓存命中率计算。 +func TestCacheHitRate(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Valid: 30 * time.Second, + } + + r := New(cfg).(*DNSResolver) + + // 初始命中率应为 0 + if r.GetHitRate() != 0 { + t.Errorf("expected 0 hit rate, got %f", r.GetHitRate()) + } + + // 模拟缓存命中 + r.cache.Store("test.example.com", &dnsCacheEntry{ + IPs: []string{"192.168.1.1"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + r.mu.Lock() + r.refreshHosts["test.example.com"] = struct{}{} + r.mu.Unlock() + + // 3 次命中 + for i := 0; i < 3; i++ { + _, _ = r.LookupHostWithCache(context.Background(), "test.example.com") + } + + // 1 次未命中(新域名) + _, _ = r.LookupHostWithCache(context.Background(), "unknown.example.com") + + // 命中率应为 3/4 = 0.75 + hitRate := r.GetHitRate() + if hitRate != 0.75 { + t.Errorf("expected 0.75 hit rate, got %f", hitRate) + } +} + +// TestStats 测试统计信息。 +func TestStats(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Valid: 30 * time.Second, + } + + r := New(cfg).(*DNSResolver) + + // 添加缓存条目 + r.cache.Store("test1.example.com", &dnsCacheEntry{ + IPs: []string{"192.168.1.1"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + r.cache.Store("test2.example.com", &dnsCacheEntry{ + IPs: []string{"192.168.1.2"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + r.mu.Lock() + r.refreshHosts["test1.example.com"] = struct{}{} + r.refreshHosts["test2.example.com"] = struct{}{} + r.mu.Unlock() + + // 触发缓存命中 + _, _ = r.LookupHostWithCache(context.Background(), "test1.example.com") + + stats := r.Stats() + if stats.CacheHits != 1 { + t.Errorf("expected 1 cache hit, got %d", stats.CacheHits) + } + if stats.CacheEntries != 2 { + t.Errorf("expected 2 cache entries, got %d", stats.CacheEntries) + } +} + +// TestResetStats 测试统计重置。 +func TestResetStats(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Valid: 30 * time.Second, + } + + r := New(cfg).(*DNSResolver) + + // 添加统计数据 + r.hits.Store(10) + r.misses.Store(5) + r.errors.Store(2) + r.latencyNs.Store(1000000) + r.count.Store(10) + + r.ResetStats() + + if r.GetCacheHits() != 0 { + t.Errorf("expected 0 hits after reset, got %d", r.GetCacheHits()) + } + if r.GetCacheMisses() != 0 { + t.Errorf("expected 0 misses after reset, got %d", r.GetCacheMisses()) + } + if r.GetResolveErrors() != 0 { + t.Errorf("expected 0 errors after reset, got %d", r.GetResolveErrors()) + } + if r.GetAverageLatency() != 0 { + t.Errorf("expected 0 latency after reset, got %v", r.GetAverageLatency()) + } +} + +// TestStartStop 测试启动和停止。 +func TestStartStop(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Addresses: []string{"8.8.8.8:53"}, + Valid: 30 * time.Second, + } + + r := New(cfg).(*DNSResolver) + + // 启动 + err := r.Start() + if err != nil { + t.Fatalf("Start failed: %v", err) + } + if !r.started.Load() { + t.Error("resolver should be started") + } + + // 重复启动不应报错 + err = r.Start() + if err != nil { + t.Errorf("Start() should not error when already started: %v", err) + } + + // 停止 + err = r.Stop() + if err != nil { + t.Fatalf("Stop failed: %v", err) + } + if r.started.Load() { + t.Error("resolver should be stopped") + } +} + +// TestDeleteCacheEntry 测试删除缓存条目。 +func TestDeleteCacheEntry(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Valid: 30 * time.Second, + } + + r := New(cfg).(*DNSResolver) + + // 添加缓存 + r.cache.Store("test.example.com", &dnsCacheEntry{ + IPs: []string{"192.168.1.1"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + r.mu.Lock() + r.refreshHosts["test.example.com"] = struct{}{} + r.mu.Unlock() + + // 删除 + r.DeleteCacheEntry("test.example.com") + + // 验证删除 + if r.IsCached("test.example.com") { + t.Error("cache entry should be deleted") + } +} + +// TestClearCache 测试清空缓存。 +func TestClearCache(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Valid: 30 * time.Second, + } + + r := New(cfg).(*DNSResolver) + + // 添加多个缓存 + for i := 0; i < 5; i++ { + host := fmt.Sprintf("test%d.example.com", i) + r.cache.Store(host, &dnsCacheEntry{ + IPs: []string{fmt.Sprintf("192.168.1.%d", i)}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + r.mu.Lock() + r.refreshHosts[host] = struct{}{} + r.mu.Unlock() + } + + // 清空 + r.ClearCache() + + // 验证 + stats := r.GetCacheStats() + if stats.Entries != 0 { + t.Errorf("expected 0 entries after clear, got %d", stats.Entries) + } +} + +// TestConcurrentAccess 测试并发访问。 +func TestConcurrentAccess(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Valid: 30 * time.Second, + } + + r := New(cfg).(*DNSResolver) + + // 添加测试缓存 + r.cache.Store("test.example.com", &dnsCacheEntry{ + IPs: []string{"192.168.1.1"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + r.mu.Lock() + r.refreshHosts["test.example.com"] = struct{}{} + r.mu.Unlock() + + // 并发读取 + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = r.LookupHostWithCache(context.Background(), "test.example.com") + }() + } + wg.Wait() + + // 验证没有竞争条件导致的 panic + if r.GetCacheHits() != 100 { + t.Errorf("expected 100 cache hits, got %d", r.GetCacheHits()) + } +} + +// TestNoopResolver 测试空解析器。 +func TestNoopResolver(t *testing.T) { + nr := &noopResolver{} + + ctx := context.Background() + + _, err := nr.LookupHost(ctx, "example.com") + if err == nil { + t.Error("noopResolver.LookupHost should return error") + } + + _, err = nr.LookupHostWithCache(ctx, "example.com") + if err == nil { + t.Error("noopResolver.LookupHostWithCache should return error") + } + + if err := nr.Refresh("example.com"); err != nil { + t.Error("noopResolver.Refresh should not return error") + } + + if err := nr.Start(); err != nil { + t.Error("noopResolver.Start should not return error") + } + + if err := nr.Stop(); err != nil { + t.Error("noopResolver.Stop should not return error") + } + + stats := nr.Stats() + if stats.CacheHits != 0 || stats.CacheMisses != 0 { + t.Error("noopResolver.Stats should return empty stats") + } +} + +// TestResolverConfigValidate 测试配置验证。 +func TestResolverConfigValidate(t *testing.T) { + // 禁用状态不验证 + cfg := &config.ResolverConfig{Enabled: false} + if err := cfg.Validate(); err != nil { + t.Errorf("disabled resolver should pass validation: %v", err) + } + + // 启用但没有地址 + cfg = &config.ResolverConfig{ + Enabled: true, + } + if err := cfg.Validate(); err == nil { + t.Error("enabled resolver without addresses should fail") + } + + // 有效配置 + cfg = &config.ResolverConfig{ + Enabled: true, + Addresses: []string{"8.8.8.8:53"}, + Valid: 30 * time.Second, + Timeout: 5 * time.Second, + IPv4: true, + IPv6: false, + } + if err := cfg.Validate(); err != nil { + t.Errorf("valid config should pass: %v", err) + } + + // TTL 太短 + cfg = &config.ResolverConfig{ + Enabled: true, + Addresses: []string{"8.8.8.8:53"}, + Valid: 500 * time.Millisecond, + } + if err := cfg.Validate(); err == nil { + t.Error("valid < 1s should fail") + } + + // IPv4 和 IPv6 都禁用 + cfg = &config.ResolverConfig{ + Enabled: true, + Addresses: []string{"8.8.8.8:53"}, + IPv4: false, + IPv6: false, + } + if err := cfg.Validate(); err == nil { + t.Error("both IPv4 and IPv6 disabled should fail") + } +} + +// TestResolverConfigTTL 测试 TTL 方法。 +func TestResolverConfigTTL(t *testing.T) { + cfg := &config.ResolverConfig{ + Valid: 60 * time.Second, + } + if cfg.TTL() != 60*time.Second { + t.Errorf("expected TTL=60s, got %v", cfg.TTL()) + } +} + +// TestRefresh 测试刷新方法。 +func TestRefresh(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Valid: 30 * time.Second, + } + + r := New(cfg).(*DNSResolver) + + // 测试 IP 地址直接返回(无 DNS 查询) + err := r.Refresh("127.0.0.1") + if err != nil { + t.Errorf("Refresh for IP should succeed: %v", err) + } +} + +// TestCacheStats 测试缓存统计。 +func TestCacheStats(t *testing.T) { + cfg := &config.ResolverConfig{ + Enabled: true, + Valid: 1 * time.Second, // 短 TTL 用于测试过期 + } + + r := New(cfg).(*DNSResolver) + + // 添加活跃缓存 + r.cache.Store("active.example.com", &dnsCacheEntry{ + IPs: []string{"192.168.1.1"}, + ExpiresAt: time.Now().Add(1 * time.Minute), + }) + + // 添加过期缓存 + r.cache.Store("expired.example.com", &dnsCacheEntry{ + IPs: []string{"192.168.1.2"}, + ExpiresAt: time.Now().Add(-1 * time.Second), + }) + r.mu.Lock() + r.refreshHosts["active.example.com"] = struct{}{} + r.refreshHosts["expired.example.com"] = struct{}{} + r.mu.Unlock() + + // 设置命中/未命中统计 + r.hits.Store(10) + r.misses.Store(5) + + stats := r.GetCacheStats() + if stats.Hits != 10 { + t.Errorf("expected 10 hits, got %d", stats.Hits) + } + if stats.Misses != 5 { + t.Errorf("expected 5 misses, got %d", stats.Misses) + } + if stats.Entries != 2 { + t.Errorf("expected 2 entries, got %d", stats.Entries) + } + if stats.Expired != 1 { + t.Errorf("expected 1 expired, got %d", stats.Expired) + } +} diff --git a/internal/resolver/stats.go b/internal/resolver/stats.go new file mode 100644 index 0000000..c8ff886 --- /dev/null +++ b/internal/resolver/stats.go @@ -0,0 +1,62 @@ +// Package resolver 提供 DNS 解析功能,支持缓存和后台刷新。 +// +// 该文件包含 DNS 解析器统计指标相关的实现。 +// +// 作者:xfy +package resolver + +import ( + "time" +) + +// StatsCollector 统计收集器接口。 +type StatsCollector interface { + // RecordHit 记录缓存命中 + RecordHit() + // RecordMiss 记录缓存未命中 + RecordMiss() + // RecordError 记录解析错误 + RecordError() + // RecordLatency 记录解析延迟 + RecordLatency(latency time.Duration) + // GetStats 获取当前统计 + GetStats() ResolverStats +} + +// ResetStats 重置所有统计信息。 +func (r *DNSResolver) ResetStats() { + r.hits.Store(0) + r.misses.Store(0) + r.errors.Store(0) + r.latencyNs.Store(0) + r.count.Store(0) +} + +// GetCacheHits 返回缓存命中次数。 +func (r *DNSResolver) GetCacheHits() int64 { + return r.hits.Load() +} + +// GetCacheMisses 返回缓存未命中次数。 +func (r *DNSResolver) GetCacheMisses() int64 { + return r.misses.Load() +} + +// GetResolveErrors 返回解析错误次数。 +func (r *DNSResolver) GetResolveErrors() int64 { + return r.errors.Load() +} + +// GetTotalQueries 返回总查询次数。 +func (r *DNSResolver) GetTotalQueries() int64 { + return r.hits.Load() + r.misses.Load() +} + +// GetAverageLatency 返回平均解析延迟。 +func (r *DNSResolver) GetAverageLatency() time.Duration { + count := r.count.Load() + if count == 0 { + return 0 + } + return time.Duration(r.latencyNs.Load() / count) +}