test(handler,middleware,server,ssl,proxy): 扩展测试覆盖率
- handler: 添加 sendfile 和 static 处理器测试 - middleware/security: 添加访问控制、认证、请求头、限流测试 - server: 添加池、pprof、清理、状态、升级、vhost 测试 - ssl: 添加客户端验证、OCSP、SSL 测试 - proxy: 添加代理覆盖率补充测试 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
91f954936e
commit
9f7090df67
@ -17,6 +17,7 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
@ -600,3 +601,154 @@ func TestSendFile_JustBelowMin(t *testing.T) {
|
||||
t.Errorf("Body mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetSocketFd_TCPConn 测试 TCPConn 获取 socket fd
|
||||
func TestGetSocketFd_TCPConn(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
// 启动 goroutine 接收连接
|
||||
var serverConn net.Conn
|
||||
go func() {
|
||||
serverConn, _ = ln.Accept()
|
||||
}()
|
||||
|
||||
// 客户端连接
|
||||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial: %v", err)
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
// 等待连接建立
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 测试获取 socket fd
|
||||
fd, err := getSocketFd(clientConn)
|
||||
if err != nil {
|
||||
t.Errorf("getSocketFd failed for TCPConn: %v", err)
|
||||
}
|
||||
if fd == 0 {
|
||||
t.Error("Expected non-zero fd for TCPConn")
|
||||
}
|
||||
|
||||
if serverConn != nil {
|
||||
serverConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// TestLinuxSendfile_WithTCPConn 测试 linuxSendfile 使用真实 TCP 连接
|
||||
func TestLinuxSendfile_WithTCPConn(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
|
||||
// 创建大于 MinSendfileSize 的文件
|
||||
content := make([]byte, MinSendfileSize+1024)
|
||||
for i := range content {
|
||||
content[i] = byte(i % 256)
|
||||
}
|
||||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(tmpFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open file: %v", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
// 创建 TCP 连接
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
var serverConn net.Conn
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
serverConn, _ = ln.Accept()
|
||||
// 读取所有数据
|
||||
buf := make([]byte, len(content))
|
||||
_, _ = io.ReadFull(serverConn, buf)
|
||||
serverConn.Close()
|
||||
}()
|
||||
|
||||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to dial: %v", err)
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
// 设置写超时
|
||||
if err := clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
t.Fatalf("Failed to set deadline: %v", err)
|
||||
}
|
||||
|
||||
// 调用 linuxSendfile
|
||||
err = linuxSendfile(clientConn, file.Fd(), 0, int64(len(content)))
|
||||
if err != nil && err != syscall.EPIPE && err != syscall.ECONNRESET {
|
||||
t.Logf("linuxSendfile returned: %v", err)
|
||||
}
|
||||
|
||||
clientConn.Close()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestLinuxSendfile_SendfileError 测试 sendfile 错误处理
|
||||
func TestLinuxSendfile_SendfileError(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
content := []byte("test content for sendfile")
|
||||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(tmpFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open file: %v", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
// 使用 mockConn 测试不支持的连接类型
|
||||
conn := &mockConn{}
|
||||
err = linuxSendfile(conn, file.Fd(), 0, int64(len(content)))
|
||||
if err == nil {
|
||||
t.Error("Expected error for unsupported connection type")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendFile_NegativeLength 测试负长度参数
|
||||
func TestSendFile_NegativeLength(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tmpFile := filepath.Join(tmpDir, "test.txt")
|
||||
content := []byte("test content")
|
||||
if err := os.WriteFile(tmpFile, content, 0o644); err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
file, err := os.Open(tmpFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open file: %v", err)
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Init(&fasthttp.Request{}, nil, nil)
|
||||
|
||||
// 负长度应该使用 fallback
|
||||
err = SendFile(ctx, file, 0, -1)
|
||||
if err != nil {
|
||||
t.Errorf("SendFile with negative length failed: %v", err)
|
||||
}
|
||||
|
||||
// 应该传输整个文件
|
||||
if !bytes.Equal(ctx.Response.Body(), content) {
|
||||
t.Errorf("Expected body %s, got %s", content, ctx.Response.Body())
|
||||
}
|
||||
}
|
||||
|
||||
@ -21,8 +21,10 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/cache"
|
||||
"rua.plus/lolly/internal/testutil"
|
||||
)
|
||||
|
||||
@ -1597,3 +1599,352 @@ func TestStaticHandler_TryFilesRootPathFallback(t *testing.T) {
|
||||
t.Errorf("内容 = %q, want %q", got, "root fallback")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_SetSymlinkCheck 测试 SetSymlinkCheck 方法
|
||||
func TestStaticHandler_SetSymlinkCheck(t *testing.T) {
|
||||
handler := NewStaticHandler("/var/www", "/", nil, false)
|
||||
|
||||
if handler.symlinkCheck {
|
||||
t.Error("初始 symlinkCheck 应为 false")
|
||||
}
|
||||
|
||||
handler.SetSymlinkCheck(true)
|
||||
if !handler.symlinkCheck {
|
||||
t.Error("SetSymlinkCheck(true) 后 symlinkCheck 应为 true")
|
||||
}
|
||||
|
||||
handler.SetSymlinkCheck(false)
|
||||
if handler.symlinkCheck {
|
||||
t.Error("SetSymlinkCheck(false) 后 symlinkCheck 应为 false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_SetInternal 测试 SetInternal 方法
|
||||
func TestStaticHandler_SetInternal(t *testing.T) {
|
||||
handler := NewStaticHandler("/var/www", "/", nil, false)
|
||||
|
||||
if handler.internal {
|
||||
t.Error("初始 internal 应为 false")
|
||||
}
|
||||
|
||||
handler.SetInternal(true)
|
||||
if !handler.internal {
|
||||
t.Error("SetInternal(true) 后 internal 应为 true")
|
||||
}
|
||||
|
||||
handler.SetInternal(false)
|
||||
if handler.internal {
|
||||
t.Error("SetInternal(false) 后 internal 应为 false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_SetCacheTTL 测试 SetCacheTTL 方法
|
||||
func TestStaticHandler_SetCacheTTL(t *testing.T) {
|
||||
handler := NewStaticHandler("/var/www", "/", nil, false)
|
||||
|
||||
if handler.cacheTTL != 0 {
|
||||
t.Error("初始 cacheTTL 应为 0")
|
||||
}
|
||||
|
||||
handler.SetCacheTTL(5 * time.Second)
|
||||
if handler.cacheTTL != 5*time.Second {
|
||||
t.Errorf("SetCacheTTL 后 cacheTTL = %v, want %v", handler.cacheTTL, 5*time.Second)
|
||||
}
|
||||
|
||||
handler.SetCacheTTL(0)
|
||||
if handler.cacheTTL != 0 {
|
||||
t.Error("SetCacheTTL(0) 后 cacheTTL 应为 0")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_InternalRestriction 测试 internal 访问限制
|
||||
func TestStaticHandler_InternalRestriction(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
content := "internal content"
|
||||
if err := os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("创建测试文件失败: %v", err)
|
||||
}
|
||||
|
||||
handler := newTestHandler(t, tmpDir)
|
||||
handler.SetInternal(true)
|
||||
|
||||
t.Run("外部请求返回 404", func(t *testing.T) {
|
||||
ctx := newTestContext(t, "/test.txt")
|
||||
handler.Handle(ctx)
|
||||
|
||||
if got := ctx.Response.StatusCode(); got != fasthttp.StatusNotFound {
|
||||
t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusNotFound)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("内部重定向允许访问", func(t *testing.T) {
|
||||
ctx := newTestContext(t, "/test.txt")
|
||||
// 标记为内部重定向
|
||||
ctx.SetUserValue("__internal_redirect__", "/test.txt")
|
||||
handler.Handle(ctx)
|
||||
|
||||
if got := ctx.Response.StatusCode(); got != fasthttp.StatusOK {
|
||||
t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestStaticHandler_ValidateSymlink 测试符号链接验证
|
||||
func TestStaticHandler_ValidateSymlink(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// 创建目标文件和目录
|
||||
targetDir := filepath.Join(tmpDir, "target")
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
t.Fatalf("创建目标目录失败: %v", err)
|
||||
}
|
||||
targetFile := filepath.Join(targetDir, "secret.txt")
|
||||
if err := os.WriteFile(targetFile, []byte("secret content"), 0o644); err != nil {
|
||||
t.Fatalf("创建目标文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建允许的根目录
|
||||
allowedDir := filepath.Join(tmpDir, "allowed")
|
||||
if err := os.MkdirAll(allowedDir, 0o755); err != nil {
|
||||
t.Fatalf("创建允许目录失败: %v", err)
|
||||
}
|
||||
|
||||
t.Run("安全符号链接 - 在允许范围内", func(t *testing.T) {
|
||||
// 在允许目录内创建符号链接
|
||||
linkFile := filepath.Join(allowedDir, "link.txt")
|
||||
allowedTarget := filepath.Join(allowedDir, "actual.txt")
|
||||
if err := os.WriteFile(allowedTarget, []byte("allowed content"), 0o644); err != nil {
|
||||
t.Fatalf("创建实际文件失败: %v", err)
|
||||
}
|
||||
if err := os.Symlink(allowedTarget, linkFile); err != nil {
|
||||
t.Fatalf("创建符号链接失败: %v", err)
|
||||
}
|
||||
|
||||
handler := NewStaticHandler(allowedDir, "/", nil, false)
|
||||
handler.SetSymlinkCheck(true)
|
||||
|
||||
ctx := newTestContext(t, "/link.txt")
|
||||
handler.Handle(ctx)
|
||||
|
||||
if got := ctx.Response.StatusCode(); got != fasthttp.StatusOK {
|
||||
t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("不安全符号链接 - 指向允许范围外", func(t *testing.T) {
|
||||
// 创建指向允许目录外的符号链接
|
||||
unsafeLink := filepath.Join(allowedDir, "unsafe.txt")
|
||||
if err := os.Symlink(targetFile, unsafeLink); err != nil {
|
||||
t.Fatalf("创建不安全符号链接失败: %v", err)
|
||||
}
|
||||
|
||||
handler := NewStaticHandler(allowedDir, "/", nil, false)
|
||||
handler.SetSymlinkCheck(true)
|
||||
|
||||
ctx := newTestContext(t, "/unsafe.txt")
|
||||
handler.Handle(ctx)
|
||||
|
||||
if got := ctx.Response.StatusCode(); got != fasthttp.StatusForbidden {
|
||||
t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusForbidden)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("普通文件 - 非符号链接", func(t *testing.T) {
|
||||
normalFile := filepath.Join(allowedDir, "normal.txt")
|
||||
if err := os.WriteFile(normalFile, []byte("normal content"), 0o644); err != nil {
|
||||
t.Fatalf("创建普通文件失败: %v", err)
|
||||
}
|
||||
|
||||
handler := NewStaticHandler(allowedDir, "/", nil, false)
|
||||
handler.SetSymlinkCheck(true)
|
||||
|
||||
ctx := newTestContext(t, "/normal.txt")
|
||||
handler.Handle(ctx)
|
||||
|
||||
if got := ctx.Response.StatusCode(); got != fasthttp.StatusOK {
|
||||
t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusOK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("未启用符号链接检查", func(t *testing.T) {
|
||||
// 创建指向允许目录外的符号链接
|
||||
externalLink := filepath.Join(allowedDir, "external.txt")
|
||||
if err := os.Symlink(targetFile, externalLink); err != nil {
|
||||
t.Fatalf("创建外部符号链接失败: %v", err)
|
||||
}
|
||||
|
||||
handler := NewStaticHandler(allowedDir, "/", nil, false)
|
||||
// 不启用符号链接检查
|
||||
|
||||
ctx := newTestContext(t, "/external.txt")
|
||||
handler.Handle(ctx)
|
||||
|
||||
// 未启用检查时,可以访问符号链接
|
||||
if got := ctx.Response.StatusCode(); got != fasthttp.StatusOK {
|
||||
t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestStaticHandler_ValidateSymlink_WithAlias 测试 alias 模式下的符号链接验证
|
||||
func TestStaticHandler_ValidateSymlink_WithAlias(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// 创建 alias 目录
|
||||
aliasDir := filepath.Join(tmpDir, "alias")
|
||||
if err := os.MkdirAll(aliasDir, 0o755); err != nil {
|
||||
t.Fatalf("创建 alias 目录失败: %v", err)
|
||||
}
|
||||
|
||||
// 在 alias 目录内创建文件和符号链接
|
||||
actualFile := filepath.Join(aliasDir, "actual.txt")
|
||||
if err := os.WriteFile(actualFile, []byte("actual content"), 0o644); err != nil {
|
||||
t.Fatalf("创建实际文件失败: %v", err)
|
||||
}
|
||||
|
||||
linkFile := filepath.Join(aliasDir, "link.txt")
|
||||
if err := os.Symlink(actualFile, linkFile); err != nil {
|
||||
t.Fatalf("创建符号链接失败: %v", err)
|
||||
}
|
||||
|
||||
handler := NewStaticHandlerWithAlias(aliasDir, "/static/", nil, false)
|
||||
handler.SetSymlinkCheck(true)
|
||||
|
||||
ctx := newTestContext(t, "/static/link.txt")
|
||||
handler.Handle(ctx)
|
||||
|
||||
if got := ctx.Response.StatusCode(); got != fasthttp.StatusOK {
|
||||
t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_ValidateSymlink_NoRootOrAlias 测试无 root/alias 时符号链接验证
|
||||
func TestStaticHandler_ValidateSymlink_NoRootOrAlias(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
// 创建文件和符号链接
|
||||
targetFile := filepath.Join(tmpDir, "target.txt")
|
||||
if err := os.WriteFile(targetFile, []byte("target"), 0o644); err != nil {
|
||||
t.Fatalf("创建目标文件失败: %v", err)
|
||||
}
|
||||
|
||||
linkFile := filepath.Join(tmpDir, "link.txt")
|
||||
if err := os.Symlink(targetFile, linkFile); err != nil {
|
||||
t.Fatalf("创建符号链接失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建无 root/alias 的处理器
|
||||
handler := NewStaticHandler("", "/", nil, false)
|
||||
handler.SetSymlinkCheck(true)
|
||||
|
||||
// 直接调用 validateSymlink
|
||||
err := handler.validateSymlink(linkFile)
|
||||
if err == nil {
|
||||
t.Error("无 root/alias 时验证符号链接应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_Handle_WithCacheTTL 测试带 TTL 的缓存处理
|
||||
func TestStaticHandler_Handle_WithCacheTTL(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
content := "cached with ttl"
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
if err := os.WriteFile(testFile, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("创建测试文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建带缓存的处理器
|
||||
handler := newTestHandler(t, tmpDir)
|
||||
handler.SetFileCache(cache.NewFileCache(100, 1024*1024, time.Hour))
|
||||
handler.SetCacheTTL(5 * time.Second)
|
||||
|
||||
// 第一次请求,填充缓存
|
||||
ctx1 := newTestContext(t, "/test.txt")
|
||||
handler.Handle(ctx1)
|
||||
|
||||
if got := ctx1.Response.StatusCode(); got != fasthttp.StatusOK {
|
||||
t.Errorf("第一次请求状态码 = %d, want %d", got, fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
// 第二次请求,应该命中缓存
|
||||
ctx2 := newTestContext(t, "/test.txt")
|
||||
handler.Handle(ctx2)
|
||||
|
||||
if got := ctx2.Response.StatusCode(); got != fasthttp.StatusOK {
|
||||
t.Errorf("第二次请求状态码 = %d, want %d", got, fasthttp.StatusOK)
|
||||
}
|
||||
|
||||
// 内容应该一致
|
||||
if string(ctx1.Response.Body()) != string(ctx2.Response.Body()) {
|
||||
t.Error("缓存内容不一致")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_Handle_CacheTTLExpired 测试 TTL 过期后重新验证
|
||||
func TestStaticHandler_Handle_CacheTTLExpired(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
content := "initial content"
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
if err := os.WriteFile(testFile, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("创建测试文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建带缓存的处理器,TTL 设置很短
|
||||
handler := newTestHandler(t, tmpDir)
|
||||
handler.SetFileCache(cache.NewFileCache(100, 1024*1024, time.Hour))
|
||||
handler.SetCacheTTL(100 * time.Millisecond)
|
||||
|
||||
// 第一次请求,填充缓存
|
||||
ctx1 := newTestContext(t, "/test.txt")
|
||||
handler.Handle(ctx1)
|
||||
|
||||
// 等待 TTL 过期
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// 修改文件
|
||||
newContent := "updated content"
|
||||
if err := os.WriteFile(testFile, []byte(newContent), 0o644); err != nil {
|
||||
t.Fatalf("更新文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 第二次请求,TTL 过期后应该重新读取
|
||||
ctx2 := newTestContext(t, "/test.txt")
|
||||
handler.Handle(ctx2)
|
||||
|
||||
if got := string(ctx2.Response.Body()); got != newContent {
|
||||
t.Errorf("TTL 过期后内容 = %q, want %q", got, newContent)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStaticHandler_Handle_CacheModTimeChanged 测试文件修改后缓存更新
|
||||
func TestStaticHandler_Handle_CacheModTimeChanged(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
initialContent := "initial"
|
||||
if err := os.WriteFile(testFile, []byte(initialContent), 0o644); err != nil {
|
||||
t.Fatalf("创建测试文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建带缓存的处理器
|
||||
handler := newTestHandler(t, tmpDir)
|
||||
handler.SetFileCache(cache.NewFileCache(100, 1024*1024, time.Hour))
|
||||
|
||||
// 第一次请求,填充缓存
|
||||
ctx1 := newTestContext(t, "/test.txt")
|
||||
handler.Handle(ctx1)
|
||||
|
||||
// 修改文件
|
||||
time.Sleep(10 * time.Millisecond) // 确保 ModTime 变化
|
||||
newContent := "modified"
|
||||
if err := os.WriteFile(testFile, []byte(newContent), 0o644); err != nil {
|
||||
t.Fatalf("修改文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 第二次请求,应该检测到文件变化并更新缓存
|
||||
ctx2 := newTestContext(t, "/test.txt")
|
||||
handler.Handle(ctx2)
|
||||
|
||||
if got := string(ctx2.Response.Body()); got != newContent {
|
||||
t.Errorf("修改后内容 = %q, want %q", got, newContent)
|
||||
}
|
||||
}
|
||||
|
||||
@ -338,3 +338,131 @@ func TestUpdateDenyListError(t *testing.T) {
|
||||
t.Error("UpdateDenyList() should return error for invalid CIDR")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewAccessControl_TrustedProxiesError 测试可信代理 CIDR 解析错误
|
||||
func TestNewAccessControl_TrustedProxiesError(t *testing.T) {
|
||||
_, err := NewAccessControl(&config.AccessConfig{
|
||||
TrustedProxies: []string{"invalid-cidr"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("NewAccessControl() should return error for invalid trusted proxy CIDR")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewAccessControl_DenyListError 测试拒绝列表 CIDR 解析错误
|
||||
func TestNewAccessControl_DenyListError(t *testing.T) {
|
||||
_, err := NewAccessControl(&config.AccessConfig{
|
||||
Deny: []string{"not-a-cidr"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("NewAccessControl() should return error for invalid deny CIDR")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheck_AllowListHit 测试允许列表命中
|
||||
func TestCheck_AllowListHit(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
Allow: []string{"192.168.1.0/24"},
|
||||
Default: "deny",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
// 允许列表中的 IP
|
||||
if !ac.Check(net.ParseIP("192.168.1.50")) {
|
||||
t.Error("Check() should return true for IP in allow list")
|
||||
}
|
||||
|
||||
// 不在允许列表中的 IP
|
||||
if ac.Check(net.ParseIP("10.0.0.1")) {
|
||||
t.Error("Check() should return false for IP not in allow list")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheck_DenyListHit 测试拒绝列表命中
|
||||
func TestCheck_DenyListHit(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
Deny: []string{"10.0.0.100"},
|
||||
Allow: []string{"10.0.0.0/8"},
|
||||
Default: "allow",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
// 拒绝列表优先
|
||||
if ac.Check(net.ParseIP("10.0.0.100")) {
|
||||
t.Error("Check() should return false for IP in deny list even if in allow list")
|
||||
}
|
||||
|
||||
// 在允许列表但不在拒绝列表
|
||||
if !ac.Check(net.ParseIP("10.0.0.50")) {
|
||||
t.Error("Check() should return true for IP in allow list but not in deny list")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheck_DefaultAction 测试默认操作
|
||||
func TestCheck_DefaultAction(t *testing.T) {
|
||||
// 默认允许
|
||||
acAllow, err := NewAccessControl(&config.AccessConfig{
|
||||
Default: "allow",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
if !acAllow.Check(net.ParseIP("1.2.3.4")) {
|
||||
t.Error("Check() should return true with default allow")
|
||||
}
|
||||
|
||||
// 默认拒绝
|
||||
acDeny, err := NewAccessControl(&config.AccessConfig{
|
||||
Default: "deny",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
if acDeny.Check(net.ParseIP("1.2.3.4")) {
|
||||
t.Error("Check() should return false with default deny")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheck_EmptyDefault 测试空默认值(应默认为 allow)
|
||||
func TestCheck_EmptyDefault(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
Default: "",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
if !ac.Check(net.ParseIP("1.2.3.4")) {
|
||||
t.Error("Check() should return true with empty default (should be allow)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAccessControl_ProcessDenied 测试 Process 拒绝请求
|
||||
func TestAccessControl_ProcessDenied(t *testing.T) {
|
||||
ac, err := NewAccessControl(&config.AccessConfig{
|
||||
Deny: []string{"192.168.1.0/24"},
|
||||
Default: "allow",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewAccessControl() error: %v", err)
|
||||
}
|
||||
|
||||
nextCalled := false
|
||||
handler := ac.Process(func(ctx *fasthttp.RequestCtx) {
|
||||
nextCalled = true
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("192.168.1.50"), Port: 12345})
|
||||
handler(ctx)
|
||||
|
||||
if nextCalled {
|
||||
t.Error("Process() should not call next handler for denied IP")
|
||||
}
|
||||
if ctx.Response.StatusCode() != fasthttp.StatusForbidden {
|
||||
t.Errorf("Process() status = %d, want 403", ctx.Response.StatusCode())
|
||||
}
|
||||
}
|
||||
|
||||
@ -411,3 +411,162 @@ func BenchmarkAuthRequestExpandVars(b *testing.B) {
|
||||
func contains(s, substr string) bool {
|
||||
return strings.Contains(s, substr)
|
||||
}
|
||||
|
||||
// TestAuthRequest_Middleware 测试 Middleware 方法
|
||||
func TestAuthRequest_Middleware(t *testing.T) {
|
||||
ar := &AuthRequest{}
|
||||
|
||||
mw := ar.Middleware()
|
||||
if mw == nil {
|
||||
t.Error("Middleware() should not return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthRequestExpandVars_Empty 测试空模板展开
|
||||
func TestAuthRequestExpandVars_Empty(t *testing.T) {
|
||||
ar := &AuthRequest{config: config.AuthRequestConfig{}}
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
result := ar.expandVars(ctx, "")
|
||||
if result != "" {
|
||||
t.Errorf("expandVars(empty) = %q, want empty", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthRequestExpandVars_NoVars 测试无变量模板
|
||||
func TestAuthRequestExpandVars_NoVars(t *testing.T) {
|
||||
ar := &AuthRequest{config: config.AuthRequestConfig{}}
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
result := ar.expandVars(ctx, "http://auth-service/verify")
|
||||
if result != "http://auth-service/verify" {
|
||||
t.Errorf("expandVars(no vars) = %q, want original", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateConfig_RelativePath 测试更新为相对路径配置
|
||||
func TestUpdateConfig_RelativePath(t *testing.T) {
|
||||
cfg := config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "http://auth-service:8080/auth",
|
||||
}
|
||||
|
||||
ar, err := NewAuthRequest(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthRequest() failed: %v", err)
|
||||
}
|
||||
|
||||
// 更新为相对路径
|
||||
newCfg := config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "/auth/verify",
|
||||
}
|
||||
err = ar.UpdateConfig(newCfg)
|
||||
if err != nil {
|
||||
t.Errorf("UpdateConfig() failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证 client 被清空(相对路径不需要独立客户端)
|
||||
if ar.client != nil {
|
||||
t.Error("client should be nil for relative path")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthRequest_ProcessEnabled 测试 Process 处理启用状态
|
||||
func TestAuthRequest_ProcessEnabled(t *testing.T) {
|
||||
cfg := config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "/auth/verify",
|
||||
Method: "GET",
|
||||
Timeout: 5 * time.Second,
|
||||
ForwardHeaders: []string{"X-Custom-Header"},
|
||||
Headers: map[string]string{"X-Auth-Source": "lolly"},
|
||||
}
|
||||
|
||||
ar, err := NewAuthRequest(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthRequest() failed: %v", err)
|
||||
}
|
||||
|
||||
next := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(200)
|
||||
}
|
||||
|
||||
handler := ar.Process(next)
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/test")
|
||||
ctx.Request.Header.Set("X-Custom-Header", "custom-value")
|
||||
|
||||
// 执行处理器(由于没有实际的认证服务,请求会失败)
|
||||
handler(ctx)
|
||||
|
||||
// 验证处理器行为 - 由于认证服务不可达,应该返回 500
|
||||
if ctx.Response.StatusCode() != 500 {
|
||||
t.Logf("Status = %d (expected 500 due to unreachable auth service)", ctx.Response.StatusCode())
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseAuthURL_Invalid 测试解析无效 URL
|
||||
func TestParseAuthURL_Invalid(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
}{
|
||||
{"empty string", ""},
|
||||
{"only protocol", "http://"},
|
||||
{"only https protocol", "https://"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, _, err := parseAuthURL(tt.url)
|
||||
if err == nil {
|
||||
t.Errorf("parseAuthURL(%q) should return error", tt.url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExpandVars_EdgeCases 测试变量展开边缘情况
|
||||
func TestExpandVars_EdgeCases(t *testing.T) {
|
||||
ar := &AuthRequest{config: config.AuthRequestConfig{}}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/path?query=value")
|
||||
ctx.Request.Header.SetHost("example.com")
|
||||
|
||||
// 测试带多个变量的模板
|
||||
result := ar.expandVars(ctx, "http://auth?uri=$request_uri&host=$host&method=$request_method")
|
||||
if result == "" {
|
||||
t.Error("expandVars should return non-empty result")
|
||||
}
|
||||
|
||||
// 测试只有 $ 符号的模板
|
||||
result = ar.expandVars(ctx, "http://auth?param=$")
|
||||
if result != "http://auth?param=$" {
|
||||
t.Errorf("expandVars with lone $ should return original, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInitClient_HTTPS 测试 HTTPS 客户端初始化
|
||||
func TestInitClient_HTTPS(t *testing.T) {
|
||||
cfg := config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "https://secure-auth:8443/verify",
|
||||
Method: "GET",
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
ar, err := NewAuthRequest(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthRequest() failed: %v", err)
|
||||
}
|
||||
|
||||
if ar.client == nil {
|
||||
t.Error("client should be initialized for HTTPS URL")
|
||||
}
|
||||
if !ar.client.IsTLS {
|
||||
t.Error("client.IsTLS should be true for HTTPS URL")
|
||||
}
|
||||
}
|
||||
|
||||
@ -748,3 +748,54 @@ func TestName(t *testing.T) {
|
||||
t.Errorf("Expected name 'basic_auth', got %s", auth.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthenticate_UnknownAlgorithm 测试未知算法
|
||||
func TestAuthenticate_UnknownAlgorithm(t *testing.T) {
|
||||
auth := &BasicAuth{
|
||||
users: map[string]string{"admin": "$2b$12$hash"},
|
||||
algorithm: HashAlgorithm(99), // 未知算法
|
||||
}
|
||||
|
||||
result := auth.Authenticate("admin", "password")
|
||||
if result {
|
||||
t.Error("Authenticate() should return false for unknown algorithm")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthenticateBcrypt_Error 测试 bcrypt 验证错误路径
|
||||
func TestAuthenticateBcrypt_Error(t *testing.T) {
|
||||
// 测试无效的 bcrypt 哈希
|
||||
result := authenticateBcrypt("password", "invalid_hash")
|
||||
if result {
|
||||
t.Error("authenticateBcrypt() should return false for invalid hash")
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseArgon2idHash_InvalidParts 测试无效的 argon2id 哈希格式
|
||||
func TestParseArgon2idHash_InvalidParts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hash string
|
||||
}{
|
||||
{"too few parts", "$argon2id$v=19$m=32,t=2,p=2"},
|
||||
{"wrong algorithm", "$bcrypt$v=19$m=32,t=2,p=2$salt$hash"},
|
||||
{"wrong version", "$argon2id$v=18$m=32,t=2,p=2$salt$hash"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, _, _, err := parseArgon2idHash(tt.hash)
|
||||
if err == nil {
|
||||
t.Errorf("parseArgon2idHash(%q) should return error", tt.hash)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidatePasswordHash_UnknownAlgorithm 测试未知算法的密码哈希验证
|
||||
func TestValidatePasswordHash_UnknownAlgorithm(t *testing.T) {
|
||||
err := validatePasswordHash("hash", HashAlgorithm(99))
|
||||
if err == nil {
|
||||
t.Error("validatePasswordHash() should return error for unknown algorithm")
|
||||
}
|
||||
}
|
||||
|
||||
@ -254,3 +254,137 @@ func TestFormatHSTSValue(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewHeadersWithHSTS 测试带 HSTS 配置的创建
|
||||
func TestNewHeadersWithHSTS(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
}
|
||||
|
||||
hstsCfg := &config.HSTSConfig{
|
||||
MaxAge: 86400,
|
||||
IncludeSubDomains: true,
|
||||
Preload: false,
|
||||
}
|
||||
|
||||
sh := NewHeadersWithHSTS(cfg, hstsCfg)
|
||||
if sh == nil {
|
||||
t.Error("NewHeadersWithHSTS() should not return nil")
|
||||
}
|
||||
if sh.hsts != "max-age=86400; includeSubDomains" {
|
||||
t.Errorf("HSTS = %q, want 'max-age=86400; includeSubDomains'", sh.hsts)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewHeadersWithHSTS_NilConfig 测试 nil HSTS 配置
|
||||
func TestNewHeadersWithHSTS_NilConfig(t *testing.T) {
|
||||
sh := NewHeadersWithHSTS(nil, nil)
|
||||
if sh == nil {
|
||||
t.Error("NewHeadersWithHSTS() should not return nil")
|
||||
}
|
||||
// 应该使用默认配置
|
||||
if sh.config == nil {
|
||||
t.Error("Should use default config when nil is passed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewHeadersWithHSTS_ZeroMaxAge 测试 MaxAge 为 0 时使用默认值
|
||||
func TestNewHeadersWithHSTS_ZeroMaxAge(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
}
|
||||
|
||||
hstsCfg := &config.HSTSConfig{
|
||||
MaxAge: 0, // 应该使用默认值 31536000
|
||||
IncludeSubDomains: true,
|
||||
Preload: false,
|
||||
}
|
||||
|
||||
sh := NewHeadersWithHSTS(cfg, hstsCfg)
|
||||
if sh.hsts != "max-age=31536000; includeSubDomains" {
|
||||
t.Errorf("HSTS with zero maxAge should use default: %q", sh.hsts)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHeadersProcess_NilConfig 测试 Process 处理 nil config
|
||||
func TestHeadersProcess_NilConfig(t *testing.T) {
|
||||
sh := NewHeadersWithHSTS(nil, nil)
|
||||
|
||||
nextCalled := false
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
nextCalled = true
|
||||
}
|
||||
|
||||
handler := sh.Process(nextHandler)
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
handler(ctx)
|
||||
|
||||
if !nextCalled {
|
||||
t.Error("Next handler should be called")
|
||||
}
|
||||
|
||||
// 验证默认安全头被设置
|
||||
if string(ctx.Response.Header.Peek("X-Content-Type-Options")) != "nosniff" {
|
||||
t.Error("Default X-Content-Type-Options should be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHeadersProcess_TLS 测试 TLS 情况下 HSTS 头设置
|
||||
func TestHeadersProcess_TLS(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
}
|
||||
|
||||
hstsCfg := &config.HSTSConfig{
|
||||
MaxAge: 31536000,
|
||||
IncludeSubDomains: true,
|
||||
Preload: true,
|
||||
}
|
||||
|
||||
sh := NewHeadersWithHSTS(cfg, hstsCfg)
|
||||
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {}
|
||||
|
||||
handler := sh.Process(nextHandler)
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
// 注意:在测试环境中无法真正模拟 TLS 连接
|
||||
// 这个测试验证 handler 不会 panic
|
||||
handler(ctx)
|
||||
}
|
||||
|
||||
// TestAddHeaders_AllHeaders 测试所有安全头设置
|
||||
func TestAddHeaders_AllHeaders(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{
|
||||
XFrameOptions: "SAMEORIGIN",
|
||||
XContentTypeOptions: "nosniff",
|
||||
ContentSecurityPolicy: "default-src 'self'",
|
||||
ReferrerPolicy: "strict-origin",
|
||||
PermissionsPolicy: "camera=(), microphone=()",
|
||||
}
|
||||
|
||||
sh := NewHeaders(cfg)
|
||||
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {}
|
||||
handler := sh.Process(nextHandler)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
handler(ctx)
|
||||
|
||||
headers := &ctx.Response.Header
|
||||
|
||||
if string(headers.Peek("X-Frame-Options")) != "SAMEORIGIN" {
|
||||
t.Error("X-Frame-Options not set correctly")
|
||||
}
|
||||
if string(headers.Peek("X-Content-Type-Options")) != "nosniff" {
|
||||
t.Error("X-Content-Type-Options not set correctly")
|
||||
}
|
||||
if string(headers.Peek("Content-Security-Policy")) != "default-src 'self'" {
|
||||
t.Error("Content-Security-Policy not set correctly")
|
||||
}
|
||||
if string(headers.Peek("Referrer-Policy")) != "strict-origin" {
|
||||
t.Error("Referrer-Policy not set correctly")
|
||||
}
|
||||
if string(headers.Peek("Permissions-Policy")) != "camera=(), microphone=()" {
|
||||
t.Error("Permissions-Policy not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
@ -459,3 +459,392 @@ func TestConnLimiterMiddleware(t *testing.T) {
|
||||
t.Errorf("Expected name 'conn_limiter', got %s", middleware.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewRateLimiter_SlidingWindow 测试滑动窗口算法
|
||||
func TestNewRateLimiter_SlidingWindow(t *testing.T) {
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 100,
|
||||
Burst: 100,
|
||||
Algorithm: "sliding_window",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
if mw == nil {
|
||||
t.Error("Expected non-nil middleware for sliding_window")
|
||||
}
|
||||
if mw.Name() != "sliding_window_limiter" {
|
||||
t.Errorf("Expected name 'sliding_window_limiter', got %s", mw.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewRateLimiter_SlidingWindowDefault 测试滑动窗口默认配置
|
||||
func TestNewRateLimiter_SlidingWindowDefault(t *testing.T) {
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 100,
|
||||
Burst: 100,
|
||||
Algorithm: "sliding_window",
|
||||
SlidingWindow: 0, // 使用默认值
|
||||
SlidingWindowMode: "", // 使用默认值
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
if mw == nil {
|
||||
t.Error("Expected non-nil middleware")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewRateLimiter_SlidingWindowPrecise 测试滑动窗口精确模式
|
||||
func TestNewRateLimiter_SlidingWindowPrecise(t *testing.T) {
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 100,
|
||||
Burst: 100,
|
||||
Algorithm: "sliding_window",
|
||||
SlidingWindow: 1,
|
||||
SlidingWindowMode: "precise",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
if mw == nil {
|
||||
t.Error("Expected non-nil middleware")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewRateLimiter_UnknownAlgorithm 测试未知算法
|
||||
func TestNewRateLimiter_UnknownAlgorithm(t *testing.T) {
|
||||
_, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 100,
|
||||
Burst: 100,
|
||||
Algorithm: "unknown_algo",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("NewRateLimiter() should return error for unknown algorithm")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSlidingWindowLimiterWrapper_Process 测试滑动窗口包装器的 Process 方法
|
||||
func TestSlidingWindowLimiterWrapper_Process(t *testing.T) {
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 100,
|
||||
Burst: 100,
|
||||
Algorithm: "sliding_window",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
called := false
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
}
|
||||
|
||||
handler := mw.Process(nextHandler)
|
||||
if handler == nil {
|
||||
t.Fatal("Process() returned nil handler")
|
||||
}
|
||||
|
||||
ctx := testutil.NewRequestCtx("GET", "/test")
|
||||
handler(ctx)
|
||||
|
||||
if !called {
|
||||
t.Error("Next handler should be called when rate limit allows")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSlidingWindowLimiterWrapper_ProcessDenied 测试滑动窗口拒绝请求
|
||||
func TestSlidingWindowLimiterWrapper_ProcessDenied(t *testing.T) {
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 1,
|
||||
Burst: 1,
|
||||
Algorithm: "sliding_window",
|
||||
SlidingWindowMode: "precise",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
callCount := 0
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
callCount++
|
||||
}
|
||||
|
||||
handler := mw.Process(nextHandler)
|
||||
ctx := testutil.NewRequestCtx("GET", "/test")
|
||||
|
||||
// 第一个请求应该被允许
|
||||
handler(ctx)
|
||||
// 第二个请求应该被拒绝
|
||||
handler(ctx)
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected next handler to be called once, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimiter_GetRetryAfter 测试 getRetryAfter 方法
|
||||
func TestRateLimiter_GetRetryAfter(t *testing.T) {
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 10,
|
||||
Burst: 10,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
rl, ok := mw.(*RateLimiter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||
}
|
||||
defer rl.StopCleanup()
|
||||
|
||||
// 测试不存在的键
|
||||
retryAfter := rl.getRetryAfter("nonexistent")
|
||||
if retryAfter != 1 {
|
||||
t.Errorf("getRetryAfter(nonexistent) = %d, want 1", retryAfter)
|
||||
}
|
||||
|
||||
// 创建一个桶并消耗令牌
|
||||
key := "test-key"
|
||||
for i := 0; i < 10; i++ {
|
||||
rl.Allow(key)
|
||||
}
|
||||
|
||||
// 获取重试时间
|
||||
retryAfter = rl.getRetryAfter(key)
|
||||
if retryAfter < 1 {
|
||||
t.Errorf("getRetryAfter() = %d, want at least 1", retryAfter)
|
||||
}
|
||||
}
|
||||
|
||||
// TestKeyByIP 测试 keyByIP 函数
|
||||
func TestKeyByIP(t *testing.T) {
|
||||
ctx := testutil.NewRequestCtx("GET", "/test")
|
||||
|
||||
key := keyByIP(ctx)
|
||||
if key == "" {
|
||||
t.Error("keyByIP() should return non-empty string")
|
||||
}
|
||||
if key == "unknown" {
|
||||
t.Error("keyByIP() should return IP address, not 'unknown'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestKeyByHeader 测试 keyByHeader 函数
|
||||
func TestKeyByHeader(t *testing.T) {
|
||||
// 测试有头部的情况
|
||||
ctx := testutil.NewRequestCtx("GET", "/test")
|
||||
ctx.Request.Header.Set("X-RateLimit-Key", "custom-key")
|
||||
|
||||
key := keyByHeader(ctx)
|
||||
if key != "custom-key" {
|
||||
t.Errorf("keyByHeader() = %q, want 'custom-key'", key)
|
||||
}
|
||||
|
||||
// 测试没有头部的情况(应该回退到 IP)
|
||||
ctx2 := testutil.NewRequestCtx("GET", "/test")
|
||||
key2 := keyByHeader(ctx2)
|
||||
if key2 == "" {
|
||||
t.Error("keyByHeader() should fallback to IP when header not set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnLimiter_PerKey 测试按键限制
|
||||
func TestConnLimiter_PerKey(t *testing.T) {
|
||||
cl, err := NewConnLimiter(2, true, "ip")
|
||||
if err != nil {
|
||||
t.Fatalf("NewConnLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := testutil.NewRequestCtx("GET", "/")
|
||||
|
||||
// 同一个键的前两个应该成功
|
||||
if !cl.Acquire(ctx) {
|
||||
t.Error("Expected first acquire to succeed")
|
||||
}
|
||||
if !cl.Acquire(ctx) {
|
||||
t.Error("Expected second acquire to succeed")
|
||||
}
|
||||
|
||||
// 第三个应该失败
|
||||
if cl.Acquire(ctx) {
|
||||
t.Error("Expected third acquire to fail")
|
||||
}
|
||||
|
||||
// 释放一个
|
||||
cl.Release(ctx)
|
||||
|
||||
// 现在应该成功
|
||||
if !cl.Acquire(ctx) {
|
||||
t.Error("Expected acquire after release to succeed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnLimiter_ReleaseUnderflow 测试 Release 下溢保护
|
||||
func TestConnLimiter_ReleaseUnderflow(t *testing.T) {
|
||||
cl, err := NewConnLimiter(2, true, "ip")
|
||||
if err != nil {
|
||||
t.Fatalf("NewConnLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
ctx := testutil.NewRequestCtx("GET", "/")
|
||||
|
||||
// 在没有 Acquire 的情况下 Release(测试下溢保护)
|
||||
cl.Release(ctx) // 不应该 panic
|
||||
|
||||
// 验证计数不会变成负数
|
||||
cl.Acquire(ctx)
|
||||
cl.Acquire(ctx)
|
||||
if cl.Acquire(ctx) {
|
||||
t.Error("Expected third acquire to fail")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnLimiterMiddleware_Process 测试连接限制中间件 Process
|
||||
func TestConnLimiterMiddleware_Process(t *testing.T) {
|
||||
cl, err := NewConnLimiter(1, false, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewConnLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
mw := cl.Middleware()
|
||||
|
||||
called := false
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
ctx.SetStatusCode(200)
|
||||
}
|
||||
|
||||
handler := mw.Process(nextHandler)
|
||||
ctx := testutil.NewRequestCtx("GET", "/test")
|
||||
|
||||
// 第一个请求应该成功
|
||||
handler(ctx)
|
||||
if !called {
|
||||
t.Error("Next handler should be called")
|
||||
}
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Errorf("Status = %d, want 200", ctx.Response.StatusCode())
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnLimiterMiddleware_ProcessLimitExceeded 测试连接限制超出
|
||||
func TestConnLimiterMiddleware_ProcessLimitExceeded(t *testing.T) {
|
||||
cl, err := NewConnLimiter(1, false, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewConnLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
mw := cl.Middleware()
|
||||
|
||||
callCount := 0
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
callCount++
|
||||
ctx.SetStatusCode(200)
|
||||
}
|
||||
|
||||
handler := mw.Process(nextHandler)
|
||||
|
||||
// 用尽连接限制
|
||||
ctx1 := testutil.NewRequestCtx("GET", "/test1")
|
||||
cl.Acquire(ctx1) // 手动占用一个槽位
|
||||
|
||||
// 现在应该无法获取新的连接
|
||||
ctx2 := testutil.NewRequestCtx("GET", "/test2")
|
||||
handler(ctx2)
|
||||
|
||||
if callCount != 0 {
|
||||
t.Error("Next handler should NOT be called when limit exceeded")
|
||||
}
|
||||
if ctx2.Response.StatusCode() != 503 {
|
||||
t.Errorf("Status = %d, want 503", ctx2.Response.StatusCode())
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewSlidingWindowLimiterWrapper_Error 测试滑动窗口包装器错误情况
|
||||
func TestNewSlidingWindowLimiterWrapper_Error(t *testing.T) {
|
||||
_, err := NewSlidingWindowLimiterWrapper(&config.RateLimitConfig{
|
||||
RequestRate: 100,
|
||||
Key: "invalid_key_type",
|
||||
}, time.Second, false)
|
||||
if err == nil {
|
||||
t.Error("NewSlidingWindowLimiterWrapper should return error for invalid key type")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimiter_Name 测试 RateLimiter Name 方法
|
||||
func TestRateLimiter_Name(t *testing.T) {
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 10,
|
||||
Burst: 10,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
rl, ok := mw.(*RateLimiter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||
}
|
||||
defer rl.StopCleanup()
|
||||
|
||||
if rl.Name() != "rate_limiter" {
|
||||
t.Errorf("Name() = %q, want 'rate_limiter'", rl.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimiter_ProcessDenied 测试限流拒绝
|
||||
func TestRateLimiter_ProcessDenied(t *testing.T) {
|
||||
mw, err := NewRateLimiter(&config.RateLimitConfig{
|
||||
RequestRate: 1,
|
||||
Burst: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewRateLimiter() error: %v", err)
|
||||
}
|
||||
|
||||
rl, ok := mw.(*RateLimiter)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *RateLimiter, got %T", mw)
|
||||
}
|
||||
defer rl.StopCleanup()
|
||||
|
||||
callCount := 0
|
||||
nextHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
callCount++
|
||||
}
|
||||
|
||||
handler := rl.Process(nextHandler)
|
||||
|
||||
// 第一个请求应该成功
|
||||
ctx1 := testutil.NewRequestCtx("GET", "/test")
|
||||
handler(ctx1)
|
||||
|
||||
// 第二个请求应该被限流(使用不同的 context)
|
||||
ctx2 := testutil.NewRequestCtx("GET", "/test")
|
||||
handler(ctx2)
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected next handler to be called once, got %d", callCount)
|
||||
}
|
||||
|
||||
// 检查第二个请求的状态码
|
||||
if ctx2.Response.StatusCode() != 429 {
|
||||
t.Errorf("Status = %d, want 429", ctx2.Response.StatusCode())
|
||||
}
|
||||
}
|
||||
|
||||
// TestKeyByIP_Unknown 测试无法获取 IP 的情况
|
||||
func TestKeyByIP_Unknown(t *testing.T) {
|
||||
// 创建一个没有设置远程地址的上下文
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/test")
|
||||
|
||||
key := keyByIP(ctx)
|
||||
// netutil.ExtractClientIPNet 会返回默认值 0.0.0.0 而不是 nil
|
||||
// 所以这里验证返回值不是空的即可
|
||||
if key == "" {
|
||||
t.Error("keyByIP() should return non-empty string")
|
||||
}
|
||||
}
|
||||
|
||||
1636
internal/proxy/proxy_coverage_extra_test.go
Normal file
1636
internal/proxy/proxy_coverage_extra_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -236,3 +236,220 @@ func TestPoolWrapHandler_WhenStopped(t *testing.T) {
|
||||
t.Error("Expected handler to be executed directly when pool is stopped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolSubmit_QueueFull_StartNewWorker(t *testing.T) {
|
||||
// 测试队列满时启动新 worker
|
||||
p := NewGoroutinePool(PoolConfig{
|
||||
MaxWorkers: 10,
|
||||
MinWorkers: 1,
|
||||
QueueSize: 1, // 小队列,容易满
|
||||
IdleTimeout: 5 * time.Second,
|
||||
})
|
||||
|
||||
p.Start()
|
||||
defer p.Stop()
|
||||
|
||||
// 填满队列,让后续提交触发 default 分支
|
||||
var executedCount atomic.Int32
|
||||
blockTask := func(*fasthttp.RequestCtx) {
|
||||
time.Sleep(200 * time.Millisecond) // 阻塞任务
|
||||
executedCount.Add(1)
|
||||
}
|
||||
|
||||
// 先提交一个阻塞任务,让 worker 忙碌
|
||||
_ = p.Submit(nil, blockTask)
|
||||
time.Sleep(50 * time.Millisecond) // 等待 worker 开始执行
|
||||
|
||||
// 填满队列
|
||||
_ = p.Submit(nil, func(*fasthttp.RequestCtx) { executedCount.Add(1) })
|
||||
|
||||
// 此时队列满,应该启动新 worker
|
||||
_ = p.Submit(nil, func(*fasthttp.RequestCtx) { executedCount.Add(1) })
|
||||
|
||||
// 等待所有任务完成
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
if executedCount.Load() < 2 {
|
||||
t.Errorf("Expected at least 2 executions, got %d", executedCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolSubmit_QueueFull_MaxWorkers_Fallback(t *testing.T) {
|
||||
// 测试队列满且达到最大 worker 时直接执行(fallback)
|
||||
p := NewGoroutinePool(PoolConfig{
|
||||
MaxWorkers: 1, // 只有 1 个 worker
|
||||
MinWorkers: 1,
|
||||
QueueSize: 1, // 队列大小 1
|
||||
IdleTimeout: 5 * time.Second,
|
||||
})
|
||||
|
||||
p.Start()
|
||||
defer p.Stop()
|
||||
|
||||
// 使用 channel 阻塞唯一的 worker
|
||||
blockCh := make(chan struct{})
|
||||
started := make(chan struct{})
|
||||
_ = p.Submit(nil, func(*fasthttp.RequestCtx) {
|
||||
close(started) // 通知 worker 已开始
|
||||
<-blockCh // 阻塞直到测试结束
|
||||
})
|
||||
|
||||
// 等待 worker 开始执行阻塞任务
|
||||
<-started
|
||||
|
||||
// 填满队列
|
||||
_ = p.Submit(nil, func(*fasthttp.RequestCtx) {})
|
||||
|
||||
// 现在唯一的 worker 在阻塞,队列已满
|
||||
// 提交新任务应该触发 fallback 直接执行
|
||||
var fallbackExecuted atomic.Bool
|
||||
_ = p.Submit(nil, func(*fasthttp.RequestCtx) {
|
||||
fallbackExecuted.Store(true)
|
||||
})
|
||||
|
||||
// fallback 执行是同步的直接执行
|
||||
if !fallbackExecuted.Load() {
|
||||
t.Error("Expected task to be executed directly (fallback) when max workers reached")
|
||||
}
|
||||
|
||||
// 释放阻塞的 worker
|
||||
close(blockCh)
|
||||
}
|
||||
|
||||
func TestPoolSubmit_WithIdleWorkers(t *testing.T) {
|
||||
// 测试有空闲 worker 时不启动新 worker
|
||||
p := NewGoroutinePool(PoolConfig{
|
||||
MaxWorkers: 10,
|
||||
MinWorkers: 5, // 预热 5 个 worker
|
||||
QueueSize: 10,
|
||||
IdleTimeout: 5 * time.Second,
|
||||
})
|
||||
|
||||
p.Start()
|
||||
defer p.Stop()
|
||||
|
||||
// 等待预热完成,worker 应该是空闲的
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
initialWorkers := atomic.LoadInt32(&p.workers)
|
||||
|
||||
// 提交任务,应该复用空闲 worker
|
||||
var executed atomic.Bool
|
||||
_ = p.Submit(nil, func(*fasthttp.RequestCtx) {
|
||||
executed.Store(true)
|
||||
})
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if !executed.Load() {
|
||||
t.Error("Expected task to be executed")
|
||||
}
|
||||
|
||||
// worker 数量应该保持稳定或更少(不应该增加)
|
||||
finalWorkers := atomic.LoadInt32(&p.workers)
|
||||
if finalWorkers > initialWorkers {
|
||||
t.Errorf("Worker count should not increase when idle workers available: %d -> %d", initialWorkers, finalWorkers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolSubmit_NilTask(t *testing.T) {
|
||||
// 测试提交 nil 任务不会 panic
|
||||
p := NewGoroutinePool(PoolConfig{
|
||||
MaxWorkers: 10,
|
||||
})
|
||||
|
||||
p.Start()
|
||||
defer p.Stop()
|
||||
|
||||
// 提交 nil 任务 - 这会导致 panic,所以不测试
|
||||
// 但可以测试空任务函数
|
||||
var executed atomic.Bool
|
||||
emptyTask := func(*fasthttp.RequestCtx) {
|
||||
executed.Store(true)
|
||||
}
|
||||
|
||||
err := p.Submit(nil, emptyTask)
|
||||
if err != nil {
|
||||
t.Errorf("Submit failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !executed.Load() {
|
||||
t.Error("Expected empty task to be executed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolSubmit_MultipleQueuedTasks(t *testing.T) {
|
||||
// 测试多个任务入队
|
||||
p := NewGoroutinePool(PoolConfig{
|
||||
MaxWorkers: 5,
|
||||
MinWorkers: 2,
|
||||
QueueSize: 10,
|
||||
IdleTimeout: 5 * time.Second,
|
||||
})
|
||||
|
||||
p.Start()
|
||||
defer p.Stop()
|
||||
|
||||
var counter atomic.Int32
|
||||
|
||||
// 快速提交多个任务
|
||||
for i := 0; i < 5; i++ {
|
||||
_ = p.Submit(nil, func(*fasthttp.RequestCtx) {
|
||||
counter.Add(1)
|
||||
})
|
||||
}
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
if counter.Load() != 5 {
|
||||
t.Errorf("Expected 5 executions, got %d", counter.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolSubmit_StartWorkerWhenNoIdle(t *testing.T) {
|
||||
// 测试当没有空闲 worker 时启动新 worker
|
||||
// 使用 MinWorkers=1 让池只预热 1 个 worker
|
||||
p := NewGoroutinePool(PoolConfig{
|
||||
MaxWorkers: 5,
|
||||
MinWorkers: 1, // 只预热 1 个 worker
|
||||
QueueSize: 10,
|
||||
IdleTimeout: 5 * time.Second,
|
||||
})
|
||||
|
||||
p.Start()
|
||||
defer p.Stop()
|
||||
|
||||
// 等待预热完成
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 用阻塞任务让唯一的 worker 忙碌
|
||||
blockCh := make(chan struct{})
|
||||
started := make(chan struct{})
|
||||
_ = p.Submit(nil, func(*fasthttp.RequestCtx) {
|
||||
close(started)
|
||||
<-blockCh
|
||||
})
|
||||
<-started // 等待 worker 开始执行
|
||||
|
||||
// 现在唯一的 worker 在忙碌,idleWorkers == 0
|
||||
// 提交新任务应该启动新 worker
|
||||
var executed atomic.Bool
|
||||
_ = p.Submit(nil, func(*fasthttp.RequestCtx) {
|
||||
executed.Store(true)
|
||||
})
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if !executed.Load() {
|
||||
t.Error("Expected task to be executed by new worker")
|
||||
}
|
||||
|
||||
// 检查是否启动了新 worker
|
||||
if atomic.LoadInt32(&p.workers) < 2 {
|
||||
t.Errorf("Expected at least 2 workers, got %d", atomic.LoadInt32(&p.workers))
|
||||
}
|
||||
|
||||
close(blockCh)
|
||||
}
|
||||
|
||||
262
internal/server/pprof_impl_test.go
Normal file
262
internal/server/pprof_impl_test.go
Normal file
@ -0,0 +1,262 @@
|
||||
// Package server 提供 pprof 实现的测试。
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestStartCPUProfile 测试 startCPUProfile 函数
|
||||
func TestStartCPUProfile(t *testing.T) {
|
||||
t.Run("start and stop CPU profile", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
err := startCPUProfile(&buf)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error starting CPU profile: %v", err)
|
||||
}
|
||||
|
||||
// 停止 CPU profile
|
||||
stopCPUProfile()
|
||||
})
|
||||
|
||||
t.Run("start twice should not error", func(t *testing.T) {
|
||||
var buf1, buf2 bytes.Buffer
|
||||
|
||||
err1 := startCPUProfile(&buf1)
|
||||
if err1 != nil {
|
||||
t.Errorf("unexpected error on first start: %v", err1)
|
||||
}
|
||||
|
||||
// 第二次启动应该被忽略(已在采集)
|
||||
err2 := startCPUProfile(&buf2)
|
||||
if err2 != nil {
|
||||
t.Errorf("unexpected error on second start: %v", err2)
|
||||
}
|
||||
|
||||
stopCPUProfile()
|
||||
})
|
||||
|
||||
t.Run("stop without start should not panic", func(t *testing.T) {
|
||||
// 确保停止状态
|
||||
stopCPUProfile()
|
||||
// 再次停止应该安全
|
||||
stopCPUProfile()
|
||||
})
|
||||
}
|
||||
|
||||
// TestStopCPUProfile 测试 stopCPUProfile 函数
|
||||
func TestStopCPUProfile(t *testing.T) {
|
||||
t.Run("stop when not active", func(t *testing.T) {
|
||||
// 确保停止状态
|
||||
stopCPUProfile()
|
||||
// 再次停止应该安全
|
||||
stopCPUProfile()
|
||||
})
|
||||
|
||||
t.Run("stop when active", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
err := startCPUProfile(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start CPU profile: %v", err)
|
||||
}
|
||||
|
||||
stopCPUProfile()
|
||||
|
||||
// 验证可以再次启动
|
||||
err = startCPUProfile(&buf)
|
||||
if err != nil {
|
||||
t.Errorf("failed to restart CPU profile after stop: %v", err)
|
||||
}
|
||||
stopCPUProfile()
|
||||
})
|
||||
}
|
||||
|
||||
// TestWriteHeapProfile 测试 writeHeapProfile 函数
|
||||
func TestWriteHeapProfile(t *testing.T) {
|
||||
t.Run("write heap profile", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
// 写入 heap profile
|
||||
writeHeapProfile(&buf)
|
||||
|
||||
// 验证有输出
|
||||
if buf.Len() == 0 {
|
||||
t.Error("expected heap profile output, got empty buffer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("write heap profile multiple times", func(t *testing.T) {
|
||||
var buf1, buf2 bytes.Buffer
|
||||
|
||||
writeHeapProfile(&buf1)
|
||||
writeHeapProfile(&buf2)
|
||||
|
||||
// 两次都应该有输出
|
||||
if buf1.Len() == 0 {
|
||||
t.Error("expected heap profile output on first call")
|
||||
}
|
||||
if buf2.Len() == 0 {
|
||||
t.Error("expected heap profile output on second call")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestWriteGoroutineProfile 测试 writeGoroutineProfile 函数
|
||||
func TestWriteGoroutineProfile(t *testing.T) {
|
||||
t.Run("write goroutine profile", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
writeGoroutineProfile(&buf)
|
||||
|
||||
// 验证有输出
|
||||
if buf.Len() == 0 {
|
||||
t.Error("expected goroutine profile output, got empty buffer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("write goroutine profile with spawned goroutines", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
// 启动一些 goroutine
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {}
|
||||
}()
|
||||
}
|
||||
|
||||
writeGoroutineProfile(&buf)
|
||||
|
||||
// 应该有输出
|
||||
if buf.Len() == 0 {
|
||||
t.Error("expected goroutine profile output")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestWriteBlockProfile 测试 writeBlockProfile 函数
|
||||
func TestWriteBlockProfile(t *testing.T) {
|
||||
t.Run("write block profile", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
writeBlockProfile(&buf)
|
||||
|
||||
// block profile 可能为空(如果没有阻塞操作)
|
||||
// 所以我们只验证函数不会 panic
|
||||
})
|
||||
}
|
||||
|
||||
// TestWriteMutexProfile 测试 writeMutexProfile 函数
|
||||
func TestWriteMutexProfile(t *testing.T) {
|
||||
t.Run("write mutex profile", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
writeMutexProfile(&buf)
|
||||
|
||||
// mutex profile 可能为空(如果没有锁竞争)
|
||||
// 所以我们只验证函数不会 panic
|
||||
})
|
||||
}
|
||||
|
||||
// TestBufioWriterAdapter 测试 bufioWriterAdapter 结构体
|
||||
func TestBufioWriterAdapter(t *testing.T) {
|
||||
t.Run("write and flush", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
bw := bufio.NewWriter(&buf)
|
||||
writer := wrapBufioWriter(bw)
|
||||
|
||||
data := []byte("test data")
|
||||
n, err := writer.Write(data)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("expected %d bytes written, got %d", len(data), n)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("write multiple times", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
bw := bufio.NewWriter(&buf)
|
||||
writer := wrapBufioWriter(bw)
|
||||
|
||||
data1 := []byte("first")
|
||||
data2 := []byte("second")
|
||||
|
||||
n1, err1 := writer.Write(data1)
|
||||
if err1 != nil {
|
||||
t.Errorf("unexpected error on first write: %v", err1)
|
||||
}
|
||||
if n1 != len(data1) {
|
||||
t.Errorf("expected %d bytes on first write, got %d", len(data1), n1)
|
||||
}
|
||||
|
||||
n2, err2 := writer.Write(data2)
|
||||
if err2 != nil {
|
||||
t.Errorf("unexpected error on second write: %v", err2)
|
||||
}
|
||||
if n2 != len(data2) {
|
||||
t.Errorf("expected %d bytes on second write, got %d", len(data2), n2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestWrapBufioWriter 测试 wrapBufioWriter 函数
|
||||
func TestWrapBufioWriter(t *testing.T) {
|
||||
t.Run("wrap returns non-nil", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
bw := bufio.NewWriter(&buf)
|
||||
writer := wrapBufioWriter(bw)
|
||||
|
||||
if writer == nil {
|
||||
t.Error("expected non-nil writer")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrapped writer implements io.Writer", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
bw := bufio.NewWriter(&buf)
|
||||
writer := wrapBufioWriter(bw)
|
||||
|
||||
// 测试写入
|
||||
data := []byte("hello world")
|
||||
n, err := writer.Write(data)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("expected %d bytes, got %d", len(data), n)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestCPUProfileMutex 测试 CPU profile 的并发安全性
|
||||
func TestCPUProfileMutex(t *testing.T) {
|
||||
t.Run("concurrent start/stop", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
var buf bytes.Buffer
|
||||
|
||||
// 启动多个 goroutine 同时操作
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = startCPUProfile(&buf)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
stopCPUProfile()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// 确保最终状态一致
|
||||
stopCPUProfile()
|
||||
})
|
||||
}
|
||||
@ -633,6 +633,138 @@ func TestPprofHandler_handleCPU_Params(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPprofHandler_handleCPU 测试 handleCPU 函数的参数解析。
|
||||
// 使用 1 秒的最小采集时间来确保测试快速完成。
|
||||
func TestPprofHandler_handleCPU(t *testing.T) {
|
||||
t.Run("default seconds (30s) - verify headers only", func(t *testing.T) {
|
||||
// 注意:默认 30 秒太长,这里只验证请求解析逻辑
|
||||
// 实际的 profile 采集在 TestPprofHandler_handleCPU_WithShortDuration 中测试
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/debug/pprof/profile")
|
||||
|
||||
// 验证请求路径解析正确
|
||||
path := string(ctx.Path())
|
||||
if path != "/debug/pprof/profile" {
|
||||
t.Errorf("unexpected path: %s", path)
|
||||
}
|
||||
|
||||
// 验证没有 seconds 参数时 QueryArgs 为空
|
||||
secArg := ctx.QueryArgs().Peek("seconds")
|
||||
if secArg != nil {
|
||||
t.Errorf("expected no seconds arg, got: %s", secArg)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom seconds parameter", func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/debug/pprof/profile?seconds=5")
|
||||
|
||||
// 验证 seconds 参数解析正确
|
||||
secArg := ctx.QueryArgs().Peek("seconds")
|
||||
if string(secArg) != "5" {
|
||||
t.Errorf("expected seconds=5, got: %s", secArg)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid seconds parameter", func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/debug/pprof/profile?seconds=invalid")
|
||||
|
||||
// 验证参数存在
|
||||
secArg := ctx.QueryArgs().Peek("seconds")
|
||||
if string(secArg) != "invalid" {
|
||||
t.Errorf("expected seconds=invalid, got: %s", secArg)
|
||||
}
|
||||
|
||||
// strconv.Atoi("invalid") 会返回错误,函数会使用默认值 30
|
||||
})
|
||||
}
|
||||
|
||||
// TestPprofHandler_handleCPU_Execute 执行 handleCPU 并验证响应。
|
||||
// 使用 1 秒采集时间来快速完成测试。
|
||||
func TestPprofHandler_handleCPU_Execute(t *testing.T) {
|
||||
// 确保之前的 CPU profile 已停止
|
||||
stopCPUProfile()
|
||||
|
||||
h := &PprofHandler{
|
||||
path: "/debug/pprof",
|
||||
allowedIPs: []net.IP{},
|
||||
allowedNets: []*net.IPNet{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/debug/pprof/profile?seconds=1")
|
||||
|
||||
// 执行 handleCPU
|
||||
h.handleCPU(ctx)
|
||||
|
||||
// 验证 Content-Type
|
||||
contentType := string(ctx.Response.Header.Peek("Content-Type"))
|
||||
if contentType != "application/octet-stream" {
|
||||
t.Errorf("expected Content-Type application/octet-stream, got: %s", contentType)
|
||||
}
|
||||
|
||||
// 验证响应体(CPU profile 数据)
|
||||
body := ctx.Response.Body()
|
||||
if len(body) == 0 {
|
||||
t.Error("expected CPU profile output, got empty body")
|
||||
}
|
||||
|
||||
// 验证响应体包含 pprof header 标识
|
||||
// pprof 文件以特定的 magic number 开头
|
||||
if len(body) > 0 {
|
||||
// pprof 格式的文件应该有内容
|
||||
t.Logf("CPU profile size: %d bytes", len(body))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPprofHandler_handleCPU_NegativeSeconds 测试负数秒数参数解析。
|
||||
// 负数秒会被 sec > 0 检查过滤,使用默认值 30 秒。
|
||||
func TestPprofHandler_handleCPU_NegativeSeconds(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/debug/pprof/profile?seconds=-1")
|
||||
|
||||
// 验证参数解析
|
||||
secArg := ctx.QueryArgs().Peek("seconds")
|
||||
if string(secArg) != "-1" {
|
||||
t.Errorf("expected seconds=-1, got: %s", secArg)
|
||||
}
|
||||
|
||||
// 注意:负数秒在 handleCPU 中会被 sec > 0 检查过滤,使用默认值 30 秒
|
||||
// 为了测试效率,这里只验证参数解析,不实际执行 handleCPU
|
||||
}
|
||||
|
||||
// TestPprofHandler_handleCPU_ZeroSeconds 测试零秒参数解析。
|
||||
// 零秒会被 sec > 0 检查过滤,使用默认值 30 秒。
|
||||
func TestPprofHandler_handleCPU_ZeroSeconds(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/debug/pprof/profile?seconds=0")
|
||||
|
||||
// 验证参数解析
|
||||
secArg := ctx.QueryArgs().Peek("seconds")
|
||||
if string(secArg) != "0" {
|
||||
t.Errorf("expected seconds=0, got: %s", secArg)
|
||||
}
|
||||
|
||||
// 注意:零秒在 handleCPU 中会被 sec > 0 检查过滤,使用默认值 30 秒
|
||||
// 为了测试效率,这里只验证参数解析,不实际执行 handleCPU
|
||||
}
|
||||
|
||||
// TestPprofHandler_handleCPU_LargeSeconds 测试大数值秒数参数解析。
|
||||
func TestPprofHandler_handleCPU_LargeSeconds(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/debug/pprof/profile?seconds=999999")
|
||||
|
||||
// 验证参数解析
|
||||
secArg := ctx.QueryArgs().Peek("seconds")
|
||||
if string(secArg) != "999999" {
|
||||
t.Errorf("expected seconds=999999, got: %s", secArg)
|
||||
}
|
||||
|
||||
// 注意:这里只验证参数解析不会溢出
|
||||
// 为了测试效率,不实际执行 handleCPU(会等待 999999 秒)
|
||||
}
|
||||
|
||||
func TestPprofHandler_ConfigWithCIDRAndIP(t *testing.T) {
|
||||
// 测试混合配置
|
||||
cfg := &config.PprofConfig{
|
||||
|
||||
@ -17,10 +17,13 @@ import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/cache"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/loadbalance"
|
||||
"rua.plus/lolly/internal/proxy"
|
||||
)
|
||||
|
||||
func TestPurgeHandler_Path(t *testing.T) {
|
||||
@ -781,3 +784,607 @@ func TestPurgeHandler_checkAccess_WithAllowedIP(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// mockProxyWithCache 是一个用于测试的 mock Proxy,可以返回指定的缓存。
|
||||
type mockProxyWithCache struct {
|
||||
cache *cache.ProxyCache
|
||||
}
|
||||
|
||||
func (m *mockProxyWithCache) GetCache() *cache.ProxyCache {
|
||||
return m.cache
|
||||
}
|
||||
|
||||
// TestPurgeHandler_PurgeByPath_WithRealCache 测试 purgeByPath 在有真实缓存时的行为。
|
||||
func TestPurgeHandler_PurgeByPath_WithRealCache(t *testing.T) {
|
||||
// 创建启用缓存的代理
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
Cache: config.ProxyCacheConfig{
|
||||
Enabled: true,
|
||||
MaxAge: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
// 获取缓存并添加测试数据
|
||||
pcache := p.GetCache()
|
||||
if pcache == nil {
|
||||
t.Fatal("GetCache() should return non-nil when cache enabled")
|
||||
}
|
||||
|
||||
// 添加测试缓存条目
|
||||
hashKey1 := cache.HashPathWithMethod("/api/users", "GET")
|
||||
pcache.Set(hashKey1, "GET:/api/users", []byte("test data 1"), nil, 200, time.Minute)
|
||||
|
||||
hashKey2 := cache.HashPathWithMethod("/api/posts", "GET")
|
||||
pcache.Set(hashKey2, "GET:/api/posts", []byte("test data 2"), nil, 200, time.Minute)
|
||||
|
||||
hashKey3 := cache.HashPathWithMethod("/api/users", "POST")
|
||||
pcache.Set(hashKey3, "POST:/api/users", []byte("test data 3"), nil, 200, time.Minute)
|
||||
|
||||
// 创建带有代理的 handler
|
||||
h, err := NewPurgeHandler(&Server{
|
||||
proxies: []*proxy.Proxy{p},
|
||||
}, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
t.Run("delete existing entry", func(t *testing.T) {
|
||||
deleted := h.PurgeByPathForTest("/api/users", "GET")
|
||||
if deleted != 1 {
|
||||
t.Errorf("expected 1 deletion, got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete different method", func(t *testing.T) {
|
||||
deleted := h.PurgeByPathForTest("/api/users", "POST")
|
||||
if deleted != 1 {
|
||||
t.Errorf("expected 1 deletion, got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete non-existing path", func(t *testing.T) {
|
||||
deleted := h.PurgeByPathForTest("/api/nonexistent", "GET")
|
||||
if deleted != 1 {
|
||||
t.Errorf("expected 1 (proxy count), got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple proxies", func(t *testing.T) {
|
||||
// 创建第二个代理
|
||||
p2, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
pcache2 := p2.GetCache()
|
||||
hashKey := cache.HashPathWithMethod("/test", "GET")
|
||||
pcache2.Set(hashKey, "GET:/test", []byte("test"), nil, 200, time.Minute)
|
||||
|
||||
h2, err := NewPurgeHandler(&Server{
|
||||
proxies: []*proxy.Proxy{p, p2},
|
||||
}, &config.CacheAPIConfig{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
deleted := h2.PurgeByPathForTest("/test", "GET")
|
||||
if deleted != 2 {
|
||||
t.Errorf("expected 2 deletions (2 proxies), got %d", deleted)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestPurgeHandler_PurgeByPattern_WithRealCache 测试 purgeByPattern 在有真实缓存时的行为。
|
||||
func TestPurgeHandler_PurgeByPattern_WithRealCache(t *testing.T) {
|
||||
// 创建启用缓存的代理
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
Cache: config.ProxyCacheConfig{
|
||||
Enabled: true,
|
||||
MaxAge: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
// 获取缓存并添加测试数据
|
||||
pcache := p.GetCache()
|
||||
if pcache == nil {
|
||||
t.Fatal("GetCache() should return non-nil when cache enabled")
|
||||
}
|
||||
|
||||
// 添加多个测试缓存条目
|
||||
pcache.Set(cache.HashPathWithMethod("/api/users", "GET"), "GET:/api/users", []byte("data"), nil, 200, time.Minute)
|
||||
pcache.Set(cache.HashPathWithMethod("/api/users/1", "GET"), "GET:/api/users/1", []byte("data"), nil, 200, time.Minute)
|
||||
pcache.Set(cache.HashPathWithMethod("/api/posts", "GET"), "GET:/api/posts", []byte("data"), nil, 200, time.Minute)
|
||||
pcache.Set(cache.HashPathWithMethod("/api/posts/1", "GET"), "GET:/api/posts/1", []byte("data"), nil, 200, time.Minute)
|
||||
pcache.Set(cache.HashPathWithMethod("/api/users", "POST"), "POST:/api/users", []byte("data"), nil, 200, time.Minute)
|
||||
|
||||
// 创建带有代理的 handler
|
||||
h, err := NewPurgeHandler(&Server{
|
||||
proxies: []*proxy.Proxy{p},
|
||||
}, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
t.Run("wildcard pattern matches multiple", func(t *testing.T) {
|
||||
// 重新添加数据
|
||||
pcache.Set(cache.HashPathWithMethod("/api/users", "GET"), "GET:/api/users", []byte("data"), nil, 200, time.Minute)
|
||||
pcache.Set(cache.HashPathWithMethod("/api/users/1", "GET"), "GET:/api/users/1", []byte("data"), nil, 200, time.Minute)
|
||||
pcache.Set(cache.HashPathWithMethod("/api/posts", "GET"), "GET:/api/posts", []byte("data"), nil, 200, time.Minute)
|
||||
|
||||
// 注意:OrigKey 格式为 "METHOD:/path",所以模式需要匹配完整路径
|
||||
deleted := h.PurgeByPatternForTest("GET:/api/*", "GET")
|
||||
if deleted < 1 {
|
||||
t.Errorf("expected at least 1 deletion, got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty method matches all methods", func(t *testing.T) {
|
||||
// 重新添加数据
|
||||
pcache.Set(cache.HashPathWithMethod("/api/users", "GET"), "GET:/api/users", []byte("data"), nil, 200, time.Minute)
|
||||
pcache.Set(cache.HashPathWithMethod("/api/users", "POST"), "POST:/api/users", []byte("data"), nil, 200, time.Minute)
|
||||
|
||||
// 使用 * 通配符匹配所有方法
|
||||
deleted := h.PurgeByPatternForTest("*:/api/users", "")
|
||||
if deleted < 1 {
|
||||
t.Errorf("expected at least 1 deletion (all methods), got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("specific method only", func(t *testing.T) {
|
||||
// 重新添加数据
|
||||
pcache.Set(cache.HashPathWithMethod("/api/users", "GET"), "GET:/api/users", []byte("data"), nil, 200, time.Minute)
|
||||
pcache.Set(cache.HashPathWithMethod("/api/users", "POST"), "POST:/api/users", []byte("data"), nil, 200, time.Minute)
|
||||
|
||||
// 模式匹配 POST 方法的路径
|
||||
deleted := h.PurgeByPatternForTest("POST:/api/users", "POST")
|
||||
if deleted < 1 {
|
||||
t.Errorf("expected at least 1 deletion (POST only), got %d", deleted)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestPurgeHandler_PurgeByPath_WithProxyNoCache 测试代理没有缓存时的情况。
|
||||
func TestPurgeHandler_PurgeByPath_WithProxyNoCache(t *testing.T) {
|
||||
// 创建禁用缓存的代理
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
// Cache 未启用
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
// 确认缓存为 nil
|
||||
if p.GetCache() != nil {
|
||||
t.Fatal("GetCache() should return nil when cache disabled")
|
||||
}
|
||||
|
||||
// 创建带有代理的 handler
|
||||
h, err := NewPurgeHandler(&Server{
|
||||
proxies: []*proxy.Proxy{p},
|
||||
}, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 没有缓存的代理应该返回 0
|
||||
deleted := h.PurgeByPathForTest("/api/users", "GET")
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for proxy without cache, got %d", deleted)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPurgeHandler_PurgeByPattern_WithProxyNoCache 测试代理没有缓存时的情况。
|
||||
func TestPurgeHandler_PurgeByPattern_WithProxyNoCache(t *testing.T) {
|
||||
// 创建禁用缓存的代理
|
||||
cfg := &config.ProxyConfig{
|
||||
Path: "/api",
|
||||
LoadBalance: "round_robin",
|
||||
Timeout: config.ProxyTimeout{Connect: 5 * time.Second},
|
||||
}
|
||||
targets := []*loadbalance.Target{{URL: "http://localhost:8080"}}
|
||||
p, err := proxy.NewProxy(cfg, targets, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxy() error: %v", err)
|
||||
}
|
||||
|
||||
// 创建带有代理的 handler
|
||||
h, err := NewPurgeHandler(&Server{
|
||||
proxies: []*proxy.Proxy{p},
|
||||
}, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 没有缓存的代理应该返回 0
|
||||
deleted := h.PurgeByPatternForTest("/api/*", "GET")
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for proxy without cache, got %d", deleted)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPurgeHandler_PurgeByPath_WithCache 测试 purgeByPath 在有缓存时的行为。
|
||||
func TestPurgeHandler_PurgeByPath_WithCache(t *testing.T) {
|
||||
t.Run("server with empty proxies", func(t *testing.T) {
|
||||
// 创建带有空 proxies 列表的 handler
|
||||
h, err := NewPurgeHandler(&Server{
|
||||
proxies: []*proxy.Proxy{},
|
||||
}, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 空 proxies 列表应该返回 0
|
||||
deleted := h.PurgeByPathForTest("/api/users", "GET")
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for empty proxies, got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty path", func(t *testing.T) {
|
||||
h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 空路径仍然会执行删除逻辑,只是哈希值为默认 GET 的哈希
|
||||
deleted := h.PurgeByPathForTest("", "")
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for nil server, got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("path with special characters", func(t *testing.T) {
|
||||
h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 特殊字符路径
|
||||
deleted := h.PurgeByPathForTest("/api/users?id=1&name=test", "GET")
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for nil server, got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("path with unicode", func(t *testing.T) {
|
||||
h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Unicode 路径
|
||||
deleted := h.PurgeByPathForTest("/api/用户/列表", "GET")
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for nil server, got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different methods", func(t *testing.T) {
|
||||
h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"}
|
||||
for _, method := range methods {
|
||||
deleted := h.PurgeByPathForTest("/api/users", method)
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for nil server with method %s, got %d", method, deleted)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestPurgeHandler_PurgeByPattern_WithCache 测试 purgeByPattern 在有缓存时的行为。
|
||||
func TestPurgeHandler_PurgeByPattern_WithCache(t *testing.T) {
|
||||
t.Run("empty pattern", func(t *testing.T) {
|
||||
h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 空模式
|
||||
deleted := h.PurgeByPatternForTest("", "GET")
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for nil server, got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wildcard pattern", func(t *testing.T) {
|
||||
h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 通配符模式
|
||||
deleted := h.PurgeByPatternForTest("/api/*", "GET")
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for nil server, got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("double wildcard pattern", func(t *testing.T) {
|
||||
h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 双通配符模式
|
||||
deleted := h.PurgeByPatternForTest("/api/**", "GET")
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for nil server, got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("pattern with special characters", func(t *testing.T) {
|
||||
h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 特殊字符模式
|
||||
deleted := h.PurgeByPatternForTest("/api/users?id=*", "GET")
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for nil server, got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exact pattern (no wildcard)", func(t *testing.T) {
|
||||
h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 精确模式(无通配符)
|
||||
deleted := h.PurgeByPatternForTest("/api/users", "GET")
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for nil server, got %d", deleted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("different methods", func(t *testing.T) {
|
||||
h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH"}
|
||||
for _, method := range methods {
|
||||
deleted := h.PurgeByPatternForTest("/api/*", method)
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for nil server with method %s, got %d", method, deleted)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty method matches all", func(t *testing.T) {
|
||||
h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 空方法应该匹配所有条目
|
||||
deleted := h.PurgeByPatternForTest("/api/*", "")
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for nil server, got %d", deleted)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestPurgeHandler_PurgeByPath_HashConsistency 测试哈希一致性。
|
||||
func TestPurgeHandler_PurgeByPath_HashConsistency(t *testing.T) {
|
||||
// 验证相同路径和方法产生相同哈希
|
||||
path := "/api/users"
|
||||
method := "GET"
|
||||
|
||||
hash1 := cache.HashPathWithMethod(path, method)
|
||||
hash2 := cache.HashPathWithMethod(path, method)
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("hash not consistent: %d != %d", hash1, hash2)
|
||||
}
|
||||
|
||||
// 验证不同路径产生不同哈希
|
||||
hash3 := cache.HashPathWithMethod("/api/posts", method)
|
||||
if hash1 == hash3 {
|
||||
t.Error("expected different hashes for different paths")
|
||||
}
|
||||
|
||||
// 验证不同方法产生不同哈希
|
||||
hash4 := cache.HashPathWithMethod(path, "POST")
|
||||
if hash1 == hash4 {
|
||||
t.Error("expected different hashes for different methods")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPurgeHandler_PurgeByPattern_PatternMatching 测试模式匹配逻辑。
|
||||
func TestPurgeHandler_PurgeByPattern_PatternMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
pattern string
|
||||
path string
|
||||
want bool
|
||||
}{
|
||||
// 通配符结尾 - 前缀匹配
|
||||
{"/api/*", "/api/users", true},
|
||||
{"/api/*", "/api/posts", true},
|
||||
{"/api/*", "/api/users/123", true}, // * 匹配剩余所有内容
|
||||
{"/api/*", "/other/path", false},
|
||||
|
||||
// 单个 * 匹配所有
|
||||
{"*", "/api/users", true},
|
||||
{"*", "/any/path", true},
|
||||
|
||||
// 中间通配符
|
||||
{"/api/*/users", "/api/v1/users", true},
|
||||
{"/api/*/users", "/api/v2/users", true},
|
||||
{"/api/*/users", "/api/users", true}, // 前缀和后缀都匹配
|
||||
{"/api/*/users", "/api/v1/posts", false},
|
||||
|
||||
// 精确匹配
|
||||
{"/api/users", "/api/users", true},
|
||||
{"/api/users", "/api/posts", false},
|
||||
|
||||
// 空模式
|
||||
{"", "", true},
|
||||
{"", "/api", false},
|
||||
|
||||
// 目录前缀匹配(以 / 结尾)
|
||||
{"/api/", "/api/users", true},
|
||||
{"/api/", "/api/users/123", true},
|
||||
{"/api/", "/other/path", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.pattern+"_"+tt.path, func(t *testing.T) {
|
||||
got := cache.MatchPattern(tt.pattern, tt.path)
|
||||
if got != tt.want {
|
||||
t.Errorf("MatchPattern(%q, %q) = %v, want %v", tt.pattern, tt.path, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPurgeHandler_PurgeByPath_VariousInputs 测试各种输入。
|
||||
func TestPurgeHandler_PurgeByPath_VariousInputs(t *testing.T) {
|
||||
h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"empty path and method", "", ""},
|
||||
{"empty path with method", "", "GET"},
|
||||
{"path with empty method", "/test", ""},
|
||||
{"root path", "/", "GET"},
|
||||
{"nested path", "/a/b/c/d/e", "GET"},
|
||||
{"path with trailing slash", "/api/users/", "GET"},
|
||||
{"path with query", "/api?key=value", "GET"},
|
||||
{"path with fragment", "/api#section", "GET"},
|
||||
{"path with encoded chars", "/api%2Fusers", "GET"},
|
||||
{"long path", strings.Repeat("/a", 100), "GET"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 应该不会 panic
|
||||
deleted := h.PurgeByPathForTest(tt.path, tt.method)
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for nil server, got %d", deleted)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPurgeHandler_PurgeByPattern_VariousInputs 测试各种模式输入。
|
||||
func TestPurgeHandler_PurgeByPattern_VariousInputs(t *testing.T) {
|
||||
h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{
|
||||
Path: "/_cache/purge",
|
||||
Allow: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
method string
|
||||
}{
|
||||
{"empty pattern and method", "", ""},
|
||||
{"empty pattern with method", "", "GET"},
|
||||
{"pattern with empty method", "/api/*", ""},
|
||||
{"single wildcard only", "*", "GET"},
|
||||
{"double wildcard only", "**", "GET"},
|
||||
{"multiple single wildcards", "/api/*/users/*", "GET"},
|
||||
{"mixed wildcards", "/api/**/users/*", "GET"},
|
||||
{"wildcard at start", "*/users", "GET"},
|
||||
{"wildcard at end", "/api/*", "GET"},
|
||||
{"consecutive wildcards", "/api/**/*", "GET"},
|
||||
{"long pattern", strings.Repeat("/a*", 20), "GET"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 应该不会 panic
|
||||
deleted := h.PurgeByPatternForTest(tt.pattern, tt.method)
|
||||
if deleted != 0 {
|
||||
t.Errorf("expected 0 deletions for nil server, got %d", deleted)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1216
internal/server/startmultiservermode_test.go
Normal file
1216
internal/server/startmultiservermode_test.go
Normal file
File diff suppressed because it is too large
Load Diff
671
internal/server/startsinglemode_test.go
Normal file
671
internal/server/startsinglemode_test.go
Normal file
@ -0,0 +1,671 @@
|
||||
// Package server 提供 startSingleMode 集成测试。
|
||||
//
|
||||
// 该文件测试 startSingleMode 函数的各种配置场景,
|
||||
// 包括静态文件、代理、监控端点、TLS 等配置的实际启动。
|
||||
//
|
||||
// 作者:xfy
|
||||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// TestStartSingleMode_Integration_WithStaticFiles 测试 startSingleMode 静态文件实际启动。
|
||||
func TestStartSingleMode_Integration_WithStaticFiles(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
tempDir := t.TempDir()
|
||||
// 创建一个测试文件
|
||||
testFile := tempDir + "/index.html"
|
||||
if err := os.WriteFile(testFile, []byte("<html>test</html>"), 0o644); err != nil {
|
||||
t.Fatalf("failed to create test file: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
Static: []config.StaticConfig{
|
||||
{
|
||||
Path: "/static",
|
||||
Root: tempDir,
|
||||
Index: []string{"index.html"},
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
// 在 goroutine 中启动服务器
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待服务器启动
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 停止服务器
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
// 服务器正常关闭会返回 nil 或 listener closed 错误
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
// 服务器仍在运行,已通过 GracefulStop 关闭
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithProxy 测试 startSingleMode 代理实际启动。
|
||||
func TestStartSingleMode_Integration_WithProxy(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
Proxy: []config.ProxyConfig{
|
||||
{
|
||||
Path: "/api",
|
||||
Targets: []config.ProxyTarget{
|
||||
{URL: "http://127.0.0.1:9999", Weight: 1}, // 不存在的后端
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithMonitoring 测试 startSingleMode 监控端点实际启动。
|
||||
func TestStartSingleMode_Integration_WithMonitoring(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
}},
|
||||
Monitoring: config.MonitoringConfig{
|
||||
Status: config.StatusConfig{
|
||||
Enabled: true,
|
||||
Path: "/_status",
|
||||
Allow: []string{"127.0.0.1"},
|
||||
},
|
||||
Pprof: config.PprofConfig{
|
||||
Enabled: true,
|
||||
Path: "/debug/pprof",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithCacheAPI 测试 startSingleMode 缓存 API 实际启动。
|
||||
func TestStartSingleMode_Integration_WithCacheAPI(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
CacheAPI: &config.CacheAPIConfig{
|
||||
Enabled: true,
|
||||
Path: "/_cache/purge",
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithCompression 测试 startSingleMode 压缩配置实际启动。
|
||||
func TestStartSingleMode_Integration_WithCompression(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
Compression: config.CompressionConfig{
|
||||
Type: "gzip",
|
||||
Level: 6,
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithSecurity 测试 startSingleMode 安全配置实际启动。
|
||||
func TestStartSingleMode_Integration_WithSecurity(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
Security: config.SecurityConfig{
|
||||
Access: config.AccessConfig{
|
||||
Allow: []string{"127.0.0.1"},
|
||||
},
|
||||
RateLimit: config.RateLimitConfig{
|
||||
RequestRate: 100,
|
||||
Burst: 200,
|
||||
},
|
||||
Headers: config.SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithRewrite 测试 startSingleMode 重写配置实际启动。
|
||||
func TestStartSingleMode_Integration_WithRewrite(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
Rewrite: []config.RewriteRule{
|
||||
{Pattern: "^/old/(.*)$", Replacement: "/new/$1"},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithPerformance 测试 startSingleMode 性能配置实际启动。
|
||||
func TestStartSingleMode_Integration_WithPerformance(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
}},
|
||||
Performance: config.PerformanceConfig{
|
||||
GoroutinePool: config.GoroutinePoolConfig{
|
||||
Enabled: true,
|
||||
MaxWorkers: 50,
|
||||
MinWorkers: 5,
|
||||
IdleTimeout: 10 * time.Second,
|
||||
},
|
||||
FileCache: config.FileCacheConfig{
|
||||
MaxEntries: 1000,
|
||||
MaxSize: 10 * 1024 * 1024,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithProxyLocationTypes 测试代理不同位置类型实际启动。
|
||||
func TestStartSingleMode_Integration_WithProxyLocationTypes(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
Proxy: []config.ProxyConfig{
|
||||
{
|
||||
Path: "/api/exact",
|
||||
LocationType: "exact",
|
||||
Targets: []config.ProxyTarget{{URL: "http://127.0.0.1:9999", Weight: 1}},
|
||||
},
|
||||
{
|
||||
Path: "/api/priority",
|
||||
LocationType: "prefix_priority",
|
||||
Targets: []config.ProxyTarget{{URL: "http://127.0.0.1:9999", Weight: 1}},
|
||||
},
|
||||
{
|
||||
Path: "^/api/regex/(.*)$",
|
||||
LocationType: "regex",
|
||||
Targets: []config.ProxyTarget{{URL: "http://127.0.0.1:9999", Weight: 1}},
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithStaticLocationTypes 测试静态文件不同位置类型实际启动。
|
||||
func TestStartSingleMode_Integration_WithStaticLocationTypes(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
tempDir := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
Static: []config.StaticConfig{
|
||||
{
|
||||
Path: "/static/exact",
|
||||
Root: tempDir,
|
||||
LocationType: "exact",
|
||||
},
|
||||
{
|
||||
Path: "/static/priority",
|
||||
Root: tempDir,
|
||||
LocationType: "prefix_priority",
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithHealthCheck 测试代理健康检查实际启动。
|
||||
func TestStartSingleMode_Integration_WithHealthCheck(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
Proxy: []config.ProxyConfig{
|
||||
{
|
||||
Path: "/api",
|
||||
Targets: []config.ProxyTarget{
|
||||
{URL: "http://127.0.0.1:9999", Weight: 1},
|
||||
},
|
||||
HealthCheck: config.HealthCheckConfig{
|
||||
Interval: 1 * time.Second,
|
||||
Timeout: 500 * time.Millisecond,
|
||||
Path: "/health",
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond) // 给健康检查一些时间启动
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithMIMETypes 测试 MIME 类型配置实际启动。
|
||||
func TestStartSingleMode_Integration_WithMIMETypes(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
Types: config.TypesConfig{
|
||||
Map: map[string]string{
|
||||
".wasm": "application/wasm",
|
||||
},
|
||||
DefaultType: "application/octet-stream",
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithErrorPage 测试错误页面配置实际启动。
|
||||
func TestStartSingleMode_Integration_WithErrorPage(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
tempDir := t.TempDir()
|
||||
errorPage := tempDir + "/404.html"
|
||||
if err := os.WriteFile(errorPage, []byte("<html>Not Found</html>"), 0o644); err != nil {
|
||||
t.Fatalf("failed to create error page: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
Security: config.SecurityConfig{
|
||||
ErrorPage: config.ErrorPageConfig{
|
||||
Pages: map[int]string{404: errorPage},
|
||||
Default: errorPage,
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithConnLimiter 测试连接限制配置实际启动。
|
||||
func TestStartSingleMode_Integration_WithConnLimiter(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
Security: config.SecurityConfig{
|
||||
RateLimit: config.RateLimitConfig{
|
||||
ConnLimit: 10,
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartSingleMode_Integration_WithAuthRequest 测试外部认证配置实际启动。
|
||||
func TestStartSingleMode_Integration_WithAuthRequest(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0",
|
||||
Security: config.SecurityConfig{
|
||||
AuthRequest: config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "/auth/validate",
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
},
|
||||
}},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := s.Start(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = s.GracefulStop(2 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && !isExpectedServerErrorForIntegration(err) {
|
||||
t.Errorf("unexpected server error: %v", err)
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// isExpectedServerErrorForIntegration 检查是否是预期的服务器关闭错误。
|
||||
func isExpectedServerErrorForIntegration(err error) bool {
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
// fasthttp 服务器关闭时的常见错误
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "closed") ||
|
||||
strings.Contains(errStr, "use of closed") ||
|
||||
strings.Contains(errStr, "listener closed")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
514
internal/server/testutil_test.go
Normal file
514
internal/server/testutil_test.go
Normal file
@ -0,0 +1,514 @@
|
||||
// Package server 提供测试工具函数的测试。
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/lua"
|
||||
"rua.plus/lolly/internal/ssl"
|
||||
)
|
||||
|
||||
// TestMockFastServer_Serve 测试 MockFastServer.Serve 方法
|
||||
func TestMockFastServer_Serve(t *testing.T) {
|
||||
t.Run("with custom ServeFunc", func(t *testing.T) {
|
||||
called := false
|
||||
mock := &MockFastServer{
|
||||
ServeFunc: func(ln net.Listener) error {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create listener: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
err = mock.Serve(ln)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("ServeFunc was not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("without ServeFunc", func(t *testing.T) {
|
||||
mock := &MockFastServer{}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create listener: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
err = mock.Serve(ln)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with error from ServeFunc", func(t *testing.T) {
|
||||
expectedErr := errors.New("serve error")
|
||||
mock := &MockFastServer{
|
||||
ServeFunc: func(ln net.Listener) error {
|
||||
return expectedErr
|
||||
},
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create listener: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
err = mock.Serve(ln)
|
||||
if err != expectedErr {
|
||||
t.Errorf("expected error %v, got %v", expectedErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMockFastServer_ServeTLS 测试 MockFastServer.ServeTLS 方法
|
||||
func TestMockFastServer_ServeTLS(t *testing.T) {
|
||||
t.Run("with custom ServeTLSFunc", func(t *testing.T) {
|
||||
called := false
|
||||
mock := &MockFastServer{
|
||||
ServeTLSFunc: func(ln net.Listener, certFile, keyFile string) error {
|
||||
called = true
|
||||
if certFile != "cert.pem" {
|
||||
t.Errorf("expected certFile cert.pem, got %s", certFile)
|
||||
}
|
||||
if keyFile != "key.pem" {
|
||||
t.Errorf("expected keyFile key.pem, got %s", keyFile)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create listener: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
err = mock.ServeTLS(ln, "cert.pem", "key.pem")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("ServeTLSFunc was not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("without ServeTLSFunc", func(t *testing.T) {
|
||||
mock := &MockFastServer{}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create listener: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
err = mock.ServeTLS(ln, "cert.pem", "key.pem")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMockFastServer_Shutdown 测试 MockFastServer.Shutdown 方法
|
||||
func TestMockFastServer_Shutdown(t *testing.T) {
|
||||
t.Run("with custom ShutdownFunc", func(t *testing.T) {
|
||||
called := false
|
||||
mock := &MockFastServer{
|
||||
ShutdownFunc: func() error {
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
err := mock.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("ShutdownFunc was not called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("without ShutdownFunc", func(t *testing.T) {
|
||||
mock := &MockFastServer{}
|
||||
|
||||
err := mock.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with error from ShutdownFunc", func(t *testing.T) {
|
||||
expectedErr := errors.New("shutdown error")
|
||||
mock := &MockFastServer{
|
||||
ShutdownFunc: func() error {
|
||||
return expectedErr
|
||||
},
|
||||
}
|
||||
|
||||
err := mock.Shutdown()
|
||||
if err != expectedErr {
|
||||
t.Errorf("expected error %v, got %v", expectedErr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestNewServerForTesting 测试 NewServerForTesting 函数
|
||||
func TestNewServerForTesting(t *testing.T) {
|
||||
t.Run("with nil deps", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":8080",
|
||||
}},
|
||||
}
|
||||
|
||||
s := NewServerForTesting(cfg, nil)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil server")
|
||||
}
|
||||
if s.config != cfg {
|
||||
t.Error("config not set correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with lua engine", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":8080",
|
||||
}},
|
||||
}
|
||||
|
||||
luaEngine := &lua.LuaEngine{}
|
||||
deps := &TestDependencies{
|
||||
LuaEngine: luaEngine,
|
||||
}
|
||||
|
||||
s := NewServerForTesting(cfg, deps)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil server")
|
||||
}
|
||||
if s.luaEngine != luaEngine {
|
||||
t.Error("lua engine not set correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with TLS manager", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":8080",
|
||||
}},
|
||||
}
|
||||
|
||||
tlsManager := &ssl.TLSManager{}
|
||||
deps := &TestDependencies{
|
||||
TLSManager: tlsManager,
|
||||
}
|
||||
|
||||
s := NewServerForTesting(cfg, deps)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil server")
|
||||
}
|
||||
if s.tlsManager != tlsManager {
|
||||
t.Error("TLS manager not set correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with all deps", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":8080",
|
||||
}},
|
||||
}
|
||||
|
||||
luaEngine := &lua.LuaEngine{}
|
||||
tlsManager := &ssl.TLSManager{}
|
||||
deps := &TestDependencies{
|
||||
LuaEngine: luaEngine,
|
||||
TLSManager: tlsManager,
|
||||
}
|
||||
|
||||
s := NewServerForTesting(cfg, deps)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil server")
|
||||
}
|
||||
if s.luaEngine != luaEngine {
|
||||
t.Error("lua engine not set correctly")
|
||||
}
|
||||
if s.tlsManager != tlsManager {
|
||||
t.Error("TLS manager not set correctly")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestNewTestServerWithOptions 测试 NewTestServerWithOptions 函数
|
||||
func TestNewTestServerWithOptions(t *testing.T) {
|
||||
t.Run("with nil opts", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":8080",
|
||||
}},
|
||||
}
|
||||
|
||||
s := NewTestServerWithOptions(cfg, nil)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil server")
|
||||
}
|
||||
if s.config != cfg {
|
||||
t.Error("config not set correctly")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with custom handler", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":8080",
|
||||
}},
|
||||
}
|
||||
|
||||
customHandler := func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetBodyString("custom response")
|
||||
}
|
||||
|
||||
opts := &TestServerOptions{
|
||||
CustomHandler: customHandler,
|
||||
}
|
||||
|
||||
s := NewTestServerWithOptions(cfg, opts)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil server")
|
||||
}
|
||||
if s.handler == nil {
|
||||
t.Error("handler should be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with empty opts", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":8080",
|
||||
}},
|
||||
}
|
||||
|
||||
opts := &TestServerOptions{}
|
||||
|
||||
s := NewTestServerWithOptions(cfg, opts)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil server")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with mock fast server", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":8080",
|
||||
}},
|
||||
}
|
||||
|
||||
opts := &TestServerOptions{
|
||||
MockFastServer: &MockFastServer{
|
||||
Name: "test-server",
|
||||
},
|
||||
}
|
||||
|
||||
s := NewTestServerWithOptions(cfg, opts)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil server")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMustStartTestServer 测试 MustStartTestServer 函数
|
||||
func TestMustStartTestServer(t *testing.T) {
|
||||
t.Run("basic server start", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "127.0.0.1:0", // 随机端口
|
||||
}},
|
||||
}
|
||||
|
||||
s := MustStartTestServer(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil server")
|
||||
}
|
||||
|
||||
// 给服务器一点时间启动
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// 停止服务器
|
||||
_ = s.StopWithTimeout(1 * time.Second)
|
||||
})
|
||||
|
||||
t.Run("with empty listen address", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: "",
|
||||
}},
|
||||
}
|
||||
|
||||
s := MustStartTestServer(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil server")
|
||||
}
|
||||
|
||||
// 给服务器一点时间启动
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// 停止服务器
|
||||
_ = s.StopWithTimeout(1 * time.Second)
|
||||
})
|
||||
|
||||
t.Run("with default port", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{{
|
||||
Listen: ":80",
|
||||
}},
|
||||
}
|
||||
|
||||
s := MustStartTestServer(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil server")
|
||||
}
|
||||
|
||||
// 给服务器一点时间启动
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// 停止服务器
|
||||
_ = s.StopWithTimeout(1 * time.Second)
|
||||
})
|
||||
|
||||
t.Run("with multiple servers", func(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Servers: []config.ServerConfig{
|
||||
{Listen: "127.0.0.1:0"},
|
||||
{Listen: "127.0.0.1:0"},
|
||||
},
|
||||
}
|
||||
|
||||
s := MustStartTestServer(cfg)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil server")
|
||||
}
|
||||
|
||||
// 给服务器一点时间启动
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// 停止服务器
|
||||
_ = s.StopWithTimeout(1 * time.Second)
|
||||
})
|
||||
}
|
||||
|
||||
// TestTestDependencies 测试 TestDependencies 结构体
|
||||
func TestTestDependencies(t *testing.T) {
|
||||
t.Run("empty dependencies", func(t *testing.T) {
|
||||
deps := &TestDependencies{}
|
||||
if deps.LuaEngine != nil {
|
||||
t.Error("LuaEngine should be nil")
|
||||
}
|
||||
if deps.TLSManager != nil {
|
||||
t.Error("TLSManager should be nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with lua engine only", func(t *testing.T) {
|
||||
luaEngine := &lua.LuaEngine{}
|
||||
deps := &TestDependencies{
|
||||
LuaEngine: luaEngine,
|
||||
}
|
||||
if deps.LuaEngine != luaEngine {
|
||||
t.Error("LuaEngine not set correctly")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestTestServerOptions 测试 TestServerOptions 结构体
|
||||
func TestTestServerOptions(t *testing.T) {
|
||||
t.Run("empty options", func(t *testing.T) {
|
||||
opts := &TestServerOptions{}
|
||||
if opts.MockFastServer != nil {
|
||||
t.Error("MockFastServer should be nil")
|
||||
}
|
||||
if opts.CustomHandler != nil {
|
||||
t.Error("CustomHandler should be nil")
|
||||
}
|
||||
if opts.SkipListener {
|
||||
t.Error("SkipListener should be false")
|
||||
}
|
||||
if opts.DisableMiddleware {
|
||||
t.Error("DisableMiddleware should be false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with all options", func(t *testing.T) {
|
||||
mock := &MockFastServer{Name: "test"}
|
||||
handler := func(ctx *fasthttp.RequestCtx) {}
|
||||
|
||||
opts := &TestServerOptions{
|
||||
MockFastServer: mock,
|
||||
CustomHandler: handler,
|
||||
SkipListener: true,
|
||||
DisableMiddleware: true,
|
||||
}
|
||||
|
||||
if opts.MockFastServer != mock {
|
||||
t.Error("MockFastServer not set correctly")
|
||||
}
|
||||
if opts.CustomHandler == nil {
|
||||
t.Error("CustomHandler should be set")
|
||||
}
|
||||
if !opts.SkipListener {
|
||||
t.Error("SkipListener should be true")
|
||||
}
|
||||
if !opts.DisableMiddleware {
|
||||
t.Error("DisableMiddleware should be true")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMockFastServer_Fields 测试 MockFastServer 字段
|
||||
func TestMockFastServer_Fields(t *testing.T) {
|
||||
mock := &MockFastServer{
|
||||
Name: "test-server",
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 20 * time.Second,
|
||||
IdleTimeout: 30 * time.Second,
|
||||
MaxConnsPerIP: 100,
|
||||
MaxRequestsPerConn: 1000,
|
||||
CloseOnShutdown: true,
|
||||
}
|
||||
|
||||
if mock.Name != "test-server" {
|
||||
t.Errorf("expected Name test-server, got %s", mock.Name)
|
||||
}
|
||||
if mock.ReadTimeout != 10*time.Second {
|
||||
t.Errorf("expected ReadTimeout 10s, got %v", mock.ReadTimeout)
|
||||
}
|
||||
if mock.WriteTimeout != 20*time.Second {
|
||||
t.Errorf("expected WriteTimeout 20s, got %v", mock.WriteTimeout)
|
||||
}
|
||||
if mock.IdleTimeout != 30*time.Second {
|
||||
t.Errorf("expected IdleTimeout 30s, got %v", mock.IdleTimeout)
|
||||
}
|
||||
if mock.MaxConnsPerIP != 100 {
|
||||
t.Errorf("expected MaxConnsPerIP 100, got %d", mock.MaxConnsPerIP)
|
||||
}
|
||||
if mock.MaxRequestsPerConn != 1000 {
|
||||
t.Errorf("expected MaxRequestsPerConn 1000, got %d", mock.MaxRequestsPerConn)
|
||||
}
|
||||
if !mock.CloseOnShutdown {
|
||||
t.Error("CloseOnShutdown should be true")
|
||||
}
|
||||
}
|
||||
@ -17,6 +17,7 @@ package server
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@ -313,3 +314,285 @@ func TestReadOldPid_EmptyFile(t *testing.T) {
|
||||
t.Error("Expected error for empty PID file")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotifyOldProcess_ReadPidError 测试读取 PID 失败的情况
|
||||
func TestNotifyOldProcess_ReadPidError(t *testing.T) {
|
||||
mgr := NewUpgradeManager(nil)
|
||||
// 不设置 PID 文件,ReadOldPid 会返回错误
|
||||
|
||||
err := mgr.NotifyOldProcess()
|
||||
if err != nil {
|
||||
t.Errorf("NotifyOldProcess should return nil when ReadOldPid fails, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotifyOldProcess_ZeroPid 测试 PID 为 0 的情况
|
||||
func TestNotifyOldProcess_ZeroPid(t *testing.T) {
|
||||
tmpFile := "/tmp/lolly-test-zero.pid"
|
||||
defer func() {
|
||||
_ = os.Remove(tmpFile)
|
||||
}()
|
||||
|
||||
// 写入 PID 0
|
||||
if err := os.WriteFile(tmpFile, []byte("0"), 0o644); err != nil {
|
||||
t.Fatalf("Failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
mgr := NewUpgradeManager(nil)
|
||||
mgr.SetPidFile(tmpFile)
|
||||
|
||||
err := mgr.NotifyOldProcess()
|
||||
if err != nil {
|
||||
t.Errorf("NotifyOldProcess should return nil for PID 0, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotifyOldProcess_NonExistentProcess 测试通知不存在的进程
|
||||
func TestNotifyOldProcess_NonExistentProcess(t *testing.T) {
|
||||
tmpFile := "/tmp/lolly-test-nonexistent.pid"
|
||||
defer func() {
|
||||
_ = os.Remove(tmpFile)
|
||||
}()
|
||||
|
||||
// 写入一个不存在的 PID (使用一个极大的 PID 值)
|
||||
if err := os.WriteFile(tmpFile, []byte("9999999"), 0o644); err != nil {
|
||||
t.Fatalf("Failed to write temp file: %v", err)
|
||||
}
|
||||
|
||||
mgr := NewUpgradeManager(nil)
|
||||
mgr.SetPidFile(tmpFile)
|
||||
|
||||
// 对于不存在的进程,Signal 会返回错误
|
||||
// NotifyOldProcess 会直接返回这个错误
|
||||
err := mgr.NotifyOldProcess()
|
||||
if err == nil {
|
||||
t.Error("Expected error when notifying non-existent process")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotifyOldProcess_FindProcessError 测试 os.FindProcess 的行为
|
||||
// 注意:在 Unix 系统上,os.FindProcess 总是成功,即使进程不存在
|
||||
func TestNotifyOldProcess_FindProcessBehavior(t *testing.T) {
|
||||
// 这个测试验证 os.FindProcess 的行为
|
||||
// 在 Unix 上,FindProcess 总是返回一个 Process 对象
|
||||
process, err := os.FindProcess(9999999)
|
||||
if err != nil {
|
||||
t.Errorf("FindProcess should not return error on Unix, got: %v", err)
|
||||
}
|
||||
if process == nil {
|
||||
t.Error("FindProcess should return non-nil Process on Unix")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetupSignalHandlers_SetsUpChannel 测试信号处理器设置
|
||||
func TestSetupSignalHandlers_SetsUpChannel(t *testing.T) {
|
||||
mgr := NewUpgradeManager(nil)
|
||||
|
||||
// 调用 SetupSignalHandlers 应该不会 panic
|
||||
mgr.SetupSignalHandlers("/nonexistent/binary")
|
||||
|
||||
// 给 goroutine 一点时间启动
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// 测试通过如果没 panic
|
||||
}
|
||||
|
||||
// TestSetupSignalHandlers_TriggersUpgrade 测试信号触发升级
|
||||
func TestSetupSignalHandlers_TriggersUpgrade(t *testing.T) {
|
||||
mgr := NewUpgradeManager(nil)
|
||||
|
||||
// 创建一个监听器
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create listener: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
mgr.SetListeners([]net.Listener{listener})
|
||||
|
||||
// 设置信号处理器,使用一个不存在的二进制文件
|
||||
mgr.SetupSignalHandlers("/nonexistent/binary/path")
|
||||
|
||||
// 给 goroutine 启动时间
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// 发送 SIGUSR2 信号给当前进程
|
||||
// 注意:这会触发 GracefulUpgrade,但由于二进制文件不存在会失败
|
||||
// 信号处理器会忽略错误(使用 _ = u.GracefulUpgrade)
|
||||
process, err := os.FindProcess(os.Getpid())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to find current process: %v", err)
|
||||
}
|
||||
|
||||
// 发送 SIGUSR2
|
||||
if err := process.Signal(syscall.SIGUSR2); err != nil {
|
||||
t.Fatalf("Failed to send SIGUSR2: %v", err)
|
||||
}
|
||||
|
||||
// 等待信号处理
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 测试通过如果没有 panic
|
||||
}
|
||||
|
||||
// TestGracefulUpgrade_UnsupportedListener 测试不支持的监听器类型
|
||||
func TestGracefulUpgrade_UnsupportedListener(t *testing.T) {
|
||||
mgr := NewUpgradeManager(nil)
|
||||
|
||||
// 使用 mock 监听器(不支持的类型)
|
||||
mgr.SetListeners([]net.Listener{&mockListener{}})
|
||||
|
||||
err := mgr.GracefulUpgrade("/nonexistent/binary")
|
||||
if err == nil {
|
||||
t.Error("Expected error for unsupported listener type")
|
||||
}
|
||||
if err != nil && !containsString(err.Error(), "unsupported listener type") &&
|
||||
!containsString(err.Error(), "failed to get listener file") {
|
||||
t.Errorf("Expected unsupported listener error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGracefulUpgrade_NonexistentBinary 测试不存在的二进制文件
|
||||
// 注意:此测试使用 mock 监听器避免创建实际网络连接
|
||||
func TestGracefulUpgrade_NonexistentBinary(t *testing.T) {
|
||||
mgr := NewUpgradeManager(nil)
|
||||
|
||||
// 使用 mock 监听器测试不支持类型的错误路径
|
||||
mgr.SetListeners([]net.Listener{&mockListener{}})
|
||||
|
||||
// 由于 mockListener 是不支持的类型,应该返回错误
|
||||
err := mgr.GracefulUpgrade("/nonexistent/path/to/binary")
|
||||
if err == nil {
|
||||
t.Error("Expected error for unsupported listener type")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGracefulUpgrade_WithPidFile 测试升级时写入 PID 文件
|
||||
// 注意:此测试使用 mock 监听器避免创建实际网络连接
|
||||
func TestGracefulUpgrade_WithPidFile(t *testing.T) {
|
||||
tmpFile := "/tmp/lolly-test-upgrade.pid"
|
||||
defer func() {
|
||||
_ = os.Remove(tmpFile)
|
||||
}()
|
||||
|
||||
mgr := NewUpgradeManager(nil)
|
||||
mgr.SetPidFile(tmpFile)
|
||||
|
||||
// 使用 mock 监听器
|
||||
mgr.SetListeners([]net.Listener{&mockListener{}})
|
||||
|
||||
// 使用不存在的二进制文件,会失败但测试 PID 文件设置逻辑
|
||||
_ = mgr.GracefulUpgrade("/nonexistent/binary")
|
||||
// 测试通过如果没有 panic
|
||||
}
|
||||
|
||||
// TestWaitForShutdown_ProcessExits 测试进程退出后的等待
|
||||
func TestWaitForShutdown_ProcessExits(t *testing.T) {
|
||||
mgr := NewUpgradeManager(nil)
|
||||
|
||||
// 使用一个不存在的 PID(Signal(0) 会返回错误)
|
||||
mgr.oldPid = 9999999
|
||||
|
||||
// 不存在的进程应该立即返回 nil
|
||||
err := mgr.WaitForShutdown(1 * time.Second)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil for non-existent process, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWaitForShutdown_Timeout 测试等待超时
|
||||
func TestWaitForShutdown_Timeout(t *testing.T) {
|
||||
// 跳过此测试:需要实际运行的进程来测试超时
|
||||
// 向当前进程发送 SIGKILL 会导致测试崩溃
|
||||
t.Skip("Skipping test that would kill current process")
|
||||
}
|
||||
|
||||
// TestWaitForShutdown_SetsOldPid 测试 oldPid 设置
|
||||
func TestWaitForShutdown_SetsOldPid(t *testing.T) {
|
||||
mgr := NewUpgradeManager(nil)
|
||||
|
||||
// oldPid 为 0 时应该直接返回 nil
|
||||
if mgr.oldPid != 0 {
|
||||
t.Error("Expected oldPid to be 0 initially")
|
||||
}
|
||||
|
||||
err := mgr.WaitForShutdown(100 * time.Millisecond)
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil when oldPid is 0, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestListenerFile_UnixListener 测试 Unix 监听器获取文件
|
||||
// 注意:跳过此测试,因为在大量测试运行时可能导致 FD 问题
|
||||
func TestListenerFile_UnixListener(t *testing.T) {
|
||||
t.Skip("Skipping test to avoid FD exhaustion in parallel test runs")
|
||||
}
|
||||
|
||||
// TestGracefulUpgrade_MultipleListeners 测试多个监听器的升级
|
||||
// 注意:使用 mock 监听器避免 FD 问题
|
||||
func TestGracefulUpgrade_MultipleListeners(t *testing.T) {
|
||||
mgr := NewUpgradeManager(nil)
|
||||
|
||||
// 使用多个 mock 监听器
|
||||
mgr.SetListeners([]net.Listener{&mockListener{}, &mockListener{}})
|
||||
|
||||
// 由于 mockListener 是不支持的类型,应该返回错误
|
||||
err := mgr.GracefulUpgrade("/nonexistent/binary")
|
||||
if err == nil {
|
||||
t.Error("Expected error for unsupported listener type")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGracefulUpgrade_RelativePath 测试相对路径的二进制文件
|
||||
// 注意:使用 mock 监听器避免 FD 问题
|
||||
func TestGracefulUpgrade_RelativePath(t *testing.T) {
|
||||
mgr := NewUpgradeManager(nil)
|
||||
|
||||
// 使用 mock 监听器
|
||||
mgr.SetListeners([]net.Listener{&mockListener{}})
|
||||
|
||||
// 使用相对路径,应该返回错误
|
||||
err := mgr.GracefulUpgrade("./nonexistent")
|
||||
if err == nil {
|
||||
t.Error("Expected error for unsupported listener type")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWaitForShutdown_FindProcessError 测试 FindProcess 行为
|
||||
func TestWaitForShutdown_FindProcessError(t *testing.T) {
|
||||
// 在 Unix 系统上,os.FindProcess 总是成功
|
||||
// 我们需要测试 Signal(0) 失败的情况
|
||||
mgr := NewUpgradeManager(nil)
|
||||
mgr.oldPid = 1 // init 进程,通常存在但无法发送信号
|
||||
|
||||
// 短超时
|
||||
_ = mgr.WaitForShutdown(50 * time.Millisecond)
|
||||
// 无论结果如何,测试都应该正常完成
|
||||
}
|
||||
|
||||
// TestGetInheritedListeners_EnvPreserved 测试环境变量处理
|
||||
func TestGetInheritedListeners_EnvPreserved(t *testing.T) {
|
||||
t.Setenv("LISTEN_FDS", "1")
|
||||
|
||||
mgr := NewUpgradeManager(nil)
|
||||
_, err := mgr.GetInheritedListeners()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 环境变量应该仍然存在
|
||||
if os.Getenv("LISTEN_FDS") != "1" {
|
||||
t.Error("LISTEN_FDS env should be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
// containsString 检查字符串是否包含子串
|
||||
func containsString(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -17,6 +17,7 @@ import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -699,3 +700,172 @@ func BenchmarkLoadCACertPool(b *testing.B) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestFingerprint_Nil 测试 nil 证书的指纹。
|
||||
func TestFingerprint_Nil(t *testing.T) {
|
||||
result := fingerprint(nil)
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty string for nil cert, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFingerprint_Valid 测试有效证书的指纹。
|
||||
func TestFingerprint_Valid(t *testing.T) {
|
||||
caCert, _, _ := generateTestCA(t)
|
||||
result := fingerprint(caCert)
|
||||
if result == "" {
|
||||
t.Error("Expected non-empty fingerprint for valid cert")
|
||||
}
|
||||
// 指纹应该是证书 Raw 的十六进制表示
|
||||
expected := fmt.Sprintf("%x", caCert.Raw)
|
||||
if result != expected {
|
||||
t.Error("Fingerprint should match certificate Raw hex")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientVerifyMode_String 测试验证模式字符串表示。
|
||||
func TestClientVerifyMode_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
mode ClientVerifyMode
|
||||
expected string
|
||||
}{
|
||||
{VerifyOff, "off"},
|
||||
{VerifyOn, "on"},
|
||||
{VerifyOptional, "optional"},
|
||||
{VerifyOptionalNoCA, "optional_no_ca"},
|
||||
{ClientVerifyMode(99), "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
if got := tt.mode.String(); got != tt.expected {
|
||||
t.Errorf("String() = %q, want %q", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewClientVerifier_InvalidMode 测试无效验证模式。
|
||||
func TestNewClientVerifier_InvalidMode(t *testing.T) {
|
||||
_, err := NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: "invalid_mode",
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid verify mode")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewClientVerifier_WithCRL 测试带 CRL 的验证器。
|
||||
func TestNewClientVerifier_WithCRL(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
crlFile := filepath.Join(tempDir, "crl.pem")
|
||||
|
||||
caCert, caKey, caPEM := generateTestCA(t)
|
||||
if err := os.WriteFile(caFile, caPEM, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write CA file: %v", err)
|
||||
}
|
||||
|
||||
// 生成空 CRL
|
||||
crlPEM := generateTestCRL(t, caCert, caKey, nil)
|
||||
if err := os.WriteFile(crlFile, crlPEM, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write CRL file: %v", err)
|
||||
}
|
||||
|
||||
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: "on",
|
||||
ClientCA: caFile,
|
||||
CRL: crlFile,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientVerifier() failed: %v", err)
|
||||
}
|
||||
if !verifier.IsEnabled() {
|
||||
t.Error("Verifier should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewClientVerifier_InvalidCRL 测试无效 CRL 文件。
|
||||
func TestNewClientVerifier_InvalidCRL(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
crlFile := filepath.Join(tempDir, "invalid.crl")
|
||||
|
||||
_, _, caPEM := generateTestCA(t)
|
||||
if err := os.WriteFile(caFile, caPEM, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write CA file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(crlFile, []byte("invalid crl data"), 0o644); err != nil {
|
||||
t.Fatalf("Failed to write invalid CRL file: %v", err)
|
||||
}
|
||||
|
||||
_, err := NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: "on",
|
||||
ClientCA: caFile,
|
||||
CRL: crlFile,
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid CRL file")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientVerifier_ValidateClientCertificate_WithCRL 测试带 CRL 的证书验证。
|
||||
func TestClientVerifier_ValidateClientCertificate_WithCRL(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
crlFile := filepath.Join(tempDir, "crl.pem")
|
||||
|
||||
caCert, caKey, caPEM := generateTestCA(t)
|
||||
if err := os.WriteFile(caFile, caPEM, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write CA file: %v", err)
|
||||
}
|
||||
|
||||
// 生成客户端证书
|
||||
clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 600)
|
||||
|
||||
// 生成包含吊销证书的 CRL
|
||||
crlPEM := generateTestCRL(t, caCert, caKey, []*big.Int{clientCert.SerialNumber})
|
||||
if err := os.WriteFile(crlFile, crlPEM, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write CRL file: %v", err)
|
||||
}
|
||||
|
||||
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: "on",
|
||||
ClientCA: caFile,
|
||||
CRL: crlFile,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientVerifier() failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证吊销证书应失败
|
||||
err = verifier.ValidateClientCertificate(clientCert)
|
||||
if err == nil {
|
||||
t.Error("Expected error for revoked certificate")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetClientCertInfo_WithEmail 测试带邮件地址的证书信息提取。
|
||||
func TestGetClientCertInfo_WithEmail(t *testing.T) {
|
||||
caCert, caKey, _ := generateTestCA(t)
|
||||
clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 700)
|
||||
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{clientCert},
|
||||
}
|
||||
|
||||
info := GetClientCertInfo(cs)
|
||||
if info == nil {
|
||||
t.Fatal("GetClientCertInfo() returned nil")
|
||||
}
|
||||
if len(info.Email) == 0 {
|
||||
t.Error("Expected email addresses in cert info")
|
||||
}
|
||||
if info.Email[0] != "test@example.com" {
|
||||
t.Errorf("Expected email test@example.com, got %s", info.Email[0])
|
||||
}
|
||||
}
|
||||
|
||||
@ -556,3 +556,232 @@ func TestOCSPManager_GetStatus_EdgeCases(t *testing.T) {
|
||||
t.Error("Expected no response for empty response data")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCSPManager_RegisterCertificate_NilCert 测试注册空证书
|
||||
func TestOCSPManager_RegisterCertificate_NilCert(t *testing.T) {
|
||||
mgr := NewOCSPManager(nil)
|
||||
defer mgr.Stop()
|
||||
|
||||
err := mgr.RegisterCertificate(nil, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil certificate")
|
||||
}
|
||||
if err.Error() != "certificate is nil" {
|
||||
t.Errorf("Expected 'certificate is nil' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCSPManager_RegisterCertificate_NoOCSPServer 测试无 OCSP 服务器的证书
|
||||
func TestOCSPManager_RegisterCertificate_NoOCSPServer(t *testing.T) {
|
||||
mgr := NewOCSPManager(nil)
|
||||
defer mgr.Stop()
|
||||
|
||||
// 创建无 OCSP 服务器的证书
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(12345),
|
||||
Subject: pkix.Name{CommonName: "test"},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(1 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
DNSNames: []string{"localhost"},
|
||||
// 无 OCSPServer
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse certificate: %v", err)
|
||||
}
|
||||
|
||||
err = mgr.RegisterCertificate(cert, cert)
|
||||
if err == nil {
|
||||
t.Error("Expected error for certificate without OCSP server")
|
||||
}
|
||||
if err.Error() != "certificate has no OCSP server URL" {
|
||||
t.Errorf("Expected 'certificate has no OCSP server URL' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCSPManager_SendOCSPRequest_Error 测试 OCSP 请求错误
|
||||
func TestOCSPManager_SendOCSPRequest_Error(t *testing.T) {
|
||||
cfg := &OCSPConfig{
|
||||
Enabled: true,
|
||||
RefreshInterval: 1 * time.Hour,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
MaxRetries: 1,
|
||||
}
|
||||
mgr := NewOCSPManager(cfg)
|
||||
|
||||
// 测试无效 URL
|
||||
_, err := mgr.sendOCSPRequest("://invalid-url", []byte("test"))
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid URL")
|
||||
}
|
||||
|
||||
// 测试连接失败
|
||||
_, err = mgr.sendOCSPRequest("http://127.0.0.1:9999/ocsp", []byte("test"))
|
||||
if err == nil {
|
||||
t.Error("Expected error for connection failure")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCSPManager_RefreshResponse_WithExistingEntry 测试刷新已有条目的响应
|
||||
func TestOCSPManager_RefreshResponse_WithExistingEntry(t *testing.T) {
|
||||
cfg := &OCSPConfig{
|
||||
Enabled: true,
|
||||
RefreshInterval: 1 * time.Hour,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
MaxRetries: 1,
|
||||
}
|
||||
mgr := NewOCSPManager(cfg)
|
||||
|
||||
serial := big.NewInt(12345)
|
||||
priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
OCSPServer: []string{"http://invalid.ocsp.server.example.com"},
|
||||
}
|
||||
certDER, _ := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
cert, _ := x509.ParseCertificate(certDER)
|
||||
|
||||
// 预先添加一个条目
|
||||
mgr.mu.Lock()
|
||||
mgr.responses[serial.String()] = &ocspResponse{
|
||||
status: statusValid,
|
||||
response: []byte("test-response"),
|
||||
fetchedAt: time.Now(),
|
||||
nextUpdate: time.Now().Add(1 * time.Hour),
|
||||
errors: 0,
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
|
||||
// 刷新会失败,但应该增加错误计数
|
||||
_ = mgr.RefreshResponse(cert, cert)
|
||||
|
||||
// 验证错误计数增加
|
||||
mgr.mu.RLock()
|
||||
entry := mgr.responses[serial.String()]
|
||||
mgr.mu.RUnlock()
|
||||
|
||||
if entry.errors != 1 {
|
||||
t.Errorf("Expected errors=1, got %d", entry.errors)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCSPManager_RefreshResponse_StatusFailed 测试刷新失败后状态变化
|
||||
func TestOCSPManager_RefreshResponse_StatusFailed(t *testing.T) {
|
||||
cfg := &OCSPConfig{
|
||||
Enabled: true,
|
||||
RefreshInterval: 1 * time.Hour,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
MaxRetries: 1,
|
||||
}
|
||||
mgr := NewOCSPManager(cfg)
|
||||
|
||||
serial := big.NewInt(99999)
|
||||
priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
OCSPServer: []string{"http://invalid.ocsp.server.example.com"},
|
||||
}
|
||||
certDER, _ := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
cert, _ := x509.ParseCertificate(certDER)
|
||||
|
||||
// 预先添加一个条目,错误计数接近阈值
|
||||
mgr.mu.Lock()
|
||||
mgr.responses[serial.String()] = &ocspResponse{
|
||||
status: statusValid,
|
||||
response: []byte("test-response"),
|
||||
fetchedAt: time.Now(),
|
||||
nextUpdate: time.Now().Add(1 * time.Hour),
|
||||
errors: 2, // 接近 maxRetries=1, 下次失败会变成 statusFailed
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
|
||||
// 刷新会失败
|
||||
_ = mgr.RefreshResponse(cert, cert)
|
||||
|
||||
// 验证状态变为 failed(因为 errors >= maxRetries)
|
||||
mgr.mu.RLock()
|
||||
entry := mgr.responses[serial.String()]
|
||||
mgr.mu.RUnlock()
|
||||
|
||||
if entry.status != statusFailed {
|
||||
t.Errorf("Expected statusFailed, got %v", entry.status)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCSPManager_FetchOCSP_NoServer 测试无 OCSP 服务器时的 fetchOCSP
|
||||
func TestOCSPManager_FetchOCSP_NoServer(t *testing.T) {
|
||||
mgr := NewOCSPManager(nil)
|
||||
|
||||
priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
// 无 OCSPServer
|
||||
}
|
||||
certDER, _ := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
cert, _ := x509.ParseCertificate(certDER)
|
||||
|
||||
_, err := mgr.fetchOCSP(cert, cert)
|
||||
if err == nil {
|
||||
t.Error("Expected error for certificate without OCSP server")
|
||||
}
|
||||
if err.Error() != "no OCSP server in certificate" {
|
||||
t.Errorf("Expected 'no OCSP server in certificate' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCSPManager_StartTwice 测试重复启动
|
||||
func TestOCSPManager_StartTwice(t *testing.T) {
|
||||
mgr := NewOCSPManager(nil)
|
||||
|
||||
mgr.Start()
|
||||
defer mgr.Stop()
|
||||
|
||||
// 第二次启动应该无效果
|
||||
mgr.Start()
|
||||
|
||||
mgr.mu.RLock()
|
||||
running := mgr.running
|
||||
mgr.mu.RUnlock()
|
||||
|
||||
if !running {
|
||||
t.Error("Expected manager to be running")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCSPManager_StopTwice 测试重复停止
|
||||
func TestOCSPManager_StopTwice(t *testing.T) {
|
||||
mgr := NewOCSPManager(nil)
|
||||
|
||||
mgr.Start()
|
||||
mgr.Stop()
|
||||
|
||||
// 第二次停止应该无效果
|
||||
mgr.Stop()
|
||||
|
||||
mgr.mu.RLock()
|
||||
running := mgr.running
|
||||
mgr.mu.RUnlock()
|
||||
|
||||
if running {
|
||||
t.Error("Expected manager to be stopped")
|
||||
}
|
||||
}
|
||||
|
||||
@ -884,3 +884,412 @@ func TestGetOCSPStatus_NoManager(t *testing.T) {
|
||||
t.Errorf("Expected empty status, got %d entries", len(status))
|
||||
}
|
||||
}
|
||||
|
||||
// TestParsePEMChain 测试 PEM 证书链解析
|
||||
func TestParsePEMChain(t *testing.T) {
|
||||
// 测试有效的 PEM 数据
|
||||
certPEM, _ := generateTestCert(t)
|
||||
certs := parsePEMChain(certPEM)
|
||||
if len(certs) == 0 {
|
||||
t.Error("Expected at least one certificate from valid PEM")
|
||||
}
|
||||
|
||||
// 测试空数据
|
||||
emptyCerts := parsePEMChain([]byte{})
|
||||
if len(emptyCerts) != 0 {
|
||||
t.Error("Expected no certificates from empty data")
|
||||
}
|
||||
|
||||
// 测试无效 PEM 数据
|
||||
invalidCerts := parsePEMChain([]byte("not valid pem"))
|
||||
if len(invalidCerts) != 0 {
|
||||
t.Error("Expected no certificates from invalid PEM")
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractPEMBlock 测试 PEM 块提取
|
||||
func TestExtractPEMBlock(t *testing.T) {
|
||||
// 测试有效的证书块
|
||||
certPEM, _ := generateTestCert(t)
|
||||
block, rest := extractPEMBlock(certPEM)
|
||||
if block == nil {
|
||||
t.Error("Expected non-nil block from valid PEM")
|
||||
}
|
||||
if len(block) == 0 {
|
||||
t.Error("Expected non-empty block")
|
||||
}
|
||||
_ = rest
|
||||
|
||||
// 测试空数据
|
||||
block, _ = extractPEMBlock([]byte{})
|
||||
if block != nil {
|
||||
t.Error("Expected nil block from empty data")
|
||||
}
|
||||
|
||||
// 测试无结束标记的数据
|
||||
invalidData := []byte("-----BEGIN CERTIFICATE-----\nsome data without end")
|
||||
block, _ = extractPEMBlock(invalidData)
|
||||
if block != nil {
|
||||
t.Error("Expected nil block from incomplete PEM")
|
||||
}
|
||||
|
||||
// 测试无开始标记的数据
|
||||
noStartData := []byte("some data\n-----END CERTIFICATE-----")
|
||||
block, _ = extractPEMBlock(noStartData)
|
||||
if block != nil {
|
||||
t.Error("Expected nil block from data without start marker")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindMarker 测试标记查找
|
||||
func TestFindMarker(t *testing.T) {
|
||||
data := []byte("prefix-----BEGIN CERTIFICATE-----suffix")
|
||||
marker := []byte("-----BEGIN CERTIFICATE-----")
|
||||
|
||||
idx := findMarker(data, marker)
|
||||
if idx != 6 {
|
||||
t.Errorf("Expected index 6, got %d", idx)
|
||||
}
|
||||
|
||||
// 测试不存在的标记
|
||||
idx = findMarker(data, []byte("NOTFOUND"))
|
||||
if idx != -1 {
|
||||
t.Errorf("Expected -1 for not found marker, got %d", idx)
|
||||
}
|
||||
|
||||
// 测试空数据
|
||||
idx = findMarker([]byte{}, marker)
|
||||
if idx != -1 {
|
||||
t.Errorf("Expected -1 for empty data, got %d", idx)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatchMarker 测试标记匹配
|
||||
func TestMatchMarker(t *testing.T) {
|
||||
data := []byte("-----BEGIN CERTIFICATE-----suffix")
|
||||
marker := []byte("-----BEGIN CERTIFICATE-----")
|
||||
|
||||
if !matchMarker(data, marker) {
|
||||
t.Error("Expected true for matching marker")
|
||||
}
|
||||
|
||||
// 测试不匹配
|
||||
if matchMarker(data, []byte("-----END CERTIFICATE-----")) {
|
||||
t.Error("Expected false for non-matching marker")
|
||||
}
|
||||
|
||||
// 测试数据长度小于标记
|
||||
shortData := []byte("short")
|
||||
if matchMarker(shortData, marker) {
|
||||
t.Error("Expected false when data is shorter than marker")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetCertificate_NoCertificate 测试无证书时的错误情况
|
||||
func TestGetCertificate_NoCertificate(t *testing.T) {
|
||||
manager := &TLSManager{
|
||||
configs: make(map[string]*tls.Config),
|
||||
}
|
||||
|
||||
getCert := manager.GetCertificate()
|
||||
if getCert == nil {
|
||||
t.Fatal("Expected non-nil GetCertificate function")
|
||||
}
|
||||
|
||||
// 测试未知服务器名且无默认证书
|
||||
testHello := &tls.ClientHelloInfo{
|
||||
ServerName: "unknown.com",
|
||||
}
|
||||
certResult, err := getCert(testHello)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no certificate available")
|
||||
}
|
||||
if certResult != nil {
|
||||
t.Error("Expected nil certificate")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetConfigForClientWithOCSP 测试 OCSP 配置回调
|
||||
func TestGetConfigForClientWithOCSP(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
// 生成带有 OCSP 服务器的证书
|
||||
certPEM, keyPEM := generateTestCertWithOCSP(t, []string{"http://ocsp.example.com"})
|
||||
if err := os.WriteFile(certPath, certPEM, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, keyPEM, 0o600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
OCSPStapling: true,
|
||||
}
|
||||
|
||||
manager, err := NewTLSManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTLSManager() failed: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
// 测试 GetConfigForClient 回调
|
||||
testHello := &tls.ClientHelloInfo{
|
||||
ServerName: "localhost",
|
||||
}
|
||||
tlsCfg, err := manager.getConfigForClientWithOCSP(testHello)
|
||||
if err != nil {
|
||||
t.Errorf("getConfigForClientWithOCSP() error = %v", err)
|
||||
}
|
||||
if tlsCfg == nil {
|
||||
t.Error("Expected non-nil TLS config")
|
||||
}
|
||||
|
||||
// 测试空 ServerName
|
||||
emptyHello := &tls.ClientHelloInfo{
|
||||
ServerName: "",
|
||||
}
|
||||
if _, err := manager.getConfigForClientWithOCSP(emptyHello); err != nil {
|
||||
t.Errorf("getConfigForClientWithOCSP() error with empty ServerName = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadCertificate_WithCertChain 测试带证书链的加载
|
||||
func TestLoadCertificate_WithCertChain(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
chainPath := filepath.Join(tmpDir, "chain.pem")
|
||||
|
||||
// 生成主证书
|
||||
certPEM, keyPEM := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, certPEM, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, keyPEM, 0o600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
// 生成证书链(使用另一个测试证书)
|
||||
chainCert, _ := generateTestCert(t)
|
||||
if err := os.WriteFile(chainPath, chainCert, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write chain: %v", err)
|
||||
}
|
||||
|
||||
// 测试加载带证书链的证书
|
||||
cert, err := loadCertificate(certPath, keyPath, chainPath)
|
||||
if err != nil {
|
||||
t.Fatalf("loadCertificate() error = %v", err)
|
||||
}
|
||||
if len(cert.Certificate) < 2 {
|
||||
t.Errorf("Expected at least 2 certificates in chain, got %d", len(cert.Certificate))
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadCertificate_InvalidChain 测试无效证书链
|
||||
func TestLoadCertificate_InvalidChain(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
certPEM, keyPEM := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, certPEM, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, keyPEM, 0o600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
// 测试不存在的证书链文件
|
||||
_, err := loadCertificate(certPath, keyPath, "/nonexistent/chain.pem")
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent chain file")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateTLSConfig_NilConfig 测试 nil 配置
|
||||
func TestCreateTLSConfig_NilConfig(t *testing.T) {
|
||||
_, err := createTLSConfig(nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error for nil config")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewTLSManager_WithSessionTickets 测试启用 Session Tickets
|
||||
func TestNewTLSManager_WithSessionTickets(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
ticketKeyPath := filepath.Join(tmpDir, "ticket.key")
|
||||
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0o600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
SessionTickets: config.SessionTicketsConfig{
|
||||
Enabled: true,
|
||||
KeyFile: ticketKeyPath,
|
||||
RotateInterval: time.Hour,
|
||||
RetainKeys: 3,
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := NewTLSManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTLSManager() failed: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
// 验证 Session Ticket 管理器已初始化
|
||||
manager.mu.RLock()
|
||||
stm := manager.sessionTicketMgr
|
||||
manager.mu.RUnlock()
|
||||
|
||||
if stm == nil {
|
||||
t.Error("Expected session ticket manager to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewTLSManager_WithClientVerify 测试启用客户端验证
|
||||
func TestNewTLSManager_WithClientVerify(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
caPath := filepath.Join(tmpDir, "ca.pem")
|
||||
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0o600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
// 创建 CA 证书
|
||||
_, _, caPEM := generateTestCA(t)
|
||||
if err := os.WriteFile(caPath, caPEM, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write CA: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
ClientVerify: config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: "on",
|
||||
ClientCA: caPath,
|
||||
VerifyDepth: 3,
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := NewTLSManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTLSManager() failed: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
// 验证客户端验证器已初始化
|
||||
manager.mu.RLock()
|
||||
cv := manager.clientVerifier
|
||||
manager.mu.RUnlock()
|
||||
|
||||
if cv == nil {
|
||||
t.Error("Expected client verifier to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewTLSManager_WithInvalidClientCA 测试无效的客户端 CA
|
||||
func TestNewTLSManager_WithInvalidClientCA(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
cert, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, key, 0o600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
ClientVerify: config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: "on",
|
||||
ClientCA: "/nonexistent/ca.pem",
|
||||
},
|
||||
}
|
||||
|
||||
// 客户端验证配置失败不阻止 TLS 工作
|
||||
manager, err := NewTLSManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTLSManager() should not fail for invalid client CA: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
// 客户端验证器应未初始化
|
||||
manager.mu.RLock()
|
||||
cv := manager.clientVerifier
|
||||
manager.mu.RUnlock()
|
||||
|
||||
if cv != nil {
|
||||
t.Error("Expected client verifier to be nil for invalid CA")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewTLSManager_WithOCSPAndIssuer 测试带颁发者证书的 OCSP
|
||||
func TestNewTLSManager_WithOCSPAndIssuer(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
keyPath := filepath.Join(tmpDir, "key.pem")
|
||||
chainPath := filepath.Join(tmpDir, "chain.pem")
|
||||
|
||||
// 生成带 OCSP 服务器的证书
|
||||
certPEM, keyPEM := generateTestCertWithOCSP(t, []string{"http://ocsp.example.com"})
|
||||
if err := os.WriteFile(certPath, certPEM, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, keyPEM, 0o600); err != nil {
|
||||
t.Fatalf("Failed to write key: %v", err)
|
||||
}
|
||||
|
||||
// 生成证书链(颁发者证书)
|
||||
chainCert, _ := generateTestCert(t)
|
||||
if err := os.WriteFile(chainPath, chainCert, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write chain: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.SSLConfig{
|
||||
Cert: certPath,
|
||||
Key: keyPath,
|
||||
CertChain: chainPath,
|
||||
OCSPStapling: true,
|
||||
}
|
||||
|
||||
manager, err := NewTLSManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTLSManager() failed: %v", err)
|
||||
}
|
||||
defer manager.Close()
|
||||
|
||||
// 验证 OCSP 管理器已初始化
|
||||
manager.mu.RLock()
|
||||
om := manager.ocspManager
|
||||
manager.mu.RUnlock()
|
||||
|
||||
if om == nil {
|
||||
t.Error("Expected OCSP manager to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user