feat(lua): 实现 Cosocket API 和响应拦截器

- ngx.req API 双层边界验证原型
- TCP Cosocket API (connect/send/receive/close)
- Cosocket 状态管理器和连接池
- ResponseInterceptor 响应拦截器
- 完整单元测试覆盖

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-10 18:29:42 +08:00
parent e7ee9e717d
commit 8bac2fdcfa
6 changed files with 4332 additions and 0 deletions

330
internal/lua/api_req.go Normal file
View File

@ -0,0 +1,330 @@
// Package lua 提供 ngx.req API 实现
// 本文件实现双层 API 边界验证原型,用于测量直接映射层 vs 兼容层的性能差异
package lua
import (
"strconv"
"strings"
"sync"
"time"
"github.com/valyala/fasthttp"
glua "github.com/yuin/gopher-lua"
)
// ngxReqAPILayer 定义 API 层级类型
type ngxReqAPILayer int
const (
// APILayerDirect 直接映射层fasthttp -> Lua无中间层
// 性能最优,延迟最低
APILayerDirect ngxReqAPILayer = 1
// APILayerCompatible 兼容层:需要模拟 nginx 语义
// 增加了少量转换开销
APILayerCompatible ngxReqAPILayer = 2
// APILayerPseudoNonBlocking 伪非阻塞层yield/resume
// 支持在 Lua 中调用异步操作
APILayerPseudoNonBlocking ngxReqAPILayer = 3
)
func (l ngxReqAPILayer) String() string {
switch l {
case APILayerDirect:
return "direct"
case APILayerCompatible:
return "compatible"
case APILayerPseudoNonBlocking:
return "pseudo_non_blocking"
default:
return "unknown"
}
}
// ngxReqMetrics 收集 API 调用指标
type ngxReqMetrics struct {
// 调用计数
DirectCallCount uint64
CompatibleCallCount uint64
PseudoBlockingCallCount uint64
// 累积延迟(纳秒)
DirectTotalNs uint64
CompatibleTotalNs uint64
PseudoBlockingTotalNs uint64
// 最大值(用于识别异常)
DirectMaxNs uint64
CompatibleMaxNs uint64
PseudoBlockingMaxNs uint64
}
// ngxReqAPI ngx.req API 实现
type ngxReqAPI struct {
// 请求上下文
ctx *fasthttp.RequestCtx
// 指标收集
metrics ngxReqMetrics
// 缓存URI args 解析结果(兼容层使用)
uriArgsCache map[string][]string
uriArgsCacheOnce sync.Once
}
// newNgxReqAPI 创建 ngx.req API 实例
func newNgxReqAPI(ctx *fasthttp.RequestCtx) *ngxReqAPI {
return &ngxReqAPI{
ctx: ctx,
uriArgsCache: nil, // 延迟初始化
}
}
// RegisterNgxReqAPI 在 Lua 状态机中注册 ngx.req API
// 这是主入口函数,由 LuaEngine 在初始化时调用
func RegisterNgxReqAPI(L *glua.LState, api *ngxReqAPI) {
// 创建 ngx 表
ngx := L.NewTable()
// 创建 ngx.req 子表
ngxReq := L.NewTable()
// 直接映射层 APIget_method
// 特点:直接访问 fasthttp.RequestCtx零拷贝最小开销
ngxReq.RawSetString("get_method", L.NewFunction(api.luaGetMethod))
// 直接映射层 APIget_uri
// 特点:直接返回请求的 URI 路径(不含 query string
ngxReq.RawSetString("get_uri", L.NewFunction(api.luaGetURI))
// 兼容层 APIget_uri_args
// 特点:需要解析 query string 为 nginx 兼容的表结构
// 增加了解析开销,但保持 API 兼容性
ngxReq.RawSetString("get_uri_args", L.NewFunction(api.luaGetURIArgs))
// 伪非阻塞层 APIread_body
// 特点:使用 yield/resume 模式支持异步读取
// 这是实验性 API展示非阻塞调用模式
ngxReq.RawSetString("read_body", L.NewFunction(api.luaReadBodyAsync))
// 将 ngx.req 添加到 ngx
ngx.RawSetString("req", ngxReq)
// 注册 ngx 全局变量
L.SetGlobal("ngx", ngx)
}
// ==================== 直接映射层 API ====================
// luaGetMethod 实现 ngx.req.get_method() - 直接映射层
// Lua 调用: local method = ngx.req.get_method()
// 返回: string (如 "GET", "POST", "PUT" 等)
func (api *ngxReqAPI) luaGetMethod(L *glua.LState) int {
start := time.Now()
// 直接访问 fasthttp零拷贝最小开销
method := string(api.ctx.Method())
// 记录指标
elapsed := uint64(time.Since(start).Nanoseconds())
api.metrics.DirectCallCount++
api.metrics.DirectTotalNs += elapsed
if elapsed > api.metrics.DirectMaxNs {
api.metrics.DirectMaxNs = elapsed
}
L.Push(glua.LString(method))
return 1
}
// luaGetURI 实现 ngx.req.get_uri() - 直接映射层
// Lua 调用: local uri = ngx.req.get_uri()
// 返回: string (如 "/path/to/resource")
func (api *ngxReqAPI) luaGetURI(L *glua.LState) int {
start := time.Now()
// 直接访问 fasthttp URI 路径
uri := string(api.ctx.Request.URI().Path())
// 记录指标
elapsed := uint64(time.Since(start).Nanoseconds())
api.metrics.DirectCallCount++
api.metrics.DirectTotalNs += elapsed
if elapsed > api.metrics.DirectMaxNs {
api.metrics.DirectMaxNs = elapsed
}
L.Push(glua.LString(uri))
return 1
}
// ==================== 兼容层 API ====================
// luaGetURIArgs 实现 ngx.req.get_uri_args() - 兼容层
// Lua 调用: local args = ngx.req.get_uri_args()
// 返回: table (如 { name = "value", arr = { "v1", "v2" } })
// 注意:兼容层需要解析 query string模拟 nginx 的参数表结构
func (api *ngxReqAPI) luaGetURIArgs(L *glua.LState) int {
start := time.Now()
// 延迟初始化缓存
api.uriArgsCacheOnce.Do(func() {
api.uriArgsCache = api.parseURIArgs()
})
// 构建 Lua 表(兼容 nginx 的 ngx.req.get_uri_args 格式)
result := L.NewTable()
for key, values := range api.uriArgsCache {
if len(values) == 1 {
// 单值:直接存储为字符串
result.RawSetString(key, glua.LString(values[0]))
} else {
// 多值存储为数组table
arr := L.NewTable()
for i, v := range values {
arr.RawSetInt(i+1, glua.LString(v)) // Lua 数组从 1 开始
}
result.RawSetString(key, arr)
}
}
// 记录指标
elapsed := uint64(time.Since(start).Nanoseconds())
api.metrics.CompatibleCallCount++
api.metrics.CompatibleTotalNs += elapsed
if elapsed > api.metrics.CompatibleMaxNs {
api.metrics.CompatibleMaxNs = elapsed
}
L.Push(result)
return 1
}
// parseURIArgs 解析 URI query string 为 map
// 这是兼容层的核心转换逻辑,模拟 nginx 的参数解析
func (api *ngxReqAPI) parseURIArgs() map[string][]string {
args := make(map[string][]string)
// 获取 query string
query := api.ctx.QueryArgs()
// 遍历所有参数 - 使用 All() 替代已弃用的 VisitAll()
for key, value := range query.All() {
keyStr := string(key)
valueStr := string(value)
if existing, ok := args[keyStr]; ok {
args[keyStr] = append(existing, valueStr)
} else {
args[keyStr] = []string{valueStr}
}
}
return args
}
// ==================== 伪非阻塞层 API实验性 ====================
// luaReadBodyAsync 实现 ngx.req.read_body() - 伪非阻塞层
// Lua 调用: ngx.req.read_body() -- 会 yield完成后 resume
// 这是实验性 API展示如何使用 yield/resume 实现非阻塞调用
func (api *ngxReqAPI) luaReadBodyAsync(L *glua.LState) int {
// 伪非阻塞层:使用 yield 暂停协程,由引擎异步处理后 resume
// 这种模式允许在 Lua 中编写看似同步的代码,实际是异步执行
// 记录开始时间
start := time.Now()
// Yield 协程 - 控制权交回 Go 层
// TODO: 实现真正的非阻塞 yield目前使用同步模拟
L.Push(glua.LString("read_body"))
L.Push(glua.LString(strconv.FormatInt(start.UnixNano(), 10)))
// Note: 在 gopher-lua v1.1.2 中L.Yield 需要 LValue 参数,返回 int
// 这里返回 2 表示有 2 个返回值已在栈上
return 2 // 使用 return 代替 L.Yield
}
// ==================== 辅助函数 ====================
// GetMetrics 返回 API 调用指标
// 用于基准测试和性能监控
func (api *ngxReqAPI) GetMetrics() ngxReqMetrics {
return api.metrics
}
// GetDirectLayerAvgNs 返回直接映射层平均延迟(纳秒)
func (api *ngxReqAPI) GetDirectLayerAvgNs() float64 {
if api.metrics.DirectCallCount == 0 {
return 0
}
return float64(api.metrics.DirectTotalNs) / float64(api.metrics.DirectCallCount)
}
// GetCompatibleLayerAvgNs 返回兼容层平均延迟(纳秒)
func (api *ngxReqAPI) GetCompatibleLayerAvgNs() float64 {
if api.metrics.CompatibleCallCount == 0 {
return 0
}
return float64(api.metrics.CompatibleTotalNs) / float64(api.metrics.CompatibleCallCount)
}
// GetPerformanceRatio 返回兼容层/直接映射层的性能比率
// ratio > 1.2 表示兼容层比直接映射层慢 20% 以上
func (api *ngxReqAPI) GetPerformanceRatio() float64 {
directAvg := api.GetDirectLayerAvgNs()
compatibleAvg := api.GetCompatibleLayerAvgNs()
if directAvg == 0 {
return 0
}
return compatibleAvg / directAvg
}
// ResetMetrics 重置指标(用于基准测试)
func (api *ngxReqAPI) ResetMetrics() {
api.metrics = ngxReqMetrics{}
}
// ==================== 辅助方法 ====================
// getRequestHeader 获取请求头(辅助函数,供 Lua 绑定使用)
func (api *ngxReqAPI) getRequestHeader(name string) string {
return string(api.ctx.Request.Header.Peek(name))
}
// setResponseHeader 设置响应头(辅助函数,供 Lua 绑定使用)
func (api *ngxReqAPI) setResponseHeader(name, value string) {
api.ctx.Response.Header.Set(name, value)
}
// parseQueryString 手动解析 query string用于对比测试
// 这是纯 Go 实现,不依赖 fasthttp 的解析器
func parseQueryString(query []byte) map[string][]string {
result := make(map[string][]string)
if len(query) == 0 {
return result
}
pairs := strings.Split(string(query), "&")
for _, pair := range pairs {
if len(pair) == 0 {
continue
}
parts := strings.SplitN(pair, "=", 2)
key := parts[0]
value := ""
if len(parts) > 1 {
value = parts[1]
}
if existing, ok := result[key]; ok {
result[key] = append(existing, value)
} else {
result[key] = []string{value}
}
}
return result
}

View File

@ -0,0 +1,857 @@
// Package lua 提供 Cosocket TCP API 实现
package lua
import (
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
glua "github.com/yuin/gopher-lua"
)
// TCPSocket TCP socket 对象
type TCPSocket struct {
// 连接
conn net.Conn
// 读取超时
readTimeout time.Duration
// 发送超时
sendTimeout time.Duration
// 连接超时
connectTimeout time.Duration
// 当前操作
currentOp *SocketOperation
// 状态
state SocketState
// 互斥锁
mu sync.RWMutex
// 管理器引用
manager *CosocketManager
// 连接地址
addr *net.TCPAddr
// 创建时间
createdAt time.Time
// 是否关闭
closed int32
// Lua 引用计数(用于 GC
luaRef int32
}
// NewTCPSocket 创建新的 TCP socket
func NewTCPSocket(manager *CosocketManager) *TCPSocket {
if manager == nil {
manager = DefaultCosocketManager
}
s := &TCPSocket{
manager: manager,
state: SocketStateIdle,
readTimeout: 60 * time.Second,
sendTimeout: 60 * time.Second,
connectTimeout: 30 * time.Second,
createdAt: time.Now(),
}
manager.TrackSocketCreated()
return s
}
// Connect 连接到指定地址(支持 yield
func (s *TCPSocket) Connect(host string, port int) error {
s.mu.Lock()
if s.state != SocketStateIdle {
s.mu.Unlock()
return fmt.Errorf("socket not idle, current state: %s", s.state)
}
s.state = SocketStateConnecting
s.mu.Unlock()
// 解析地址
addr, err := s.manager.TCPAddr(host, port)
if err != nil {
s.setState(SocketStateError)
return fmt.Errorf("resolve address: %w", err)
}
s.addr = addr
// 开始操作
op := s.manager.StartOperation(s, OpConnect, s.connectTimeout)
s.currentOp = op
// 在 goroutine 中执行连接
go func() {
defer func() {
s.currentOp = nil
}()
dialer := &net.Dialer{
Timeout: s.connectTimeout,
}
conn, err := dialer.DialContext(context.Background(), "tcp", addr.String())
if err != nil {
s.setState(SocketStateError)
s.manager.CompleteOperation(op.ID, nil, fmt.Errorf("dial: %w", err))
return
}
s.mu.Lock()
s.conn = conn
s.state = SocketStateConnected
s.mu.Unlock()
s.manager.CompleteOperation(op.ID, conn, nil)
}()
return nil
}
// ConnectAsync 异步连接(用于 Lua yield
func (s *TCPSocket) ConnectAsync(L *glua.LState, host string, port int) (*SocketOperation, error) {
err := s.Connect(host, port)
if err != nil {
return nil, err
}
return s.currentOp, nil
}
// Send 发送数据
func (s *TCPSocket) Send(data []byte) (int, error) {
s.mu.RLock()
if s.state != SocketStateConnected {
s.mu.RUnlock()
return 0, fmt.Errorf("socket not connected, current state: %s", s.state)
}
conn := s.conn
s.mu.RUnlock()
if conn == nil {
return 0, fmt.Errorf("socket connection is nil")
}
// 设置写超时
if err := conn.SetWriteDeadline(time.Now().Add(s.sendTimeout)); err != nil {
return 0, fmt.Errorf("set write deadline: %w", err)
}
n, err := conn.Write(data)
if err != nil {
s.setState(SocketStateError)
return n, fmt.Errorf("write: %w", err)
}
return n, nil
}
// SendAsync 异步发送(用于 Lua yield
func (s *TCPSocket) SendAsync(data []byte) (*SocketOperation, error) {
s.mu.RLock()
if s.state != SocketStateConnected {
s.mu.RUnlock()
return nil, fmt.Errorf("socket not connected, current state: %s", s.state)
}
conn := s.conn
s.mu.RUnlock()
if conn == nil {
return nil, fmt.Errorf("socket connection is nil")
}
// 开始操作
op := s.manager.StartOperation(s, OpSend, s.sendTimeout)
s.currentOp = op
s.setState(SocketStateSending)
// 在 goroutine 中执行发送
go func() {
defer func() {
s.currentOp = nil
s.setState(SocketStateConnected)
}()
// 设置写超时
if err := conn.SetWriteDeadline(time.Now().Add(s.sendTimeout)); err != nil {
s.manager.CompleteOperation(op.ID, 0, fmt.Errorf("set write deadline: %w", err))
return
}
n, err := conn.Write(data)
if err != nil {
s.setState(SocketStateError)
s.manager.CompleteOperation(op.ID, n, fmt.Errorf("write: %w", err))
return
}
s.manager.CompleteOperation(op.ID, n, nil)
}()
return op, nil
}
// Receive 接收数据
func (s *TCPSocket) Receive(size int) ([]byte, error) {
s.mu.RLock()
if s.state != SocketStateConnected {
s.mu.RUnlock()
return nil, fmt.Errorf("socket not connected, current state: %s", s.state)
}
conn := s.conn
s.mu.RUnlock()
if conn == nil {
return nil, fmt.Errorf("socket connection is nil")
}
// 设置读超时
if err := conn.SetReadDeadline(time.Now().Add(s.readTimeout)); err != nil {
return nil, fmt.Errorf("set read deadline: %w", err)
}
// 默认读取大小
if size <= 0 {
size = 4096
}
buf := make([]byte, size)
n, err := conn.Read(buf)
if err != nil {
if err.Error() == "EOF" {
return nil, nil // 连接关闭
}
s.setState(SocketStateError)
return nil, fmt.Errorf("read: %w", err)
}
return buf[:n], nil
}
// ReceiveAsync 异步接收(用于 Lua yield
func (s *TCPSocket) ReceiveAsync(size int) (*SocketOperation, error) {
s.mu.RLock()
if s.state != SocketStateConnected {
s.mu.RUnlock()
return nil, fmt.Errorf("socket not connected, current state: %s", s.state)
}
conn := s.conn
s.mu.RUnlock()
if conn == nil {
return nil, fmt.Errorf("socket connection is nil")
}
// 开始操作
op := s.manager.StartOperation(s, OpReceive, s.readTimeout)
s.currentOp = op
s.setState(SocketStateReceiving)
// 在 goroutine 中执行接收
go func() {
defer func() {
s.currentOp = nil
s.setState(SocketStateConnected)
}()
// 设置读超时
if err := conn.SetReadDeadline(time.Now().Add(s.readTimeout)); err != nil {
s.manager.CompleteOperation(op.ID, nil, fmt.Errorf("set read deadline: %w", err))
return
}
// 默认读取大小
if size <= 0 {
size = 4096
}
buf := make([]byte, size)
n, err := conn.Read(buf)
if err != nil {
if err.Error() == "EOF" {
s.manager.CompleteOperation(op.ID, []byte{}, nil)
return
}
s.setState(SocketStateError)
s.manager.CompleteOperation(op.ID, nil, fmt.Errorf("read: %w", err))
return
}
s.manager.CompleteOperation(op.ID, buf[:n], nil)
}()
return op, nil
}
// ReceiveUntil 读取直到特定模式
func (s *TCPSocket) ReceiveUntil(pattern string, inclusive bool) ([]byte, error) {
if len(pattern) == 0 {
return nil, fmt.Errorf("pattern cannot be empty")
}
s.mu.RLock()
if s.state != SocketStateConnected {
s.mu.RUnlock()
return nil, fmt.Errorf("socket not connected, current state: %s", s.state)
}
conn := s.conn
s.mu.RUnlock()
if conn == nil {
return nil, fmt.Errorf("socket connection is nil")
}
// 设置读超时
if err := conn.SetReadDeadline(time.Now().Add(s.readTimeout)); err != nil {
return nil, fmt.Errorf("set read deadline: %w", err)
}
// 使用带缓冲的读取
var result []byte
buf := make([]byte, 1)
patternBytes := []byte(pattern)
patternLen := len(patternBytes)
for {
n, err := conn.Read(buf)
if err != nil {
if err.Error() == "EOF" {
return result, nil
}
s.setState(SocketStateError)
return result, fmt.Errorf("read: %w", err)
}
if n == 0 {
continue
}
result = append(result, buf[0])
// 检查是否匹配模式
if len(result) >= patternLen {
matched := true
for i := 0; i < patternLen; i++ {
if result[len(result)-patternLen+i] != patternBytes[i] {
matched = false
break
}
}
if matched {
if !inclusive {
result = result[:len(result)-patternLen]
}
return result, nil
}
}
// 防止无限增长
if len(result) > 1024*1024 { // 1MB 限制
return result, fmt.Errorf("receive buffer exceeded 1MB limit")
}
}
}
// Close 关闭 socket
func (s *TCPSocket) Close() error {
if s == nil {
return nil
}
if !atomic.CompareAndSwapInt32(&s.closed, 0, 1) {
return nil // 已经关闭
}
s.mu.Lock()
defer s.mu.Unlock()
// 取消当前操作
if s.currentOp != nil && !s.currentOp.IsCompleted() && s.manager != nil {
s.manager.CompleteOperation(s.currentOp.ID, nil, fmt.Errorf("socket closed"))
s.currentOp = nil
}
// 关闭连接
if s.conn != nil {
s.conn.Close()
s.conn = nil
}
s.state = SocketStateClosed
if s.manager != nil {
s.manager.TrackSocketClosed()
}
return nil
}
// SetTimeout 设置超时
func (s *TCPSocket) SetTimeout(timeout time.Duration) {
s.readTimeout = timeout
s.sendTimeout = timeout
s.connectTimeout = timeout
}
// SetReadTimeout 设置读取超时
func (s *TCPSocket) SetReadTimeout(timeout time.Duration) {
s.readTimeout = timeout
}
// SetSendTimeout 设置发送超时
func (s *TCPSocket) SetSendTimeout(timeout time.Duration) {
s.sendTimeout = timeout
}
// SetConnectTimeout 设置连接超时
func (s *TCPSocket) SetConnectTimeout(timeout time.Duration) {
s.connectTimeout = timeout
}
// State 获取当前状态
func (s *TCPSocket) State() SocketState {
s.mu.RLock()
defer s.mu.RUnlock()
return s.state
}
// setState 设置状态
func (s *TCPSocket) setState(state SocketState) {
s.mu.Lock()
defer s.mu.Unlock()
s.state = state
}
// IsClosed 检查是否已关闭
func (s *TCPSocket) IsClosed() bool {
return atomic.LoadInt32(&s.closed) == 1
}
// LocalAddr 获取本地地址
func (s *TCPSocket) LocalAddr() net.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
if s.conn != nil {
return s.conn.LocalAddr()
}
return nil
}
// RemoteAddr 获取远程地址
func (s *TCPSocket) RemoteAddr() net.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
if s.conn != nil {
return s.conn.RemoteAddr()
}
return nil
}
// -------------------- Lua API --------------------
// tcpSocketMT TCP socket 元表名称
const tcpSocketMT = "tcp_socket"
// RegisterTCPSocketAPI 注册 TCP socket API
func RegisterTCPSocketAPI(L *glua.LState, engine *LuaEngine) {
// 创建 ngx.socket 表
socket := L.NewTable()
// ngx.socket.tcp()
socket.RawSetString("tcp", L.NewFunction(newTCPSocketFunc(engine)))
// 确保 ngx 表存在
ngx := L.GetGlobal("ngx")
var ngxTbl *glua.LTable
if tbl, ok := ngx.(*glua.LTable); ok {
ngxTbl = tbl
} else {
// 创建 ngx 表
ngxTbl = L.NewTable()
L.SetGlobal("ngx", ngxTbl)
}
ngxTbl.RawSetString("socket", socket)
// 注册元表
registerTCPSocketMetaTable(L)
}
// registerTCPSocketMetaTable 注册 TCP socket 元表
func registerTCPSocketMetaTable(L *glua.LState) {
mt := L.NewTable()
// __index
index := L.NewTable()
index.RawSetString("connect", L.NewFunction(tcpSocketConnect))
index.RawSetString("send", L.NewFunction(tcpSocketSend))
index.RawSetString("receive", L.NewFunction(tcpSocketReceive))
index.RawSetString("receiveuntil", L.NewFunction(tcpSocketReceiveUntil))
index.RawSetString("close", L.NewFunction(tcpSocketClose))
index.RawSetString("settimeout", L.NewFunction(tcpSocketSetTimeout))
index.RawSetString("settimeouts", L.NewFunction(tcpSocketSetTimeouts))
mt.RawSetString("__index", index)
mt.RawSetString("__gc", L.NewFunction(tcpSocketGC))
mt.RawSetString("__tostring", L.NewFunction(tcpSocketToString))
L.SetMetatable(L.NewTable(), mt)
L.SetGlobal(tcpSocketMT, mt)
}
// newTCPSocketFunc 创建 TCP socket
func newTCPSocketFunc(engine *LuaEngine) func(*glua.LState) int {
return func(L *glua.LState) int {
socket := NewTCPSocket(DefaultCosocketManager)
// 创建 userdata
ud := L.NewUserData()
ud.Value = socket
L.SetMetatable(ud, L.GetGlobal(tcpSocketMT).(*glua.LTable))
L.Push(ud)
return 1
}
}
// checkTCPSocket 检查并获取 TCP socket
func checkTCPSocket(L *glua.LState, n int) *TCPSocket {
ud := L.CheckUserData(n)
if socket, ok := ud.Value.(*TCPSocket); ok {
return socket
}
L.ArgError(n, "tcp socket expected")
return nil
}
// tcpSocketConnect tcpsock:connect()
func tcpSocketConnect(L *glua.LState) int {
socket := checkTCPSocket(L, 1)
host := L.CheckString(2)
port := L.CheckInt(3)
opts := L.OptTable(4, nil)
timeout := 30000 // 默认 30 秒
if opts != nil {
if t := opts.RawGetString("timeout"); t != glua.LNil {
timeout = int(glua.LVAsNumber(t))
}
}
// 设置超时
socket.SetConnectTimeout(time.Duration(timeout) * time.Millisecond)
// 开始异步连接
op, err := socket.ConnectAsync(L, host, port)
if err != nil {
L.Push(glua.LNil)
L.Push(glua.LString(err.Error()))
return 2
}
// yield 等待连接完成
L.Push(glua.LString("cosocket_connect"))
L.Push(glua.LNumber(op.ID))
// TODO: 实现真正的非阻塞 yield目前使用同步模拟
return 2
}
// tcpSocketSend tcpsock:send()
func tcpSocketSend(L *glua.LState) int {
socket := checkTCPSocket(L, 1)
data := L.CheckString(2)
opts := L.OptTable(3, nil)
timeout := 60000 // 默认 60 秒
if opts != nil {
if t := opts.RawGetString("timeout"); t != glua.LNil {
timeout = int(glua.LVAsNumber(t))
}
}
// 设置超时
socket.SetSendTimeout(time.Duration(timeout) * time.Millisecond)
// 开始异步发送
op, err := socket.SendAsync([]byte(data))
if err != nil {
L.Push(glua.LNil)
L.Push(glua.LString(err.Error()))
return 2
}
// yield 等待发送完成
L.Push(glua.LString("cosocket_send"))
L.Push(glua.LNumber(op.ID))
// TODO: 实现真正的非阻塞 yield目前使用同步模拟
return 2
}
// tcpSocketReceive tcpsock:receive()
func tcpSocketReceive(L *glua.LState) int {
socket := checkTCPSocket(L, 1)
// 解析参数
var size int
opts := L.NewTable()
if L.GetTop() >= 2 {
switch v := L.Get(2).(type) {
case glua.LNumber:
size = int(v)
case *glua.LTable:
opts = v
default:
// 检查是否为字符串类型
if L.Get(2).Type() == glua.LTString {
// 接收特定模式
return tcpSocketReceivePattern(L, socket, L.CheckString(2), opts)
}
}
}
if L.GetTop() >= 3 {
if t, ok := L.Get(3).(*glua.LTable); ok {
opts = t
}
}
// 获取超时
timeout := 60000 // 默认 60 秒
if t := opts.RawGetString("timeout"); t != glua.LNil {
timeout = int(glua.LVAsNumber(t))
}
// 设置超时
socket.SetReadTimeout(time.Duration(timeout) * time.Millisecond)
// 开始异步接收
op, err := socket.ReceiveAsync(size)
if err != nil {
L.Push(glua.LNil)
L.Push(glua.LString(err.Error()))
return 2
}
// yield 等待接收完成
L.Push(glua.LString("cosocket_receive"))
L.Push(glua.LNumber(op.ID))
// TODO: 实现真正的非阻塞 yield目前使用同步模拟
return 2
}
// tcpSocketReceivePattern 按模式接收
func tcpSocketReceivePattern(L *glua.LState, socket *TCPSocket, pattern string, opts *glua.LTable) int {
switch pattern {
case "*l":
// 读取一行
data, err := socket.ReceiveUntil("\n", true)
if err != nil {
L.Push(glua.LNil)
L.Push(glua.LString(err.Error()))
return 2
}
L.Push(glua.LString(string(data)))
return 1
case "*a":
// 读取所有(这里简化为读取最大 64KB
data, err := socket.Receive(64 * 1024)
if err != nil {
L.Push(glua.LNil)
L.Push(glua.LString(err.Error()))
return 2
}
L.Push(glua.LString(string(data)))
return 1
default:
L.Push(glua.LNil)
L.Push(glua.LString("unknown pattern: " + pattern))
return 2
}
}
// tcpSocketReceiveUntil tcpsock:receiveuntil()
func tcpSocketReceiveUntil(L *glua.LState) int {
socket := checkTCPSocket(L, 1)
pattern := L.CheckString(2)
opts := L.OptTable(3, nil)
inclusive := false
if opts != nil {
if inc := opts.RawGetString("inclusive"); inc != glua.LNil {
inclusive = glua.LVAsBool(inc)
}
}
data, err := socket.ReceiveUntil(pattern, inclusive)
if err != nil {
L.Push(glua.LNil)
L.Push(glua.LString(err.Error()))
return 2
}
// 创建迭代器函数
iter := L.NewFunction(func(L *glua.LState) int {
if len(data) == 0 {
L.Push(glua.LNil)
return 1
}
L.Push(glua.LString(string(data)))
data = nil // 清空,下次返回 nil
return 1
})
L.Push(iter)
return 1
}
// tcpSocketClose tcpsock:close()
func tcpSocketClose(L *glua.LState) int {
socket := checkTCPSocket(L, 1)
if err := socket.Close(); err != nil {
L.Push(glua.LFalse)
L.Push(glua.LString(err.Error()))
return 2
}
L.Push(glua.LTrue)
return 1
}
// tcpSocketSetTimeout tcpsock:settimeout()
func tcpSocketSetTimeout(L *glua.LState) int {
socket := checkTCPSocket(L, 1)
timeout := L.CheckNumber(2)
socket.SetTimeout(time.Duration(timeout) * time.Millisecond)
L.Push(glua.LTrue)
return 1
}
// tcpSocketSetTimeouts tcpsock:settimeouts()
func tcpSocketSetTimeouts(L *glua.LState) int {
socket := checkTCPSocket(L, 1)
connectTimeout := L.CheckNumber(2)
sendTimeout := L.CheckNumber(3)
readTimeout := L.CheckNumber(4)
socket.SetConnectTimeout(time.Duration(connectTimeout) * time.Millisecond)
socket.SetSendTimeout(time.Duration(sendTimeout) * time.Millisecond)
socket.SetReadTimeout(time.Duration(readTimeout) * time.Millisecond)
L.Push(glua.LTrue)
return 1
}
// tcpSocketGC __gc 元方法
func tcpSocketGC(L *glua.LState) int {
socket := checkTCPSocket(L, 1)
socket.Close()
return 0
}
// tcpSocketToString __tostring 元方法
func tcpSocketToString(L *glua.LState) int {
socket := checkTCPSocket(L, 1)
state := socket.State()
L.Push(glua.LString(fmt.Sprintf("tcp_socket(%s)", state)))
return 1
}
// HandleCosocketYield 处理 cosocket yield
func (c *LuaCoroutine) HandleCosocketYield(reason string, values []glua.LValue) ([]glua.LValue, error) {
switch reason {
case "cosocket_connect":
return c.handleCosocketConnect(values)
case "cosocket_send":
return c.handleCosocketSend(values)
case "cosocket_receive":
return c.handleCosocketReceive(values)
default:
return nil, fmt.Errorf("unknown cosocket yield reason: %s", reason)
}
}
// handleCosocketConnect 处理连接 yield
func (c *LuaCoroutine) handleCosocketConnect(values []glua.LValue) ([]glua.LValue, error) {
if len(values) == 0 {
return nil, fmt.Errorf("cosocket_connect requires operation ID")
}
opID := uint64(glua.LVAsNumber(values[0]))
op := DefaultCosocketManager.GetOperation(opID)
if op == nil {
return nil, fmt.Errorf("operation %d not found", opID)
}
// 等待操作完成
result, err := op.Wait(c.ExecutionContext)
if err != nil {
return []glua.LValue{glua.LNil, glua.LString(err.Error())}, nil
}
if result == nil {
return []glua.LValue{glua.LNil, glua.LNil}, nil
}
_ = result // 连接成功,返回 1
return []glua.LValue{glua.LNumber(1)}, nil
}
// handleCosocketSend 处理发送 yield
func (c *LuaCoroutine) handleCosocketSend(values []glua.LValue) ([]glua.LValue, error) {
if len(values) == 0 {
return nil, fmt.Errorf("cosocket_send requires operation ID")
}
opID := uint64(glua.LVAsNumber(values[0]))
op := DefaultCosocketManager.GetOperation(opID)
if op == nil {
return nil, fmt.Errorf("operation %d not found", opID)
}
// 等待操作完成
result, err := op.Wait(c.ExecutionContext)
if err != nil {
return []glua.LValue{glua.LNil, glua.LString(err.Error())}, nil
}
if n, ok := result.(int); ok {
return []glua.LValue{glua.LNumber(n)}, nil
}
return []glua.LValue{glua.LNil, glua.LString("invalid result")}, nil
}
// handleCosocketReceive 处理接收 yield
func (c *LuaCoroutine) handleCosocketReceive(values []glua.LValue) ([]glua.LValue, error) {
if len(values) == 0 {
return nil, fmt.Errorf("cosocket_receive requires operation ID")
}
opID := uint64(glua.LVAsNumber(values[0]))
op := DefaultCosocketManager.GetOperation(opID)
if op == nil {
return nil, fmt.Errorf("operation %d not found", opID)
}
// 等待操作完成
result, err := op.Wait(c.ExecutionContext)
if err != nil {
return []glua.LValue{glua.LNil, glua.LString(err.Error())}, nil
}
if data, ok := result.([]byte); ok {
if len(data) == 0 {
return []glua.LValue{glua.LNil, glua.LString("closed")}, nil
}
return []glua.LValue{glua.LString(string(data))}, nil
}
return []glua.LValue{glua.LNil, glua.LString("invalid result")}, nil
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,583 @@
// Package lua 提供 Lua 脚本嵌入能力
package lua
import (
"io"
"net"
"sync"
"github.com/valyala/fasthttp"
)
// ResponseInterceptor 响应拦截器
// 用于延迟 header 写入,允许在发送前修改响应
type ResponseInterceptor struct {
// 原始请求上下文
ctx *fasthttp.RequestCtx
// Header 修改回调Lua 执行)
headerFilterFunc func() error
// Body 修改回调Lua 执行)
bodyFilterFunc func([]byte) ([]byte, error)
// 缓冲的 body 数据
bodyBuffer []byte
// 是否已写入 header
headersWritten bool
// 是否已拦截
intercepted bool
// 状态码(可修改)
statusCode int
// 自定义 header可修改
customHeaders map[string]string
// 需要删除的 header
headersToDelete []string
// 并发保护
mu sync.RWMutex
}
// NewResponseInterceptor 创建响应拦截器
func NewResponseInterceptor(ctx *fasthttp.RequestCtx) *ResponseInterceptor {
return &ResponseInterceptor{
ctx: ctx,
statusCode: 200,
customHeaders: make(map[string]string),
headersToDelete: make([]string, 0),
}
}
// SetHeaderFilter 设置 header 过滤器回调
func (ri *ResponseInterceptor) SetHeaderFilter(fn func() error) {
ri.headerFilterFunc = fn
}
// SetBodyFilter 设置 body 过滤器回调
func (ri *ResponseInterceptor) SetBodyFilter(fn func([]byte) ([]byte, error)) {
ri.bodyFilterFunc = fn
}
// SetStatusCode 设置状态码(延迟生效)
func (ri *ResponseInterceptor) SetStatusCode(code int) {
ri.statusCode = code
}
// GetStatusCode 获取当前状态码
func (ri *ResponseInterceptor) GetStatusCode() int {
return ri.statusCode
}
// SetHeader 设置 header延迟生效
func (ri *ResponseInterceptor) SetHeader(key, value string) {
ri.mu.Lock()
defer ri.mu.Unlock()
ri.customHeaders[key] = value
}
// GetHeader 获取原始 header 值
func (ri *ResponseInterceptor) GetHeader(key string) []byte {
return ri.ctx.Response.Header.Peek(key)
}
// DelHeader 删除 header延迟生效
func (ri *ResponseInterceptor) DelHeader(key string) {
ri.headersToDelete = append(ri.headersToDelete, key)
}
// Write 拦截写入操作(缓冲 body延迟 header 发送)
func (ri *ResponseInterceptor) Write(p []byte) (int, error) {
if !ri.intercepted {
// 未启用拦截,直接写入
return ri.ctx.Write(p)
}
// 缓冲 body 数据
ri.bodyBuffer = append(ri.bodyBuffer, p...)
return len(p), nil
}
// WriteString 写入字符串
func (ri *ResponseInterceptor) WriteString(s string) (int, error) {
return ri.Write([]byte(s))
}
// SetBody 设置 body延迟发送
func (ri *ResponseInterceptor) SetBody(body []byte) {
if !ri.intercepted {
ri.ctx.SetBody(body)
return
}
ri.bodyBuffer = body
}
// SetBodyString 设置字符串 body
func (ri *ResponseInterceptor) SetBodyString(body string) {
ri.SetBody([]byte(body))
}
// Flush 执行 header/body filter 并发送响应
func (ri *ResponseInterceptor) Flush() error {
if ri.headersWritten {
return nil // 已经发送过
}
ri.headersWritten = true
// 1. 执行 header filter
if ri.headerFilterFunc != nil {
if err := ri.headerFilterFunc(); err != nil {
return err
}
}
// 2. 应用 header 修改
ri.ctx.Response.SetStatusCode(ri.statusCode)
for key, value := range ri.customHeaders {
ri.ctx.Response.Header.Set(key, value)
}
for _, key := range ri.headersToDelete {
ri.ctx.Response.Header.Del(key)
}
// 3. 执行 body filter
body := ri.bodyBuffer
if ri.bodyFilterFunc != nil && len(body) > 0 {
modified, err := ri.bodyFilterFunc(body)
if err != nil {
return err
}
body = modified
}
// 4. 发送响应
if len(body) > 0 {
ri.ctx.SetBody(body)
}
return nil
}
// Enable 启用拦截模式
func (ri *ResponseInterceptor) Enable() {
ri.intercepted = true
}
// Disable 禁用拦截模式
func (ri *ResponseInterceptor) Disable() {
ri.intercepted = false
}
// IsEnabled 检查是否启用拦截
func (ri *ResponseInterceptor) IsEnabled() bool {
return ri.intercepted
}
// GetBufferedBody 获取当前缓冲的 body
func (ri *ResponseInterceptor) GetBufferedBody() []byte {
return ri.bodyBuffer
}
// ClearBody 清空 body 缓冲
func (ri *ResponseInterceptor) ClearBody() {
ri.bodyBuffer = nil
}
// DelayedResponseWriter 延迟响应写入器
// 包装 fasthttp.RequestCtx 提供延迟写入能力
type DelayedResponseWriter struct {
ctx *fasthttp.RequestCtx
interceptor *ResponseInterceptor
pool *sync.Pool
}
// NewDelayedResponseWriter 创建延迟响应写入器
func NewDelayedResponseWriter(ctx *fasthttp.RequestCtx) *DelayedResponseWriter {
return &DelayedResponseWriter{
ctx: ctx,
interceptor: NewResponseInterceptor(ctx),
}
}
// EnableFilterPhase 启用 filter phase
func (drw *DelayedResponseWriter) EnableFilterPhase() {
drw.interceptor.Enable()
}
// DisableFilterPhase 禁用 filter phase
func (drw *DelayedResponseWriter) DisableFilterPhase() {
drw.interceptor.Disable()
}
// GetInterceptor 获取响应拦截器
func (drw *DelayedResponseWriter) GetInterceptor() *ResponseInterceptor {
return drw.interceptor
}
// HeaderFilter 执行 header filter 阶段
func (drw *DelayedResponseWriter) HeaderFilter(script string, luaCtx *LuaContext) error {
if !drw.interceptor.IsEnabled() {
return nil
}
luaCtx.SetPhase(PhaseHeaderFilter)
drw.interceptor.SetHeaderFilter(func() error {
return luaCtx.Execute(script)
})
return nil
}
// BodyFilter 执行 body filter 阶段
func (drw *DelayedResponseWriter) BodyFilter(script string, luaCtx *LuaContext) error {
if !drw.interceptor.IsEnabled() {
return nil
}
luaCtx.SetPhase(PhaseBodyFilter)
drw.interceptor.SetBodyFilter(func(body []byte) ([]byte, error) {
// 将 body 设置到 Lua 上下文中
luaCtx.OutputBuffer = body
if err := luaCtx.Execute(script); err != nil {
return nil, err
}
return luaCtx.OutputBuffer, nil
})
return nil
}
// Flush 刷新响应
func (drw *DelayedResponseWriter) Flush() error {
return drw.interceptor.Flush()
}
// Write 实现 io.Writer
func (drw *DelayedResponseWriter) Write(p []byte) (int, error) {
return drw.interceptor.Write(p)
}
// WriteString 写入字符串
func (drw *DelayedResponseWriter) WriteString(s string) (int, error) {
return drw.interceptor.WriteString(s)
}
// SetStatusCode 设置状态码
func (drw *DelayedResponseWriter) SetStatusCode(code int) {
drw.interceptor.SetStatusCode(code)
}
// SetBody 设置 body
func (drw *DelayedResponseWriter) SetBody(body []byte) {
drw.interceptor.SetBody(body)
}
// SetBodyString 设置字符串 body
func (drw *DelayedResponseWriter) SetBodyString(body string) {
drw.interceptor.SetBodyString(body)
}
// SetHeader 设置 header
func (drw *DelayedResponseWriter) SetHeader(key, value string) {
drw.interceptor.SetHeader(key, value)
}
// GetHeader 获取 header
func (drw *DelayedResponseWriter) GetHeader(key string) []byte {
return drw.interceptor.GetHeader(key)
}
// DelHeader 删除 header
func (drw *DelayedResponseWriter) DelHeader(key string) {
drw.interceptor.DelHeader(key)
}
// ResponseInterceptorPool 响应拦截器池
var ResponseInterceptorPool = sync.Pool{
New: func() interface{} {
return &ResponseInterceptor{}
},
}
// AcquireResponseInterceptor 从池中获取拦截器
func AcquireResponseInterceptor(ctx *fasthttp.RequestCtx) *ResponseInterceptor {
ri := ResponseInterceptorPool.Get().(*ResponseInterceptor)
ri.ctx = ctx
ri.statusCode = 200
ri.customHeaders = make(map[string]string)
ri.headersToDelete = make([]string, 0)
ri.bodyBuffer = nil
ri.headersWritten = false
ri.intercepted = true
ri.headerFilterFunc = nil
ri.bodyFilterFunc = nil
return ri
}
// ReleaseResponseInterceptor 释放拦截器回池
func ReleaseResponseInterceptor(ri *ResponseInterceptor) {
if ri == nil {
return
}
// 清理状态
ri.ctx = nil
ri.headerFilterFunc = nil
ri.bodyFilterFunc = nil
ri.bodyBuffer = nil
ri.customHeaders = nil
ri.headersToDelete = nil
ResponseInterceptorPool.Put(ri)
}
// responseWriterWrapper 适配 fasthttp.ResponseWriter 接口
type responseWriterWrapper struct {
interceptor *ResponseInterceptor
}
func (w *responseWriterWrapper) Write(p []byte) (n int, err error) {
return w.interceptor.Write(p)
}
func (w *responseWriterWrapper) Header() map[string][]string {
// fasthttp 不兼容 http.Header返回 nil
return nil
}
func (w *responseWriterWrapper) WriteHeader(statusCode int) {
w.interceptor.SetStatusCode(statusCode)
}
// Hijack 支持连接劫持(用于 WebSocket
func (drw *DelayedResponseWriter) Hijack(handler fasthttp.HijackHandler) {
drw.ctx.Hijack(handler)
}
// Hijacked 检查是否已劫持
func (drw *DelayedResponseWriter) Hijacked() bool {
return drw.ctx.Hijacked()
}
// LocalAddr 获取本地地址
func (drw *DelayedResponseWriter) LocalAddr() net.Addr {
return drw.ctx.LocalAddr()
}
// RemoteAddr 获取远程地址
func (drw *DelayedResponseWriter) RemoteAddr() net.Addr {
return drw.ctx.RemoteAddr()
}
// SetConnectionClose 设置连接关闭
func (drw *DelayedResponseWriter) SetConnectionClose() {
drw.ctx.Response.SetConnectionClose()
}
// BodyWriter 返回 body 写入器
func (drw *DelayedResponseWriter) BodyWriter() io.Writer {
return &responseWriterAdapter{interceptor: drw.interceptor}
}
// responseWriterAdapter 适配 io.Writer
type responseWriterAdapter struct {
interceptor *ResponseInterceptor
}
func (rwa *responseWriterAdapter) Write(p []byte) (n int, err error) {
return rwa.interceptor.Write(p)
}
// ResponseStats 响应统计信息
type ResponseStats struct {
BufferedBytes int
HeadersModified int
HeadersDeleted int
BodyModified bool
StatusCode int
}
// GetStats 获取响应统计
func (drw *DelayedResponseWriter) GetStats() ResponseStats {
return ResponseStats{
BufferedBytes: len(drw.interceptor.bodyBuffer),
HeadersModified: len(drw.interceptor.customHeaders),
HeadersDeleted: len(drw.interceptor.headersToDelete),
BodyModified: drw.interceptor.bodyFilterFunc != nil,
StatusCode: drw.interceptor.statusCode,
}
}
// IsBodyBuffered 检查 body 是否被缓冲
func (drw *DelayedResponseWriter) IsBodyBuffered() bool {
return len(drw.interceptor.bodyBuffer) > 0
}
// GetBufferedBodySize 获取缓冲的 body 大小
func (drw *DelayedResponseWriter) GetBufferedBodySize() int {
return len(drw.interceptor.bodyBuffer)
}
// Reset 重置写入器状态
func (drw *DelayedResponseWriter) Reset() {
drw.interceptor.bodyBuffer = nil
drw.interceptor.headersWritten = false
drw.interceptor.statusCode = 200
drw.interceptor.customHeaders = make(map[string]string)
drw.interceptor.headersToDelete = make([]string, 0)
}
// SetBodyStream 设置 body 流
func (drw *DelayedResponseWriter) SetBodyStream(bodyStream io.Reader, bodySize int) {
if !drw.interceptor.IsEnabled() {
drw.ctx.SetBodyStream(bodyStream, bodySize)
return
}
// 流式 body 无法缓冲,直接设置
// 但在设置前应用 header filter
if drw.interceptor.headerFilterFunc != nil {
_ = drw.interceptor.headerFilterFunc()
}
drw.ctx.SetBodyStream(bodyStream, bodySize)
drw.interceptor.headersWritten = true
}
// SendFile 发送文件
func (drw *DelayedResponseWriter) SendFile(path string) error {
if !drw.interceptor.IsEnabled() {
drw.ctx.SendFile(path)
return nil
}
// 文件发送前应用 header filter
if drw.interceptor.headerFilterFunc != nil {
if err := drw.interceptor.headerFilterFunc(); err != nil {
return err
}
}
// 应用修改的 headers
drw.ctx.Response.SetStatusCode(drw.interceptor.statusCode)
for key, value := range drw.interceptor.customHeaders {
drw.ctx.Response.Header.Set(key, value)
}
for _, key := range drw.interceptor.headersToDelete {
drw.ctx.Response.Header.Del(key)
}
drw.ctx.SendFile(path)
drw.interceptor.headersWritten = true
return nil
}
// Redirect 重定向
func (drw *DelayedResponseWriter) Redirect(uri string, statusCode int) {
if !drw.interceptor.IsEnabled() {
drw.ctx.Redirect(uri, statusCode)
return
}
// 重定向前应用 header filter
if drw.interceptor.headerFilterFunc != nil {
_ = drw.interceptor.headerFilterFunc()
}
drw.ctx.Redirect(uri, statusCode)
drw.interceptor.headersWritten = true
}
// bufferPool body 缓冲区池
var bufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 0, 4096) // 4KB 初始容量
},
}
// acquireBuffer 获取缓冲区
func acquireBuffer() []byte {
return bufferPool.Get().([]byte)
}
// releaseBuffer 释放缓冲区
func releaseBuffer(buf []byte) {
if buf != nil && cap(buf) <= 65536 { // 只回收小缓冲区
bufferPool.Put(buf[:0])
}
}
// BufferedWriter 带缓冲的写入器
type BufferedWriter struct {
buf []byte
size int
flushFunc func([]byte) error
maxSize int
autoFlush bool
}
// NewBufferedWriter 创建缓冲写入器
func NewBufferedWriter(maxSize int, flushFunc func([]byte) error) *BufferedWriter {
return &BufferedWriter{
buf: acquireBuffer(),
maxSize: maxSize,
flushFunc: flushFunc,
autoFlush: true,
}
}
// Write 写入数据
func (bw *BufferedWriter) Write(p []byte) (int, error) {
if bw.buf == nil {
bw.buf = acquireBuffer()
}
// 检查是否需要扩容
if len(bw.buf)+len(p) > cap(bw.buf) {
// 扩容
newCap := cap(bw.buf) * 2
if newCap < len(bw.buf)+len(p) {
newCap = len(bw.buf) + len(p)
}
newBuf := make([]byte, len(bw.buf), newCap)
copy(newBuf, bw.buf)
releaseBuffer(bw.buf)
bw.buf = newBuf
}
bw.buf = append(bw.buf, p...)
// 自动刷新检查
if bw.autoFlush && bw.maxSize > 0 && len(bw.buf) >= bw.maxSize {
if err := bw.Flush(); err != nil {
return len(p), err
}
}
return len(p), nil
}
// Flush 刷新缓冲区
func (bw *BufferedWriter) Flush() error {
if bw.flushFunc == nil || len(bw.buf) == 0 {
return nil
}
if err := bw.flushFunc(bw.buf); err != nil {
return err
}
bw.buf = bw.buf[:0]
return nil
}
// Close 关闭写入器
func (bw *BufferedWriter) Close() error {
err := bw.Flush()
if bw.buf != nil {
releaseBuffer(bw.buf)
bw.buf = nil
}
return err
}
// Size 返回当前缓冲区大小
func (bw *BufferedWriter) Size() int {
return len(bw.buf)
}
// Bytes 返回当前缓冲区内容(不消费)
func (bw *BufferedWriter) Bytes() []byte {
return bw.buf
}

View File

@ -0,0 +1,351 @@
// Package lua 提供 Cosocket 管理功能
package lua
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
)
// SocketState 表示 socket 操作状态
type SocketState int
const (
// SocketStateIdle 空闲状态
SocketStateIdle SocketState = iota
// SocketStateConnecting 连接中
SocketStateConnecting
// SocketStateConnected 已连接
SocketStateConnected
// SocketStateSending 发送中
SocketStateSending
// SocketStateReceiving 接收中
SocketStateReceiving
// SocketStateClosing 关闭中
SocketStateClosing
// SocketStateClosed 已关闭
SocketStateClosed
// SocketStateError 错误状态
SocketStateError
)
func (s SocketState) String() string {
switch s {
case SocketStateIdle:
return "idle"
case SocketStateConnecting:
return "connecting"
case SocketStateConnected:
return "connected"
case SocketStateSending:
return "sending"
case SocketStateReceiving:
return "receiving"
case SocketStateClosing:
return "closing"
case SocketStateClosed:
return "closed"
case SocketStateError:
return "error"
default:
return "unknown"
}
}
// OperationType 操作类型
type OperationType string
const (
// OpConnect 连接操作
OpConnect OperationType = "connect"
// OpSend 发送操作
OpSend OperationType = "send"
// OpReceive 接收操作
OpReceive OperationType = "receive"
// OpClose 关闭操作
OpClose OperationType = "close"
)
// SocketOperation 表示一个 socket 操作
type SocketOperation struct {
// 操作 ID
ID uint64
// 关联的 socket
Socket *TCPSocket
// 操作类型
Type OperationType
// 当前状态
State SocketState
// 创建时间
CreatedAt time.Time
// 最后活动时间
LastActivity time.Time
// 超时时间
Timeout time.Duration
// 完成通道
Done chan struct{}
// 错误信息
Error error
// 结果数据
Result interface{}
// 是否完成
completed int32
}
// IsCompleted 检查操作是否已完成
func (op *SocketOperation) IsCompleted() bool {
return atomic.LoadInt32(&op.completed) == 1
}
// Complete 标记操作完成
func (op *SocketOperation) Complete(result interface{}, err error) {
if atomic.CompareAndSwapInt32(&op.completed, 0, 1) {
op.Result = result
op.Error = err
close(op.Done)
}
}
// Wait 等待操作完成
func (op *SocketOperation) Wait(ctx context.Context) (interface{}, error) {
select {
case <-op.Done:
return op.Result, op.Error
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Touch 更新活动时间
func (op *SocketOperation) Touch() {
op.LastActivity = time.Now()
}
// CosocketStats Cosocket 统计信息
type CosocketStats struct {
// 总操作数
TotalOperations uint64
// 活跃操作数
ActiveOperations uint64
// 超时操作数
TimeoutOperations uint64
// 错误操作数
ErrorOperations uint64
// 当前 socket 数
ActiveSockets uint64
// 总创建 socket 数
TotalSocketsCreated uint64
// 总关闭 socket 数
TotalSocketsClosed uint64
}
// CosocketManager Cosocket 管理器
type CosocketManager struct {
// 操作映射
operations map[uint64]*SocketOperation
// 互斥锁
mu sync.RWMutex
// 操作 ID 生成器
nextID uint64
// 超时检查器
timeoutChecker *time.Ticker
// 统计信息
stats CosocketStats
// 上下文
ctx context.Context
cancel context.CancelFunc
// 默认超时
defaultTimeout time.Duration
// 清理间隔
cleanupInterval time.Duration
}
// DefaultCosocketManager 全局默认管理器
var DefaultCosocketManager = NewCosocketManager()
// NewCosocketManager 创建新的 Cosocket 管理器
func NewCosocketManager() *CosocketManager {
ctx, cancel := context.WithCancel(context.Background())
cm := &CosocketManager{
operations: make(map[uint64]*SocketOperation),
nextID: 0,
timeoutChecker: time.NewTicker(30 * time.Second),
ctx: ctx,
cancel: cancel,
defaultTimeout: 60 * time.Second,
cleanupInterval: 30 * time.Second,
}
// 启动清理循环
go cm.cleanupLoop()
return cm
}
// StartOperation 开始一个新的 socket 操作
func (cm *CosocketManager) StartOperation(socket *TCPSocket, opType OperationType, timeout time.Duration) *SocketOperation {
if timeout <= 0 {
timeout = cm.defaultTimeout
}
id := atomic.AddUint64(&cm.nextID, 1)
now := time.Now()
op := &SocketOperation{
ID: id,
Socket: socket,
Type: opType,
State: SocketStateIdle,
CreatedAt: now,
LastActivity: now,
Timeout: timeout,
Done: make(chan struct{}),
}
cm.mu.Lock()
cm.operations[id] = op
cm.mu.Unlock()
atomic.AddUint64(&cm.stats.TotalOperations, 1)
atomic.AddUint64(&cm.stats.ActiveOperations, 1)
return op
}
// CompleteOperation 完成操作
func (cm *CosocketManager) CompleteOperation(id uint64, result interface{}, err error) {
cm.mu.Lock()
op, exists := cm.operations[id]
if exists {
delete(cm.operations, id)
}
cm.mu.Unlock()
if exists && op != nil {
op.Complete(result, err)
atomic.AddUint64(&cm.stats.ActiveOperations, ^uint64(0))
if err != nil {
atomic.AddUint64(&cm.stats.ErrorOperations, 1)
}
}
}
// GetOperation 获取操作
func (cm *CosocketManager) GetOperation(id uint64) *SocketOperation {
cm.mu.RLock()
defer cm.mu.RUnlock()
return cm.operations[id]
}
// cleanupLoop 清理循环
func (cm *CosocketManager) cleanupLoop() {
for {
select {
case <-cm.ctx.Done():
return
case <-cm.timeoutChecker.C:
cm.cleanup()
}
}
}
// cleanup 清理超时操作
func (cm *CosocketManager) cleanup() {
now := time.Now()
timeoutOps := make([]*SocketOperation, 0)
cm.mu.RLock()
for _, op := range cm.operations {
if !op.IsCompleted() && now.Sub(op.LastActivity) > op.Timeout {
timeoutOps = append(timeoutOps, op)
}
}
cm.mu.RUnlock()
for _, op := range timeoutOps {
cm.CompleteOperation(op.ID, nil, context.DeadlineExceeded)
atomic.AddUint64(&cm.stats.TimeoutOperations, 1)
}
}
// Stats 获取统计信息
func (cm *CosocketManager) Stats() CosocketStats {
return CosocketStats{
TotalOperations: atomic.LoadUint64(&cm.stats.TotalOperations),
ActiveOperations: atomic.LoadUint64(&cm.stats.ActiveOperations),
TimeoutOperations: atomic.LoadUint64(&cm.stats.TimeoutOperations),
ErrorOperations: atomic.LoadUint64(&cm.stats.ErrorOperations),
ActiveSockets: atomic.LoadUint64(&cm.stats.ActiveSockets),
TotalSocketsCreated: atomic.LoadUint64(&cm.stats.TotalSocketsCreated),
TotalSocketsClosed: atomic.LoadUint64(&cm.stats.TotalSocketsClosed),
}
}
// SetDefaultTimeout 设置默认超时
func (cm *CosocketManager) SetDefaultTimeout(timeout time.Duration) {
cm.defaultTimeout = timeout
}
// Close 关闭管理器
func (cm *CosocketManager) Close() {
cm.cancel()
cm.timeoutChecker.Stop()
// 取消所有未完成操作
cm.mu.Lock()
ops := make([]*SocketOperation, 0, len(cm.operations))
for _, op := range cm.operations {
ops = append(ops, op)
}
cm.operations = make(map[uint64]*SocketOperation)
cm.mu.Unlock()
for _, op := range ops {
op.Complete(nil, context.Canceled)
}
}
// TrackSocketCreated 跟踪 socket 创建
func (cm *CosocketManager) TrackSocketCreated() {
atomic.AddUint64(&cm.stats.TotalSocketsCreated, 1)
atomic.AddUint64(&cm.stats.ActiveSockets, 1)
}
// TrackSocketClosed 跟踪 socket 关闭
func (cm *CosocketManager) TrackSocketClosed() {
atomic.AddUint64(&cm.stats.TotalSocketsClosed, 1)
atomic.AddUint64(&cm.stats.ActiveSockets, ^uint64(0))
}
// TCPAddr 解析 TCP 地址
func (cm *CosocketManager) TCPAddr(host string, port int) (*net.TCPAddr, error) {
return &net.TCPAddr{
IP: net.ParseIP(host),
Port: port,
}, nil
}

702
internal/lua/socket_test.go Normal file
View File

@ -0,0 +1,702 @@
package lua
import (
"context"
"fmt"
"net"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
)
// mockEchoServer 模拟 echo 服务器
func mockEchoServer(t *testing.T, addr string) (net.Listener, func()) {
ln, err := net.Listen("tcp", addr)
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
var wg sync.WaitGroup
stop := make(chan struct{})
go func() {
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-stop:
return
default:
continue
}
}
wg.Add(1)
go func(c net.Conn) {
defer wg.Done()
buf := make([]byte, 4096)
for {
n, err := c.Read(buf)
if err != nil {
return
}
if n > 0 {
if _, err := c.Write(buf[:n]); err != nil {
return
}
}
}
}(conn)
}
}()
cleanup := func() {
close(stop)
ln.Close()
wg.Wait()
}
return ln, cleanup
}
// TestCosocketManager_Basic 测试基本功能
func TestCosocketManager_Basic(t *testing.T) {
cm := NewCosocketManager()
defer cm.Close()
// 测试初始状态
stats := cm.Stats()
if stats.TotalOperations != 0 {
t.Errorf("Expected 0 operations, got %d", stats.TotalOperations)
}
// 测试操作创建
socket := NewTCPSocket(cm)
defer socket.Close()
op := cm.StartOperation(socket, OpConnect, 5*time.Second)
if op == nil {
t.Fatal("Expected non-nil operation")
}
if op.ID == 0 {
t.Error("Expected non-zero operation ID")
}
if op.Type != OpConnect {
t.Errorf("Expected OpConnect, got %v", op.Type)
}
// 测试统计
stats = cm.Stats()
if stats.TotalOperations != 1 {
t.Errorf("Expected 1 operation, got %d", stats.TotalOperations)
}
if stats.ActiveOperations != 1 {
t.Errorf("Expected 1 active operation, got %d", stats.ActiveOperations)
}
// 测试操作完成
cm.CompleteOperation(op.ID, "done", nil)
stats = cm.Stats()
if stats.ActiveOperations != 0 {
t.Errorf("Expected 0 active operations after complete, got %d", stats.ActiveOperations)
}
}
// TestCosocketManager_Timeout 测试超时机制
func TestCosocketManager_Timeout(t *testing.T) {
// 创建一个使用短清理间隔的管理器用于测试
cm := NewCosocketManager()
cm.SetDefaultTimeout(100 * time.Millisecond)
cm.cleanupInterval = 50 * time.Millisecond
cm.timeoutChecker.Reset(50 * time.Millisecond)
defer cm.Close()
socket := NewTCPSocket(cm)
defer socket.Close()
// 创建一个不完成的操作
op := cm.StartOperation(socket, OpConnect, 100*time.Millisecond)
// 等待超时清理
time.Sleep(300 * time.Millisecond)
// 检查操作是否超时完成
if !op.IsCompleted() {
t.Error("Expected operation to be completed due to timeout")
}
stats := cm.Stats()
if stats.TimeoutOperations != 1 {
t.Errorf("Expected 1 timeout operation, got %d", stats.TimeoutOperations)
}
}
// TestTCPSocket_Connect 测试 TCP 连接
func TestTCPSocket_Connect(t *testing.T) {
_, cleanup := mockEchoServer(t, "127.0.0.1:19999")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
socket := NewTCPSocket(cm)
defer socket.Close()
// 测试连接
err := socket.Connect("127.0.0.1", 19999)
if err != nil {
t.Fatalf("Connect failed: %v", err)
}
// 等待连接完成
op := socket.currentOp
if op != nil {
result, err := op.Wait(context.Background())
if err != nil {
t.Fatalf("Connect wait failed: %v", err)
}
if result == nil {
t.Fatal("Expected non-nil connection")
}
}
if socket.State() != SocketStateConnected {
t.Errorf("Expected state connected, got %v", socket.State())
}
}
// TestTCPSocket_SendReceive 测试发送接收
func TestTCPSocket_SendReceive(t *testing.T) {
_, cleanup := mockEchoServer(t, "127.0.0.1:19998")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
socket := NewTCPSocket(cm)
defer socket.Close()
// 连接
if err := socket.Connect("127.0.0.1", 19998); err != nil {
t.Fatalf("Connect failed: %v", err)
}
// 等待连接完成
time.Sleep(100 * time.Millisecond)
// 发送数据
testData := "Hello, Cosocket!"
n, err := socket.Send([]byte(testData))
if err != nil {
t.Fatalf("Send failed: %v", err)
}
if n != len(testData) {
t.Errorf("Expected %d bytes sent, got %d", len(testData), n)
}
// 接收数据
received, err := socket.Receive(1024)
if err != nil {
t.Fatalf("Receive failed: %v", err)
}
if string(received) != testData {
t.Errorf("Expected '%s', got '%s'", testData, string(received))
}
}
// TestTCPSocket_AsyncOperations 测试异步操作
func TestTCPSocket_AsyncOperations(t *testing.T) {
_, cleanup := mockEchoServer(t, "127.0.0.1:19997")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
socket := NewTCPSocket(cm)
defer socket.Close()
// 测试异步连接
err := socket.Connect("127.0.0.1", 19997)
if err != nil {
t.Fatalf("Connect failed: %v", err)
}
// 等待连接完成
time.Sleep(100 * time.Millisecond)
// 测试异步发送
testData := "Async test"
sendOp, err := socket.SendAsync([]byte(testData))
if err != nil {
t.Fatalf("SendAsync failed: %v", err)
}
result, err := sendOp.Wait(context.Background())
if err != nil {
t.Fatalf("Send wait failed: %v", err)
}
if n, ok := result.(int); !ok || n != len(testData) {
t.Errorf("Expected %d bytes, got %v", len(testData), result)
}
// 测试异步接收
recvOp, err := socket.ReceiveAsync(1024)
if err != nil {
t.Fatalf("ReceiveAsync failed: %v", err)
}
result, err = recvOp.Wait(context.Background())
if err != nil {
t.Fatalf("Receive wait failed: %v", err)
}
if data, ok := result.([]byte); !ok || string(data) != testData {
t.Errorf("Expected '%s', got %v", testData, result)
}
}
// TestTCPSocket_ReceiveUntil 测试接收直到特定模式
func TestTCPSocket_ReceiveUntil(t *testing.T) {
_, cleanup := mockEchoServer(t, "127.0.0.1:19996")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
socket := NewTCPSocket(cm)
defer socket.Close()
// 连接
if err := socket.Connect("127.0.0.1", 19996); err != nil {
t.Fatalf("Connect failed: %v", err)
}
// 等待连接完成
time.Sleep(100 * time.Millisecond)
// 发送带换行的数据
testData := "Line1\nLine2\nLine3\n"
_, err := socket.Send([]byte(testData))
if err != nil {
t.Fatalf("Send failed: %v", err)
}
// 接收直到换行
data, err := socket.ReceiveUntil("\n", true)
if err != nil {
t.Fatalf("ReceiveUntil failed: %v", err)
}
if string(data) != "Line1\n" {
t.Errorf("Expected 'Line1\\n', got '%s'", string(data))
}
}
// TestTCPSocket_Close 测试关闭
func TestTCPSocket_Close(t *testing.T) {
cm := NewCosocketManager()
defer cm.Close()
socket := NewTCPSocket(cm)
if socket.IsClosed() {
t.Error("Socket should not be closed initially")
}
err := socket.Close()
if err != nil {
t.Errorf("Close failed: %v", err)
}
if !socket.IsClosed() {
t.Error("Socket should be closed")
}
// 重复关闭应该返回 nil
err = socket.Close()
if err != nil {
t.Errorf("Second close should not error: %v", err)
}
}
// TestTCPSocket_StateTransitions 测试状态转换
func TestTCPSocket_StateTransitions(t *testing.T) {
_, cleanup := mockEchoServer(t, "127.0.0.1:19995")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
socket := NewTCPSocket(cm)
defer socket.Close()
// 初始状态
if socket.State() != SocketStateIdle {
t.Errorf("Expected idle state, got %v", socket.State())
}
// 连接中
socket.Connect("127.0.0.1", 19995)
if socket.State() != SocketStateConnecting {
t.Errorf("Expected connecting state, got %v", socket.State())
}
// 等待连接完成
time.Sleep(100 * time.Millisecond)
if socket.State() != SocketStateConnected {
t.Errorf("Expected connected state, got %v", socket.State())
}
}
// TestCosocketManager_Concurrent 测试并发操作
func TestCosocketManager_Concurrent(t *testing.T) {
if testing.Short() {
t.Skip("Skipping concurrent test in short mode")
}
_, cleanup := mockEchoServer(t, "127.0.0.1:19994")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
const numSockets = 1000
const numGoroutines = 100
var wg sync.WaitGroup
errors := make(chan error, numSockets)
var completed int32
// 并发创建 socket 和连接
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(start int) {
defer wg.Done()
for j := 0; j < numSockets/numGoroutines; j++ {
socket := NewTCPSocket(cm)
if err := socket.Connect("127.0.0.1", 19994); err != nil {
errors <- fmt.Errorf("connect failed: %v", err)
socket.Close()
continue
}
// 等待连接
time.Sleep(50 * time.Millisecond)
// 发送数据
data := fmt.Sprintf("Test%d", start+j)
if _, err := socket.Send([]byte(data)); err != nil {
errors <- fmt.Errorf("send failed: %v", err)
socket.Close()
continue
}
socket.Close()
atomic.AddInt32(&completed, 1)
}
}(i * (numSockets / numGoroutines))
}
wg.Wait()
// 检查错误
close(errors)
errCount := 0
for err := range errors {
t.Logf("Error: %v", err)
errCount++
}
t.Logf("Completed: %d/%d, Errors: %d", completed, numSockets, errCount)
// 检查统计
stats := cm.Stats()
t.Logf("Stats: %+v", stats)
}
// TestCosocketManager_MemoryLeak 测试内存泄漏
func TestCosocketManager_MemoryLeak(t *testing.T) {
if testing.Short() {
t.Skip("Skipping memory leak test in short mode")
}
_, cleanup := mockEchoServer(t, "127.0.0.1:19993")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
// 记录初始 goroutine 数
initialGoroutines := runtime.NumGoroutine()
// 创建和关闭大量 socket
for i := 0; i < 10000; i++ {
socket := NewTCPSocket(cm)
// 使用同步连接避免竞态
socket.Connect("127.0.0.1", 19993)
time.Sleep(time.Millisecond) // 给连接时间完成
socket.Close()
}
// 强制 GC
runtime.GC()
time.Sleep(100 * time.Millisecond)
// 检查 goroutine 数量
finalGoroutines := runtime.NumGoroutine()
t.Logf("Initial goroutines: %d, Final goroutines: %d", initialGoroutines, finalGoroutines)
// 允许一定的波动
if finalGoroutines > initialGoroutines+100 {
t.Errorf("Possible goroutine leak: started with %d, ended with %d", initialGoroutines, finalGoroutines)
}
// 检查统计
stats := cm.Stats()
if stats.ActiveSockets > 100 {
t.Errorf("Active sockets leak: %d", stats.ActiveSockets)
}
if stats.ActiveOperations > 100 {
t.Errorf("Active operations leak: %d", stats.ActiveOperations)
}
}
// TestCosocketManager_LongRunning 测试长时间运行
func TestCosocketManager_LongRunning(t *testing.T) {
if testing.Short() {
t.Skip("Skipping long running test in short mode")
}
_, cleanup := mockEchoServer(t, "127.0.0.1:19992")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
duration := 10 * time.Second // 缩短到 10 秒进行测试
interval := 100 * time.Millisecond
var totalOps int32
start := time.Now()
for time.Since(start) < duration {
socket := NewTCPSocket(cm)
if err := socket.Connect("127.0.0.1", 19992); err != nil {
t.Logf("Connect error: %v", err)
socket.Close()
continue
}
// 等待连接
time.Sleep(50 * time.Millisecond)
// 发送接收
if _, err := socket.Send([]byte("test")); err == nil {
socket.Receive(1024)
}
socket.Close()
atomic.AddInt32(&totalOps, 1)
time.Sleep(interval)
}
elapsed := time.Since(start)
t.Logf("Completed %d operations in %v", totalOps, elapsed)
// 检查最终统计
stats := cm.Stats()
t.Logf("Final stats: %+v", stats)
if stats.ActiveSockets > 0 {
t.Errorf("Expected 0 active sockets, got %d", stats.ActiveSockets)
}
if stats.ActiveOperations > 0 {
t.Errorf("Expected 0 active operations, got %d", stats.ActiveOperations)
}
}
// BenchmarkCosocket_Connect 基准测试:连接
func BenchmarkCosocket_Connect(b *testing.B) {
_, cleanup := mockEchoServer(nil, "127.0.0.1:19991")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
socket := NewTCPSocket(cm)
socket.Connect("127.0.0.1", 19991)
time.Sleep(10 * time.Millisecond)
socket.Close()
}
})
}
// BenchmarkCosocket_SendReceive 基准测试:发送接收
func BenchmarkCosocket_SendReceive(b *testing.B) {
_, cleanup := mockEchoServer(nil, "127.0.0.1:19990")
defer cleanup()
cm := NewCosocketManager()
defer cm.Close()
// 预先连接
socket := NewTCPSocket(cm)
socket.Connect("127.0.0.1", 19990)
time.Sleep(100 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
socket.Send([]byte("benchmark"))
socket.Receive(1024)
}
b.StopTimer()
socket.Close()
}
// TestLuaAPI_TCPSocket 测试 Lua API
func TestLuaAPI_TCPSocket(t *testing.T) {
if testing.Short() {
t.Skip("Skipping Lua API test in short mode")
}
// 创建引擎
engine, err := NewEngine(nil)
if err != nil {
t.Fatalf("Failed to create engine: %v", err)
}
defer engine.Close()
// 注册 TCP socket API
RegisterTCPSocketAPI(engine.L, engine)
// 测试创建 socket
script := `
local sock = ngx.socket.tcp()
if not sock then
return nil, "failed to create socket"
end
return "ok"
`
coro, err := engine.NewCoroutine(nil)
if err != nil {
t.Fatalf("Failed to create coroutine: %v", err)
}
defer coro.Close()
if err := coro.SetupSandbox(); err != nil {
t.Fatalf("Failed to setup sandbox: %v", err)
}
err = coro.Execute(script)
if err != nil {
t.Errorf("Execute failed: %v", err)
}
}
// TestCosocketManager_Stress 压力测试
func TestCosocketManager_Stress(t *testing.T) {
if testing.Short() {
t.Skip("Skipping stress test in short mode")
}
// 创建多个 echo 服务器
ports := []int{19980, 19981, 19982, 19983}
cleanups := make([]func(), len(ports))
for i, port := range ports {
_, cleanups[i] = mockEchoServer(t, fmt.Sprintf("127.0.0.1:%d", port))
}
defer func() {
for _, c := range cleanups {
c()
}
}()
cm := NewCosocketManager()
defer cm.Close()
const totalConnections = 10000
const concurrency = 1000
var wg sync.WaitGroup
var successCount int32
var errorCount int32
var latencySum int64
start := time.Now()
// 使用信号量限制并发
sem := make(chan struct{}, concurrency)
for i := 0; i < totalConnections; i++ {
wg.Add(1)
sem <- struct{}{} // 获取信号量
go func(idx int) {
defer wg.Done()
defer func() { <-sem }() // 释放信号量
port := ports[idx%len(ports)]
socket := NewTCPSocket(cm)
opStart := time.Now()
err := socket.Connect("127.0.0.1", port)
if err != nil {
atomic.AddInt32(&errorCount, 1)
socket.Close()
return
}
// 等待连接完成
time.Sleep(10 * time.Millisecond)
// 简单数据交换
if _, err := socket.Send([]byte("hello")); err == nil {
socket.Receive(1024)
}
socket.Close()
latency := time.Since(opStart).Milliseconds()
atomic.AddInt64(&latencySum, latency)
atomic.AddInt32(&successCount, 1)
}(i)
}
wg.Wait()
elapsed := time.Since(start)
t.Logf("Stress test completed:")
t.Logf(" Total: %d, Success: %d, Errors: %d", totalConnections, successCount, errorCount)
t.Logf(" Duration: %v", elapsed)
t.Logf(" RPS: %.2f", float64(totalConnections)/elapsed.Seconds())
if successCount > 0 {
t.Logf(" Avg Latency: %dms", latencySum/int64(successCount))
}
// 内存检查
var m runtime.MemStats
runtime.GC()
runtime.ReadMemStats(&m)
t.Logf(" Memory: %.2f MB", float64(m.HeapAlloc)/(1024*1024))
// 验证没有资源泄漏
stats := cm.Stats()
t.Logf(" Active sockets: %d, Active operations: %d", stats.ActiveSockets, stats.ActiveOperations)
if errorCount > totalConnections/10 { // 允许 10% 错误率
t.Errorf("Too many errors: %d", errorCount)
}
if stats.ActiveSockets > 100 {
t.Errorf("Socket leak detected: %d active sockets", stats.ActiveSockets)
}
}