refactor(handler): 拆分 sendfile 实现为平台特定文件

- Linux 平台保留 sendfile 系统调用的零拷贝实现
- 非 Linux 平台使用普通 IO fallback
- 分离平台特定测试到独立文件

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-12 11:45:46 +08:00
parent 6a6cfcd11c
commit 92ef122226
4 changed files with 406 additions and 202 deletions

View File

@ -1,18 +1,8 @@
//go:build !linux
// Package handler 提供 HTTP 请求处理器,包括路由、静态文件服务和零拷贝传输。
//
// 该文件包含零拷贝文件传输相关的核心逻辑,包括:
// - sendfile 系统调用的平台特定实现
// - 文件传输的 fallback 机制
// - 缓冲池管理
//
// 主要用途:
//
// 用于优化大文件传输性能,通过零拷贝技术减少 CPU 和内存开销。
//
// 注意事项:
// - Linux 平台使用 sendfile 系统调用
// - macOS 和 Windows 使用 fallback 方式
// - 小文件(< 8KB直接使用 io.Copy
// 该文件包含非 Linux 平台的 sendfile 实现(使用 fallback 方式)。
//
// 作者xfy
package handler
@ -21,7 +11,6 @@ import (
"io"
"net"
"os"
"runtime"
"syscall"
"github.com/valyala/fasthttp"
@ -31,9 +20,6 @@ const (
// MinSendfileSize 使用 sendfile 的最小文件大小8KB
// 小于该值的文件使用普通 io.Copy避免系统调用开销。
MinSendfileSize = 8 * 1024
// platformLinux Linux 平台标识符。
platformLinux = "linux"
)
// SendFile 零拷贝文件传输。
@ -52,8 +38,7 @@ const (
//
// 注意事项:
// - 小于 8KB 的文件使用普通 io.Copy
// - Linux 使用 sendfile 系统调用
// - macOS 和 Windows 使用 fallback 方式
// - 非 Linux 平台macOS、Windows使用 fallback 方式
//
// 使用示例:
//
@ -73,7 +58,7 @@ func SendFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) err
return copyFile(ctx, file, offset, length)
}
// 根据平台选择 sendfile 实现
// 非 Linux 平台使用 fallback
err := platformSendfile(conn, file, offset, length)
if err != nil {
// sendfile 失败fallback 到 io.Copy
@ -124,9 +109,9 @@ func copyFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) err
return err
}
// platformSendfile 平台特定的 sendfile 实现。
// platformSendfile 非 Linux 平台的 sendfile 实现。
//
// 根据运行平台选择合适的零拷贝传输方式
// macOS 和 Windows 不支持 sendfile 系统调用,返回 ENOTSUP 触发 fallback
//
// 参数:
// - conn: 目标网络连接
@ -135,90 +120,9 @@ func copyFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) err
// - length: 传输长度
//
// 返回值:
// - error: 传输错误或不支持时返回 ENOTSUP
// - error: 始终返回 ENOTSUP表示不支持
func platformSendfile(conn net.Conn, file *os.File, offset, length int64) error {
switch runtime.GOOS {
case platformLinux:
return linuxSendfile(conn, file.Fd(), offset, length)
case "darwin":
// macOS sendfile 签名复杂,简化使用 fallback
return syscall.ENOTSUP
case "windows":
// Windows TransmitFile 需要特殊 API
return syscall.ENOTSUP
default:
return syscall.ENOTSUP
}
}
// linuxSendfile Linux sendfile 系统调用。
//
// 使用 Linux 特有的 sendfile 系统调用实现零拷贝传输。
//
// 参数:
// - conn: 目标网络连接
// - fileFd: 源文件描述符
// - offset: 文件起始偏移量
// - length: 传输长度
//
// 返回值:
// - error: 系统调用错误
func linuxSendfile(conn net.Conn, fileFd uintptr, _, length int64) error {
socketFd, err := getSocketFd(conn)
if err != nil {
return err
}
// Linux sendfile: sendfile(out_fd, in_fd, offset, count)
var sent int64
remain := length
for remain > 0 {
n, err := syscall.Sendfile(int(socketFd), int(fileFd), nil, int(remain))
if err != nil {
return err
}
if n == 0 {
break // EOF
}
sent += int64(n)
remain -= int64(n)
}
return nil
}
// getSocketFd 获取 socket 文件描述符。
//
// 从网络连接中提取底层的文件描述符,用于 sendfile 系统调用。
//
// 参数:
// - conn: 网络连接对象
//
// 返回值:
// - uintptr: 文件描述符
// - error: 不支持的连接类型或获取失败时返回错误
//
// 支持的连接类型:
// - *net.TCPConn
// - *net.UnixConn
func getSocketFd(conn net.Conn) (uintptr, error) {
switch c := conn.(type) {
case *net.TCPConn:
file, err := c.File()
if err != nil {
return 0, err
}
defer func() { _ = file.Close() }()
return file.Fd(), nil
case *net.UnixConn:
file, err := c.File()
if err != nil {
return 0, err
}
defer func() { _ = file.Close() }()
return file.Fd(), nil
default:
return 0, syscall.ENOTSUP
}
}
// macOS sendfile 签名复杂,简化使用 fallback
// Windows TransmitFile 需要特殊 API
return syscall.ENOTSUP
}

View File

@ -0,0 +1,135 @@
//go:build linux
// Package handler 提供 HTTP 请求处理器,包括路由、静态文件服务和零拷贝传输。
//
// 该文件包含 Linux 平台完整的 sendfile 实现(零拷贝 + 公共函数)。
//
// 作者xfy
package handler
import (
"io"
"net"
"os"
"syscall"
"github.com/valyala/fasthttp"
)
const (
// MinSendfileSize 使用 sendfile 的最小文件大小8KB
// 小于该值的文件使用普通 io.Copy避免系统调用开销。
MinSendfileSize = 8 * 1024
)
// platformLinux Linux 平台标识符。
const platformLinux = "linux"
// SendFile 零拷贝文件传输。
//
// 大文件使用系统调用直接从文件传输到 socket避免用户空间拷贝
// 从而减少 CPU 和内存开销,提升传输性能。
//
// 参数:
// - ctx: fasthttp 请求上下文,用于获取底层连接
// - file: 要传输的文件对象
// - offset: 文件起始偏移量(字节)
// - length: 传输长度(字节),-1 表示传输到文件末尾
//
// 返回值:
// - error: 传输过程中的错误
func SendFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) error {
// 小文件使用普通 io.Copy
if length < MinSendfileSize {
return copyFile(ctx, file, offset, length)
}
// 尝试获取 socket 文件描述符
conn := getNetConn(ctx)
if conn == nil {
return copyFile(ctx, file, offset, length)
}
// Linux 平台使用 sendfile 系统调用
err := linuxSendfile(conn, file.Fd(), offset, length)
if err != nil {
// sendfile 失败fallback 到 io.Copy
return copyFile(ctx, file, offset, length)
}
return nil
}
// getNetConn 从 fasthttp.RequestCtx 获取底层 net.Conn。
func getNetConn(ctx *fasthttp.RequestCtx) net.Conn {
return ctx.Conn()
}
// copyFile 普通文件拷贝fallback
func copyFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) error {
if offset > 0 {
if _, err := file.Seek(offset, io.SeekStart); err != nil {
return err
}
}
if length > 0 {
_, err := io.CopyN(ctx, file, length)
return err
}
_, err := io.Copy(ctx, file)
return err
}
// linuxSendfile Linux sendfile 系统调用。
//
// 使用 Linux 特有的 sendfile 系统调用实现零拷贝传输。
func linuxSendfile(conn net.Conn, fileFd uintptr, _, length int64) error {
socketFd, err := getSocketFd(conn)
if err != nil {
return err
}
// Linux sendfile: sendfile(out_fd, in_fd, offset, count)
var sent int64
remain := length
for remain > 0 {
n, err := syscall.Sendfile(int(socketFd), int(fileFd), nil, int(remain))
if err != nil {
return err
}
if n == 0 {
break // EOF
}
sent += int64(n)
remain -= int64(n)
}
return nil
}
// getSocketFd 获取 socket 文件描述符。
//
// 从网络连接中提取底层的文件描述符,用于 sendfile 系统调用。
func getSocketFd(conn net.Conn) (uintptr, error) {
switch c := conn.(type) {
case *net.TCPConn:
file, err := c.File()
if err != nil {
return 0, err
}
defer func() { _ = file.Close() }()
return file.Fd(), nil
case *net.UnixConn:
file, err := c.File()
if err != nil {
return 0, err
}
defer func() { _ = file.Close() }()
return file.Fd(), nil
default:
return 0, syscall.ENOTSUP
}
}

View File

@ -0,0 +1,248 @@
//go:build !linux
// Package handler 提供 Sendfile 功能的测试(非 Linux 平台)。
//
// 该文件测试非 Linux 平台的 Sendfile 功能。
//
// 作者xfy
package handler
import (
"bytes"
"io"
"os"
"path/filepath"
"syscall"
"testing"
"github.com/valyala/fasthttp"
)
func TestMinSendfileSize(t *testing.T) {
if MinSendfileSize != 8*1024 {
t.Errorf("Expected MinSendfileSize 8KB, got %d", MinSendfileSize)
}
}
// TestPlatformSendfile_NonLinux 测试非 Linux 平台的 sendfile 行为
func TestPlatformSendfile_NonLinux(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("test content")
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
err = platformSendfile(nil, file, 0, int64(len(content)))
if err != syscall.ENOTSUP {
t.Errorf("expected ENOTSUP on non-Linux, got: %v", err)
}
}
// TestCopyFile 测试 copyFile fallback 函数
func TestCopyFile(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("Hello, World! This is test content for copyFile.")
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
tests := []struct {
name string
offset int64
length int64
wantLen int
wantErr bool
}{
{
name: "full file",
offset: 0,
length: 0,
wantLen: len(content),
wantErr: false,
},
{
name: "with length",
offset: 0,
length: 10,
wantLen: 10,
wantErr: false,
},
{
name: "with offset",
offset: 7,
length: 5,
wantLen: 5,
wantErr: false,
},
{
name: "offset beyond file",
offset: 1000,
length: 10,
wantLen: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, _ = file.Seek(0, io.SeekStart)
ctx := &fasthttp.RequestCtx{}
err := copyFile(ctx, file, tt.offset, tt.length)
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
body := ctx.Response.Body()
if len(body) != tt.wantLen {
t.Errorf("expected body length %d, got %d", tt.wantLen, len(body))
}
if tt.wantLen > 0 && tt.length == 0 {
if string(body) != string(content[tt.offset:]) {
t.Errorf("body content mismatch")
}
}
}
})
}
}
// TestSendFile_SmallFile 测试小文件发送(使用 fallback
func TestSendFile_SmallFile(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "small.txt")
content := []byte("small file content")
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
err = SendFile(ctx, file, 0, int64(len(content)))
if err != nil {
t.Errorf("SendFile failed: %v", err)
}
if !bytes.Equal(ctx.Response.Body(), content) {
t.Errorf("Expected body %s, got %s", content, ctx.Response.Body())
}
}
// TestSendFile_WithOffset 测试带偏移量的文件发送
func TestSendFile_WithOffset(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("0123456789ABCDEF")
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
err = SendFile(ctx, file, 5, 5)
if err != nil {
t.Errorf("SendFile failed: %v", err)
}
expected := content[5:10]
if !bytes.Equal(ctx.Response.Body(), expected) {
t.Errorf("Expected body %s, got %s", expected, ctx.Response.Body())
}
}
// TestSendFile_ZeroLength 测试零长度文件
func TestSendFile_ZeroLength(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "empty.txt")
if err := os.WriteFile(tmpFile, []byte{}, 0644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
err = SendFile(ctx, file, 0, 0)
if err != nil {
t.Errorf("SendFile failed: %v", err)
}
if len(ctx.Response.Body()) != 0 {
t.Errorf("Expected empty body, got %s", ctx.Response.Body())
}
}
// TestGetNetConn 测试获取底层连接
func TestGetNetConn(_ *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
conn := getNetConn(ctx)
_ = conn
}
// TestCopyFile_Error 测试 copyFile 错误情况
func TestCopyFile_Error(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("test content")
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
err = copyFile(ctx, file, 1000, 10)
if err == nil {
t.Error("Expected error for offset beyond file size")
}
}

View File

@ -1,16 +1,11 @@
// Package handler 提供 Sendfile 功能的测试。
//go:build linux
// Package handler 提供 Sendfile 功能的 Linux 平台测试。
//
// 该文件测试 Sendfile 模块的各项功能,包括:
// - 最小 Sendfile 大小
// - 平台 Sendfile 行为
// - 文件复制功能
// - 非 Linux 平台行为
// 该文件测试 Linux 平台特有的 Sendfile 功能,包括:
// - Linux sendfile 系统调用
// - 套接字文件描述符获取
// - 小文件发送
// - 带偏移量发送
// - 零长度文件
// - 网络连接获取
// - 错误处理
// - 小文件发送 fallback
//
// 作者xfy
package handler
@ -21,7 +16,6 @@ import (
"net"
"os"
"path/filepath"
"runtime"
"syscall"
"testing"
"time"
@ -35,27 +29,6 @@ func TestMinSendfileSize(t *testing.T) {
}
}
func TestPlatformSendfile(t *testing.T) {
// 创建临时文件
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("Hello, World! This is a test file for sendfile.")
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
// 测试平台 sendfile小文件会 fallback 到 copyFile
// 由于没有真实的网络连接,这个测试主要验证不会崩溃
_ = platformSendfile(nil, file, 0, int64(len(content)))
}
// TestCopyFile 测试 copyFile fallback 函数
func TestCopyFile(t *testing.T) {
tmpDir := t.TempDir()
@ -82,7 +55,7 @@ func TestCopyFile(t *testing.T) {
{
name: "full file",
offset: 0,
length: 0, // 0 means copy all
length: 0,
wantLen: len(content),
wantErr: false,
},
@ -105,16 +78,13 @@ func TestCopyFile(t *testing.T) {
offset: 1000,
length: 10,
wantLen: 0,
wantErr: true, // io.CopyN returns EOF error
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 重置文件位置
_, _ = file.Seek(0, io.SeekStart)
// 创建响应上下文
ctx := &fasthttp.RequestCtx{}
err := copyFile(ctx, file, tt.offset, tt.length)
@ -131,7 +101,6 @@ func TestCopyFile(t *testing.T) {
t.Errorf("expected body length %d, got %d", tt.wantLen, len(body))
}
if tt.wantLen > 0 && tt.length == 0 {
// 全量拷贝时验证内容
if string(body) != string(content[tt.offset:]) {
t.Errorf("body content mismatch")
}
@ -141,31 +110,6 @@ func TestCopyFile(t *testing.T) {
}
}
// TestPlatformSendfile_NonLinux 测试非 Linux 平台的 sendfile 行为
func TestPlatformSendfile_NonLinux(t *testing.T) {
if runtime.GOOS == platformLinux {
t.Skip("this test is for non-Linux platforms")
}
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("test content")
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
file, err := os.Open(tmpFile)
if err != nil {
t.Fatalf("Failed to open file: %v", err)
}
defer func() { _ = file.Close() }()
err = platformSendfile(nil, file, 0, int64(len(content)))
if err != syscall.ENOTSUP {
t.Errorf("expected ENOTSUP on non-Linux, got: %v", err)
}
}
// TestGetSocketFd_NilConn 测试 nil 连接的情况
func TestGetSocketFd_NilConn(t *testing.T) {
_, err := getSocketFd(nil)
@ -176,7 +120,6 @@ func TestGetSocketFd_NilConn(t *testing.T) {
// TestGetSocketFd_UnsupportedType 测试不支持的连接类型
func TestGetSocketFd_UnsupportedType(t *testing.T) {
// 创建一个不支持的连接类型
conn := &mockConn{}
_, err := getSocketFd(conn)
if err != syscall.ENOTSUP {
@ -201,7 +144,6 @@ func TestSendFile_SmallFile(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "small.txt")
// 创建小文件 (< 8KB)
content := []byte("small file content")
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
@ -221,7 +163,6 @@ func TestSendFile_SmallFile(t *testing.T) {
t.Errorf("SendFile failed: %v", err)
}
// 验证响应体
if !bytes.Equal(ctx.Response.Body(), content) {
t.Errorf("Expected body %s, got %s", content, ctx.Response.Body())
}
@ -232,7 +173,7 @@ func TestSendFile_WithOffset(t *testing.T) {
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("0123456789ABCDEF") // 16 bytes
content := []byte("0123456789ABCDEF")
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
@ -246,13 +187,12 @@ func TestSendFile_WithOffset(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
// 从偏移量 5 开始,读取 5 字节
err = SendFile(ctx, file, 5, 5)
if err != nil {
t.Errorf("SendFile failed: %v", err)
}
expected := content[5:10] // "56789"
expected := content[5:10]
if !bytes.Equal(ctx.Response.Body(), expected) {
t.Errorf("Expected body %s, got %s", expected, ctx.Response.Body())
}
@ -291,28 +231,10 @@ func TestGetNetConn(_ *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
// fasthttp 会创建内部连接,所以这里测试能正常获取
conn := getNetConn(ctx)
// 主要验证不会崩溃
_ = conn
}
// TestSendFile_NilFile 测试空文件指针
func TestSendFile_NilFile(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
// 传入 nil 文件应该 panic 或返回错误
defer func() {
if r := recover(); r == nil {
// 没有 panic检查是否有错误返回
t.Log("expected panic did not occur for nil file")
}
}()
// 这个测试主要确保不会静默失败
}
// TestCopyFile_Error 测试 copyFile 错误情况
func TestCopyFile_Error(t *testing.T) {
tmpDir := t.TempDir()
@ -332,7 +254,6 @@ func TestCopyFile_Error(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Init(&fasthttp.Request{}, nil, nil)
// 测试偏移量超出文件大小
err = copyFile(ctx, file, 1000, 10)
if err == nil {
t.Error("Expected error for offset beyond file size")
@ -341,10 +262,6 @@ func TestCopyFile_Error(t *testing.T) {
// TestLinuxSendfile_NilConn 测试 linuxSendfile 空连接
func TestLinuxSendfile_NilConn(t *testing.T) {
if runtime.GOOS != platformLinux {
t.Skip("This test is for Linux only")
}
tmpDir := t.TempDir()
tmpFile := filepath.Join(tmpDir, "test.txt")
content := []byte("test")
@ -360,4 +277,4 @@ func TestLinuxSendfile_NilConn(t *testing.T) {
if err == nil {
t.Error("Expected error for nil connection")
}
}
}