feat(lua): 添加 Lua 脚本嵌入支持
- 基于 gopher-lua 实现类似 OpenResty 的脚本嵌入能力 - LuaEngine: server 级单 LState + 请求级临时协程 - LuaContext: 请求上下文,变量存储和阶段管理 - LuaCoroutine: 沙箱隔离,Yield/Resume 循环,执行超时 - CodeCache: 字节码缓存,LRU 淘汰 + TTL 过期 - 新增 testify 用于测试断言 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d22c20cbbb
commit
7a66e350f0
4
go.mod
4
go.mod
@ -9,16 +9,20 @@ require (
|
||||
github.com/klauspost/compress v1.18.2
|
||||
github.com/quic-go/quic-go v0.59.0
|
||||
github.com/rs/zerolog v1.35.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/valyala/fasthttp v1.69.0
|
||||
github.com/yuin/gopher-lua v1.1.2
|
||||
golang.org/x/crypto v0.49.0
|
||||
golang.org/x/net v0.51.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@ -37,6 +37,8 @@ github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZy
|
||||
github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yuin/gopher-lua v1.1.2 h1:yF/FjE3hD65tBbt0VXLE13HWS9h34fdzJmrWRXwobGA=
|
||||
github.com/yuin/gopher-lua v1.1.2/go.mod h1:7aRmXIWl37SqRf0koeyylBEzJ+aPt8A+mmkQ4f1ntR8=
|
||||
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
|
||||
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
|
||||
263
internal/lua/cache.go
Normal file
263
internal/lua/cache.go
Normal file
@ -0,0 +1,263 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
"github.com/yuin/gopher-lua/parse"
|
||||
)
|
||||
|
||||
// CacheKeyType 缓存键类型
|
||||
type CacheKeyType int
|
||||
|
||||
const (
|
||||
CacheKeyInline CacheKeyType = iota // 内联脚本
|
||||
CacheKeyFile // 文件脚本
|
||||
)
|
||||
|
||||
// CachedProto 缓存的字节码
|
||||
type CachedProto struct {
|
||||
Proto *glua.FunctionProto // 编译后的字节码
|
||||
SourceType CacheKeyType // 来源类型
|
||||
SourcePath string // 文件路径(仅 file 类型)
|
||||
ModTime time.Time // 文件修改时间(仅 file 类型)
|
||||
CachedAt time.Time // 缓存时间
|
||||
AccessAt atomic.Value // 最后访问时间
|
||||
}
|
||||
|
||||
// CodeCache 字节码缓存
|
||||
type CodeCache struct {
|
||||
mu sync.RWMutex
|
||||
protos map[string]*CachedProto // 缓存键 -> 字节码
|
||||
order []string // LRU 顺序
|
||||
maxSize int // 最大缓存数
|
||||
ttl time.Duration // 缓存 TTL
|
||||
fileWatch bool // 是否监控文件变更
|
||||
|
||||
// 统计
|
||||
hits uint64
|
||||
misses uint64
|
||||
}
|
||||
|
||||
// NewCodeCache 创建字节码缓存
|
||||
func NewCodeCache(maxSize int, ttl time.Duration, fileWatch bool) *CodeCache {
|
||||
return &CodeCache{
|
||||
protos: make(map[string]*CachedProto),
|
||||
order: make([]string, 0, maxSize),
|
||||
maxSize: maxSize,
|
||||
ttl: ttl,
|
||||
fileWatch: fileWatch,
|
||||
}
|
||||
}
|
||||
|
||||
// generateInlineKey 生成内联脚本缓存键
|
||||
func (c *CodeCache) generateInlineKey(src string) string {
|
||||
hash := sha256.Sum256([]byte(src))
|
||||
return "nhli_" + hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// generateFileKey 生成文件脚本缓存键
|
||||
func (c *CodeCache) generateFileKey(path string) string {
|
||||
hash := sha256.Sum256([]byte(path))
|
||||
return "nhlf_" + hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// GetOrCompileInline 获取或编译内联脚本
|
||||
func (c *CodeCache) GetOrCompileInline(src string) (*glua.FunctionProto, error) {
|
||||
key := c.generateInlineKey(src)
|
||||
|
||||
c.mu.RLock()
|
||||
cached, ok := c.protos[key]
|
||||
c.mu.RUnlock()
|
||||
|
||||
if ok && !c.isExpired(cached) {
|
||||
atomic.AddUint64(&c.hits, 1)
|
||||
cached.AccessAt.Store(time.Now())
|
||||
return cached.Proto, nil
|
||||
}
|
||||
|
||||
atomic.AddUint64(&c.misses, 1)
|
||||
|
||||
// 编译脚本
|
||||
chunk, err := parse.Parse(strings.NewReader(src), "<inline>")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse inline script: %w", err)
|
||||
}
|
||||
proto, err := glua.Compile(chunk, "<inline>")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compile inline script: %w", err)
|
||||
}
|
||||
|
||||
// 存入缓存
|
||||
cached = &CachedProto{
|
||||
Proto: proto,
|
||||
SourceType: CacheKeyInline,
|
||||
CachedAt: time.Now(),
|
||||
}
|
||||
cached.AccessAt.Store(time.Now())
|
||||
|
||||
c.mu.Lock()
|
||||
c.storeLocked(key, cached)
|
||||
c.mu.Unlock()
|
||||
|
||||
return proto, nil
|
||||
}
|
||||
|
||||
// GetOrCompileFile 获取或编译文件脚本
|
||||
func (c *CodeCache) GetOrCompileFile(path string) (*glua.FunctionProto, error) {
|
||||
key := c.generateFileKey(path)
|
||||
|
||||
c.mu.RLock()
|
||||
cached, ok := c.protos[key]
|
||||
c.mu.RUnlock()
|
||||
|
||||
// 检查是否需要重新加载
|
||||
if ok && !c.isExpired(cached) && !c.isFileChanged(cached) {
|
||||
atomic.AddUint64(&c.hits, 1)
|
||||
cached.AccessAt.Store(time.Now())
|
||||
return cached.Proto, nil
|
||||
}
|
||||
|
||||
atomic.AddUint64(&c.misses, 1)
|
||||
|
||||
// 读取文件
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file %s: %w", path, err)
|
||||
}
|
||||
|
||||
// 获取文件信息
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat file %s: %w", path, err)
|
||||
}
|
||||
|
||||
// 编译脚本
|
||||
reader := bufio.NewReader(strings.NewReader(string(content)))
|
||||
chunk, err := parse.Parse(reader, path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse file %s: %w", path, err)
|
||||
}
|
||||
proto, err := glua.Compile(chunk, path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compile file %s: %w", path, err)
|
||||
}
|
||||
|
||||
// 存入缓存
|
||||
cached = &CachedProto{
|
||||
Proto: proto,
|
||||
SourceType: CacheKeyFile,
|
||||
SourcePath: path,
|
||||
ModTime: info.ModTime(),
|
||||
CachedAt: time.Now(),
|
||||
}
|
||||
cached.AccessAt.Store(time.Now())
|
||||
|
||||
c.mu.Lock()
|
||||
c.storeLocked(key, cached)
|
||||
c.mu.Unlock()
|
||||
|
||||
return proto, nil
|
||||
}
|
||||
|
||||
// storeLocked 存入缓存(需持有锁)
|
||||
func (c *CodeCache) storeLocked(key string, cached *CachedProto) {
|
||||
// 如果已存在,更新
|
||||
if _, ok := c.protos[key]; ok {
|
||||
c.protos[key] = cached
|
||||
return
|
||||
}
|
||||
|
||||
// LRU 淘汰
|
||||
if len(c.protos) >= c.maxSize {
|
||||
c.evictLocked()
|
||||
}
|
||||
|
||||
c.protos[key] = cached
|
||||
c.order = append(c.order, key)
|
||||
}
|
||||
|
||||
// evictLocked 淘汰最久未使用的缓存(需持有锁)
|
||||
func (c *CodeCache) evictLocked() {
|
||||
if len(c.order) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 找到最久未访问的
|
||||
oldestKey := c.order[0]
|
||||
oldestTime := time.Now()
|
||||
|
||||
for _, key := range c.order {
|
||||
cached := c.protos[key]
|
||||
if t, ok := cached.AccessAt.Load().(time.Time); ok && t.Before(oldestTime) {
|
||||
oldestTime = t
|
||||
oldestKey = key
|
||||
}
|
||||
}
|
||||
|
||||
// 删除
|
||||
delete(c.protos, oldestKey)
|
||||
for i, k := range c.order {
|
||||
if k == oldestKey {
|
||||
c.order = append(c.order[:i], c.order[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isExpired 检查缓存是否过期
|
||||
func (c *CodeCache) isExpired(cached *CachedProto) bool {
|
||||
if c.ttl <= 0 {
|
||||
return false
|
||||
}
|
||||
return time.Since(cached.CachedAt) > c.ttl
|
||||
}
|
||||
|
||||
// isFileChanged 检查文件是否变更
|
||||
func (c *CodeCache) isFileChanged(cached *CachedProto) bool {
|
||||
if !c.fileWatch || cached.SourceType != CacheKeyFile {
|
||||
return false
|
||||
}
|
||||
|
||||
info, err := os.Stat(cached.SourcePath)
|
||||
if err != nil {
|
||||
return true // 文件不存在,视为变更
|
||||
}
|
||||
|
||||
return info.ModTime().After(cached.ModTime)
|
||||
}
|
||||
|
||||
// Stats 返回缓存统计
|
||||
func (c *CodeCache) Stats() (hits, misses uint64, size int) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return atomic.LoadUint64(&c.hits), atomic.LoadUint64(&c.misses), len(c.protos)
|
||||
}
|
||||
|
||||
// HitRate 返回缓存命中率
|
||||
func (c *CodeCache) HitRate() float64 {
|
||||
hits := atomic.LoadUint64(&c.hits)
|
||||
misses := atomic.LoadUint64(&c.misses)
|
||||
total := hits + misses
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
return float64(hits) / float64(total)
|
||||
}
|
||||
|
||||
// Clear 清空缓存
|
||||
func (c *CodeCache) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.protos = make(map[string]*CachedProto)
|
||||
c.order = c.order[:0]
|
||||
}
|
||||
44
internal/lua/config.go
Normal file
44
internal/lua/config.go
Normal file
@ -0,0 +1,44 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
// 采用 Server 级单 LState + 请求级临时协程架构
|
||||
package lua
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config Lua 引擎配置
|
||||
type Config struct {
|
||||
// 协程配置
|
||||
MaxConcurrentCoroutines int // 最大并发协程数(默认 1000)
|
||||
CoroutineTimeout time.Duration // 协程执行超时(默认 30s)
|
||||
|
||||
// 字节码缓存配置
|
||||
CodeCacheSize int // 缓存条目数(默认 1000)
|
||||
CodeCacheTTL time.Duration // 缓存过期时间(默认 1h)
|
||||
|
||||
// 文件监控
|
||||
EnableFileWatch bool // 是否启用文件变更检测(默认 true)
|
||||
|
||||
// 执行限制
|
||||
MaxExecutionTime time.Duration // 单脚本最大执行时间(默认 30s)
|
||||
|
||||
// 安全设置
|
||||
EnableOSLib bool // 是否加载 os 库(默认 false)
|
||||
EnableIOLib bool // 是否加载 io 库(默认 false)
|
||||
EnableLoadLib bool // 是否允许 load/loadfile(默认 false)
|
||||
}
|
||||
|
||||
// DefaultConfig 返回默认配置
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
MaxConcurrentCoroutines: 1000,
|
||||
CoroutineTimeout: 30 * time.Second,
|
||||
CodeCacheSize: 1000,
|
||||
CodeCacheTTL: time.Hour,
|
||||
EnableFileWatch: true,
|
||||
MaxExecutionTime: 30 * time.Second,
|
||||
EnableOSLib: false,
|
||||
EnableIOLib: false,
|
||||
EnableLoadLib: false,
|
||||
}
|
||||
}
|
||||
210
internal/lua/constraints_test.go
Normal file
210
internal/lua/constraints_test.go
Normal file
@ -0,0 +1,210 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestNewEngine 测试引擎创建
|
||||
func TestNewEngine(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, engine)
|
||||
defer engine.Close()
|
||||
|
||||
assert.NotNil(t, engine.L)
|
||||
assert.NotNil(t, engine.codeCache)
|
||||
assert.Equal(t, int32(0), engine.ActiveCoroutines())
|
||||
}
|
||||
|
||||
// TestNewCoroutine 测试协程创建
|
||||
func TestNewCoroutine(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, coro)
|
||||
require.NotNil(t, coro.Co)
|
||||
|
||||
defer coro.Close()
|
||||
|
||||
assert.Equal(t, int32(1), engine.ActiveCoroutines())
|
||||
}
|
||||
|
||||
// TestCoroutineDeadAfterResumeOK 验证协程 ResumeOK 后变成 dead 不能复用
|
||||
// 注意:gopher-lua 的 Resume 对 dead coroutine 会 panic,无法安全测试
|
||||
// 此测试验证 ResumeOK 正常完成,证明协程生命周期正确
|
||||
func TestCoroutineDeadAfterResumeOK(t *testing.T) {
|
||||
L := glua.NewState()
|
||||
defer L.Close()
|
||||
|
||||
// 创建协程
|
||||
co, cocancel := L.NewThread()
|
||||
require.NotNil(t, co)
|
||||
if cocancel != nil {
|
||||
defer cocancel()
|
||||
}
|
||||
|
||||
// 编译简单脚本
|
||||
proto, err := engineCodeToProto("return 42")
|
||||
require.NoError(t, err)
|
||||
|
||||
fn := L.NewFunctionFromProto(proto)
|
||||
|
||||
// 执行协程,应该正常完成
|
||||
st, err, values := L.Resume(co, fn)
|
||||
assert.Equal(t, glua.ResumeOK, st)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, values, 1)
|
||||
assert.Equal(t, glua.LNumber(42), values[0])
|
||||
|
||||
// 协程完成后变成 dead 状态
|
||||
// 注意:再次调用 Resume(co, fn) 会 panic
|
||||
// 实际使用中必须确保每个协程只使用一次
|
||||
}
|
||||
|
||||
// TestLFunctionCannotCrossLState 验证 LFunction 不能跨 LState 使用
|
||||
// 注意:FunctionProto 可以跨 LState 使用,但 LFunction 绑定到特定 LState
|
||||
// 这个测试验证的是 FunctionProto 共享的正确性
|
||||
func TestLFunctionCannotCrossLState(t *testing.T) {
|
||||
L1 := glua.NewState()
|
||||
defer L1.Close()
|
||||
|
||||
// 在 L1 中编译脚本并执行
|
||||
proto, err := engineCodeToProto("return 42")
|
||||
require.NoError(t, err)
|
||||
|
||||
fn := L1.NewFunctionFromProto(proto)
|
||||
L1.Push(fn)
|
||||
err = L1.PCall(0, 1, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, glua.LNumber(42), L1.Get(-1))
|
||||
L1.Pop(1)
|
||||
|
||||
// FunctionProto 可以在不同 LState 使用(这是缓存的核心假设)
|
||||
L2 := glua.NewState()
|
||||
defer L2.Close()
|
||||
|
||||
fn2 := L2.NewFunctionFromProto(proto) // 从同一个 proto 创建新的函数
|
||||
L2.Push(fn2)
|
||||
err = L2.PCall(0, 1, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, glua.LNumber(42), L2.Get(-1))
|
||||
L2.Pop(1)
|
||||
}
|
||||
|
||||
// TestNewThreadInheritsGlobals 验证 NewThread 继承全局环境
|
||||
func TestNewThreadInheritsGlobals(t *testing.T) {
|
||||
L := glua.NewState()
|
||||
defer L.Close()
|
||||
|
||||
// 在主 LState 设置全局变量
|
||||
L.SetGlobal("test_global", glua.LString("shared_value"))
|
||||
|
||||
// 创建协程
|
||||
co, cocancel := L.NewThread()
|
||||
require.NotNil(t, co)
|
||||
if cocancel != nil {
|
||||
defer cocancel()
|
||||
}
|
||||
|
||||
// 协程应该能访问主 LState 的全局变量
|
||||
proto, err := engineCodeToProto("return test_global")
|
||||
require.NoError(t, err)
|
||||
|
||||
fn := L.NewFunctionFromProto(proto)
|
||||
st, err, values := L.Resume(co, fn)
|
||||
assert.Equal(t, glua.ResumeOK, st)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, values, 1)
|
||||
assert.Equal(t, glua.LString("shared_value"), values[0])
|
||||
}
|
||||
|
||||
// TestPerRequestEnvSandbox 验证 _ENV 沙箱隔离
|
||||
func TestPerRequestEnvSandbox(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
// 创建第一个协程并设置沙箱
|
||||
coro1, err := engine.NewCoroutine(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, coro1)
|
||||
|
||||
err = coro1.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 在沙箱中设置变量
|
||||
err = coro1.Execute("local x = 1")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建第二个协程
|
||||
coro2, err := engine.NewCoroutine(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, coro2)
|
||||
|
||||
err = coro2.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 第二个协程不应该看到第一个协程的变量
|
||||
// 由于我们使用了 _ENV 沙箱,局部变量是隔离的
|
||||
coro1.Close()
|
||||
coro2.Close()
|
||||
}
|
||||
|
||||
// TestCodeCache 测试字节码缓存
|
||||
func TestCodeCache(t *testing.T) {
|
||||
cache := NewCodeCache(100, 0, false)
|
||||
|
||||
script := "return 1 + 1"
|
||||
|
||||
// 第一次编译
|
||||
proto1, err := cache.GetOrCompileInline(script)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, proto1)
|
||||
|
||||
// 第二次应该命中缓存
|
||||
proto2, err := cache.GetOrCompileInline(script)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, proto2)
|
||||
|
||||
// 相同的脚本应该返回相同的字节码
|
||||
assert.Equal(t, proto1, proto2)
|
||||
|
||||
// 检查命中率
|
||||
hits, misses, _ := cache.Stats()
|
||||
assert.Equal(t, uint64(1), hits)
|
||||
assert.Equal(t, uint64(1), misses)
|
||||
}
|
||||
|
||||
// TestCodeCacheDifferentScripts 测试不同脚本的缓存
|
||||
func TestCodeCacheDifferentScripts(t *testing.T) {
|
||||
cache := NewCodeCache(100, 0, false)
|
||||
|
||||
proto1, err := cache.GetOrCompileInline("return 1")
|
||||
require.NoError(t, err)
|
||||
|
||||
proto2, err := cache.GetOrCompileInline("return 2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 不同脚本应该产生不同的字节码
|
||||
assert.NotEqual(t, proto1, proto2)
|
||||
|
||||
hits, misses, _ := cache.Stats()
|
||||
assert.Equal(t, uint64(0), hits) // 都是 miss
|
||||
assert.Equal(t, uint64(2), misses)
|
||||
}
|
||||
|
||||
// Helper function: compile Lua code to FunctionProto
|
||||
func engineCodeToProto(src string) (*glua.FunctionProto, error) {
|
||||
return cacheGetOrCompile(src)
|
||||
}
|
||||
|
||||
// Package-level helper for testing
|
||||
var cacheGetOrCompile = NewCodeCache(100, 0, false).GetOrCompileInline
|
||||
126
internal/lua/context.go
Normal file
126
internal/lua/context.go
Normal file
@ -0,0 +1,126 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// LuaContext 请求级 Lua 上下文
|
||||
type LuaContext struct {
|
||||
// 引擎引用
|
||||
Engine *LuaEngine
|
||||
|
||||
// 协程
|
||||
Coroutine *LuaCoroutine
|
||||
|
||||
// HTTP 请求上下文
|
||||
RequestCtx *fasthttp.RequestCtx
|
||||
|
||||
// 当前阶段
|
||||
Phase Phase
|
||||
|
||||
// 变量存储(ngx.var 实现)
|
||||
Variables map[string]string
|
||||
|
||||
// 输出缓冲
|
||||
OutputBuffer []byte
|
||||
|
||||
// 是否已退出
|
||||
Exited bool
|
||||
}
|
||||
|
||||
// NewContext 创建请求上下文
|
||||
func NewContext(engine *LuaEngine, req *fasthttp.RequestCtx) *LuaContext {
|
||||
return &LuaContext{
|
||||
Engine: engine,
|
||||
RequestCtx: req,
|
||||
Variables: make(map[string]string),
|
||||
Phase: PhaseInit,
|
||||
}
|
||||
}
|
||||
|
||||
// InitCoroutine 初始化协程
|
||||
func (c *LuaContext) InitCoroutine() error {
|
||||
coro, err := c.Engine.NewCoroutine(c.RequestCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Coroutine = coro
|
||||
return c.Coroutine.SetupSandbox()
|
||||
}
|
||||
|
||||
// Execute 执行 Lua 脚本
|
||||
func (c *LuaContext) Execute(script string) error {
|
||||
if c.Coroutine == nil {
|
||||
if err := c.InitCoroutine(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return c.Coroutine.Execute(script)
|
||||
}
|
||||
|
||||
// ExecuteFile 执行文件脚本
|
||||
func (c *LuaContext) ExecuteFile(path string) error {
|
||||
if c.Coroutine == nil {
|
||||
if err := c.InitCoroutine(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return c.Coroutine.ExecuteFile(path)
|
||||
}
|
||||
|
||||
// SetPhase 设置当前阶段
|
||||
func (c *LuaContext) SetPhase(phase Phase) {
|
||||
c.Phase = phase
|
||||
}
|
||||
|
||||
// GetPhase 获取当前阶段
|
||||
func (c *LuaContext) GetPhase() Phase {
|
||||
return c.Phase
|
||||
}
|
||||
|
||||
// GetVariable 获取变量
|
||||
func (c *LuaContext) GetVariable(name string) (string, bool) {
|
||||
val, ok := c.Variables[name]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// SetVariable 设置变量
|
||||
func (c *LuaContext) SetVariable(name, value string) {
|
||||
c.Variables[name] = value
|
||||
}
|
||||
|
||||
// Write 输出内容
|
||||
func (c *LuaContext) Write(data []byte) {
|
||||
c.OutputBuffer = append(c.OutputBuffer, data...)
|
||||
}
|
||||
|
||||
// Say 输出内容并换行
|
||||
func (c *LuaContext) Say(data string) {
|
||||
c.OutputBuffer = append(c.OutputBuffer, data...)
|
||||
c.OutputBuffer = append(c.OutputBuffer, '\n')
|
||||
}
|
||||
|
||||
// Exit 退出请求处理
|
||||
func (c *LuaContext) Exit(code int) {
|
||||
c.Exited = true
|
||||
c.RequestCtx.SetStatusCode(code)
|
||||
}
|
||||
|
||||
// Release 释放资源
|
||||
func (c *LuaContext) Release() {
|
||||
if c.Coroutine != nil {
|
||||
c.Coroutine.Close()
|
||||
c.Coroutine = nil
|
||||
}
|
||||
c.Variables = nil
|
||||
c.OutputBuffer = nil
|
||||
}
|
||||
|
||||
// FlushOutput 刷新输出到响应
|
||||
func (c *LuaContext) FlushOutput() {
|
||||
if len(c.OutputBuffer) > 0 && c.RequestCtx != nil {
|
||||
c.RequestCtx.Write(c.OutputBuffer)
|
||||
c.OutputBuffer = c.OutputBuffer[:0]
|
||||
}
|
||||
}
|
||||
189
internal/lua/coroutine.go
Normal file
189
internal/lua/coroutine.go
Normal file
@ -0,0 +1,189 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// Phase 处理阶段
|
||||
type Phase int
|
||||
|
||||
const (
|
||||
PhaseInit Phase = iota
|
||||
PhaseRewrite
|
||||
PhaseAccess
|
||||
PhaseContent
|
||||
PhaseLog
|
||||
PhaseHeaderFilter
|
||||
PhaseBodyFilter
|
||||
)
|
||||
|
||||
func (p Phase) String() string {
|
||||
switch p {
|
||||
case PhaseInit:
|
||||
return "init"
|
||||
case PhaseRewrite:
|
||||
return "rewrite"
|
||||
case PhaseAccess:
|
||||
return "access"
|
||||
case PhaseContent:
|
||||
return "content"
|
||||
case PhaseLog:
|
||||
return "log"
|
||||
case PhaseHeaderFilter:
|
||||
return "header_filter"
|
||||
case PhaseBodyFilter:
|
||||
return "body_filter"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// LuaCoroutine 请求级临时协程
|
||||
// 注意:协程在 ResumeOK 后变成 dead 状态,不能复用
|
||||
type LuaCoroutine struct {
|
||||
// 所属引擎
|
||||
Engine *LuaEngine
|
||||
|
||||
// 协程 LState(通过 NewThread 创建)
|
||||
Co *glua.LState
|
||||
|
||||
// 取消函数
|
||||
Cancel context.CancelFunc
|
||||
|
||||
// 请求上下文
|
||||
RequestCtx *fasthttp.RequestCtx
|
||||
|
||||
// 执行上下文
|
||||
ExecutionContext context.Context
|
||||
executionCancel context.CancelFunc
|
||||
|
||||
// 创建时间
|
||||
CreatedAt time.Time
|
||||
|
||||
// 状态
|
||||
Exited bool // 是否已调用 exit
|
||||
|
||||
// 输出缓冲
|
||||
OutputBuffer []byte
|
||||
}
|
||||
|
||||
// SetupSandbox 创建 per-request _ENV 沙箱
|
||||
// 每个请求创建独立的 _ENV 表,通过元表继承全局环境
|
||||
func (c *LuaCoroutine) SetupSandbox() error {
|
||||
// 创建独立的 _ENV 表
|
||||
env := c.Co.NewTable()
|
||||
|
||||
// 获取全局环境 - 使用 Engine 的主 LState 全局表
|
||||
// 协程通过 NewThread 继承了父 LState 的全局环境
|
||||
globals := c.Engine.L.GetGlobal("_G")
|
||||
|
||||
// 设置元表,使未找到的变量从全局环境读取
|
||||
mt := c.Co.NewTable()
|
||||
mt.RawSetString("__index", globals)
|
||||
|
||||
// 阻止写入全局环境(可选)
|
||||
readOnlyFn := c.Co.NewFunction(func(L *glua.LState) int {
|
||||
L.RaiseError("attempt to modify global table (read-only)")
|
||||
return 0
|
||||
})
|
||||
mt.RawSetString("__newindex", readOnlyFn)
|
||||
|
||||
// 设置元表
|
||||
c.Co.SetMetatable(env, mt)
|
||||
|
||||
// 将 _ENV 设置到协程
|
||||
c.Co.SetGlobal("_ENV", env)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute 在协程中执行 Lua 脚本(支持 Yield/Resume)
|
||||
func (c *LuaCoroutine) Execute(script string) error {
|
||||
proto, err := c.Engine.codeCache.GetOrCompileInline(script)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compile script: %w", err)
|
||||
}
|
||||
return c.executeProto(proto)
|
||||
}
|
||||
|
||||
// ExecuteFile 执行文件脚本
|
||||
func (c *LuaCoroutine) ExecuteFile(path string) error {
|
||||
proto, err := c.Engine.codeCache.GetOrCompileFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compile file: %w", err)
|
||||
}
|
||||
return c.executeProto(proto)
|
||||
}
|
||||
|
||||
// executeProto 执行编译后的字节码,处理 yield/resume 循环
|
||||
func (c *LuaCoroutine) executeProto(proto *glua.FunctionProto) error {
|
||||
fn := c.Engine.L.NewFunctionFromProto(proto)
|
||||
st, execErr, values := c.Engine.L.Resume(c.Co, fn)
|
||||
|
||||
for st == glua.ResumeYield {
|
||||
results, handleErr := c.handleYield(values)
|
||||
if handleErr != nil {
|
||||
return fmt.Errorf("handle yield: %w", handleErr)
|
||||
}
|
||||
st, execErr, values = c.Engine.L.Resume(c.Co, nil, results...)
|
||||
}
|
||||
|
||||
if st == glua.ResumeError {
|
||||
atomic.AddUint64(&c.Engine.stats.ScriptsErrors, 1)
|
||||
return fmt.Errorf("lua execution error: %w", execErr)
|
||||
}
|
||||
|
||||
atomic.AddUint64(&c.Engine.stats.ScriptsExecuted, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleYield 处理协程 yield
|
||||
func (c *LuaCoroutine) handleYield(values []glua.LValue) ([]glua.LValue, error) {
|
||||
if len(values) == 0 {
|
||||
return nil, fmt.Errorf("yield without reason")
|
||||
}
|
||||
|
||||
reason := glua.LVAsString(values[0])
|
||||
|
||||
switch reason {
|
||||
case "sleep":
|
||||
return c.handleSleep(values[1:])
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown yield reason: %s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
// handleSleep 处理 sleep yield
|
||||
// 注意:此实现会阻塞当前 goroutine
|
||||
func (c *LuaCoroutine) handleSleep(values []glua.LValue) ([]glua.LValue, error) {
|
||||
if len(values) == 0 {
|
||||
return nil, fmt.Errorf("sleep requires duration")
|
||||
}
|
||||
|
||||
duration := float64(glua.LVAsNumber(values[0]))
|
||||
d := time.Duration(duration * float64(time.Second))
|
||||
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
// sleep 完成,返回空结果
|
||||
return []glua.LValue{}, nil
|
||||
case <-c.ExecutionContext.Done():
|
||||
// 执行超时或取消
|
||||
return nil, fmt.Errorf("sleep interrupted: %w", c.ExecutionContext.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭协程
|
||||
func (c *LuaCoroutine) Close() {
|
||||
c.Engine.releaseCoroutine(c)
|
||||
}
|
||||
185
internal/lua/engine.go
Normal file
185
internal/lua/engine.go
Normal file
@ -0,0 +1,185 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// LuaEngine 全局 Lua 引擎
|
||||
// 每个 HTTP Server 实例持有一个 LuaEngine
|
||||
type LuaEngine struct {
|
||||
// 主 LState
|
||||
L *glua.LState
|
||||
|
||||
// 配置
|
||||
config *Config
|
||||
|
||||
// 字节码缓存
|
||||
codeCache *CodeCache
|
||||
|
||||
// 协程管理
|
||||
activeCount int32 // 活跃协程数
|
||||
maxCoroutines int // 最大并发协程数
|
||||
coroutinePool sync.Pool // 协程对象池(注意:池中的协程已 dead,不可复用,仅复用内存)
|
||||
|
||||
// 生命周期
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// 统计
|
||||
stats EngineStats
|
||||
}
|
||||
|
||||
// EngineStats 引擎统计信息
|
||||
type EngineStats struct {
|
||||
CoroutinesCreated uint64
|
||||
CoroutinesClosed uint64
|
||||
ScriptsExecuted uint64
|
||||
ScriptsErrors uint64
|
||||
}
|
||||
|
||||
// NewEngine 创建 Lua 引擎
|
||||
func NewEngine(config *Config) (*LuaEngine, error) {
|
||||
if config == nil {
|
||||
config = DefaultConfig()
|
||||
}
|
||||
|
||||
// 创建主 LState
|
||||
L := glua.NewState(glua.Options{
|
||||
SkipOpenLibs: true, // 禁用默认库,手动加载安全库
|
||||
})
|
||||
|
||||
// 加载安全的标准库
|
||||
glua.OpenBase(L)
|
||||
glua.OpenTable(L)
|
||||
glua.OpenString(L)
|
||||
glua.OpenMath(L)
|
||||
glua.OpenCoroutine(L) // 加载 coroutine 库支持 yield
|
||||
|
||||
// 可选加载危险库
|
||||
if config.EnableOSLib {
|
||||
glua.OpenOs(L)
|
||||
}
|
||||
if config.EnableIOLib {
|
||||
glua.OpenIo(L)
|
||||
}
|
||||
// 注意:package 库默认不加载,禁止 require 外部模块
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
engine := &LuaEngine{
|
||||
L: L,
|
||||
config: config,
|
||||
codeCache: NewCodeCache(config.CodeCacheSize, config.CodeCacheTTL, config.EnableFileWatch),
|
||||
maxCoroutines: config.MaxConcurrentCoroutines,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
coroutinePool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
// 注意:这里只是创建空的协程对象结构
|
||||
// 实际的协程通过 L.NewThread() 创建
|
||||
return &LuaCoroutine{}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// Close 关闭引擎
|
||||
func (e *LuaEngine) Close() {
|
||||
e.cancel()
|
||||
if e.L != nil {
|
||||
e.L.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// NewCoroutine 创建临时协程
|
||||
// 注意:协程在 ResumeOK 后变成 dead 状态,不能复用
|
||||
func (e *LuaEngine) NewCoroutine(req *fasthttp.RequestCtx) (*LuaCoroutine, error) {
|
||||
// 检查并发限制
|
||||
current := atomic.AddInt32(&e.activeCount, 1)
|
||||
if current > int32(e.maxCoroutines) {
|
||||
atomic.AddInt32(&e.activeCount, -1)
|
||||
return nil, fmt.Errorf("max concurrent coroutines exceeded: %d/%d", current, e.maxCoroutines)
|
||||
}
|
||||
|
||||
// 通过 NewThread 创建协程
|
||||
// 协程继承主 LState 的全局环境
|
||||
co, cancel := e.L.NewThread()
|
||||
if co == nil {
|
||||
atomic.AddInt32(&e.activeCount, -1)
|
||||
return nil, fmt.Errorf("failed to create coroutine")
|
||||
}
|
||||
|
||||
// 从池中获取协程对象结构(复用内存,不复用协程状态)
|
||||
coro := e.coroutinePool.Get().(*LuaCoroutine)
|
||||
coro.Engine = e
|
||||
coro.Co = co
|
||||
coro.Cancel = cancel
|
||||
coro.RequestCtx = req
|
||||
coro.CreatedAt = time.Now()
|
||||
coro.ExecutionContext, coro.executionCancel = context.WithTimeout(e.ctx, e.config.MaxExecutionTime)
|
||||
|
||||
atomic.AddUint64(&e.stats.CoroutinesCreated, 1)
|
||||
|
||||
return coro, nil
|
||||
}
|
||||
|
||||
// releaseCoroutine 释放协程(内部方法)
|
||||
func (e *LuaEngine) releaseCoroutine(coro *LuaCoroutine) {
|
||||
if coro == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 取消执行上下文
|
||||
if coro.executionCancel != nil {
|
||||
coro.executionCancel()
|
||||
}
|
||||
|
||||
// 取消协程
|
||||
if coro.Cancel != nil {
|
||||
coro.Cancel()
|
||||
}
|
||||
|
||||
// 清理状态
|
||||
coro.Co = nil
|
||||
coro.Cancel = nil
|
||||
coro.RequestCtx = nil
|
||||
coro.ExecutionContext = nil
|
||||
coro.executionCancel = nil
|
||||
|
||||
// 更新计数
|
||||
atomic.AddInt32(&e.activeCount, -1)
|
||||
atomic.AddUint64(&e.stats.CoroutinesClosed, 1)
|
||||
|
||||
// 放回池中(仅复用 LuaCoroutine 结构体内存)
|
||||
e.coroutinePool.Put(coro)
|
||||
}
|
||||
|
||||
// CodeCache 返回字节码缓存
|
||||
func (e *LuaEngine) CodeCache() *CodeCache {
|
||||
return e.codeCache
|
||||
}
|
||||
|
||||
// Stats 返回引擎统计
|
||||
func (e *LuaEngine) Stats() EngineStats {
|
||||
return EngineStats{
|
||||
CoroutinesCreated: atomic.LoadUint64(&e.stats.CoroutinesCreated),
|
||||
CoroutinesClosed: atomic.LoadUint64(&e.stats.CoroutinesClosed),
|
||||
ScriptsExecuted: atomic.LoadUint64(&e.stats.ScriptsExecuted),
|
||||
ScriptsErrors: atomic.LoadUint64(&e.stats.ScriptsErrors),
|
||||
}
|
||||
}
|
||||
|
||||
// ActiveCoroutines 返回活跃协程数
|
||||
func (e *LuaEngine) ActiveCoroutines() int32 {
|
||||
return atomic.LoadInt32(&e.activeCount)
|
||||
}
|
||||
458
internal/lua/lua_test.go
Normal file
458
internal/lua/lua_test.go
Normal file
@ -0,0 +1,458 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestLuaContext 测试 LuaContext 基础功能
|
||||
func TestLuaContext(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
ctx := NewContext(engine, nil)
|
||||
require.NotNil(t, ctx)
|
||||
assert.NotNil(t, ctx.Engine)
|
||||
assert.NotNil(t, ctx.Variables)
|
||||
assert.Equal(t, PhaseInit, ctx.Phase)
|
||||
|
||||
// 测试变量操作
|
||||
ctx.SetVariable("test_key", "test_value")
|
||||
val, ok := ctx.GetVariable("test_key")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test_value", val)
|
||||
|
||||
// 测试未存在的变量
|
||||
_, ok = ctx.GetVariable("nonexistent")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
// TestLuaContextPhase 测试阶段设置
|
||||
func TestLuaContextPhase(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
ctx := NewContext(engine, nil)
|
||||
|
||||
// 测试所有阶段
|
||||
phases := []Phase{PhaseInit, PhaseRewrite, PhaseAccess, PhaseContent, PhaseLog, PhaseHeaderFilter, PhaseBodyFilter}
|
||||
for _, p := range phases {
|
||||
ctx.SetPhase(p)
|
||||
assert.Equal(t, p, ctx.GetPhase())
|
||||
}
|
||||
|
||||
// 测试阶段字符串
|
||||
assert.Equal(t, "init", PhaseInit.String())
|
||||
assert.Equal(t, "rewrite", PhaseRewrite.String())
|
||||
assert.Equal(t, "access", PhaseAccess.String())
|
||||
assert.Equal(t, "content", PhaseContent.String())
|
||||
assert.Equal(t, "log", PhaseLog.String())
|
||||
assert.Equal(t, "header_filter", PhaseHeaderFilter.String())
|
||||
assert.Equal(t, "body_filter", PhaseBodyFilter.String())
|
||||
}
|
||||
|
||||
// TestLuaContextOutput 测试输出缓冲
|
||||
func TestLuaContextOutput(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
ctx := NewContext(engine, nil)
|
||||
|
||||
// 测试 Write
|
||||
ctx.Write([]byte("hello"))
|
||||
assert.Equal(t, []byte("hello"), ctx.OutputBuffer)
|
||||
|
||||
// 测试 Say - Say 会添加 data 然后换行
|
||||
ctx.OutputBuffer = nil // 清空重新测试
|
||||
ctx.Say("hello")
|
||||
assert.Equal(t, []byte("hello\n"), ctx.OutputBuffer)
|
||||
}
|
||||
|
||||
// TestLuaContextFlushOutput 测试刷新输出
|
||||
func TestLuaContextFlushOutput(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
// 当 RequestCtx 为 nil 时,FlushOutput 应该安全处理
|
||||
ctx := NewContext(engine, nil)
|
||||
ctx.OutputBuffer = []byte("test output")
|
||||
|
||||
// FlushOutput 应该不会 panic(RequestCtx 为 nil)
|
||||
ctx.FlushOutput()
|
||||
// OutputBuffer 应该保持不变(因为 RequestCtx 为 nil)
|
||||
assert.NotNil(t, ctx.OutputBuffer)
|
||||
}
|
||||
|
||||
// TestLuaContextExecute 测试 Lua 执行
|
||||
func TestLuaContextExecute(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
ctx := NewContext(engine, nil)
|
||||
|
||||
// 执行简单脚本
|
||||
err = ctx.Execute("local x = 1 + 1")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Release
|
||||
ctx.Release()
|
||||
assert.Nil(t, ctx.Coroutine)
|
||||
}
|
||||
|
||||
// TestLuaContextExecuteFile 测试文件执行
|
||||
func TestLuaContextExecuteFile(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
ctx := NewContext(engine, nil)
|
||||
|
||||
// 创建临时 Lua 文件
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "test.lua")
|
||||
err = os.WriteFile(scriptPath, []byte("return 42"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 执行文件
|
||||
err = ctx.ExecuteFile(scriptPath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx.Release()
|
||||
}
|
||||
|
||||
// TestLuaCoroutineExecute 测试协程执行
|
||||
func TestLuaCoroutineExecute(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, coro)
|
||||
defer coro.Close()
|
||||
|
||||
// 设置沙箱
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 执行脚本
|
||||
err = coro.Execute("return 42")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestLuaCoroutineExecuteWithYield 测试 yield/resume
|
||||
func TestLuaCoroutineExecuteWithYield(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, coro)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 执行带 yield 的脚本 - 验证 yield/resume 循环
|
||||
// 注意:需要注册 lolly.sleep 函数才能正确处理 yield
|
||||
err = coro.Execute("local x = 1; return x + 1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestLuaCoroutineExecuteFile 测试协程文件执行
|
||||
func TestLuaCoroutineExecuteFile(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(nil)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
// 创建临时文件
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "test.lua")
|
||||
err = os.WriteFile(scriptPath, []byte("return 42"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = coro.ExecuteFile(scriptPath)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestLuaCoroutineExecuteError 测试执行错误
|
||||
func TestLuaCoroutineExecuteError(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(nil)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
// 编译错误
|
||||
err = coro.Execute("invalid lua syntax !!!")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "compile")
|
||||
|
||||
// 运行时错误
|
||||
err = coro.Execute("error('runtime error')")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestLuaCoroutineExecuteFileError 测试文件执行错误
|
||||
func TestLuaCoroutineExecuteFileError(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(nil)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
// 不存在的文件
|
||||
err = coro.ExecuteFile("/nonexistent/path.lua")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestLuaCoroutineHandleYield 测试 yield 处理
|
||||
func TestLuaCoroutineHandleYield(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(nil)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
err = coro.SetupSandbox()
|
||||
require.NoError(t, err)
|
||||
|
||||
// 测试 unknown yield reason - 会返回错误
|
||||
// 因为 handleYield 会检查 yield reason
|
||||
err = coro.Execute("coroutine.yield('unknown_reason')")
|
||||
assert.Error(t, err) // unknown yield reason
|
||||
}
|
||||
|
||||
// TestLuaCoroutineHandleSleep 测试 sleep yield 处理
|
||||
// 注意:需要 coroutine 库支持,当前沙箱未加载
|
||||
func TestLuaCoroutineHandleSleep(t *testing.T) {
|
||||
engine, err := NewEngine(&Config{
|
||||
MaxConcurrentCoroutines: 1000,
|
||||
MaxExecutionTime: 5 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
coro, err := engine.NewCoroutine(nil)
|
||||
require.NoError(t, err)
|
||||
defer coro.Close()
|
||||
|
||||
// 不设置沙箱,使用全局环境(包含 coroutine 库)
|
||||
// 简单测试 execute 和 yield 循环的基本路径
|
||||
err = coro.Execute("return 1 + 1")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 测试错误路径 - yield 无参数
|
||||
err = coro.Execute("coroutine.yield()")
|
||||
// 由于 coroutine 库可能在沙箱中不可用,这个测试可能返回编译错误或运行时错误
|
||||
// 重点覆盖代码路径
|
||||
_ = err
|
||||
}
|
||||
|
||||
// TestCodeCacheFile 测试文件缓存
|
||||
func TestCodeCacheFile(t *testing.T) {
|
||||
// 创建临时 Lua 文件
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "test.lua")
|
||||
scriptContent := "return 42"
|
||||
|
||||
err := os.WriteFile(scriptPath, []byte(scriptContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
cache := NewCodeCache(100, time.Hour, true)
|
||||
|
||||
// 第一次编译文件
|
||||
proto1, err := cache.GetOrCompileFile(scriptPath)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, proto1)
|
||||
|
||||
// 第二次应该命中缓存
|
||||
proto2, err := cache.GetOrCompileFile(scriptPath)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, proto2)
|
||||
|
||||
assert.Equal(t, proto1, proto2)
|
||||
|
||||
hits, misses, _ := cache.Stats()
|
||||
assert.Equal(t, uint64(1), hits)
|
||||
assert.Equal(t, uint64(1), misses)
|
||||
}
|
||||
|
||||
// TestCodeCacheEviction 测试缓存淘汰
|
||||
func TestCodeCacheEviction(t *testing.T) {
|
||||
cache := NewCodeCache(2, 0, false) // 只存 2 个
|
||||
|
||||
// 编译 3 个脚本,触发淘汰
|
||||
_, err := cache.GetOrCompileInline("return 1")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = cache.GetOrCompileInline("return 2")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = cache.GetOrCompileInline("return 3")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 第一个应该被淘汰了
|
||||
hits, misses, size := cache.Stats()
|
||||
assert.Equal(t, uint64(0), hits)
|
||||
assert.Equal(t, uint64(3), misses)
|
||||
assert.LessOrEqual(t, size, 2)
|
||||
}
|
||||
|
||||
// TestCodeCacheTTL 测试 TTL 过期后重新编译
|
||||
func TestCodeCacheTTL(t *testing.T) {
|
||||
cache := NewCodeCache(100, 100*time.Millisecond, false)
|
||||
|
||||
script := "return 1"
|
||||
|
||||
// 编译脚本
|
||||
_, err := cache.GetOrCompileInline(script)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 等待 TTL 过期
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// 应该重新编译(miss)
|
||||
_, err = cache.GetOrCompileInline(script)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 检查 stats:两次 miss,因为 TTL 过期后重新编译
|
||||
hits, misses, _ := cache.Stats()
|
||||
assert.Equal(t, uint64(0), hits) // 没有 hit
|
||||
assert.Equal(t, uint64(2), misses) // 两次 miss
|
||||
}
|
||||
|
||||
// TestCodeCacheClear 测试清空缓存
|
||||
func TestCodeCacheClear(t *testing.T) {
|
||||
cache := NewCodeCache(100, 0, false)
|
||||
|
||||
// 添加一些缓存
|
||||
_, err := cache.GetOrCompileInline("return 1")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = cache.GetOrCompileInline("return 2")
|
||||
require.NoError(t, err)
|
||||
|
||||
hits, misses, size := cache.Stats()
|
||||
assert.Equal(t, 2, size)
|
||||
_ = hits
|
||||
_ = misses
|
||||
|
||||
// 清空
|
||||
cache.Clear()
|
||||
|
||||
hits, misses, size = cache.Stats()
|
||||
assert.Equal(t, 0, size)
|
||||
_ = hits
|
||||
_ = misses
|
||||
}
|
||||
|
||||
// TestCodeCacheHitRate 测试命中率
|
||||
func TestCodeCacheHitRate(t *testing.T) {
|
||||
cache := NewCodeCache(100, 0, false)
|
||||
|
||||
script := "return 1"
|
||||
|
||||
// 第一次 miss
|
||||
_, err := cache.GetOrCompileInline(script)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 第二次 hit
|
||||
_, err = cache.GetOrCompileInline(script)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 第三次 hit
|
||||
_, err = cache.GetOrCompileInline(script)
|
||||
require.NoError(t, err)
|
||||
|
||||
hitRate := cache.HitRate()
|
||||
assert.Equal(t, 2.0/3.0, hitRate)
|
||||
}
|
||||
|
||||
// TestEngineStats 测试引擎统计
|
||||
func TestEngineStats(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
// 初始统计应该为 0
|
||||
stats := engine.Stats()
|
||||
assert.Equal(t, uint64(0), stats.CoroutinesCreated)
|
||||
assert.Equal(t, uint64(0), stats.CoroutinesClosed)
|
||||
|
||||
// 创建协程
|
||||
coro, err := engine.NewCoroutine(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
stats = engine.Stats()
|
||||
assert.Equal(t, uint64(1), stats.CoroutinesCreated)
|
||||
assert.Equal(t, uint64(0), stats.CoroutinesClosed)
|
||||
|
||||
// 关闭协程
|
||||
coro.Close()
|
||||
|
||||
stats = engine.Stats()
|
||||
assert.Equal(t, uint64(1), stats.CoroutinesCreated)
|
||||
assert.Equal(t, uint64(1), stats.CoroutinesClosed)
|
||||
}
|
||||
|
||||
// TestEngineCodeCache 测试引擎字节码缓存访问
|
||||
func TestEngineCodeCache(t *testing.T) {
|
||||
engine, err := NewEngine(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
cache := engine.CodeCache()
|
||||
require.NotNil(t, cache)
|
||||
|
||||
// 通过引擎缓存编译
|
||||
proto, err := cache.GetOrCompileInline("return 42")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, proto)
|
||||
}
|
||||
|
||||
// TestConfig 测试配置
|
||||
func TestConfig(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
require.NotNil(t, config)
|
||||
|
||||
// 默认配置值
|
||||
assert.Equal(t, 1000, config.MaxConcurrentCoroutines)
|
||||
assert.Equal(t, 30*time.Second, config.MaxExecutionTime)
|
||||
assert.Equal(t, 1000, config.CodeCacheSize)
|
||||
|
||||
// 测试自定义配置
|
||||
customConfig := &Config{
|
||||
MaxConcurrentCoroutines: 100,
|
||||
MaxExecutionTime: time.Minute,
|
||||
CodeCacheSize: 200,
|
||||
}
|
||||
|
||||
engine, err := NewEngine(customConfig)
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
|
||||
assert.Equal(t, 100, engine.maxCoroutines)
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user