fix: resolve race conditions in handler sendfile and lua cosocket tests

This commit is contained in:
xfy 2026-06-05 12:31:39 +08:00
parent f73a761632
commit 5e3196c37e
2 changed files with 81 additions and 73 deletions

View File

@ -309,46 +309,43 @@ func TestSendFile_LargeFile(t *testing.T) {
}
defer ln.Close()
// 启动 goroutine 接收连接
var serverConn net.Conn
connCh := make(chan net.Conn, 1)
go func() {
serverConn, _ = ln.Accept()
c, err := ln.Accept()
if err == nil {
connCh <- c
}
}()
// 客户端连接
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer clientConn.Close()
// 等待服务器接受
time.Sleep(100 * time.Millisecond)
// 将客户端连接设置为非阻塞以便测试 sendfile
if err := clientConn.SetDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("Failed to set deadline: %v", err)
}
// 构造 RequestCtx
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/test")
// 发送大文件(应使用 sendfile
err = SendFile(ctx, file, 0, int64(len(content)))
if err != nil {
t.Logf("SendFile returned: %v", err)
// EPIPE 是可接受的,因为服务器可能在读取后关闭连接
if err != syscall.EPIPE && err != syscall.ECONNRESET {
t.Errorf("SendFile unexpected error: %v", err)
}
}
// 关闭服务器连接
if serverConn != nil {
serverConn.Close()
select {
case sc := <-connCh:
if sc != nil {
sc.Close()
}
case <-time.After(3 * time.Second):
}
}
@ -438,23 +435,20 @@ func TestGetSocketFd_UnixConn(t *testing.T) {
defer ln.Close()
defer os.Remove(socketPath)
// 启动 goroutine 接收连接
var serverConn net.Conn
connCh := make(chan net.Conn, 1)
go func() {
serverConn, _ = ln.Accept()
c, err := ln.Accept()
if err == nil {
connCh <- c
}
}()
// 客户端连接
clientConn, err := net.Dial("unix", socketPath)
if err != nil {
t.Fatalf("Failed to dial unix socket: %v", err)
}
defer clientConn.Close()
// 等待连接建立
time.Sleep(100 * time.Millisecond)
// 测试获取 socket fd
fd, err := getSocketFd(clientConn)
if err != nil {
t.Errorf("getSocketFd failed: %v", err)
@ -463,8 +457,12 @@ func TestGetSocketFd_UnixConn(t *testing.T) {
t.Error("Expected non-zero fd")
}
if serverConn != nil {
serverConn.Close()
select {
case sc := <-connCh:
if sc != nil {
sc.Close()
}
case <-time.After(3 * time.Second):
}
}
@ -538,24 +536,22 @@ func TestSendFile_AtMinBoundary(t *testing.T) {
}
defer ln.Close()
// 启动 goroutine 接收连接
var serverConn net.Conn
go func() {
serverConn, _ = ln.Accept()
c, err := ln.Accept()
if err != nil {
return
}
buf := make([]byte, MinSendfileSize)
serverConn.Read(buf)
serverConn.Close()
c.Read(buf)
c.Close()
}()
// 客户端连接
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer clientConn.Close()
time.Sleep(100 * time.Millisecond)
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
ctx.Request.Header.SetMethod("GET")
@ -567,10 +563,6 @@ func TestSendFile_AtMinBoundary(t *testing.T) {
t.Logf("SendFile returned: %v", err)
}
}
if serverConn != nil {
serverConn.Close()
}
}
// TestSendFile_JustBelowMin 测试刚好小于 MinSendfileSize 的文件(使用 fallback
@ -610,23 +602,20 @@ func TestGetSocketFd_TCPConn(t *testing.T) {
}
defer ln.Close()
// 启动 goroutine 接收连接
var serverConn net.Conn
connCh := make(chan net.Conn, 1)
go func() {
serverConn, _ = ln.Accept()
c, err := ln.Accept()
if err == nil {
connCh <- c
}
}()
// 客户端连接
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer clientConn.Close()
// 等待连接建立
time.Sleep(100 * time.Millisecond)
// 测试获取 socket fd
fd, err := getSocketFd(clientConn)
if err != nil {
t.Errorf("getSocketFd failed for TCPConn: %v", err)
@ -635,8 +624,12 @@ func TestGetSocketFd_TCPConn(t *testing.T) {
t.Error("Expected non-zero fd for TCPConn")
}
if serverConn != nil {
serverConn.Close()
select {
case sc := <-connCh:
if sc != nil {
sc.Close()
}
case <-time.After(3 * time.Second):
}
}
@ -667,14 +660,15 @@ func TestLinuxSendfile_WithTCPConn(t *testing.T) {
}
defer ln.Close()
var serverConn net.Conn
var wg sync.WaitGroup
wg.Go(func() {
serverConn, _ = ln.Accept()
// 读取所有数据
c, err := ln.Accept()
if err != nil {
return
}
buf := make([]byte, len(content))
_, _ = io.ReadFull(serverConn, buf)
serverConn.Close()
_, _ = io.ReadFull(c, buf)
c.Close()
})
clientConn, err := net.Dial("tcp", ln.Addr().String())
@ -729,9 +723,12 @@ func TestLinuxSendfile_InvalidFileFd(t *testing.T) {
}
defer ln.Close()
var serverConn net.Conn
connCh := make(chan net.Conn, 1)
go func() {
serverConn, _ = ln.Accept()
c, err := ln.Accept()
if err == nil {
connCh <- c
}
}()
clientConn, err := net.Dial("tcp", ln.Addr().String())
@ -740,16 +737,17 @@ func TestLinuxSendfile_InvalidFileFd(t *testing.T) {
}
defer clientConn.Close()
time.Sleep(100 * time.Millisecond)
// 使用无效的文件描述符
err = linuxSendfile(clientConn, uintptr(99999), 0, 1024)
if err == nil {
t.Error("Expected error for invalid file descriptor")
}
if serverConn != nil {
serverConn.Close()
select {
case sc := <-connCh:
if sc != nil {
sc.Close()
}
case <-time.After(3 * time.Second):
}
}
@ -772,9 +770,12 @@ func TestLinuxSendfile_ZeroLength(t *testing.T) {
}
defer ln.Close()
var serverConn net.Conn
connCh := make(chan net.Conn, 1)
go func() {
serverConn, _ = ln.Accept()
c, err := ln.Accept()
if err == nil {
connCh <- c
}
}()
clientConn, err := net.Dial("tcp", ln.Addr().String())
@ -783,16 +784,17 @@ func TestLinuxSendfile_ZeroLength(t *testing.T) {
}
defer clientConn.Close()
time.Sleep(100 * time.Millisecond)
// 零长度应该立即返回
err = linuxSendfile(clientConn, file.Fd(), 0, 0)
if err != nil {
t.Errorf("Expected nil for zero length, got: %v", err)
}
if serverConn != nil {
serverConn.Close()
select {
case sc := <-connCh:
if sc != nil {
sc.Close()
}
case <-time.After(3 * time.Second):
}
}
@ -822,15 +824,17 @@ func TestLinuxSendfile_PartialTransfer(t *testing.T) {
}
defer ln.Close()
var serverConn net.Conn
var received []byte
var wg sync.WaitGroup
wg.Go(func() {
serverConn, _ = ln.Accept()
c, err := ln.Accept()
if err != nil {
return
}
buf := make([]byte, len(content))
n, _ := serverConn.Read(buf)
n, _ := c.Read(buf)
received = buf[:n]
serverConn.Close()
c.Close()
})
clientConn, err := net.Dial("tcp", ln.Addr().String())
@ -882,13 +886,15 @@ func TestLinuxSendfile_WithOffset(t *testing.T) {
}
defer ln.Close()
var serverConn net.Conn
var wg sync.WaitGroup
wg.Go(func() {
serverConn, _ = ln.Accept()
c, err := ln.Accept()
if err != nil {
return
}
buf := make([]byte, 8*1024)
_, _ = serverConn.Read(buf)
serverConn.Close()
_, _ = c.Read(buf)
c.Close()
})
clientConn, err := net.Dial("tcp", ln.Addr().String())

View File

@ -248,7 +248,9 @@ func TestTCPSocket_Connect_Failure(t *testing.T) {
require.NoError(t, err) // Connect 本身不报错
// 等待异步连接完成
socket.mu.RLock()
op := socket.currentOp
socket.mu.RUnlock()
if op != nil {
_, err := op.Wait(context.Background())
assert.Error(t, err) // 连接应该失败