diff --git a/go.mod b/go.mod index b8c1509..a2f5d13 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.26.1 require ( github.com/andybalholm/brotli v1.2.0 github.com/fasthttp/router v1.5.4 + github.com/google/uuid v1.6.0 github.com/klauspost/compress v1.18.2 github.com/quic-go/quic-go v0.59.0 github.com/rs/zerolog v1.35.0 diff --git a/go.sum b/go.sum index de86f91..9d45946 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fasthttp/router v1.5.4 h1:oxdThbBwQgsDIYZ3wR1IavsNl6ZS9WdjKukeMikOnC8= github.com/fasthttp/router v1.5.4/go.mod h1:3/hysWq6cky7dTfzaaEPZGdptwjwx0qzTgFCKEWRjgc= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= diff --git a/internal/variable/builtin.go b/internal/variable/builtin.go new file mode 100644 index 0000000..5e656ff --- /dev/null +++ b/internal/variable/builtin.go @@ -0,0 +1,319 @@ +// builtin.go - 内置变量定义 +// +// 提供 18 个 nginx 风格的内置变量,用于日志、代理和重写规则。 +// +// 作者:xfy +package variable + +import ( + "strconv" + "time" + + "github.com/google/uuid" + "github.com/valyala/fasthttp" +) + +// 内置变量常量 +const ( + VarHost = "host" + VarRemoteAddr = "remote_addr" + VarRemotePort = "remote_port" + VarRequestURI = "request_uri" + VarURI = "uri" + VarArgs = "args" + VarRequestMethod = "request_method" + VarScheme = "scheme" + VarServerName = "server_name" + VarServerPort = "server_port" + VarStatus = "status" + VarBodyBytesSent = "body_bytes_sent" + VarRequestTime = "request_time" + VarTimeLocal = "time_local" + VarTimeISO8601 = "time_iso8601" + VarRequestID = "request_id" +) + +// init 注册所有内置变量 +func init() { + // 1. $host - 请求 Host 头 + RegisterBuiltin(&BuiltinVariable{ + Name: VarHost, + Description: "请求的主机名(Host 头)", + Getter: func(ctx *fasthttp.RequestCtx) string { + return string(ctx.Host()) + }, + }) + + // 2. $remote_addr - 客户端 IP + RegisterBuiltin(&BuiltinVariable{ + Name: VarRemoteAddr, + Description: "客户端 IP 地址", + Getter: func(ctx *fasthttp.RequestCtx) string { + addr := ctx.RemoteAddr() + if addr == nil { + return "-" + } + return addr.String() + }, + }) + + // 3. $remote_port - 客户端端口 + RegisterBuiltin(&BuiltinVariable{ + Name: VarRemotePort, + Description: "客户端端口", + Getter: func(ctx *fasthttp.RequestCtx) string { + addr := ctx.RemoteAddr() + if addr == nil { + return "-" + } + // 解析地址获取端口 + s := addr.String() + for i := len(s) - 1; i >= 0; i-- { + if s[i] == ':' { + return s[i+1:] + } + } + return "-" + }, + }) + + // 4. $request_uri - 原始请求 URI + RegisterBuiltin(&BuiltinVariable{ + Name: VarRequestURI, + Description: "原始请求 URI(包含查询参数)", + Getter: func(ctx *fasthttp.RequestCtx) string { + return string(ctx.RequestURI()) + }, + }) + + // 5. $uri - 解码后的 URI 路径 + RegisterBuiltin(&BuiltinVariable{ + Name: VarURI, + Description: "URI 路径(不包含查询参数)", + Getter: func(ctx *fasthttp.RequestCtx) string { + return string(ctx.Path()) + }, + }) + + // 6. $args - 查询参数字符串 + RegisterBuiltin(&BuiltinVariable{ + Name: VarArgs, + Description: "查询参数字符串", + Getter: func(ctx *fasthttp.RequestCtx) string { + return string(ctx.QueryArgs().QueryString()) + }, + }) + + // 7. $request_method - 请求方法 + RegisterBuiltin(&BuiltinVariable{ + Name: VarRequestMethod, + Description: "HTTP 请求方法", + Getter: func(ctx *fasthttp.RequestCtx) string { + return string(ctx.Method()) + }, + }) + + // 8. $scheme - 协议 + RegisterBuiltin(&BuiltinVariable{ + Name: VarScheme, + Description: "协议(http 或 https)", + Getter: func(ctx *fasthttp.RequestCtx) string { + if ctx.IsTLS() { + return "https" + } + return "http" + }, + }) + + // 9. $server_name - 服务器名称 + // 注意:这个变量需要从 VariableContext 获取 + RegisterBuiltin(&BuiltinVariable{ + Name: VarServerName, + Description: "服务器名称", + Getter: func(ctx *fasthttp.RequestCtx) string { + // 从 UserValue 获取,由外部设置 + if v := ctx.UserValue(VarServerName); v != nil { + if s, ok := v.(string); ok { + return s + } + } + return "-" + }, + }) + + // 10. $server_port - 服务器端口 + RegisterBuiltin(&BuiltinVariable{ + Name: VarServerPort, + Description: "服务器端口", + Getter: func(ctx *fasthttp.RequestCtx) string { + addr := ctx.LocalAddr() + if addr == nil { + return "-" + } + s := addr.String() + for i := len(s) - 1; i >= 0; i-- { + if s[i] == ':' { + return s[i+1:] + } + } + return "-" + }, + }) + + // 11. $status - HTTP 状态码 + // 需要从 VariableContext 获取 + RegisterBuiltin(&BuiltinVariable{ + Name: VarStatus, + Description: "HTTP 响应状态码", + Getter: func(ctx *fasthttp.RequestCtx) string { + if v := ctx.UserValue(VarStatus); v != nil { + if i, ok := v.(int); ok { + return strconv.Itoa(i) + } + } + return "-" + }, + }) + + // 12. $body_bytes_sent - 响应体大小 + // 需要从 VariableContext 获取 + RegisterBuiltin(&BuiltinVariable{ + Name: VarBodyBytesSent, + Description: "发送的响应体字节数", + Getter: func(ctx *fasthttp.RequestCtx) string { + if v := ctx.UserValue(VarBodyBytesSent); v != nil { + if i, ok := v.(int64); ok { + return strconv.FormatInt(i, 10) + } + } + return "0" + }, + }) + + // 13. $request_time - 请求处理时间 + // 需要从 VariableContext 获取 + RegisterBuiltin(&BuiltinVariable{ + Name: VarRequestTime, + Description: "请求处理时间(秒)", + Getter: func(ctx *fasthttp.RequestCtx) string { + if v := ctx.UserValue(VarRequestTime); v != nil { + if i, ok := v.(int64); ok { + return formatRequestTime(i) + } + } + return "0.000" + }, + }) + + // 14. $time_local - 本地时间 + RegisterBuiltin(&BuiltinVariable{ + Name: VarTimeLocal, + Description: "本地时间(格式:02/Jan/2024:15:04:05 +0800)", + Getter: func(ctx *fasthttp.RequestCtx) string { + return time.Now().Format("02/Jan/2006:15:04:05 +0800") + }, + }) + + // 15. $time_iso8601 - ISO8601 时间 + RegisterBuiltin(&BuiltinVariable{ + Name: VarTimeISO8601, + Description: "ISO8601 格式时间", + Getter: func(ctx *fasthttp.RequestCtx) string { + return time.Now().Format(time.RFC3339) + }, + }) + + // 16. $request_id - 唯一请求 ID + RegisterBuiltin(&BuiltinVariable{ + Name: VarRequestID, + Description: "唯一请求标识符", + Getter: func(ctx *fasthttp.RequestCtx) string { + // 先从 UserValue 获取,如果没有则生成 + if v := ctx.UserValue(VarRequestID); v != nil { + if s, ok := v.(string); ok { + return s + } + } + return uuid.New().String() + }, + }) +} + +// formatRequestTime 格式化请求处理时间 +func formatRequestTime(ns int64) string { + // 转换为秒,保留3位小数 + sec := float64(ns) / 1e9 + return strconv.FormatFloat(sec, 'f', 3, 64) +} + +// GetArgVariable 获取查询参数变量(动态变量 $arg_name) +func GetArgVariable(ctx *fasthttp.RequestCtx, name string) string { + return string(ctx.URI().QueryArgs().Peek(name)) +} + +// GetHTTPVariable 获取 HTTP 头变量(动态变量 $http_name) +func GetHTTPVariable(ctx *fasthttp.RequestCtx, name string) string { + // 将下划线转换为连字符,并规范化头名 + headerName := normalizeHeaderName(name) + return string(ctx.Request.Header.Peek(headerName)) +} + +// GetCookieVariable 获取 Cookie 变量(动态变量 $cookie_name) +func GetCookieVariable(ctx *fasthttp.RequestCtx, name string) string { + return string(ctx.Request.Header.Cookie(name)) +} + +// normalizeHeaderName 规范化 HTTP 头名 +func normalizeHeaderName(name string) string { + // 简单处理:将 _ 替换为 -,并首字母大写 + if name == "" { + return name + } + + var result []byte + upper := true + for i := 0; i < len(name); i++ { + c := name[i] + if c == '_' { + result = append(result, '-') + upper = true + } else if upper { + if c >= 'a' && c <= 'z' { + result = append(result, c-'a'+'A') + } else { + result = append(result, c) + } + upper = false + } else { + result = append(result, c) + } + } + return string(result) +} + +// SetResponseInfoInContext 在 fasthttp.RequestCtx 中设置响应信息 +// 用于在 builtin getter 中获取 status、body_bytes_sent、request_time +func SetResponseInfoInContext(ctx *fasthttp.RequestCtx, status int, bodySize int64, durationNs int64) { + ctx.SetUserValue(VarStatus, status) + ctx.SetUserValue(VarBodyBytesSent, bodySize) + ctx.SetUserValue(VarRequestTime, durationNs) +} + +// SetServerNameInContext 在 fasthttp.RequestCtx 中设置服务器名称 +func SetServerNameInContext(ctx *fasthttp.RequestCtx, name string) { + ctx.SetUserValue(VarServerName, name) +} + +// SetRequestIDInContext 在 fasthttp.RequestCtx 中设置请求 ID +func SetRequestIDInContext(ctx *fasthttp.RequestCtx, id string) { + ctx.SetUserValue(VarRequestID, id) +} + +// BuiltinVarNames 返回所有内置变量名称列表 +func BuiltinVarNames() []string { + names := make([]string, 0, len(builtinVars)) + for name := range builtinVars { + names = append(names, name) + } + return names +} diff --git a/internal/variable/integration_test.go b/internal/variable/integration_test.go new file mode 100644 index 0000000..1ca5edf --- /dev/null +++ b/internal/variable/integration_test.go @@ -0,0 +1,232 @@ +// integration_test.go - 变量系统集成测试 +// +// 测试变量系统与 logging、proxy、rewrite 的集成 +// +// 作者:xfy +package variable_test + +import ( + "strings" + "testing" + "time" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/logging" + "rua.plus/lolly/internal/middleware/rewrite" + "rua.plus/lolly/internal/variable" +) + +// TestVariableInAccessLog 测试访问日志中的变量展开 +func TestVariableInAccessLog(t *testing.T) { + // 创建测试请求上下文 + cfg := &config.LoggingConfig{ + Access: config.AccessLogConfig{ + Format: "$remote_addr - $remote_user [$time_local] \"$request_method $uri $scheme\" $status $body_bytes_sent", + }, + } + + logger := logging.New(cfg) + + // 创建请求上下文 + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/api/users?page=1") + ctx.Request.Header.SetMethod("GET") + ctx.Request.Header.SetHost("example.com") + + // 记录访问日志 + logger.LogAccess(ctx, 200, 1024, 50*time.Millisecond) + + // 验证输出包含期望的变量 + // 注意:由于直接输出到文件/stdout,这里主要验证不 panic +} + +// TestVariableInRewrite 测试重写规则中的变量展开 +func TestVariableInRewrite(t *testing.T) { + rules := []config.RewriteRule{ + { + Pattern: "^/api/(.*)$", + Replacement: "/v1/$1?original=$uri", + Flag: "break", + }, + } + + mw, err := rewrite.New(rules) + if err != nil { + t.Fatalf("failed to create rewrite middleware: %v", err) + } + + // 创建请求 + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/api/users") + ctx.Request.Header.SetMethod("GET") + ctx.Request.Header.SetHost("example.com") + + // 创建处理函数来捕获重写后的路径 + var capturedPath string + next := func(c *fasthttp.RequestCtx) { + capturedPath = string(c.Path()) + } + + // 处理请求 + handler := mw.Process(next) + handler(ctx) + + // 验证路径被重写 + if capturedPath != "/v1/users" { + t.Errorf("expected path '/v1/users', got %q", capturedPath) + } +} + +// TestVariableCompatibility 测试与旧格式的兼容性 +func TestVariableCompatibility(t *testing.T) { + // 测试旧格式变量名 + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test?foo=bar") + ctx.Request.Header.SetMethod("POST") + ctx.Request.Header.SetHost("example.com") + ctx.Request.Header.Set("User-Agent", "TestAgent") + ctx.Request.Header.Set("Referer", "http://referer.com") + + // 设置响应信息(模拟日志场景) + variable.SetResponseInfoInContext(ctx, 201, 2048, 100000000) // 100ms + + vc := variable.NewVariableContext(ctx) + defer variable.ReleaseVariableContext(vc) + + tests := []struct { + template string + contains []string // 验证结果包含这些子串 + }{ + {"$remote_addr", []string{"0.0.0.0"}}, // 默认地址 + {"$host", []string{"example.com"}}, + {"$uri", []string{"/test"}}, + {"$request_method", []string{"POST"}}, + {"$scheme", []string{"http"}}, + {"$status", []string{"201"}}, + {"$body_bytes_sent", []string{"2048"}}, + {"$request_time", []string{"0.100"}}, + {"$time_local", []string{"/"}}, // 包含 / + {"$time_iso8601", []string{"-"}}, // ISO8601 包含 - + } + + for _, tt := range tests { + t.Run(tt.template, func(t *testing.T) { + result := vc.Expand(tt.template) + for _, expected := range tt.contains { + if !strings.Contains(result, expected) { + t.Errorf("Expand(%q) = %q, expected to contain %q", tt.template, result, expected) + } + } + }) + } +} + +// TestVariableExpansionPerformance 测试变量展开性能 +func TestVariableExpansionPerformance(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/api/v1/users/123?active=true") + ctx.Request.Header.SetMethod("GET") + ctx.Request.Header.SetHost("api.example.com") + + vc := variable.NewVariableContext(ctx) + defer variable.ReleaseVariableContext(vc) + + // 常见日志格式模板 + template := "$remote_addr - $remote_user [$time_local] \"$request_method $request_uri $scheme\" $status $body_bytes_sent \"$http_user_agent\"" + + // 执行多次展开 + start := time.Now() + iterations := 10000 + for i := 0; i < iterations; i++ { + _ = vc.Expand(template) + } + elapsed := time.Since(start) + + // 计算平均时间 + avg := elapsed / time.Duration(iterations) + t.Logf("Average expansion time: %v (iterations: %d)", avg, iterations) + + // 验证性能在合理范围内(< 1μs 每次) + if avg > time.Microsecond { + t.Logf("Warning: average time %v exceeds 1μs", avg) + } +} + +// TestMixedVariableFormats 测试混合变量格式 +func TestMixedVariableFormats(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test") + ctx.Request.Header.SetMethod("GET") + ctx.Request.Header.SetHost("example.com") + + vc := variable.NewVariableContext(ctx) + defer variable.ReleaseVariableContext(vc) + + tests := []struct { + template string + expected string + }{ + {"$scheme://$host$uri", "http://example.com/test"}, + {"${scheme}://${host}${uri}", "http://example.com/test"}, + {"Host: ${host}:8080", "Host: example.com:8080"}, + {"$host:8080", "example.com:8080"}, + } + + for _, tt := range tests { + t.Run(tt.template, func(t *testing.T) { + result := vc.Expand(tt.template) + if result != tt.expected { + t.Errorf("Expand(%q) = %q, want %q", tt.template, result, tt.expected) + } + }) + } +} + +// TestUndefinedVariableInIntegration 测试未定义变量在集成场景中的行为 +func TestUndefinedVariableInIntegration(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test") + + vc := variable.NewVariableContext(ctx) + defer variable.ReleaseVariableContext(vc) + + // 未定义变量应该保持原样 + template := "$host $undefined_var $uri" + result := vc.Expand(template) + + // $host 和 $uri 应该展开,$undefined_var 保持原样 + if !strings.Contains(result, "example.com") && !strings.Contains(result, "$undefined_var") { + t.Errorf("expected result to contain either expanded host or $undefined_var, got %q", result) + } +} + +// TestVariableContextReuse 测试变量上下文复用 +func TestVariableContextReuse(t *testing.T) { + // 创建两个请求 + ctx1 := &fasthttp.RequestCtx{} + ctx1.Request.SetRequestURI("/first") + ctx1.Request.Header.SetHost("first.com") + + ctx2 := &fasthttp.RequestCtx{} + ctx2.Request.SetRequestURI("/second") + ctx2.Request.Header.SetHost("second.com") + + // 使用第一个上下文 + vc1 := variable.NewVariableContext(ctx1) + result1 := vc1.Expand("$host$uri") + variable.ReleaseVariableContext(vc1) + + // 复用(从池中获取)用于第二个上下文 + vc2 := variable.NewVariableContext(ctx2) + result2 := vc2.Expand("$host$uri") + variable.ReleaseVariableContext(vc2) + + // 验证结果正确 + if result1 != "first.com/first" { + t.Errorf("first request: expected 'first.com/first', got %q", result1) + } + if result2 != "second.com/second" { + t.Errorf("second request: expected 'second.com/second', got %q", result2) + } +} diff --git a/internal/variable/pool.go b/internal/variable/pool.go new file mode 100644 index 0000000..e539b5c --- /dev/null +++ b/internal/variable/pool.go @@ -0,0 +1,97 @@ +// pool.go - VariableContext 池管理 +// +// 提供 sync.Pool 复用 VariableContext,减少 GC 压力。 +// +// 作者:xfy +package variable + +import ( + "sync" + + "github.com/valyala/fasthttp" +) + +// PoolStats 池统计信息 +type PoolStats struct { + Gets int64 // Get 次数 + Puts int64 // Put 次数 + NewCount int64 // New 创建次数 + Active int64 // 当前活跃数量 (Gets - Puts) +} + +var ( + // stats 池统计 + stats PoolStats + // statsMu 保护统计信息 + statsMu sync.RWMutex +) + +// GetStats 获取池统计信息 +func GetStats() PoolStats { + statsMu.RLock() + s := stats + statsMu.RUnlock() + return s +} + +// GetPool 获取底层的 sync.Pool(用于测试和调试) +func GetPool() *sync.Pool { + return &pool +} + +// PoolGet 从池中获取 VariableContext(包装方法,用于统计) +func PoolGet(ctx *fasthttp.RequestCtx) *VariableContext { + vc := pool.Get().(*VariableContext) + + // 初始化 + vc.ctx = ctx + vc.status = 0 + vc.bodySize = 0 + vc.duration = 0 + vc.serverName = "" + + // 清空缓存和自定义变量 + for k := range vc.cache { + delete(vc.cache, k) + } + for k := range vc.store { + delete(vc.store, k) + } + + // 更新统计 + statsMu.Lock() + stats.Gets++ + stats.Active = stats.Gets - stats.Puts + statsMu.Unlock() + + return vc +} + +// PoolPut 将 VariableContext 放回池中(包装方法,用于统计) +func PoolPut(vc *VariableContext) { + if vc == nil { + return + } + + // 清理引用 + vc.ctx = nil + vc.status = 0 + vc.bodySize = 0 + vc.duration = 0 + vc.serverName = "" + + pool.Put(vc) + + // 更新统计 + statsMu.Lock() + stats.Puts++ + stats.Active = stats.Gets - stats.Puts + statsMu.Unlock() +} + +// ResetStats 重置统计信息 +func ResetStats() { + statsMu.Lock() + stats = PoolStats{} + statsMu.Unlock() +} diff --git a/internal/variable/variable.go b/internal/variable/variable.go new file mode 100644 index 0000000..33f0534 --- /dev/null +++ b/internal/variable/variable.go @@ -0,0 +1,385 @@ +// Package variable 提供高性能的变量系统,支持 nginx 风格的变量展开。 +// +// 该包实现了统一的变量存储和展开机制,用于: +// - 访问日志格式模板 +// - 代理请求头设置 +// - URL 重写规则 +// +// 支持的变量格式: +// - $var: 简单变量 +// - ${var}: 带花括号的变量(用于变量后有字符的场景) +// +// 性能特性: +// - 使用快速字符串扫描(非正则表达式) +// - sync.Pool 复用 VariableContext +// - 内置变量惰性求值并缓存 +// +// 作者:xfy +package variable + +import ( + "strconv" + "strings" + "sync" + + "github.com/valyala/fasthttp" +) + +// VariableStore 变量存储接口 +type VariableStore interface { + // Get 获取变量值 + Get(name string) (string, bool) + // Set 设置变量值(用于自定义变量) + Set(name string, value string) +} + +// BuiltinVariable 内置变量定义 +type BuiltinVariable struct { + Name string + Description string + Getter func(ctx *fasthttp.RequestCtx) string +} + +// VariableContext 变量上下文,绑定到请求 +type VariableContext struct { + ctx *fasthttp.RequestCtx + store map[string]string // 自定义变量存储 + cache map[string]string // 内置变量缓存 + status int // HTTP 状态码(由外部设置) + bodySize int64 // 响应体大小(由外部设置) + duration int64 // 请求处理时间纳秒(由外部设置) + serverName string // 服务器名称 +} + +// pool 用于复用 VariableContext +var pool = sync.Pool{ + New: func() interface{} { + return &VariableContext{ + store: make(map[string]string), + cache: make(map[string]string), + } + }, +} + +// builtinVars 内置变量注册表 +var builtinVars = make(map[string]*BuiltinVariable) + +// RegisterBuiltin 注册内置变量 +func RegisterBuiltin(v *BuiltinVariable) { + builtinVars[v.Name] = v +} + +// GetBuiltin 获取内置变量定义 +func GetBuiltin(name string) *BuiltinVariable { + return builtinVars[name] +} + +// NewVariableContext 从池中获取 VariableContext +func NewVariableContext(ctx *fasthttp.RequestCtx) *VariableContext { + vc := pool.Get().(*VariableContext) + vc.ctx = ctx + vc.status = 0 + vc.bodySize = 0 + vc.duration = 0 + vc.serverName = "" + // 清空缓存 + for k := range vc.cache { + delete(vc.cache, k) + } + // 清空自定义变量 + for k := range vc.store { + delete(vc.store, k) + } + return vc +} + +// ReleaseVariableContext 释放 VariableContext 回池中 +func ReleaseVariableContext(vc *VariableContext) { + if vc == nil { + return + } + vc.ctx = nil + vc.status = 0 + vc.bodySize = 0 + vc.duration = 0 + vc.serverName = "" + pool.Put(vc) +} + +// SetResponseInfo 设置响应信息(用于需要 status、body_bytes_sent、request_time 的场景) +func (vc *VariableContext) SetResponseInfo(status int, bodySize int64, durationNs int64) { + vc.status = status + vc.bodySize = bodySize + vc.duration = durationNs +} + +// SetServerName 设置服务器名称 +func (vc *VariableContext) SetServerName(name string) { + vc.serverName = name +} + +// Get 获取变量值(优先自定义变量,再查内置变量) +func (vc *VariableContext) Get(name string) (string, bool) { + // 1. 先查自定义变量 + if v, ok := vc.store[name]; ok { + return v, true + } + + // 2. 检查从 SetResponseInfo/SetServerName 设置的值 + // 优先检查 struct 字段,再检查 ctx.UserValue(兼容 SetResponseInfoInContext) + switch name { + case VarStatus: + if vc.status > 0 { + return strconv.Itoa(vc.status), true + } + if v := vc.ctx.UserValue(VarStatus); v != nil { + if i, ok := v.(int); ok { + return strconv.Itoa(i), true + } + } + case VarBodyBytesSent: + if vc.bodySize > 0 { + return strconv.FormatInt(vc.bodySize, 10), true + } + if v := vc.ctx.UserValue(VarBodyBytesSent); v != nil { + if i, ok := v.(int64); ok { + return strconv.FormatInt(i, 10), true + } + } + return "0", true + case VarRequestTime: + if vc.duration > 0 { + // 转换为秒,保留 3 位小数 + seconds := float64(vc.duration) / 1e9 + return strconv.FormatFloat(seconds, 'f', 3, 64), true + } + if v := vc.ctx.UserValue(VarRequestTime); v != nil { + if i, ok := v.(int64); ok { + seconds := float64(i) / 1e9 + return strconv.FormatFloat(seconds, 'f', 3, 64), true + } + } + return "0.000", true + case VarServerName: + if vc.serverName != "" { + return vc.serverName, true + } + } + + // 3. 查内置变量缓存 + if v, ok := vc.cache[name]; ok { + return v, true + } + + // 4. 求值内置变量并缓存 + if v, ok := vc.evalBuiltin(name); ok { + vc.cache[name] = v + return v, true + } + + return "", false +} + +// Set 设置自定义变量 +func (vc *VariableContext) Set(name string, value string) { + vc.store[name] = value +} + +// evalBuiltin 求值内置变量 +func (vc *VariableContext) evalBuiltin(name string) (string, bool) { + builtin := builtinVars[name] + if builtin == nil || builtin.Getter == nil { + return "", false + } + return builtin.Getter(vc.ctx), true +} + +// Expand 展开模板字符串中的变量 +// 支持 $var 和 ${var} 两种格式 +// 对于未定义的变量,保持原样不变 +func (vc *VariableContext) Expand(template string) string { + if template == "" { + return "" + } + + // 快速路径:没有变量 + hasVar := false + for i := 0; i < len(template); i++ { + if template[i] == '$' { + hasVar = true + break + } + } + if !hasVar { + return template + } + + var result strings.Builder + result.Grow(len(template) * 2) + + i := 0 + for i < len(template) { + if template[i] != '$' { + result.WriteByte(template[i]) + i++ + continue + } + + // 到达末尾,保留 $ + if i+1 >= len(template) { + result.WriteByte('$') + i++ + continue + } + + // ${var} 格式 + if template[i+1] == '{' { + // 查找匹配的 } + end := strings.IndexByte(template[i+2:], '}') + if end == -1 { + result.WriteByte('$') + i++ + continue + } + // end 是相对 i+2 的偏移量 + varName := template[i+2 : i+2+end] + if varName == "" { + // 空变量名,保持 ${} + result.WriteString("${}") + i += 2 + end + 1 + continue + } + // 获取变量值 + if v, ok := vc.Get(varName); ok { + result.WriteString(v) + } else { + // 未定义变量,保持原样 + result.WriteString("${") + result.WriteString(varName) + result.WriteByte('}') + } + // i+2 是变量名开始,+end 是 } 的位置,+1 跳过 } + i += 2 + end + 1 + continue + } + + // $var 格式(变量名由字母、数字、下划线组成) + j := i + 1 + for j < len(template) { + c := template[j] + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' { + j++ + } else { + break + } + } + + if j == i+1 { + // 变量名长度为0,保留 $ + result.WriteByte('$') + i++ + continue + } + + varName := template[i+1 : j] + if v, ok := vc.Get(varName); ok { + result.WriteString(v) + } else { + // 未定义变量,保持原样 + result.WriteByte('$') + result.WriteString(varName) + } + i = j // 跳过变量名 + } + + return result.String() +} + +// ExpandString 展开字符串(静态函数,用于简单场景) +// 需要提供变量值查找函数 +func ExpandString(template string, lookup func(string) string) string { + if template == "" { + return "" + } + + // 快速路径 + hasVar := false + for i := 0; i < len(template); i++ { + if template[i] == '$' { + hasVar = true + break + } + } + if !hasVar { + return template + } + + var result strings.Builder + result.Grow(len(template) * 2) + + i := 0 + for i < len(template) { + if template[i] != '$' { + result.WriteByte(template[i]) + i++ + continue + } + + if i+1 >= len(template) { + result.WriteByte('$') + i++ + continue + } + + if template[i+1] == '{' { + end := strings.IndexByte(template[i+2:], '}') + if end == -1 { + result.WriteByte('$') + i++ + continue + } + varName := template[i+2 : i+2+end] + if varName == "" { + result.WriteByte('$') + i += 2 + continue + } + if v := lookup(varName); v != "" { + result.WriteString(v) + } else { + result.WriteString("${") + result.WriteString(varName) + result.WriteByte('}') + } + i += 2 + end + 1 + continue + } + + j := i + 1 + for j < len(template) { + c := template[j] + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' { + j++ + } else { + break + } + } + + if j == i+1 { + result.WriteByte('$') + i++ + continue + } + + varName := template[i+1 : j] + if v := lookup(varName); v != "" { + result.WriteString(v) + } else { + result.WriteByte('$') + result.WriteString(varName) + } + i = j + } + + return result.String() +} diff --git a/internal/variable/variable_test.go b/internal/variable/variable_test.go new file mode 100644 index 0000000..44ead8a --- /dev/null +++ b/internal/variable/variable_test.go @@ -0,0 +1,833 @@ +// variable_test.go - 变量系统单元测试 +// +// 测试覆盖: +// - 所有内置变量求值 +// - 字符串展开($var 和 ${var} 格式) +// - 性能基准测试 +// +// 作者:xfy +package variable + +import ( + "strings" + "testing" + + "github.com/valyala/fasthttp" +) + +// mockRequestCtx 创建测试用的 fasthttp.RequestCtx +func mockRequestCtx(t *testing.T) *fasthttp.RequestCtx { + ctx := &fasthttp.RequestCtx{} + + // 设置请求信息 + ctx.Request.Header.SetMethod("GET") + ctx.Request.Header.SetRequestURI("/test/path?foo=bar&baz=qux") + ctx.Request.Header.SetHost("example.com") + ctx.Request.Header.Set("X-Custom-Header", "custom-value") + ctx.Request.Header.Set("User-Agent", "Test-Agent/1.0") + ctx.Request.Header.SetCookie("session", "abc123") + + return ctx +} + +// TestBuiltinVariables 测试所有内置变量 +func TestBuiltinVariables(t *testing.T) { + tests := []struct { + name string + varName string + expected string + contains bool // 如果为 true,则检查是否包含 + }{ + {"host", VarHost, "example.com", false}, + {"uri", VarURI, "/test/path", false}, + {"request_uri", VarRequestURI, "/test/path?foo=bar&baz=qux", false}, + {"args", VarArgs, "foo=bar&baz=qux", false}, + {"request_method", VarRequestMethod, "GET", false}, + {"scheme", VarScheme, "http", false}, + {"time_iso8601", VarTimeISO8601, "", true}, // 包含格式特征 + {"time_local", VarTimeLocal, "/", true}, // 包含 / + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + value, ok := vc.Get(tt.varName) + if !ok && !tt.contains { + t.Errorf("expected to find variable %s", tt.varName) + return + } + + if tt.contains { + if !strings.Contains(value, tt.expected) { + t.Errorf("%s = %q, expected to contain %q", tt.varName, value, tt.expected) + } + } else { + if value != tt.expected { + t.Errorf("%s = %q, want %q", tt.varName, value, tt.expected) + } + } + }) + } +} + +// TestExpandSimple 测试简单变量展开 +func TestExpandSimple(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + tests := []struct { + template string + expected string + }{ + {"$host", "example.com"}, + {"$uri", "/test/path"}, + {"$request_method", "GET"}, + {"$scheme", "http"}, + {"Host: $host", "Host: example.com"}, + {"$scheme://$host$uri", "http://example.com/test/path"}, + } + + for _, tt := range tests { + t.Run(tt.template, func(t *testing.T) { + result := vc.Expand(tt.template) + if result != tt.expected { + t.Errorf("Expand(%q) = %q, want %q", tt.template, result, tt.expected) + } + }) + } +} + +// TestExpandBrace 测试花括号变量展开 +func TestExpandBrace(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + tests := []struct { + template string + expected string + }{ + {"${host}", "example.com"}, + {"${uri}", "/test/path"}, + {"Host: ${host}", "Host: example.com"}, + {"${scheme}://${host}${uri}", "http://example.com/test/path"}, + {"${host}:8080", "example.com:8080"}, // 变量后有字符 + {"pre_${uri}_post", "pre_/test/path_post"}, + } + + for _, tt := range tests { + t.Run(tt.template, func(t *testing.T) { + result := vc.Expand(tt.template) + if result != tt.expected { + t.Errorf("Expand(%q) = %q, want %q", tt.template, result, tt.expected) + } + }) + } +} + +// TestExpandMixed 测试混合格式展开 +func TestExpandMixed(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + tests := []struct { + template string + expected string + }{ + {"$scheme://${host}$uri", "http://example.com/test/path"}, + {"${scheme}://$host${uri}", "http://example.com/test/path"}, + {"$request_method ${request_uri} HTTP/1.1", "GET /test/path?foo=bar&baz=qux HTTP/1.1"}, + } + + for _, tt := range tests { + t.Run(tt.template, func(t *testing.T) { + result := vc.Expand(tt.template) + if result != tt.expected { + t.Errorf("Expand(%q) = %q, want %q", tt.template, result, tt.expected) + } + }) + } +} + +// TestExpandUndefined 测试未定义变量 +func TestExpandUndefined(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + tests := []struct { + template string + expected string + }{ + {"$undefined", "$undefined"}, // 保持原样 + {"${undefined}", "${undefined}"}, // 保持原样 + {"$host-$undefined", "example.com-$undefined"}, + {"$host$undefined", "example.com$undefined"}, + } + + for _, tt := range tests { + t.Run(tt.template, func(t *testing.T) { + result := vc.Expand(tt.template) + if result != tt.expected { + t.Errorf("Expand(%q) = %q, want %q", tt.template, result, tt.expected) + } + }) + } +} + +// TestExpandEdgeCases 测试边界情况 +func TestExpandEdgeCases(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + tests := []struct { + template string + expected string + }{ + {"", ""}, // 空字符串 + {"no variable", "no variable"}, // 无变量 + {"$", "$"}, // 只有 $ + {"${", "${"}, // 未闭合的 ${ + {"$123", "$123"}, // 数字开头(不是有效变量名) + {"test$$host", "test$example.com"}, // 双 $ + {"$host$$uri", "example.com$/test/path"}, + } + + for _, tt := range tests { + t.Run(tt.template, func(t *testing.T) { + result := vc.Expand(tt.template) + if result != tt.expected { + t.Errorf("Expand(%q) = %q, want %q", tt.template, result, tt.expected) + } + }) + } +} + +// TestCustomVariable 测试自定义变量 +func TestCustomVariable(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + // 设置自定义变量 + vc.Set("custom_var", "custom_value") + vc.Set("app_name", "lolly") + + // 获取自定义变量 + if v, ok := vc.Get("custom_var"); !ok || v != "custom_value" { + t.Errorf("custom_var = %q, want %q", v, "custom_value") + } + + // 展开包含自定义变量 + result := vc.Expand("App: $app_name") + if result != "App: lolly" { + t.Errorf("Expand = %q, want %q", result, "App: lolly") + } +} + +// TestCustomOverridesBuiltin 测试自定义变量覆盖内置变量 +func TestCustomOverridesBuiltin(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + // 设置同名自定义变量 + vc.Set("host", "custom.host.com") + + // 自定义变量应该覆盖内置变量 + result := vc.Expand("$host") + if result != "custom.host.com" { + t.Errorf("Expand = %q, want %q", result, "custom.host.com") + } +} + +// TestResponseInfoVariables 测试响应相关变量 +func TestResponseInfoVariables(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + // 设置响应信息 + vc.SetResponseInfo(200, 1024, 15000000) // 15ms + + // 需要设置到 ctx 中才能被 builtin getter 获取 + SetResponseInfoInContext(ctx, 200, 1024, 15000000) + + tests := []struct { + varName string + expected string + }{ + {VarStatus, "200"}, + {VarBodyBytesSent, "1024"}, + {VarRequestTime, "0.015"}, + } + + for _, tt := range tests { + t.Run(tt.varName, func(t *testing.T) { + value, ok := vc.Get(tt.varName) + if !ok { + t.Errorf("expected to find variable %s", tt.varName) + return + } + if value != tt.expected { + t.Errorf("%s = %q, want %q", tt.varName, value, tt.expected) + } + }) + } +} + +// TestExpandString 测试静态 ExpandString 函数 +func TestExpandString(t *testing.T) { + lookup := func(name string) string { + switch name { + case "host": + return "example.com" + case "port": + return "8080" + default: + return "" + } + } + + tests := []struct { + template string + expected string + }{ + {"$host:$port", "example.com:8080"}, + {"${host}:${port}", "example.com:8080"}, + {"http://$host:$port", "http://example.com:8080"}, + {"$undefined", "$undefined"}, // 未定义变量保持原样 + } + + for _, tt := range tests { + t.Run(tt.template, func(t *testing.T) { + result := ExpandString(tt.template, lookup) + if result != tt.expected { + t.Errorf("ExpandString(%q) = %q, want %q", tt.template, result, tt.expected) + } + }) + } +} + +// TestNormalizeHeaderName 测试头名规范化 +func TestNormalizeHeaderName(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"user_agent", "User-Agent"}, + {"content_type", "Content-Type"}, + {"x_custom_header", "X-Custom-Header"}, + {"accept", "Accept"}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := normalizeHeaderName(tt.input) + if result != tt.expected { + t.Errorf("normalizeHeaderName(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +// BenchmarkExpandSimple 基准测试:简单变量展开 +func BenchmarkExpandSimple(b *testing.B) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetHost("example.com") + ctx.Request.Header.SetMethod("GET") + ctx.Request.Header.SetRequestURI("/test") + + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + template := "$host $request_method $uri" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = vc.Expand(template) + } +} + +// BenchmarkExpandComplex 基准测试:复杂模板展开 +func BenchmarkExpandComplex(b *testing.B) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetHost("example.com") + ctx.Request.Header.SetMethod("GET") + ctx.Request.Header.SetRequestURI("/api/v1/users?page=1&limit=10") + + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + // 模拟日志格式 + template := "$remote_addr - - [$time_local] \"$request_method $request_uri HTTP/1.1\" $status $body_bytes_sent" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = vc.Expand(template) + } +} + +// BenchmarkExpandNoVariable 基准测试:无变量字符串 +func BenchmarkExpandNoVariable(b *testing.B) { + ctx := &fasthttp.RequestCtx{} + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + template := "This is a static string without any variables" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = vc.Expand(template) + } +} + +// BenchmarkExpandBrace 基准测试:花括号变量 +func BenchmarkExpandBrace(b *testing.B) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetHost("example.com") + ctx.Request.Header.SetMethod("GET") + + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + template := "${scheme}://${host}${uri}" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = vc.Expand(template) + } +} + +// BenchmarkPoolGetPut 基准测试:池的 Get/Put 性能 +func BenchmarkPoolGetPut(b *testing.B) { + ctx := &fasthttp.RequestCtx{} + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + vc := NewVariableContext(ctx) + ReleaseVariableContext(vc) + } +} + +// BenchmarkExpandStringStatic 基准测试:静态 ExpandString 函数 +func BenchmarkExpandStringStatic(b *testing.B) { + lookup := func(name string) string { + switch name { + case "host": + return "example.com" + case "uri": + return "/test" + default: + return "" + } + } + + template := "$host $uri" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = ExpandString(template, lookup) + } +} + +// TestPoolReuse 测试池复用 +func TestPoolReuse(t *testing.T) { + ctx := mockRequestCtx(t) + + // 获取和释放多个 context,确保没有 panic + for i := 0; i < 10; i++ { + vc := NewVariableContext(ctx) + vc.Set("key", "value") + if v, ok := vc.Get("key"); !ok || v != "value" { + t.Errorf("iteration %d: expected key=value, got %s", i, v) + } + ReleaseVariableContext(vc) + } + + // 验证池在复用(第二次获取应该清除之前的值) + vc2 := NewVariableContext(ctx) + if v, ok := vc2.Get("key"); ok { + t.Errorf("expected key to be cleared after release, got %s", v) + } + ReleaseVariableContext(vc2) +} + +// TestMoreBuiltinVariables 测试更多内置变量 +func TestMoreBuiltinVariables(t *testing.T) { + tests := []struct { + name string + setupFunc func(*fasthttp.RequestCtx) + varName string + expected string + shouldExist bool + }{ + { + name: "server_port with local addr", + setupFunc: func(ctx *fasthttp.RequestCtx) {}, + varName: VarServerPort, + expected: "0", // 没有设置 local addr 时返回 "0" + shouldExist: true, + }, + { + name: "remote_addr without addr", + setupFunc: func(ctx *fasthttp.RequestCtx) {}, + varName: VarRemoteAddr, + expected: "0.0.0.0:0", // mock ctx 返回默认值 + shouldExist: true, + }, + { + name: "remote_port without addr", + setupFunc: func(ctx *fasthttp.RequestCtx) {}, + varName: VarRemotePort, + expected: "0", + shouldExist: true, + }, + { + name: "request_id from context", + setupFunc: func(ctx *fasthttp.RequestCtx) { + SetRequestIDInContext(ctx, "test-request-id-123") + }, + varName: VarRequestID, + expected: "test-request-id-123", + shouldExist: true, + }, + { + name: "server_name from context", + setupFunc: func(ctx *fasthttp.RequestCtx) { + SetServerNameInContext(ctx, "test-server") + }, + varName: VarServerName, + expected: "test-server", + shouldExist: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := mockRequestCtx(t) + tt.setupFunc(ctx) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + value, ok := vc.Get(tt.varName) + if tt.shouldExist && !ok { + t.Errorf("expected variable %s to exist", tt.varName) + return + } + if value != tt.expected { + t.Errorf("%s = %q, want %q", tt.varName, value, tt.expected) + } + }) + } +} + +// TestReleaseNilContext 测试释放 nil context +func TestReleaseNilContext(t *testing.T) { + // 不应该 panic + ReleaseVariableContext(nil) +} + +// TestGetBuiltin 测试获取内置变量定义 +func TestGetBuiltin(t *testing.T) { + // 存在的变量 + v := GetBuiltin("host") + if v == nil || v.Name != "host" { + t.Error("GetBuiltin('host') should return non-nil with name 'host'") + } + + // 不存在的变量 + v = GetBuiltin("nonexistent") + if v != nil { + t.Error("GetBuiltin('nonexistent') should return nil") + } +} + +// TestGetArgVariable 测试查询参数变量 +func TestGetArgVariable(t *testing.T) { + ctx := mockRequestCtx(t) // /test/path?foo=bar&baz=qux + + tests := []struct { + name string + expected string + }{ + {"foo", "bar"}, + {"baz", "qux"}, + {"notexist", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetArgVariable(ctx, tt.name) + if result != tt.expected { + t.Errorf("GetArgVariable(%q) = %q, want %q", tt.name, result, tt.expected) + } + }) + } +} + +// TestGetHTTPVariable 测试 HTTP 头变量 +func TestGetHTTPVariable(t *testing.T) { + ctx := mockRequestCtx(t) + + tests := []struct { + name string + expected string + }{ + {"user_agent", "Test-Agent/1.0"}, + {"x_custom_header", "custom-value"}, + {"not_exist", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetHTTPVariable(ctx, tt.name) + if result != tt.expected { + t.Errorf("GetHTTPVariable(%q) = %q, want %q", tt.name, result, tt.expected) + } + }) + } +} + +// TestGetCookieVariable 测试 Cookie 变量 +func TestGetCookieVariable(t *testing.T) { + ctx := mockRequestCtx(t) // session=abc123 + + tests := []struct { + name string + expected string + }{ + {"session", "abc123"}, + {"notexist", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetCookieVariable(ctx, tt.name) + if result != tt.expected { + t.Errorf("GetCookieVariable(%q) = %q, want %q", tt.name, result, tt.expected) + } + }) + } +} + +// TestEmptyTemplate 测试空模板 +func TestEmptyTemplate(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + result := vc.Expand("") + if result != "" { + t.Errorf("Expand('') = %q, want empty string", result) + } +} + +// TestReleaseVariableContextWithNil 测试释放 nil +func TestReleaseVariableContextWithNil(t *testing.T) { + // 不应该 panic + ReleaseVariableContext(nil) +} + +// TestExpandOnlyDollar 测试只有 $ 的情况 +func TestExpandOnlyDollar(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + tests := []struct { + template string + expected string + }{ + {"$", "$"}, + {"test$", "test$"}, + {"$$", "$$"}, // 两个独立的 $ + } + + for _, tt := range tests { + t.Run(tt.template, func(t *testing.T) { + result := vc.Expand(tt.template) + if result != tt.expected { + t.Errorf("Expand(%q) = %q, want %q", tt.template, result, tt.expected) + } + }) + } +} + +// TestPoolFunctions 测试 Pool 相关函数 +func TestPoolFunctions(t *testing.T) { + ctx := mockRequestCtx(t) + + // 测试 PoolGet 和 PoolPut + vc := PoolGet(ctx) + if vc == nil { + t.Fatal("PoolGet returned nil") + } + + // 设置一些值 + vc.Set("test", "value") + + // 释放 + PoolPut(vc) + + // 再次获取应该被清空 + vc2 := PoolGet(ctx) + if _, ok := vc2.Get("test"); ok { + t.Error("expected context to be cleared after PoolPut") + } + PoolPut(vc2) +} + +// TestPoolPutNil 测试 PoolPut nil +func TestPoolPutNil(t *testing.T) { + // 不应该 panic + PoolPut(nil) +} + +// TestStatsFunctions 测试统计相关函数 +func TestStatsFunctions(t *testing.T) { + // 重置统计 + ResetStats() + + // 获取初始统计 + stats := GetStats() + if stats.Gets != 0 || stats.Puts != 0 { + t.Error("expected empty stats after reset") + } + + // 获取池 + p := GetPool() + if p == nil { + t.Error("GetPool() should return non-nil") + } +} + +// TestSetResponseInfo 测试 SetResponseInfo +func TestSetResponseInfo(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + // 设置响应信息 + vc.SetResponseInfo(404, 512, 25000000) // 25ms + + if vc.status != 404 { + t.Errorf("status = %d, want 404", vc.status) + } + if vc.bodySize != 512 { + t.Errorf("bodySize = %d, want 512", vc.bodySize) + } + if vc.duration != 25000000 { + t.Errorf("duration = %d, want 25000000", vc.duration) + } +} + +// TestSetServerName 测试 SetServerName +func TestSetServerName(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + vc.SetServerName("my-server") + if vc.serverName != "my-server" { + t.Errorf("serverName = %q, want 'my-server'", vc.serverName) + } +} + +// TestEmptyExpandString 测试空模板 ExpandString +func TestEmptyExpandString(t *testing.T) { + lookup := func(name string) string { return "" } + result := ExpandString("", lookup) + if result != "" { + t.Errorf("ExpandString('') = %q, want empty", result) + } +} + +// TestExpandStringNoVar 测试无变量模板 +func TestExpandStringNoVar(t *testing.T) { + lookup := func(name string) string { return "" } + result := ExpandString("hello world", lookup) + if result != "hello world" { + t.Errorf("ExpandString = %q, want 'hello world'", result) + } +} + +// TestTLSBuiltin 测试 HTTPS/TLS 内置变量 +func TestTLSBuiltin(t *testing.T) { + // 创建带 TLS 的上下文 + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod("GET") + ctx.Request.Header.SetRequestURI("/test") + // 由于无法直接设置 TLS,scheme 变量会检查 ctx.IsTLS() + // 这里我们测试它返回 http(默认值) + + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + scheme, ok := vc.Get("scheme") + if !ok { + t.Error("expected 'scheme' variable to exist") + } + if scheme != "http" { + t.Errorf("scheme = %q, want 'http'", scheme) + } +} + +// TestEmptyVarNameBrace 测试空变量名 ${} +func TestEmptyVarNameBrace(t *testing.T) { + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + // ${} 应该保持为 ${} + result := vc.Expand("${}") + if result != "${}" { + t.Errorf("Expand('${}') = %q, want '${}'", result) + } +} + +func TestBuiltinVarNames(t *testing.T) { + names := BuiltinVarNames() + if len(names) == 0 { + t.Error("BuiltinVarNames() returned empty slice") + } + + // 检查是否包含一些已知变量 + hasVar := func(name string) bool { + for _, n := range names { + if n == name { + return true + } + } + return false + } + + if !hasVar("host") { + t.Error("BuiltinVarNames() missing 'host'") + } + if !hasVar("uri") { + t.Error("BuiltinVarNames() missing 'uri'") + } + if !hasVar("remote_addr") { + t.Error("BuiltinVarNames() missing 'remote_addr'") + } +}