test(handler,middleware,server): 新增 try_files、错误页面、pprof 单元测试
- static_test.go: 新增 try_files 配置解析、占位符解析、SPA 场景测试 - errorpage_test.go: 新增错误页面管理器完整测试覆盖 - errorintercept_test.go: 新增错误拦截中间件功能测试 - pprof_test.go: 新增 pprof 性能分析端点测试 Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
4d108267c3
commit
b6f8894d78
667
internal/handler/errorpage_test.go
Normal file
667
internal/handler/errorpage_test.go
Normal file
@ -0,0 +1,667 @@
|
||||
// Package handler 提供错误页面管理器功能的测试。
|
||||
//
|
||||
// 该文件测试错误页面管理模块的各项功能,包括:
|
||||
// - 管理器构造函数
|
||||
// - 空配置处理
|
||||
// - 部分加载失败处理
|
||||
// - 全部加载失败处理
|
||||
// - 错误页面查找
|
||||
// - 默认页面 fallback
|
||||
// - 状态码检查
|
||||
// - 响应码覆盖
|
||||
// - 配置状态检查
|
||||
//
|
||||
// 作者:xfy
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// TestNewErrorPageManager_EmptyConfig 测试空配置情况
|
||||
func TestNewErrorPageManager_EmptyConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.ErrorPageConfig
|
||||
want *ErrorPageManager
|
||||
}{
|
||||
{
|
||||
name: "完全空配置",
|
||||
cfg: config.ErrorPageConfig{},
|
||||
},
|
||||
{
|
||||
name: "空的 pages 和 default",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{},
|
||||
Default: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "设置了 responseCode 但无页面",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{},
|
||||
Default: "",
|
||||
ResponseCode: 200,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := NewErrorPageManager(&tt.cfg)
|
||||
if err != nil {
|
||||
t.Errorf("NewErrorPageManager() 不应返回错误, got %v", err)
|
||||
}
|
||||
if manager == nil {
|
||||
t.Fatal("NewErrorPageManager() 返回 nil")
|
||||
}
|
||||
if manager.IsConfigured() {
|
||||
t.Error("IsConfigured() 应返回 false")
|
||||
}
|
||||
if manager.GetResponseCode() != tt.cfg.ResponseCode {
|
||||
t.Errorf("GetResponseCode() = %d, want %d", manager.GetResponseCode(), tt.cfg.ResponseCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewErrorPageManager_PartialLoadFailure 测试部分加载失败
|
||||
func TestNewErrorPageManager_PartialLoadFailure(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// 创建有效的错误页面文件
|
||||
validPage := filepath.Join(tmpDir, "404.html")
|
||||
if err := os.WriteFile(validPage, []byte("404 page"), 0644); err != nil {
|
||||
t.Fatalf("创建测试文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 不存在的文件路径
|
||||
nonExistent := filepath.Join(tmpDir, "nonexistent", "500.html")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.ErrorPageConfig
|
||||
wantConfigured bool
|
||||
wantPages map[int]bool
|
||||
wantPartialErr bool
|
||||
}{
|
||||
{
|
||||
name: "一个成功一个失败",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: validPage,
|
||||
500: nonExistent,
|
||||
},
|
||||
},
|
||||
wantConfigured: true,
|
||||
wantPages: map[int]bool{404: true},
|
||||
wantPartialErr: true,
|
||||
},
|
||||
{
|
||||
name: "特定页面成功默认失败",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: validPage,
|
||||
},
|
||||
Default: filepath.Join(tmpDir, "default.html"),
|
||||
},
|
||||
wantConfigured: true,
|
||||
wantPages: map[int]bool{404: true},
|
||||
wantPartialErr: true,
|
||||
},
|
||||
{
|
||||
name: "默认成功特定页面失败",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: nonExistent,
|
||||
},
|
||||
Default: validPage,
|
||||
},
|
||||
wantConfigured: true,
|
||||
wantPages: map[int]bool{},
|
||||
wantPartialErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := NewErrorPageManager(&tt.cfg)
|
||||
|
||||
// 检查是否为部分错误
|
||||
if tt.wantPartialErr {
|
||||
if err == nil {
|
||||
t.Error("期望返回部分加载错误,但 got nil")
|
||||
return
|
||||
}
|
||||
if _, ok := err.(*PartialLoadError); !ok {
|
||||
t.Errorf("期望返回 *PartialLoadError,但 got %T", err)
|
||||
return
|
||||
}
|
||||
} else if err != nil {
|
||||
t.Errorf("NewErrorPageManager() 不应返回错误, got %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if manager == nil {
|
||||
t.Fatal("NewErrorPageManager() 返回 nil")
|
||||
}
|
||||
|
||||
if got := manager.IsConfigured(); got != tt.wantConfigured {
|
||||
t.Errorf("IsConfigured() = %v, want %v", got, tt.wantConfigured)
|
||||
}
|
||||
|
||||
// 检查特定页面是否存在
|
||||
for code, shouldExist := range tt.wantPages {
|
||||
if got := manager.HasPage(code); got != shouldExist {
|
||||
t.Errorf("HasPage(%d) = %v, want %v", code, got, shouldExist)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewErrorPageManager_AllLoadFailure 测试全部加载失败
|
||||
func TestNewErrorPageManager_AllLoadFailure(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// 不存在的文件路径
|
||||
nonExistent1 := filepath.Join(tmpDir, "nonexistent1.html")
|
||||
nonExistent2 := filepath.Join(tmpDir, "nonexistent2.html")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.ErrorPageConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "单个页面加载失败",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: nonExistent1,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "多个页面全部加载失败",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: nonExistent1,
|
||||
500: nonExistent2,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "默认页面加载失败",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Default: nonExistent1,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "页面和默认都加载失败",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: nonExistent1,
|
||||
},
|
||||
Default: nonExistent2,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := NewErrorPageManager(&tt.cfg)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("期望返回错误,但 got nil")
|
||||
}
|
||||
// 全部失败时不应返回 PartialLoadError
|
||||
if _, ok := err.(*PartialLoadError); ok {
|
||||
t.Error("全部失败时不应返回 *PartialLoadError")
|
||||
}
|
||||
if manager != nil {
|
||||
t.Error("全部失败时 manager 应为 nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("NewErrorPageManager() 不应返回错误, got %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPartialLoadError_Error 测试错误消息格式
|
||||
func TestPartialLoadError_Error(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errors map[int]error
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "单个错误",
|
||||
errors: map[int]error{404: os.ErrNotExist},
|
||||
want: "部分错误页面加载失败: 1 个错误",
|
||||
},
|
||||
{
|
||||
name: "多个错误",
|
||||
errors: map[int]error{404: os.ErrNotExist, 500: os.ErrPermission},
|
||||
want: "部分错误页面加载失败: 2 个错误",
|
||||
},
|
||||
{
|
||||
name: "包含默认页面错误",
|
||||
errors: map[int]error{0: os.ErrNotExist, 404: os.ErrNotExist},
|
||||
want: "部分错误页面加载失败: 2 个错误",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := &PartialLoadError{Errors: tt.errors}
|
||||
got := err.Error()
|
||||
if !strings.Contains(got, tt.want) {
|
||||
t.Errorf("Error() = %q, want contain %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorPageManager_GetPage 测试获取错误页面
|
||||
func TestErrorPageManager_GetPage(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// 创建测试页面文件
|
||||
page404 := filepath.Join(tmpDir, "404.html")
|
||||
page500 := filepath.Join(tmpDir, "500.html")
|
||||
pageDefault := filepath.Join(tmpDir, "default.html")
|
||||
|
||||
if err := os.WriteFile(page404, []byte("404 page content"), 0644); err != nil {
|
||||
t.Fatalf("创建 404 页面失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(page500, []byte("500 page content"), 0644); err != nil {
|
||||
t.Fatalf("创建 500 页面失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(pageDefault, []byte("default page content"), 0644); err != nil {
|
||||
t.Fatalf("创建默认页面失败: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.ErrorPageConfig
|
||||
requestCode int
|
||||
wantContent string
|
||||
wantFound bool
|
||||
wantResponseCode int
|
||||
}{
|
||||
{
|
||||
name: "找到特定状态码页面",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: page404,
|
||||
500: page500,
|
||||
},
|
||||
},
|
||||
requestCode: 404,
|
||||
wantContent: "404 page content",
|
||||
wantFound: true,
|
||||
wantResponseCode: 404,
|
||||
},
|
||||
{
|
||||
name: "找到另一个状态码页面",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: page404,
|
||||
500: page500,
|
||||
},
|
||||
},
|
||||
requestCode: 500,
|
||||
wantContent: "500 page content",
|
||||
wantFound: true,
|
||||
wantResponseCode: 500,
|
||||
},
|
||||
{
|
||||
name: "未找到特定页面,使用默认页面",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: page404,
|
||||
},
|
||||
Default: pageDefault,
|
||||
},
|
||||
requestCode: 500,
|
||||
wantContent: "default page content",
|
||||
wantFound: true,
|
||||
wantResponseCode: 500,
|
||||
},
|
||||
{
|
||||
name: "未找到页面且无默认页面",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: page404,
|
||||
},
|
||||
},
|
||||
requestCode: 500,
|
||||
wantContent: "",
|
||||
wantFound: false,
|
||||
wantResponseCode: 500,
|
||||
},
|
||||
{
|
||||
name: "有默认页面但请求特定页面",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Default: pageDefault,
|
||||
},
|
||||
requestCode: 503,
|
||||
wantContent: "default page content",
|
||||
wantFound: true,
|
||||
wantResponseCode: 503,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := NewErrorPageManager(&tt.cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewErrorPageManager() 失败: %v", err)
|
||||
}
|
||||
|
||||
content, found, responseCode := manager.GetPage(tt.requestCode)
|
||||
|
||||
if found != tt.wantFound {
|
||||
t.Errorf("GetPage() found = %v, want %v", found, tt.wantFound)
|
||||
}
|
||||
if string(content) != tt.wantContent {
|
||||
t.Errorf("GetPage() content = %q, want %q", string(content), tt.wantContent)
|
||||
}
|
||||
if responseCode != tt.wantResponseCode {
|
||||
t.Errorf("GetPage() responseCode = %d, want %d", responseCode, tt.wantResponseCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorPageManager_GetPage_WithResponseCodeOverride 测试响应码覆盖
|
||||
func TestErrorPageManager_GetPage_WithResponseCodeOverride(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
page404 := filepath.Join(tmpDir, "404.html")
|
||||
if err := os.WriteFile(page404, []byte("404 page"), 0644); err != nil {
|
||||
t.Fatalf("创建测试文件失败: %v", err)
|
||||
}
|
||||
|
||||
cfg := config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: page404,
|
||||
},
|
||||
ResponseCode: 200, // 覆盖响应码
|
||||
}
|
||||
|
||||
manager, err := NewErrorPageManager(&cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewErrorPageManager() 失败: %v", err)
|
||||
}
|
||||
|
||||
_, _, responseCode := manager.GetPage(404)
|
||||
if responseCode != 200 {
|
||||
t.Errorf("GetPage() responseCode = %d, want 200", responseCode)
|
||||
}
|
||||
|
||||
// 测试默认页面也受覆盖影响
|
||||
cfg2 := config.ErrorPageConfig{
|
||||
Default: page404,
|
||||
ResponseCode: 418,
|
||||
}
|
||||
|
||||
manager2, err := NewErrorPageManager(&cfg2)
|
||||
if err != nil {
|
||||
t.Fatalf("NewErrorPageManager() 失败: %v", err)
|
||||
}
|
||||
|
||||
_, _, responseCode2 := manager2.GetPage(500)
|
||||
if responseCode2 != 418 {
|
||||
t.Errorf("GetPage() responseCode = %d, want 418", responseCode2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorPageManager_HasPage 测试页面存在检查
|
||||
func TestErrorPageManager_HasPage(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
page404 := filepath.Join(tmpDir, "404.html")
|
||||
pageDefault := filepath.Join(tmpDir, "default.html")
|
||||
|
||||
if err := os.WriteFile(page404, []byte("404"), 0644); err != nil {
|
||||
t.Fatalf("创建测试文件失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(pageDefault, []byte("default"), 0644); err != nil {
|
||||
t.Fatalf("创建测试文件失败: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.ErrorPageConfig
|
||||
code int
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "配置的特定页面存在",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{404: page404},
|
||||
},
|
||||
code: 404,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "未配置的页面但有默认页面",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{404: page404},
|
||||
Default: pageDefault,
|
||||
},
|
||||
code: 500,
|
||||
expected: true, // 因为有默认页面
|
||||
},
|
||||
{
|
||||
name: "只有默认页面",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Default: pageDefault,
|
||||
},
|
||||
code: 404,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "页面不存在且无默认页面",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{404: page404},
|
||||
},
|
||||
code: 500,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := NewErrorPageManager(&tt.cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewErrorPageManager() 失败: %v", err)
|
||||
}
|
||||
|
||||
if got := manager.HasPage(tt.code); got != tt.expected {
|
||||
t.Errorf("HasPage(%d) = %v, want %v", tt.code, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorPageManager_GetResponseCode 测试获取响应码覆盖值
|
||||
func TestErrorPageManager_GetResponseCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code int
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "无覆盖",
|
||||
code: 0,
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "覆盖为 200",
|
||||
code: 200,
|
||||
want: 200,
|
||||
},
|
||||
{
|
||||
name: "覆盖为 418",
|
||||
code: 418,
|
||||
want: 418,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := config.ErrorPageConfig{
|
||||
ResponseCode: tt.code,
|
||||
}
|
||||
manager, err := NewErrorPageManager(&cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewErrorPageManager() 失败: %v", err)
|
||||
}
|
||||
|
||||
if got := manager.GetResponseCode(); got != tt.want {
|
||||
t.Errorf("GetResponseCode() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorPageManager_IsConfigured 测试配置状态检查
|
||||
func TestErrorPageManager_IsConfigured(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
page404 := filepath.Join(tmpDir, "404.html")
|
||||
pageDefault := filepath.Join(tmpDir, "default.html")
|
||||
|
||||
if err := os.WriteFile(page404, []byte("404"), 0644); err != nil {
|
||||
t.Fatalf("创建测试文件失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(pageDefault, []byte("default"), 0644); err != nil {
|
||||
t.Fatalf("创建测试文件失败: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.ErrorPageConfig
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "空配置",
|
||||
cfg: config.ErrorPageConfig{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "配置特定页面",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{404: page404},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "配置默认页面",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Default: pageDefault,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "配置两者",
|
||||
cfg: config.ErrorPageConfig{
|
||||
Pages: map[int]string{404: page404},
|
||||
Default: pageDefault,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := NewErrorPageManager(&tt.cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewErrorPageManager() 失败: %v", err)
|
||||
}
|
||||
|
||||
if got := manager.IsConfigured(); got != tt.expected {
|
||||
t.Errorf("IsConfigured() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorPageManager_SuccessfulLoad 测试成功加载场景
|
||||
func TestErrorPageManager_SuccessfulLoad(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// 创建多个测试页面
|
||||
pages := map[int]string{
|
||||
404: filepath.Join(tmpDir, "404.html"),
|
||||
500: filepath.Join(tmpDir, "500.html"),
|
||||
403: filepath.Join(tmpDir, "403.html"),
|
||||
}
|
||||
defaultPage := filepath.Join(tmpDir, "default.html")
|
||||
|
||||
for code, path := range pages {
|
||||
content := []byte(fmt.Sprintf("Error %d page", code))
|
||||
if err := os.WriteFile(path, content, 0644); err != nil {
|
||||
t.Fatalf("创建页面 %d 失败: %v", code, err)
|
||||
}
|
||||
}
|
||||
if err := os.WriteFile(defaultPage, []byte("Default error page"), 0644); err != nil {
|
||||
t.Fatalf("创建默认页面失败: %v", err)
|
||||
}
|
||||
|
||||
cfg := config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: pages[404],
|
||||
500: pages[500],
|
||||
403: pages[403],
|
||||
},
|
||||
Default: defaultPage,
|
||||
}
|
||||
|
||||
manager, err := NewErrorPageManager(&cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewErrorPageManager() 失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证所有页面都能正常访问
|
||||
for code := range pages {
|
||||
content, found, responseCode := manager.GetPage(code)
|
||||
if !found {
|
||||
t.Errorf("GetPage(%d) 应返回 found=true", code)
|
||||
}
|
||||
if responseCode != code {
|
||||
t.Errorf("GetPage(%d) responseCode = %d, want %d", code, responseCode, code)
|
||||
}
|
||||
wantContent := fmt.Sprintf("Error %d page", code)
|
||||
if string(content) != wantContent {
|
||||
t.Errorf("GetPage(%d) content = %q, want %q", code, string(content), wantContent)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证未配置的状态码返回默认页面
|
||||
content, found, responseCode := manager.GetPage(503)
|
||||
if !found {
|
||||
t.Error("GetPage(503) 应返回 found=true (默认页面)")
|
||||
}
|
||||
if responseCode != 503 {
|
||||
t.Errorf("GetPage(503) responseCode = %d, want 503", responseCode)
|
||||
}
|
||||
if string(content) != "Default error page" {
|
||||
t.Errorf("GetPage(503) content = %q, want default page", string(content))
|
||||
}
|
||||
}
|
||||
@ -581,3 +581,561 @@ func TestStaticHandler_Handle_Symlink(t *testing.T) {
|
||||
t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_SetTryFiles 测试 SetTryFiles 配置设置
|
||||
func TestStaticHandler_SetTryFiles(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tryFiles []string
|
||||
tryFilesPass bool
|
||||
wantTryFiles []string
|
||||
wantPass bool
|
||||
}{
|
||||
{
|
||||
name: "基本配置",
|
||||
tryFiles: []string{"$uri", "$uri/", "/index.html"},
|
||||
tryFilesPass: false,
|
||||
wantTryFiles: []string{"$uri", "$uri/", "/index.html"},
|
||||
wantPass: false,
|
||||
},
|
||||
{
|
||||
name: "启用 tryFilesPass",
|
||||
tryFiles: []string{"$uri", "/fallback.html"},
|
||||
tryFilesPass: true,
|
||||
wantTryFiles: []string{"$uri", "/fallback.html"},
|
||||
wantPass: true,
|
||||
},
|
||||
{
|
||||
name: "空配置",
|
||||
tryFiles: []string{},
|
||||
tryFilesPass: false,
|
||||
wantTryFiles: []string{},
|
||||
wantPass: false,
|
||||
},
|
||||
{
|
||||
name: "单一项配置",
|
||||
tryFiles: []string{"/app.html"},
|
||||
tryFilesPass: false,
|
||||
wantTryFiles: []string{"/app.html"},
|
||||
wantPass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := NewStaticHandler("/var/www", "/", []string{"index.html"}, false)
|
||||
router := NewRouter()
|
||||
|
||||
handler.SetTryFiles(tt.tryFiles, tt.tryFilesPass, router)
|
||||
|
||||
// 验证配置
|
||||
if len(handler.tryFiles) != len(tt.wantTryFiles) {
|
||||
t.Errorf("tryFiles length = %d, want %d", len(handler.tryFiles), len(tt.wantTryFiles))
|
||||
}
|
||||
for i, v := range tt.wantTryFiles {
|
||||
if handler.tryFiles[i] != v {
|
||||
t.Errorf("tryFiles[%d] = %q, want %q", i, handler.tryFiles[i], v)
|
||||
}
|
||||
}
|
||||
if handler.tryFilesPass != tt.wantPass {
|
||||
t.Errorf("tryFilesPass = %v, want %v", handler.tryFilesPass, tt.wantPass)
|
||||
}
|
||||
if handler.router != router {
|
||||
t.Error("router 未正确设置")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_resolveTryFilePath 测试 resolveTryFilePath 占位符解析
|
||||
func TestStaticHandler_resolveTryFilePath(t *testing.T) {
|
||||
handler := NewStaticHandler("/var/www", "/", []string{"index.html"}, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tryFile string
|
||||
relPath string
|
||||
wantResult string
|
||||
}{
|
||||
{
|
||||
name: "$uri 占位符",
|
||||
tryFile: "$uri",
|
||||
relPath: "/api/user",
|
||||
wantResult: "/api/user",
|
||||
},
|
||||
{
|
||||
name: "$uri/ 占位符",
|
||||
tryFile: "$uri/",
|
||||
relPath: "/api/user",
|
||||
wantResult: "/api/user/",
|
||||
},
|
||||
{
|
||||
name: "绝对路径",
|
||||
tryFile: "/index.html",
|
||||
relPath: "/api/user",
|
||||
wantResult: "index.html",
|
||||
},
|
||||
{
|
||||
name: "普通文件名",
|
||||
tryFile: "fallback.html",
|
||||
relPath: "/api/user",
|
||||
wantResult: "fallback.html",
|
||||
},
|
||||
{
|
||||
name: "根路径 $uri",
|
||||
tryFile: "$uri",
|
||||
relPath: "/",
|
||||
wantResult: "/",
|
||||
},
|
||||
{
|
||||
name: "嵌套路径 $uri",
|
||||
tryFile: "$uri",
|
||||
relPath: "/assets/js/app.js",
|
||||
wantResult: "/assets/js/app.js",
|
||||
},
|
||||
{
|
||||
name: "带查询风格路径",
|
||||
tryFile: "$uri",
|
||||
relPath: "/path/to/file.txt",
|
||||
wantResult: "/path/to/file.txt",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := handler.resolveTryFilePath(tt.tryFile, tt.relPath)
|
||||
if got != tt.wantResult {
|
||||
t.Errorf("resolveTryFilePath(%q, %q) = %q, want %q", tt.tryFile, tt.relPath, got, tt.wantResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_handleTryFiles 测试 handleTryFiles 功能
|
||||
func TestStaticHandler_handleTryFiles(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(t *testing.T, root string)
|
||||
tryFiles []string
|
||||
path string
|
||||
wantStatus int
|
||||
wantContent string
|
||||
skipContent bool
|
||||
}{
|
||||
{
|
||||
name: "$uri 找到文件",
|
||||
setup: func(t *testing.T, root string) {
|
||||
if err := os.WriteFile(filepath.Join(root, "app.js"), []byte("app content"), 0644); err != nil {
|
||||
t.Fatalf("创建文件失败: %v", err)
|
||||
}
|
||||
},
|
||||
tryFiles: []string{"$uri", "$uri/", "/index.html"},
|
||||
path: "/app.js",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "app content",
|
||||
},
|
||||
{
|
||||
name: "$uri 未找到回退到 $uri/",
|
||||
setup: func(t *testing.T, root string) {
|
||||
dir := filepath.Join(root, "assets")
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
t.Fatalf("创建目录失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "index.html"), []byte("assets index"), 0644); err != nil {
|
||||
t.Fatalf("创建索引文件失败: %v", err)
|
||||
}
|
||||
},
|
||||
tryFiles: []string{"$uri", "$uri/", "/index.html"},
|
||||
path: "/assets",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "assets index",
|
||||
},
|
||||
{
|
||||
name: "回退到 fallback 文件",
|
||||
setup: func(t *testing.T, root string) {
|
||||
if err := os.WriteFile(filepath.Join(root, "index.html"), []byte("spa fallback"), 0644); err != nil {
|
||||
t.Fatalf("创建 fallback 文件失败: %v", err)
|
||||
}
|
||||
},
|
||||
tryFiles: []string{"$uri", "$uri/", "/index.html"},
|
||||
path: "/nonexistent",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "spa fallback",
|
||||
},
|
||||
{
|
||||
name: "所有 try_files 都未找到",
|
||||
setup: func(t *testing.T, root string) {
|
||||
// 不创建任何文件
|
||||
},
|
||||
tryFiles: []string{"$uri", "$uri/", "/index.html"},
|
||||
path: "/nonexistent",
|
||||
wantStatus: fasthttp.StatusNotFound,
|
||||
skipContent: true,
|
||||
},
|
||||
{
|
||||
name: "嵌套目录回退",
|
||||
setup: func(t *testing.T, root string) {
|
||||
if err := os.WriteFile(filepath.Join(root, "app.html"), []byte("app shell"), 0644); err != nil {
|
||||
t.Fatalf("创建 fallback 文件失败: %v", err)
|
||||
}
|
||||
},
|
||||
tryFiles: []string{"$uri", "/app.html"},
|
||||
path: "/user/profile",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "app shell",
|
||||
},
|
||||
{
|
||||
name: "路径前缀剥离",
|
||||
setup: func(t *testing.T, root string) {
|
||||
apiDir := filepath.Join(root, "api")
|
||||
if err := os.MkdirAll(apiDir, 0755); err != nil {
|
||||
t.Fatalf("创建目录失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(apiDir, "data.json"), []byte("json data"), 0644); err != nil {
|
||||
t.Fatalf("创建文件失败: %v", err)
|
||||
}
|
||||
},
|
||||
tryFiles: []string{"$uri"},
|
||||
path: "/static/api/data.json",
|
||||
wantStatus: fasthttp.StatusNotFound, // 路径前缀剥离后找不到
|
||||
skipContent: true,
|
||||
},
|
||||
{
|
||||
name: "空 try_files 数组",
|
||||
setup: func(t *testing.T, root string) {
|
||||
if err := os.WriteFile(filepath.Join(root, "test.txt"), []byte("test"), 0644); err != nil {
|
||||
t.Fatalf("创建文件失败: %v", err)
|
||||
}
|
||||
},
|
||||
tryFiles: []string{},
|
||||
path: "/test.txt",
|
||||
wantStatus: fasthttp.StatusOK, // 空 try_files 走标准处理流程
|
||||
wantContent: "test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tt.setup(t, tmpDir)
|
||||
|
||||
handler := NewStaticHandler(tmpDir, "/", []string{"index.html"}, false)
|
||||
handler.SetTryFiles(tt.tryFiles, false, nil)
|
||||
|
||||
ctx := newTestContext(t, tt.path)
|
||||
handler.Handle(ctx)
|
||||
|
||||
if got := ctx.Response.StatusCode(); got != tt.wantStatus {
|
||||
t.Errorf("状态码 = %d, want %d", got, tt.wantStatus)
|
||||
}
|
||||
|
||||
if !tt.skipContent && tt.wantContent != "" {
|
||||
got := string(ctx.Response.Body())
|
||||
if got != tt.wantContent {
|
||||
t.Errorf("内容 = %q, want %q", got, tt.wantContent)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_handleInternalRedirect 测试内部重定向功能
|
||||
func TestStaticHandler_handleInternalRedirect(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(t *testing.T, root string)
|
||||
tryFiles []string
|
||||
tryFilesPass bool
|
||||
path string
|
||||
wantStatus int
|
||||
wantContent string
|
||||
skipContent bool
|
||||
}{
|
||||
{
|
||||
name: "tryFilesPass false 直接服务文件",
|
||||
setup: func(t *testing.T, root string) {
|
||||
if err := os.WriteFile(filepath.Join(root, "index.html"), []byte("index content"), 0644); err != nil {
|
||||
t.Fatalf("创建文件失败: %v", err)
|
||||
}
|
||||
},
|
||||
tryFiles: []string{"$uri", "/index.html"},
|
||||
tryFilesPass: false,
|
||||
path: "/nonexistent",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "index content",
|
||||
},
|
||||
{
|
||||
name: "tryFilesPass true 触发重定向",
|
||||
setup: func(t *testing.T, root string) {
|
||||
if err := os.WriteFile(filepath.Join(root, "fallback.txt"), []byte("fallback content"), 0644); err != nil {
|
||||
t.Fatalf("创建文件失败: %v", err)
|
||||
}
|
||||
},
|
||||
tryFiles: []string{"$uri", "/fallback.txt"},
|
||||
tryFilesPass: true,
|
||||
path: "/nonexistent",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "fallback content",
|
||||
},
|
||||
{
|
||||
name: "内部重定向目标不存在",
|
||||
setup: func(t *testing.T, root string) {
|
||||
// 不创建 fallback 文件
|
||||
},
|
||||
tryFiles: []string{"$uri", "/fallback.html"},
|
||||
tryFilesPass: false,
|
||||
path: "/nonexistent",
|
||||
wantStatus: fasthttp.StatusNotFound,
|
||||
skipContent: true,
|
||||
},
|
||||
{
|
||||
name: "内部重定向目标是目录",
|
||||
setup: func(t *testing.T, root string) {
|
||||
dir := filepath.Join(root, "fallback")
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
t.Fatalf("创建目录失败: %v", err)
|
||||
}
|
||||
// 在 fallback 目录中创建一个 index.html
|
||||
if err := os.WriteFile(filepath.Join(dir, "index.html"), []byte("fallback index"), 0644); err != nil {
|
||||
t.Fatalf("创建 index.html 失败: %v", err)
|
||||
}
|
||||
},
|
||||
tryFiles: []string{"$uri", "$uri/", "/fallback"},
|
||||
tryFilesPass: false,
|
||||
path: "/nonexistent",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "fallback index",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tt.setup(t, tmpDir)
|
||||
|
||||
handler := NewStaticHandler(tmpDir, "/", []string{"index.html"}, false)
|
||||
router := NewRouter()
|
||||
handler.SetTryFiles(tt.tryFiles, tt.tryFilesPass, router)
|
||||
|
||||
// 注册路由处理器用于测试 tryFilesPass 重定向
|
||||
if tt.tryFilesPass {
|
||||
router.GET("/{filepath:*}", func(ctx *fasthttp.RequestCtx) {
|
||||
// 通配符路由,可以匹配任何路径
|
||||
path := string(ctx.Path())
|
||||
// 从 root 读取文件
|
||||
filePath := filepath.Join(tmpDir, path[1:]) // 去掉开头的 /
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
ctx.SetStatusCode(fasthttp.StatusNotFound)
|
||||
ctx.SetBodyString("Not Found")
|
||||
return
|
||||
}
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetBody(data)
|
||||
})
|
||||
}
|
||||
|
||||
ctx := newTestContext(t, tt.path)
|
||||
handler.Handle(ctx)
|
||||
|
||||
if got := ctx.Response.StatusCode(); got != tt.wantStatus {
|
||||
t.Errorf("状态码 = %d, want %d", got, tt.wantStatus)
|
||||
}
|
||||
|
||||
if !tt.skipContent && tt.wantContent != "" {
|
||||
got := string(ctx.Response.Body())
|
||||
if got != tt.wantContent {
|
||||
t.Errorf("内容 = %q, want %q", got, tt.wantContent)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_TryFilesSPA 测试 SPA 场景下的 try_files
|
||||
func TestStaticHandler_TryFilesSPA(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// 创建 SPA 文件结构
|
||||
// index.html - 主应用入口
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "index.html"), []byte("<!DOCTYPE html><html><body>SPA App</body></html>"), 0644); err != nil {
|
||||
t.Fatalf("创建 index.html 失败: %v", err)
|
||||
}
|
||||
|
||||
// 静态资源文件
|
||||
assetsDir := filepath.Join(tmpDir, "assets")
|
||||
if err := os.MkdirAll(assetsDir, 0755); err != nil {
|
||||
t.Fatalf("创建 assets 目录失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(assetsDir, "app.js"), []byte("console.log('app')"), 0644); err != nil {
|
||||
t.Fatalf("创建 app.js 失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(assetsDir, "style.css"), []byte("body { margin: 0 }"), 0644); err != nil {
|
||||
t.Fatalf("创建 style.css 失败: %v", err)
|
||||
}
|
||||
|
||||
handler := NewStaticHandler(tmpDir, "/", []string{"index.html"}, false)
|
||||
handler.SetTryFiles([]string{"$uri", "$uri/", "/index.html"}, false, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantStatus int
|
||||
wantContent string
|
||||
}{
|
||||
{
|
||||
name: "访问存在的静态资源",
|
||||
path: "/assets/app.js",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "console.log('app')",
|
||||
},
|
||||
{
|
||||
name: "访问存在的 CSS 文件",
|
||||
path: "/assets/style.css",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "body { margin: 0 }",
|
||||
},
|
||||
{
|
||||
name: "访问前端路由回退到 index.html",
|
||||
path: "/dashboard",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "<!DOCTYPE html><html><body>SPA App</body></html>",
|
||||
},
|
||||
{
|
||||
name: "访问嵌套前端路由",
|
||||
path: "/user/profile/settings",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "<!DOCTYPE html><html><body>SPA App</body></html>",
|
||||
},
|
||||
{
|
||||
name: "访问根路径",
|
||||
path: "/",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "<!DOCTYPE html><html><body>SPA App</body></html>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := newTestContext(t, tt.path)
|
||||
handler.Handle(ctx)
|
||||
|
||||
if got := ctx.Response.StatusCode(); got != tt.wantStatus {
|
||||
t.Errorf("状态码 = %d, want %d", got, tt.wantStatus)
|
||||
}
|
||||
|
||||
got := string(ctx.Response.Body())
|
||||
if got != tt.wantContent {
|
||||
t.Errorf("内容 = %q, want %q", got, tt.wantContent)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_TryFilesWithPathPrefix 测试带路径前缀的 try_files
|
||||
func TestStaticHandler_TryFilesWithPathPrefix(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// 创建 API 模拟文件
|
||||
apiDir := filepath.Join(tmpDir, "api")
|
||||
if err := os.MkdirAll(apiDir, 0755); err != nil {
|
||||
t.Fatalf("创建 api 目录失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(apiDir, "users.json"), []byte("[]"), 0644); err != nil {
|
||||
t.Fatalf("创建 users.json 失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建静态文件
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "index.html"), []byte("static index"), 0644); err != nil {
|
||||
t.Fatalf("创建 index.html 失败: %v", err)
|
||||
}
|
||||
|
||||
handler := NewStaticHandler(tmpDir, "/static", []string{"index.html"}, false)
|
||||
handler.SetTryFiles([]string{"$uri", "$uri/", "/index.html"}, false, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantStatus int
|
||||
wantContent string
|
||||
skipContent bool
|
||||
}{
|
||||
{
|
||||
name: "带前缀访问文件",
|
||||
path: "/static/api/users.json",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "[]",
|
||||
},
|
||||
{
|
||||
name: "带前缀访问目录",
|
||||
path: "/static/api/",
|
||||
wantStatus: fasthttp.StatusOK, // 目录无索引文件,但会回退到 /index.html
|
||||
wantContent: "static index",
|
||||
},
|
||||
{
|
||||
name: "前缀剥离后回退",
|
||||
path: "/static/unknown",
|
||||
wantStatus: fasthttp.StatusOK,
|
||||
wantContent: "static index",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := newTestContext(t, tt.path)
|
||||
handler.Handle(ctx)
|
||||
|
||||
if got := ctx.Response.StatusCode(); got != tt.wantStatus {
|
||||
t.Errorf("状态码 = %d, want %d", got, tt.wantStatus)
|
||||
}
|
||||
|
||||
if !tt.skipContent {
|
||||
got := string(ctx.Response.Body())
|
||||
if got != tt.wantContent {
|
||||
t.Errorf("内容 = %q, want %q", got, tt.wantContent)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_TryFilesEdgeCases 测试 try_files 边界情况
|
||||
func TestStaticHandler_TryFilesEdgeCases(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// 创建测试文件
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "file with spaces.txt"), []byte("spaces"), 0644); err != nil {
|
||||
t.Fatalf("创建带空格文件失败: %v", err)
|
||||
}
|
||||
|
||||
handler := NewStaticHandler(tmpDir, "/", []string{"index.html"}, false)
|
||||
handler.SetTryFiles([]string{"$uri", "/index.html"}, false, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "路径遍历攻击被阻止 - fasthttp 规范化",
|
||||
path: "/../secret",
|
||||
wantStatus: fasthttp.StatusNotFound, // fasthttp 规范化为 /secret,文件不存在返回 404
|
||||
},
|
||||
{
|
||||
name: "双点号在路径中被阻止",
|
||||
path: "/file..name",
|
||||
wantStatus: fasthttp.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := newTestContext(t, tt.path)
|
||||
handler.Handle(ctx)
|
||||
|
||||
if got := ctx.Response.StatusCode(); got != tt.wantStatus {
|
||||
t.Errorf("状态码 = %d, want %d", got, tt.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
528
internal/middleware/errorintercept/errorintercept_test.go
Normal file
528
internal/middleware/errorintercept/errorintercept_test.go
Normal file
@ -0,0 +1,528 @@
|
||||
// Package errorintercept 提供 HTTP 错误拦截中间件的测试。
|
||||
//
|
||||
// 该文件测试错误拦截中间件的各项功能,包括:
|
||||
// - 中间件实例创建
|
||||
// - 中间件名称获取
|
||||
// - 错误响应拦截
|
||||
// - 4xx/5xx 状态码检测
|
||||
// - 错误页面替换
|
||||
// - 边界值测试
|
||||
//
|
||||
// 作者:xfy
|
||||
package errorintercept
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/handler"
|
||||
)
|
||||
|
||||
// TestNew 测试创建错误拦截中间件。
|
||||
func TestNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
manager *handler.ErrorPageManager
|
||||
}{
|
||||
{
|
||||
name: "创建带有 manager 的实例",
|
||||
manager: &handler.ErrorPageManager{},
|
||||
},
|
||||
{
|
||||
name: "创建带有 nil manager 的实例",
|
||||
manager: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ei := New(tt.manager)
|
||||
|
||||
if ei == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
|
||||
if ei.manager != tt.manager {
|
||||
t.Errorf("expected manager %v, got %v", tt.manager, ei.manager)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorIntercept_Name 测试获取中间件名称。
|
||||
func TestErrorIntercept_Name(t *testing.T) {
|
||||
ei := New(nil)
|
||||
|
||||
name := ei.Name()
|
||||
|
||||
if name != "ErrorIntercept" {
|
||||
t.Errorf("expected name 'ErrorIntercept', got '%s'", name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorIntercept_Process_NilManager 测试 nil manager 情况下的 Process。
|
||||
func TestErrorIntercept_Process_NilManager(t *testing.T) {
|
||||
ei := New(nil)
|
||||
|
||||
called := false
|
||||
next := func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
ctx.SetStatusCode(200)
|
||||
ctx.SetBodyString("OK")
|
||||
}
|
||||
|
||||
wrapped := ei.Process(next)
|
||||
|
||||
if wrapped == nil {
|
||||
t.Fatal("Process() returned nil")
|
||||
}
|
||||
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
wrapped(&ctx)
|
||||
|
||||
if !called {
|
||||
t.Error("next handler was not called")
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorIntercept_Process_NotConfigured 测试未配置错误页面的情况。
|
||||
func TestErrorIntercept_Process_NotConfigured(t *testing.T) {
|
||||
// 创建一个空 manager(没有配置任何页面)
|
||||
manager, err := handler.NewErrorPageManager(&config.ErrorPageConfig{
|
||||
Pages: make(map[int]string),
|
||||
})
|
||||
if err != nil {
|
||||
t.Skipf("跳过测试:无法创建 manager: %v", err)
|
||||
}
|
||||
ei := New(manager)
|
||||
|
||||
called := false
|
||||
next := func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
ctx.SetStatusCode(200)
|
||||
ctx.SetBodyString("OK")
|
||||
}
|
||||
|
||||
wrapped := ei.Process(next)
|
||||
|
||||
// 未配置时应该返回 next 本身(通过行为验证)
|
||||
// 当 manager 未配置时,Process 直接返回 next,不会包装
|
||||
// 验证方式是确认 wrapped 调用后会直接执行 next 且不做额外操作
|
||||
if wrapped == nil {
|
||||
t.Fatal("Process() returned nil")
|
||||
}
|
||||
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
wrapped(&ctx)
|
||||
|
||||
if !called {
|
||||
t.Error("next handler was not called")
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorIntercept_Process_SuccessStatus 测试成功状态码不拦截。
|
||||
func TestErrorIntercept_Process_SuccessStatus(t *testing.T) {
|
||||
manager := createConfiguredManager(t)
|
||||
if manager == nil {
|
||||
t.Skip("跳过测试:无法创建配置好的 manager")
|
||||
}
|
||||
ei := New(manager)
|
||||
|
||||
tests := []int{200, 201, 299, 301, 304, 399}
|
||||
|
||||
for _, status := range tests {
|
||||
t.Run("状态码_"+string(rune('0'+status/100)), func(t *testing.T) {
|
||||
next := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(status)
|
||||
ctx.SetBodyString("success")
|
||||
}
|
||||
|
||||
wrapped := ei.Process(next)
|
||||
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
wrapped(&ctx)
|
||||
|
||||
// 验证状态码未被修改
|
||||
if ctx.Response.StatusCode() != status {
|
||||
t.Errorf("expected status %d, got %d", status, ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
// 验证 body 未被修改
|
||||
if string(ctx.Response.Body()) != "success" {
|
||||
t.Errorf("expected body 'success', got '%s'", string(ctx.Response.Body()))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorIntercept_Process_ErrorStatus_Intercepted 测试错误状态码被拦截并替换。
|
||||
func TestErrorIntercept_Process_ErrorStatus_Intercepted(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 创建测试用的错误页面文件
|
||||
page404 := filepath.Join(tempDir, "404.html")
|
||||
page500 := filepath.Join(tempDir, "500.html")
|
||||
pageDefault := filepath.Join(tempDir, "default.html")
|
||||
|
||||
if err := os.WriteFile(page404, []byte("<html>404 Not Found</html>"), 0644); err != nil {
|
||||
t.Fatalf("创建 404.html 失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(page500, []byte("<html>500 Error</html>"), 0644); err != nil {
|
||||
t.Fatalf("创建 500.html 失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(pageDefault, []byte("<html>Default Error</html>"), 0644); err != nil {
|
||||
t.Fatalf("创建 default.html 失败: %v", err)
|
||||
}
|
||||
|
||||
manager, err := handler.NewErrorPageManager(&config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: page404,
|
||||
500: page500,
|
||||
},
|
||||
Default: pageDefault,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("创建 ErrorPageManager 失败: %v", err)
|
||||
}
|
||||
|
||||
ei := New(manager)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
expectedBody string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "拦截 404 错误",
|
||||
statusCode: 404,
|
||||
expectedBody: "<html>404 Not Found</html>",
|
||||
expectedStatus: 404,
|
||||
},
|
||||
{
|
||||
name: "拦截 500 错误",
|
||||
statusCode: 500,
|
||||
expectedBody: "<html>500 Error</html>",
|
||||
expectedStatus: 500,
|
||||
},
|
||||
{
|
||||
name: "拦截 403 错误(使用默认页面)",
|
||||
statusCode: 403,
|
||||
expectedBody: "<html>Default Error</html>",
|
||||
expectedStatus: 403,
|
||||
},
|
||||
{
|
||||
name: "拦截 502 错误(使用默认页面)",
|
||||
statusCode: 502,
|
||||
expectedBody: "<html>Default Error</html>",
|
||||
expectedStatus: 502,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
next := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(tt.statusCode)
|
||||
ctx.SetBodyString("original error response")
|
||||
}
|
||||
|
||||
wrapped := ei.Process(next)
|
||||
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
wrapped(&ctx)
|
||||
|
||||
// 验证 body 被替换
|
||||
if string(ctx.Response.Body()) != tt.expectedBody {
|
||||
t.Errorf("expected body '%s', got '%s'", tt.expectedBody, string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
// 验证状态码
|
||||
if ctx.Response.StatusCode() != tt.expectedStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.expectedStatus, ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
// 验证 content-type
|
||||
contentType := string(ctx.Response.Header.ContentType())
|
||||
if contentType != "text/html; charset=utf-8" {
|
||||
t.Errorf("expected content-type 'text/html; charset=utf-8', got '%s'", contentType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorIntercept_Process_WithResponseCodeOverride 测试响应状态码覆盖。
|
||||
func TestErrorIntercept_Process_WithResponseCodeOverride(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
page404 := filepath.Join(tempDir, "404.html")
|
||||
if err := os.WriteFile(page404, []byte("<html>404 Not Found</html>"), 0644); err != nil {
|
||||
t.Fatalf("创建 404.html 失败: %v", err)
|
||||
}
|
||||
|
||||
manager, err := handler.NewErrorPageManager(&config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: page404,
|
||||
},
|
||||
ResponseCode: 200, // 覆盖状态码为 200
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("创建 ErrorPageManager 失败: %v", err)
|
||||
}
|
||||
|
||||
ei := New(manager)
|
||||
|
||||
next := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(404)
|
||||
ctx.SetBodyString("not found")
|
||||
}
|
||||
|
||||
wrapped := ei.Process(next)
|
||||
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
wrapped(&ctx)
|
||||
|
||||
// 状态码应该被覆盖为 200
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Errorf("expected status 200 (overridden), got %d", ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
// body 应该被替换
|
||||
if string(ctx.Response.Body()) != "<html>404 Not Found</html>" {
|
||||
t.Errorf("expected custom error page, got '%s'", string(ctx.Response.Body()))
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorIntercept_Process_NoMatchingPage 测试没有匹配错误页面的情况。
|
||||
func TestErrorIntercept_Process_NoMatchingPage(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 只创建 404 页面,没有默认页面
|
||||
page404 := filepath.Join(tempDir, "404.html")
|
||||
if err := os.WriteFile(page404, []byte("<html>404 Not Found</html>"), 0644); err != nil {
|
||||
t.Fatalf("创建 404.html 失败: %v", err)
|
||||
}
|
||||
|
||||
manager, err := handler.NewErrorPageManager(&config.ErrorPageConfig{
|
||||
Pages: map[int]string{
|
||||
404: page404,
|
||||
},
|
||||
// 没有配置默认页面
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("创建 ErrorPageManager 失败: %v", err)
|
||||
}
|
||||
|
||||
ei := New(manager)
|
||||
|
||||
// 请求 500,但没有配置 500 页面和默认页面
|
||||
next := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(500)
|
||||
ctx.SetBodyString("original 500 error")
|
||||
}
|
||||
|
||||
wrapped := ei.Process(next)
|
||||
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
wrapped(&ctx)
|
||||
|
||||
// 没有匹配的错误页面,应该保持原样
|
||||
if ctx.Response.StatusCode() != 500 {
|
||||
t.Errorf("expected status 500, got %d", ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
// body 不应该被修改(因为没有找到匹配页面)
|
||||
if string(ctx.Response.Body()) != "original 500 error" {
|
||||
t.Errorf("expected body unchanged, got '%s'", string(ctx.Response.Body()))
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorIntercept_Process_4xxErrors 测试所有 4xx 错误被拦截。
|
||||
func TestErrorIntercept_Process_4xxErrors(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 创建默认错误页面
|
||||
pageDefault := filepath.Join(tempDir, "default.html")
|
||||
if err := os.WriteFile(pageDefault, []byte("<html>Error</html>"), 0644); err != nil {
|
||||
t.Fatalf("创建 default.html 失败: %v", err)
|
||||
}
|
||||
|
||||
manager, err := handler.NewErrorPageManager(&config.ErrorPageConfig{
|
||||
Default: pageDefault,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("创建 ErrorPageManager 失败: %v", err)
|
||||
}
|
||||
|
||||
ei := New(manager)
|
||||
|
||||
// 测试不同的 4xx 错误码
|
||||
codes := []int{400, 401, 403, 404, 405, 408, 429, 499}
|
||||
|
||||
for _, code := range codes {
|
||||
t.Run("状态码_"+string(rune('0'+code/100)), func(t *testing.T) {
|
||||
next := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(code)
|
||||
ctx.SetBodyString("error")
|
||||
}
|
||||
|
||||
wrapped := ei.Process(next)
|
||||
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
wrapped(&ctx)
|
||||
|
||||
// 验证 body 被替换为默认页面
|
||||
if string(ctx.Response.Body()) != "<html>Error</html>" {
|
||||
t.Errorf("expected custom error page, got '%s'", string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
// 验证状态码
|
||||
if ctx.Response.StatusCode() != code {
|
||||
t.Errorf("expected status %d, got %d", code, ctx.Response.StatusCode())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorIntercept_Process_5xxErrors 测试所有 5xx 错误被拦截。
|
||||
func TestErrorIntercept_Process_5xxErrors(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 创建默认错误页面
|
||||
pageDefault := filepath.Join(tempDir, "default.html")
|
||||
if err := os.WriteFile(pageDefault, []byte("<html>Server Error</html>"), 0644); err != nil {
|
||||
t.Fatalf("创建 default.html 失败: %v", err)
|
||||
}
|
||||
|
||||
manager, err := handler.NewErrorPageManager(&config.ErrorPageConfig{
|
||||
Default: pageDefault,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("创建 ErrorPageManager 失败: %v", err)
|
||||
}
|
||||
|
||||
ei := New(manager)
|
||||
|
||||
// 测试不同的 5xx 错误码
|
||||
codes := []int{500, 501, 502, 503, 504, 505, 599}
|
||||
|
||||
for _, code := range codes {
|
||||
t.Run("状态码_"+string(rune('0'+code/100)), func(t *testing.T) {
|
||||
next := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(code)
|
||||
ctx.SetBodyString("server error")
|
||||
}
|
||||
|
||||
wrapped := ei.Process(next)
|
||||
|
||||
var ctx fasthttp.RequestCtx
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
wrapped(&ctx)
|
||||
|
||||
// 验证 body 被替换
|
||||
if string(ctx.Response.Body()) != "<html>Server Error</html>" {
|
||||
t.Errorf("expected custom error page, got '%s'", string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
// 验证状态码
|
||||
if ctx.Response.StatusCode() != code {
|
||||
t.Errorf("expected status %d, got %d", code, ctx.Response.StatusCode())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsErrorStatusCode 测试错误状态码检测。
|
||||
func TestIsErrorStatusCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code int
|
||||
expected bool
|
||||
}{
|
||||
// 边界测试
|
||||
{"边界 399", 399, false},
|
||||
{"边界 400", 400, true},
|
||||
{"边界 599", 599, true},
|
||||
{"边界 600", 600, false},
|
||||
|
||||
// 4xx 错误
|
||||
{"400 Bad Request", 400, true},
|
||||
{"401 Unauthorized", 401, true},
|
||||
{"403 Forbidden", 403, true},
|
||||
{"404 Not Found", 404, true},
|
||||
{"429 Too Many Requests", 429, true},
|
||||
{"499", 499, true},
|
||||
|
||||
// 5xx 错误
|
||||
{"500 Internal Server Error", 500, true},
|
||||
{"502 Bad Gateway", 502, true},
|
||||
{"503 Service Unavailable", 503, true},
|
||||
{"504 Gateway Timeout", 504, true},
|
||||
|
||||
// 非错误状态码
|
||||
{"200 OK", 200, false},
|
||||
{"201 Created", 201, false},
|
||||
{"301 Redirect", 301, false},
|
||||
{"304 Not Modified", 304, false},
|
||||
|
||||
// 边缘值
|
||||
{"0", 0, false},
|
||||
{"负值 -1", -1, false},
|
||||
{"极大值 999", 999, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isErrorStatusCode(tt.code)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isErrorStatusCode(%d) = %v, expected %v", tt.code, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestErrorIntercept_GetManager 测试获取 manager。
|
||||
func TestErrorIntercept_GetManager(t *testing.T) {
|
||||
manager := &handler.ErrorPageManager{}
|
||||
ei := New(manager)
|
||||
|
||||
got := ei.GetManager()
|
||||
if got != manager {
|
||||
t.Error("GetManager() did not return the expected manager")
|
||||
}
|
||||
}
|
||||
|
||||
// createConfiguredManager 创建一个已配置的 ErrorPageManager 用于测试。
|
||||
func createConfiguredManager(t *testing.T) *handler.ErrorPageManager {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// 创建一个简单的错误页面
|
||||
pageDefault := filepath.Join(tempDir, "default.html")
|
||||
if err := os.WriteFile(pageDefault, []byte("<html>Error</html>"), 0644); err != nil {
|
||||
t.Fatalf("创建 default.html 失败: %v", err)
|
||||
}
|
||||
|
||||
manager, err := handler.NewErrorPageManager(&config.ErrorPageConfig{
|
||||
Default: pageDefault,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("创建 ErrorPageManager 失败: %v", err)
|
||||
}
|
||||
|
||||
return manager
|
||||
}
|
||||
821
internal/server/pprof_test.go
Normal file
821
internal/server/pprof_test.go
Normal file
@ -0,0 +1,821 @@
|
||||
// Package server 提供 pprof 性能分析端点功能的测试。
|
||||
//
|
||||
// 该文件测试 pprof 处理器模块的各项功能,包括:
|
||||
// - pprof 处理器创建
|
||||
// - 配置解析和默认值
|
||||
// - IP/CIDR 白名单验证
|
||||
// - 路径返回
|
||||
// - ServeHTTP 路径分发
|
||||
// - 访问控制逻辑
|
||||
// - HTML 索引页面生成
|
||||
//
|
||||
// 作者:xfy
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
func TestNewPprofHandler_Disabled(t *testing.T) {
|
||||
cfg := &config.PprofConfig{
|
||||
Enabled: false,
|
||||
}
|
||||
|
||||
h, err := NewPprofHandler(cfg)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if h != nil {
|
||||
t.Error("expected nil handler when disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPprofHandler_DefaultPath(t *testing.T) {
|
||||
cfg := &config.PprofConfig{
|
||||
Enabled: true,
|
||||
Path: "",
|
||||
}
|
||||
|
||||
h, err := NewPprofHandler(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if h == nil {
|
||||
t.Fatal("expected non-nil handler")
|
||||
}
|
||||
|
||||
if h.Path() != "/debug/pprof" {
|
||||
t.Errorf("expected default path /debug/pprof, got %s", h.Path())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPprofHandler_CustomPath(t *testing.T) {
|
||||
cfg := &config.PprofConfig{
|
||||
Enabled: true,
|
||||
Path: "/custom/pprof",
|
||||
}
|
||||
|
||||
h, err := NewPprofHandler(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if h == nil {
|
||||
t.Fatal("expected non-nil handler")
|
||||
}
|
||||
|
||||
if h.Path() != "/custom/pprof" {
|
||||
t.Errorf("expected custom path /custom/pprof, got %s", h.Path())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPprofHandler_SingleIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allow []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid IPv4",
|
||||
allow: []string{"192.168.1.100"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid IPv6",
|
||||
allow: []string{"::1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple IPs",
|
||||
allow: []string{"192.168.1.1", "127.0.0.1", "::1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty allow list - use default localhost",
|
||||
allow: []string{},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.PprofConfig{
|
||||
Enabled: true,
|
||||
Path: "/debug/pprof",
|
||||
Allow: tt.allow,
|
||||
}
|
||||
|
||||
h, err := NewPprofHandler(cfg)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if h == nil {
|
||||
t.Fatal("expected non-nil handler")
|
||||
}
|
||||
// 空列表时应该默认允许 localhost
|
||||
if len(tt.allow) == 0 {
|
||||
if len(h.allowedIPs) != 2 {
|
||||
t.Errorf("expected 2 default allowed IPs (127.0.0.1 and ::1), got %d", len(h.allowedIPs))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPprofHandler_CIDR(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allow []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid CIDR IPv4",
|
||||
allow: []string{"192.168.1.0/24"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid CIDR IPv6",
|
||||
allow: []string{"2001:db8::/32"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple CIDRs",
|
||||
allow: []string{"10.0.0.0/8", "172.16.0.0/12"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed IP and CIDR",
|
||||
allow: []string{"192.168.1.1", "10.0.0.0/8"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.PprofConfig{
|
||||
Enabled: true,
|
||||
Path: "/debug/pprof",
|
||||
Allow: tt.allow,
|
||||
}
|
||||
|
||||
h, err := NewPprofHandler(cfg)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if h == nil {
|
||||
t.Fatal("expected non-nil handler")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewPprofHandler_InvalidIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allow []string
|
||||
}{
|
||||
{
|
||||
name: "invalid IP format",
|
||||
allow: []string{"not-an-ip"},
|
||||
},
|
||||
{
|
||||
name: "invalid CIDR format",
|
||||
allow: []string{"invalid-cidr"},
|
||||
},
|
||||
{
|
||||
name: "CIDR with invalid mask",
|
||||
allow: []string{"192.168.1.0/33"},
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid",
|
||||
allow: []string{"127.0.0.1", "invalid"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.PprofConfig{
|
||||
Enabled: true,
|
||||
Path: "/debug/pprof",
|
||||
Allow: tt.allow,
|
||||
}
|
||||
|
||||
_, err := NewPprofHandler(cfg)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid IP/CIDR, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_Path(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantPath string
|
||||
}{
|
||||
{
|
||||
name: "default path",
|
||||
path: "",
|
||||
wantPath: "/debug/pprof",
|
||||
},
|
||||
{
|
||||
name: "custom path",
|
||||
path: "/admin/pprof",
|
||||
wantPath: "/admin/pprof",
|
||||
},
|
||||
{
|
||||
name: "nested path",
|
||||
path: "/api/v1/debug/pprof",
|
||||
wantPath: "/api/v1/debug/pprof",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &config.PprofConfig{
|
||||
Enabled: true,
|
||||
Path: tt.path,
|
||||
}
|
||||
|
||||
h, err := NewPprofHandler(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if h == nil {
|
||||
t.Fatal("expected non-nil handler")
|
||||
}
|
||||
|
||||
if h.Path() != tt.wantPath {
|
||||
t.Errorf("expected path %s, got %s", tt.wantPath, h.Path())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_isAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedIPs []string
|
||||
allowedNets []string
|
||||
clientIP string
|
||||
wantAllowed bool
|
||||
}{
|
||||
{
|
||||
name: "empty allow list - allow all",
|
||||
allowedIPs: []string{},
|
||||
allowedNets: []string{},
|
||||
clientIP: "192.168.1.100",
|
||||
wantAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "IP exact match",
|
||||
allowedIPs: []string{"127.0.0.1"},
|
||||
allowedNets: []string{},
|
||||
clientIP: "127.0.0.1",
|
||||
wantAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "IP no match",
|
||||
allowedIPs: []string{"127.0.0.1"},
|
||||
allowedNets: []string{},
|
||||
clientIP: "127.0.0.2",
|
||||
wantAllowed: false,
|
||||
},
|
||||
{
|
||||
name: "CIDR match",
|
||||
allowedIPs: []string{},
|
||||
allowedNets: []string{"192.168.0.0/16"},
|
||||
clientIP: "192.168.1.100",
|
||||
wantAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "CIDR no match",
|
||||
allowedIPs: []string{},
|
||||
allowedNets: []string{"10.0.0.0/8"},
|
||||
clientIP: "192.168.1.100",
|
||||
wantAllowed: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6 CIDR match",
|
||||
allowedIPs: []string{},
|
||||
allowedNets: []string{"2001:db8::/32"},
|
||||
clientIP: "2001:db8::1",
|
||||
wantAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6 exact match",
|
||||
allowedIPs: []string{"::1"},
|
||||
allowedNets: []string{},
|
||||
clientIP: "::1",
|
||||
wantAllowed: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: parseIPs(tt.allowedIPs),
|
||||
allowedNets: parseNets(tt.allowedNets),
|
||||
}
|
||||
|
||||
// 创建请求上下文,模拟客户端 IP
|
||||
// 通过设置请求头来模拟 IP 需要特殊处理
|
||||
// fasthttp 的 RemoteIP() 从连接获取,这里我们直接测试 isAllowed 逻辑
|
||||
|
||||
// 手动测试 isAllowed 的内部逻辑
|
||||
clientIP := net.ParseIP(tt.clientIP)
|
||||
if clientIP == nil {
|
||||
t.Fatalf("failed to parse client IP: %s", tt.clientIP)
|
||||
}
|
||||
|
||||
// 复制 isAllowed 的逻辑进行测试
|
||||
allowed := false
|
||||
if len(h.allowedIPs) == 0 && len(h.allowedNets) == 0 {
|
||||
allowed = true
|
||||
} else {
|
||||
for _, ip := range h.allowedIPs {
|
||||
if ip.Equal(clientIP) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, n := range h.allowedNets {
|
||||
if n.Contains(clientIP) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if allowed != tt.wantAllowed {
|
||||
t.Errorf("isAllowed() = %v, want %v", allowed, tt.wantAllowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// parseIPs 辅助函数,解析 IP 字符串列表
|
||||
func parseIPs(ips []string) []net.IP {
|
||||
result := make([]net.IP, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
if parsed := net.ParseIP(ip); parsed != nil {
|
||||
result = append(result, parsed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// parseNets 辅助函数,解析 CIDR 字符串列表
|
||||
func parseNets(cidrs []string) []*net.IPNet {
|
||||
result := make([]*net.IPNet, 0, len(cidrs))
|
||||
for _, cidr := range cidrs {
|
||||
_, net, err := net.ParseCIDR(cidr)
|
||||
if err == nil {
|
||||
result = append(result, net)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func TestPprofHandler_ServeHTTP_WithAllowListEmpty(t *testing.T) {
|
||||
// 测试空 allow 列表时允许所有访问
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/debug/pprof")
|
||||
|
||||
h.ServeHTTP(ctx)
|
||||
|
||||
// 空 allow 列表时应允许访问(返回 200)
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Errorf("expected status 200 for open access, got %d", ctx.Response.StatusCode())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_ServeHTTP_ProfileEndpoints(t *testing.T) {
|
||||
// 使用空 allow 列表允许所有访问
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "heap endpoint",
|
||||
path: "/debug/pprof/heap",
|
||||
wantStatus: 200,
|
||||
},
|
||||
{
|
||||
name: "goroutine endpoint",
|
||||
path: "/debug/pprof/goroutine",
|
||||
wantStatus: 200,
|
||||
},
|
||||
{
|
||||
name: "block endpoint",
|
||||
path: "/debug/pprof/block",
|
||||
wantStatus: 200,
|
||||
},
|
||||
{
|
||||
name: "mutex endpoint",
|
||||
path: "/debug/pprof/mutex",
|
||||
wantStatus: 200,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI(tt.path)
|
||||
|
||||
h.ServeHTTP(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.wantStatus, ctx.Response.StatusCode())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_handleIndex(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/debug/pprof")
|
||||
|
||||
h.handleIndex(ctx)
|
||||
|
||||
// 验证状态码
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Errorf("expected status 200, got %d", ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
// 验证 Content-Type
|
||||
contentType := string(ctx.Response.Header.Peek("Content-Type"))
|
||||
if !strings.Contains(contentType, "text/html") {
|
||||
t.Errorf("expected Content-Type text/html, got %s", contentType)
|
||||
}
|
||||
|
||||
// 验证响应体包含关键内容
|
||||
body := ctx.Response.Body()
|
||||
if !bytes.Contains(body, []byte("Pprof Profiles")) {
|
||||
t.Error("expected body to contain 'Pprof Profiles'")
|
||||
}
|
||||
if !bytes.Contains(body, []byte("/debug/pprof/profile")) {
|
||||
t.Error("expected body to contain profile link")
|
||||
}
|
||||
if !bytes.Contains(body, []byte("/debug/pprof/heap")) {
|
||||
t.Error("expected body to contain heap link")
|
||||
}
|
||||
if !bytes.Contains(body, []byte("/debug/pprof/goroutine")) {
|
||||
t.Error("expected body to contain goroutine link")
|
||||
}
|
||||
if !bytes.Contains(body, []byte("/debug/pprof/block")) {
|
||||
t.Error("expected body to contain block link")
|
||||
}
|
||||
if !bytes.Contains(body, []byte("/debug/pprof/mutex")) {
|
||||
t.Error("expected body to contain mutex link")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_ServeHTTP_PathRouting(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "index path",
|
||||
path: "/debug/pprof",
|
||||
wantStatus: 200,
|
||||
wantBody: "Pprof Profiles",
|
||||
},
|
||||
{
|
||||
name: "index path with slash",
|
||||
path: "/debug/pprof/",
|
||||
wantStatus: 200,
|
||||
wantBody: "Pprof Profiles",
|
||||
},
|
||||
{
|
||||
name: "unknown path",
|
||||
path: "/debug/pprof/unknown",
|
||||
wantStatus: 404,
|
||||
wantBody: "Unknown profile",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI(tt.path)
|
||||
|
||||
h.ServeHTTP(ctx)
|
||||
|
||||
if ctx.Response.StatusCode() != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.wantStatus, ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
if tt.wantBody != "" {
|
||||
body := string(ctx.Response.Body())
|
||||
if !strings.Contains(body, tt.wantBody) {
|
||||
t.Errorf("expected body to contain '%s', got '%s'", tt.wantBody, body)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_ServeHTTP_Forbidden(t *testing.T) {
|
||||
// 创建只允许特定 IP 的 handler
|
||||
allowedIP := net.ParseIP("10.0.0.1")
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{allowedIP},
|
||||
allowedNets: []*net.IPNet{},
|
||||
}
|
||||
|
||||
// 由于无法轻松设置 RemoteIP,我们直接测试 isAllowed 返回 false 的情况
|
||||
// 通过构造一个 allowedIPs 非空的情况来触发检查
|
||||
|
||||
// 验证 handler 配置正确
|
||||
if len(h.allowedIPs) != 1 {
|
||||
t.Errorf("expected 1 allowed IP, got %d", len(h.allowedIPs))
|
||||
}
|
||||
|
||||
// 验证 allowed IPs 包含配置的 IP
|
||||
if !h.allowedIPs[0].Equal(allowedIP) {
|
||||
t.Error("expected allowedIPs to contain configured IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_handleCPU_Params(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
seconds string
|
||||
wantType string
|
||||
}{
|
||||
{
|
||||
name: "default seconds",
|
||||
seconds: "",
|
||||
wantType: "application/octet-stream",
|
||||
},
|
||||
{
|
||||
name: "custom seconds",
|
||||
seconds: "1",
|
||||
wantType: "application/octet-stream",
|
||||
},
|
||||
{
|
||||
name: "invalid seconds",
|
||||
seconds: "invalid",
|
||||
wantType: "application/octet-stream",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
if tt.seconds != "" {
|
||||
ctx.Request.SetRequestURI("/debug/pprof/profile?seconds=" + tt.seconds)
|
||||
} else {
|
||||
ctx.Request.SetRequestURI("/debug/pprof/profile")
|
||||
}
|
||||
|
||||
// 注意:handleCPU 会启动实际的 CPU profile,需要特殊处理
|
||||
// 这里主要验证 Content-Type 设置正确
|
||||
// 实际 profile 测试需要更复杂的设置
|
||||
|
||||
// 验证 handler 配置
|
||||
if h.path != "/debug/pprof" {
|
||||
t.Error("unexpected handler path")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_ConfigWithCIDRAndIP(t *testing.T) {
|
||||
// 测试混合配置
|
||||
cfg := &config.PprofConfig{
|
||||
Enabled: true,
|
||||
Path: "/debug/pprof",
|
||||
Allow: []string{"127.0.0.1", "192.168.0.0/24", "::1"},
|
||||
}
|
||||
|
||||
h, err := NewPprofHandler(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if h == nil {
|
||||
t.Fatal("expected non-nil handler")
|
||||
}
|
||||
|
||||
// 验证 IP 和 CIDR 都被正确解析
|
||||
if len(h.allowedIPs) != 2 {
|
||||
t.Errorf("expected 2 allowed IPs, got %d", len(h.allowedIPs))
|
||||
}
|
||||
if len(h.allowedNets) != 1 {
|
||||
t.Errorf("expected 1 allowed net, got %d", len(h.allowedNets))
|
||||
}
|
||||
|
||||
// 验证具体内容
|
||||
foundV4 := false
|
||||
foundV6 := false
|
||||
for _, ip := range h.allowedIPs {
|
||||
if ip.Equal(net.ParseIP("127.0.0.1")) {
|
||||
foundV4 = true
|
||||
}
|
||||
if ip.Equal(net.ParseIP("::1")) {
|
||||
foundV6 = true
|
||||
}
|
||||
}
|
||||
if !foundV4 {
|
||||
t.Error("expected to find 127.0.0.1 in allowedIPs")
|
||||
}
|
||||
if !foundV6 {
|
||||
t.Error("expected to find ::1 in allowedIPs")
|
||||
}
|
||||
|
||||
// 验证 CIDR
|
||||
if h.allowedNets[0].String() != "192.168.0.0/24" {
|
||||
t.Errorf("expected CIDR 192.168.0.0/24, got %s", h.allowedNets[0].String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_DefaultLocalhostBehavior(t *testing.T) {
|
||||
// 测试空配置时默认只允许 localhost
|
||||
cfg := &config.PprofConfig{
|
||||
Enabled: true,
|
||||
Path: "/debug/pprof",
|
||||
Allow: []string{},
|
||||
}
|
||||
|
||||
h, err := NewPprofHandler(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if h == nil {
|
||||
t.Fatal("expected non-nil handler")
|
||||
}
|
||||
|
||||
// 验证默认允许 localhost
|
||||
if len(h.allowedIPs) != 2 {
|
||||
t.Errorf("expected 2 default allowed IPs, got %d", len(h.allowedIPs))
|
||||
}
|
||||
|
||||
// 验证包含 IPv4 和 IPv6 localhost
|
||||
hasV4 := false
|
||||
hasV6 := false
|
||||
for _, ip := range h.allowedIPs {
|
||||
if ip.Equal(net.ParseIP("127.0.0.1")) {
|
||||
hasV4 = true
|
||||
}
|
||||
if ip.Equal(net.ParseIP("::1")) {
|
||||
hasV6 = true
|
||||
}
|
||||
}
|
||||
if !hasV4 {
|
||||
t.Error("expected default to include 127.0.0.1")
|
||||
}
|
||||
if !hasV6 {
|
||||
t.Error("expected default to include ::1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_handleHeap(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/debug/pprof/heap")
|
||||
|
||||
h.handleHeap(ctx)
|
||||
|
||||
// 验证状态码
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Errorf("expected status 200, got %d", ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
// 验证 Content-Type
|
||||
contentType := string(ctx.Response.Header.Peek("Content-Type"))
|
||||
if contentType != "application/octet-stream" {
|
||||
t.Errorf("expected Content-Type application/octet-stream, got %s", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_handleGoroutine(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/debug/pprof/goroutine")
|
||||
|
||||
h.handleGoroutine(ctx)
|
||||
|
||||
// 验证状态码
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Errorf("expected status 200, got %d", ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
// 验证 Content-Type
|
||||
contentType := string(ctx.Response.Header.Peek("Content-Type"))
|
||||
if contentType != "application/octet-stream" {
|
||||
t.Errorf("expected Content-Type application/octet-stream, got %s", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_handleBlock(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/debug/pprof/block")
|
||||
|
||||
h.handleBlock(ctx)
|
||||
|
||||
// 验证状态码
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Errorf("expected status 200, got %d", ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
// 验证 Content-Type
|
||||
contentType := string(ctx.Response.Header.Peek("Content-Type"))
|
||||
if contentType != "application/octet-stream" {
|
||||
t.Errorf("expected Content-Type application/octet-stream, got %s", contentType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPprofHandler_handleMutex(t *testing.T) {
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/debug/pprof/mutex")
|
||||
|
||||
h.handleMutex(ctx)
|
||||
|
||||
// 验证状态码
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Errorf("expected status 200, got %d", ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
// 验证 Content-Type
|
||||
contentType := string(ctx.Response.Header.Peek("Content-Type"))
|
||||
if contentType != "application/octet-stream" {
|
||||
t.Errorf("expected Content-Type application/octet-stream, got %s", contentType)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user