test(config): 添加配置模块单元测试

- 添加 internal/app 包测试(版本显示、配置生成)
- 添加 internal/config 包测试(加载、保存、验证、默认值)
- 更新 docs/plan.md 日志系统设计(选用 zerolog)
- 更新 .gitignore 添加 coverage.html

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-02 14:40:56 +08:00
parent f3e5aad21e
commit 06e8d55ef5
6 changed files with 1781 additions and 34 deletions

3
.gitignore vendored
View File

@ -59,4 +59,5 @@ temp/
CLAUDE.md
lolly.yaml
config.yaml
lolly
lolly
coverage.html

View File

@ -286,35 +286,62 @@ type VHostManager struct {
**原因**:调试 Phase 2-4 功能需要日志支持,将日志系统基础版本提前实现。
**选型**:使用 [zerolog](https://github.com/rs/zerolog) 作为日志库。
**选择理由**
- **零分配**:高并发场景 GC 压力最小,性能最优
- **JSON 输出**便于日志采集系统ELK、Loki解析
- **API 简洁**:链式调用风格,开发体验好
- **灵活输出**:支持 stdout/stderr/文件,开发模式可选 ConsoleWriter 美化
**性能对比**10 条日志,禁用输出):
| 库 | ns/op | allocs/op |
|----|-------|-----------|
| zerolog | ~40ns | **0** |
| zap (structured) | ~50ns | 0 |
| slog (Go 1.21+) | ~200ns | 5+ |
| logrus | ~2000ns | 23 |
**实现**
```go
// internal/logging/logging.go
// Logger 日志管理器
type Logger struct {
level LogLevel
output io.Writer
import "github.com/rs/zerolog"
// 全局日志实例
var log zerolog.Logger
// Init 初始化日志系统
func Init(level string, pretty bool) {
l := parseLevel(level)
if pretty {
// 开发模式:带颜色和格式化(性能较差,仅开发用)
log = zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout})
} else {
// 生产模式JSON 输出
log = zerolog.New(os.Stdout).Level(l).With().Timestamp().Logger()
}
}
// LogLevel 日志级别
type LogLevel int
const (
LogLevelDebug LogLevel = iota
LogLevelInfo
LogLevelWarn
LogLevelError
)
// AccessLogger 访问日志(基础版)
func LogAccess(r *http.Request, status int, size int64, duration time.Duration)
func LogAccess(r *http.Request, status int, size int64, duration time.Duration) {
log.Info().
Str("method", r.Method).
Str("path", r.URL.Path).
Int("status", status).
Int64("size", size).
Dur("duration", duration).
Msg("request")
}
```
**Phase 2 实现范围**
- 基础请求日志:记录请求方法、路径、状态码
- 控制台输出:开发阶段便于调试
- Phase 5 将扩展为完整日志系统(文件输出、自定义格式)
- 控制台输出:开发阶段便于调试ConsoleWriter 美化)
- Phase 5 将扩展为完整日志系统(文件输出、自定义格式、访问/错误日志分离
### 验证方法
@ -906,42 +933,76 @@ cache:
#### 5.4 日志系统
**扩展 Phase 2 的 zerolog 实现**,增加文件输出和访问/错误日志分离。
**实现**
```go
// internal/logging/logging.go
import (
"io"
"github.com/rs/zerolog"
)
// Logger 日志管理器
type Logger struct {
accessLog *AccessLogger
errorLog *ErrorLogger
level LogLevel
accessLog zerolog.Logger // 访问日志
errorLog zerolog.Logger // 错误日志
}
// AccessLogger 访问日志
type AccessLogger struct {
format string // 日志格式
output io.Writer
// New 创建日志管理器
func New(cfg *LoggingConfig) *Logger {
// 访问日志stdout 或文件
accessOut := getOutput(cfg.Access.Path)
accessLog := zerolog.New(accessOut).With().Timestamp().Logger()
// 错误日志stderr 或文件
errorOut := getOutput(cfg.Error.Path)
errorLevel := parseLevel(cfg.Error.Level)
errorLog := zerolog.New(errorOut).Level(errorLevel).With().Timestamp().Logger()
return &Logger{accessLog: accessLog, errorLog: errorLog}
}
// LogFormat 日志格式变量
// $remote_addr - 客户端 IP
// $request - 请求行
// $status - 响应状态码
// $body_bytes_sent - 响应体大小
// $request_time - 请求耗时
// LogAccess 记录访问日志nginx 格式变量)
func (l *Logger) LogAccess(r *http.Request, status int, size int64, duration time.Duration) {
l.accessLog.Info().
Str("remote_addr", r.RemoteAddr).
Str("request", fmt.Sprintf("%s %s", r.Method, r.URL.Path)).
Int("status", status).
Int64("body_bytes_sent", size).
Dur("request_time", duration).
Msg("")
}
// getOutput 获取输出目标stdout/stderr/文件)
func getOutput(path string) io.Writer {
if path == "" {
return os.Stdout
}
f, _ := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
return f
}
```
**日志格式变量**(支持 nginx 风格配置):
- `$remote_addr` - 客户端 IP
- `$request` - 请求行(方法 + 路径)
- `$status` - 响应状态码
- `$body_bytes_sent` - 响应体大小
- `$request_time` - 请求耗时
**配置示例**
```yaml
logging:
access:
path: /var/log/lolly/access.log
format: "$remote_addr - $request - $status - $body_bytes_sent"
path: /var/log/lolly/access.log # 留空则输出到 stdout
format: json # json 或 textPhase 2 ConsoleWriter
error:
path: /var/log/lolly/error.log
level: info # debug/info/warn/error
path: /var/log/lolly/error.log # 留空则输出到 stderr
level: info # debug/info/warn/error
```
#### 5.5 状态监控端点

237
internal/app/app_test.go Normal file
View File

@ -0,0 +1,237 @@
// Package app 提供应用程序的启动和运行逻辑。
package app
import (
"bytes"
"os"
"path/filepath"
"strings"
"testing"
)
// captureStdout 捕获 stdout 输出,返回捕获的内容和恢复函数。
func captureStdout(t *testing.T) (func() string, func()) {
t.Helper()
old := os.Stdout
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("创建 pipe 失败: %v", err)
}
os.Stdout = w
return func() string {
w.Close()
os.Stdout = old
var buf bytes.Buffer
buf.ReadFrom(r)
return buf.String()
}, func() {
w.Close()
os.Stdout = old
}
}
// captureStderr 捕获 stderr 输出,返回捕获的内容和恢复函数。
func captureStderr(t *testing.T) (func() string, func()) {
t.Helper()
old := os.Stderr
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("创建 pipe 失败: %v", err)
}
os.Stderr = w
return func() string {
w.Close()
os.Stderr = old
var buf bytes.Buffer
buf.ReadFrom(r)
return buf.String()
}, func() {
w.Close()
os.Stderr = old
}
}
// TestRun 测试 Run 函数的各种场景。
func TestRun(t *testing.T) {
tests := []struct {
name string
cfgPath string
genConfig bool
outputPath string
showVersion bool
wantExitCode int
wantContains string // stdout 应包含的内容
wantErrContains string // stderr 应包含的内容(可选)
}{
{
name: "显示版本",
showVersion: true,
wantExitCode: 0,
wantContains: "lolly version",
},
{
name: "生成配置输出到 stdout",
genConfig: true,
outputPath: "",
wantExitCode: 0,
wantContains: "server:",
},
{
name: "生成配置输出到文件",
genConfig: true,
outputPath: filepath.Join(t.TempDir(), "config.yaml"),
wantExitCode: 0,
wantContains: "配置已写入:",
},
{
name: "配置文件不存在",
cfgPath: filepath.Join(t.TempDir(), "nonexistent.yaml"),
genConfig: false,
showVersion: false,
wantExitCode: 1,
wantErrContains: "加载配置失败",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
getStdout, restoreStdout := captureStdout(t)
getStderr, restoreStderr := captureStderr(t)
exitCode := Run(tt.cfgPath, tt.genConfig, tt.outputPath, tt.showVersion)
restoreStderr()
restoreStdout()
stdout := getStdout()
stderr := getStderr()
if exitCode != tt.wantExitCode {
t.Errorf("exit code = %d, want %d", exitCode, tt.wantExitCode)
}
if tt.wantContains != "" && !strings.Contains(stdout, tt.wantContains) {
t.Errorf("stdout 应包含 %q, 实际输出: %q", tt.wantContains, stdout)
}
if tt.wantErrContains != "" && !strings.Contains(stderr, tt.wantErrContains) {
t.Errorf("stderr 应包含 %q, 实际输出: %q", tt.wantErrContains, stderr)
}
// 验证生成配置文件的内容
if tt.outputPath != "" && tt.genConfig && exitCode == 0 {
data, err := os.ReadFile(tt.outputPath)
if err != nil {
t.Errorf("读取生成的配置文件失败: %v", err)
} else if !strings.Contains(string(data), "server:") {
t.Errorf("生成的配置文件应包含 'server:', 实际内容: %s", string(data)[:100])
}
}
})
}
}
// TestGenerateConfig 测试 generateConfig 函数。
func TestGenerateConfig(t *testing.T) {
t.Run("输出到 stdout", func(t *testing.T) {
getStdout, restoreStdout := captureStdout(t)
exitCode := generateConfig("")
restoreStdout()
stdout := getStdout()
if exitCode != 0 {
t.Errorf("exit code = %d, want 0", exitCode)
}
// 验证输出包含基本配置结构
expectedFields := []string{"server:", "listen:", "logging:", "performance:", "monitoring:"}
for _, field := range expectedFields {
if !strings.Contains(stdout, field) {
t.Errorf("输出应包含 %q", field)
}
}
})
t.Run("输出到文件", func(t *testing.T) {
tmpDir := t.TempDir()
outputPath := filepath.Join(tmpDir, "test-config.yaml")
getStdout, restoreStdout := captureStdout(t)
exitCode := generateConfig(outputPath)
restoreStdout()
stdout := getStdout()
if exitCode != 0 {
t.Errorf("exit code = %d, want 0", exitCode)
}
if !strings.Contains(stdout, outputPath) {
t.Errorf("stdout 应包含文件路径 %q, 实际输出: %q", outputPath, stdout)
}
// 验证文件存在且内容正确
data, err := os.ReadFile(outputPath)
if err != nil {
t.Fatalf("读取生成的配置文件失败: %v", err)
}
content := string(data)
expectedFields := []string{"server:", "listen:", "logging:", "performance:", "monitoring:"}
for _, field := range expectedFields {
if !strings.Contains(content, field) {
t.Errorf("配置文件应包含 %q", field)
}
}
})
t.Run("输出到无效路径", func(t *testing.T) {
// 使用一个无法写入的路径(如根目录下的文件)
invalidPath := "/root/cannot-write-here.yaml"
getStderr, restoreStderr := captureStderr(t)
exitCode := generateConfig(invalidPath)
restoreStderr()
stderr := getStderr()
if exitCode != 1 {
t.Errorf("exit code = %d, want 1", exitCode)
}
if !strings.Contains(stderr, "写入文件失败") {
t.Errorf("stderr 应包含 '写入文件失败', 实际输出: %q", stderr)
}
})
}
// TestPrintVersion 测试 printVersion 函数。
func TestPrintVersion(t *testing.T) {
getStdout, restoreStdout := captureStdout(t)
printVersion()
restoreStdout()
stdout := getStdout()
// 验证版本输出格式
expectedLines := []string{
"lolly version",
"Git:",
"Built:",
"Go:",
"Platform:",
}
for _, line := range expectedLines {
if !strings.Contains(stdout, line) {
t.Errorf("版本输出应包含 %q, 实际输出: %q", line, stdout)
}
}
}

View File

@ -0,0 +1,439 @@
// Package config 提供配置加载和管理的测试。
package config
import (
"os"
"path/filepath"
"testing"
)
// TestLoad 测试从文件加载配置。
func TestLoad(t *testing.T) {
t.Run("有效配置文件", func(t *testing.T) {
// 创建临时配置文件
content := `
server:
listen: ":8080"
static:
root: "/var/www"
index:
- "index.html"
logging:
access:
path: "/var/log/access.log"
format: "combined"
error:
path: "/var/log/error.log"
level: "info"
performance:
goroutine_pool:
enabled: true
max_workers: 100
file_cache:
max_entries: 1000
monitoring:
status:
path: "/status"
`
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil {
t.Fatalf("创建临时配置文件失败: %v", err)
}
cfg, err := Load(tmpFile)
if err != nil {
t.Fatalf("Load() 失败: %v", err)
}
if cfg.Server.Listen != ":8080" {
t.Errorf("Server.Listen = %q, want %q", cfg.Server.Listen, ":8080")
}
if cfg.Server.Static.Root != "/var/www" {
t.Errorf("Server.Static.Root = %q, want %q", cfg.Server.Static.Root, "/var/www")
}
if len(cfg.Server.Static.Index) != 1 || cfg.Server.Static.Index[0] != "index.html" {
t.Errorf("Server.Static.Index = %v, want [index.html]", cfg.Server.Static.Index)
}
})
t.Run("文件不存在", func(t *testing.T) {
_, err := Load("/nonexistent/path/config.yaml")
if err == nil {
t.Error("Load() 期望返回错误,但返回 nil")
}
})
t.Run("无效YAML", func(t *testing.T) {
content := `
server:
listen: ":8080"
static:
root: [invalid yaml structure
`
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "invalid.yaml")
if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil {
t.Fatalf("创建临时配置文件失败: %v", err)
}
_, err := Load(tmpFile)
if err == nil {
t.Error("Load() 期望返回错误,但返回 nil")
}
})
t.Run("缺少必填字段(无服务器配置)", func(t *testing.T) {
content := `
logging:
access:
path: "/var/log/access.log"
`
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "no_server.yaml")
if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil {
t.Fatalf("创建临时配置文件失败: %v", err)
}
_, err := Load(tmpFile)
if err == nil {
t.Error("Load() 期望返回错误,但返回 nil")
}
})
t.Run("多虚拟主机模式", func(t *testing.T) {
content := `
servers:
- listen: ":8080"
name: "server1"
- listen: ":8081"
name: "server2"
`
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "multi.yaml")
if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil {
t.Fatalf("创建临时配置文件失败: %v", err)
}
cfg, err := Load(tmpFile)
if err != nil {
t.Fatalf("Load() 失败: %v", err)
}
if len(cfg.Servers) != 2 {
t.Fatalf("len(Servers) = %d, want 2", len(cfg.Servers))
}
if cfg.Servers[0].Name != "server1" {
t.Errorf("Servers[0].Name = %q, want %q", cfg.Servers[0].Name, "server1")
}
if cfg.Servers[1].Name != "server2" {
t.Errorf("Servers[1].Name = %q, want %q", cfg.Servers[1].Name, "server2")
}
})
}
// TestLoadFromString 测试从字符串加载配置。
func TestLoadFromString(t *testing.T) {
t.Run("有效字符串", func(t *testing.T) {
yamlStr := `
server:
listen: ":9090"
static:
root: "/app/public"
`
cfg, err := LoadFromString(yamlStr)
if err != nil {
t.Fatalf("LoadFromString() 失败: %v", err)
}
if cfg.Server.Listen != ":9090" {
t.Errorf("Server.Listen = %q, want %q", cfg.Server.Listen, ":9090")
}
if cfg.Server.Static.Root != "/app/public" {
t.Errorf("Server.Static.Root = %q, want %q", cfg.Server.Static.Root, "/app/public")
}
})
t.Run("无效YAML", func(t *testing.T) {
yamlStr := `
server:
listen: ":8080"
broken: [unclosed
`
_, err := LoadFromString(yamlStr)
if err == nil {
t.Error("LoadFromString() 期望返回错误,但返回 nil")
}
})
t.Run("缺少必填字段", func(t *testing.T) {
yamlStr := `
logging:
access:
path: "/var/log/access.log"
`
_, err := LoadFromString(yamlStr)
if err == nil {
t.Error("LoadFromString() 期望返回错误,但返回 nil")
}
})
t.Run("空字符串", func(t *testing.T) {
_, err := LoadFromString("")
if err == nil {
t.Error("LoadFromString() 期望返回错误,但返回 nil")
}
})
}
// TestSave 测试保存配置到文件。
func TestSave(t *testing.T) {
t.Run("正常保存", func(t *testing.T) {
cfg := &Config{
Server: ServerConfig{
Listen: ":8080",
Static: StaticConfig{
Root: "/var/www",
Index: []string{"index.html"},
},
},
}
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "saved_config.yaml")
if err := Save(cfg, tmpFile); err != nil {
t.Fatalf("Save() 失败: %v", err)
}
// 验证文件已创建并可重新加载
loaded, err := Load(tmpFile)
if err != nil {
t.Fatalf("重新加载配置失败: %v", err)
}
if loaded.Server.Listen != cfg.Server.Listen {
t.Errorf("loaded.Server.Listen = %q, want %q", loaded.Server.Listen, cfg.Server.Listen)
}
if loaded.Server.Static.Root != cfg.Server.Static.Root {
t.Errorf("loaded.Server.Static.Root = %q, want %q", loaded.Server.Static.Root, cfg.Server.Static.Root)
}
})
t.Run("无效路径", func(t *testing.T) {
cfg := &Config{
Server: ServerConfig{
Listen: ":8080",
},
}
err := Save(cfg, "/nonexistent/directory/config.yaml")
if err == nil {
t.Error("Save() 期望返回错误,但返回 nil")
}
})
t.Run("保存多虚拟主机配置", func(t *testing.T) {
cfg := &Config{
Servers: []ServerConfig{
{Listen: ":8080", Name: "server1"},
{Listen: ":8081", Name: "server2"},
},
}
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "multi.yaml")
if err := Save(cfg, tmpFile); err != nil {
t.Fatalf("Save() 失败: %v", err)
}
loaded, err := Load(tmpFile)
if err != nil {
t.Fatalf("重新加载配置失败: %v", err)
}
if len(loaded.Servers) != 2 {
t.Errorf("len(loaded.Servers) = %d, want 2", len(loaded.Servers))
}
})
t.Run("保存并加载完整配置", func(t *testing.T) {
cfg := &Config{
Server: ServerConfig{
Listen: ":8443",
Name: "default",
Static: StaticConfig{
Root: "/var/www/html",
Index: []string{"index.html", "index.htm"},
},
Proxy: []ProxyConfig{
{
Path: "/api",
Targets: []ProxyTarget{
{URL: "http://backend1:8080", Weight: 1},
{URL: "http://backend2:8080", Weight: 2},
},
LoadBalance: "round_robin",
},
},
SSL: SSLConfig{
Cert: "/etc/ssl/cert.pem",
Key: "/etc/ssl/key.pem",
Protocols: []string{"TLSv1.2", "TLSv1.3"},
},
Security: SecurityConfig{
RateLimit: RateLimitConfig{
RequestRate: 100,
Burst: 200,
},
},
},
Logging: LoggingConfig{
Access: AccessLogConfig{
Path: "/var/log/access.log",
Format: "combined",
},
Error: ErrorLogConfig{
Path: "/var/log/error.log",
Level: "warn",
},
},
Performance: PerformanceConfig{
GoroutinePool: GoroutinePoolConfig{
Enabled: true,
MaxWorkers: 1000,
MinWorkers: 10,
},
},
}
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "full.yaml")
if err := Save(cfg, tmpFile); err != nil {
t.Fatalf("Save() 失败: %v", err)
}
loaded, err := Load(tmpFile)
if err != nil {
t.Fatalf("重新加载配置失败: %v", err)
}
// 验证关键字段
if loaded.Server.Listen != cfg.Server.Listen {
t.Errorf("loaded.Server.Listen = %q, want %q", loaded.Server.Listen, cfg.Server.Listen)
}
if len(loaded.Server.Proxy) != 1 {
t.Errorf("len(loaded.Server.Proxy) = %d, want 1", len(loaded.Server.Proxy))
}
if loaded.Server.Proxy[0].LoadBalance != "round_robin" {
t.Errorf("loaded.Server.Proxy[0].LoadBalance = %q, want %q", loaded.Server.Proxy[0].LoadBalance, "round_robin")
}
})
}
// TestConfigMethods 测试 Config 结构体的方法。
func TestConfigMethods(t *testing.T) {
t.Run("HasServers_有服务器列表", func(t *testing.T) {
cfg := &Config{
Servers: []ServerConfig{
{Listen: ":8080"},
},
}
if !cfg.HasServers() {
t.Error("HasServers() = false, want true")
}
})
t.Run("HasServers_无服务器列表", func(t *testing.T) {
cfg := &Config{}
if cfg.HasServers() {
t.Error("HasServers() = true, want false")
}
})
t.Run("HasDefaultServer_有默认服务器", func(t *testing.T) {
cfg := &Config{
Server: ServerConfig{
Listen: ":8080",
},
}
if !cfg.HasDefaultServer() {
t.Error("HasDefaultServer() = false, want true")
}
})
t.Run("HasDefaultServer_无默认服务器", func(t *testing.T) {
cfg := &Config{}
if cfg.HasDefaultServer() {
t.Error("HasDefaultServer() = true, want false")
}
})
t.Run("GetDefaultServer_有默认服务器", func(t *testing.T) {
cfg := &Config{
Server: ServerConfig{
Listen: ":8080",
Name: "default",
},
}
server := cfg.GetDefaultServer()
if server == nil {
t.Fatal("GetDefaultServer() = nil, want non-nil")
}
if server.Listen != ":8080" {
t.Errorf("server.Listen = %q, want %q", server.Listen, ":8080")
}
})
t.Run("GetDefaultServer_无默认服务器", func(t *testing.T) {
cfg := &Config{}
server := cfg.GetDefaultServer()
if server != nil {
t.Errorf("GetDefaultServer() = %v, want nil", server)
}
})
t.Run("配置模式判断", func(t *testing.T) {
tests := []struct {
name string
cfg *Config
wantHasServers bool
wantHasDefault bool
}{
{
name: "仅默认服务器",
cfg: &Config{Server: ServerConfig{Listen: ":8080"}},
wantHasServers: false,
wantHasDefault: true,
},
{
name: "仅多虚拟主机",
cfg: &Config{Servers: []ServerConfig{{Listen: ":8080"}}},
wantHasServers: true,
wantHasDefault: false,
},
{
name: "混合模式",
cfg: &Config{
Server: ServerConfig{Listen: ":8080"},
Servers: []ServerConfig{{Listen: ":8081"}},
},
wantHasServers: true,
wantHasDefault: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.cfg.HasServers(); got != tt.wantHasServers {
t.Errorf("HasServers() = %v, want %v", got, tt.wantHasServers)
}
if got := tt.cfg.HasDefaultServer(); got != tt.wantHasDefault {
t.Errorf("HasDefaultServer() = %v, want %v", got, tt.wantHasDefault)
}
})
}
})
}

View File

@ -0,0 +1,196 @@
package config
import (
"strings"
"testing"
"time"
)
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
// 验证 Listen 默认值
if cfg.Server.Listen != ":8080" {
t.Errorf("Server.Listen 期望 :8080, 实际 %s", cfg.Server.Listen)
}
// 验证 SSL 默认版本
if len(cfg.Server.SSL.Protocols) != 2 {
t.Errorf("SSL.Protocols 期望 2 个版本, 实际 %d", len(cfg.Server.SSL.Protocols))
}
expectedProtocols := []string{"TLSv1.2", "TLSv1.3"}
for i, proto := range cfg.Server.SSL.Protocols {
if proto != expectedProtocols[i] {
t.Errorf("SSL.Protocols[%d] 期望 %s, 实际 %s", i, expectedProtocols[i], proto)
}
}
// 验证 HSTS 默认值
if cfg.Server.SSL.HSTS.MaxAge != 31536000 {
t.Errorf("HSTS.MaxAge 期望 31536000, 实际 %d", cfg.Server.SSL.HSTS.MaxAge)
}
if !cfg.Server.SSL.HSTS.IncludeSubDomains {
t.Errorf("HSTS.IncludeSubDomains 期望 true, 实际 %v", cfg.Server.SSL.HSTS.IncludeSubDomains)
}
if cfg.Server.SSL.HSTS.Preload {
t.Errorf("HSTS.Preload 期望 false, 实际 %v", cfg.Server.SSL.HSTS.Preload)
}
// 验证压缩默认值
if cfg.Server.Compression.Type != "gzip" {
t.Errorf("Compression.Type 期望 gzip, 实际 %s", cfg.Server.Compression.Type)
}
if cfg.Server.Compression.Level != 6 {
t.Errorf("Compression.Level 期望 6, 实际 %d", cfg.Server.Compression.Level)
}
if cfg.Server.Compression.MinSize != 1024 {
t.Errorf("Compression.MinSize 期望 1024, 实际 %d", cfg.Server.Compression.MinSize)
}
expectedTypes := []string{"text/html", "text/css", "text/javascript", "application/json", "application/javascript"}
for i, ct := range cfg.Server.Compression.Types {
if ct != expectedTypes[i] {
t.Errorf("Compression.Types[%d] 期望 %s, 实际 %s", i, expectedTypes[i], ct)
}
}
}
func TestGenerateConfigYAML(t *testing.T) {
cfg := DefaultConfig()
yamlData, err := GenerateConfigYAML(cfg)
if err != nil {
t.Fatalf("GenerateConfigYAML 返回错误: %v", err)
}
// 验证输出非空
if len(yamlData) == 0 {
t.Error("GenerateConfigYAML 输出为空")
}
yamlStr := string(yamlData)
// 验证包含注释
if !strings.Contains(yamlStr, "#") {
t.Error("YAML 输出未包含注释")
}
if !strings.Contains(yamlStr, "# Lolly 配置文件") {
t.Error("YAML 输出未包含文件头注释")
}
if !strings.Contains(yamlStr, "# 服务器配置") {
t.Error("YAML 输出未包含服务器配置注释")
}
// 验证可重新解析 - 使用 LoadFromString 解析生成的 YAML
// 注意GenerateConfigYAML 生成的 YAML 包含注释的示例配置(如 proxy、rewrite 等)
// 这些是注释掉的示例,不会被解析。需要提取实际生效的部分进行验证。
// 构建一个可解析的简化 YAML 进行验证
simpleYAML, err := GenerateSimpleYAML(cfg)
if err != nil {
t.Fatalf("GenerateSimpleYAML 返回错误: %v", err)
}
parsedCfg, err := LoadFromString(string(simpleYAML))
if err != nil {
t.Fatalf("解析生成的 YAML 失败: %v", err)
}
// 验证配置一致性
if parsedCfg.Server.Listen != cfg.Server.Listen {
t.Errorf("解析后 Server.Listen 不一致: 期望 %s, 实际 %s", cfg.Server.Listen, parsedCfg.Server.Listen)
}
if parsedCfg.Server.Name != cfg.Server.Name {
t.Errorf("解析后 Server.Name 不一致: 期望 %s, 实际 %s", cfg.Server.Name, parsedCfg.Server.Name)
}
if parsedCfg.Server.Compression.Type != cfg.Server.Compression.Type {
t.Errorf("解析后 Compression.Type 不一致: 期望 %s, 实际 %s", cfg.Server.Compression.Type, parsedCfg.Server.Compression.Type)
}
if parsedCfg.Server.Compression.Level != cfg.Server.Compression.Level {
t.Errorf("解析后 Compression.Level 不一致: 期望 %d, 实际 %d", cfg.Server.Compression.Level, parsedCfg.Server.Compression.Level)
}
// 验证性能配置一致性
if parsedCfg.Performance.GoroutinePool.MaxWorkers != cfg.Performance.GoroutinePool.MaxWorkers {
t.Errorf("解析后 GoroutinePool.MaxWorkers 不一致: 期望 %d, 实际 %d",
cfg.Performance.GoroutinePool.MaxWorkers, parsedCfg.Performance.GoroutinePool.MaxWorkers)
}
if parsedCfg.Performance.FileCache.MaxEntries != cfg.Performance.FileCache.MaxEntries {
t.Errorf("解析后 FileCache.MaxEntries 不一致: 期望 %d, 实际 %d",
cfg.Performance.FileCache.MaxEntries, parsedCfg.Performance.FileCache.MaxEntries)
}
// 验证时间.Duration 字段正确解析
if parsedCfg.Performance.GoroutinePool.IdleTimeout != cfg.Performance.GoroutinePool.IdleTimeout {
t.Errorf("解析后 GoroutinePool.IdleTimeout 不一致: 期望 %v, 实际 %v",
cfg.Performance.GoroutinePool.IdleTimeout, parsedCfg.Performance.GoroutinePool.IdleTimeout)
}
if parsedCfg.Performance.FileCache.Inactive != cfg.Performance.FileCache.Inactive {
t.Errorf("解析后 FileCache.Inactive 不一致: 期望 %v, 实际 %v",
cfg.Performance.FileCache.Inactive, parsedCfg.Performance.FileCache.Inactive)
}
}
func TestGenerateSimpleYAML(t *testing.T) {
cfg := DefaultConfig()
yamlData, err := GenerateSimpleYAML(cfg)
if err != nil {
t.Fatalf("GenerateSimpleYAML 返回错误: %v", err)
}
if len(yamlData) == 0 {
t.Error("GenerateSimpleYAML 输出为空")
}
// 验证不包含注释(简洁 YAML
yamlStr := string(yamlData)
if strings.Contains(yamlStr, "# Lolly 配置文件") {
t.Error("简洁 YAML 不应包含文件头注释")
}
}
func TestDefaultConfigPerformance(t *testing.T) {
cfg := DefaultConfig()
// 验证 GoroutinePool 默认值
if cfg.Performance.GoroutinePool.Enabled {
t.Errorf("GoroutinePool.Enabled 期望 false, 实际 %v", cfg.Performance.GoroutinePool.Enabled)
}
if cfg.Performance.GoroutinePool.MaxWorkers != 1000 {
t.Errorf("GoroutinePool.MaxWorkers 期望 1000, 实际 %d", cfg.Performance.GoroutinePool.MaxWorkers)
}
if cfg.Performance.GoroutinePool.MinWorkers != 10 {
t.Errorf("GoroutinePool.MinWorkers 期望 10, 实际 %d", cfg.Performance.GoroutinePool.MinWorkers)
}
if cfg.Performance.GoroutinePool.IdleTimeout != 60*time.Second {
t.Errorf("GoroutinePool.IdleTimeout 期望 60s, 实际 %v", cfg.Performance.GoroutinePool.IdleTimeout)
}
// 验证 FileCache 默认值
if cfg.Performance.FileCache.MaxEntries != 10000 {
t.Errorf("FileCache.MaxEntries 期望 10000, 实际 %d", cfg.Performance.FileCache.MaxEntries)
}
if cfg.Performance.FileCache.MaxSize != 256*1024*1024 {
t.Errorf("FileCache.MaxSize 期望 256MB, 实际 %d", cfg.Performance.FileCache.MaxSize)
}
if cfg.Performance.FileCache.Inactive != 20*time.Second {
t.Errorf("FileCache.Inactive 期望 20s, 实际 %v", cfg.Performance.FileCache.Inactive)
}
if !cfg.Performance.FileCache.LRUEviction {
t.Errorf("FileCache.LRUEviction 期望 true, 实际 %v", cfg.Performance.FileCache.LRUEviction)
}
// 验证 Transport 默认值
if cfg.Performance.Transport.MaxIdleConns != 100 {
t.Errorf("Transport.MaxIdleConns 期望 100, 实际 %d", cfg.Performance.Transport.MaxIdleConns)
}
if cfg.Performance.Transport.MaxIdleConnsPerHost != 32 {
t.Errorf("Transport.MaxIdleConnsPerHost 期望 32, 实际 %d", cfg.Performance.Transport.MaxIdleConnsPerHost)
}
if cfg.Performance.Transport.IdleConnTimeout != 90*time.Second {
t.Errorf("Transport.IdleConnTimeout 期望 90s, 实际 %v", cfg.Performance.Transport.IdleConnTimeout)
}
if cfg.Performance.Transport.MaxConnsPerHost != 0 {
t.Errorf("Transport.MaxConnsPerHost 期望 0 (不限制), 实际 %d", cfg.Performance.Transport.MaxConnsPerHost)
}
}

View File

@ -0,0 +1,813 @@
// Package config 提供 YAML 配置文件的解析、验证和默认配置生成功能。
package config
import (
"strings"
"testing"
)
func TestValidateServer(t *testing.T) {
tests := []struct {
name string
config ServerConfig
isDefault bool
wantErr bool
errMsg string
}{
{
name: "有效配置",
config: ServerConfig{
Listen: ":8080",
Static: StaticConfig{Root: "/var/www"},
Proxy: []ProxyConfig{
{Path: "/api", Targets: []ProxyTarget{{URL: "http://backend:8080"}}},
},
},
isDefault: false,
wantErr: false,
},
{
name: "默认服务器可省略Listen",
config: ServerConfig{
Static: StaticConfig{Root: "/var/www"},
},
isDefault: true,
wantErr: false,
},
{
name: "非默认服务器Listen缺失",
config: ServerConfig{
Static: StaticConfig{Root: "/var/www"},
},
isDefault: false,
wantErr: true,
errMsg: "listen 地址必填",
},
{
name: "无效Listen地址",
config: ServerConfig{
Listen: "invalid:address:format",
},
isDefault: false,
wantErr: true,
errMsg: "无效的监听地址",
},
{
name: "静态根目录含..",
config: ServerConfig{
Listen: ":8080",
Static: StaticConfig{Root: "/var/../www"},
},
isDefault: false,
wantErr: true,
errMsg: "根目录路径不能包含 '..'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateServer(&tt.config, tt.isDefault)
if tt.wantErr {
if err == nil {
t.Errorf("validateServer() 期望返回错误,但返回 nil")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateServer() 错误消息不匹配,期望包含 %q实际 %q", tt.errMsg, err.Error())
}
} else {
if err != nil {
t.Errorf("validateServer() 期望返回 nil但返回错误: %v", err)
}
}
})
}
}
func TestValidateProxy(t *testing.T) {
tests := []struct {
name string
config ProxyConfig
wantErr bool
errMsg string
}{
{
name: "有效代理配置",
config: ProxyConfig{
Path: "/api",
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
},
wantErr: false,
},
{
name: "有效代理带负载均衡",
config: ProxyConfig{
Path: "/api",
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
LoadBalance: "round_robin",
},
wantErr: false,
},
{
name: "Path缺失",
config: ProxyConfig{
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
},
wantErr: true,
errMsg: "path 必填",
},
{
name: "Targets空",
config: ProxyConfig{
Path: "/api",
Targets: []ProxyTarget{},
},
wantErr: true,
errMsg: "targets 至少需要一个目标地址",
},
{
name: "URL格式错误-无协议",
config: ProxyConfig{
Path: "/api",
Targets: []ProxyTarget{{URL: "backend:8080"}},
},
wantErr: true,
errMsg: "必须以 http:// 或 https:// 开头",
},
{
name: "URL格式错误-空URL",
config: ProxyConfig{
Path: "/api",
Targets: []ProxyTarget{{URL: ""}},
},
wantErr: true,
errMsg: "url 必填",
},
{
name: "无效负载均衡算法",
config: ProxyConfig{
Path: "/api",
Targets: []ProxyTarget{{URL: "http://backend:8080"}},
LoadBalance: "invalid_algorithm",
},
wantErr: true,
errMsg: "无效的负载均衡算法",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateProxy(&tt.config)
if tt.wantErr {
if err == nil {
t.Errorf("validateProxy() 期望返回错误,但返回 nil")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateProxy() 错误消息不匹配,期望包含 %q实际 %q", tt.errMsg, err.Error())
}
} else {
if err != nil {
t.Errorf("validateProxy() 期望返回 nil但返回错误: %v", err)
}
}
})
}
}
func TestValidateSSL(t *testing.T) {
tests := []struct {
name string
config SSLConfig
wantErr bool
errMsg string
}{
{
name: "未配置SSL",
config: SSLConfig{},
wantErr: false,
},
{
name: "有效SSL配置",
config: SSLConfig{
Cert: "/path/to/cert.pem",
Key: "/path/to/key.pem",
Protocols: []string{"TLSv1.2", "TLSv1.3"},
},
wantErr: false,
},
{
name: "仅Cert配置",
config: SSLConfig{
Cert: "/path/to/cert.pem",
},
wantErr: true,
errMsg: "cert 和 key 必须同时配置",
},
{
name: "仅Key配置",
config: SSLConfig{
Key: "/path/to/key.pem",
},
wantErr: true,
errMsg: "cert 和 key 必须同时配置",
},
{
name: "TLSv1.0不安全",
config: SSLConfig{
Cert: "/path/to/cert.pem",
Key: "/path/to/key.pem",
Protocols: []string{"TLSv1.0"},
},
wantErr: true,
errMsg: "不安全的 TLS 版本: TLSv1.0",
},
{
name: "TLSv1.1不安全",
config: SSLConfig{
Cert: "/path/to/cert.pem",
Key: "/path/to/key.pem",
Protocols: []string{"TLSv1.1"},
},
wantErr: true,
errMsg: "不安全的 TLS 版本: TLSv1.1",
},
{
name: "不安全加密套件RC4",
config: SSLConfig{
Cert: "/path/to/cert.pem",
Key: "/path/to/key.pem",
Protocols: []string{"TLSv1.2"},
Ciphers: []string{"RC4-SHA"},
},
wantErr: true,
errMsg: "不安全的加密套件",
},
{
name: "不安全加密套件DES",
config: SSLConfig{
Cert: "/path/to/cert.pem",
Key: "/path/to/key.pem",
Protocols: []string{"TLSv1.2"},
Ciphers: []string{"DES-CBC3-SHA"},
},
wantErr: true,
errMsg: "不安全的加密套件",
},
{
name: "未知TLS版本",
config: SSLConfig{
Cert: "/path/to/cert.pem",
Key: "/path/to/key.pem",
Protocols: []string{"TLSv1.4"},
},
wantErr: true,
errMsg: "未知的 TLS 版本",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateSSL(&tt.config)
if tt.wantErr {
if err == nil {
t.Errorf("validateSSL() 期望返回错误,但返回 nil")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateSSL() 错误消息不匹配,期望包含 %q实际 %q", tt.errMsg, err.Error())
}
} else {
if err != nil {
t.Errorf("validateSSL() 期望返回 nil但返回错误: %v", err)
}
}
})
}
}
func TestValidateAuth(t *testing.T) {
tests := []struct {
name string
config AuthConfig
wantErr bool
errMsg string
}{
{
name: "未配置认证",
config: AuthConfig{},
wantErr: false,
},
{
name: "有效Basic认证配置",
config: AuthConfig{
Type: "basic",
Algorithm: "bcrypt",
Users: []User{{Name: "admin", Password: "hashed_password"}},
},
wantErr: false,
},
{
name: "无效认证类型",
config: AuthConfig{
Type: "oauth",
Users: []User{{Name: "admin", Password: "hashed_password"}},
},
wantErr: true,
errMsg: "不支持的认证类型",
},
{
name: "启用认证但无用户",
config: AuthConfig{
Type: "basic",
Algorithm: "bcrypt",
Users: []User{},
},
wantErr: true,
errMsg: "启用认证时至少需要一个用户",
},
{
name: "用户名缺失",
config: AuthConfig{
Type: "basic",
Algorithm: "bcrypt",
Users: []User{{Name: "", Password: "hashed_password"}},
},
wantErr: true,
errMsg: "name 必填",
},
{
name: "密码缺失",
config: AuthConfig{
Type: "basic",
Algorithm: "bcrypt",
Users: []User{{Name: "admin", Password: ""}},
},
wantErr: true,
errMsg: "password 必填",
},
{
name: "无效哈希算法",
config: AuthConfig{
Type: "basic",
Algorithm: "md5",
Users: []User{{Name: "admin", Password: "hashed_password"}},
},
wantErr: true,
errMsg: "不支持的哈希算法",
},
{
name: "空算法默认有效",
config: AuthConfig{
Type: "basic",
Users: []User{{Name: "admin", Password: "hashed_password"}},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateAuth(&tt.config)
if tt.wantErr {
if err == nil {
t.Errorf("validateAuth() 期望返回错误,但返回 nil")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateAuth() 错误消息不匹配,期望包含 %q实际 %q", tt.errMsg, err.Error())
}
} else {
if err != nil {
t.Errorf("validateAuth() 期望返回 nil但返回错误: %v", err)
}
}
})
}
}
func TestValidateRateLimit(t *testing.T) {
tests := []struct {
name string
config RateLimitConfig
wantErr bool
errMsg string
}{
{
name: "未配置速率限制",
config: RateLimitConfig{},
wantErr: false,
},
{
name: "有效速率限制配置",
config: RateLimitConfig{
RequestRate: 100,
Burst: 20,
Key: "ip",
},
wantErr: false,
},
{
name: "负数RequestRate",
config: RateLimitConfig{
RequestRate: -1,
},
wantErr: true,
errMsg: "request_rate 不能为负数",
},
{
name: "负数Burst",
config: RateLimitConfig{
RequestRate: 100,
Burst: -1,
},
wantErr: true,
errMsg: "burst 不能为负数",
},
{
name: "负数ConnLimit",
config: RateLimitConfig{
ConnLimit: -1,
},
wantErr: true,
errMsg: "conn_limit 不能为负数",
},
{
name: "无效Key来源",
config: RateLimitConfig{
RequestRate: 100,
Key: "invalid_key",
},
wantErr: true,
errMsg: "无效的 key 来源",
},
{
name: "仅ConnLimit配置",
config: RateLimitConfig{
ConnLimit: 10,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateRateLimit(&tt.config)
if tt.wantErr {
if err == nil {
t.Errorf("validateRateLimit() 期望返回错误,但返回 nil")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateRateLimit() 错误消息不匹配,期望包含 %q实际 %q", tt.errMsg, err.Error())
}
} else {
if err != nil {
t.Errorf("validateRateLimit() 期望返回 nil但返回错误: %v", err)
}
}
})
}
}
func TestValidateCompression(t *testing.T) {
tests := []struct {
name string
config CompressionConfig
wantErr bool
errMsg string
}{
{
name: "未配置压缩",
config: CompressionConfig{},
wantErr: false,
},
{
name: "有效gzip压缩配置",
config: CompressionConfig{
Type: "gzip",
Level: 6,
MinSize: 1024,
},
wantErr: false,
},
{
name: "有效brotli压缩配置",
config: CompressionConfig{
Type: "brotli",
Level: 4,
MinSize: 512,
},
wantErr: false,
},
{
name: "无效压缩类型",
config: CompressionConfig{
Type: "lz4",
},
wantErr: true,
errMsg: "无效的压缩类型",
},
{
name: "级别过低",
config: CompressionConfig{
Type: "gzip",
Level: -1,
},
wantErr: true,
errMsg: "无效的压缩级别",
},
{
name: "级别过高",
config: CompressionConfig{
Type: "gzip",
Level: 10,
},
wantErr: true,
errMsg: "无效的压缩级别",
},
{
name: "负数MinSize",
config: CompressionConfig{
Type: "gzip",
MinSize: -100,
},
wantErr: true,
errMsg: "min_size 不能为负数",
},
{
name: "级别0有效",
config: CompressionConfig{
Type: "gzip",
Level: 0,
},
wantErr: false,
},
{
name: "级别9有效",
config: CompressionConfig{
Type: "gzip",
Level: 9,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateCompression(&tt.config)
if tt.wantErr {
if err == nil {
t.Errorf("validateCompression() 期望返回错误,但返回 nil")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateCompression() 错误消息不匹配,期望包含 %q实际 %q", tt.errMsg, err.Error())
}
} else {
if err != nil {
t.Errorf("validateCompression() 期望返回 nil但返回错误: %v", err)
}
}
})
}
}
func TestValidateAccess(t *testing.T) {
tests := []struct {
name string
config AccessConfig
wantErr bool
errMsg string
}{
{
name: "空配置有效",
config: AccessConfig{},
wantErr: false,
},
{
name: "有效CIDR",
config: AccessConfig{
Allow: []string{"192.168.1.0/24", "10.0.0.0/8"},
},
wantErr: false,
},
{
name: "有效单个IP",
config: AccessConfig{
Allow: []string{"192.168.1.100"},
Deny: []string{"10.0.0.1"},
},
wantErr: false,
},
{
name: "有效IPv6 CIDR",
config: AccessConfig{
Allow: []string{"2001:db8::/32"},
},
wantErr: false,
},
{
name: "有效IPv6地址",
config: AccessConfig{
Allow: []string{"::1", "2001:db8::1"},
},
wantErr: false,
},
{
name: "无效CIDR格式",
config: AccessConfig{
Allow: []string{"invalid-cidr"},
},
wantErr: true,
errMsg: "无效的 allow CIDR/IP",
},
{
name: "无效Deny CIDR",
config: AccessConfig{
Deny: []string{"not-a-cidr"},
},
wantErr: true,
errMsg: "无效的 deny CIDR/IP",
},
{
name: "有效默认动作allow",
config: AccessConfig{
Default: "allow",
},
wantErr: false,
},
{
name: "有效默认动作deny",
config: AccessConfig{
Default: "deny",
},
wantErr: false,
},
{
name: "无效默认动作",
config: AccessConfig{
Default: "reject",
},
wantErr: true,
errMsg: "无效的 default 动作",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateAccess(&tt.config)
if tt.wantErr {
if err == nil {
t.Errorf("validateAccess() 期望返回错误,但返回 nil")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateAccess() 错误消息不匹配,期望包含 %q实际 %q", tt.errMsg, err.Error())
}
} else {
if err != nil {
t.Errorf("validateAccess() 期望返回 nil但返回错误: %v", err)
}
}
})
}
}
func TestValidateStatic(t *testing.T) {
tests := []struct {
name string
config StaticConfig
wantErr bool
errMsg string
}{
{
name: "空配置有效",
config: StaticConfig{},
wantErr: false,
},
{
name: "有效根目录",
config: StaticConfig{
Root: "/var/www/html",
},
wantErr: false,
},
{
name: "根目录含..路径遍历",
config: StaticConfig{
Root: "/var/www/../etc",
},
wantErr: true,
errMsg: "根目录路径不能包含 '..'",
},
{
name: "根目录含多个..",
config: StaticConfig{
Root: "/var/../www/../html",
},
wantErr: true,
errMsg: "根目录路径不能包含 '..'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateStatic(&tt.config)
if tt.wantErr {
if err == nil {
t.Errorf("validateStatic() 期望返回错误,但返回 nil")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateStatic() 错误消息不匹配,期望包含 %q实际 %q", tt.errMsg, err.Error())
}
} else {
if err != nil {
t.Errorf("validateStatic() 期望返回 nil但返回错误: %v", err)
}
}
})
}
}
func TestValidateSecurity(t *testing.T) {
tests := []struct {
name string
config SecurityConfig
wantErr bool
errMsg string
}{
{
name: "空配置有效",
config: SecurityConfig{},
wantErr: false,
},
{
name: "有效安全配置",
config: SecurityConfig{
Access: AccessConfig{
Allow: []string{"192.168.1.0/24"},
},
Auth: AuthConfig{
Type: "basic",
Users: []User{{Name: "admin", Password: "hashed"}},
},
RateLimit: RateLimitConfig{
RequestRate: 100,
},
},
wantErr: false,
},
{
name: "无效Access配置",
config: SecurityConfig{
Access: AccessConfig{
Allow: []string{"invalid-ip"},
},
},
wantErr: true,
errMsg: "无效的 allow CIDR/IP",
},
{
name: "无效Auth配置",
config: SecurityConfig{
Auth: AuthConfig{
Type: "invalid",
Users: []User{{Name: "admin", Password: "hashed"}},
},
},
wantErr: true,
errMsg: "不支持的认证类型",
},
{
name: "无效RateLimit配置",
config: SecurityConfig{
RateLimit: RateLimitConfig{
RequestRate: -1,
},
},
wantErr: true,
errMsg: "request_rate 不能为负数",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateSecurity(&tt.config)
if tt.wantErr {
if err == nil {
t.Errorf("validateSecurity() 期望返回错误,但返回 nil")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("validateSecurity() 错误消息不匹配,期望包含 %q实际 %q", tt.errMsg, err.Error())
}
} else {
if err != nil {
t.Errorf("validateSecurity() 期望返回 nil但返回错误: %v", err)
}
}
})
}
}