lolly/internal/handler/sendfile_test.go
xfy 27e00b84a8 fix(proxy,handler,server,stream,ratelimit): fix resource leaks and functional bugs
- proxy/proxy.go: decrement connection count on dangerous path rejection
  (line 724) to prevent connection count leak
- handler/sendfile_linux.go: return *os.File from getSocketFile and let
  linuxSendfile close it, fixing EBADF from deferred close in getSocketFd
- proxy/websocket.go: return bufio.Reader from readWebSocketUpgradeResponse
  and wrap targetConn with bufferedConn to consume pre-buffered frame data,
  preventing first-frame loss
- server/pool.go: use non-blocking send after starting new worker to avoid
  deadlock when queue is full
- stream/stream.go: check stopCh on non-timeout UDP read errors to prevent
  infinite loop and shutdown deadlock
- middleware/ratelimit: replace select-based close guard with sync.Once in
  StopCleanup to prevent double-close panic
2026-06-11 16:35:10 +08:00

948 lines
22 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build linux
// Package handler 提供 Sendfile 功能的 Linux 平台测试。
//
// 该文件测试 Linux 平台特有的 Sendfile 功能,包括:
// - Linux sendfile 系统调用
// - 套接字文件描述符获取
// - 小文件发送 fallback
//
// 作者xfy
package handler
import (
"bytes"
"io"
"math/rand"
"net"
"os"
"path/filepath"
"sync"
"syscall"
"testing"
"time"
"github.com/valyala/fasthttp"
)
func TestMinSendfileSize(t *testing.T) {
if MinSendfileSize != 8*1024 {
t.Errorf("Expected MinSendfileSize 8KB, got %d", MinSendfileSize)
}
}
// TestCopyFile 测试 copyFile fallback 函数
func TestCopyFile(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("Hello, World! This is test content for copyFile.")
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
tests := []struct {
name string
offset int64
length int64
wantLen int
wantErr bool
}{
{
name: "full file",
offset: 0,
length: 0,
wantLen: len(content),
wantErr: false,
},
{
name: "with length",
offset: 0,
length: 10,
wantLen: 10,
wantErr: false,
},
{
name: "with offset",
offset: 7,
length: 5,
wantLen: 5,
wantErr: false,
},
{
name: "offset beyond file",
offset: 1000,
length: 10,
wantLen: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, _ = file.Seek(0, io.SeekStart)
ctx := &fasthttp.RequestCtx{}
err := copyFile(ctx, file, tt.offset, tt.length)
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
body := ctx.Response.Body()
if len(body) != tt.wantLen {
t.Errorf("expected body length %d, got %d", tt.wantLen, len(body))
}
if tt.wantLen > 0 && tt.length == 0 {
if string(body) != string(content[tt.offset:]) {
t.Errorf("body content mismatch")
}
}
}
})
}
}
// TestGetSocketFd_NilConn 测试 nil 连接的情况
func TestGetSocketFd_NilConn(t *testing.T) {
_, err := getSocketFile(nil)
if err == nil {
t.Error("expected error for nil connection")
}
}
// TestGetSocketFd_UnsupportedType 测试不支持的连接类型
func TestGetSocketFd_UnsupportedType(t *testing.T) {
conn := &mockConn{}
_, err := getSocketFile(conn)
if err != syscall.ENOTSUP {
t.Errorf("expected ENOTSUP for unsupported conn type, got: %v", err)
}
}
// mockConn 是一个不实现 TCPConn/UnixConn 的模拟连接。
//
// 用于测试 getSocketFile 对不支持连接类型的处理。
type mockConn struct{}
func (m *mockConn) Read([]byte) (n int, err error) { return 0, nil }
func (m *mockConn) Write([]byte) (n int, err error) { return 0, nil }
func (m *mockConn) Close() error { return nil }
func (m *mockConn) LocalAddr() net.Addr { return nil }
func (m *mockConn) RemoteAddr() net.Addr { return nil }
func (m *mockConn) SetDeadline(time.Time) error { return nil }
func (m *mockConn) SetReadDeadline(time.Time) error { return nil }
func (m *mockConn) SetWriteDeadline(time.Time) error { return nil }
// TestSendFile_SmallFile 测试小文件发送(使用 fallback
func TestSendFile_SmallFile(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "small.txt")
content := []byte("small file content")
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
err = SendFile(ctx, file, 0, int64(len(content)))
if err != nil {
t.Errorf("SendFile failed: %v", err)
}
if !bytes.Equal(ctx.Response.Body(), content) {
t.Errorf("Expected body %s, got %s", content, ctx.Response.Body())
}
}
// TestSendFile_WithOffset 测试带偏移量的文件发送
func TestSendFile_WithOffset(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("0123456789ABCDEF")
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
err = SendFile(ctx, file, 5, 5)
if err != nil {
t.Errorf("SendFile failed: %v", err)
}
expected := content[5:10]
if !bytes.Equal(ctx.Response.Body(), expected) {
t.Errorf("Expected body %s, got %s", expected, ctx.Response.Body())
}
}
// TestSendFile_ZeroLength 测试零长度文件
func TestSendFile_ZeroLength(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "empty.txt")
if err := os.WriteFile(tmpFile, []byte{}, 0o644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
err = SendFile(ctx, file, 0, 0)
if err != nil {
t.Errorf("SendFile failed: %v", err)
}
if len(ctx.Response.Body()) != 0 {
t.Errorf("Expected empty body, got %s", ctx.Response.Body())
}
}
// TestGetNetConn 测试获取底层连接
func TestGetNetConn(_ *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
conn := getNetConn(ctx)
_ = conn
}
// TestCopyFile_Error 测试 copyFile 错误情况
func TestCopyFile_Error(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("test content")
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
err = copyFile(ctx, file, 1000, 10)
if err == nil {
t.Error("Expected error for offset beyond file size")
}
}
// TestLinuxSendfile_NilConn 测试 linuxSendfile 空连接
func TestLinuxSendfile_NilConn(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("test")
_ = os.WriteFile(tmpFile, content, 0o644)
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
err = linuxSendfile(nil, file.Fd(), 0, int64(len(content)))
if err == nil {
t.Error("Expected error for nil connection")
}
}
// TestSendFile_LargeFile 测试大文件使用 sendfile 调用
func TestSendFile_LargeFile(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "large.bin")
// 创建超过 MinSendfileSize (8KB) 的文件
content := make([]byte, 16*1024) // 16KB
_, _ = rand.Read(content)
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
// 创建真正的 TCP 连接用于 sendfile
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer ln.Close()
connCh := make(chan net.Conn, 1)
go func() {
c, err := ln.Accept()
if err == nil {
connCh <- c
}
}()
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer clientConn.Close()
if err := clientConn.SetDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Fatalf("Failed to set deadline: %v", err)
}
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/test")
err = SendFile(ctx, file, 0, int64(len(content)))
if err != nil {
t.Logf("SendFile returned: %v", err)
if err != syscall.EPIPE && err != syscall.ECONNRESET {
t.Errorf("SendFile unexpected error: %v", err)
}
}
select {
case sc := <-connCh:
if sc != nil {
sc.Close()
}
case <-time.After(3 * time.Second):
}
}
// TestSendFile_FullRange 测试传输完整文件范围
func TestSendFile_FullRange(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "range.txt")
content := []byte("0123456789ABCDEF")
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
// 传输整个文件
err = SendFile(ctx, file, 0, -1)
if err != nil {
t.Errorf("SendFile failed: %v", err)
}
if !bytes.Equal(ctx.Response.Body(), content) {
t.Errorf("Expected body %s, got %s", content, ctx.Response.Body())
}
}
// TestSendFile_FileNotFound 测试文件不存在的情况
func TestSendFile_FileNotFound(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
// 打开不存在的文件
file, err := os.Open("/nonexistent/file/test.txt")
if err != nil {
t.Skip("Skipping: file not found")
}
defer file.Close()
err = SendFile(ctx, file, 0, 100)
if err == nil {
t.Error("Expected error for non-existent file")
}
}
// TestCopyFile_EmptyFile 测试空文件拷贝
func TestCopyFile_EmptyFile(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "empty.txt")
if err := os.WriteFile(tmpFile, []byte{}, 0o644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpDir)
if err != nil {
t.Fatalf("Failed to open dir: %v", err)
}
defer file.Close()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
// 尝试拷贝目录(应失败)
err = copyFile(ctx, file, 0, 0)
// 目录不可读,应返回错误
if err == nil {
t.Error("Expected error when copying directory")
}
}
// TestGetSocketFd_UnixConn 测试 UnixConn 获取 socket fd
func TestGetSocketFd_UnixConn(t *testing.T) {
tmpDir := t.TempDir()
socketPath := filepath.Join(tmpDir, "test.sock")
ln, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("Failed to listen on unix socket: %v", err)
}
defer ln.Close()
defer os.Remove(socketPath)
connCh := make(chan net.Conn, 1)
go func() {
c, err := ln.Accept()
if err == nil {
connCh <- c
}
}()
clientConn, err := net.Dial("unix", socketPath)
if err != nil {
t.Fatalf("Failed to dial unix socket: %v", err)
}
defer clientConn.Close()
f, err := getSocketFile(clientConn)
if err != nil {
t.Errorf("getSocketFile failed: %v", err)
}
if f == nil {
t.Error("Expected non-zero fd")
}
select {
case sc := <-connCh:
if sc != nil {
sc.Close()
}
case <-time.After(3 * time.Second):
}
}
// TestSendFile_OffsetBeyondFile 测试偏移量超出文件大小
func TestSendFile_OffsetBeyondFile(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("short content")
_ = os.WriteFile(tmpFile, content, 0o644)
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
// 偏移量超出文件大小
err = SendFile(ctx, file, 1000, 10)
if err == nil {
t.Error("Expected error when offset beyond file size")
}
}
// TestSendFile_LengthOutOfRange 测试长度超出文件范围
func TestSendFile_LengthOutOfRange(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("short")
_ = os.WriteFile(tmpFile, content, 0o644)
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
// 请求长度超出文件大小
err = SendFile(ctx, file, 0, 1000)
if err != nil {
// 小文件会使用 copyFile可能返回错误
t.Logf("SendFile returned: %v", err)
}
}
// TestSendFile_AtMinBoundary 测试刚好等于 MinSendfileSize 的文件
func TestSendFile_AtMinBoundary(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "boundary.bin")
// 创建刚好等于 MinSendfileSize 的文件
content := make([]byte, MinSendfileSize)
_, _ = rand.Read(content)
_ = os.WriteFile(tmpFile, content, 0o644)
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
// 创建监听器
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer ln.Close()
go func() {
c, err := ln.Accept()
if err != nil {
return
}
buf := make([]byte, MinSendfileSize)
c.Read(buf)
c.Close()
}()
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer clientConn.Close()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/test")
err = SendFile(ctx, file, 0, int64(len(content)))
if err != nil {
if err != syscall.EPIPE && err != syscall.ECONNRESET {
t.Logf("SendFile returned: %v", err)
}
}
}
// TestSendFile_JustBelowMin 测试刚好小于 MinSendfileSize 的文件(使用 fallback
func TestSendFile_JustBelowMin(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "below.bin")
// 创建略小于 MinSendfileSize 的文件
content := make([]byte, MinSendfileSize-1)
_, _ = rand.Read(content)
_ = os.WriteFile(tmpFile, content, 0o644)
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
err = SendFile(ctx, file, 0, int64(len(content)))
if err != nil {
t.Errorf("SendFile failed: %v", err)
}
if !bytes.Equal(ctx.Response.Body(), content) {
t.Errorf("Body mismatch")
}
}
// TestGetSocketFd_TCPConn 测试 TCPConn 获取 socket fd
func TestGetSocketFd_TCPConn(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer ln.Close()
connCh := make(chan net.Conn, 1)
go func() {
c, err := ln.Accept()
if err == nil {
connCh <- c
}
}()
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer clientConn.Close()
f, err := getSocketFile(clientConn)
if err != nil {
t.Errorf("getSocketFile failed for TCPConn: %v", err)
}
if f == nil {
t.Error("Expected non-zero fd for TCPConn")
}
select {
case sc := <-connCh:
if sc != nil {
sc.Close()
}
case <-time.After(3 * time.Second):
}
}
// TestLinuxSendfile_WithTCPConn 测试 linuxSendfile 使用真实 TCP 连接
func TestLinuxSendfile_WithTCPConn(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
// 创建大于 MinSendfileSize 的文件
content := make([]byte, MinSendfileSize+1024)
for i := range content {
content[i] = byte(i % 256)
}
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
// 创建 TCP 连接
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Go(func() {
c, err := ln.Accept()
if err != nil {
return
}
buf := make([]byte, len(content))
_, _ = io.ReadFull(c, buf)
c.Close()
})
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer clientConn.Close()
// 设置写超时
if err := clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil {
t.Fatalf("Failed to set deadline: %v", err)
}
// 调用 linuxSendfile
err = linuxSendfile(clientConn, file.Fd(), 0, int64(len(content)))
if err != nil && err != syscall.EPIPE && err != syscall.ECONNRESET {
t.Logf("linuxSendfile returned: %v", err)
}
clientConn.Close()
wg.Wait()
}
// TestLinuxSendfile_SendfileError 测试 sendfile 错误处理
func TestLinuxSendfile_SendfileError(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("test content for sendfile")
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
// 使用 mockConn 测试不支持的连接类型
conn := &mockConn{}
err = linuxSendfile(conn, file.Fd(), 0, int64(len(content)))
if err == nil {
t.Error("Expected error for unsupported connection type")
}
}
// TestLinuxSendfile_InvalidFileFd 测试无效文件描述符
func TestLinuxSendfile_InvalidFileFd(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer ln.Close()
connCh := make(chan net.Conn, 1)
go func() {
c, err := ln.Accept()
if err == nil {
connCh <- c
}
}()
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer clientConn.Close()
err = linuxSendfile(clientConn, uintptr(99999), 0, 1024)
if err == nil {
t.Error("Expected error for invalid file descriptor")
}
select {
case sc := <-connCh:
if sc != nil {
sc.Close()
}
case <-time.After(3 * time.Second):
}
}
// TestLinuxSendfile_ZeroLength 测试零长度传输
func TestLinuxSendfile_ZeroLength(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("test")
_ = os.WriteFile(tmpFile, content, 0o644)
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer ln.Close()
connCh := make(chan net.Conn, 1)
go func() {
c, err := ln.Accept()
if err == nil {
connCh <- c
}
}()
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer clientConn.Close()
err = linuxSendfile(clientConn, file.Fd(), 0, 0)
if err != nil {
t.Errorf("Expected nil for zero length, got: %v", err)
}
select {
case sc := <-connCh:
if sc != nil {
sc.Close()
}
case <-time.After(3 * time.Second):
}
}
// TestLinuxSendfile_PartialTransfer 测试部分传输后继续
func TestLinuxSendfile_PartialTransfer(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "partial.bin")
// 创建大文件
content := make([]byte, 32*1024)
for i := range content {
content[i] = byte(i % 256)
}
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer ln.Close()
var received []byte
var wg sync.WaitGroup
wg.Go(func() {
c, err := ln.Accept()
if err != nil {
return
}
buf := make([]byte, len(content))
n, _ := c.Read(buf)
received = buf[:n]
c.Close()
})
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer clientConn.Close()
clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second))
// 调用 linuxSendfile 传输整个文件
err = linuxSendfile(clientConn, file.Fd(), 0, int64(len(content)))
// EPIPE/ECONNRESET 是可接受的,因为服务器可能提前关闭
if err != nil && err != syscall.EPIPE && err != syscall.ECONNRESET {
t.Logf("linuxSendfile returned: %v", err)
}
clientConn.Close()
wg.Wait()
// 验证至少传输了部分数据
if len(received) > 0 {
t.Logf("Received %d bytes", len(received))
}
}
// TestLinuxSendfile_WithOffset 测试带偏移量的传输
func TestLinuxSendfile_WithOffset(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "offset.bin")
content := make([]byte, 16*1024)
for i := range content {
content[i] = byte(i % 256)
}
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Go(func() {
c, err := ln.Accept()
if err != nil {
return
}
buf := make([]byte, 8*1024)
_, _ = c.Read(buf)
c.Close()
})
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer clientConn.Close()
clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second))
// 注意linuxSendfile 的 offset 参数未使用(由内核处理)
// 这里测试 length 参数
err = linuxSendfile(clientConn, file.Fd(), 0, 8*1024)
if err != nil && err != syscall.EPIPE && err != syscall.ECONNRESET {
t.Logf("linuxSendfile returned: %v", err)
}
clientConn.Close()
wg.Wait()
}
// TestSendFile_NegativeLength 测试负长度参数
func TestSendFile_NegativeLength(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("test content")
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
// 负长度应该使用 fallback
err = SendFile(ctx, file, 0, -1)
if err != nil {
t.Errorf("SendFile with negative length failed: %v", err)
}
// 应该传输整个文件
if !bytes.Equal(ctx.Response.Body(), content) {
t.Errorf("Expected body %s, got %s", content, ctx.Response.Body())
}
}