fix: resolve race conditions in handler sendfile and lua cosocket tests
This commit is contained in:
parent
f73a761632
commit
5e3196c37e
@ -309,46 +309,43 @@ func TestSendFile_LargeFile(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
|
|
||||||
// 启动 goroutine 接收连接
|
connCh := make(chan net.Conn, 1)
|
||||||
var serverConn net.Conn
|
|
||||||
go func() {
|
go func() {
|
||||||
serverConn, _ = ln.Accept()
|
c, err := ln.Accept()
|
||||||
|
if err == nil {
|
||||||
|
connCh <- c
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 客户端连接
|
|
||||||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to dial: %v", err)
|
t.Fatalf("Failed to dial: %v", err)
|
||||||
}
|
}
|
||||||
defer clientConn.Close()
|
defer clientConn.Close()
|
||||||
|
|
||||||
// 等待服务器接受
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
// 将客户端连接设置为非阻塞以便测试 sendfile
|
|
||||||
if err := clientConn.SetDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
if err := clientConn.SetDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||||
t.Fatalf("Failed to set deadline: %v", err)
|
t.Fatalf("Failed to set deadline: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构造 RequestCtx
|
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod("GET")
|
||||||
ctx.Request.SetRequestURI("/test")
|
ctx.Request.SetRequestURI("/test")
|
||||||
|
|
||||||
// 发送大文件(应使用 sendfile)
|
|
||||||
err = SendFile(ctx, file, 0, int64(len(content)))
|
err = SendFile(ctx, file, 0, int64(len(content)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("SendFile returned: %v", err)
|
t.Logf("SendFile returned: %v", err)
|
||||||
// EPIPE 是可接受的,因为服务器可能在读取后关闭连接
|
|
||||||
if err != syscall.EPIPE && err != syscall.ECONNRESET {
|
if err != syscall.EPIPE && err != syscall.ECONNRESET {
|
||||||
t.Errorf("SendFile unexpected error: %v", err)
|
t.Errorf("SendFile unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 关闭服务器连接
|
select {
|
||||||
if serverConn != nil {
|
case sc := <-connCh:
|
||||||
serverConn.Close()
|
if sc != nil {
|
||||||
|
sc.Close()
|
||||||
|
}
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -438,23 +435,20 @@ func TestGetSocketFd_UnixConn(t *testing.T) {
|
|||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
defer os.Remove(socketPath)
|
defer os.Remove(socketPath)
|
||||||
|
|
||||||
// 启动 goroutine 接收连接
|
connCh := make(chan net.Conn, 1)
|
||||||
var serverConn net.Conn
|
|
||||||
go func() {
|
go func() {
|
||||||
serverConn, _ = ln.Accept()
|
c, err := ln.Accept()
|
||||||
|
if err == nil {
|
||||||
|
connCh <- c
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 客户端连接
|
|
||||||
clientConn, err := net.Dial("unix", socketPath)
|
clientConn, err := net.Dial("unix", socketPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to dial unix socket: %v", err)
|
t.Fatalf("Failed to dial unix socket: %v", err)
|
||||||
}
|
}
|
||||||
defer clientConn.Close()
|
defer clientConn.Close()
|
||||||
|
|
||||||
// 等待连接建立
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
// 测试获取 socket fd
|
|
||||||
fd, err := getSocketFd(clientConn)
|
fd, err := getSocketFd(clientConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("getSocketFd failed: %v", err)
|
t.Errorf("getSocketFd failed: %v", err)
|
||||||
@ -463,8 +457,12 @@ func TestGetSocketFd_UnixConn(t *testing.T) {
|
|||||||
t.Error("Expected non-zero fd")
|
t.Error("Expected non-zero fd")
|
||||||
}
|
}
|
||||||
|
|
||||||
if serverConn != nil {
|
select {
|
||||||
serverConn.Close()
|
case sc := <-connCh:
|
||||||
|
if sc != nil {
|
||||||
|
sc.Close()
|
||||||
|
}
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -538,24 +536,22 @@ func TestSendFile_AtMinBoundary(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
|
|
||||||
// 启动 goroutine 接收连接
|
|
||||||
var serverConn net.Conn
|
|
||||||
go func() {
|
go func() {
|
||||||
serverConn, _ = ln.Accept()
|
c, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
buf := make([]byte, MinSendfileSize)
|
buf := make([]byte, MinSendfileSize)
|
||||||
serverConn.Read(buf)
|
c.Read(buf)
|
||||||
serverConn.Close()
|
c.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 客户端连接
|
|
||||||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to dial: %v", err)
|
t.Fatalf("Failed to dial: %v", err)
|
||||||
}
|
}
|
||||||
defer clientConn.Close()
|
defer clientConn.Close()
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
ctx := &fasthttp.RequestCtx{}
|
ctx := &fasthttp.RequestCtx{}
|
||||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||||
ctx.Request.Header.SetMethod("GET")
|
ctx.Request.Header.SetMethod("GET")
|
||||||
@ -567,10 +563,6 @@ func TestSendFile_AtMinBoundary(t *testing.T) {
|
|||||||
t.Logf("SendFile returned: %v", err)
|
t.Logf("SendFile returned: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if serverConn != nil {
|
|
||||||
serverConn.Close()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestSendFile_JustBelowMin 测试刚好小于 MinSendfileSize 的文件(使用 fallback)
|
// TestSendFile_JustBelowMin 测试刚好小于 MinSendfileSize 的文件(使用 fallback)
|
||||||
@ -610,23 +602,20 @@ func TestGetSocketFd_TCPConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
|
|
||||||
// 启动 goroutine 接收连接
|
connCh := make(chan net.Conn, 1)
|
||||||
var serverConn net.Conn
|
|
||||||
go func() {
|
go func() {
|
||||||
serverConn, _ = ln.Accept()
|
c, err := ln.Accept()
|
||||||
|
if err == nil {
|
||||||
|
connCh <- c
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 客户端连接
|
|
||||||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to dial: %v", err)
|
t.Fatalf("Failed to dial: %v", err)
|
||||||
}
|
}
|
||||||
defer clientConn.Close()
|
defer clientConn.Close()
|
||||||
|
|
||||||
// 等待连接建立
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
// 测试获取 socket fd
|
|
||||||
fd, err := getSocketFd(clientConn)
|
fd, err := getSocketFd(clientConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("getSocketFd failed for TCPConn: %v", err)
|
t.Errorf("getSocketFd failed for TCPConn: %v", err)
|
||||||
@ -635,8 +624,12 @@ func TestGetSocketFd_TCPConn(t *testing.T) {
|
|||||||
t.Error("Expected non-zero fd for TCPConn")
|
t.Error("Expected non-zero fd for TCPConn")
|
||||||
}
|
}
|
||||||
|
|
||||||
if serverConn != nil {
|
select {
|
||||||
serverConn.Close()
|
case sc := <-connCh:
|
||||||
|
if sc != nil {
|
||||||
|
sc.Close()
|
||||||
|
}
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -667,14 +660,15 @@ func TestLinuxSendfile_WithTCPConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
|
|
||||||
var serverConn net.Conn
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Go(func() {
|
wg.Go(func() {
|
||||||
serverConn, _ = ln.Accept()
|
c, err := ln.Accept()
|
||||||
// 读取所有数据
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
buf := make([]byte, len(content))
|
buf := make([]byte, len(content))
|
||||||
_, _ = io.ReadFull(serverConn, buf)
|
_, _ = io.ReadFull(c, buf)
|
||||||
serverConn.Close()
|
c.Close()
|
||||||
})
|
})
|
||||||
|
|
||||||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
@ -729,9 +723,12 @@ func TestLinuxSendfile_InvalidFileFd(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
|
|
||||||
var serverConn net.Conn
|
connCh := make(chan net.Conn, 1)
|
||||||
go func() {
|
go func() {
|
||||||
serverConn, _ = ln.Accept()
|
c, err := ln.Accept()
|
||||||
|
if err == nil {
|
||||||
|
connCh <- c
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
@ -740,16 +737,17 @@ func TestLinuxSendfile_InvalidFileFd(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer clientConn.Close()
|
defer clientConn.Close()
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
// 使用无效的文件描述符
|
|
||||||
err = linuxSendfile(clientConn, uintptr(99999), 0, 1024)
|
err = linuxSendfile(clientConn, uintptr(99999), 0, 1024)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error for invalid file descriptor")
|
t.Error("Expected error for invalid file descriptor")
|
||||||
}
|
}
|
||||||
|
|
||||||
if serverConn != nil {
|
select {
|
||||||
serverConn.Close()
|
case sc := <-connCh:
|
||||||
|
if sc != nil {
|
||||||
|
sc.Close()
|
||||||
|
}
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -772,9 +770,12 @@ func TestLinuxSendfile_ZeroLength(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
|
|
||||||
var serverConn net.Conn
|
connCh := make(chan net.Conn, 1)
|
||||||
go func() {
|
go func() {
|
||||||
serverConn, _ = ln.Accept()
|
c, err := ln.Accept()
|
||||||
|
if err == nil {
|
||||||
|
connCh <- c
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
@ -783,16 +784,17 @@ func TestLinuxSendfile_ZeroLength(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer clientConn.Close()
|
defer clientConn.Close()
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
// 零长度应该立即返回
|
|
||||||
err = linuxSendfile(clientConn, file.Fd(), 0, 0)
|
err = linuxSendfile(clientConn, file.Fd(), 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Expected nil for zero length, got: %v", err)
|
t.Errorf("Expected nil for zero length, got: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if serverConn != nil {
|
select {
|
||||||
serverConn.Close()
|
case sc := <-connCh:
|
||||||
|
if sc != nil {
|
||||||
|
sc.Close()
|
||||||
|
}
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -822,15 +824,17 @@ func TestLinuxSendfile_PartialTransfer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
|
|
||||||
var serverConn net.Conn
|
|
||||||
var received []byte
|
var received []byte
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Go(func() {
|
wg.Go(func() {
|
||||||
serverConn, _ = ln.Accept()
|
c, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
buf := make([]byte, len(content))
|
buf := make([]byte, len(content))
|
||||||
n, _ := serverConn.Read(buf)
|
n, _ := c.Read(buf)
|
||||||
received = buf[:n]
|
received = buf[:n]
|
||||||
serverConn.Close()
|
c.Close()
|
||||||
})
|
})
|
||||||
|
|
||||||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
@ -882,13 +886,15 @@ func TestLinuxSendfile_WithOffset(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
|
|
||||||
var serverConn net.Conn
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Go(func() {
|
wg.Go(func() {
|
||||||
serverConn, _ = ln.Accept()
|
c, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
buf := make([]byte, 8*1024)
|
buf := make([]byte, 8*1024)
|
||||||
_, _ = serverConn.Read(buf)
|
_, _ = c.Read(buf)
|
||||||
serverConn.Close()
|
c.Close()
|
||||||
})
|
})
|
||||||
|
|
||||||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
|
|||||||
@ -248,7 +248,9 @@ func TestTCPSocket_Connect_Failure(t *testing.T) {
|
|||||||
require.NoError(t, err) // Connect 本身不报错
|
require.NoError(t, err) // Connect 本身不报错
|
||||||
|
|
||||||
// 等待异步连接完成
|
// 等待异步连接完成
|
||||||
|
socket.mu.RLock()
|
||||||
op := socket.currentOp
|
op := socket.currentOp
|
||||||
|
socket.mu.RUnlock()
|
||||||
if op != nil {
|
if op != nil {
|
||||||
_, err := op.Wait(context.Background())
|
_, err := op.Wait(context.Background())
|
||||||
assert.Error(t, err) // 连接应该失败
|
assert.Error(t, err) // 连接应该失败
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user