refactor(handler): 拆分 sendfile 实现为平台特定文件
- Linux 平台保留 sendfile 系统调用的零拷贝实现 - 非 Linux 平台使用普通 IO fallback - 分离平台特定测试到独立文件 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
6a6cfcd11c
commit
92ef122226
@ -1,18 +1,8 @@
|
||||
//go:build !linux
|
||||
|
||||
// Package handler 提供 HTTP 请求处理器,包括路由、静态文件服务和零拷贝传输。
|
||||
//
|
||||
// 该文件包含零拷贝文件传输相关的核心逻辑,包括:
|
||||
// - sendfile 系统调用的平台特定实现
|
||||
// - 文件传输的 fallback 机制
|
||||
// - 缓冲池管理
|
||||
//
|
||||
// 主要用途:
|
||||
//
|
||||
// 用于优化大文件传输性能,通过零拷贝技术减少 CPU 和内存开销。
|
||||
//
|
||||
// 注意事项:
|
||||
// - Linux 平台使用 sendfile 系统调用
|
||||
// - macOS 和 Windows 使用 fallback 方式
|
||||
// - 小文件(< 8KB)直接使用 io.Copy
|
||||
// 该文件包含非 Linux 平台的 sendfile 实现(使用 fallback 方式)。
|
||||
//
|
||||
// 作者:xfy
|
||||
package handler
|
||||
@ -21,7 +11,6 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
@ -31,9 +20,6 @@ const (
|
||||
// MinSendfileSize 使用 sendfile 的最小文件大小(8KB)。
|
||||
// 小于该值的文件使用普通 io.Copy,避免系统调用开销。
|
||||
MinSendfileSize = 8 * 1024
|
||||
|
||||
// platformLinux Linux 平台标识符。
|
||||
platformLinux = "linux"
|
||||
)
|
||||
|
||||
// SendFile 零拷贝文件传输。
|
||||
@ -52,8 +38,7 @@ const (
|
||||
//
|
||||
// 注意事项:
|
||||
// - 小于 8KB 的文件使用普通 io.Copy
|
||||
// - Linux 使用 sendfile 系统调用
|
||||
// - macOS 和 Windows 使用 fallback 方式
|
||||
// - 非 Linux 平台(macOS、Windows)使用 fallback 方式
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
@ -73,7 +58,7 @@ func SendFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) err
|
||||
return copyFile(ctx, file, offset, length)
|
||||
}
|
||||
|
||||
// 根据平台选择 sendfile 实现
|
||||
// 非 Linux 平台使用 fallback
|
||||
err := platformSendfile(conn, file, offset, length)
|
||||
if err != nil {
|
||||
// sendfile 失败,fallback 到 io.Copy
|
||||
@ -124,9 +109,9 @@ func copyFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) err
|
||||
return err
|
||||
}
|
||||
|
||||
// platformSendfile 平台特定的 sendfile 实现。
|
||||
// platformSendfile 非 Linux 平台的 sendfile 实现。
|
||||
//
|
||||
// 根据运行平台选择合适的零拷贝传输方式。
|
||||
// macOS 和 Windows 不支持 sendfile 系统调用,返回 ENOTSUP 触发 fallback。
|
||||
//
|
||||
// 参数:
|
||||
// - conn: 目标网络连接
|
||||
@ -135,90 +120,9 @@ func copyFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) err
|
||||
// - length: 传输长度
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 传输错误或不支持时返回 ENOTSUP
|
||||
// - error: 始终返回 ENOTSUP,表示不支持
|
||||
func platformSendfile(conn net.Conn, file *os.File, offset, length int64) error {
|
||||
switch runtime.GOOS {
|
||||
case platformLinux:
|
||||
return linuxSendfile(conn, file.Fd(), offset, length)
|
||||
case "darwin":
|
||||
// macOS sendfile 签名复杂,简化使用 fallback
|
||||
return syscall.ENOTSUP
|
||||
case "windows":
|
||||
// Windows TransmitFile 需要特殊 API
|
||||
return syscall.ENOTSUP
|
||||
default:
|
||||
return syscall.ENOTSUP
|
||||
}
|
||||
}
|
||||
|
||||
// linuxSendfile Linux sendfile 系统调用。
|
||||
//
|
||||
// 使用 Linux 特有的 sendfile 系统调用实现零拷贝传输。
|
||||
//
|
||||
// 参数:
|
||||
// - conn: 目标网络连接
|
||||
// - fileFd: 源文件描述符
|
||||
// - offset: 文件起始偏移量
|
||||
// - length: 传输长度
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 系统调用错误
|
||||
func linuxSendfile(conn net.Conn, fileFd uintptr, _, length int64) error {
|
||||
socketFd, err := getSocketFd(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Linux sendfile: sendfile(out_fd, in_fd, offset, count)
|
||||
var sent int64
|
||||
remain := length
|
||||
|
||||
for remain > 0 {
|
||||
n, err := syscall.Sendfile(int(socketFd), int(fileFd), nil, int(remain))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
break // EOF
|
||||
}
|
||||
sent += int64(n)
|
||||
remain -= int64(n)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSocketFd 获取 socket 文件描述符。
|
||||
//
|
||||
// 从网络连接中提取底层的文件描述符,用于 sendfile 系统调用。
|
||||
//
|
||||
// 参数:
|
||||
// - conn: 网络连接对象
|
||||
//
|
||||
// 返回值:
|
||||
// - uintptr: 文件描述符
|
||||
// - error: 不支持的连接类型或获取失败时返回错误
|
||||
//
|
||||
// 支持的连接类型:
|
||||
// - *net.TCPConn
|
||||
// - *net.UnixConn
|
||||
func getSocketFd(conn net.Conn) (uintptr, error) {
|
||||
switch c := conn.(type) {
|
||||
case *net.TCPConn:
|
||||
file, err := c.File()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
return file.Fd(), nil
|
||||
case *net.UnixConn:
|
||||
file, err := c.File()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
return file.Fd(), nil
|
||||
default:
|
||||
return 0, syscall.ENOTSUP
|
||||
}
|
||||
}
|
||||
// macOS sendfile 签名复杂,简化使用 fallback
|
||||
// Windows TransmitFile 需要特殊 API
|
||||
return syscall.ENOTSUP
|
||||
}
|
||||
135
internal/handler/sendfile_linux.go
Normal file
135
internal/handler/sendfile_linux.go
Normal file
@ -0,0 +1,135 @@
|
||||
//go:build linux
|
||||
|
||||
// Package handler 提供 HTTP 请求处理器,包括路由、静态文件服务和零拷贝传输。
|
||||
//
|
||||
// 该文件包含 Linux 平台完整的 sendfile 实现(零拷贝 + 公共函数)。
|
||||
//
|
||||
// 作者:xfy
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
const (
|
||||
// MinSendfileSize 使用 sendfile 的最小文件大小(8KB)。
|
||||
// 小于该值的文件使用普通 io.Copy,避免系统调用开销。
|
||||
MinSendfileSize = 8 * 1024
|
||||
)
|
||||
|
||||
// platformLinux Linux 平台标识符。
|
||||
const platformLinux = "linux"
|
||||
|
||||
// SendFile 零拷贝文件传输。
|
||||
//
|
||||
// 大文件使用系统调用直接从文件传输到 socket,避免用户空间拷贝,
|
||||
// 从而减少 CPU 和内存开销,提升传输性能。
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: fasthttp 请求上下文,用于获取底层连接
|
||||
// - file: 要传输的文件对象
|
||||
// - offset: 文件起始偏移量(字节)
|
||||
// - length: 传输长度(字节),-1 表示传输到文件末尾
|
||||
//
|
||||
// 返回值:
|
||||
// - error: 传输过程中的错误
|
||||
func SendFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) error {
|
||||
// 小文件使用普通 io.Copy
|
||||
if length < MinSendfileSize {
|
||||
return copyFile(ctx, file, offset, length)
|
||||
}
|
||||
|
||||
// 尝试获取 socket 文件描述符
|
||||
conn := getNetConn(ctx)
|
||||
if conn == nil {
|
||||
return copyFile(ctx, file, offset, length)
|
||||
}
|
||||
|
||||
// Linux 平台使用 sendfile 系统调用
|
||||
err := linuxSendfile(conn, file.Fd(), offset, length)
|
||||
if err != nil {
|
||||
// sendfile 失败,fallback 到 io.Copy
|
||||
return copyFile(ctx, file, offset, length)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getNetConn 从 fasthttp.RequestCtx 获取底层 net.Conn。
|
||||
func getNetConn(ctx *fasthttp.RequestCtx) net.Conn {
|
||||
return ctx.Conn()
|
||||
}
|
||||
|
||||
// copyFile 普通文件拷贝(fallback)。
|
||||
func copyFile(ctx *fasthttp.RequestCtx, file *os.File, offset, length int64) error {
|
||||
if offset > 0 {
|
||||
if _, err := file.Seek(offset, io.SeekStart); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if length > 0 {
|
||||
_, err := io.CopyN(ctx, file, length)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := io.Copy(ctx, file)
|
||||
return err
|
||||
}
|
||||
|
||||
// linuxSendfile Linux sendfile 系统调用。
|
||||
//
|
||||
// 使用 Linux 特有的 sendfile 系统调用实现零拷贝传输。
|
||||
func linuxSendfile(conn net.Conn, fileFd uintptr, _, length int64) error {
|
||||
socketFd, err := getSocketFd(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Linux sendfile: sendfile(out_fd, in_fd, offset, count)
|
||||
var sent int64
|
||||
remain := length
|
||||
|
||||
for remain > 0 {
|
||||
n, err := syscall.Sendfile(int(socketFd), int(fileFd), nil, int(remain))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n == 0 {
|
||||
break // EOF
|
||||
}
|
||||
sent += int64(n)
|
||||
remain -= int64(n)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSocketFd 获取 socket 文件描述符。
|
||||
//
|
||||
// 从网络连接中提取底层的文件描述符,用于 sendfile 系统调用。
|
||||
func getSocketFd(conn net.Conn) (uintptr, error) {
|
||||
switch c := conn.(type) {
|
||||
case *net.TCPConn:
|
||||
file, err := c.File()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
return file.Fd(), nil
|
||||
case *net.UnixConn:
|
||||
file, err := c.File()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
return file.Fd(), nil
|
||||
default:
|
||||
return 0, syscall.ENOTSUP
|
||||
}
|
||||
}
|
||||
248
internal/handler/sendfile_other_test.go
Normal file
248
internal/handler/sendfile_other_test.go
Normal file
@ -0,0 +1,248 @@
|
||||
//go:build !linux
|
||||
|
||||
// Package handler 提供 Sendfile 功能的测试(非 Linux 平台)。
|
||||
//
|
||||
// 该文件测试非 Linux 平台的 Sendfile 功能。
|
||||
//
|
||||
// 作者:xfy
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func TestMinSendfileSize(t *testing.T) {
|
||||
if MinSendfileSize != 8*1024 {
|
||||
t.Errorf("Expected MinSendfileSize 8KB, got %d", MinSendfileSize)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPlatformSendfile_NonLinux 测试非 Linux 平台的 sendfile 行为
|
||||
func TestPlatformSendfile_NonLinux(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
content := []byte("test content")
|
||||
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(tmpFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open file: %v", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
err = platformSendfile(nil, file, 0, int64(len(content)))
|
||||
if err != syscall.ENOTSUP {
|
||||
t.Errorf("expected ENOTSUP on non-Linux, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopyFile 测试 copyFile fallback 函数
|
||||
func TestCopyFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
|
||||
content := []byte("Hello, World! This is test content for copyFile.")
|
||||
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(tmpFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open file: %v", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
offset int64
|
||||
length int64
|
||||
wantLen int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "full file",
|
||||
offset: 0,
|
||||
length: 0,
|
||||
wantLen: len(content),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "with length",
|
||||
offset: 0,
|
||||
length: 10,
|
||||
wantLen: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "with offset",
|
||||
offset: 7,
|
||||
length: 5,
|
||||
wantLen: 5,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "offset beyond file",
|
||||
offset: 1000,
|
||||
length: 10,
|
||||
wantLen: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, _ = file.Seek(0, io.SeekStart)
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
err := copyFile(ctx, file, tt.offset, tt.length)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
body := ctx.Response.Body()
|
||||
if len(body) != tt.wantLen {
|
||||
t.Errorf("expected body length %d, got %d", tt.wantLen, len(body))
|
||||
}
|
||||
if tt.wantLen > 0 && tt.length == 0 {
|
||||
if string(body) != string(content[tt.offset:]) {
|
||||
t.Errorf("body content mismatch")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendFile_SmallFile 测试小文件发送(使用 fallback)
|
||||
func TestSendFile_SmallFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "small.txt")
|
||||
|
||||
content := []byte("small file content")
|
||||
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(tmpFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open file: %v", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
err = SendFile(ctx, file, 0, int64(len(content)))
|
||||
if err != nil {
|
||||
t.Errorf("SendFile failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(ctx.Response.Body(), content) {
|
||||
t.Errorf("Expected body %s, got %s", content, ctx.Response.Body())
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendFile_WithOffset 测试带偏移量的文件发送
|
||||
func TestSendFile_WithOffset(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
|
||||
content := []byte("0123456789ABCDEF")
|
||||
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(tmpFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open file: %v", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
err = SendFile(ctx, file, 5, 5)
|
||||
if err != nil {
|
||||
t.Errorf("SendFile failed: %v", err)
|
||||
}
|
||||
|
||||
expected := content[5:10]
|
||||
if !bytes.Equal(ctx.Response.Body(), expected) {
|
||||
t.Errorf("Expected body %s, got %s", expected, ctx.Response.Body())
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendFile_ZeroLength 测试零长度文件
|
||||
func TestSendFile_ZeroLength(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "empty.txt")
|
||||
|
||||
if err := os.WriteFile(tmpFile, []byte{}, 0644); err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(tmpFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open file: %v", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
err = SendFile(ctx, file, 0, 0)
|
||||
if err != nil {
|
||||
t.Errorf("SendFile failed: %v", err)
|
||||
}
|
||||
|
||||
if len(ctx.Response.Body()) != 0 {
|
||||
t.Errorf("Expected empty body, got %s", ctx.Response.Body())
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetNetConn 测试获取底层连接
|
||||
func TestGetNetConn(_ *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
conn := getNetConn(ctx)
|
||||
_ = conn
|
||||
}
|
||||
|
||||
// TestCopyFile_Error 测试 copyFile 错误情况
|
||||
func TestCopyFile_Error(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
|
||||
content := []byte("test content")
|
||||
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(tmpFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open file: %v", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
err = copyFile(ctx, file, 1000, 10)
|
||||
if err == nil {
|
||||
t.Error("Expected error for offset beyond file size")
|
||||
}
|
||||
}
|
||||
@ -1,16 +1,11 @@
|
||||
// Package handler 提供 Sendfile 功能的测试。
|
||||
//go:build linux
|
||||
|
||||
// Package handler 提供 Sendfile 功能的 Linux 平台测试。
|
||||
//
|
||||
// 该文件测试 Sendfile 模块的各项功能,包括:
|
||||
// - 最小 Sendfile 大小
|
||||
// - 平台 Sendfile 行为
|
||||
// - 文件复制功能
|
||||
// - 非 Linux 平台行为
|
||||
// 该文件测试 Linux 平台特有的 Sendfile 功能,包括:
|
||||
// - Linux sendfile 系统调用
|
||||
// - 套接字文件描述符获取
|
||||
// - 小文件发送
|
||||
// - 带偏移量发送
|
||||
// - 零长度文件
|
||||
// - 网络连接获取
|
||||
// - 错误处理
|
||||
// - 小文件发送 fallback
|
||||
//
|
||||
// 作者:xfy
|
||||
package handler
|
||||
@ -21,7 +16,6 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
@ -35,27 +29,6 @@ func TestMinSendfileSize(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlatformSendfile(t *testing.T) {
|
||||
// 创建临时文件
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
|
||||
content := []byte("Hello, World! This is a test file for sendfile.")
|
||||
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(tmpFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open file: %v", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
// 测试平台 sendfile(小文件会 fallback 到 copyFile)
|
||||
// 由于没有真实的网络连接,这个测试主要验证不会崩溃
|
||||
_ = platformSendfile(nil, file, 0, int64(len(content)))
|
||||
}
|
||||
|
||||
// TestCopyFile 测试 copyFile fallback 函数
|
||||
func TestCopyFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
@ -82,7 +55,7 @@ func TestCopyFile(t *testing.T) {
|
||||
{
|
||||
name: "full file",
|
||||
offset: 0,
|
||||
length: 0, // 0 means copy all
|
||||
length: 0,
|
||||
wantLen: len(content),
|
||||
wantErr: false,
|
||||
},
|
||||
@ -105,16 +78,13 @@ func TestCopyFile(t *testing.T) {
|
||||
offset: 1000,
|
||||
length: 10,
|
||||
wantLen: 0,
|
||||
wantErr: true, // io.CopyN returns EOF error
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 重置文件位置
|
||||
_, _ = file.Seek(0, io.SeekStart)
|
||||
|
||||
// 创建响应上下文
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
err := copyFile(ctx, file, tt.offset, tt.length)
|
||||
@ -131,7 +101,6 @@ func TestCopyFile(t *testing.T) {
|
||||
t.Errorf("expected body length %d, got %d", tt.wantLen, len(body))
|
||||
}
|
||||
if tt.wantLen > 0 && tt.length == 0 {
|
||||
// 全量拷贝时验证内容
|
||||
if string(body) != string(content[tt.offset:]) {
|
||||
t.Errorf("body content mismatch")
|
||||
}
|
||||
@ -141,31 +110,6 @@ func TestCopyFile(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPlatformSendfile_NonLinux 测试非 Linux 平台的 sendfile 行为
|
||||
func TestPlatformSendfile_NonLinux(t *testing.T) {
|
||||
if runtime.GOOS == platformLinux {
|
||||
t.Skip("this test is for non-Linux platforms")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
content := []byte("test content")
|
||||
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(tmpFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open file: %v", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
err = platformSendfile(nil, file, 0, int64(len(content)))
|
||||
if err != syscall.ENOTSUP {
|
||||
t.Errorf("expected ENOTSUP on non-Linux, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetSocketFd_NilConn 测试 nil 连接的情况
|
||||
func TestGetSocketFd_NilConn(t *testing.T) {
|
||||
_, err := getSocketFd(nil)
|
||||
@ -176,7 +120,6 @@ func TestGetSocketFd_NilConn(t *testing.T) {
|
||||
|
||||
// TestGetSocketFd_UnsupportedType 测试不支持的连接类型
|
||||
func TestGetSocketFd_UnsupportedType(t *testing.T) {
|
||||
// 创建一个不支持的连接类型
|
||||
conn := &mockConn{}
|
||||
_, err := getSocketFd(conn)
|
||||
if err != syscall.ENOTSUP {
|
||||
@ -201,7 +144,6 @@ func TestSendFile_SmallFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "small.txt")
|
||||
|
||||
// 创建小文件 (< 8KB)
|
||||
content := []byte("small file content")
|
||||
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
@ -221,7 +163,6 @@ func TestSendFile_SmallFile(t *testing.T) {
|
||||
t.Errorf("SendFile failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证响应体
|
||||
if !bytes.Equal(ctx.Response.Body(), content) {
|
||||
t.Errorf("Expected body %s, got %s", content, ctx.Response.Body())
|
||||
}
|
||||
@ -232,7 +173,7 @@ func TestSendFile_WithOffset(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
|
||||
content := []byte("0123456789ABCDEF") // 16 bytes
|
||||
content := []byte("0123456789ABCDEF")
|
||||
if err := os.WriteFile(tmpFile, content, 0644); err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
@ -246,13 +187,12 @@ func TestSendFile_WithOffset(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
// 从偏移量 5 开始,读取 5 字节
|
||||
err = SendFile(ctx, file, 5, 5)
|
||||
if err != nil {
|
||||
t.Errorf("SendFile failed: %v", err)
|
||||
}
|
||||
|
||||
expected := content[5:10] // "56789"
|
||||
expected := content[5:10]
|
||||
if !bytes.Equal(ctx.Response.Body(), expected) {
|
||||
t.Errorf("Expected body %s, got %s", expected, ctx.Response.Body())
|
||||
}
|
||||
@ -291,28 +231,10 @@ func TestGetNetConn(_ *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
// fasthttp 会创建内部连接,所以这里测试能正常获取
|
||||
conn := getNetConn(ctx)
|
||||
// 主要验证不会崩溃
|
||||
_ = conn
|
||||
}
|
||||
|
||||
// TestSendFile_NilFile 测试空文件指针
|
||||
func TestSendFile_NilFile(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
// 传入 nil 文件应该 panic 或返回错误
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
// 没有 panic,检查是否有错误返回
|
||||
t.Log("expected panic did not occur for nil file")
|
||||
}
|
||||
}()
|
||||
|
||||
// 这个测试主要确保不会静默失败
|
||||
}
|
||||
|
||||
// TestCopyFile_Error 测试 copyFile 错误情况
|
||||
func TestCopyFile_Error(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
@ -332,7 +254,6 @@ func TestCopyFile_Error(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
// 测试偏移量超出文件大小
|
||||
err = copyFile(ctx, file, 1000, 10)
|
||||
if err == nil {
|
||||
t.Error("Expected error for offset beyond file size")
|
||||
@ -341,10 +262,6 @@ func TestCopyFile_Error(t *testing.T) {
|
||||
|
||||
// TestLinuxSendfile_NilConn 测试 linuxSendfile 空连接
|
||||
func TestLinuxSendfile_NilConn(t *testing.T) {
|
||||
if runtime.GOOS != platformLinux {
|
||||
t.Skip("This test is for Linux only")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
content := []byte("test")
|
||||
@ -360,4 +277,4 @@ func TestLinuxSendfile_NilConn(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user