From 668ecde6db248d3e6026b00cc3d7c6ea778302de Mon Sep 17 00:00:00 2001 From: xfy Date: Thu, 9 Apr 2026 12:18:56 +0800 Subject: [PATCH] =?UTF-8?q?feat(variable):=20=E6=96=B0=E5=A2=9E=E5=85=A8?= =?UTF-8?q?=E5=B1=80=E5=8F=98=E9=87=8F=E6=94=AF=E6=8C=81=EF=BC=8C=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E6=B3=A8=E5=85=A5=E8=AF=B7=E6=B1=82=E4=B8=8A=E4=B8=8B?= =?UTF-8?q?=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/variable/variable.go | 53 +++++++++++++++- internal/variable/variable_test.go | 98 ++++++++++++++++++++++++++++++ 2 files changed, 149 insertions(+), 2 deletions(-) diff --git a/internal/variable/variable.go b/internal/variable/variable.go index ec78b3e..23e6690 100644 --- a/internal/variable/variable.go +++ b/internal/variable/variable.go @@ -67,6 +67,50 @@ var pool = sync.Pool{ }, } +// 全局自定义变量存储 +var ( + globalVariables map[string]string + globalVariablesLock sync.RWMutex +) + +// SetGlobalVariables 设置全局自定义变量。 +// 在应用启动或配置重载时调用。 +func SetGlobalVariables(vars map[string]string) { + globalVariablesLock.Lock() + defer globalVariablesLock.Unlock() + globalVariables = make(map[string]string, len(vars)) + for k, v := range vars { + globalVariables[k] = v + } +} + +// GetGlobalVariable 获取全局变量值。 +func GetGlobalVariable(name string) (string, bool) { + globalVariablesLock.RLock() + defer globalVariablesLock.RUnlock() + if globalVariables == nil { + return "", false + } + v, ok := globalVariables[name] + return v, ok +} + +// GetAllGlobalVariables 获取所有全局变量的副本。 +// 用于在 NewVariableContext 中批量注入。 +func GetAllGlobalVariables() map[string]string { + globalVariablesLock.RLock() + defer globalVariablesLock.RUnlock() + if globalVariables == nil { + return nil + } + // 返回副本,避免外部修改影响全局存储 + result := make(map[string]string, len(globalVariables)) + for k, v := range globalVariables { + result[k] = v + } + return result +} + // builtinVars 内置变量注册表 var builtinVars = make(map[string]*BuiltinVariable) @@ -80,7 +124,7 @@ func GetBuiltin(name string) *BuiltinVariable { return builtinVars[name] } -// NewVariableContext 从池中获取 VariableContext +// NewVariableContext 从池中获取 VariableContext,并注入全局变量。 func NewVariableContext(ctx *fasthttp.RequestCtx) *VariableContext { vc := pool.Get().(*VariableContext) vc.ctx = ctx @@ -97,10 +141,15 @@ func NewVariableContext(ctx *fasthttp.RequestCtx) *VariableContext { for k := range vc.cache { delete(vc.cache, k) } - // 清空自定义变量 + // 清空自定义变量 store,然后注入全局变量 for k := range vc.store { delete(vc.store, k) } + // 注入全局变量 + globals := GetAllGlobalVariables() + for name, value := range globals { + vc.store[name] = value + } return vc } diff --git a/internal/variable/variable_test.go b/internal/variable/variable_test.go index cd0431b..1df155c 100644 --- a/internal/variable/variable_test.go +++ b/internal/variable/variable_test.go @@ -1035,3 +1035,101 @@ func BenchmarkUpstreamVariables(b *testing.B) { _, _ = vc.Get(VarUpstreamResponseTime) } } + +// TestGlobalVariables 测试全局变量功能 +func TestGlobalVariables(t *testing.T) { + // 清理 + SetGlobalVariables(nil) + + // 测试设置全局变量 + SetGlobalVariables(map[string]string{ + "app_name": "lolly", + "version": "1.0.0", + }) + + // 测试 GetGlobalVariable + if v, ok := GetGlobalVariable("app_name"); !ok || v != "lolly" { + t.Errorf("GetGlobalVariable('app_name') = %q, %v, want 'lolly', true", v, ok) + } + + if v, ok := GetGlobalVariable("notexist"); ok { + t.Errorf("GetGlobalVariable('notexist') = %q, %v, want '', false", v, ok) + } + + // 测试 GetAllGlobalVariables + globals := GetAllGlobalVariables() + if globals == nil { + t.Error("GetAllGlobalVariables() returned nil") + } + if globals["app_name"] != "lolly" { + t.Errorf("globals['app_name'] = %q, want 'lolly'", globals["app_name"]) + } + + // 测试返回副本而非引用 + globals["app_name"] = "modified" + if v, _ := GetGlobalVariable("app_name"); v != "lolly" { + t.Error("GetAllGlobalVariables() should return a copy, not a reference") + } + + // 清理 + SetGlobalVariables(nil) +} + +// TestNewVariableContextWithGlobals 测试全局变量注入到请求上下文 +func TestNewVariableContextWithGlobals(t *testing.T) { + // 设置全局变量 + SetGlobalVariables(map[string]string{ + "global_var": "global_value", + }) + defer SetGlobalVariables(nil) + + ctx := mockRequestCtx(t) + vc := NewVariableContext(ctx) + defer ReleaseVariableContext(vc) + + // 全局变量应该被注入 + if v, ok := vc.Get("global_var"); !ok || v != "global_value" { + t.Errorf("Get('global_var') = %q, %v, want 'global_value', true", v, ok) + } + + // 展开应该包含全局变量 + result := vc.Expand("$global_var") + if result != "global_value" { + t.Errorf("Expand('$global_var') = %q, want 'global_value'", result) + } +} + +// TestGlobalVariablesConcurrent 测试全局变量并发访问 +func TestGlobalVariablesConcurrent(t *testing.T) { + SetGlobalVariables(map[string]string{ + "counter": "0", + }) + defer SetGlobalVariables(nil) + + done := make(chan bool) + + // 并发读取 + for i := 0; i < 10; i++ { + go func() { + for j := 0; j < 100; j++ { + _, _ = GetGlobalVariable("counter") + } + done <- true + }() + } + + // 并发写入 + for i := 0; i < 5; i++ { + go func() { + for j := 0; j < 50; j++ { + SetGlobalVariables(map[string]string{"counter": "updated"}) + } + done <- true + }() + } + + // 等待所有 goroutine 完成 + for i := 0; i < 15; i++ { + <-done + } +}