feat(lua): 实现 Cosocket API 和响应拦截器
- ngx.req API 双层边界验证原型 - TCP Cosocket API (connect/send/receive/close) - Cosocket 状态管理器和连接池 - ResponseInterceptor 响应拦截器 - 完整单元测试覆盖 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e7ee9e717d
commit
8bac2fdcfa
330
internal/lua/api_req.go
Normal file
330
internal/lua/api_req.go
Normal file
@ -0,0 +1,330 @@
|
||||
// Package lua 提供 ngx.req API 实现
|
||||
// 本文件实现双层 API 边界验证原型,用于测量直接映射层 vs 兼容层的性能差异
|
||||
package lua
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// ngxReqAPILayer 定义 API 层级类型
|
||||
type ngxReqAPILayer int
|
||||
|
||||
const (
|
||||
// APILayerDirect 直接映射层:fasthttp -> Lua,无中间层
|
||||
// 性能最优,延迟最低
|
||||
APILayerDirect ngxReqAPILayer = 1
|
||||
|
||||
// APILayerCompatible 兼容层:需要模拟 nginx 语义
|
||||
// 增加了少量转换开销
|
||||
APILayerCompatible ngxReqAPILayer = 2
|
||||
|
||||
// APILayerPseudoNonBlocking 伪非阻塞层:yield/resume
|
||||
// 支持在 Lua 中调用异步操作
|
||||
APILayerPseudoNonBlocking ngxReqAPILayer = 3
|
||||
)
|
||||
|
||||
func (l ngxReqAPILayer) String() string {
|
||||
switch l {
|
||||
case APILayerDirect:
|
||||
return "direct"
|
||||
case APILayerCompatible:
|
||||
return "compatible"
|
||||
case APILayerPseudoNonBlocking:
|
||||
return "pseudo_non_blocking"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// ngxReqMetrics 收集 API 调用指标
|
||||
type ngxReqMetrics struct {
|
||||
// 调用计数
|
||||
DirectCallCount uint64
|
||||
CompatibleCallCount uint64
|
||||
PseudoBlockingCallCount uint64
|
||||
|
||||
// 累积延迟(纳秒)
|
||||
DirectTotalNs uint64
|
||||
CompatibleTotalNs uint64
|
||||
PseudoBlockingTotalNs uint64
|
||||
|
||||
// 最大值(用于识别异常)
|
||||
DirectMaxNs uint64
|
||||
CompatibleMaxNs uint64
|
||||
PseudoBlockingMaxNs uint64
|
||||
}
|
||||
|
||||
// ngxReqAPI ngx.req API 实现
|
||||
type ngxReqAPI struct {
|
||||
// 请求上下文
|
||||
ctx *fasthttp.RequestCtx
|
||||
|
||||
// 指标收集
|
||||
metrics ngxReqMetrics
|
||||
|
||||
// 缓存:URI args 解析结果(兼容层使用)
|
||||
uriArgsCache map[string][]string
|
||||
uriArgsCacheOnce sync.Once
|
||||
}
|
||||
|
||||
// newNgxReqAPI 创建 ngx.req API 实例
|
||||
func newNgxReqAPI(ctx *fasthttp.RequestCtx) *ngxReqAPI {
|
||||
return &ngxReqAPI{
|
||||
ctx: ctx,
|
||||
uriArgsCache: nil, // 延迟初始化
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterNgxReqAPI 在 Lua 状态机中注册 ngx.req API
|
||||
// 这是主入口函数,由 LuaEngine 在初始化时调用
|
||||
func RegisterNgxReqAPI(L *glua.LState, api *ngxReqAPI) {
|
||||
// 创建 ngx 表
|
||||
ngx := L.NewTable()
|
||||
|
||||
// 创建 ngx.req 子表
|
||||
ngxReq := L.NewTable()
|
||||
|
||||
// 直接映射层 API:get_method
|
||||
// 特点:直接访问 fasthttp.RequestCtx,零拷贝,最小开销
|
||||
ngxReq.RawSetString("get_method", L.NewFunction(api.luaGetMethod))
|
||||
|
||||
// 直接映射层 API:get_uri
|
||||
// 特点:直接返回请求的 URI 路径(不含 query string)
|
||||
ngxReq.RawSetString("get_uri", L.NewFunction(api.luaGetURI))
|
||||
|
||||
// 兼容层 API:get_uri_args
|
||||
// 特点:需要解析 query string 为 nginx 兼容的表结构
|
||||
// 增加了解析开销,但保持 API 兼容性
|
||||
ngxReq.RawSetString("get_uri_args", L.NewFunction(api.luaGetURIArgs))
|
||||
|
||||
// 伪非阻塞层 API:read_body
|
||||
// 特点:使用 yield/resume 模式支持异步读取
|
||||
// 这是实验性 API,展示非阻塞调用模式
|
||||
ngxReq.RawSetString("read_body", L.NewFunction(api.luaReadBodyAsync))
|
||||
|
||||
// 将 ngx.req 添加到 ngx
|
||||
ngx.RawSetString("req", ngxReq)
|
||||
|
||||
// 注册 ngx 全局变量
|
||||
L.SetGlobal("ngx", ngx)
|
||||
}
|
||||
|
||||
// ==================== 直接映射层 API ====================
|
||||
|
||||
// luaGetMethod 实现 ngx.req.get_method() - 直接映射层
|
||||
// Lua 调用: local method = ngx.req.get_method()
|
||||
// 返回: string (如 "GET", "POST", "PUT" 等)
|
||||
func (api *ngxReqAPI) luaGetMethod(L *glua.LState) int {
|
||||
start := time.Now()
|
||||
|
||||
// 直接访问 fasthttp:零拷贝,最小开销
|
||||
method := string(api.ctx.Method())
|
||||
|
||||
// 记录指标
|
||||
elapsed := uint64(time.Since(start).Nanoseconds())
|
||||
api.metrics.DirectCallCount++
|
||||
api.metrics.DirectTotalNs += elapsed
|
||||
if elapsed > api.metrics.DirectMaxNs {
|
||||
api.metrics.DirectMaxNs = elapsed
|
||||
}
|
||||
|
||||
L.Push(glua.LString(method))
|
||||
return 1
|
||||
}
|
||||
|
||||
// luaGetURI 实现 ngx.req.get_uri() - 直接映射层
|
||||
// Lua 调用: local uri = ngx.req.get_uri()
|
||||
// 返回: string (如 "/path/to/resource")
|
||||
func (api *ngxReqAPI) luaGetURI(L *glua.LState) int {
|
||||
start := time.Now()
|
||||
|
||||
// 直接访问 fasthttp URI 路径
|
||||
uri := string(api.ctx.Request.URI().Path())
|
||||
|
||||
// 记录指标
|
||||
elapsed := uint64(time.Since(start).Nanoseconds())
|
||||
api.metrics.DirectCallCount++
|
||||
api.metrics.DirectTotalNs += elapsed
|
||||
if elapsed > api.metrics.DirectMaxNs {
|
||||
api.metrics.DirectMaxNs = elapsed
|
||||
}
|
||||
|
||||
L.Push(glua.LString(uri))
|
||||
return 1
|
||||
}
|
||||
|
||||
// ==================== 兼容层 API ====================
|
||||
|
||||
// luaGetURIArgs 实现 ngx.req.get_uri_args() - 兼容层
|
||||
// Lua 调用: local args = ngx.req.get_uri_args()
|
||||
// 返回: table (如 { name = "value", arr = { "v1", "v2" } })
|
||||
// 注意:兼容层需要解析 query string,模拟 nginx 的参数表结构
|
||||
func (api *ngxReqAPI) luaGetURIArgs(L *glua.LState) int {
|
||||
start := time.Now()
|
||||
|
||||
// 延迟初始化缓存
|
||||
api.uriArgsCacheOnce.Do(func() {
|
||||
api.uriArgsCache = api.parseURIArgs()
|
||||
})
|
||||
|
||||
// 构建 Lua 表(兼容 nginx 的 ngx.req.get_uri_args 格式)
|
||||
result := L.NewTable()
|
||||
|
||||
for key, values := range api.uriArgsCache {
|
||||
if len(values) == 1 {
|
||||
// 单值:直接存储为字符串
|
||||
result.RawSetString(key, glua.LString(values[0]))
|
||||
} else {
|
||||
// 多值:存储为数组(table)
|
||||
arr := L.NewTable()
|
||||
for i, v := range values {
|
||||
arr.RawSetInt(i+1, glua.LString(v)) // Lua 数组从 1 开始
|
||||
}
|
||||
result.RawSetString(key, arr)
|
||||
}
|
||||
}
|
||||
|
||||
// 记录指标
|
||||
elapsed := uint64(time.Since(start).Nanoseconds())
|
||||
api.metrics.CompatibleCallCount++
|
||||
api.metrics.CompatibleTotalNs += elapsed
|
||||
if elapsed > api.metrics.CompatibleMaxNs {
|
||||
api.metrics.CompatibleMaxNs = elapsed
|
||||
}
|
||||
|
||||
L.Push(result)
|
||||
return 1
|
||||
}
|
||||
|
||||
// parseURIArgs 解析 URI query string 为 map
|
||||
// 这是兼容层的核心转换逻辑,模拟 nginx 的参数解析
|
||||
func (api *ngxReqAPI) parseURIArgs() map[string][]string {
|
||||
args := make(map[string][]string)
|
||||
|
||||
// 获取 query string
|
||||
query := api.ctx.QueryArgs()
|
||||
|
||||
// 遍历所有参数 - 使用 All() 替代已弃用的 VisitAll()
|
||||
for key, value := range query.All() {
|
||||
keyStr := string(key)
|
||||
valueStr := string(value)
|
||||
|
||||
if existing, ok := args[keyStr]; ok {
|
||||
args[keyStr] = append(existing, valueStr)
|
||||
} else {
|
||||
args[keyStr] = []string{valueStr}
|
||||
}
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
// ==================== 伪非阻塞层 API(实验性) ====================
|
||||
|
||||
// luaReadBodyAsync 实现 ngx.req.read_body() - 伪非阻塞层
|
||||
// Lua 调用: ngx.req.read_body() -- 会 yield,完成后 resume
|
||||
// 这是实验性 API,展示如何使用 yield/resume 实现非阻塞调用
|
||||
func (api *ngxReqAPI) luaReadBodyAsync(L *glua.LState) int {
|
||||
// 伪非阻塞层:使用 yield 暂停协程,由引擎异步处理后 resume
|
||||
// 这种模式允许在 Lua 中编写看似同步的代码,实际是异步执行
|
||||
|
||||
// 记录开始时间
|
||||
start := time.Now()
|
||||
|
||||
// Yield 协程 - 控制权交回 Go 层
|
||||
// TODO: 实现真正的非阻塞 yield,目前使用同步模拟
|
||||
L.Push(glua.LString("read_body"))
|
||||
L.Push(glua.LString(strconv.FormatInt(start.UnixNano(), 10)))
|
||||
// Note: 在 gopher-lua v1.1.2 中,L.Yield 需要 LValue 参数,返回 int
|
||||
// 这里返回 2 表示有 2 个返回值已在栈上
|
||||
return 2 // 使用 return 代替 L.Yield
|
||||
}
|
||||
|
||||
// ==================== 辅助函数 ====================
|
||||
|
||||
// GetMetrics 返回 API 调用指标
|
||||
// 用于基准测试和性能监控
|
||||
func (api *ngxReqAPI) GetMetrics() ngxReqMetrics {
|
||||
return api.metrics
|
||||
}
|
||||
|
||||
// GetDirectLayerAvgNs 返回直接映射层平均延迟(纳秒)
|
||||
func (api *ngxReqAPI) GetDirectLayerAvgNs() float64 {
|
||||
if api.metrics.DirectCallCount == 0 {
|
||||
return 0
|
||||
}
|
||||
return float64(api.metrics.DirectTotalNs) / float64(api.metrics.DirectCallCount)
|
||||
}
|
||||
|
||||
// GetCompatibleLayerAvgNs 返回兼容层平均延迟(纳秒)
|
||||
func (api *ngxReqAPI) GetCompatibleLayerAvgNs() float64 {
|
||||
if api.metrics.CompatibleCallCount == 0 {
|
||||
return 0
|
||||
}
|
||||
return float64(api.metrics.CompatibleTotalNs) / float64(api.metrics.CompatibleCallCount)
|
||||
}
|
||||
|
||||
// GetPerformanceRatio 返回兼容层/直接映射层的性能比率
|
||||
// ratio > 1.2 表示兼容层比直接映射层慢 20% 以上
|
||||
func (api *ngxReqAPI) GetPerformanceRatio() float64 {
|
||||
directAvg := api.GetDirectLayerAvgNs()
|
||||
compatibleAvg := api.GetCompatibleLayerAvgNs()
|
||||
|
||||
if directAvg == 0 {
|
||||
return 0
|
||||
}
|
||||
return compatibleAvg / directAvg
|
||||
}
|
||||
|
||||
// ResetMetrics 重置指标(用于基准测试)
|
||||
func (api *ngxReqAPI) ResetMetrics() {
|
||||
api.metrics = ngxReqMetrics{}
|
||||
}
|
||||
|
||||
// ==================== 辅助方法 ====================
|
||||
|
||||
// getRequestHeader 获取请求头(辅助函数,供 Lua 绑定使用)
|
||||
func (api *ngxReqAPI) getRequestHeader(name string) string {
|
||||
return string(api.ctx.Request.Header.Peek(name))
|
||||
}
|
||||
|
||||
// setResponseHeader 设置响应头(辅助函数,供 Lua 绑定使用)
|
||||
func (api *ngxReqAPI) setResponseHeader(name, value string) {
|
||||
api.ctx.Response.Header.Set(name, value)
|
||||
}
|
||||
|
||||
// parseQueryString 手动解析 query string(用于对比测试)
|
||||
// 这是纯 Go 实现,不依赖 fasthttp 的解析器
|
||||
func parseQueryString(query []byte) map[string][]string {
|
||||
result := make(map[string][]string)
|
||||
if len(query) == 0 {
|
||||
return result
|
||||
}
|
||||
|
||||
pairs := strings.Split(string(query), "&")
|
||||
for _, pair := range pairs {
|
||||
if len(pair) == 0 {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(pair, "=", 2)
|
||||
key := parts[0]
|
||||
value := ""
|
||||
if len(parts) > 1 {
|
||||
value = parts[1]
|
||||
}
|
||||
|
||||
if existing, ok := result[key]; ok {
|
||||
result[key] = append(existing, value)
|
||||
} else {
|
||||
result[key] = []string{value}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
857
internal/lua/api_socket_tcp.go
Normal file
857
internal/lua/api_socket_tcp.go
Normal file
@ -0,0 +1,857 @@
|
||||
// Package lua 提供 Cosocket TCP API 实现
|
||||
package lua
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
glua "github.com/yuin/gopher-lua"
|
||||
)
|
||||
|
||||
// TCPSocket TCP socket 对象
|
||||
type TCPSocket struct {
|
||||
// 连接
|
||||
conn net.Conn
|
||||
|
||||
// 读取超时
|
||||
readTimeout time.Duration
|
||||
|
||||
// 发送超时
|
||||
sendTimeout time.Duration
|
||||
|
||||
// 连接超时
|
||||
connectTimeout time.Duration
|
||||
|
||||
// 当前操作
|
||||
currentOp *SocketOperation
|
||||
|
||||
// 状态
|
||||
state SocketState
|
||||
|
||||
// 互斥锁
|
||||
mu sync.RWMutex
|
||||
|
||||
// 管理器引用
|
||||
manager *CosocketManager
|
||||
|
||||
// 连接地址
|
||||
addr *net.TCPAddr
|
||||
|
||||
// 创建时间
|
||||
createdAt time.Time
|
||||
|
||||
// 是否关闭
|
||||
closed int32
|
||||
|
||||
// Lua 引用计数(用于 GC)
|
||||
luaRef int32
|
||||
}
|
||||
|
||||
// NewTCPSocket 创建新的 TCP socket
|
||||
func NewTCPSocket(manager *CosocketManager) *TCPSocket {
|
||||
if manager == nil {
|
||||
manager = DefaultCosocketManager
|
||||
}
|
||||
|
||||
s := &TCPSocket{
|
||||
manager: manager,
|
||||
state: SocketStateIdle,
|
||||
readTimeout: 60 * time.Second,
|
||||
sendTimeout: 60 * time.Second,
|
||||
connectTimeout: 30 * time.Second,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
|
||||
manager.TrackSocketCreated()
|
||||
return s
|
||||
}
|
||||
|
||||
// Connect 连接到指定地址(支持 yield)
|
||||
func (s *TCPSocket) Connect(host string, port int) error {
|
||||
s.mu.Lock()
|
||||
if s.state != SocketStateIdle {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("socket not idle, current state: %s", s.state)
|
||||
}
|
||||
s.state = SocketStateConnecting
|
||||
s.mu.Unlock()
|
||||
|
||||
// 解析地址
|
||||
addr, err := s.manager.TCPAddr(host, port)
|
||||
if err != nil {
|
||||
s.setState(SocketStateError)
|
||||
return fmt.Errorf("resolve address: %w", err)
|
||||
}
|
||||
s.addr = addr
|
||||
|
||||
// 开始操作
|
||||
op := s.manager.StartOperation(s, OpConnect, s.connectTimeout)
|
||||
s.currentOp = op
|
||||
|
||||
// 在 goroutine 中执行连接
|
||||
go func() {
|
||||
defer func() {
|
||||
s.currentOp = nil
|
||||
}()
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: s.connectTimeout,
|
||||
}
|
||||
|
||||
conn, err := dialer.DialContext(context.Background(), "tcp", addr.String())
|
||||
if err != nil {
|
||||
s.setState(SocketStateError)
|
||||
s.manager.CompleteOperation(op.ID, nil, fmt.Errorf("dial: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.conn = conn
|
||||
s.state = SocketStateConnected
|
||||
s.mu.Unlock()
|
||||
|
||||
s.manager.CompleteOperation(op.ID, conn, nil)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConnectAsync 异步连接(用于 Lua yield)
|
||||
func (s *TCPSocket) ConnectAsync(L *glua.LState, host string, port int) (*SocketOperation, error) {
|
||||
err := s.Connect(host, port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.currentOp, nil
|
||||
}
|
||||
|
||||
// Send 发送数据
|
||||
func (s *TCPSocket) Send(data []byte) (int, error) {
|
||||
s.mu.RLock()
|
||||
if s.state != SocketStateConnected {
|
||||
s.mu.RUnlock()
|
||||
return 0, fmt.Errorf("socket not connected, current state: %s", s.state)
|
||||
}
|
||||
conn := s.conn
|
||||
s.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return 0, fmt.Errorf("socket connection is nil")
|
||||
}
|
||||
|
||||
// 设置写超时
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(s.sendTimeout)); err != nil {
|
||||
return 0, fmt.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
|
||||
n, err := conn.Write(data)
|
||||
if err != nil {
|
||||
s.setState(SocketStateError)
|
||||
return n, fmt.Errorf("write: %w", err)
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// SendAsync 异步发送(用于 Lua yield)
|
||||
func (s *TCPSocket) SendAsync(data []byte) (*SocketOperation, error) {
|
||||
s.mu.RLock()
|
||||
if s.state != SocketStateConnected {
|
||||
s.mu.RUnlock()
|
||||
return nil, fmt.Errorf("socket not connected, current state: %s", s.state)
|
||||
}
|
||||
conn := s.conn
|
||||
s.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("socket connection is nil")
|
||||
}
|
||||
|
||||
// 开始操作
|
||||
op := s.manager.StartOperation(s, OpSend, s.sendTimeout)
|
||||
s.currentOp = op
|
||||
s.setState(SocketStateSending)
|
||||
|
||||
// 在 goroutine 中执行发送
|
||||
go func() {
|
||||
defer func() {
|
||||
s.currentOp = nil
|
||||
s.setState(SocketStateConnected)
|
||||
}()
|
||||
|
||||
// 设置写超时
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(s.sendTimeout)); err != nil {
|
||||
s.manager.CompleteOperation(op.ID, 0, fmt.Errorf("set write deadline: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
n, err := conn.Write(data)
|
||||
if err != nil {
|
||||
s.setState(SocketStateError)
|
||||
s.manager.CompleteOperation(op.ID, n, fmt.Errorf("write: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.manager.CompleteOperation(op.ID, n, nil)
|
||||
}()
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
// Receive 接收数据
|
||||
func (s *TCPSocket) Receive(size int) ([]byte, error) {
|
||||
s.mu.RLock()
|
||||
if s.state != SocketStateConnected {
|
||||
s.mu.RUnlock()
|
||||
return nil, fmt.Errorf("socket not connected, current state: %s", s.state)
|
||||
}
|
||||
conn := s.conn
|
||||
s.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("socket connection is nil")
|
||||
}
|
||||
|
||||
// 设置读超时
|
||||
if err := conn.SetReadDeadline(time.Now().Add(s.readTimeout)); err != nil {
|
||||
return nil, fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
|
||||
// 默认读取大小
|
||||
if size <= 0 {
|
||||
size = 4096
|
||||
}
|
||||
|
||||
buf := make([]byte, size)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err.Error() == "EOF" {
|
||||
return nil, nil // 连接关闭
|
||||
}
|
||||
s.setState(SocketStateError)
|
||||
return nil, fmt.Errorf("read: %w", err)
|
||||
}
|
||||
|
||||
return buf[:n], nil
|
||||
}
|
||||
|
||||
// ReceiveAsync 异步接收(用于 Lua yield)
|
||||
func (s *TCPSocket) ReceiveAsync(size int) (*SocketOperation, error) {
|
||||
s.mu.RLock()
|
||||
if s.state != SocketStateConnected {
|
||||
s.mu.RUnlock()
|
||||
return nil, fmt.Errorf("socket not connected, current state: %s", s.state)
|
||||
}
|
||||
conn := s.conn
|
||||
s.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("socket connection is nil")
|
||||
}
|
||||
|
||||
// 开始操作
|
||||
op := s.manager.StartOperation(s, OpReceive, s.readTimeout)
|
||||
s.currentOp = op
|
||||
s.setState(SocketStateReceiving)
|
||||
|
||||
// 在 goroutine 中执行接收
|
||||
go func() {
|
||||
defer func() {
|
||||
s.currentOp = nil
|
||||
s.setState(SocketStateConnected)
|
||||
}()
|
||||
|
||||
// 设置读超时
|
||||
if err := conn.SetReadDeadline(time.Now().Add(s.readTimeout)); err != nil {
|
||||
s.manager.CompleteOperation(op.ID, nil, fmt.Errorf("set read deadline: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 默认读取大小
|
||||
if size <= 0 {
|
||||
size = 4096
|
||||
}
|
||||
|
||||
buf := make([]byte, size)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err.Error() == "EOF" {
|
||||
s.manager.CompleteOperation(op.ID, []byte{}, nil)
|
||||
return
|
||||
}
|
||||
s.setState(SocketStateError)
|
||||
s.manager.CompleteOperation(op.ID, nil, fmt.Errorf("read: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
s.manager.CompleteOperation(op.ID, buf[:n], nil)
|
||||
}()
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
// ReceiveUntil 读取直到特定模式
|
||||
func (s *TCPSocket) ReceiveUntil(pattern string, inclusive bool) ([]byte, error) {
|
||||
if len(pattern) == 0 {
|
||||
return nil, fmt.Errorf("pattern cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
if s.state != SocketStateConnected {
|
||||
s.mu.RUnlock()
|
||||
return nil, fmt.Errorf("socket not connected, current state: %s", s.state)
|
||||
}
|
||||
conn := s.conn
|
||||
s.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("socket connection is nil")
|
||||
}
|
||||
|
||||
// 设置读超时
|
||||
if err := conn.SetReadDeadline(time.Now().Add(s.readTimeout)); err != nil {
|
||||
return nil, fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
|
||||
// 使用带缓冲的读取
|
||||
var result []byte
|
||||
buf := make([]byte, 1)
|
||||
patternBytes := []byte(pattern)
|
||||
patternLen := len(patternBytes)
|
||||
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err.Error() == "EOF" {
|
||||
return result, nil
|
||||
}
|
||||
s.setState(SocketStateError)
|
||||
return result, fmt.Errorf("read: %w", err)
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
result = append(result, buf[0])
|
||||
|
||||
// 检查是否匹配模式
|
||||
if len(result) >= patternLen {
|
||||
matched := true
|
||||
for i := 0; i < patternLen; i++ {
|
||||
if result[len(result)-patternLen+i] != patternBytes[i] {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
if !inclusive {
|
||||
result = result[:len(result)-patternLen]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 防止无限增长
|
||||
if len(result) > 1024*1024 { // 1MB 限制
|
||||
return result, fmt.Errorf("receive buffer exceeded 1MB limit")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭 socket
|
||||
func (s *TCPSocket) Close() error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
if !atomic.CompareAndSwapInt32(&s.closed, 0, 1) {
|
||||
return nil // 已经关闭
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 取消当前操作
|
||||
if s.currentOp != nil && !s.currentOp.IsCompleted() && s.manager != nil {
|
||||
s.manager.CompleteOperation(s.currentOp.ID, nil, fmt.Errorf("socket closed"))
|
||||
s.currentOp = nil
|
||||
}
|
||||
|
||||
// 关闭连接
|
||||
if s.conn != nil {
|
||||
s.conn.Close()
|
||||
s.conn = nil
|
||||
}
|
||||
|
||||
s.state = SocketStateClosed
|
||||
if s.manager != nil {
|
||||
s.manager.TrackSocketClosed()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTimeout 设置超时
|
||||
func (s *TCPSocket) SetTimeout(timeout time.Duration) {
|
||||
s.readTimeout = timeout
|
||||
s.sendTimeout = timeout
|
||||
s.connectTimeout = timeout
|
||||
}
|
||||
|
||||
// SetReadTimeout 设置读取超时
|
||||
func (s *TCPSocket) SetReadTimeout(timeout time.Duration) {
|
||||
s.readTimeout = timeout
|
||||
}
|
||||
|
||||
// SetSendTimeout 设置发送超时
|
||||
func (s *TCPSocket) SetSendTimeout(timeout time.Duration) {
|
||||
s.sendTimeout = timeout
|
||||
}
|
||||
|
||||
// SetConnectTimeout 设置连接超时
|
||||
func (s *TCPSocket) SetConnectTimeout(timeout time.Duration) {
|
||||
s.connectTimeout = timeout
|
||||
}
|
||||
|
||||
// State 获取当前状态
|
||||
func (s *TCPSocket) State() SocketState {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.state
|
||||
}
|
||||
|
||||
// setState 设置状态
|
||||
func (s *TCPSocket) setState(state SocketState) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.state = state
|
||||
}
|
||||
|
||||
// IsClosed 检查是否已关闭
|
||||
func (s *TCPSocket) IsClosed() bool {
|
||||
return atomic.LoadInt32(&s.closed) == 1
|
||||
}
|
||||
|
||||
// LocalAddr 获取本地地址
|
||||
func (s *TCPSocket) LocalAddr() net.Addr {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.conn != nil {
|
||||
return s.conn.LocalAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoteAddr 获取远程地址
|
||||
func (s *TCPSocket) RemoteAddr() net.Addr {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.conn != nil {
|
||||
return s.conn.RemoteAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// -------------------- Lua API --------------------
|
||||
|
||||
// tcpSocketMT TCP socket 元表名称
|
||||
const tcpSocketMT = "tcp_socket"
|
||||
|
||||
// RegisterTCPSocketAPI 注册 TCP socket API
|
||||
func RegisterTCPSocketAPI(L *glua.LState, engine *LuaEngine) {
|
||||
// 创建 ngx.socket 表
|
||||
socket := L.NewTable()
|
||||
|
||||
// ngx.socket.tcp()
|
||||
socket.RawSetString("tcp", L.NewFunction(newTCPSocketFunc(engine)))
|
||||
|
||||
// 确保 ngx 表存在
|
||||
ngx := L.GetGlobal("ngx")
|
||||
var ngxTbl *glua.LTable
|
||||
if tbl, ok := ngx.(*glua.LTable); ok {
|
||||
ngxTbl = tbl
|
||||
} else {
|
||||
// 创建 ngx 表
|
||||
ngxTbl = L.NewTable()
|
||||
L.SetGlobal("ngx", ngxTbl)
|
||||
}
|
||||
ngxTbl.RawSetString("socket", socket)
|
||||
|
||||
// 注册元表
|
||||
registerTCPSocketMetaTable(L)
|
||||
}
|
||||
|
||||
// registerTCPSocketMetaTable 注册 TCP socket 元表
|
||||
func registerTCPSocketMetaTable(L *glua.LState) {
|
||||
mt := L.NewTable()
|
||||
|
||||
// __index
|
||||
index := L.NewTable()
|
||||
index.RawSetString("connect", L.NewFunction(tcpSocketConnect))
|
||||
index.RawSetString("send", L.NewFunction(tcpSocketSend))
|
||||
index.RawSetString("receive", L.NewFunction(tcpSocketReceive))
|
||||
index.RawSetString("receiveuntil", L.NewFunction(tcpSocketReceiveUntil))
|
||||
index.RawSetString("close", L.NewFunction(tcpSocketClose))
|
||||
index.RawSetString("settimeout", L.NewFunction(tcpSocketSetTimeout))
|
||||
index.RawSetString("settimeouts", L.NewFunction(tcpSocketSetTimeouts))
|
||||
|
||||
mt.RawSetString("__index", index)
|
||||
mt.RawSetString("__gc", L.NewFunction(tcpSocketGC))
|
||||
mt.RawSetString("__tostring", L.NewFunction(tcpSocketToString))
|
||||
|
||||
L.SetMetatable(L.NewTable(), mt)
|
||||
L.SetGlobal(tcpSocketMT, mt)
|
||||
}
|
||||
|
||||
// newTCPSocketFunc 创建 TCP socket
|
||||
func newTCPSocketFunc(engine *LuaEngine) func(*glua.LState) int {
|
||||
return func(L *glua.LState) int {
|
||||
socket := NewTCPSocket(DefaultCosocketManager)
|
||||
|
||||
// 创建 userdata
|
||||
ud := L.NewUserData()
|
||||
ud.Value = socket
|
||||
L.SetMetatable(ud, L.GetGlobal(tcpSocketMT).(*glua.LTable))
|
||||
|
||||
L.Push(ud)
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
// checkTCPSocket 检查并获取 TCP socket
|
||||
func checkTCPSocket(L *glua.LState, n int) *TCPSocket {
|
||||
ud := L.CheckUserData(n)
|
||||
if socket, ok := ud.Value.(*TCPSocket); ok {
|
||||
return socket
|
||||
}
|
||||
L.ArgError(n, "tcp socket expected")
|
||||
return nil
|
||||
}
|
||||
|
||||
// tcpSocketConnect tcpsock:connect()
|
||||
func tcpSocketConnect(L *glua.LState) int {
|
||||
socket := checkTCPSocket(L, 1)
|
||||
host := L.CheckString(2)
|
||||
port := L.CheckInt(3)
|
||||
|
||||
opts := L.OptTable(4, nil)
|
||||
timeout := 30000 // 默认 30 秒
|
||||
if opts != nil {
|
||||
if t := opts.RawGetString("timeout"); t != glua.LNil {
|
||||
timeout = int(glua.LVAsNumber(t))
|
||||
}
|
||||
}
|
||||
|
||||
// 设置超时
|
||||
socket.SetConnectTimeout(time.Duration(timeout) * time.Millisecond)
|
||||
|
||||
// 开始异步连接
|
||||
op, err := socket.ConnectAsync(L, host, port)
|
||||
if err != nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
// yield 等待连接完成
|
||||
L.Push(glua.LString("cosocket_connect"))
|
||||
L.Push(glua.LNumber(op.ID))
|
||||
// TODO: 实现真正的非阻塞 yield,目前使用同步模拟
|
||||
return 2
|
||||
}
|
||||
|
||||
// tcpSocketSend tcpsock:send()
|
||||
func tcpSocketSend(L *glua.LState) int {
|
||||
socket := checkTCPSocket(L, 1)
|
||||
data := L.CheckString(2)
|
||||
|
||||
opts := L.OptTable(3, nil)
|
||||
timeout := 60000 // 默认 60 秒
|
||||
if opts != nil {
|
||||
if t := opts.RawGetString("timeout"); t != glua.LNil {
|
||||
timeout = int(glua.LVAsNumber(t))
|
||||
}
|
||||
}
|
||||
|
||||
// 设置超时
|
||||
socket.SetSendTimeout(time.Duration(timeout) * time.Millisecond)
|
||||
|
||||
// 开始异步发送
|
||||
op, err := socket.SendAsync([]byte(data))
|
||||
if err != nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
// yield 等待发送完成
|
||||
L.Push(glua.LString("cosocket_send"))
|
||||
L.Push(glua.LNumber(op.ID))
|
||||
// TODO: 实现真正的非阻塞 yield,目前使用同步模拟
|
||||
return 2
|
||||
}
|
||||
|
||||
// tcpSocketReceive tcpsock:receive()
|
||||
func tcpSocketReceive(L *glua.LState) int {
|
||||
socket := checkTCPSocket(L, 1)
|
||||
|
||||
// 解析参数
|
||||
var size int
|
||||
opts := L.NewTable()
|
||||
|
||||
if L.GetTop() >= 2 {
|
||||
switch v := L.Get(2).(type) {
|
||||
case glua.LNumber:
|
||||
size = int(v)
|
||||
case *glua.LTable:
|
||||
opts = v
|
||||
default:
|
||||
// 检查是否为字符串类型
|
||||
if L.Get(2).Type() == glua.LTString {
|
||||
// 接收特定模式
|
||||
return tcpSocketReceivePattern(L, socket, L.CheckString(2), opts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if L.GetTop() >= 3 {
|
||||
if t, ok := L.Get(3).(*glua.LTable); ok {
|
||||
opts = t
|
||||
}
|
||||
}
|
||||
|
||||
// 获取超时
|
||||
timeout := 60000 // 默认 60 秒
|
||||
if t := opts.RawGetString("timeout"); t != glua.LNil {
|
||||
timeout = int(glua.LVAsNumber(t))
|
||||
}
|
||||
|
||||
// 设置超时
|
||||
socket.SetReadTimeout(time.Duration(timeout) * time.Millisecond)
|
||||
|
||||
// 开始异步接收
|
||||
op, err := socket.ReceiveAsync(size)
|
||||
if err != nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
// yield 等待接收完成
|
||||
L.Push(glua.LString("cosocket_receive"))
|
||||
L.Push(glua.LNumber(op.ID))
|
||||
// TODO: 实现真正的非阻塞 yield,目前使用同步模拟
|
||||
return 2
|
||||
}
|
||||
|
||||
// tcpSocketReceivePattern 按模式接收
|
||||
func tcpSocketReceivePattern(L *glua.LState, socket *TCPSocket, pattern string, opts *glua.LTable) int {
|
||||
switch pattern {
|
||||
case "*l":
|
||||
// 读取一行
|
||||
data, err := socket.ReceiveUntil("\n", true)
|
||||
if err != nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(glua.LString(string(data)))
|
||||
return 1
|
||||
case "*a":
|
||||
// 读取所有(这里简化为读取最大 64KB)
|
||||
data, err := socket.Receive(64 * 1024)
|
||||
if err != nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(glua.LString(string(data)))
|
||||
return 1
|
||||
default:
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString("unknown pattern: " + pattern))
|
||||
return 2
|
||||
}
|
||||
}
|
||||
|
||||
// tcpSocketReceiveUntil tcpsock:receiveuntil()
|
||||
func tcpSocketReceiveUntil(L *glua.LState) int {
|
||||
socket := checkTCPSocket(L, 1)
|
||||
pattern := L.CheckString(2)
|
||||
|
||||
opts := L.OptTable(3, nil)
|
||||
inclusive := false
|
||||
if opts != nil {
|
||||
if inc := opts.RawGetString("inclusive"); inc != glua.LNil {
|
||||
inclusive = glua.LVAsBool(inc)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := socket.ReceiveUntil(pattern, inclusive)
|
||||
if err != nil {
|
||||
L.Push(glua.LNil)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
|
||||
// 创建迭代器函数
|
||||
iter := L.NewFunction(func(L *glua.LState) int {
|
||||
if len(data) == 0 {
|
||||
L.Push(glua.LNil)
|
||||
return 1
|
||||
}
|
||||
L.Push(glua.LString(string(data)))
|
||||
data = nil // 清空,下次返回 nil
|
||||
return 1
|
||||
})
|
||||
|
||||
L.Push(iter)
|
||||
return 1
|
||||
}
|
||||
|
||||
// tcpSocketClose tcpsock:close()
|
||||
func tcpSocketClose(L *glua.LState) int {
|
||||
socket := checkTCPSocket(L, 1)
|
||||
if err := socket.Close(); err != nil {
|
||||
L.Push(glua.LFalse)
|
||||
L.Push(glua.LString(err.Error()))
|
||||
return 2
|
||||
}
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// tcpSocketSetTimeout tcpsock:settimeout()
|
||||
func tcpSocketSetTimeout(L *glua.LState) int {
|
||||
socket := checkTCPSocket(L, 1)
|
||||
timeout := L.CheckNumber(2)
|
||||
socket.SetTimeout(time.Duration(timeout) * time.Millisecond)
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// tcpSocketSetTimeouts tcpsock:settimeouts()
|
||||
func tcpSocketSetTimeouts(L *glua.LState) int {
|
||||
socket := checkTCPSocket(L, 1)
|
||||
connectTimeout := L.CheckNumber(2)
|
||||
sendTimeout := L.CheckNumber(3)
|
||||
readTimeout := L.CheckNumber(4)
|
||||
|
||||
socket.SetConnectTimeout(time.Duration(connectTimeout) * time.Millisecond)
|
||||
socket.SetSendTimeout(time.Duration(sendTimeout) * time.Millisecond)
|
||||
socket.SetReadTimeout(time.Duration(readTimeout) * time.Millisecond)
|
||||
|
||||
L.Push(glua.LTrue)
|
||||
return 1
|
||||
}
|
||||
|
||||
// tcpSocketGC __gc 元方法
|
||||
func tcpSocketGC(L *glua.LState) int {
|
||||
socket := checkTCPSocket(L, 1)
|
||||
socket.Close()
|
||||
return 0
|
||||
}
|
||||
|
||||
// tcpSocketToString __tostring 元方法
|
||||
func tcpSocketToString(L *glua.LState) int {
|
||||
socket := checkTCPSocket(L, 1)
|
||||
state := socket.State()
|
||||
L.Push(glua.LString(fmt.Sprintf("tcp_socket(%s)", state)))
|
||||
return 1
|
||||
}
|
||||
|
||||
// HandleCosocketYield 处理 cosocket yield
|
||||
func (c *LuaCoroutine) HandleCosocketYield(reason string, values []glua.LValue) ([]glua.LValue, error) {
|
||||
switch reason {
|
||||
case "cosocket_connect":
|
||||
return c.handleCosocketConnect(values)
|
||||
case "cosocket_send":
|
||||
return c.handleCosocketSend(values)
|
||||
case "cosocket_receive":
|
||||
return c.handleCosocketReceive(values)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown cosocket yield reason: %s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCosocketConnect 处理连接 yield
|
||||
func (c *LuaCoroutine) handleCosocketConnect(values []glua.LValue) ([]glua.LValue, error) {
|
||||
if len(values) == 0 {
|
||||
return nil, fmt.Errorf("cosocket_connect requires operation ID")
|
||||
}
|
||||
|
||||
opID := uint64(glua.LVAsNumber(values[0]))
|
||||
op := DefaultCosocketManager.GetOperation(opID)
|
||||
if op == nil {
|
||||
return nil, fmt.Errorf("operation %d not found", opID)
|
||||
}
|
||||
|
||||
// 等待操作完成
|
||||
result, err := op.Wait(c.ExecutionContext)
|
||||
if err != nil {
|
||||
return []glua.LValue{glua.LNil, glua.LString(err.Error())}, nil
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
return []glua.LValue{glua.LNil, glua.LNil}, nil
|
||||
}
|
||||
|
||||
_ = result // 连接成功,返回 1
|
||||
return []glua.LValue{glua.LNumber(1)}, nil
|
||||
}
|
||||
|
||||
// handleCosocketSend 处理发送 yield
|
||||
func (c *LuaCoroutine) handleCosocketSend(values []glua.LValue) ([]glua.LValue, error) {
|
||||
if len(values) == 0 {
|
||||
return nil, fmt.Errorf("cosocket_send requires operation ID")
|
||||
}
|
||||
|
||||
opID := uint64(glua.LVAsNumber(values[0]))
|
||||
op := DefaultCosocketManager.GetOperation(opID)
|
||||
if op == nil {
|
||||
return nil, fmt.Errorf("operation %d not found", opID)
|
||||
}
|
||||
|
||||
// 等待操作完成
|
||||
result, err := op.Wait(c.ExecutionContext)
|
||||
if err != nil {
|
||||
return []glua.LValue{glua.LNil, glua.LString(err.Error())}, nil
|
||||
}
|
||||
|
||||
if n, ok := result.(int); ok {
|
||||
return []glua.LValue{glua.LNumber(n)}, nil
|
||||
}
|
||||
|
||||
return []glua.LValue{glua.LNil, glua.LString("invalid result")}, nil
|
||||
}
|
||||
|
||||
// handleCosocketReceive 处理接收 yield
|
||||
func (c *LuaCoroutine) handleCosocketReceive(values []glua.LValue) ([]glua.LValue, error) {
|
||||
if len(values) == 0 {
|
||||
return nil, fmt.Errorf("cosocket_receive requires operation ID")
|
||||
}
|
||||
|
||||
opID := uint64(glua.LVAsNumber(values[0]))
|
||||
op := DefaultCosocketManager.GetOperation(opID)
|
||||
if op == nil {
|
||||
return nil, fmt.Errorf("operation %d not found", opID)
|
||||
}
|
||||
|
||||
// 等待操作完成
|
||||
result, err := op.Wait(c.ExecutionContext)
|
||||
if err != nil {
|
||||
return []glua.LValue{glua.LNil, glua.LString(err.Error())}, nil
|
||||
}
|
||||
|
||||
if data, ok := result.([]byte); ok {
|
||||
if len(data) == 0 {
|
||||
return []glua.LValue{glua.LNil, glua.LString("closed")}, nil
|
||||
}
|
||||
return []glua.LValue{glua.LString(string(data))}, nil
|
||||
}
|
||||
|
||||
return []glua.LValue{glua.LNil, glua.LString("invalid result")}, nil
|
||||
}
|
||||
1509
internal/lua/filter_phase_test.go
Normal file
1509
internal/lua/filter_phase_test.go
Normal file
File diff suppressed because it is too large
Load Diff
583
internal/lua/filter_writer.go
Normal file
583
internal/lua/filter_writer.go
Normal file
@ -0,0 +1,583 @@
|
||||
// Package lua 提供 Lua 脚本嵌入能力
|
||||
package lua
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ResponseInterceptor 响应拦截器
|
||||
// 用于延迟 header 写入,允许在发送前修改响应
|
||||
type ResponseInterceptor struct {
|
||||
// 原始请求上下文
|
||||
ctx *fasthttp.RequestCtx
|
||||
|
||||
// Header 修改回调(Lua 执行)
|
||||
headerFilterFunc func() error
|
||||
|
||||
// Body 修改回调(Lua 执行)
|
||||
bodyFilterFunc func([]byte) ([]byte, error)
|
||||
|
||||
// 缓冲的 body 数据
|
||||
bodyBuffer []byte
|
||||
|
||||
// 是否已写入 header
|
||||
headersWritten bool
|
||||
|
||||
// 是否已拦截
|
||||
intercepted bool
|
||||
|
||||
// 状态码(可修改)
|
||||
statusCode int
|
||||
|
||||
// 自定义 header(可修改)
|
||||
customHeaders map[string]string
|
||||
|
||||
// 需要删除的 header
|
||||
headersToDelete []string
|
||||
|
||||
// 并发保护
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewResponseInterceptor 创建响应拦截器
|
||||
func NewResponseInterceptor(ctx *fasthttp.RequestCtx) *ResponseInterceptor {
|
||||
return &ResponseInterceptor{
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
customHeaders: make(map[string]string),
|
||||
headersToDelete: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// SetHeaderFilter 设置 header 过滤器回调
|
||||
func (ri *ResponseInterceptor) SetHeaderFilter(fn func() error) {
|
||||
ri.headerFilterFunc = fn
|
||||
}
|
||||
|
||||
// SetBodyFilter 设置 body 过滤器回调
|
||||
func (ri *ResponseInterceptor) SetBodyFilter(fn func([]byte) ([]byte, error)) {
|
||||
ri.bodyFilterFunc = fn
|
||||
}
|
||||
|
||||
// SetStatusCode 设置状态码(延迟生效)
|
||||
func (ri *ResponseInterceptor) SetStatusCode(code int) {
|
||||
ri.statusCode = code
|
||||
}
|
||||
|
||||
// GetStatusCode 获取当前状态码
|
||||
func (ri *ResponseInterceptor) GetStatusCode() int {
|
||||
return ri.statusCode
|
||||
}
|
||||
|
||||
// SetHeader 设置 header(延迟生效)
|
||||
func (ri *ResponseInterceptor) SetHeader(key, value string) {
|
||||
ri.mu.Lock()
|
||||
defer ri.mu.Unlock()
|
||||
ri.customHeaders[key] = value
|
||||
}
|
||||
|
||||
// GetHeader 获取原始 header 值
|
||||
func (ri *ResponseInterceptor) GetHeader(key string) []byte {
|
||||
return ri.ctx.Response.Header.Peek(key)
|
||||
}
|
||||
|
||||
// DelHeader 删除 header(延迟生效)
|
||||
func (ri *ResponseInterceptor) DelHeader(key string) {
|
||||
ri.headersToDelete = append(ri.headersToDelete, key)
|
||||
}
|
||||
|
||||
// Write 拦截写入操作(缓冲 body,延迟 header 发送)
|
||||
func (ri *ResponseInterceptor) Write(p []byte) (int, error) {
|
||||
if !ri.intercepted {
|
||||
// 未启用拦截,直接写入
|
||||
return ri.ctx.Write(p)
|
||||
}
|
||||
|
||||
// 缓冲 body 数据
|
||||
ri.bodyBuffer = append(ri.bodyBuffer, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// WriteString 写入字符串
|
||||
func (ri *ResponseInterceptor) WriteString(s string) (int, error) {
|
||||
return ri.Write([]byte(s))
|
||||
}
|
||||
|
||||
// SetBody 设置 body(延迟发送)
|
||||
func (ri *ResponseInterceptor) SetBody(body []byte) {
|
||||
if !ri.intercepted {
|
||||
ri.ctx.SetBody(body)
|
||||
return
|
||||
}
|
||||
ri.bodyBuffer = body
|
||||
}
|
||||
|
||||
// SetBodyString 设置字符串 body
|
||||
func (ri *ResponseInterceptor) SetBodyString(body string) {
|
||||
ri.SetBody([]byte(body))
|
||||
}
|
||||
|
||||
// Flush 执行 header/body filter 并发送响应
|
||||
func (ri *ResponseInterceptor) Flush() error {
|
||||
if ri.headersWritten {
|
||||
return nil // 已经发送过
|
||||
}
|
||||
ri.headersWritten = true
|
||||
|
||||
// 1. 执行 header filter
|
||||
if ri.headerFilterFunc != nil {
|
||||
if err := ri.headerFilterFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 应用 header 修改
|
||||
ri.ctx.Response.SetStatusCode(ri.statusCode)
|
||||
for key, value := range ri.customHeaders {
|
||||
ri.ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
for _, key := range ri.headersToDelete {
|
||||
ri.ctx.Response.Header.Del(key)
|
||||
}
|
||||
|
||||
// 3. 执行 body filter
|
||||
body := ri.bodyBuffer
|
||||
if ri.bodyFilterFunc != nil && len(body) > 0 {
|
||||
modified, err := ri.bodyFilterFunc(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body = modified
|
||||
}
|
||||
|
||||
// 4. 发送响应
|
||||
if len(body) > 0 {
|
||||
ri.ctx.SetBody(body)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Enable 启用拦截模式
|
||||
func (ri *ResponseInterceptor) Enable() {
|
||||
ri.intercepted = true
|
||||
}
|
||||
|
||||
// Disable 禁用拦截模式
|
||||
func (ri *ResponseInterceptor) Disable() {
|
||||
ri.intercepted = false
|
||||
}
|
||||
|
||||
// IsEnabled 检查是否启用拦截
|
||||
func (ri *ResponseInterceptor) IsEnabled() bool {
|
||||
return ri.intercepted
|
||||
}
|
||||
|
||||
// GetBufferedBody 获取当前缓冲的 body
|
||||
func (ri *ResponseInterceptor) GetBufferedBody() []byte {
|
||||
return ri.bodyBuffer
|
||||
}
|
||||
|
||||
// ClearBody 清空 body 缓冲
|
||||
func (ri *ResponseInterceptor) ClearBody() {
|
||||
ri.bodyBuffer = nil
|
||||
}
|
||||
|
||||
// DelayedResponseWriter 延迟响应写入器
|
||||
// 包装 fasthttp.RequestCtx 提供延迟写入能力
|
||||
type DelayedResponseWriter struct {
|
||||
ctx *fasthttp.RequestCtx
|
||||
interceptor *ResponseInterceptor
|
||||
pool *sync.Pool
|
||||
}
|
||||
|
||||
// NewDelayedResponseWriter 创建延迟响应写入器
|
||||
func NewDelayedResponseWriter(ctx *fasthttp.RequestCtx) *DelayedResponseWriter {
|
||||
return &DelayedResponseWriter{
|
||||
ctx: ctx,
|
||||
interceptor: NewResponseInterceptor(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
// EnableFilterPhase 启用 filter phase
|
||||
func (drw *DelayedResponseWriter) EnableFilterPhase() {
|
||||
drw.interceptor.Enable()
|
||||
}
|
||||
|
||||
// DisableFilterPhase 禁用 filter phase
|
||||
func (drw *DelayedResponseWriter) DisableFilterPhase() {
|
||||
drw.interceptor.Disable()
|
||||
}
|
||||
|
||||
// GetInterceptor 获取响应拦截器
|
||||
func (drw *DelayedResponseWriter) GetInterceptor() *ResponseInterceptor {
|
||||
return drw.interceptor
|
||||
}
|
||||
|
||||
// HeaderFilter 执行 header filter 阶段
|
||||
func (drw *DelayedResponseWriter) HeaderFilter(script string, luaCtx *LuaContext) error {
|
||||
if !drw.interceptor.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
luaCtx.SetPhase(PhaseHeaderFilter)
|
||||
drw.interceptor.SetHeaderFilter(func() error {
|
||||
return luaCtx.Execute(script)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// BodyFilter 执行 body filter 阶段
|
||||
func (drw *DelayedResponseWriter) BodyFilter(script string, luaCtx *LuaContext) error {
|
||||
if !drw.interceptor.IsEnabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
luaCtx.SetPhase(PhaseBodyFilter)
|
||||
drw.interceptor.SetBodyFilter(func(body []byte) ([]byte, error) {
|
||||
// 将 body 设置到 Lua 上下文中
|
||||
luaCtx.OutputBuffer = body
|
||||
if err := luaCtx.Execute(script); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return luaCtx.OutputBuffer, nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush 刷新响应
|
||||
func (drw *DelayedResponseWriter) Flush() error {
|
||||
return drw.interceptor.Flush()
|
||||
}
|
||||
|
||||
// Write 实现 io.Writer
|
||||
func (drw *DelayedResponseWriter) Write(p []byte) (int, error) {
|
||||
return drw.interceptor.Write(p)
|
||||
}
|
||||
|
||||
// WriteString 写入字符串
|
||||
func (drw *DelayedResponseWriter) WriteString(s string) (int, error) {
|
||||
return drw.interceptor.WriteString(s)
|
||||
}
|
||||
|
||||
// SetStatusCode 设置状态码
|
||||
func (drw *DelayedResponseWriter) SetStatusCode(code int) {
|
||||
drw.interceptor.SetStatusCode(code)
|
||||
}
|
||||
|
||||
// SetBody 设置 body
|
||||
func (drw *DelayedResponseWriter) SetBody(body []byte) {
|
||||
drw.interceptor.SetBody(body)
|
||||
}
|
||||
|
||||
// SetBodyString 设置字符串 body
|
||||
func (drw *DelayedResponseWriter) SetBodyString(body string) {
|
||||
drw.interceptor.SetBodyString(body)
|
||||
}
|
||||
|
||||
// SetHeader 设置 header
|
||||
func (drw *DelayedResponseWriter) SetHeader(key, value string) {
|
||||
drw.interceptor.SetHeader(key, value)
|
||||
}
|
||||
|
||||
// GetHeader 获取 header
|
||||
func (drw *DelayedResponseWriter) GetHeader(key string) []byte {
|
||||
return drw.interceptor.GetHeader(key)
|
||||
}
|
||||
|
||||
// DelHeader 删除 header
|
||||
func (drw *DelayedResponseWriter) DelHeader(key string) {
|
||||
drw.interceptor.DelHeader(key)
|
||||
}
|
||||
|
||||
// ResponseInterceptorPool 响应拦截器池
|
||||
var ResponseInterceptorPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &ResponseInterceptor{}
|
||||
},
|
||||
}
|
||||
|
||||
// AcquireResponseInterceptor 从池中获取拦截器
|
||||
func AcquireResponseInterceptor(ctx *fasthttp.RequestCtx) *ResponseInterceptor {
|
||||
ri := ResponseInterceptorPool.Get().(*ResponseInterceptor)
|
||||
ri.ctx = ctx
|
||||
ri.statusCode = 200
|
||||
ri.customHeaders = make(map[string]string)
|
||||
ri.headersToDelete = make([]string, 0)
|
||||
ri.bodyBuffer = nil
|
||||
ri.headersWritten = false
|
||||
ri.intercepted = true
|
||||
ri.headerFilterFunc = nil
|
||||
ri.bodyFilterFunc = nil
|
||||
return ri
|
||||
}
|
||||
|
||||
// ReleaseResponseInterceptor 释放拦截器回池
|
||||
func ReleaseResponseInterceptor(ri *ResponseInterceptor) {
|
||||
if ri == nil {
|
||||
return
|
||||
}
|
||||
// 清理状态
|
||||
ri.ctx = nil
|
||||
ri.headerFilterFunc = nil
|
||||
ri.bodyFilterFunc = nil
|
||||
ri.bodyBuffer = nil
|
||||
ri.customHeaders = nil
|
||||
ri.headersToDelete = nil
|
||||
ResponseInterceptorPool.Put(ri)
|
||||
}
|
||||
|
||||
// responseWriterWrapper 适配 fasthttp.ResponseWriter 接口
|
||||
type responseWriterWrapper struct {
|
||||
interceptor *ResponseInterceptor
|
||||
}
|
||||
|
||||
func (w *responseWriterWrapper) Write(p []byte) (n int, err error) {
|
||||
return w.interceptor.Write(p)
|
||||
}
|
||||
|
||||
func (w *responseWriterWrapper) Header() map[string][]string {
|
||||
// fasthttp 不兼容 http.Header,返回 nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *responseWriterWrapper) WriteHeader(statusCode int) {
|
||||
w.interceptor.SetStatusCode(statusCode)
|
||||
}
|
||||
|
||||
// Hijack 支持连接劫持(用于 WebSocket)
|
||||
func (drw *DelayedResponseWriter) Hijack(handler fasthttp.HijackHandler) {
|
||||
drw.ctx.Hijack(handler)
|
||||
}
|
||||
|
||||
// Hijacked 检查是否已劫持
|
||||
func (drw *DelayedResponseWriter) Hijacked() bool {
|
||||
return drw.ctx.Hijacked()
|
||||
}
|
||||
|
||||
// LocalAddr 获取本地地址
|
||||
func (drw *DelayedResponseWriter) LocalAddr() net.Addr {
|
||||
return drw.ctx.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr 获取远程地址
|
||||
func (drw *DelayedResponseWriter) RemoteAddr() net.Addr {
|
||||
return drw.ctx.RemoteAddr()
|
||||
}
|
||||
|
||||
// SetConnectionClose 设置连接关闭
|
||||
func (drw *DelayedResponseWriter) SetConnectionClose() {
|
||||
drw.ctx.Response.SetConnectionClose()
|
||||
}
|
||||
|
||||
// BodyWriter 返回 body 写入器
|
||||
func (drw *DelayedResponseWriter) BodyWriter() io.Writer {
|
||||
return &responseWriterAdapter{interceptor: drw.interceptor}
|
||||
}
|
||||
|
||||
// responseWriterAdapter 适配 io.Writer
|
||||
type responseWriterAdapter struct {
|
||||
interceptor *ResponseInterceptor
|
||||
}
|
||||
|
||||
func (rwa *responseWriterAdapter) Write(p []byte) (n int, err error) {
|
||||
return rwa.interceptor.Write(p)
|
||||
}
|
||||
|
||||
// ResponseStats 响应统计信息
|
||||
type ResponseStats struct {
|
||||
BufferedBytes int
|
||||
HeadersModified int
|
||||
HeadersDeleted int
|
||||
BodyModified bool
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
// GetStats 获取响应统计
|
||||
func (drw *DelayedResponseWriter) GetStats() ResponseStats {
|
||||
return ResponseStats{
|
||||
BufferedBytes: len(drw.interceptor.bodyBuffer),
|
||||
HeadersModified: len(drw.interceptor.customHeaders),
|
||||
HeadersDeleted: len(drw.interceptor.headersToDelete),
|
||||
BodyModified: drw.interceptor.bodyFilterFunc != nil,
|
||||
StatusCode: drw.interceptor.statusCode,
|
||||
}
|
||||
}
|
||||
|
||||
// IsBodyBuffered 检查 body 是否被缓冲
|
||||
func (drw *DelayedResponseWriter) IsBodyBuffered() bool {
|
||||
return len(drw.interceptor.bodyBuffer) > 0
|
||||
}
|
||||
|
||||
// GetBufferedBodySize 获取缓冲的 body 大小
|
||||
func (drw *DelayedResponseWriter) GetBufferedBodySize() int {
|
||||
return len(drw.interceptor.bodyBuffer)
|
||||
}
|
||||
|
||||
// Reset 重置写入器状态
|
||||
func (drw *DelayedResponseWriter) Reset() {
|
||||
drw.interceptor.bodyBuffer = nil
|
||||
drw.interceptor.headersWritten = false
|
||||
drw.interceptor.statusCode = 200
|
||||
drw.interceptor.customHeaders = make(map[string]string)
|
||||
drw.interceptor.headersToDelete = make([]string, 0)
|
||||
}
|
||||
|
||||
// SetBodyStream 设置 body 流
|
||||
func (drw *DelayedResponseWriter) SetBodyStream(bodyStream io.Reader, bodySize int) {
|
||||
if !drw.interceptor.IsEnabled() {
|
||||
drw.ctx.SetBodyStream(bodyStream, bodySize)
|
||||
return
|
||||
}
|
||||
// 流式 body 无法缓冲,直接设置
|
||||
// 但在设置前应用 header filter
|
||||
if drw.interceptor.headerFilterFunc != nil {
|
||||
_ = drw.interceptor.headerFilterFunc()
|
||||
}
|
||||
drw.ctx.SetBodyStream(bodyStream, bodySize)
|
||||
drw.interceptor.headersWritten = true
|
||||
}
|
||||
|
||||
// SendFile 发送文件
|
||||
func (drw *DelayedResponseWriter) SendFile(path string) error {
|
||||
if !drw.interceptor.IsEnabled() {
|
||||
drw.ctx.SendFile(path)
|
||||
return nil
|
||||
}
|
||||
// 文件发送前应用 header filter
|
||||
if drw.interceptor.headerFilterFunc != nil {
|
||||
if err := drw.interceptor.headerFilterFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// 应用修改的 headers
|
||||
drw.ctx.Response.SetStatusCode(drw.interceptor.statusCode)
|
||||
for key, value := range drw.interceptor.customHeaders {
|
||||
drw.ctx.Response.Header.Set(key, value)
|
||||
}
|
||||
for _, key := range drw.interceptor.headersToDelete {
|
||||
drw.ctx.Response.Header.Del(key)
|
||||
}
|
||||
drw.ctx.SendFile(path)
|
||||
drw.interceptor.headersWritten = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Redirect 重定向
|
||||
func (drw *DelayedResponseWriter) Redirect(uri string, statusCode int) {
|
||||
if !drw.interceptor.IsEnabled() {
|
||||
drw.ctx.Redirect(uri, statusCode)
|
||||
return
|
||||
}
|
||||
// 重定向前应用 header filter
|
||||
if drw.interceptor.headerFilterFunc != nil {
|
||||
_ = drw.interceptor.headerFilterFunc()
|
||||
}
|
||||
drw.ctx.Redirect(uri, statusCode)
|
||||
drw.interceptor.headersWritten = true
|
||||
}
|
||||
|
||||
// bufferPool body 缓冲区池
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 0, 4096) // 4KB 初始容量
|
||||
},
|
||||
}
|
||||
|
||||
// acquireBuffer 获取缓冲区
|
||||
func acquireBuffer() []byte {
|
||||
return bufferPool.Get().([]byte)
|
||||
}
|
||||
|
||||
// releaseBuffer 释放缓冲区
|
||||
func releaseBuffer(buf []byte) {
|
||||
if buf != nil && cap(buf) <= 65536 { // 只回收小缓冲区
|
||||
bufferPool.Put(buf[:0])
|
||||
}
|
||||
}
|
||||
|
||||
// BufferedWriter 带缓冲的写入器
|
||||
type BufferedWriter struct {
|
||||
buf []byte
|
||||
size int
|
||||
flushFunc func([]byte) error
|
||||
maxSize int
|
||||
autoFlush bool
|
||||
}
|
||||
|
||||
// NewBufferedWriter 创建缓冲写入器
|
||||
func NewBufferedWriter(maxSize int, flushFunc func([]byte) error) *BufferedWriter {
|
||||
return &BufferedWriter{
|
||||
buf: acquireBuffer(),
|
||||
maxSize: maxSize,
|
||||
flushFunc: flushFunc,
|
||||
autoFlush: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Write 写入数据
|
||||
func (bw *BufferedWriter) Write(p []byte) (int, error) {
|
||||
if bw.buf == nil {
|
||||
bw.buf = acquireBuffer()
|
||||
}
|
||||
|
||||
// 检查是否需要扩容
|
||||
if len(bw.buf)+len(p) > cap(bw.buf) {
|
||||
// 扩容
|
||||
newCap := cap(bw.buf) * 2
|
||||
if newCap < len(bw.buf)+len(p) {
|
||||
newCap = len(bw.buf) + len(p)
|
||||
}
|
||||
newBuf := make([]byte, len(bw.buf), newCap)
|
||||
copy(newBuf, bw.buf)
|
||||
releaseBuffer(bw.buf)
|
||||
bw.buf = newBuf
|
||||
}
|
||||
|
||||
bw.buf = append(bw.buf, p...)
|
||||
|
||||
// 自动刷新检查
|
||||
if bw.autoFlush && bw.maxSize > 0 && len(bw.buf) >= bw.maxSize {
|
||||
if err := bw.Flush(); err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Flush 刷新缓冲区
|
||||
func (bw *BufferedWriter) Flush() error {
|
||||
if bw.flushFunc == nil || len(bw.buf) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := bw.flushFunc(bw.buf); err != nil {
|
||||
return err
|
||||
}
|
||||
bw.buf = bw.buf[:0]
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭写入器
|
||||
func (bw *BufferedWriter) Close() error {
|
||||
err := bw.Flush()
|
||||
if bw.buf != nil {
|
||||
releaseBuffer(bw.buf)
|
||||
bw.buf = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Size 返回当前缓冲区大小
|
||||
func (bw *BufferedWriter) Size() int {
|
||||
return len(bw.buf)
|
||||
}
|
||||
|
||||
// Bytes 返回当前缓冲区内容(不消费)
|
||||
func (bw *BufferedWriter) Bytes() []byte {
|
||||
return bw.buf
|
||||
}
|
||||
351
internal/lua/socket_manager.go
Normal file
351
internal/lua/socket_manager.go
Normal file
@ -0,0 +1,351 @@
|
||||
// Package lua 提供 Cosocket 管理功能
|
||||
package lua
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SocketState 表示 socket 操作状态
|
||||
type SocketState int
|
||||
|
||||
const (
|
||||
// SocketStateIdle 空闲状态
|
||||
SocketStateIdle SocketState = iota
|
||||
// SocketStateConnecting 连接中
|
||||
SocketStateConnecting
|
||||
// SocketStateConnected 已连接
|
||||
SocketStateConnected
|
||||
// SocketStateSending 发送中
|
||||
SocketStateSending
|
||||
// SocketStateReceiving 接收中
|
||||
SocketStateReceiving
|
||||
// SocketStateClosing 关闭中
|
||||
SocketStateClosing
|
||||
// SocketStateClosed 已关闭
|
||||
SocketStateClosed
|
||||
// SocketStateError 错误状态
|
||||
SocketStateError
|
||||
)
|
||||
|
||||
func (s SocketState) String() string {
|
||||
switch s {
|
||||
case SocketStateIdle:
|
||||
return "idle"
|
||||
case SocketStateConnecting:
|
||||
return "connecting"
|
||||
case SocketStateConnected:
|
||||
return "connected"
|
||||
case SocketStateSending:
|
||||
return "sending"
|
||||
case SocketStateReceiving:
|
||||
return "receiving"
|
||||
case SocketStateClosing:
|
||||
return "closing"
|
||||
case SocketStateClosed:
|
||||
return "closed"
|
||||
case SocketStateError:
|
||||
return "error"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// OperationType 操作类型
|
||||
type OperationType string
|
||||
|
||||
const (
|
||||
// OpConnect 连接操作
|
||||
OpConnect OperationType = "connect"
|
||||
// OpSend 发送操作
|
||||
OpSend OperationType = "send"
|
||||
// OpReceive 接收操作
|
||||
OpReceive OperationType = "receive"
|
||||
// OpClose 关闭操作
|
||||
OpClose OperationType = "close"
|
||||
)
|
||||
|
||||
// SocketOperation 表示一个 socket 操作
|
||||
type SocketOperation struct {
|
||||
// 操作 ID
|
||||
ID uint64
|
||||
|
||||
// 关联的 socket
|
||||
Socket *TCPSocket
|
||||
|
||||
// 操作类型
|
||||
Type OperationType
|
||||
|
||||
// 当前状态
|
||||
State SocketState
|
||||
|
||||
// 创建时间
|
||||
CreatedAt time.Time
|
||||
|
||||
// 最后活动时间
|
||||
LastActivity time.Time
|
||||
|
||||
// 超时时间
|
||||
Timeout time.Duration
|
||||
|
||||
// 完成通道
|
||||
Done chan struct{}
|
||||
|
||||
// 错误信息
|
||||
Error error
|
||||
|
||||
// 结果数据
|
||||
Result interface{}
|
||||
|
||||
// 是否完成
|
||||
completed int32
|
||||
}
|
||||
|
||||
// IsCompleted 检查操作是否已完成
|
||||
func (op *SocketOperation) IsCompleted() bool {
|
||||
return atomic.LoadInt32(&op.completed) == 1
|
||||
}
|
||||
|
||||
// Complete 标记操作完成
|
||||
func (op *SocketOperation) Complete(result interface{}, err error) {
|
||||
if atomic.CompareAndSwapInt32(&op.completed, 0, 1) {
|
||||
op.Result = result
|
||||
op.Error = err
|
||||
close(op.Done)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait 等待操作完成
|
||||
func (op *SocketOperation) Wait(ctx context.Context) (interface{}, error) {
|
||||
select {
|
||||
case <-op.Done:
|
||||
return op.Result, op.Error
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Touch 更新活动时间
|
||||
func (op *SocketOperation) Touch() {
|
||||
op.LastActivity = time.Now()
|
||||
}
|
||||
|
||||
// CosocketStats Cosocket 统计信息
|
||||
type CosocketStats struct {
|
||||
// 总操作数
|
||||
TotalOperations uint64
|
||||
|
||||
// 活跃操作数
|
||||
ActiveOperations uint64
|
||||
|
||||
// 超时操作数
|
||||
TimeoutOperations uint64
|
||||
|
||||
// 错误操作数
|
||||
ErrorOperations uint64
|
||||
|
||||
// 当前 socket 数
|
||||
ActiveSockets uint64
|
||||
|
||||
// 总创建 socket 数
|
||||
TotalSocketsCreated uint64
|
||||
|
||||
// 总关闭 socket 数
|
||||
TotalSocketsClosed uint64
|
||||
}
|
||||
|
||||
// CosocketManager Cosocket 管理器
|
||||
type CosocketManager struct {
|
||||
// 操作映射
|
||||
operations map[uint64]*SocketOperation
|
||||
|
||||
// 互斥锁
|
||||
mu sync.RWMutex
|
||||
|
||||
// 操作 ID 生成器
|
||||
nextID uint64
|
||||
|
||||
// 超时检查器
|
||||
timeoutChecker *time.Ticker
|
||||
|
||||
// 统计信息
|
||||
stats CosocketStats
|
||||
|
||||
// 上下文
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// 默认超时
|
||||
defaultTimeout time.Duration
|
||||
|
||||
// 清理间隔
|
||||
cleanupInterval time.Duration
|
||||
}
|
||||
|
||||
// DefaultCosocketManager 全局默认管理器
|
||||
var DefaultCosocketManager = NewCosocketManager()
|
||||
|
||||
// NewCosocketManager 创建新的 Cosocket 管理器
|
||||
func NewCosocketManager() *CosocketManager {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cm := &CosocketManager{
|
||||
operations: make(map[uint64]*SocketOperation),
|
||||
nextID: 0,
|
||||
timeoutChecker: time.NewTicker(30 * time.Second),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
defaultTimeout: 60 * time.Second,
|
||||
cleanupInterval: 30 * time.Second,
|
||||
}
|
||||
|
||||
// 启动清理循环
|
||||
go cm.cleanupLoop()
|
||||
|
||||
return cm
|
||||
}
|
||||
|
||||
// StartOperation 开始一个新的 socket 操作
|
||||
func (cm *CosocketManager) StartOperation(socket *TCPSocket, opType OperationType, timeout time.Duration) *SocketOperation {
|
||||
if timeout <= 0 {
|
||||
timeout = cm.defaultTimeout
|
||||
}
|
||||
|
||||
id := atomic.AddUint64(&cm.nextID, 1)
|
||||
now := time.Now()
|
||||
|
||||
op := &SocketOperation{
|
||||
ID: id,
|
||||
Socket: socket,
|
||||
Type: opType,
|
||||
State: SocketStateIdle,
|
||||
CreatedAt: now,
|
||||
LastActivity: now,
|
||||
Timeout: timeout,
|
||||
Done: make(chan struct{}),
|
||||
}
|
||||
|
||||
cm.mu.Lock()
|
||||
cm.operations[id] = op
|
||||
cm.mu.Unlock()
|
||||
|
||||
atomic.AddUint64(&cm.stats.TotalOperations, 1)
|
||||
atomic.AddUint64(&cm.stats.ActiveOperations, 1)
|
||||
|
||||
return op
|
||||
}
|
||||
|
||||
// CompleteOperation 完成操作
|
||||
func (cm *CosocketManager) CompleteOperation(id uint64, result interface{}, err error) {
|
||||
cm.mu.Lock()
|
||||
op, exists := cm.operations[id]
|
||||
if exists {
|
||||
delete(cm.operations, id)
|
||||
}
|
||||
cm.mu.Unlock()
|
||||
|
||||
if exists && op != nil {
|
||||
op.Complete(result, err)
|
||||
atomic.AddUint64(&cm.stats.ActiveOperations, ^uint64(0))
|
||||
if err != nil {
|
||||
atomic.AddUint64(&cm.stats.ErrorOperations, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetOperation 获取操作
|
||||
func (cm *CosocketManager) GetOperation(id uint64) *SocketOperation {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return cm.operations[id]
|
||||
}
|
||||
|
||||
// cleanupLoop 清理循环
|
||||
func (cm *CosocketManager) cleanupLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-cm.ctx.Done():
|
||||
return
|
||||
case <-cm.timeoutChecker.C:
|
||||
cm.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup 清理超时操作
|
||||
func (cm *CosocketManager) cleanup() {
|
||||
now := time.Now()
|
||||
timeoutOps := make([]*SocketOperation, 0)
|
||||
|
||||
cm.mu.RLock()
|
||||
for _, op := range cm.operations {
|
||||
if !op.IsCompleted() && now.Sub(op.LastActivity) > op.Timeout {
|
||||
timeoutOps = append(timeoutOps, op)
|
||||
}
|
||||
}
|
||||
cm.mu.RUnlock()
|
||||
|
||||
for _, op := range timeoutOps {
|
||||
cm.CompleteOperation(op.ID, nil, context.DeadlineExceeded)
|
||||
atomic.AddUint64(&cm.stats.TimeoutOperations, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Stats 获取统计信息
|
||||
func (cm *CosocketManager) Stats() CosocketStats {
|
||||
return CosocketStats{
|
||||
TotalOperations: atomic.LoadUint64(&cm.stats.TotalOperations),
|
||||
ActiveOperations: atomic.LoadUint64(&cm.stats.ActiveOperations),
|
||||
TimeoutOperations: atomic.LoadUint64(&cm.stats.TimeoutOperations),
|
||||
ErrorOperations: atomic.LoadUint64(&cm.stats.ErrorOperations),
|
||||
ActiveSockets: atomic.LoadUint64(&cm.stats.ActiveSockets),
|
||||
TotalSocketsCreated: atomic.LoadUint64(&cm.stats.TotalSocketsCreated),
|
||||
TotalSocketsClosed: atomic.LoadUint64(&cm.stats.TotalSocketsClosed),
|
||||
}
|
||||
}
|
||||
|
||||
// SetDefaultTimeout 设置默认超时
|
||||
func (cm *CosocketManager) SetDefaultTimeout(timeout time.Duration) {
|
||||
cm.defaultTimeout = timeout
|
||||
}
|
||||
|
||||
// Close 关闭管理器
|
||||
func (cm *CosocketManager) Close() {
|
||||
cm.cancel()
|
||||
cm.timeoutChecker.Stop()
|
||||
|
||||
// 取消所有未完成操作
|
||||
cm.mu.Lock()
|
||||
ops := make([]*SocketOperation, 0, len(cm.operations))
|
||||
for _, op := range cm.operations {
|
||||
ops = append(ops, op)
|
||||
}
|
||||
cm.operations = make(map[uint64]*SocketOperation)
|
||||
cm.mu.Unlock()
|
||||
|
||||
for _, op := range ops {
|
||||
op.Complete(nil, context.Canceled)
|
||||
}
|
||||
}
|
||||
|
||||
// TrackSocketCreated 跟踪 socket 创建
|
||||
func (cm *CosocketManager) TrackSocketCreated() {
|
||||
atomic.AddUint64(&cm.stats.TotalSocketsCreated, 1)
|
||||
atomic.AddUint64(&cm.stats.ActiveSockets, 1)
|
||||
}
|
||||
|
||||
// TrackSocketClosed 跟踪 socket 关闭
|
||||
func (cm *CosocketManager) TrackSocketClosed() {
|
||||
atomic.AddUint64(&cm.stats.TotalSocketsClosed, 1)
|
||||
atomic.AddUint64(&cm.stats.ActiveSockets, ^uint64(0))
|
||||
}
|
||||
|
||||
// TCPAddr 解析 TCP 地址
|
||||
func (cm *CosocketManager) TCPAddr(host string, port int) (*net.TCPAddr, error) {
|
||||
return &net.TCPAddr{
|
||||
IP: net.ParseIP(host),
|
||||
Port: port,
|
||||
}, nil
|
||||
}
|
||||
702
internal/lua/socket_test.go
Normal file
702
internal/lua/socket_test.go
Normal file
@ -0,0 +1,702 @@
|
||||
package lua
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockEchoServer 模拟 echo 服务器
|
||||
func mockEchoServer(t *testing.T, addr string) (net.Listener, func()) {
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stop := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(c net.Conn) {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n > 0 {
|
||||
if _, err := c.Write(buf[:n]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
cleanup := func() {
|
||||
close(stop)
|
||||
ln.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
return ln, cleanup
|
||||
}
|
||||
|
||||
// TestCosocketManager_Basic 测试基本功能
|
||||
func TestCosocketManager_Basic(t *testing.T) {
|
||||
cm := NewCosocketManager()
|
||||
defer cm.Close()
|
||||
|
||||
// 测试初始状态
|
||||
stats := cm.Stats()
|
||||
if stats.TotalOperations != 0 {
|
||||
t.Errorf("Expected 0 operations, got %d", stats.TotalOperations)
|
||||
}
|
||||
|
||||
// 测试操作创建
|
||||
socket := NewTCPSocket(cm)
|
||||
defer socket.Close()
|
||||
|
||||
op := cm.StartOperation(socket, OpConnect, 5*time.Second)
|
||||
if op == nil {
|
||||
t.Fatal("Expected non-nil operation")
|
||||
}
|
||||
|
||||
if op.ID == 0 {
|
||||
t.Error("Expected non-zero operation ID")
|
||||
}
|
||||
|
||||
if op.Type != OpConnect {
|
||||
t.Errorf("Expected OpConnect, got %v", op.Type)
|
||||
}
|
||||
|
||||
// 测试统计
|
||||
stats = cm.Stats()
|
||||
if stats.TotalOperations != 1 {
|
||||
t.Errorf("Expected 1 operation, got %d", stats.TotalOperations)
|
||||
}
|
||||
if stats.ActiveOperations != 1 {
|
||||
t.Errorf("Expected 1 active operation, got %d", stats.ActiveOperations)
|
||||
}
|
||||
|
||||
// 测试操作完成
|
||||
cm.CompleteOperation(op.ID, "done", nil)
|
||||
|
||||
stats = cm.Stats()
|
||||
if stats.ActiveOperations != 0 {
|
||||
t.Errorf("Expected 0 active operations after complete, got %d", stats.ActiveOperations)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCosocketManager_Timeout 测试超时机制
|
||||
func TestCosocketManager_Timeout(t *testing.T) {
|
||||
// 创建一个使用短清理间隔的管理器用于测试
|
||||
cm := NewCosocketManager()
|
||||
cm.SetDefaultTimeout(100 * time.Millisecond)
|
||||
cm.cleanupInterval = 50 * time.Millisecond
|
||||
cm.timeoutChecker.Reset(50 * time.Millisecond)
|
||||
defer cm.Close()
|
||||
|
||||
socket := NewTCPSocket(cm)
|
||||
defer socket.Close()
|
||||
|
||||
// 创建一个不完成的操作
|
||||
op := cm.StartOperation(socket, OpConnect, 100*time.Millisecond)
|
||||
|
||||
// 等待超时清理
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// 检查操作是否超时完成
|
||||
if !op.IsCompleted() {
|
||||
t.Error("Expected operation to be completed due to timeout")
|
||||
}
|
||||
|
||||
stats := cm.Stats()
|
||||
if stats.TimeoutOperations != 1 {
|
||||
t.Errorf("Expected 1 timeout operation, got %d", stats.TimeoutOperations)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTCPSocket_Connect 测试 TCP 连接
|
||||
func TestTCPSocket_Connect(t *testing.T) {
|
||||
_, cleanup := mockEchoServer(t, "127.0.0.1:19999")
|
||||
defer cleanup()
|
||||
|
||||
cm := NewCosocketManager()
|
||||
defer cm.Close()
|
||||
|
||||
socket := NewTCPSocket(cm)
|
||||
defer socket.Close()
|
||||
|
||||
// 测试连接
|
||||
err := socket.Connect("127.0.0.1", 19999)
|
||||
if err != nil {
|
||||
t.Fatalf("Connect failed: %v", err)
|
||||
}
|
||||
|
||||
// 等待连接完成
|
||||
op := socket.currentOp
|
||||
if op != nil {
|
||||
result, err := op.Wait(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Connect wait failed: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("Expected non-nil connection")
|
||||
}
|
||||
}
|
||||
|
||||
if socket.State() != SocketStateConnected {
|
||||
t.Errorf("Expected state connected, got %v", socket.State())
|
||||
}
|
||||
}
|
||||
|
||||
// TestTCPSocket_SendReceive 测试发送接收
|
||||
func TestTCPSocket_SendReceive(t *testing.T) {
|
||||
_, cleanup := mockEchoServer(t, "127.0.0.1:19998")
|
||||
defer cleanup()
|
||||
|
||||
cm := NewCosocketManager()
|
||||
defer cm.Close()
|
||||
|
||||
socket := NewTCPSocket(cm)
|
||||
defer socket.Close()
|
||||
|
||||
// 连接
|
||||
if err := socket.Connect("127.0.0.1", 19998); err != nil {
|
||||
t.Fatalf("Connect failed: %v", err)
|
||||
}
|
||||
|
||||
// 等待连接完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 发送数据
|
||||
testData := "Hello, Cosocket!"
|
||||
n, err := socket.Send([]byte(testData))
|
||||
if err != nil {
|
||||
t.Fatalf("Send failed: %v", err)
|
||||
}
|
||||
if n != len(testData) {
|
||||
t.Errorf("Expected %d bytes sent, got %d", len(testData), n)
|
||||
}
|
||||
|
||||
// 接收数据
|
||||
received, err := socket.Receive(1024)
|
||||
if err != nil {
|
||||
t.Fatalf("Receive failed: %v", err)
|
||||
}
|
||||
if string(received) != testData {
|
||||
t.Errorf("Expected '%s', got '%s'", testData, string(received))
|
||||
}
|
||||
}
|
||||
|
||||
// TestTCPSocket_AsyncOperations 测试异步操作
|
||||
func TestTCPSocket_AsyncOperations(t *testing.T) {
|
||||
_, cleanup := mockEchoServer(t, "127.0.0.1:19997")
|
||||
defer cleanup()
|
||||
|
||||
cm := NewCosocketManager()
|
||||
defer cm.Close()
|
||||
|
||||
socket := NewTCPSocket(cm)
|
||||
defer socket.Close()
|
||||
|
||||
// 测试异步连接
|
||||
err := socket.Connect("127.0.0.1", 19997)
|
||||
if err != nil {
|
||||
t.Fatalf("Connect failed: %v", err)
|
||||
}
|
||||
|
||||
// 等待连接完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 测试异步发送
|
||||
testData := "Async test"
|
||||
sendOp, err := socket.SendAsync([]byte(testData))
|
||||
if err != nil {
|
||||
t.Fatalf("SendAsync failed: %v", err)
|
||||
}
|
||||
|
||||
result, err := sendOp.Wait(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Send wait failed: %v", err)
|
||||
}
|
||||
if n, ok := result.(int); !ok || n != len(testData) {
|
||||
t.Errorf("Expected %d bytes, got %v", len(testData), result)
|
||||
}
|
||||
|
||||
// 测试异步接收
|
||||
recvOp, err := socket.ReceiveAsync(1024)
|
||||
if err != nil {
|
||||
t.Fatalf("ReceiveAsync failed: %v", err)
|
||||
}
|
||||
|
||||
result, err = recvOp.Wait(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Receive wait failed: %v", err)
|
||||
}
|
||||
if data, ok := result.([]byte); !ok || string(data) != testData {
|
||||
t.Errorf("Expected '%s', got %v", testData, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTCPSocket_ReceiveUntil 测试接收直到特定模式
|
||||
func TestTCPSocket_ReceiveUntil(t *testing.T) {
|
||||
_, cleanup := mockEchoServer(t, "127.0.0.1:19996")
|
||||
defer cleanup()
|
||||
|
||||
cm := NewCosocketManager()
|
||||
defer cm.Close()
|
||||
|
||||
socket := NewTCPSocket(cm)
|
||||
defer socket.Close()
|
||||
|
||||
// 连接
|
||||
if err := socket.Connect("127.0.0.1", 19996); err != nil {
|
||||
t.Fatalf("Connect failed: %v", err)
|
||||
}
|
||||
|
||||
// 等待连接完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 发送带换行的数据
|
||||
testData := "Line1\nLine2\nLine3\n"
|
||||
_, err := socket.Send([]byte(testData))
|
||||
if err != nil {
|
||||
t.Fatalf("Send failed: %v", err)
|
||||
}
|
||||
|
||||
// 接收直到换行
|
||||
data, err := socket.ReceiveUntil("\n", true)
|
||||
if err != nil {
|
||||
t.Fatalf("ReceiveUntil failed: %v", err)
|
||||
}
|
||||
if string(data) != "Line1\n" {
|
||||
t.Errorf("Expected 'Line1\\n', got '%s'", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
// TestTCPSocket_Close 测试关闭
|
||||
func TestTCPSocket_Close(t *testing.T) {
|
||||
cm := NewCosocketManager()
|
||||
defer cm.Close()
|
||||
|
||||
socket := NewTCPSocket(cm)
|
||||
|
||||
if socket.IsClosed() {
|
||||
t.Error("Socket should not be closed initially")
|
||||
}
|
||||
|
||||
err := socket.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
if !socket.IsClosed() {
|
||||
t.Error("Socket should be closed")
|
||||
}
|
||||
|
||||
// 重复关闭应该返回 nil
|
||||
err = socket.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Second close should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTCPSocket_StateTransitions 测试状态转换
|
||||
func TestTCPSocket_StateTransitions(t *testing.T) {
|
||||
_, cleanup := mockEchoServer(t, "127.0.0.1:19995")
|
||||
defer cleanup()
|
||||
|
||||
cm := NewCosocketManager()
|
||||
defer cm.Close()
|
||||
|
||||
socket := NewTCPSocket(cm)
|
||||
defer socket.Close()
|
||||
|
||||
// 初始状态
|
||||
if socket.State() != SocketStateIdle {
|
||||
t.Errorf("Expected idle state, got %v", socket.State())
|
||||
}
|
||||
|
||||
// 连接中
|
||||
socket.Connect("127.0.0.1", 19995)
|
||||
if socket.State() != SocketStateConnecting {
|
||||
t.Errorf("Expected connecting state, got %v", socket.State())
|
||||
}
|
||||
|
||||
// 等待连接完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if socket.State() != SocketStateConnected {
|
||||
t.Errorf("Expected connected state, got %v", socket.State())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCosocketManager_Concurrent 测试并发操作
|
||||
func TestCosocketManager_Concurrent(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping concurrent test in short mode")
|
||||
}
|
||||
|
||||
_, cleanup := mockEchoServer(t, "127.0.0.1:19994")
|
||||
defer cleanup()
|
||||
|
||||
cm := NewCosocketManager()
|
||||
defer cm.Close()
|
||||
|
||||
const numSockets = 1000
|
||||
const numGoroutines = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, numSockets)
|
||||
var completed int32
|
||||
|
||||
// 并发创建 socket 和连接
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(start int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numSockets/numGoroutines; j++ {
|
||||
socket := NewTCPSocket(cm)
|
||||
if err := socket.Connect("127.0.0.1", 19994); err != nil {
|
||||
errors <- fmt.Errorf("connect failed: %v", err)
|
||||
socket.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
// 等待连接
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 发送数据
|
||||
data := fmt.Sprintf("Test%d", start+j)
|
||||
if _, err := socket.Send([]byte(data)); err != nil {
|
||||
errors <- fmt.Errorf("send failed: %v", err)
|
||||
socket.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
socket.Close()
|
||||
atomic.AddInt32(&completed, 1)
|
||||
}
|
||||
}(i * (numSockets / numGoroutines))
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// 检查错误
|
||||
close(errors)
|
||||
errCount := 0
|
||||
for err := range errors {
|
||||
t.Logf("Error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
t.Logf("Completed: %d/%d, Errors: %d", completed, numSockets, errCount)
|
||||
|
||||
// 检查统计
|
||||
stats := cm.Stats()
|
||||
t.Logf("Stats: %+v", stats)
|
||||
}
|
||||
|
||||
// TestCosocketManager_MemoryLeak 测试内存泄漏
|
||||
func TestCosocketManager_MemoryLeak(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping memory leak test in short mode")
|
||||
}
|
||||
|
||||
_, cleanup := mockEchoServer(t, "127.0.0.1:19993")
|
||||
defer cleanup()
|
||||
|
||||
cm := NewCosocketManager()
|
||||
defer cm.Close()
|
||||
|
||||
// 记录初始 goroutine 数
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
// 创建和关闭大量 socket
|
||||
for i := 0; i < 10000; i++ {
|
||||
socket := NewTCPSocket(cm)
|
||||
// 使用同步连接避免竞态
|
||||
socket.Connect("127.0.0.1", 19993)
|
||||
time.Sleep(time.Millisecond) // 给连接时间完成
|
||||
socket.Close()
|
||||
}
|
||||
|
||||
// 强制 GC
|
||||
runtime.GC()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 检查 goroutine 数量
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
t.Logf("Initial goroutines: %d, Final goroutines: %d", initialGoroutines, finalGoroutines)
|
||||
|
||||
// 允许一定的波动
|
||||
if finalGoroutines > initialGoroutines+100 {
|
||||
t.Errorf("Possible goroutine leak: started with %d, ended with %d", initialGoroutines, finalGoroutines)
|
||||
}
|
||||
|
||||
// 检查统计
|
||||
stats := cm.Stats()
|
||||
if stats.ActiveSockets > 100 {
|
||||
t.Errorf("Active sockets leak: %d", stats.ActiveSockets)
|
||||
}
|
||||
if stats.ActiveOperations > 100 {
|
||||
t.Errorf("Active operations leak: %d", stats.ActiveOperations)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCosocketManager_LongRunning 测试长时间运行
|
||||
func TestCosocketManager_LongRunning(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping long running test in short mode")
|
||||
}
|
||||
|
||||
_, cleanup := mockEchoServer(t, "127.0.0.1:19992")
|
||||
defer cleanup()
|
||||
|
||||
cm := NewCosocketManager()
|
||||
defer cm.Close()
|
||||
|
||||
duration := 10 * time.Second // 缩短到 10 秒进行测试
|
||||
interval := 100 * time.Millisecond
|
||||
|
||||
var totalOps int32
|
||||
start := time.Now()
|
||||
|
||||
for time.Since(start) < duration {
|
||||
socket := NewTCPSocket(cm)
|
||||
if err := socket.Connect("127.0.0.1", 19992); err != nil {
|
||||
t.Logf("Connect error: %v", err)
|
||||
socket.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
// 等待连接
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 发送接收
|
||||
if _, err := socket.Send([]byte("test")); err == nil {
|
||||
socket.Receive(1024)
|
||||
}
|
||||
|
||||
socket.Close()
|
||||
atomic.AddInt32(&totalOps, 1)
|
||||
time.Sleep(interval)
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
t.Logf("Completed %d operations in %v", totalOps, elapsed)
|
||||
|
||||
// 检查最终统计
|
||||
stats := cm.Stats()
|
||||
t.Logf("Final stats: %+v", stats)
|
||||
|
||||
if stats.ActiveSockets > 0 {
|
||||
t.Errorf("Expected 0 active sockets, got %d", stats.ActiveSockets)
|
||||
}
|
||||
if stats.ActiveOperations > 0 {
|
||||
t.Errorf("Expected 0 active operations, got %d", stats.ActiveOperations)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCosocket_Connect 基准测试:连接
|
||||
func BenchmarkCosocket_Connect(b *testing.B) {
|
||||
_, cleanup := mockEchoServer(nil, "127.0.0.1:19991")
|
||||
defer cleanup()
|
||||
|
||||
cm := NewCosocketManager()
|
||||
defer cm.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
socket := NewTCPSocket(cm)
|
||||
socket.Connect("127.0.0.1", 19991)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
socket.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkCosocket_SendReceive 基准测试:发送接收
|
||||
func BenchmarkCosocket_SendReceive(b *testing.B) {
|
||||
_, cleanup := mockEchoServer(nil, "127.0.0.1:19990")
|
||||
defer cleanup()
|
||||
|
||||
cm := NewCosocketManager()
|
||||
defer cm.Close()
|
||||
|
||||
// 预先连接
|
||||
socket := NewTCPSocket(cm)
|
||||
socket.Connect("127.0.0.1", 19990)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
socket.Send([]byte("benchmark"))
|
||||
socket.Receive(1024)
|
||||
}
|
||||
b.StopTimer()
|
||||
|
||||
socket.Close()
|
||||
}
|
||||
|
||||
// TestLuaAPI_TCPSocket 测试 Lua API
|
||||
func TestLuaAPI_TCPSocket(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Lua API test in short mode")
|
||||
}
|
||||
|
||||
// 创建引擎
|
||||
engine, err := NewEngine(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create engine: %v", err)
|
||||
}
|
||||
defer engine.Close()
|
||||
|
||||
// 注册 TCP socket API
|
||||
RegisterTCPSocketAPI(engine.L, engine)
|
||||
|
||||
// 测试创建 socket
|
||||
script := `
|
||||
local sock = ngx.socket.tcp()
|
||||
if not sock then
|
||||
return nil, "failed to create socket"
|
||||
end
|
||||
return "ok"
|
||||
`
|
||||
|
||||
coro, err := engine.NewCoroutine(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create coroutine: %v", err)
|
||||
}
|
||||
defer coro.Close()
|
||||
|
||||
if err := coro.SetupSandbox(); err != nil {
|
||||
t.Fatalf("Failed to setup sandbox: %v", err)
|
||||
}
|
||||
|
||||
err = coro.Execute(script)
|
||||
if err != nil {
|
||||
t.Errorf("Execute failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCosocketManager_Stress 压力测试
|
||||
func TestCosocketManager_Stress(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping stress test in short mode")
|
||||
}
|
||||
|
||||
// 创建多个 echo 服务器
|
||||
ports := []int{19980, 19981, 19982, 19983}
|
||||
cleanups := make([]func(), len(ports))
|
||||
for i, port := range ports {
|
||||
_, cleanups[i] = mockEchoServer(t, fmt.Sprintf("127.0.0.1:%d", port))
|
||||
}
|
||||
defer func() {
|
||||
for _, c := range cleanups {
|
||||
c()
|
||||
}
|
||||
}()
|
||||
|
||||
cm := NewCosocketManager()
|
||||
defer cm.Close()
|
||||
|
||||
const totalConnections = 10000
|
||||
const concurrency = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var successCount int32
|
||||
var errorCount int32
|
||||
var latencySum int64
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// 使用信号量限制并发
|
||||
sem := make(chan struct{}, concurrency)
|
||||
|
||||
for i := 0; i < totalConnections; i++ {
|
||||
wg.Add(1)
|
||||
sem <- struct{}{} // 获取信号量
|
||||
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }() // 释放信号量
|
||||
|
||||
port := ports[idx%len(ports)]
|
||||
socket := NewTCPSocket(cm)
|
||||
|
||||
opStart := time.Now()
|
||||
err := socket.Connect("127.0.0.1", port)
|
||||
if err != nil {
|
||||
atomic.AddInt32(&errorCount, 1)
|
||||
socket.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// 等待连接完成
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// 简单数据交换
|
||||
if _, err := socket.Send([]byte("hello")); err == nil {
|
||||
socket.Receive(1024)
|
||||
}
|
||||
|
||||
socket.Close()
|
||||
|
||||
latency := time.Since(opStart).Milliseconds()
|
||||
atomic.AddInt64(&latencySum, latency)
|
||||
atomic.AddInt32(&successCount, 1)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
t.Logf("Stress test completed:")
|
||||
t.Logf(" Total: %d, Success: %d, Errors: %d", totalConnections, successCount, errorCount)
|
||||
t.Logf(" Duration: %v", elapsed)
|
||||
t.Logf(" RPS: %.2f", float64(totalConnections)/elapsed.Seconds())
|
||||
if successCount > 0 {
|
||||
t.Logf(" Avg Latency: %dms", latencySum/int64(successCount))
|
||||
}
|
||||
|
||||
// 内存检查
|
||||
var m runtime.MemStats
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&m)
|
||||
t.Logf(" Memory: %.2f MB", float64(m.HeapAlloc)/(1024*1024))
|
||||
|
||||
// 验证没有资源泄漏
|
||||
stats := cm.Stats()
|
||||
t.Logf(" Active sockets: %d, Active operations: %d", stats.ActiveSockets, stats.ActiveOperations)
|
||||
|
||||
if errorCount > totalConnections/10 { // 允许 10% 错误率
|
||||
t.Errorf("Too many errors: %d", errorCount)
|
||||
}
|
||||
|
||||
if stats.ActiveSockets > 100 {
|
||||
t.Errorf("Socket leak detected: %d active sockets", stats.ActiveSockets)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user