From 06e8d55ef5b9a036068d56eb7b52cbf8e773db57 Mon Sep 17 00:00:00 2001 From: xfy Date: Thu, 2 Apr 2026 14:40:56 +0800 Subject: [PATCH] =?UTF-8?q?test(config):=20=E6=B7=BB=E5=8A=A0=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E6=A8=A1=E5=9D=97=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加 internal/app 包测试(版本显示、配置生成) - 添加 internal/config 包测试(加载、保存、验证、默认值) - 更新 docs/plan.md 日志系统设计(选用 zerolog) - 更新 .gitignore 添加 coverage.html Co-Authored-By: Claude --- .gitignore | 3 +- docs/plan.md | 127 +++-- internal/app/app_test.go | 237 +++++++++ internal/config/config_test.go | 439 +++++++++++++++++ internal/config/defaults_test.go | 196 ++++++++ internal/config/validate_test.go | 813 +++++++++++++++++++++++++++++++ 6 files changed, 1781 insertions(+), 34 deletions(-) create mode 100644 internal/app/app_test.go create mode 100644 internal/config/config_test.go create mode 100644 internal/config/defaults_test.go create mode 100644 internal/config/validate_test.go diff --git a/.gitignore b/.gitignore index 2b188ca..693a3fb 100644 --- a/.gitignore +++ b/.gitignore @@ -59,4 +59,5 @@ temp/ CLAUDE.md lolly.yaml config.yaml -lolly \ No newline at end of file +lolly +coverage.html \ No newline at end of file diff --git a/docs/plan.md b/docs/plan.md index 5fccb65..8c1413a 100644 --- a/docs/plan.md +++ b/docs/plan.md @@ -286,35 +286,62 @@ type VHostManager struct { **原因**:调试 Phase 2-4 功能需要日志支持,将日志系统基础版本提前实现。 +**选型**:使用 [zerolog](https://github.com/rs/zerolog) 作为日志库。 + +**选择理由**: +- **零分配**:高并发场景 GC 压力最小,性能最优 +- **JSON 输出**:便于日志采集系统(ELK、Loki)解析 +- **API 简洁**:链式调用风格,开发体验好 +- **灵活输出**:支持 stdout/stderr/文件,开发模式可选 ConsoleWriter 美化 + +**性能对比**(10 条日志,禁用输出): + +| 库 | ns/op | allocs/op | +|----|-------|-----------| +| zerolog | ~40ns | **0** | +| zap (structured) | ~50ns | 0 | +| slog (Go 1.21+) | ~200ns | 5+ | +| logrus | ~2000ns | 23 | + **实现**: ```go // internal/logging/logging.go -// Logger 日志管理器 -type Logger struct { - level LogLevel - output io.Writer +import "github.com/rs/zerolog" + +// 全局日志实例 +var log zerolog.Logger + +// Init 初始化日志系统 +func Init(level string, pretty bool) { + l := parseLevel(level) + if pretty { + // 开发模式:带颜色和格式化(性能较差,仅开发用) + log = zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout}) + } else { + // 生产模式:JSON 输出 + log = zerolog.New(os.Stdout).Level(l).With().Timestamp().Logger() + } } -// LogLevel 日志级别 -type LogLevel int -const ( - LogLevelDebug LogLevel = iota - LogLevelInfo - LogLevelWarn - LogLevelError -) - // AccessLogger 访问日志(基础版) -func LogAccess(r *http.Request, status int, size int64, duration time.Duration) +func LogAccess(r *http.Request, status int, size int64, duration time.Duration) { + log.Info(). + Str("method", r.Method). + Str("path", r.URL.Path). + Int("status", status). + Int64("size", size). + Dur("duration", duration). + Msg("request") +} ``` **Phase 2 实现范围**: - 基础请求日志:记录请求方法、路径、状态码 -- 控制台输出:开发阶段便于调试 -- Phase 5 将扩展为完整日志系统(文件输出、自定义格式) +- 控制台输出:开发阶段便于调试(ConsoleWriter 美化) +- Phase 5 将扩展为完整日志系统(文件输出、自定义格式、访问/错误日志分离) ### 验证方法 @@ -906,42 +933,76 @@ cache: #### 5.4 日志系统 +**扩展 Phase 2 的 zerolog 实现**,增加文件输出和访问/错误日志分离。 + **实现**: ```go // internal/logging/logging.go +import ( + "io" + "github.com/rs/zerolog" +) + // Logger 日志管理器 type Logger struct { - accessLog *AccessLogger - errorLog *ErrorLogger - level LogLevel + accessLog zerolog.Logger // 访问日志 + errorLog zerolog.Logger // 错误日志 } -// AccessLogger 访问日志 -type AccessLogger struct { - format string // 日志格式 - output io.Writer +// New 创建日志管理器 +func New(cfg *LoggingConfig) *Logger { + // 访问日志:stdout 或文件 + accessOut := getOutput(cfg.Access.Path) + accessLog := zerolog.New(accessOut).With().Timestamp().Logger() + + // 错误日志:stderr 或文件 + errorOut := getOutput(cfg.Error.Path) + errorLevel := parseLevel(cfg.Error.Level) + errorLog := zerolog.New(errorOut).Level(errorLevel).With().Timestamp().Logger() + + return &Logger{accessLog: accessLog, errorLog: errorLog} } -// LogFormat 日志格式变量 -// $remote_addr - 客户端 IP -// $request - 请求行 -// $status - 响应状态码 -// $body_bytes_sent - 响应体大小 -// $request_time - 请求耗时 +// LogAccess 记录访问日志(nginx 格式变量) +func (l *Logger) LogAccess(r *http.Request, status int, size int64, duration time.Duration) { + l.accessLog.Info(). + Str("remote_addr", r.RemoteAddr). + Str("request", fmt.Sprintf("%s %s", r.Method, r.URL.Path)). + Int("status", status). + Int64("body_bytes_sent", size). + Dur("request_time", duration). + Msg("") +} + +// getOutput 获取输出目标(stdout/stderr/文件) +func getOutput(path string) io.Writer { + if path == "" { + return os.Stdout + } + f, _ := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + return f +} ``` +**日志格式变量**(支持 nginx 风格配置): +- `$remote_addr` - 客户端 IP +- `$request` - 请求行(方法 + 路径) +- `$status` - 响应状态码 +- `$body_bytes_sent` - 响应体大小 +- `$request_time` - 请求耗时 + **配置示例**: ```yaml logging: access: - path: /var/log/lolly/access.log - format: "$remote_addr - $request - $status - $body_bytes_sent" + path: /var/log/lolly/access.log # 留空则输出到 stdout + format: json # json 或 text(Phase 2 ConsoleWriter) error: - path: /var/log/lolly/error.log - level: info # debug/info/warn/error + path: /var/log/lolly/error.log # 留空则输出到 stderr + level: info # debug/info/warn/error ``` #### 5.5 状态监控端点 diff --git a/internal/app/app_test.go b/internal/app/app_test.go new file mode 100644 index 0000000..6231fee --- /dev/null +++ b/internal/app/app_test.go @@ -0,0 +1,237 @@ +// Package app 提供应用程序的启动和运行逻辑。 +package app + +import ( + "bytes" + "os" + "path/filepath" + "strings" + "testing" +) + +// captureStdout 捕获 stdout 输出,返回捕获的内容和恢复函数。 +func captureStdout(t *testing.T) (func() string, func()) { + t.Helper() + old := os.Stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("创建 pipe 失败: %v", err) + } + os.Stdout = w + + return func() string { + w.Close() + os.Stdout = old + var buf bytes.Buffer + buf.ReadFrom(r) + return buf.String() + }, func() { + w.Close() + os.Stdout = old + } +} + +// captureStderr 捕获 stderr 输出,返回捕获的内容和恢复函数。 +func captureStderr(t *testing.T) (func() string, func()) { + t.Helper() + old := os.Stderr + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("创建 pipe 失败: %v", err) + } + os.Stderr = w + + return func() string { + w.Close() + os.Stderr = old + var buf bytes.Buffer + buf.ReadFrom(r) + return buf.String() + }, func() { + w.Close() + os.Stderr = old + } +} + +// TestRun 测试 Run 函数的各种场景。 +func TestRun(t *testing.T) { + tests := []struct { + name string + cfgPath string + genConfig bool + outputPath string + showVersion bool + wantExitCode int + wantContains string // stdout 应包含的内容 + wantErrContains string // stderr 应包含的内容(可选) + }{ + { + name: "显示版本", + showVersion: true, + wantExitCode: 0, + wantContains: "lolly version", + }, + { + name: "生成配置输出到 stdout", + genConfig: true, + outputPath: "", + wantExitCode: 0, + wantContains: "server:", + }, + { + name: "生成配置输出到文件", + genConfig: true, + outputPath: filepath.Join(t.TempDir(), "config.yaml"), + wantExitCode: 0, + wantContains: "配置已写入:", + }, + { + name: "配置文件不存在", + cfgPath: filepath.Join(t.TempDir(), "nonexistent.yaml"), + genConfig: false, + showVersion: false, + wantExitCode: 1, + wantErrContains: "加载配置失败", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + getStdout, restoreStdout := captureStdout(t) + getStderr, restoreStderr := captureStderr(t) + + exitCode := Run(tt.cfgPath, tt.genConfig, tt.outputPath, tt.showVersion) + + restoreStderr() + restoreStdout() + + stdout := getStdout() + stderr := getStderr() + + if exitCode != tt.wantExitCode { + t.Errorf("exit code = %d, want %d", exitCode, tt.wantExitCode) + } + + if tt.wantContains != "" && !strings.Contains(stdout, tt.wantContains) { + t.Errorf("stdout 应包含 %q, 实际输出: %q", tt.wantContains, stdout) + } + + if tt.wantErrContains != "" && !strings.Contains(stderr, tt.wantErrContains) { + t.Errorf("stderr 应包含 %q, 实际输出: %q", tt.wantErrContains, stderr) + } + + // 验证生成配置文件的内容 + if tt.outputPath != "" && tt.genConfig && exitCode == 0 { + data, err := os.ReadFile(tt.outputPath) + if err != nil { + t.Errorf("读取生成的配置文件失败: %v", err) + } else if !strings.Contains(string(data), "server:") { + t.Errorf("生成的配置文件应包含 'server:', 实际内容: %s", string(data)[:100]) + } + } + }) + } +} + +// TestGenerateConfig 测试 generateConfig 函数。 +func TestGenerateConfig(t *testing.T) { + t.Run("输出到 stdout", func(t *testing.T) { + getStdout, restoreStdout := captureStdout(t) + + exitCode := generateConfig("") + restoreStdout() + + stdout := getStdout() + + if exitCode != 0 { + t.Errorf("exit code = %d, want 0", exitCode) + } + + // 验证输出包含基本配置结构 + expectedFields := []string{"server:", "listen:", "logging:", "performance:", "monitoring:"} + for _, field := range expectedFields { + if !strings.Contains(stdout, field) { + t.Errorf("输出应包含 %q", field) + } + } + }) + + t.Run("输出到文件", func(t *testing.T) { + tmpDir := t.TempDir() + outputPath := filepath.Join(tmpDir, "test-config.yaml") + + getStdout, restoreStdout := captureStdout(t) + + exitCode := generateConfig(outputPath) + restoreStdout() + + stdout := getStdout() + + if exitCode != 0 { + t.Errorf("exit code = %d, want 0", exitCode) + } + + if !strings.Contains(stdout, outputPath) { + t.Errorf("stdout 应包含文件路径 %q, 实际输出: %q", outputPath, stdout) + } + + // 验证文件存在且内容正确 + data, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("读取生成的配置文件失败: %v", err) + } + + content := string(data) + expectedFields := []string{"server:", "listen:", "logging:", "performance:", "monitoring:"} + for _, field := range expectedFields { + if !strings.Contains(content, field) { + t.Errorf("配置文件应包含 %q", field) + } + } + }) + + t.Run("输出到无效路径", func(t *testing.T) { + // 使用一个无法写入的路径(如根目录下的文件) + invalidPath := "/root/cannot-write-here.yaml" + + getStderr, restoreStderr := captureStderr(t) + + exitCode := generateConfig(invalidPath) + restoreStderr() + + stderr := getStderr() + + if exitCode != 1 { + t.Errorf("exit code = %d, want 1", exitCode) + } + + if !strings.Contains(stderr, "写入文件失败") { + t.Errorf("stderr 应包含 '写入文件失败', 实际输出: %q", stderr) + } + }) +} + +// TestPrintVersion 测试 printVersion 函数。 +func TestPrintVersion(t *testing.T) { + getStdout, restoreStdout := captureStdout(t) + + printVersion() + restoreStdout() + + stdout := getStdout() + + // 验证版本输出格式 + expectedLines := []string{ + "lolly version", + "Git:", + "Built:", + "Go:", + "Platform:", + } + + for _, line := range expectedLines { + if !strings.Contains(stdout, line) { + t.Errorf("版本输出应包含 %q, 实际输出: %q", line, stdout) + } + } +} \ No newline at end of file diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..591094b --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,439 @@ +// Package config 提供配置加载和管理的测试。 +package config + +import ( + "os" + "path/filepath" + "testing" +) + +// TestLoad 测试从文件加载配置。 +func TestLoad(t *testing.T) { + t.Run("有效配置文件", func(t *testing.T) { + // 创建临时配置文件 + content := ` +server: + listen: ":8080" + static: + root: "/var/www" + index: + - "index.html" +logging: + access: + path: "/var/log/access.log" + format: "combined" + error: + path: "/var/log/error.log" + level: "info" +performance: + goroutine_pool: + enabled: true + max_workers: 100 + file_cache: + max_entries: 1000 +monitoring: + status: + path: "/status" +` + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "config.yaml") + if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil { + t.Fatalf("创建临时配置文件失败: %v", err) + } + + cfg, err := Load(tmpFile) + if err != nil { + t.Fatalf("Load() 失败: %v", err) + } + + if cfg.Server.Listen != ":8080" { + t.Errorf("Server.Listen = %q, want %q", cfg.Server.Listen, ":8080") + } + if cfg.Server.Static.Root != "/var/www" { + t.Errorf("Server.Static.Root = %q, want %q", cfg.Server.Static.Root, "/var/www") + } + if len(cfg.Server.Static.Index) != 1 || cfg.Server.Static.Index[0] != "index.html" { + t.Errorf("Server.Static.Index = %v, want [index.html]", cfg.Server.Static.Index) + } + }) + + t.Run("文件不存在", func(t *testing.T) { + _, err := Load("/nonexistent/path/config.yaml") + if err == nil { + t.Error("Load() 期望返回错误,但返回 nil") + } + }) + + t.Run("无效YAML", func(t *testing.T) { + content := ` +server: + listen: ":8080" + static: + root: [invalid yaml structure +` + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "invalid.yaml") + if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil { + t.Fatalf("创建临时配置文件失败: %v", err) + } + + _, err := Load(tmpFile) + if err == nil { + t.Error("Load() 期望返回错误,但返回 nil") + } + }) + + t.Run("缺少必填字段(无服务器配置)", func(t *testing.T) { + content := ` +logging: + access: + path: "/var/log/access.log" +` + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "no_server.yaml") + if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil { + t.Fatalf("创建临时配置文件失败: %v", err) + } + + _, err := Load(tmpFile) + if err == nil { + t.Error("Load() 期望返回错误,但返回 nil") + } + }) + + t.Run("多虚拟主机模式", func(t *testing.T) { + content := ` +servers: + - listen: ":8080" + name: "server1" + - listen: ":8081" + name: "server2" +` + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "multi.yaml") + if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil { + t.Fatalf("创建临时配置文件失败: %v", err) + } + + cfg, err := Load(tmpFile) + if err != nil { + t.Fatalf("Load() 失败: %v", err) + } + + if len(cfg.Servers) != 2 { + t.Fatalf("len(Servers) = %d, want 2", len(cfg.Servers)) + } + if cfg.Servers[0].Name != "server1" { + t.Errorf("Servers[0].Name = %q, want %q", cfg.Servers[0].Name, "server1") + } + if cfg.Servers[1].Name != "server2" { + t.Errorf("Servers[1].Name = %q, want %q", cfg.Servers[1].Name, "server2") + } + }) +} + +// TestLoadFromString 测试从字符串加载配置。 +func TestLoadFromString(t *testing.T) { + t.Run("有效字符串", func(t *testing.T) { + yamlStr := ` +server: + listen: ":9090" + static: + root: "/app/public" +` + cfg, err := LoadFromString(yamlStr) + if err != nil { + t.Fatalf("LoadFromString() 失败: %v", err) + } + + if cfg.Server.Listen != ":9090" { + t.Errorf("Server.Listen = %q, want %q", cfg.Server.Listen, ":9090") + } + if cfg.Server.Static.Root != "/app/public" { + t.Errorf("Server.Static.Root = %q, want %q", cfg.Server.Static.Root, "/app/public") + } + }) + + t.Run("无效YAML", func(t *testing.T) { + yamlStr := ` +server: + listen: ":8080" + broken: [unclosed +` + _, err := LoadFromString(yamlStr) + if err == nil { + t.Error("LoadFromString() 期望返回错误,但返回 nil") + } + }) + + t.Run("缺少必填字段", func(t *testing.T) { + yamlStr := ` +logging: + access: + path: "/var/log/access.log" +` + _, err := LoadFromString(yamlStr) + if err == nil { + t.Error("LoadFromString() 期望返回错误,但返回 nil") + } + }) + + t.Run("空字符串", func(t *testing.T) { + _, err := LoadFromString("") + if err == nil { + t.Error("LoadFromString() 期望返回错误,但返回 nil") + } + }) +} + +// TestSave 测试保存配置到文件。 +func TestSave(t *testing.T) { + t.Run("正常保存", func(t *testing.T) { + cfg := &Config{ + Server: ServerConfig{ + Listen: ":8080", + Static: StaticConfig{ + Root: "/var/www", + Index: []string{"index.html"}, + }, + }, + } + + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "saved_config.yaml") + + if err := Save(cfg, tmpFile); err != nil { + t.Fatalf("Save() 失败: %v", err) + } + + // 验证文件已创建并可重新加载 + loaded, err := Load(tmpFile) + if err != nil { + t.Fatalf("重新加载配置失败: %v", err) + } + + if loaded.Server.Listen != cfg.Server.Listen { + t.Errorf("loaded.Server.Listen = %q, want %q", loaded.Server.Listen, cfg.Server.Listen) + } + if loaded.Server.Static.Root != cfg.Server.Static.Root { + t.Errorf("loaded.Server.Static.Root = %q, want %q", loaded.Server.Static.Root, cfg.Server.Static.Root) + } + }) + + t.Run("无效路径", func(t *testing.T) { + cfg := &Config{ + Server: ServerConfig{ + Listen: ":8080", + }, + } + + err := Save(cfg, "/nonexistent/directory/config.yaml") + if err == nil { + t.Error("Save() 期望返回错误,但返回 nil") + } + }) + + t.Run("保存多虚拟主机配置", func(t *testing.T) { + cfg := &Config{ + Servers: []ServerConfig{ + {Listen: ":8080", Name: "server1"}, + {Listen: ":8081", Name: "server2"}, + }, + } + + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "multi.yaml") + + if err := Save(cfg, tmpFile); err != nil { + t.Fatalf("Save() 失败: %v", err) + } + + loaded, err := Load(tmpFile) + if err != nil { + t.Fatalf("重新加载配置失败: %v", err) + } + + if len(loaded.Servers) != 2 { + t.Errorf("len(loaded.Servers) = %d, want 2", len(loaded.Servers)) + } + }) + + t.Run("保存并加载完整配置", func(t *testing.T) { + cfg := &Config{ + Server: ServerConfig{ + Listen: ":8443", + Name: "default", + Static: StaticConfig{ + Root: "/var/www/html", + Index: []string{"index.html", "index.htm"}, + }, + Proxy: []ProxyConfig{ + { + Path: "/api", + Targets: []ProxyTarget{ + {URL: "http://backend1:8080", Weight: 1}, + {URL: "http://backend2:8080", Weight: 2}, + }, + LoadBalance: "round_robin", + }, + }, + SSL: SSLConfig{ + Cert: "/etc/ssl/cert.pem", + Key: "/etc/ssl/key.pem", + Protocols: []string{"TLSv1.2", "TLSv1.3"}, + }, + Security: SecurityConfig{ + RateLimit: RateLimitConfig{ + RequestRate: 100, + Burst: 200, + }, + }, + }, + Logging: LoggingConfig{ + Access: AccessLogConfig{ + Path: "/var/log/access.log", + Format: "combined", + }, + Error: ErrorLogConfig{ + Path: "/var/log/error.log", + Level: "warn", + }, + }, + Performance: PerformanceConfig{ + GoroutinePool: GoroutinePoolConfig{ + Enabled: true, + MaxWorkers: 1000, + MinWorkers: 10, + }, + }, + } + + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "full.yaml") + + if err := Save(cfg, tmpFile); err != nil { + t.Fatalf("Save() 失败: %v", err) + } + + loaded, err := Load(tmpFile) + if err != nil { + t.Fatalf("重新加载配置失败: %v", err) + } + + // 验证关键字段 + if loaded.Server.Listen != cfg.Server.Listen { + t.Errorf("loaded.Server.Listen = %q, want %q", loaded.Server.Listen, cfg.Server.Listen) + } + if len(loaded.Server.Proxy) != 1 { + t.Errorf("len(loaded.Server.Proxy) = %d, want 1", len(loaded.Server.Proxy)) + } + if loaded.Server.Proxy[0].LoadBalance != "round_robin" { + t.Errorf("loaded.Server.Proxy[0].LoadBalance = %q, want %q", loaded.Server.Proxy[0].LoadBalance, "round_robin") + } + }) +} + +// TestConfigMethods 测试 Config 结构体的方法。 +func TestConfigMethods(t *testing.T) { + t.Run("HasServers_有服务器列表", func(t *testing.T) { + cfg := &Config{ + Servers: []ServerConfig{ + {Listen: ":8080"}, + }, + } + if !cfg.HasServers() { + t.Error("HasServers() = false, want true") + } + }) + + t.Run("HasServers_无服务器列表", func(t *testing.T) { + cfg := &Config{} + if cfg.HasServers() { + t.Error("HasServers() = true, want false") + } + }) + + t.Run("HasDefaultServer_有默认服务器", func(t *testing.T) { + cfg := &Config{ + Server: ServerConfig{ + Listen: ":8080", + }, + } + if !cfg.HasDefaultServer() { + t.Error("HasDefaultServer() = false, want true") + } + }) + + t.Run("HasDefaultServer_无默认服务器", func(t *testing.T) { + cfg := &Config{} + if cfg.HasDefaultServer() { + t.Error("HasDefaultServer() = true, want false") + } + }) + + t.Run("GetDefaultServer_有默认服务器", func(t *testing.T) { + cfg := &Config{ + Server: ServerConfig{ + Listen: ":8080", + Name: "default", + }, + } + server := cfg.GetDefaultServer() + if server == nil { + t.Fatal("GetDefaultServer() = nil, want non-nil") + } + if server.Listen != ":8080" { + t.Errorf("server.Listen = %q, want %q", server.Listen, ":8080") + } + }) + + t.Run("GetDefaultServer_无默认服务器", func(t *testing.T) { + cfg := &Config{} + server := cfg.GetDefaultServer() + if server != nil { + t.Errorf("GetDefaultServer() = %v, want nil", server) + } + }) + + t.Run("配置模式判断", func(t *testing.T) { + tests := []struct { + name string + cfg *Config + wantHasServers bool + wantHasDefault bool + }{ + { + name: "仅默认服务器", + cfg: &Config{Server: ServerConfig{Listen: ":8080"}}, + wantHasServers: false, + wantHasDefault: true, + }, + { + name: "仅多虚拟主机", + cfg: &Config{Servers: []ServerConfig{{Listen: ":8080"}}}, + wantHasServers: true, + wantHasDefault: false, + }, + { + name: "混合模式", + cfg: &Config{ + Server: ServerConfig{Listen: ":8080"}, + Servers: []ServerConfig{{Listen: ":8081"}}, + }, + wantHasServers: true, + wantHasDefault: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.HasServers(); got != tt.wantHasServers { + t.Errorf("HasServers() = %v, want %v", got, tt.wantHasServers) + } + if got := tt.cfg.HasDefaultServer(); got != tt.wantHasDefault { + t.Errorf("HasDefaultServer() = %v, want %v", got, tt.wantHasDefault) + } + }) + } + }) +} \ No newline at end of file diff --git a/internal/config/defaults_test.go b/internal/config/defaults_test.go new file mode 100644 index 0000000..5abd27b --- /dev/null +++ b/internal/config/defaults_test.go @@ -0,0 +1,196 @@ +package config + +import ( + "strings" + "testing" + "time" +) + +func TestDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + + // 验证 Listen 默认值 + if cfg.Server.Listen != ":8080" { + t.Errorf("Server.Listen 期望 :8080, 实际 %s", cfg.Server.Listen) + } + + // 验证 SSL 默认版本 + if len(cfg.Server.SSL.Protocols) != 2 { + t.Errorf("SSL.Protocols 期望 2 个版本, 实际 %d", len(cfg.Server.SSL.Protocols)) + } + expectedProtocols := []string{"TLSv1.2", "TLSv1.3"} + for i, proto := range cfg.Server.SSL.Protocols { + if proto != expectedProtocols[i] { + t.Errorf("SSL.Protocols[%d] 期望 %s, 实际 %s", i, expectedProtocols[i], proto) + } + } + + // 验证 HSTS 默认值 + if cfg.Server.SSL.HSTS.MaxAge != 31536000 { + t.Errorf("HSTS.MaxAge 期望 31536000, 实际 %d", cfg.Server.SSL.HSTS.MaxAge) + } + if !cfg.Server.SSL.HSTS.IncludeSubDomains { + t.Errorf("HSTS.IncludeSubDomains 期望 true, 实际 %v", cfg.Server.SSL.HSTS.IncludeSubDomains) + } + if cfg.Server.SSL.HSTS.Preload { + t.Errorf("HSTS.Preload 期望 false, 实际 %v", cfg.Server.SSL.HSTS.Preload) + } + + // 验证压缩默认值 + if cfg.Server.Compression.Type != "gzip" { + t.Errorf("Compression.Type 期望 gzip, 实际 %s", cfg.Server.Compression.Type) + } + if cfg.Server.Compression.Level != 6 { + t.Errorf("Compression.Level 期望 6, 实际 %d", cfg.Server.Compression.Level) + } + if cfg.Server.Compression.MinSize != 1024 { + t.Errorf("Compression.MinSize 期望 1024, 实际 %d", cfg.Server.Compression.MinSize) + } + expectedTypes := []string{"text/html", "text/css", "text/javascript", "application/json", "application/javascript"} + for i, ct := range cfg.Server.Compression.Types { + if ct != expectedTypes[i] { + t.Errorf("Compression.Types[%d] 期望 %s, 实际 %s", i, expectedTypes[i], ct) + } + } +} + +func TestGenerateConfigYAML(t *testing.T) { + cfg := DefaultConfig() + + yamlData, err := GenerateConfigYAML(cfg) + if err != nil { + t.Fatalf("GenerateConfigYAML 返回错误: %v", err) + } + + // 验证输出非空 + if len(yamlData) == 0 { + t.Error("GenerateConfigYAML 输出为空") + } + + yamlStr := string(yamlData) + + // 验证包含注释 + if !strings.Contains(yamlStr, "#") { + t.Error("YAML 输出未包含注释") + } + if !strings.Contains(yamlStr, "# Lolly 配置文件") { + t.Error("YAML 输出未包含文件头注释") + } + if !strings.Contains(yamlStr, "# 服务器配置") { + t.Error("YAML 输出未包含服务器配置注释") + } + + // 验证可重新解析 - 使用 LoadFromString 解析生成的 YAML + // 注意:GenerateConfigYAML 生成的 YAML 包含注释的示例配置(如 proxy、rewrite 等) + // 这些是注释掉的示例,不会被解析。需要提取实际生效的部分进行验证。 + + // 构建一个可解析的简化 YAML 进行验证 + simpleYAML, err := GenerateSimpleYAML(cfg) + if err != nil { + t.Fatalf("GenerateSimpleYAML 返回错误: %v", err) + } + + parsedCfg, err := LoadFromString(string(simpleYAML)) + if err != nil { + t.Fatalf("解析生成的 YAML 失败: %v", err) + } + + // 验证配置一致性 + if parsedCfg.Server.Listen != cfg.Server.Listen { + t.Errorf("解析后 Server.Listen 不一致: 期望 %s, 实际 %s", cfg.Server.Listen, parsedCfg.Server.Listen) + } + if parsedCfg.Server.Name != cfg.Server.Name { + t.Errorf("解析后 Server.Name 不一致: 期望 %s, 实际 %s", cfg.Server.Name, parsedCfg.Server.Name) + } + if parsedCfg.Server.Compression.Type != cfg.Server.Compression.Type { + t.Errorf("解析后 Compression.Type 不一致: 期望 %s, 实际 %s", cfg.Server.Compression.Type, parsedCfg.Server.Compression.Type) + } + if parsedCfg.Server.Compression.Level != cfg.Server.Compression.Level { + t.Errorf("解析后 Compression.Level 不一致: 期望 %d, 实际 %d", cfg.Server.Compression.Level, parsedCfg.Server.Compression.Level) + } + + // 验证性能配置一致性 + if parsedCfg.Performance.GoroutinePool.MaxWorkers != cfg.Performance.GoroutinePool.MaxWorkers { + t.Errorf("解析后 GoroutinePool.MaxWorkers 不一致: 期望 %d, 实际 %d", + cfg.Performance.GoroutinePool.MaxWorkers, parsedCfg.Performance.GoroutinePool.MaxWorkers) + } + if parsedCfg.Performance.FileCache.MaxEntries != cfg.Performance.FileCache.MaxEntries { + t.Errorf("解析后 FileCache.MaxEntries 不一致: 期望 %d, 实际 %d", + cfg.Performance.FileCache.MaxEntries, parsedCfg.Performance.FileCache.MaxEntries) + } + + // 验证时间.Duration 字段正确解析 + if parsedCfg.Performance.GoroutinePool.IdleTimeout != cfg.Performance.GoroutinePool.IdleTimeout { + t.Errorf("解析后 GoroutinePool.IdleTimeout 不一致: 期望 %v, 实际 %v", + cfg.Performance.GoroutinePool.IdleTimeout, parsedCfg.Performance.GoroutinePool.IdleTimeout) + } + if parsedCfg.Performance.FileCache.Inactive != cfg.Performance.FileCache.Inactive { + t.Errorf("解析后 FileCache.Inactive 不一致: 期望 %v, 实际 %v", + cfg.Performance.FileCache.Inactive, parsedCfg.Performance.FileCache.Inactive) + } +} + +func TestGenerateSimpleYAML(t *testing.T) { + cfg := DefaultConfig() + + yamlData, err := GenerateSimpleYAML(cfg) + if err != nil { + t.Fatalf("GenerateSimpleYAML 返回错误: %v", err) + } + + if len(yamlData) == 0 { + t.Error("GenerateSimpleYAML 输出为空") + } + + // 验证不包含注释(简洁 YAML) + yamlStr := string(yamlData) + if strings.Contains(yamlStr, "# Lolly 配置文件") { + t.Error("简洁 YAML 不应包含文件头注释") + } +} + +func TestDefaultConfigPerformance(t *testing.T) { + cfg := DefaultConfig() + + // 验证 GoroutinePool 默认值 + if cfg.Performance.GoroutinePool.Enabled { + t.Errorf("GoroutinePool.Enabled 期望 false, 实际 %v", cfg.Performance.GoroutinePool.Enabled) + } + if cfg.Performance.GoroutinePool.MaxWorkers != 1000 { + t.Errorf("GoroutinePool.MaxWorkers 期望 1000, 实际 %d", cfg.Performance.GoroutinePool.MaxWorkers) + } + if cfg.Performance.GoroutinePool.MinWorkers != 10 { + t.Errorf("GoroutinePool.MinWorkers 期望 10, 实际 %d", cfg.Performance.GoroutinePool.MinWorkers) + } + if cfg.Performance.GoroutinePool.IdleTimeout != 60*time.Second { + t.Errorf("GoroutinePool.IdleTimeout 期望 60s, 实际 %v", cfg.Performance.GoroutinePool.IdleTimeout) + } + + // 验证 FileCache 默认值 + if cfg.Performance.FileCache.MaxEntries != 10000 { + t.Errorf("FileCache.MaxEntries 期望 10000, 实际 %d", cfg.Performance.FileCache.MaxEntries) + } + if cfg.Performance.FileCache.MaxSize != 256*1024*1024 { + t.Errorf("FileCache.MaxSize 期望 256MB, 实际 %d", cfg.Performance.FileCache.MaxSize) + } + if cfg.Performance.FileCache.Inactive != 20*time.Second { + t.Errorf("FileCache.Inactive 期望 20s, 实际 %v", cfg.Performance.FileCache.Inactive) + } + if !cfg.Performance.FileCache.LRUEviction { + t.Errorf("FileCache.LRUEviction 期望 true, 实际 %v", cfg.Performance.FileCache.LRUEviction) + } + + // 验证 Transport 默认值 + if cfg.Performance.Transport.MaxIdleConns != 100 { + t.Errorf("Transport.MaxIdleConns 期望 100, 实际 %d", cfg.Performance.Transport.MaxIdleConns) + } + if cfg.Performance.Transport.MaxIdleConnsPerHost != 32 { + t.Errorf("Transport.MaxIdleConnsPerHost 期望 32, 实际 %d", cfg.Performance.Transport.MaxIdleConnsPerHost) + } + if cfg.Performance.Transport.IdleConnTimeout != 90*time.Second { + t.Errorf("Transport.IdleConnTimeout 期望 90s, 实际 %v", cfg.Performance.Transport.IdleConnTimeout) + } + if cfg.Performance.Transport.MaxConnsPerHost != 0 { + t.Errorf("Transport.MaxConnsPerHost 期望 0 (不限制), 实际 %d", cfg.Performance.Transport.MaxConnsPerHost) + } +} \ No newline at end of file diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go new file mode 100644 index 0000000..5ac5e8c --- /dev/null +++ b/internal/config/validate_test.go @@ -0,0 +1,813 @@ +// Package config 提供 YAML 配置文件的解析、验证和默认配置生成功能。 +package config + +import ( + "strings" + "testing" +) + +func TestValidateServer(t *testing.T) { + tests := []struct { + name string + config ServerConfig + isDefault bool + wantErr bool + errMsg string + }{ + { + name: "有效配置", + config: ServerConfig{ + Listen: ":8080", + Static: StaticConfig{Root: "/var/www"}, + Proxy: []ProxyConfig{ + {Path: "/api", Targets: []ProxyTarget{{URL: "http://backend:8080"}}}, + }, + }, + isDefault: false, + wantErr: false, + }, + { + name: "默认服务器可省略Listen", + config: ServerConfig{ + Static: StaticConfig{Root: "/var/www"}, + }, + isDefault: true, + wantErr: false, + }, + { + name: "非默认服务器Listen缺失", + config: ServerConfig{ + Static: StaticConfig{Root: "/var/www"}, + }, + isDefault: false, + wantErr: true, + errMsg: "listen 地址必填", + }, + { + name: "无效Listen地址", + config: ServerConfig{ + Listen: "invalid:address:format", + }, + isDefault: false, + wantErr: true, + errMsg: "无效的监听地址", + }, + { + name: "静态根目录含..", + config: ServerConfig{ + Listen: ":8080", + Static: StaticConfig{Root: "/var/../www"}, + }, + isDefault: false, + wantErr: true, + errMsg: "根目录路径不能包含 '..'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateServer(&tt.config, tt.isDefault) + if tt.wantErr { + if err == nil { + t.Errorf("validateServer() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateServer() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateServer() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateProxy(t *testing.T) { + tests := []struct { + name string + config ProxyConfig + wantErr bool + errMsg string + }{ + { + name: "有效代理配置", + config: ProxyConfig{ + Path: "/api", + Targets: []ProxyTarget{{URL: "http://backend:8080"}}, + }, + wantErr: false, + }, + { + name: "有效代理带负载均衡", + config: ProxyConfig{ + Path: "/api", + Targets: []ProxyTarget{{URL: "http://backend:8080"}}, + LoadBalance: "round_robin", + }, + wantErr: false, + }, + { + name: "Path缺失", + config: ProxyConfig{ + Targets: []ProxyTarget{{URL: "http://backend:8080"}}, + }, + wantErr: true, + errMsg: "path 必填", + }, + { + name: "Targets空", + config: ProxyConfig{ + Path: "/api", + Targets: []ProxyTarget{}, + }, + wantErr: true, + errMsg: "targets 至少需要一个目标地址", + }, + { + name: "URL格式错误-无协议", + config: ProxyConfig{ + Path: "/api", + Targets: []ProxyTarget{{URL: "backend:8080"}}, + }, + wantErr: true, + errMsg: "必须以 http:// 或 https:// 开头", + }, + { + name: "URL格式错误-空URL", + config: ProxyConfig{ + Path: "/api", + Targets: []ProxyTarget{{URL: ""}}, + }, + wantErr: true, + errMsg: "url 必填", + }, + { + name: "无效负载均衡算法", + config: ProxyConfig{ + Path: "/api", + Targets: []ProxyTarget{{URL: "http://backend:8080"}}, + LoadBalance: "invalid_algorithm", + }, + wantErr: true, + errMsg: "无效的负载均衡算法", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateProxy(&tt.config) + if tt.wantErr { + if err == nil { + t.Errorf("validateProxy() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateProxy() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateProxy() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateSSL(t *testing.T) { + tests := []struct { + name string + config SSLConfig + wantErr bool + errMsg string + }{ + { + name: "未配置SSL", + config: SSLConfig{}, + wantErr: false, + }, + { + name: "有效SSL配置", + config: SSLConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + Protocols: []string{"TLSv1.2", "TLSv1.3"}, + }, + wantErr: false, + }, + { + name: "仅Cert配置", + config: SSLConfig{ + Cert: "/path/to/cert.pem", + }, + wantErr: true, + errMsg: "cert 和 key 必须同时配置", + }, + { + name: "仅Key配置", + config: SSLConfig{ + Key: "/path/to/key.pem", + }, + wantErr: true, + errMsg: "cert 和 key 必须同时配置", + }, + { + name: "TLSv1.0不安全", + config: SSLConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + Protocols: []string{"TLSv1.0"}, + }, + wantErr: true, + errMsg: "不安全的 TLS 版本: TLSv1.0", + }, + { + name: "TLSv1.1不安全", + config: SSLConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + Protocols: []string{"TLSv1.1"}, + }, + wantErr: true, + errMsg: "不安全的 TLS 版本: TLSv1.1", + }, + { + name: "不安全加密套件RC4", + config: SSLConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + Protocols: []string{"TLSv1.2"}, + Ciphers: []string{"RC4-SHA"}, + }, + wantErr: true, + errMsg: "不安全的加密套件", + }, + { + name: "不安全加密套件DES", + config: SSLConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + Protocols: []string{"TLSv1.2"}, + Ciphers: []string{"DES-CBC3-SHA"}, + }, + wantErr: true, + errMsg: "不安全的加密套件", + }, + { + name: "未知TLS版本", + config: SSLConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + Protocols: []string{"TLSv1.4"}, + }, + wantErr: true, + errMsg: "未知的 TLS 版本", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateSSL(&tt.config) + if tt.wantErr { + if err == nil { + t.Errorf("validateSSL() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateSSL() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateSSL() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateAuth(t *testing.T) { + tests := []struct { + name string + config AuthConfig + wantErr bool + errMsg string + }{ + { + name: "未配置认证", + config: AuthConfig{}, + wantErr: false, + }, + { + name: "有效Basic认证配置", + config: AuthConfig{ + Type: "basic", + Algorithm: "bcrypt", + Users: []User{{Name: "admin", Password: "hashed_password"}}, + }, + wantErr: false, + }, + { + name: "无效认证类型", + config: AuthConfig{ + Type: "oauth", + Users: []User{{Name: "admin", Password: "hashed_password"}}, + }, + wantErr: true, + errMsg: "不支持的认证类型", + }, + { + name: "启用认证但无用户", + config: AuthConfig{ + Type: "basic", + Algorithm: "bcrypt", + Users: []User{}, + }, + wantErr: true, + errMsg: "启用认证时至少需要一个用户", + }, + { + name: "用户名缺失", + config: AuthConfig{ + Type: "basic", + Algorithm: "bcrypt", + Users: []User{{Name: "", Password: "hashed_password"}}, + }, + wantErr: true, + errMsg: "name 必填", + }, + { + name: "密码缺失", + config: AuthConfig{ + Type: "basic", + Algorithm: "bcrypt", + Users: []User{{Name: "admin", Password: ""}}, + }, + wantErr: true, + errMsg: "password 必填", + }, + { + name: "无效哈希算法", + config: AuthConfig{ + Type: "basic", + Algorithm: "md5", + Users: []User{{Name: "admin", Password: "hashed_password"}}, + }, + wantErr: true, + errMsg: "不支持的哈希算法", + }, + { + name: "空算法默认有效", + config: AuthConfig{ + Type: "basic", + Users: []User{{Name: "admin", Password: "hashed_password"}}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateAuth(&tt.config) + if tt.wantErr { + if err == nil { + t.Errorf("validateAuth() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateAuth() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateAuth() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateRateLimit(t *testing.T) { + tests := []struct { + name string + config RateLimitConfig + wantErr bool + errMsg string + }{ + { + name: "未配置速率限制", + config: RateLimitConfig{}, + wantErr: false, + }, + { + name: "有效速率限制配置", + config: RateLimitConfig{ + RequestRate: 100, + Burst: 20, + Key: "ip", + }, + wantErr: false, + }, + { + name: "负数RequestRate", + config: RateLimitConfig{ + RequestRate: -1, + }, + wantErr: true, + errMsg: "request_rate 不能为负数", + }, + { + name: "负数Burst", + config: RateLimitConfig{ + RequestRate: 100, + Burst: -1, + }, + wantErr: true, + errMsg: "burst 不能为负数", + }, + { + name: "负数ConnLimit", + config: RateLimitConfig{ + ConnLimit: -1, + }, + wantErr: true, + errMsg: "conn_limit 不能为负数", + }, + { + name: "无效Key来源", + config: RateLimitConfig{ + RequestRate: 100, + Key: "invalid_key", + }, + wantErr: true, + errMsg: "无效的 key 来源", + }, + { + name: "仅ConnLimit配置", + config: RateLimitConfig{ + ConnLimit: 10, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateRateLimit(&tt.config) + if tt.wantErr { + if err == nil { + t.Errorf("validateRateLimit() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateRateLimit() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateRateLimit() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateCompression(t *testing.T) { + tests := []struct { + name string + config CompressionConfig + wantErr bool + errMsg string + }{ + { + name: "未配置压缩", + config: CompressionConfig{}, + wantErr: false, + }, + { + name: "有效gzip压缩配置", + config: CompressionConfig{ + Type: "gzip", + Level: 6, + MinSize: 1024, + }, + wantErr: false, + }, + { + name: "有效brotli压缩配置", + config: CompressionConfig{ + Type: "brotli", + Level: 4, + MinSize: 512, + }, + wantErr: false, + }, + { + name: "无效压缩类型", + config: CompressionConfig{ + Type: "lz4", + }, + wantErr: true, + errMsg: "无效的压缩类型", + }, + { + name: "级别过低", + config: CompressionConfig{ + Type: "gzip", + Level: -1, + }, + wantErr: true, + errMsg: "无效的压缩级别", + }, + { + name: "级别过高", + config: CompressionConfig{ + Type: "gzip", + Level: 10, + }, + wantErr: true, + errMsg: "无效的压缩级别", + }, + { + name: "负数MinSize", + config: CompressionConfig{ + Type: "gzip", + MinSize: -100, + }, + wantErr: true, + errMsg: "min_size 不能为负数", + }, + { + name: "级别0有效", + config: CompressionConfig{ + Type: "gzip", + Level: 0, + }, + wantErr: false, + }, + { + name: "级别9有效", + config: CompressionConfig{ + Type: "gzip", + Level: 9, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateCompression(&tt.config) + if tt.wantErr { + if err == nil { + t.Errorf("validateCompression() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateCompression() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateCompression() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateAccess(t *testing.T) { + tests := []struct { + name string + config AccessConfig + wantErr bool + errMsg string + }{ + { + name: "空配置有效", + config: AccessConfig{}, + wantErr: false, + }, + { + name: "有效CIDR", + config: AccessConfig{ + Allow: []string{"192.168.1.0/24", "10.0.0.0/8"}, + }, + wantErr: false, + }, + { + name: "有效单个IP", + config: AccessConfig{ + Allow: []string{"192.168.1.100"}, + Deny: []string{"10.0.0.1"}, + }, + wantErr: false, + }, + { + name: "有效IPv6 CIDR", + config: AccessConfig{ + Allow: []string{"2001:db8::/32"}, + }, + wantErr: false, + }, + { + name: "有效IPv6地址", + config: AccessConfig{ + Allow: []string{"::1", "2001:db8::1"}, + }, + wantErr: false, + }, + { + name: "无效CIDR格式", + config: AccessConfig{ + Allow: []string{"invalid-cidr"}, + }, + wantErr: true, + errMsg: "无效的 allow CIDR/IP", + }, + { + name: "无效Deny CIDR", + config: AccessConfig{ + Deny: []string{"not-a-cidr"}, + }, + wantErr: true, + errMsg: "无效的 deny CIDR/IP", + }, + { + name: "有效默认动作allow", + config: AccessConfig{ + Default: "allow", + }, + wantErr: false, + }, + { + name: "有效默认动作deny", + config: AccessConfig{ + Default: "deny", + }, + wantErr: false, + }, + { + name: "无效默认动作", + config: AccessConfig{ + Default: "reject", + }, + wantErr: true, + errMsg: "无效的 default 动作", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateAccess(&tt.config) + if tt.wantErr { + if err == nil { + t.Errorf("validateAccess() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateAccess() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateAccess() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateStatic(t *testing.T) { + tests := []struct { + name string + config StaticConfig + wantErr bool + errMsg string + }{ + { + name: "空配置有效", + config: StaticConfig{}, + wantErr: false, + }, + { + name: "有效根目录", + config: StaticConfig{ + Root: "/var/www/html", + }, + wantErr: false, + }, + { + name: "根目录含..路径遍历", + config: StaticConfig{ + Root: "/var/www/../etc", + }, + wantErr: true, + errMsg: "根目录路径不能包含 '..'", + }, + { + name: "根目录含多个..", + config: StaticConfig{ + Root: "/var/../www/../html", + }, + wantErr: true, + errMsg: "根目录路径不能包含 '..'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateStatic(&tt.config) + if tt.wantErr { + if err == nil { + t.Errorf("validateStatic() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateStatic() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateStatic() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} + +func TestValidateSecurity(t *testing.T) { + tests := []struct { + name string + config SecurityConfig + wantErr bool + errMsg string + }{ + { + name: "空配置有效", + config: SecurityConfig{}, + wantErr: false, + }, + { + name: "有效安全配置", + config: SecurityConfig{ + Access: AccessConfig{ + Allow: []string{"192.168.1.0/24"}, + }, + Auth: AuthConfig{ + Type: "basic", + Users: []User{{Name: "admin", Password: "hashed"}}, + }, + RateLimit: RateLimitConfig{ + RequestRate: 100, + }, + }, + wantErr: false, + }, + { + name: "无效Access配置", + config: SecurityConfig{ + Access: AccessConfig{ + Allow: []string{"invalid-ip"}, + }, + }, + wantErr: true, + errMsg: "无效的 allow CIDR/IP", + }, + { + name: "无效Auth配置", + config: SecurityConfig{ + Auth: AuthConfig{ + Type: "invalid", + Users: []User{{Name: "admin", Password: "hashed"}}, + }, + }, + wantErr: true, + errMsg: "不支持的认证类型", + }, + { + name: "无效RateLimit配置", + config: SecurityConfig{ + RateLimit: RateLimitConfig{ + RequestRate: -1, + }, + }, + wantErr: true, + errMsg: "request_rate 不能为负数", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateSecurity(&tt.config) + if tt.wantErr { + if err == nil { + t.Errorf("validateSecurity() 期望返回错误,但返回 nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("validateSecurity() 错误消息不匹配,期望包含 %q,实际 %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("validateSecurity() 期望返回 nil,但返回错误: %v", err) + } + } + }) + } +} \ No newline at end of file