- proxy/proxy.go: decrement connection count on dangerous path rejection (line 724) to prevent connection count leak - handler/sendfile_linux.go: return *os.File from getSocketFile and let linuxSendfile close it, fixing EBADF from deferred close in getSocketFd - proxy/websocket.go: return bufio.Reader from readWebSocketUpgradeResponse and wrap targetConn with bufferedConn to consume pre-buffered frame data, preventing first-frame loss - server/pool.go: use non-blocking send after starting new worker to avoid deadlock when queue is full - stream/stream.go: check stopCh on non-timeout UDP read errors to prevent infinite loop and shutdown deadlock - middleware/ratelimit: replace select-based close guard with sync.Once in StopCleanup to prevent double-close panic
948 lines
22 KiB
Go
948 lines
22 KiB
Go
//go:build linux
|
||
|
||
// Package handler 提供 Sendfile 功能的 Linux 平台测试。
|
||
//
|
||
// 该文件测试 Linux 平台特有的 Sendfile 功能,包括:
|
||
// - Linux sendfile 系统调用
|
||
// - 套接字文件描述符获取
|
||
// - 小文件发送 fallback
|
||
//
|
||
// 作者:xfy
|
||
package handler
|
||
|
||
import (
|
||
"bytes"
|
||
"io"
|
||
"math/rand"
|
||
"net"
|
||
"os"
|
||
"path/filepath"
|
||
"sync"
|
||
"syscall"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/valyala/fasthttp"
|
||
)
|
||
|
||
func TestMinSendfileSize(t *testing.T) {
|
||
if MinSendfileSize != 8*1024 {
|
||
t.Errorf("Expected MinSendfileSize 8KB, got %d", MinSendfileSize)
|
||
}
|
||
}
|
||
|
||
// TestCopyFile 测试 copyFile fallback 函数
|
||
func TestCopyFile(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||
|
||
content := []byte("Hello, World! This is test content for copyFile.")
|
||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||
t.Fatalf("Failed to create temp file: %v", err)
|
||
}
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
tests := []struct {
|
||
name string
|
||
offset int64
|
||
length int64
|
||
wantLen int
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "full file",
|
||
offset: 0,
|
||
length: 0,
|
||
wantLen: len(content),
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "with length",
|
||
offset: 0,
|
||
length: 10,
|
||
wantLen: 10,
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "with offset",
|
||
offset: 7,
|
||
length: 5,
|
||
wantLen: 5,
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "offset beyond file",
|
||
offset: 1000,
|
||
length: 10,
|
||
wantLen: 0,
|
||
wantErr: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
_, _ = file.Seek(0, io.SeekStart)
|
||
ctx := &fasthttp.RequestCtx{}
|
||
|
||
err := copyFile(ctx, file, tt.offset, tt.length)
|
||
if tt.wantErr {
|
||
if err == nil {
|
||
t.Error("expected error, got nil")
|
||
}
|
||
} else {
|
||
if err != nil {
|
||
t.Errorf("unexpected error: %v", err)
|
||
}
|
||
body := ctx.Response.Body()
|
||
if len(body) != tt.wantLen {
|
||
t.Errorf("expected body length %d, got %d", tt.wantLen, len(body))
|
||
}
|
||
if tt.wantLen > 0 && tt.length == 0 {
|
||
if string(body) != string(content[tt.offset:]) {
|
||
t.Errorf("body content mismatch")
|
||
}
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestGetSocketFd_NilConn 测试 nil 连接的情况
|
||
func TestGetSocketFd_NilConn(t *testing.T) {
|
||
_, err := getSocketFile(nil)
|
||
if err == nil {
|
||
t.Error("expected error for nil connection")
|
||
}
|
||
}
|
||
|
||
// TestGetSocketFd_UnsupportedType 测试不支持的连接类型
|
||
func TestGetSocketFd_UnsupportedType(t *testing.T) {
|
||
conn := &mockConn{}
|
||
_, err := getSocketFile(conn)
|
||
if err != syscall.ENOTSUP {
|
||
t.Errorf("expected ENOTSUP for unsupported conn type, got: %v", err)
|
||
}
|
||
}
|
||
|
||
// mockConn 是一个不实现 TCPConn/UnixConn 的模拟连接。
|
||
//
|
||
// 用于测试 getSocketFile 对不支持连接类型的处理。
|
||
type mockConn struct{}
|
||
|
||
func (m *mockConn) Read([]byte) (n int, err error) { return 0, nil }
|
||
func (m *mockConn) Write([]byte) (n int, err error) { return 0, nil }
|
||
func (m *mockConn) Close() error { return nil }
|
||
func (m *mockConn) LocalAddr() net.Addr { return nil }
|
||
func (m *mockConn) RemoteAddr() net.Addr { return nil }
|
||
func (m *mockConn) SetDeadline(time.Time) error { return nil }
|
||
func (m *mockConn) SetReadDeadline(time.Time) error { return nil }
|
||
func (m *mockConn) SetWriteDeadline(time.Time) error { return nil }
|
||
|
||
// TestSendFile_SmallFile 测试小文件发送(使用 fallback)
|
||
func TestSendFile_SmallFile(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "small.txt")
|
||
|
||
content := []byte("small file content")
|
||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||
t.Fatalf("Failed to create temp file: %v", err)
|
||
}
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
|
||
err = SendFile(ctx, file, 0, int64(len(content)))
|
||
if err != nil {
|
||
t.Errorf("SendFile failed: %v", err)
|
||
}
|
||
|
||
if !bytes.Equal(ctx.Response.Body(), content) {
|
||
t.Errorf("Expected body %s, got %s", content, ctx.Response.Body())
|
||
}
|
||
}
|
||
|
||
// TestSendFile_WithOffset 测试带偏移量的文件发送
|
||
func TestSendFile_WithOffset(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||
|
||
content := []byte("0123456789ABCDEF")
|
||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||
t.Fatalf("Failed to create temp file: %v", err)
|
||
}
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
|
||
err = SendFile(ctx, file, 5, 5)
|
||
if err != nil {
|
||
t.Errorf("SendFile failed: %v", err)
|
||
}
|
||
|
||
expected := content[5:10]
|
||
if !bytes.Equal(ctx.Response.Body(), expected) {
|
||
t.Errorf("Expected body %s, got %s", expected, ctx.Response.Body())
|
||
}
|
||
}
|
||
|
||
// TestSendFile_ZeroLength 测试零长度文件
|
||
func TestSendFile_ZeroLength(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "empty.txt")
|
||
|
||
if err := os.WriteFile(tmpFile, []byte{}, 0o644); err != nil {
|
||
t.Fatalf("Failed to create temp file: %v", err)
|
||
}
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
|
||
err = SendFile(ctx, file, 0, 0)
|
||
if err != nil {
|
||
t.Errorf("SendFile failed: %v", err)
|
||
}
|
||
|
||
if len(ctx.Response.Body()) != 0 {
|
||
t.Errorf("Expected empty body, got %s", ctx.Response.Body())
|
||
}
|
||
}
|
||
|
||
// TestGetNetConn 测试获取底层连接
|
||
func TestGetNetConn(_ *testing.T) {
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
|
||
conn := getNetConn(ctx)
|
||
_ = conn
|
||
}
|
||
|
||
// TestCopyFile_Error 测试 copyFile 错误情况
|
||
func TestCopyFile_Error(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||
|
||
content := []byte("test content")
|
||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||
t.Fatalf("Failed to create temp file: %v", err)
|
||
}
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
|
||
err = copyFile(ctx, file, 1000, 10)
|
||
if err == nil {
|
||
t.Error("Expected error for offset beyond file size")
|
||
}
|
||
}
|
||
|
||
// TestLinuxSendfile_NilConn 测试 linuxSendfile 空连接
|
||
func TestLinuxSendfile_NilConn(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||
content := []byte("test")
|
||
_ = os.WriteFile(tmpFile, content, 0o644)
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
err = linuxSendfile(nil, file.Fd(), 0, int64(len(content)))
|
||
if err == nil {
|
||
t.Error("Expected error for nil connection")
|
||
}
|
||
}
|
||
|
||
// TestSendFile_LargeFile 测试大文件使用 sendfile 调用
|
||
func TestSendFile_LargeFile(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "large.bin")
|
||
|
||
// 创建超过 MinSendfileSize (8KB) 的文件
|
||
content := make([]byte, 16*1024) // 16KB
|
||
_, _ = rand.Read(content)
|
||
|
||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||
t.Fatalf("Failed to create temp file: %v", err)
|
||
}
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
// 创建真正的 TCP 连接用于 sendfile
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to listen: %v", err)
|
||
}
|
||
defer ln.Close()
|
||
|
||
connCh := make(chan net.Conn, 1)
|
||
go func() {
|
||
c, err := ln.Accept()
|
||
if err == nil {
|
||
connCh <- c
|
||
}
|
||
}()
|
||
|
||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial: %v", err)
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
if err := clientConn.SetDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||
t.Fatalf("Failed to set deadline: %v", err)
|
||
}
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
ctx.Request.Header.SetMethod("GET")
|
||
ctx.Request.SetRequestURI("/test")
|
||
|
||
err = SendFile(ctx, file, 0, int64(len(content)))
|
||
if err != nil {
|
||
t.Logf("SendFile returned: %v", err)
|
||
if err != syscall.EPIPE && err != syscall.ECONNRESET {
|
||
t.Errorf("SendFile unexpected error: %v", err)
|
||
}
|
||
}
|
||
|
||
select {
|
||
case sc := <-connCh:
|
||
if sc != nil {
|
||
sc.Close()
|
||
}
|
||
case <-time.After(3 * time.Second):
|
||
}
|
||
}
|
||
|
||
// TestSendFile_FullRange 测试传输完整文件范围
|
||
func TestSendFile_FullRange(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "range.txt")
|
||
|
||
content := []byte("0123456789ABCDEF")
|
||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||
t.Fatalf("Failed to create temp file: %v", err)
|
||
}
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
|
||
// 传输整个文件
|
||
err = SendFile(ctx, file, 0, -1)
|
||
if err != nil {
|
||
t.Errorf("SendFile failed: %v", err)
|
||
}
|
||
|
||
if !bytes.Equal(ctx.Response.Body(), content) {
|
||
t.Errorf("Expected body %s, got %s", content, ctx.Response.Body())
|
||
}
|
||
}
|
||
|
||
// TestSendFile_FileNotFound 测试文件不存在的情况
|
||
func TestSendFile_FileNotFound(t *testing.T) {
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
|
||
// 打开不存在的文件
|
||
file, err := os.Open("/nonexistent/file/test.txt")
|
||
if err != nil {
|
||
t.Skip("Skipping: file not found")
|
||
}
|
||
defer file.Close()
|
||
|
||
err = SendFile(ctx, file, 0, 100)
|
||
if err == nil {
|
||
t.Error("Expected error for non-existent file")
|
||
}
|
||
}
|
||
|
||
// TestCopyFile_EmptyFile 测试空文件拷贝
|
||
func TestCopyFile_EmptyFile(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "empty.txt")
|
||
|
||
if err := os.WriteFile(tmpFile, []byte{}, 0o644); err != nil {
|
||
t.Fatalf("Failed to create temp file: %v", err)
|
||
}
|
||
|
||
file, err := os.Open(tmpDir)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open dir: %v", err)
|
||
}
|
||
defer file.Close()
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
|
||
// 尝试拷贝目录(应失败)
|
||
err = copyFile(ctx, file, 0, 0)
|
||
// 目录不可读,应返回错误
|
||
if err == nil {
|
||
t.Error("Expected error when copying directory")
|
||
}
|
||
}
|
||
|
||
// TestGetSocketFd_UnixConn 测试 UnixConn 获取 socket fd
|
||
func TestGetSocketFd_UnixConn(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
socketPath := filepath.Join(tmpDir, "test.sock")
|
||
|
||
ln, err := net.Listen("unix", socketPath)
|
||
if err != nil {
|
||
t.Fatalf("Failed to listen on unix socket: %v", err)
|
||
}
|
||
defer ln.Close()
|
||
defer os.Remove(socketPath)
|
||
|
||
connCh := make(chan net.Conn, 1)
|
||
go func() {
|
||
c, err := ln.Accept()
|
||
if err == nil {
|
||
connCh <- c
|
||
}
|
||
}()
|
||
|
||
clientConn, err := net.Dial("unix", socketPath)
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial unix socket: %v", err)
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
f, err := getSocketFile(clientConn)
|
||
if err != nil {
|
||
t.Errorf("getSocketFile failed: %v", err)
|
||
}
|
||
if f == nil {
|
||
t.Error("Expected non-zero fd")
|
||
}
|
||
|
||
select {
|
||
case sc := <-connCh:
|
||
if sc != nil {
|
||
sc.Close()
|
||
}
|
||
case <-time.After(3 * time.Second):
|
||
}
|
||
}
|
||
|
||
// TestSendFile_OffsetBeyondFile 测试偏移量超出文件大小
|
||
func TestSendFile_OffsetBeyondFile(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||
content := []byte("short content")
|
||
_ = os.WriteFile(tmpFile, content, 0o644)
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
|
||
// 偏移量超出文件大小
|
||
err = SendFile(ctx, file, 1000, 10)
|
||
if err == nil {
|
||
t.Error("Expected error when offset beyond file size")
|
||
}
|
||
}
|
||
|
||
// TestSendFile_LengthOutOfRange 测试长度超出文件范围
|
||
func TestSendFile_LengthOutOfRange(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||
content := []byte("short")
|
||
_ = os.WriteFile(tmpFile, content, 0o644)
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
|
||
// 请求长度超出文件大小
|
||
err = SendFile(ctx, file, 0, 1000)
|
||
if err != nil {
|
||
// 小文件会使用 copyFile,可能返回错误
|
||
t.Logf("SendFile returned: %v", err)
|
||
}
|
||
}
|
||
|
||
// TestSendFile_AtMinBoundary 测试刚好等于 MinSendfileSize 的文件
|
||
func TestSendFile_AtMinBoundary(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "boundary.bin")
|
||
|
||
// 创建刚好等于 MinSendfileSize 的文件
|
||
content := make([]byte, MinSendfileSize)
|
||
_, _ = rand.Read(content)
|
||
_ = os.WriteFile(tmpFile, content, 0o644)
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
// 创建监听器
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to listen: %v", err)
|
||
}
|
||
defer ln.Close()
|
||
|
||
go func() {
|
||
c, err := ln.Accept()
|
||
if err != nil {
|
||
return
|
||
}
|
||
buf := make([]byte, MinSendfileSize)
|
||
c.Read(buf)
|
||
c.Close()
|
||
}()
|
||
|
||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial: %v", err)
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
ctx.Request.Header.SetMethod("GET")
|
||
ctx.Request.SetRequestURI("/test")
|
||
|
||
err = SendFile(ctx, file, 0, int64(len(content)))
|
||
if err != nil {
|
||
if err != syscall.EPIPE && err != syscall.ECONNRESET {
|
||
t.Logf("SendFile returned: %v", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
// TestSendFile_JustBelowMin 测试刚好小于 MinSendfileSize 的文件(使用 fallback)
|
||
func TestSendFile_JustBelowMin(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "below.bin")
|
||
|
||
// 创建略小于 MinSendfileSize 的文件
|
||
content := make([]byte, MinSendfileSize-1)
|
||
_, _ = rand.Read(content)
|
||
_ = os.WriteFile(tmpFile, content, 0o644)
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
|
||
err = SendFile(ctx, file, 0, int64(len(content)))
|
||
if err != nil {
|
||
t.Errorf("SendFile failed: %v", err)
|
||
}
|
||
|
||
if !bytes.Equal(ctx.Response.Body(), content) {
|
||
t.Errorf("Body mismatch")
|
||
}
|
||
}
|
||
|
||
// TestGetSocketFd_TCPConn 测试 TCPConn 获取 socket fd
|
||
func TestGetSocketFd_TCPConn(t *testing.T) {
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to listen: %v", err)
|
||
}
|
||
defer ln.Close()
|
||
|
||
connCh := make(chan net.Conn, 1)
|
||
go func() {
|
||
c, err := ln.Accept()
|
||
if err == nil {
|
||
connCh <- c
|
||
}
|
||
}()
|
||
|
||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial: %v", err)
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
f, err := getSocketFile(clientConn)
|
||
if err != nil {
|
||
t.Errorf("getSocketFile failed for TCPConn: %v", err)
|
||
}
|
||
if f == nil {
|
||
t.Error("Expected non-zero fd for TCPConn")
|
||
}
|
||
|
||
select {
|
||
case sc := <-connCh:
|
||
if sc != nil {
|
||
sc.Close()
|
||
}
|
||
case <-time.After(3 * time.Second):
|
||
}
|
||
}
|
||
|
||
// TestLinuxSendfile_WithTCPConn 测试 linuxSendfile 使用真实 TCP 连接
|
||
func TestLinuxSendfile_WithTCPConn(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||
|
||
// 创建大于 MinSendfileSize 的文件
|
||
content := make([]byte, MinSendfileSize+1024)
|
||
for i := range content {
|
||
content[i] = byte(i % 256)
|
||
}
|
||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||
t.Fatalf("Failed to create temp file: %v", err)
|
||
}
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
// 创建 TCP 连接
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to listen: %v", err)
|
||
}
|
||
defer ln.Close()
|
||
|
||
var wg sync.WaitGroup
|
||
wg.Go(func() {
|
||
c, err := ln.Accept()
|
||
if err != nil {
|
||
return
|
||
}
|
||
buf := make([]byte, len(content))
|
||
_, _ = io.ReadFull(c, buf)
|
||
c.Close()
|
||
})
|
||
|
||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial: %v", err)
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
// 设置写超时
|
||
if err := clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||
t.Fatalf("Failed to set deadline: %v", err)
|
||
}
|
||
|
||
// 调用 linuxSendfile
|
||
err = linuxSendfile(clientConn, file.Fd(), 0, int64(len(content)))
|
||
if err != nil && err != syscall.EPIPE && err != syscall.ECONNRESET {
|
||
t.Logf("linuxSendfile returned: %v", err)
|
||
}
|
||
|
||
clientConn.Close()
|
||
wg.Wait()
|
||
}
|
||
|
||
// TestLinuxSendfile_SendfileError 测试 sendfile 错误处理
|
||
func TestLinuxSendfile_SendfileError(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||
content := []byte("test content for sendfile")
|
||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||
t.Fatalf("Failed to create temp file: %v", err)
|
||
}
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
// 使用 mockConn 测试不支持的连接类型
|
||
conn := &mockConn{}
|
||
err = linuxSendfile(conn, file.Fd(), 0, int64(len(content)))
|
||
if err == nil {
|
||
t.Error("Expected error for unsupported connection type")
|
||
}
|
||
}
|
||
|
||
// TestLinuxSendfile_InvalidFileFd 测试无效文件描述符
|
||
func TestLinuxSendfile_InvalidFileFd(t *testing.T) {
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to listen: %v", err)
|
||
}
|
||
defer ln.Close()
|
||
|
||
connCh := make(chan net.Conn, 1)
|
||
go func() {
|
||
c, err := ln.Accept()
|
||
if err == nil {
|
||
connCh <- c
|
||
}
|
||
}()
|
||
|
||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial: %v", err)
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
err = linuxSendfile(clientConn, uintptr(99999), 0, 1024)
|
||
if err == nil {
|
||
t.Error("Expected error for invalid file descriptor")
|
||
}
|
||
|
||
select {
|
||
case sc := <-connCh:
|
||
if sc != nil {
|
||
sc.Close()
|
||
}
|
||
case <-time.After(3 * time.Second):
|
||
}
|
||
}
|
||
|
||
// TestLinuxSendfile_ZeroLength 测试零长度传输
|
||
func TestLinuxSendfile_ZeroLength(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||
content := []byte("test")
|
||
_ = os.WriteFile(tmpFile, content, 0o644)
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to listen: %v", err)
|
||
}
|
||
defer ln.Close()
|
||
|
||
connCh := make(chan net.Conn, 1)
|
||
go func() {
|
||
c, err := ln.Accept()
|
||
if err == nil {
|
||
connCh <- c
|
||
}
|
||
}()
|
||
|
||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial: %v", err)
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
err = linuxSendfile(clientConn, file.Fd(), 0, 0)
|
||
if err != nil {
|
||
t.Errorf("Expected nil for zero length, got: %v", err)
|
||
}
|
||
|
||
select {
|
||
case sc := <-connCh:
|
||
if sc != nil {
|
||
sc.Close()
|
||
}
|
||
case <-time.After(3 * time.Second):
|
||
}
|
||
}
|
||
|
||
// TestLinuxSendfile_PartialTransfer 测试部分传输后继续
|
||
func TestLinuxSendfile_PartialTransfer(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "partial.bin")
|
||
|
||
// 创建大文件
|
||
content := make([]byte, 32*1024)
|
||
for i := range content {
|
||
content[i] = byte(i % 256)
|
||
}
|
||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||
t.Fatalf("Failed to create temp file: %v", err)
|
||
}
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to listen: %v", err)
|
||
}
|
||
defer ln.Close()
|
||
|
||
var received []byte
|
||
var wg sync.WaitGroup
|
||
wg.Go(func() {
|
||
c, err := ln.Accept()
|
||
if err != nil {
|
||
return
|
||
}
|
||
buf := make([]byte, len(content))
|
||
n, _ := c.Read(buf)
|
||
received = buf[:n]
|
||
c.Close()
|
||
})
|
||
|
||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial: %v", err)
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||
|
||
// 调用 linuxSendfile 传输整个文件
|
||
err = linuxSendfile(clientConn, file.Fd(), 0, int64(len(content)))
|
||
// EPIPE/ECONNRESET 是可接受的,因为服务器可能提前关闭
|
||
if err != nil && err != syscall.EPIPE && err != syscall.ECONNRESET {
|
||
t.Logf("linuxSendfile returned: %v", err)
|
||
}
|
||
|
||
clientConn.Close()
|
||
wg.Wait()
|
||
|
||
// 验证至少传输了部分数据
|
||
if len(received) > 0 {
|
||
t.Logf("Received %d bytes", len(received))
|
||
}
|
||
}
|
||
|
||
// TestLinuxSendfile_WithOffset 测试带偏移量的传输
|
||
func TestLinuxSendfile_WithOffset(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "offset.bin")
|
||
|
||
content := make([]byte, 16*1024)
|
||
for i := range content {
|
||
content[i] = byte(i % 256)
|
||
}
|
||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||
t.Fatalf("Failed to create temp file: %v", err)
|
||
}
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||
if err != nil {
|
||
t.Fatalf("Failed to listen: %v", err)
|
||
}
|
||
defer ln.Close()
|
||
|
||
var wg sync.WaitGroup
|
||
wg.Go(func() {
|
||
c, err := ln.Accept()
|
||
if err != nil {
|
||
return
|
||
}
|
||
buf := make([]byte, 8*1024)
|
||
_, _ = c.Read(buf)
|
||
c.Close()
|
||
})
|
||
|
||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||
if err != nil {
|
||
t.Fatalf("Failed to dial: %v", err)
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||
|
||
// 注意:linuxSendfile 的 offset 参数未使用(由内核处理)
|
||
// 这里测试 length 参数
|
||
err = linuxSendfile(clientConn, file.Fd(), 0, 8*1024)
|
||
if err != nil && err != syscall.EPIPE && err != syscall.ECONNRESET {
|
||
t.Logf("linuxSendfile returned: %v", err)
|
||
}
|
||
|
||
clientConn.Close()
|
||
wg.Wait()
|
||
}
|
||
|
||
// TestSendFile_NegativeLength 测试负长度参数
|
||
func TestSendFile_NegativeLength(t *testing.T) {
|
||
tmpDir := t.TempDir()
|
||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||
content := []byte("test content")
|
||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||
t.Fatalf("Failed to create temp file: %v", err)
|
||
}
|
||
|
||
file, err := os.Open(tmpFile)
|
||
if err != nil {
|
||
t.Fatalf("Failed to open file: %v", err)
|
||
}
|
||
defer func() { _ = file.Close() }()
|
||
|
||
ctx := &fasthttp.RequestCtx{}
|
||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||
|
||
// 负长度应该使用 fallback
|
||
err = SendFile(ctx, file, 0, -1)
|
||
if err != nil {
|
||
t.Errorf("SendFile with negative length failed: %v", err)
|
||
}
|
||
|
||
// 应该传输整个文件
|
||
if !bytes.Equal(ctx.Response.Body(), content) {
|
||
t.Errorf("Expected body %s, got %s", content, ctx.Response.Body())
|
||
}
|
||
}
|