refactor: 优化字符串构建方式,统一测试错误处理风格
- 使用 fmt.Fprintf 替代冗余的 WriteString(fmt.Sprintf) 组合 - 测试中 nil 检查使用 t.Fatal 替代 t.Error 立即终止 - .gitignore 添加 html/ 目录忽略 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
a1d84c1144
commit
cd2d1a8194
3
.gitignore
vendored
3
.gitignore
vendored
@ -60,4 +60,5 @@ CLAUDE.md
|
|||||||
lolly.yaml
|
lolly.yaml
|
||||||
config.yaml
|
config.yaml
|
||||||
lolly
|
lolly
|
||||||
coverage.html
|
coverage.html
|
||||||
|
html/
|
||||||
@ -253,28 +253,28 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
|||||||
buf.WriteString("\n")
|
buf.WriteString("\n")
|
||||||
buf.WriteString(" # 速率限制\n")
|
buf.WriteString(" # 速率限制\n")
|
||||||
buf.WriteString(" rate_limit:\n")
|
buf.WriteString(" rate_limit:\n")
|
||||||
buf.WriteString(fmt.Sprintf(" request_rate: %d # 每秒请求数(0 表示不限制)\n", cfg.Server.Security.RateLimit.RequestRate))
|
fmt.Fprintf(&buf, " request_rate: %d # 每秒请求数(0 表示不限制)\n", cfg.Server.Security.RateLimit.RequestRate)
|
||||||
buf.WriteString(fmt.Sprintf(" burst: %d # 突发上限\n", cfg.Server.Security.RateLimit.Burst))
|
fmt.Fprintf(&buf, " burst: %d # 突发上限\n", cfg.Server.Security.RateLimit.Burst)
|
||||||
buf.WriteString(fmt.Sprintf(" conn_limit: %d # 连接数限制\n", cfg.Server.Security.RateLimit.ConnLimit))
|
fmt.Fprintf(&buf, " conn_limit: %d # 连接数限制\n", cfg.Server.Security.RateLimit.ConnLimit)
|
||||||
buf.WriteString(fmt.Sprintf(" key: \"%s\" # 限流 key 来源(有效值: ip, header)\n", cfg.Server.Security.RateLimit.Key))
|
fmt.Fprintf(&buf, " key: \"%s\" # 限流 key 来源(有效值: ip, header)\n", cfg.Server.Security.RateLimit.Key)
|
||||||
buf.WriteString(fmt.Sprintf(" algorithm: \"%s\" # 限流算法(有效值: token_bucket, sliding_window)\n", cfg.Server.Security.RateLimit.Algorithm))
|
fmt.Fprintf(&buf, " algorithm: \"%s\" # 限流算法(有效值: token_bucket, sliding_window)\n", cfg.Server.Security.RateLimit.Algorithm)
|
||||||
buf.WriteString(fmt.Sprintf(" sliding_window_mode: \"%s\" # 滑动窗口模式(有效值: approximate, precise,仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindowMode))
|
fmt.Fprintf(&buf, " sliding_window_mode: \"%s\" # 滑动窗口模式(有效值: approximate, precise,仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindowMode)
|
||||||
buf.WriteString(fmt.Sprintf(" sliding_window: %d # 滑动窗口大小(秒,仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindow))
|
fmt.Fprintf(&buf, " sliding_window: %d # 滑动窗口大小(秒,仅 algorithm=sliding_window 时有效)\n", cfg.Server.Security.RateLimit.SlidingWindow)
|
||||||
buf.WriteString("\n")
|
buf.WriteString("\n")
|
||||||
buf.WriteString(" # 认证配置(type 为空时禁用)\n")
|
buf.WriteString(" # 认证配置(type 为空时禁用)\n")
|
||||||
buf.WriteString(" auth:\n")
|
buf.WriteString(" auth:\n")
|
||||||
buf.WriteString(" type: \"\" # 认证类型(有效值: basic,空表示禁用)\n")
|
buf.WriteString(" type: \"\" # 认证类型(有效值: basic,空表示禁用)\n")
|
||||||
buf.WriteString(fmt.Sprintf(" require_tls: %v # 启用时强制 HTTPS\n", cfg.Server.Security.Auth.RequireTLS))
|
fmt.Fprintf(&buf, " require_tls: %v # 启用时强制 HTTPS\n", cfg.Server.Security.Auth.RequireTLS)
|
||||||
buf.WriteString(fmt.Sprintf(" algorithm: \"%s\" # 密码哈希算法(有效值: bcrypt, argon2id)\n", cfg.Server.Security.Auth.Algorithm))
|
fmt.Fprintf(&buf, " algorithm: \"%s\" # 密码哈希算法(有效值: bcrypt, argon2id)\n", cfg.Server.Security.Auth.Algorithm)
|
||||||
buf.WriteString(" users: [] # 用户列表\n")
|
buf.WriteString(" users: [] # 用户列表\n")
|
||||||
buf.WriteString(fmt.Sprintf(" realm: \"%s\" # 认证域\n", cfg.Server.Security.Auth.Realm))
|
fmt.Fprintf(&buf, " realm: \"%s\" # 认证域\n", cfg.Server.Security.Auth.Realm)
|
||||||
buf.WriteString(fmt.Sprintf(" min_password_length: %d # 密码最小长度\n", cfg.Server.Security.Auth.MinPasswordLength))
|
fmt.Fprintf(&buf, " min_password_length: %d # 密码最小长度\n", cfg.Server.Security.Auth.MinPasswordLength)
|
||||||
buf.WriteString("\n")
|
buf.WriteString("\n")
|
||||||
buf.WriteString(" # 安全头部\n")
|
buf.WriteString(" # 安全头部\n")
|
||||||
buf.WriteString(" headers:\n")
|
buf.WriteString(" headers:\n")
|
||||||
buf.WriteString(fmt.Sprintf(" x_frame_options: \"%s\" # 防止点击劫持(有效值: DENY, SAMEORIGIN, 空表示禁用)\n", cfg.Server.Security.Headers.XFrameOptions))
|
fmt.Fprintf(&buf, " x_frame_options: \"%s\" # 防止点击劫持(有效值: DENY, SAMEORIGIN, 空表示禁用)\n", cfg.Server.Security.Headers.XFrameOptions)
|
||||||
buf.WriteString(fmt.Sprintf(" x_content_type_options: \"%s\" # 防止 MIME 嗅探\n", cfg.Server.Security.Headers.XContentTypeOptions))
|
fmt.Fprintf(&buf, " x_content_type_options: \"%s\" # 防止 MIME 嗅探\n", cfg.Server.Security.Headers.XContentTypeOptions)
|
||||||
buf.WriteString(fmt.Sprintf(" referrer_policy: \"%s\" # 引用策略(有效值: no-referrer, no-referrer-when-downgrade, origin, origin-when-cross-origin, same-origin, strict-origin, strict-origin-when-cross-origin, unsafe-url)\n", cfg.Server.Security.Headers.ReferrerPolicy))
|
fmt.Fprintf(&buf, " referrer_policy: \"%s\" # 引用策略(有效值: no-referrer, no-referrer-when-downgrade, origin, origin-when-cross-origin, same-origin, strict-origin, strict-origin-when-cross-origin, unsafe-url)\n", cfg.Server.Security.Headers.ReferrerPolicy)
|
||||||
buf.WriteString(" # content_security_policy: \"default-src 'self'\" # 内容安全策略 CSP\n")
|
buf.WriteString(" # content_security_policy: \"default-src 'self'\" # 内容安全策略 CSP\n")
|
||||||
buf.WriteString(" # permissions_policy: \"geolocation=(), microphone=()\" # 权限策略\n")
|
buf.WriteString(" # permissions_policy: \"geolocation=(), microphone=()\" # 权限策略\n")
|
||||||
buf.WriteString("\n")
|
buf.WriteString("\n")
|
||||||
@ -290,17 +290,17 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
|||||||
// compression 配置
|
// compression 配置
|
||||||
buf.WriteString(" # 响应压缩配置\n")
|
buf.WriteString(" # 响应压缩配置\n")
|
||||||
buf.WriteString(" compression:\n")
|
buf.WriteString(" compression:\n")
|
||||||
buf.WriteString(fmt.Sprintf(" type: \"%s\" # 压缩类型(有效值: gzip, brotli, both,空表示禁用)\n", cfg.Server.Compression.Type))
|
fmt.Fprintf(&buf, " type: \"%s\" # 压缩类型(有效值: gzip, brotli, both,空表示禁用)\n", cfg.Server.Compression.Type)
|
||||||
buf.WriteString(fmt.Sprintf(" level: %d # 压缩级别(范围 1-9,值越大压缩率越高但速度越慢)\n", cfg.Server.Compression.Level))
|
fmt.Fprintf(&buf, " level: %d # 压缩级别(范围 1-9,值越大压缩率越高但速度越慢)\n", cfg.Server.Compression.Level)
|
||||||
buf.WriteString(fmt.Sprintf(" min_size: %d # 最小压缩大小(字节,小于此值不压缩)\n", cfg.Server.Compression.MinSize))
|
fmt.Fprintf(&buf, " min_size: %d # 最小压缩大小(字节,小于此值不压缩)\n", cfg.Server.Compression.MinSize)
|
||||||
buf.WriteString(fmt.Sprintf(" gzip_static: %v # 启用预压缩文件支持(自动查找 .gz/.br 文件)\n", cfg.Server.Compression.GzipStatic))
|
fmt.Fprintf(&buf, " gzip_static: %v # 启用预压缩文件支持(自动查找 .gz/.br 文件)\n", cfg.Server.Compression.GzipStatic)
|
||||||
buf.WriteString(" gzip_static_extensions: # 预压缩文件扩展名\n")
|
buf.WriteString(" gzip_static_extensions: # 预压缩文件扩展名\n")
|
||||||
for _, ext := range cfg.Server.Compression.GzipStaticExtensions {
|
for _, ext := range cfg.Server.Compression.GzipStaticExtensions {
|
||||||
buf.WriteString(fmt.Sprintf(" - \"%s\"\n", ext))
|
fmt.Fprintf(&buf, " - \"%s\"\n", ext)
|
||||||
}
|
}
|
||||||
buf.WriteString(" types: # 可压缩的 MIME 类型\n")
|
buf.WriteString(" types: # 可压缩的 MIME 类型\n")
|
||||||
for _, t := range cfg.Server.Compression.Types {
|
for _, t := range cfg.Server.Compression.Types {
|
||||||
buf.WriteString(fmt.Sprintf(" - \"%s\"\n", t))
|
fmt.Fprintf(&buf, " - \"%s\"\n", t)
|
||||||
}
|
}
|
||||||
buf.WriteString("\n")
|
buf.WriteString("\n")
|
||||||
|
|
||||||
@ -374,55 +374,55 @@ func GenerateConfigYAML(cfg *Config) ([]byte, error) {
|
|||||||
// logging 配置
|
// logging 配置
|
||||||
buf.WriteString("# 日志配置\n")
|
buf.WriteString("# 日志配置\n")
|
||||||
buf.WriteString("logging:\n")
|
buf.WriteString("logging:\n")
|
||||||
buf.WriteString(fmt.Sprintf(" format: \"%s\" # 全局日志格式(有效值: text, json),控制启动/停止日志格式\n", cfg.Logging.Format))
|
fmt.Fprintf(&buf, " format: \"%s\" # 全局日志格式(有效值: text, json),控制启动/停止日志格式\n", cfg.Logging.Format)
|
||||||
buf.WriteString(" access:\n")
|
buf.WriteString(" access:\n")
|
||||||
buf.WriteString(" path: \"\" # 日志文件路径(空表示输出到 stdout)\n")
|
buf.WriteString(" path: \"\" # 日志文件路径(空表示输出到 stdout)\n")
|
||||||
buf.WriteString(fmt.Sprintf(" format: '%s' # 访问日志格式,近似 nginx combined\n", cfg.Logging.Access.Format))
|
fmt.Fprintf(&buf, " format: '%s' # 访问日志格式,近似 nginx combined\n", cfg.Logging.Access.Format)
|
||||||
buf.WriteString(" # 支持变量: $remote_addr, $remote_user, $request, $status, $body_bytes_sent, $request_time, $http_referer, $http_user_agent, $time\n")
|
buf.WriteString(" # 支持变量: $remote_addr, $remote_user, $request, $status, $body_bytes_sent, $request_time, $http_referer, $http_user_agent, $time\n")
|
||||||
buf.WriteString(" # 特殊值 \"json\" 输出结构化 JSON\n")
|
buf.WriteString(" # 特殊值 \"json\" 输出结构化 JSON\n")
|
||||||
buf.WriteString(" error:\n")
|
buf.WriteString(" error:\n")
|
||||||
buf.WriteString(" path: \"\" # 日志文件路径(空表示输出到 stderr)\n")
|
buf.WriteString(" path: \"\" # 日志文件路径(空表示输出到 stderr)\n")
|
||||||
buf.WriteString(fmt.Sprintf(" level: \"%s\" # 日志级别(有效值: debug, info, warn, error,级别越高日志越少)\n", cfg.Logging.Error.Level))
|
fmt.Fprintf(&buf, " level: \"%s\" # 日志级别(有效值: debug, info, warn, error,级别越高日志越少)\n", cfg.Logging.Error.Level)
|
||||||
buf.WriteString("\n")
|
buf.WriteString("\n")
|
||||||
|
|
||||||
// performance 配置
|
// performance 配置
|
||||||
buf.WriteString("# 性能配置\n")
|
buf.WriteString("# 性能配置\n")
|
||||||
buf.WriteString("performance:\n")
|
buf.WriteString("performance:\n")
|
||||||
buf.WriteString(" goroutine_pool: # Goroutine 池(处理并发请求)\n")
|
buf.WriteString(" goroutine_pool: # Goroutine 池(处理并发请求)\n")
|
||||||
buf.WriteString(fmt.Sprintf(" enabled: %v # 是否启用\n", cfg.Performance.GoroutinePool.Enabled))
|
fmt.Fprintf(&buf, " enabled: %v # 是否启用\n", cfg.Performance.GoroutinePool.Enabled)
|
||||||
buf.WriteString(fmt.Sprintf(" max_workers: %d # 最大 worker 数\n", cfg.Performance.GoroutinePool.MaxWorkers))
|
fmt.Fprintf(&buf, " max_workers: %d # 最大 worker 数\n", cfg.Performance.GoroutinePool.MaxWorkers)
|
||||||
buf.WriteString(fmt.Sprintf(" min_workers: %d # 最小 worker 数(预热)\n", cfg.Performance.GoroutinePool.MinWorkers))
|
fmt.Fprintf(&buf, " min_workers: %d # 最小 worker 数(预热)\n", cfg.Performance.GoroutinePool.MinWorkers)
|
||||||
buf.WriteString(fmt.Sprintf(" idle_timeout: %ds # 空闲超时\n", int(cfg.Performance.GoroutinePool.IdleTimeout.Seconds())))
|
fmt.Fprintf(&buf, " idle_timeout: %ds # 空闲超时\n", int(cfg.Performance.GoroutinePool.IdleTimeout.Seconds()))
|
||||||
buf.WriteString(" file_cache: # 静态文件缓存\n")
|
buf.WriteString(" file_cache: # 静态文件缓存\n")
|
||||||
buf.WriteString(fmt.Sprintf(" max_entries: %d # 最大缓存条目\n", cfg.Performance.FileCache.MaxEntries))
|
fmt.Fprintf(&buf, " max_entries: %d # 最大缓存条目\n", cfg.Performance.FileCache.MaxEntries)
|
||||||
buf.WriteString(fmt.Sprintf(" max_size: %d # 内存上限(字节,%dMB)\n", cfg.Performance.FileCache.MaxSize, cfg.Performance.FileCache.MaxSize/1024/1024))
|
fmt.Fprintf(&buf, " max_size: %d # 内存上限(字节,%dMB)\n", cfg.Performance.FileCache.MaxSize, cfg.Performance.FileCache.MaxSize/1024/1024)
|
||||||
buf.WriteString(fmt.Sprintf(" inactive: %ds # 未访问淘汰时间\n", int(cfg.Performance.FileCache.Inactive.Seconds())))
|
fmt.Fprintf(&buf, " inactive: %ds # 未访问淘汰时间\n", int(cfg.Performance.FileCache.Inactive.Seconds()))
|
||||||
buf.WriteString(fmt.Sprintf(" lru_eviction: %v # 启用 LRU 淘汰\n", cfg.Performance.FileCache.LRUEviction))
|
fmt.Fprintf(&buf, " lru_eviction: %v # 启用 LRU 淘汰\n", cfg.Performance.FileCache.LRUEviction)
|
||||||
buf.WriteString(" transport: # HTTP Transport 连接池\n")
|
buf.WriteString(" transport: # HTTP Transport 连接池\n")
|
||||||
buf.WriteString(fmt.Sprintf(" max_idle_conns: %d # 最大空闲连接\n", cfg.Performance.Transport.MaxIdleConns))
|
fmt.Fprintf(&buf, " max_idle_conns: %d # 最大空闲连接\n", cfg.Performance.Transport.MaxIdleConns)
|
||||||
buf.WriteString(fmt.Sprintf(" max_idle_conns_per_host: %d # 每主机空闲连接\n", cfg.Performance.Transport.MaxIdleConnsPerHost))
|
fmt.Fprintf(&buf, " max_idle_conns_per_host: %d # 每主机空闲连接\n", cfg.Performance.Transport.MaxIdleConnsPerHost)
|
||||||
buf.WriteString(fmt.Sprintf(" idle_conn_timeout: %ds # 空闲超时\n", int(cfg.Performance.Transport.IdleConnTimeout.Seconds())))
|
fmt.Fprintf(&buf, " idle_conn_timeout: %ds # 空闲超时\n", int(cfg.Performance.Transport.IdleConnTimeout.Seconds()))
|
||||||
buf.WriteString(fmt.Sprintf(" max_conns_per_host: %d # 每主机最大连接(0 表示不限制)\n", cfg.Performance.Transport.MaxConnsPerHost))
|
fmt.Fprintf(&buf, " max_conns_per_host: %d # 每主机最大连接(0 表示不限制)\n", cfg.Performance.Transport.MaxConnsPerHost)
|
||||||
buf.WriteString("\n")
|
buf.WriteString("\n")
|
||||||
|
|
||||||
// HTTP3 配置
|
// HTTP3 配置
|
||||||
buf.WriteString("# HTTP/3 (QUIC) 配置(需要 SSL 证书)\n")
|
buf.WriteString("# HTTP/3 (QUIC) 配置(需要 SSL 证书)\n")
|
||||||
buf.WriteString("http3:\n")
|
buf.WriteString("http3:\n")
|
||||||
buf.WriteString(fmt.Sprintf(" enabled: %v # 是否启用 HTTP/3\n", cfg.HTTP3.Enabled))
|
fmt.Fprintf(&buf, " enabled: %v # 是否启用 HTTP/3\n", cfg.HTTP3.Enabled)
|
||||||
buf.WriteString(fmt.Sprintf(" listen: \"%s\" # UDP 监听地址\n", cfg.HTTP3.Listen))
|
fmt.Fprintf(&buf, " listen: \"%s\" # UDP 监听地址\n", cfg.HTTP3.Listen)
|
||||||
buf.WriteString(fmt.Sprintf(" max_streams: %d # 最大并发流\n", cfg.HTTP3.MaxStreams))
|
fmt.Fprintf(&buf, " max_streams: %d # 最大并发流\n", cfg.HTTP3.MaxStreams)
|
||||||
buf.WriteString(fmt.Sprintf(" idle_timeout: %ds # 空闲超时\n", int(cfg.HTTP3.IdleTimeout.Seconds())))
|
fmt.Fprintf(&buf, " idle_timeout: %ds # 空闲超时\n", int(cfg.HTTP3.IdleTimeout.Seconds()))
|
||||||
buf.WriteString(fmt.Sprintf(" enable_0rtt: %v # 启用 0-RTT(早期数据,可能存在安全风险)\n", cfg.HTTP3.Enable0RTT))
|
fmt.Fprintf(&buf, " enable_0rtt: %v # 启用 0-RTT(早期数据,可能存在安全风险)\n", cfg.HTTP3.Enable0RTT)
|
||||||
buf.WriteString("\n")
|
buf.WriteString("\n")
|
||||||
|
|
||||||
// monitoring 配置
|
// monitoring 配置
|
||||||
buf.WriteString("# 监控配置\n")
|
buf.WriteString("# 监控配置\n")
|
||||||
buf.WriteString("monitoring:\n")
|
buf.WriteString("monitoring:\n")
|
||||||
buf.WriteString(" status:\n")
|
buf.WriteString(" status:\n")
|
||||||
buf.WriteString(fmt.Sprintf(" path: \"%s\" # 状态端点路径\n", cfg.Monitoring.Status.Path))
|
fmt.Fprintf(&buf, " path: \"%s\" # 状态端点路径\n", cfg.Monitoring.Status.Path)
|
||||||
buf.WriteString(" allow: # 允许访问的 IP\n")
|
buf.WriteString(" allow: # 允许访问的 IP\n")
|
||||||
for _, ip := range cfg.Monitoring.Status.Allow {
|
for _, ip := range cfg.Monitoring.Status.Allow {
|
||||||
buf.WriteString(fmt.Sprintf(" - \"%s\"\n", ip))
|
fmt.Fprintf(&buf, " - \"%s\"\n", ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
|
|||||||
@ -14,7 +14,7 @@ import (
|
|||||||
func TestNewAdapter(t *testing.T) {
|
func TestNewAdapter(t *testing.T) {
|
||||||
adapter := NewAdapter()
|
adapter := NewAdapter()
|
||||||
if adapter == nil {
|
if adapter == nil {
|
||||||
t.Error("Expected non-nil adapter")
|
t.Fatal("Expected non-nil adapter")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 测试 ctxPool 是否初始化
|
// 测试 ctxPool 是否初始化
|
||||||
|
|||||||
@ -88,7 +88,7 @@ func TestNewServer_Success(t *testing.T) {
|
|||||||
t.Errorf("Unexpected error: %v", err)
|
t.Errorf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
if server == nil {
|
if server == nil {
|
||||||
t.Error("Expected non-nil server")
|
t.Fatal("Expected non-nil server")
|
||||||
}
|
}
|
||||||
|
|
||||||
if server.config != cfg {
|
if server.config != cfg {
|
||||||
|
|||||||
@ -288,8 +288,8 @@ func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string) s
|
|||||||
|
|
||||||
// 构建请求头
|
// 构建请求头
|
||||||
var req strings.Builder
|
var req strings.Builder
|
||||||
req.WriteString(fmt.Sprintf("GET %s HTTP/1.1\r\n", path))
|
fmt.Fprintf(&req, "GET %s HTTP/1.1\r\n", path)
|
||||||
req.WriteString(fmt.Sprintf("Host: %s\r\n", targetHost))
|
fmt.Fprintf(&req, "Host: %s\r\n", targetHost)
|
||||||
|
|
||||||
// 复制原始请求的关键头
|
// 复制原始请求的关键头
|
||||||
copyHeaders := []string{
|
copyHeaders := []string{
|
||||||
@ -304,27 +304,27 @@ func buildWebSocketUpgradeRequest(ctx *fasthttp.RequestCtx, targetHost string) s
|
|||||||
|
|
||||||
for _, header := range copyHeaders {
|
for _, header := range copyHeaders {
|
||||||
if value := ctx.Request.Header.Peek(header); len(value) > 0 {
|
if value := ctx.Request.Header.Peek(header); len(value) > 0 {
|
||||||
req.WriteString(fmt.Sprintf("%s: %s\r\n", header, string(value)))
|
fmt.Fprintf(&req, "%s: %s\r\n", header, string(value))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加 X-Forwarded 头
|
// 添加 X-Forwarded 头
|
||||||
clientIP := getClientIP(ctx)
|
clientIP := getClientIP(ctx)
|
||||||
if clientIP != "" {
|
if clientIP != "" {
|
||||||
req.WriteString(fmt.Sprintf("X-Forwarded-For: %s\r\n", clientIP))
|
fmt.Fprintf(&req, "X-Forwarded-For: %s\r\n", clientIP)
|
||||||
req.WriteString(fmt.Sprintf("X-Real-IP: %s\r\n", clientIP))
|
fmt.Fprintf(&req, "X-Real-IP: %s\r\n", clientIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
host := string(ctx.Host())
|
host := string(ctx.Host())
|
||||||
if host != "" {
|
if host != "" {
|
||||||
req.WriteString(fmt.Sprintf("X-Forwarded-Host: %s\r\n", host))
|
fmt.Fprintf(&req, "X-Forwarded-Host: %s\r\n", host)
|
||||||
}
|
}
|
||||||
|
|
||||||
proto := "http"
|
proto := "http"
|
||||||
if ctx.IsTLS() {
|
if ctx.IsTLS() {
|
||||||
proto = "https"
|
proto = "https"
|
||||||
}
|
}
|
||||||
req.WriteString(fmt.Sprintf("X-Forwarded-Proto: %s\r\n", proto))
|
fmt.Fprintf(&req, "X-Forwarded-Proto: %s\r\n", proto)
|
||||||
|
|
||||||
// 结束请求头
|
// 结束请求头
|
||||||
req.WriteString("\r\n")
|
req.WriteString("\r\n")
|
||||||
@ -486,12 +486,12 @@ func extractHost(url string) string {
|
|||||||
func writeUpgradeResponse(conn net.Conn, resp *http.Response) error {
|
func writeUpgradeResponse(conn net.Conn, resp *http.Response) error {
|
||||||
// 构建响应行
|
// 构建响应行
|
||||||
var respStr strings.Builder
|
var respStr strings.Builder
|
||||||
respStr.WriteString(fmt.Sprintf("HTTP/%d.%d %s\r\n", resp.ProtoMajor, resp.ProtoMinor, resp.Status))
|
fmt.Fprintf(&respStr, "HTTP/%d.%d %s\r\n", resp.ProtoMajor, resp.ProtoMinor, resp.Status)
|
||||||
|
|
||||||
// 写入响应头
|
// 写入响应头
|
||||||
for key, values := range resp.Header {
|
for key, values := range resp.Header {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
respStr.WriteString(fmt.Sprintf("%s: %s\r\n", key, value))
|
fmt.Fprintf(&respStr, "%s: %s\r\n", key, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -17,9 +17,8 @@ func TestNewWebSocketBridge(t *testing.T) {
|
|||||||
defer func() { _ = targetConn.Close() }()
|
defer func() { _ = targetConn.Close() }()
|
||||||
|
|
||||||
bridge := NewWebSocketBridge(clientConn, targetConn)
|
bridge := NewWebSocketBridge(clientConn, targetConn)
|
||||||
|
|
||||||
if bridge == nil {
|
if bridge == nil {
|
||||||
t.Error("Expected non-nil bridge")
|
t.Fatal("Expected non-nil bridge")
|
||||||
}
|
}
|
||||||
if bridge.clientConn != clientConn {
|
if bridge.clientConn != clientConn {
|
||||||
t.Error("Expected clientConn to be set")
|
t.Error("Expected clientConn to be set")
|
||||||
@ -27,7 +26,7 @@ func TestNewWebSocketBridge(t *testing.T) {
|
|||||||
if bridge.targetConn != targetConn {
|
if bridge.targetConn != targetConn {
|
||||||
t.Error("Expected targetConn to be set")
|
t.Error("Expected targetConn to be set")
|
||||||
}
|
}
|
||||||
if bridge.closed != false {
|
if bridge.closed {
|
||||||
t.Error("Expected closed to be false initially")
|
t.Error("Expected closed to be false initially")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,7 +19,7 @@ func TestNewGoroutinePool(t *testing.T) {
|
|||||||
|
|
||||||
p := NewGoroutinePool(cfg)
|
p := NewGoroutinePool(cfg)
|
||||||
if p == nil {
|
if p == nil {
|
||||||
t.Error("Expected non-nil pool")
|
t.Fatal("Expected non-nil pool")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查配置
|
// 检查配置
|
||||||
|
|||||||
@ -54,7 +54,7 @@ func TestNewStatusHandler_CIDR(t *testing.T) {
|
|||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
if h == nil {
|
if h == nil {
|
||||||
t.Error("expected non-nil handler")
|
t.Fatal("expected non-nil handler")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -101,7 +101,7 @@ func TestNewStatusHandler_SingleIP(t *testing.T) {
|
|||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
if h == nil {
|
if h == nil {
|
||||||
t.Error("expected non-nil handler")
|
t.Fatal("expected non-nil handler")
|
||||||
}
|
}
|
||||||
if len(h.allowed) != len(tt.allow) {
|
if len(h.allowed) != len(tt.allow) {
|
||||||
t.Errorf("expected %d allowed networks, got %d", len(tt.allow), len(h.allowed))
|
t.Errorf("expected %d allowed networks, got %d", len(tt.allow), len(h.allowed))
|
||||||
|
|||||||
@ -10,9 +10,8 @@ import (
|
|||||||
func TestNewUpgradeManager(t *testing.T) {
|
func TestNewUpgradeManager(t *testing.T) {
|
||||||
srv := New(nil)
|
srv := New(nil)
|
||||||
mgr := NewUpgradeManager(srv)
|
mgr := NewUpgradeManager(srv)
|
||||||
|
|
||||||
if mgr == nil {
|
if mgr == nil {
|
||||||
t.Error("Expected non-nil manager")
|
t.Fatal("Expected non-nil manager")
|
||||||
}
|
}
|
||||||
if mgr.server != srv {
|
if mgr.server != srv {
|
||||||
t.Error("Expected server to be set")
|
t.Error("Expected server to be set")
|
||||||
|
|||||||
@ -749,7 +749,7 @@ func TestHealthCheckerCheckWithHealthyTarget(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.Close()
|
_ = conn.Close()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user