refactor: 优化字符串构建方式,统一测试错误处理风格

- 使用 fmt.Fprintf 替代冗余的 WriteString(fmt.Sprintf) 组合
- 测试中 nil 检查使用 t.Fatal 替代 t.Error 立即终止
- .gitignore 添加 html/ 目录忽略

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-03 17:49:11 +08:00
parent a1d84c1144
commit cd2d1a8194
10 changed files with 62 additions and 63 deletions

3
.gitignore vendored
View File

@ -60,4 +60,5 @@ CLAUDE.md
lolly.yaml
config.yaml
lolly
coverage.html
coverage.html
html/

View File

@ -253,28 +253,28 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
buf.WriteString("\n")
buf.WriteString(" # 速率限制\n")
buf.WriteString(" rate_limit:\n")
buf.WriteString(fmt.Sprintf(" request_rate: %d # 每秒请求数0 表示不限制)\n", cfg.Server.Security.RateLimit.RequestRate))
buf.WriteString(fmt.Sprintf(" burst: %d # 突发上限\n", cfg.Server.Security.RateLimit.Burst))
buf.WriteString(fmt.Sprintf(" conn_limit: %d # 连接数限制\n", cfg.Server.Security.RateLimit.ConnLimit))
buf.WriteString(fmt.Sprintf(" key: \"%s\" # 限流 key 来源(有效值: ip, header\n", cfg.Server.Security.RateLimit.Key))
buf.WriteString(fmt.Sprintf(" algorithm: \"%s\" # 限流算法(有效值: token_bucket, sliding_window\n", cfg.Server.Security.RateLimit.Algorithm))
buf.WriteString(fmt.Sprintf(" sliding_window_mode: \"%s\" # 滑动窗口模式(有效值: approximate, precise仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindowMode))
buf.WriteString(fmt.Sprintf(" sliding_window: %d # 滑动窗口大小(秒,仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindow))
fmt.Fprintf(&buf, " request_rate: %d # 每秒请求数0 表示不限制)\n", cfg.Server.Security.RateLimit.RequestRate)
fmt.Fprintf(&buf, " burst: %d # 突发上限\n", cfg.Server.Security.RateLimit.Burst)
fmt.Fprintf(&buf, " conn_limit: %d # 连接数限制\n", cfg.Server.Security.RateLimit.ConnLimit)
fmt.Fprintf(&buf, " key: \"%s\" # 限流 key 来源(有效值: ip, header\n", cfg.Server.Security.RateLimit.Key)
fmt.Fprintf(&buf, " algorithm: \"%s\" # 限流算法(有效值: token_bucket, sliding_window\n", cfg.Server.Security.RateLimit.Algorithm)
fmt.Fprintf(&buf, " sliding_window_mode: \"%s\" # 滑动窗口模式(有效值: approximate, precise仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindowMode)
fmt.Fprintf(&buf, " sliding_window: %d # 滑动窗口大小(秒,仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindow)
buf.WriteString("\n")
buf.WriteString(" # 认证配置type 为空时禁用)\n")
buf.WriteString(" auth:\n")
buf.WriteString(" type: \"\" # 认证类型(有效值: basic空表示禁用\n")
buf.WriteString(fmt.Sprintf(" require_tls: %v # 启用时强制 HTTPS\n", cfg.Server.Security.Auth.RequireTLS))
buf.WriteString(fmt.Sprintf(" algorithm: \"%s\" # 密码哈希算法(有效值: bcrypt, argon2id\n", cfg.Server.Security.Auth.Algorithm))
fmt.Fprintf(&buf, " require_tls: %v # 启用时强制 HTTPS\n", cfg.Server.Security.Auth.RequireTLS)
fmt.Fprintf(&buf, " algorithm: \"%s\" # 密码哈希算法(有效值: bcrypt, argon2id\n", cfg.Server.Security.Auth.Algorithm)
buf.WriteString(" users: [] # 用户列表\n")
buf.WriteString(fmt.Sprintf(" realm: \"%s\" # 认证域\n", cfg.Server.Security.Auth.Realm))
buf.WriteString(fmt.Sprintf(" min_password_length: %d # 密码最小长度\n", cfg.Server.Security.Auth.MinPasswordLength))
fmt.Fprintf(&buf, " realm: \"%s\" # 认证域\n", cfg.Server.Security.Auth.Realm)
fmt.Fprintf(&buf, " min_password_length: %d # 密码最小长度\n", cfg.Server.Security.Auth.MinPasswordLength)
buf.WriteString("\n")
buf.WriteString(" # 安全头部\n")
buf.WriteString(" headers:\n")
buf.WriteString(fmt.Sprintf(" x_frame_options: \"%s\" # 防止点击劫持(有效值: DENY, SAMEORIGIN, 空表示禁用)\n", cfg.Server.Security.Headers.XFrameOptions))
buf.WriteString(fmt.Sprintf(" x_content_type_options: \"%s\" # 防止 MIME 嗅探\n", cfg.Server.Security.Headers.XContentTypeOptions))
buf.WriteString(fmt.Sprintf(" referrer_policy: \"%s\" # 引用策略(有效值: no-referrer, no-referrer-when-downgrade, origin, origin-when-cross-origin, same-origin, strict-origin, strict-origin-when-cross-origin, unsafe-url\n", cfg.Server.Security.Headers.ReferrerPolicy))
fmt.Fprintf(&buf, " x_frame_options: \"%s\" # 防止点击劫持(有效值: DENY, SAMEORIGIN, 空表示禁用)\n", cfg.Server.Security.Headers.XFrameOptions)
fmt.Fprintf(&buf, " x_content_type_options: \"%s\" # 防止 MIME 嗅探\n", cfg.Server.Security.Headers.XContentTypeOptions)
fmt.Fprintf(&buf, " referrer_policy: \"%s\" # 引用策略(有效值: no-referrer, no-referrer-when-downgrade, origin, origin-when-cross-origin, same-origin, strict-origin, strict-origin-when-cross-origin, unsafe-url\n", cfg.Server.Security.Headers.ReferrerPolicy)
buf.WriteString(" # content_security_policy: \"default-src 'self'\" # 内容安全策略 CSP\n")
buf.WriteString(" # permissions_policy: \"geolocation=(), microphone=()\" # 权限策略\n")
buf.WriteString("\n")
@ -290,17 +290,17 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
// compression 配置
buf.WriteString(" # 响应压缩配置\n")
buf.WriteString(" compression:\n")
buf.WriteString(fmt.Sprintf(" type: \"%s\" # 压缩类型(有效值: gzip, brotli, both空表示禁用\n", cfg.Server.Compression.Type))
buf.WriteString(fmt.Sprintf(" level: %d # 压缩级别(范围 1-9值越大压缩率越高但速度越慢\n", cfg.Server.Compression.Level))
buf.WriteString(fmt.Sprintf(" min_size: %d # 最小压缩大小(字节,小于此值不压缩)\n", cfg.Server.Compression.MinSize))
buf.WriteString(fmt.Sprintf(" gzip_static: %v # 启用预压缩文件支持(自动查找 .gz/.br 文件)\n", cfg.Server.Compression.GzipStatic))
fmt.Fprintf(&buf, " type: \"%s\" # 压缩类型(有效值: gzip, brotli, both空表示禁用\n", cfg.Server.Compression.Type)
fmt.Fprintf(&buf, " level: %d # 压缩级别(范围 1-9值越大压缩率越高但速度越慢\n", cfg.Server.Compression.Level)
fmt.Fprintf(&buf, " min_size: %d # 最小压缩大小(字节,小于此值不压缩)\n", cfg.Server.Compression.MinSize)
fmt.Fprintf(&buf, " gzip_static: %v # 启用预压缩文件支持(自动查找 .gz/.br 文件)\n", cfg.Server.Compression.GzipStatic)
buf.WriteString(" gzip_static_extensions: # 预压缩文件扩展名\n")
for _, ext := range cfg.Server.Compression.GzipStaticExtensions {
buf.WriteString(fmt.Sprintf(" - \"%s\"\n", ext))
fmt.Fprintf(&buf, " - \"%s\"\n", ext)
}
buf.WriteString(" types: # 可压缩的 MIME 类型\n")
for _, t := range cfg.Server.Compression.Types {
buf.WriteString(fmt.Sprintf(" - \"%s\"\n", t))
fmt.Fprintf(&buf, " - \"%s\"\n", t)
}
buf.WriteString("\n")
@ -374,55 +374,55 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
// logging 配置
buf.WriteString("# 日志配置\n")
buf.WriteString("logging:\n")
buf.WriteString(fmt.Sprintf(" format: \"%s\" # 全局日志格式(有效值: text, json控制启动/停止日志格式\n", cfg.Logging.Format))
fmt.Fprintf(&buf, " format: \"%s\" # 全局日志格式(有效值: text, json控制启动/停止日志格式\n", cfg.Logging.Format)
buf.WriteString(" access:\n")
buf.WriteString(" path: \"\" # 日志文件路径(空表示输出到 stdout\n")
buf.WriteString(fmt.Sprintf(" format: '%s' # 访问日志格式,近似 nginx combined\n", cfg.Logging.Access.Format))
fmt.Fprintf(&buf, " format: '%s' # 访问日志格式,近似 nginx combined\n", cfg.Logging.Access.Format)
buf.WriteString(" # 支持变量: $remote_addr, $remote_user, $request, $status, $body_bytes_sent, $request_time, $http_referer, $http_user_agent, $time\n")
buf.WriteString(" # 特殊值 \"json\" 输出结构化 JSON\n")
buf.WriteString(" error:\n")
buf.WriteString(" path: \"\" # 日志文件路径(空表示输出到 stderr\n")
buf.WriteString(fmt.Sprintf(" level: \"%s\" # 日志级别(有效值: debug, info, warn, error级别越高日志越少\n", cfg.Logging.Error.Level))
fmt.Fprintf(&buf, " level: \"%s\" # 日志级别(有效值: debug, info, warn, error级别越高日志越少\n", cfg.Logging.Error.Level)
buf.WriteString("\n")
// performance 配置
buf.WriteString("# 性能配置\n")
buf.WriteString("performance:\n")
buf.WriteString(" goroutine_pool: # Goroutine 池(处理并发请求)\n")
buf.WriteString(fmt.Sprintf(" enabled: %v # 是否启用\n", cfg.Performance.GoroutinePool.Enabled))
buf.WriteString(fmt.Sprintf(" max_workers: %d # 最大 worker 数\n", cfg.Performance.GoroutinePool.MaxWorkers))
buf.WriteString(fmt.Sprintf(" min_workers: %d # 最小 worker 数(预热)\n", cfg.Performance.GoroutinePool.MinWorkers))
buf.WriteString(fmt.Sprintf(" idle_timeout: %ds # 空闲超时\n", int(cfg.Performance.GoroutinePool.IdleTimeout.Seconds())))
fmt.Fprintf(&buf, " enabled: %v # 是否启用\n", cfg.Performance.GoroutinePool.Enabled)
fmt.Fprintf(&buf, " max_workers: %d # 最大 worker 数\n", cfg.Performance.GoroutinePool.MaxWorkers)
fmt.Fprintf(&buf, " min_workers: %d # 最小 worker 数(预热)\n", cfg.Performance.GoroutinePool.MinWorkers)
fmt.Fprintf(&buf, " idle_timeout: %ds # 空闲超时\n", int(cfg.Performance.GoroutinePool.IdleTimeout.Seconds()))
buf.WriteString(" file_cache: # 静态文件缓存\n")
buf.WriteString(fmt.Sprintf(" max_entries: %d # 最大缓存条目\n", cfg.Performance.FileCache.MaxEntries))
buf.WriteString(fmt.Sprintf(" max_size: %d # 内存上限(字节,%dMB\n", cfg.Performance.FileCache.MaxSize, cfg.Performance.FileCache.MaxSize/1024/1024))
buf.WriteString(fmt.Sprintf(" inactive: %ds # 未访问淘汰时间\n", int(cfg.Performance.FileCache.Inactive.Seconds())))
buf.WriteString(fmt.Sprintf(" lru_eviction: %v # 启用 LRU 淘汰\n", cfg.Performance.FileCache.LRUEviction))
fmt.Fprintf(&buf, " max_entries: %d # 最大缓存条目\n", cfg.Performance.FileCache.MaxEntries)
fmt.Fprintf(&buf, " max_size: %d # 内存上限(字节,%dMB\n", cfg.Performance.FileCache.MaxSize, cfg.Performance.FileCache.MaxSize/1024/1024)
fmt.Fprintf(&buf, " inactive: %ds # 未访问淘汰时间\n", int(cfg.Performance.FileCache.Inactive.Seconds()))
fmt.Fprintf(&buf, " lru_eviction: %v # 启用 LRU 淘汰\n", cfg.Performance.FileCache.LRUEviction)
buf.WriteString(" transport: # HTTP Transport 连接池\n")
buf.WriteString(fmt.Sprintf(" max_idle_conns: %d # 最大空闲连接\n", cfg.Performance.Transport.MaxIdleConns))
buf.WriteString(fmt.Sprintf(" max_idle_conns_per_host: %d # 每主机空闲连接\n", cfg.Performance.Transport.MaxIdleConnsPerHost))
buf.WriteString(fmt.Sprintf(" idle_conn_timeout: %ds # 空闲超时\n", int(cfg.Performance.Transport.IdleConnTimeout.Seconds())))
buf.WriteString(fmt.Sprintf(" max_conns_per_host: %d # 每主机最大连接0 表示不限制)\n", cfg.Performance.Transport.MaxConnsPerHost))
fmt.Fprintf(&buf, " max_idle_conns: %d # 最大空闲连接\n", cfg.Performance.Transport.MaxIdleConns)
fmt.Fprintf(&buf, " max_idle_conns_per_host: %d # 每主机空闲连接\n", cfg.Performance.Transport.MaxIdleConnsPerHost)
fmt.Fprintf(&buf, " idle_conn_timeout: %ds # 空闲超时\n", int(cfg.Performance.Transport.IdleConnTimeout.Seconds()))
fmt.Fprintf(&buf, " max_conns_per_host: %d # 每主机最大连接0 表示不限制)\n", cfg.Performance.Transport.MaxConnsPerHost)
buf.WriteString("\n")
// HTTP3 配置
buf.WriteString("# HTTP/3 (QUIC) 配置(需要 SSL 证书)\n")
buf.WriteString("http3:\n")
buf.WriteString(fmt.Sprintf(" enabled: %v # 是否启用 HTTP/3\n", cfg.HTTP3.Enabled))
buf.WriteString(fmt.Sprintf(" listen: \"%s\" # UDP 监听地址\n", cfg.HTTP3.Listen))
buf.WriteString(fmt.Sprintf(" max_streams: %d # 最大并发流\n", cfg.HTTP3.MaxStreams))
buf.WriteString(fmt.Sprintf(" idle_timeout: %ds # 空闲超时\n", int(cfg.HTTP3.IdleTimeout.Seconds())))
buf.WriteString(fmt.Sprintf(" enable_0rtt: %v # 启用 0-RTT早期数据可能存在安全风险\n", cfg.HTTP3.Enable0RTT))
fmt.Fprintf(&buf, " enabled: %v # 是否启用 HTTP/3\n", cfg.HTTP3.Enabled)
fmt.Fprintf(&buf, " listen: \"%s\" # UDP 监听地址\n", cfg.HTTP3.Listen)
fmt.Fprintf(&buf, " max_streams: %d # 最大并发流\n", cfg.HTTP3.MaxStreams)
fmt.Fprintf(&buf, " idle_timeout: %ds # 空闲超时\n", int(cfg.HTTP3.IdleTimeout.Seconds()))
fmt.Fprintf(&buf, " enable_0rtt: %v # 启用 0-RTT早期数据可能存在安全风险\n", cfg.HTTP3.Enable0RTT)
buf.WriteString("\n")
// monitoring 配置
buf.WriteString("# 监控配置\n")
buf.WriteString("monitoring:\n")
buf.WriteString(" status:\n")
buf.WriteString(fmt.Sprintf(" path: \"%s\" # 状态端点路径\n", cfg.Monitoring.Status.Path))
fmt.Fprintf(&buf, " path: \"%s\" # 状态端点路径\n", cfg.Monitoring.Status.Path)
buf.WriteString(" allow: # 允许访问的 IP\n")
for _, ip := range cfg.Monitoring.Status.Allow {
buf.WriteString(fmt.Sprintf(" - \"%s\"\n", ip))
fmt.Fprintf(&buf, " - \"%s\"\n", ip)
}
return buf.Bytes(), nil

View File

@ -14,7 +14,7 @@ import (
func TestNewAdapter(t *testing.T) {
adapter := NewAdapter()
if adapter == nil {
t.Error("Expected non-nil adapter")
t.Fatal("Expected non-nil adapter")
}
// 测试 ctxPool 是否初始化

View File

@ -88,7 +88,7 @@ func TestNewServer_Success(t *testing.T) {
t.Errorf("Unexpected error: %v", err)
}
if server == nil {
t.Error("Expected non-nil server")
t.Fatal("Expected non-nil server")
}
if server.config != cfg {

View File

@ -288,8 +288,8 @@ func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string) s
// 构建请求头
var req strings.Builder
req.WriteString(fmt.Sprintf("GET %s HTTP/1.1\r\n", path))
req.WriteString(fmt.Sprintf("Host: %s\r\n", targetHost))
fmt.Fprintf(&req, "GET %s HTTP/1.1\r\n", path)
fmt.Fprintf(&req, "Host: %s\r\n", targetHost)
// 复制原始请求的关键头
copyHeaders := []string{
@ -304,27 +304,27 @@ func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string) s
for _, header := range copyHeaders {
if value := ctx.Request.Header.Peek(header); len(value) > 0 {
req.WriteString(fmt.Sprintf("%s: %s\r\n", header, string(value)))
fmt.Fprintf(&req, "%s: %s\r\n", header, string(value))
}
}
// 添加 X-Forwarded 头
clientIP := getClientIP(ctx)
if clientIP != "" {
req.WriteString(fmt.Sprintf("X-Forwarded-For: %s\r\n", clientIP))
req.WriteString(fmt.Sprintf("X-Real-IP: %s\r\n", clientIP))
fmt.Fprintf(&req, "X-Forwarded-For: %s\r\n", clientIP)
fmt.Fprintf(&req, "X-Real-IP: %s\r\n", clientIP)
}
host := string(ctx.Host())
if host != "" {
req.WriteString(fmt.Sprintf("X-Forwarded-Host: %s\r\n", host))
fmt.Fprintf(&req, "X-Forwarded-Host: %s\r\n", host)
}
proto := "http"
if ctx.IsTLS() {
proto = "https"
}
req.WriteString(fmt.Sprintf("X-Forwarded-Proto: %s\r\n", proto))
fmt.Fprintf(&req, "X-Forwarded-Proto: %s\r\n", proto)
// 结束请求头
req.WriteString("\r\n")
@ -486,12 +486,12 @@ func extractHost(url string) string {
func writeUpgradeResponse(conn net.Conn, resp *http.Response) error {
// 构建响应行
var respStr strings.Builder
respStr.WriteString(fmt.Sprintf("HTTP/%d.%d %s\r\n", resp.ProtoMajor, resp.ProtoMinor, resp.Status))
fmt.Fprintf(&respStr, "HTTP/%d.%d %s\r\n", resp.ProtoMajor, resp.ProtoMinor, resp.Status)
// 写入响应头
for key, values := range resp.Header {
for _, value := range values {
respStr.WriteString(fmt.Sprintf("%s: %s\r\n", key, value))
fmt.Fprintf(&respStr, "%s: %s\r\n", key, value)
}
}

View File

@ -17,9 +17,8 @@ func TestNewWebSocketBridge(t *testing.T) {
defer func() { _ = targetConn.Close() }()
bridge := NewWebSocketBridge(clientConn, targetConn)
if bridge == nil {
t.Error("Expected non-nil bridge")
t.Fatal("Expected non-nil bridge")
}
if bridge.clientConn != clientConn {
t.Error("Expected clientConn to be set")
@ -27,7 +26,7 @@ func TestNewWebSocketBridge(t *testing.T) {
if bridge.targetConn != targetConn {
t.Error("Expected targetConn to be set")
}
if bridge.closed != false {
if bridge.closed {
t.Error("Expected closed to be false initially")
}
}

View File

@ -19,7 +19,7 @@ func TestNewGoroutinePool(t *testing.T) {
p := NewGoroutinePool(cfg)
if p == nil {
t.Error("Expected non-nil pool")
t.Fatal("Expected non-nil pool")
}
// 检查配置

View File

@ -54,7 +54,7 @@ func TestNewStatusHandler_CIDR(t *testing.T) {
t.Errorf("unexpected error: %v", err)
}
if h == nil {
t.Error("expected non-nil handler")
t.Fatal("expected non-nil handler")
}
}
})
@ -101,7 +101,7 @@ func TestNewStatusHandler_SingleIP(t *testing.T) {
t.Errorf("unexpected error: %v", err)
}
if h == nil {
t.Error("expected non-nil handler")
t.Fatal("expected non-nil handler")
}
if len(h.allowed) != len(tt.allow) {
t.Errorf("expected %d allowed networks, got %d", len(tt.allow), len(h.allowed))

View File

@ -10,9 +10,8 @@ import (
func TestNewUpgradeManager(t *testing.T) {
srv := New(nil)
mgr := NewUpgradeManager(srv)
if mgr == nil {
t.Error("Expected non-nil manager")
t.Fatal("Expected non-nil manager")
}
if mgr.server != srv {
t.Error("Expected server to be set")

View File

@ -749,7 +749,7 @@ func TestHealthCheckerCheckWithHealthyTarget(t *testing.T) {
if err != nil {
return
}
conn.Close()
_ = conn.Close()
}
}()