From 5e3196c37e4cea15c7bd4ff639ec94e51468ca68 Mon Sep 17 00:00:00 2001 From: xfy Date: Fri, 5 Jun 2026 12:31:39 +0800 Subject: [PATCH] fix: resolve race conditions in handler sendfile and lua cosocket tests --- internal/handler/sendfile_test.go | 152 +++++++++++++++------------- internal/lua/api_socket_tcp_test.go | 2 + 2 files changed, 81 insertions(+), 73 deletions(-) diff --git a/internal/handler/sendfile_test.go b/internal/handler/sendfile_test.go index cb7a0fd..247469f 100644 --- a/internal/handler/sendfile_test.go +++ b/internal/handler/sendfile_test.go @@ -309,46 +309,43 @@ func TestSendFile_LargeFile(t *testing.T) { } defer ln.Close() - // 启动 goroutine 接收连接 - var serverConn net.Conn + connCh := make(chan net.Conn, 1) go func() { - serverConn, _ = ln.Accept() + c, err := ln.Accept() + if err == nil { + connCh <- c + } }() - // 客户端连接 clientConn, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatalf("Failed to dial: %v", err) } defer clientConn.Close() - // 等待服务器接受 - time.Sleep(100 * time.Millisecond) - - // 将客户端连接设置为非阻塞以便测试 sendfile if err := clientConn.SetDeadline(time.Now().Add(2 * time.Second)); err != nil { t.Fatalf("Failed to set deadline: %v", err) } - // 构造 RequestCtx ctx := &fasthttp.RequestCtx{} ctx.Init(&fasthttp.Request{}, nil, nil) ctx.Request.Header.SetMethod("GET") ctx.Request.SetRequestURI("/test") - // 发送大文件(应使用 sendfile) err = SendFile(ctx, file, 0, int64(len(content))) if err != nil { t.Logf("SendFile returned: %v", err) - // EPIPE 是可接受的,因为服务器可能在读取后关闭连接 if err != syscall.EPIPE && err != syscall.ECONNRESET { t.Errorf("SendFile unexpected error: %v", err) } } - // 关闭服务器连接 - if serverConn != nil { - serverConn.Close() + select { + case sc := <-connCh: + if sc != nil { + sc.Close() + } + case <-time.After(3 * time.Second): } } @@ -438,23 +435,20 @@ func TestGetSocketFd_UnixConn(t *testing.T) { defer ln.Close() defer os.Remove(socketPath) - // 启动 goroutine 接收连接 - var serverConn net.Conn + connCh := make(chan net.Conn, 1) go func() { - serverConn, _ = ln.Accept() + c, err := ln.Accept() + if err == nil { + connCh <- c + } }() - // 客户端连接 clientConn, err := net.Dial("unix", socketPath) if err != nil { t.Fatalf("Failed to dial unix socket: %v", err) } defer clientConn.Close() - // 等待连接建立 - time.Sleep(100 * time.Millisecond) - - // 测试获取 socket fd fd, err := getSocketFd(clientConn) if err != nil { t.Errorf("getSocketFd failed: %v", err) @@ -463,8 +457,12 @@ func TestGetSocketFd_UnixConn(t *testing.T) { t.Error("Expected non-zero fd") } - if serverConn != nil { - serverConn.Close() + select { + case sc := <-connCh: + if sc != nil { + sc.Close() + } + case <-time.After(3 * time.Second): } } @@ -538,24 +536,22 @@ func TestSendFile_AtMinBoundary(t *testing.T) { } defer ln.Close() - // 启动 goroutine 接收连接 - var serverConn net.Conn go func() { - serverConn, _ = ln.Accept() + c, err := ln.Accept() + if err != nil { + return + } buf := make([]byte, MinSendfileSize) - serverConn.Read(buf) - serverConn.Close() + c.Read(buf) + c.Close() }() - // 客户端连接 clientConn, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatalf("Failed to dial: %v", err) } defer clientConn.Close() - time.Sleep(100 * time.Millisecond) - ctx := &fasthttp.RequestCtx{} ctx.Init(&fasthttp.Request{}, nil, nil) ctx.Request.Header.SetMethod("GET") @@ -567,10 +563,6 @@ func TestSendFile_AtMinBoundary(t *testing.T) { t.Logf("SendFile returned: %v", err) } } - - if serverConn != nil { - serverConn.Close() - } } // TestSendFile_JustBelowMin 测试刚好小于 MinSendfileSize 的文件(使用 fallback) @@ -610,23 +602,20 @@ func TestGetSocketFd_TCPConn(t *testing.T) { } defer ln.Close() - // 启动 goroutine 接收连接 - var serverConn net.Conn + connCh := make(chan net.Conn, 1) go func() { - serverConn, _ = ln.Accept() + c, err := ln.Accept() + if err == nil { + connCh <- c + } }() - // 客户端连接 clientConn, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatalf("Failed to dial: %v", err) } defer clientConn.Close() - // 等待连接建立 - time.Sleep(100 * time.Millisecond) - - // 测试获取 socket fd fd, err := getSocketFd(clientConn) if err != nil { t.Errorf("getSocketFd failed for TCPConn: %v", err) @@ -635,8 +624,12 @@ func TestGetSocketFd_TCPConn(t *testing.T) { t.Error("Expected non-zero fd for TCPConn") } - if serverConn != nil { - serverConn.Close() + select { + case sc := <-connCh: + if sc != nil { + sc.Close() + } + case <-time.After(3 * time.Second): } } @@ -667,14 +660,15 @@ func TestLinuxSendfile_WithTCPConn(t *testing.T) { } defer ln.Close() - var serverConn net.Conn var wg sync.WaitGroup wg.Go(func() { - serverConn, _ = ln.Accept() - // 读取所有数据 + c, err := ln.Accept() + if err != nil { + return + } buf := make([]byte, len(content)) - _, _ = io.ReadFull(serverConn, buf) - serverConn.Close() + _, _ = io.ReadFull(c, buf) + c.Close() }) clientConn, err := net.Dial("tcp", ln.Addr().String()) @@ -729,9 +723,12 @@ func TestLinuxSendfile_InvalidFileFd(t *testing.T) { } defer ln.Close() - var serverConn net.Conn + connCh := make(chan net.Conn, 1) go func() { - serverConn, _ = ln.Accept() + c, err := ln.Accept() + if err == nil { + connCh <- c + } }() clientConn, err := net.Dial("tcp", ln.Addr().String()) @@ -740,16 +737,17 @@ func TestLinuxSendfile_InvalidFileFd(t *testing.T) { } defer clientConn.Close() - time.Sleep(100 * time.Millisecond) - - // 使用无效的文件描述符 err = linuxSendfile(clientConn, uintptr(99999), 0, 1024) if err == nil { t.Error("Expected error for invalid file descriptor") } - if serverConn != nil { - serverConn.Close() + select { + case sc := <-connCh: + if sc != nil { + sc.Close() + } + case <-time.After(3 * time.Second): } } @@ -772,9 +770,12 @@ func TestLinuxSendfile_ZeroLength(t *testing.T) { } defer ln.Close() - var serverConn net.Conn + connCh := make(chan net.Conn, 1) go func() { - serverConn, _ = ln.Accept() + c, err := ln.Accept() + if err == nil { + connCh <- c + } }() clientConn, err := net.Dial("tcp", ln.Addr().String()) @@ -783,16 +784,17 @@ func TestLinuxSendfile_ZeroLength(t *testing.T) { } defer clientConn.Close() - time.Sleep(100 * time.Millisecond) - - // 零长度应该立即返回 err = linuxSendfile(clientConn, file.Fd(), 0, 0) if err != nil { t.Errorf("Expected nil for zero length, got: %v", err) } - if serverConn != nil { - serverConn.Close() + select { + case sc := <-connCh: + if sc != nil { + sc.Close() + } + case <-time.After(3 * time.Second): } } @@ -822,15 +824,17 @@ func TestLinuxSendfile_PartialTransfer(t *testing.T) { } defer ln.Close() - var serverConn net.Conn var received []byte var wg sync.WaitGroup wg.Go(func() { - serverConn, _ = ln.Accept() + c, err := ln.Accept() + if err != nil { + return + } buf := make([]byte, len(content)) - n, _ := serverConn.Read(buf) + n, _ := c.Read(buf) received = buf[:n] - serverConn.Close() + c.Close() }) clientConn, err := net.Dial("tcp", ln.Addr().String()) @@ -882,13 +886,15 @@ func TestLinuxSendfile_WithOffset(t *testing.T) { } defer ln.Close() - var serverConn net.Conn var wg sync.WaitGroup wg.Go(func() { - serverConn, _ = ln.Accept() + c, err := ln.Accept() + if err != nil { + return + } buf := make([]byte, 8*1024) - _, _ = serverConn.Read(buf) - serverConn.Close() + _, _ = c.Read(buf) + c.Close() }) clientConn, err := net.Dial("tcp", ln.Addr().String()) diff --git a/internal/lua/api_socket_tcp_test.go b/internal/lua/api_socket_tcp_test.go index 3245d7e..f5426af 100644 --- a/internal/lua/api_socket_tcp_test.go +++ b/internal/lua/api_socket_tcp_test.go @@ -248,7 +248,9 @@ func TestTCPSocket_Connect_Failure(t *testing.T) { require.NoError(t, err) // Connect 本身不报错 // 等待异步连接完成 + socket.mu.RLock() op := socket.currentOp + socket.mu.RUnlock() if op != nil { _, err := op.Wait(context.Background()) assert.Error(t, err) // 连接应该失败