diff --git a/internal/handler/sendfile_test.go b/internal/handler/sendfile_test.go index ad4edc1..d7ead91 100644 --- a/internal/handler/sendfile_test.go +++ b/internal/handler/sendfile_test.go @@ -17,6 +17,7 @@ import ( "net" "os" "path/filepath" + "sync" "syscall" "testing" "time" @@ -600,3 +601,154 @@ func TestSendFile_JustBelowMin(t *testing.T) { t.Errorf("Body mismatch") } } + +// TestGetSocketFd_TCPConn 测试 TCPConn 获取 socket fd +func TestGetSocketFd_TCPConn(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + defer ln.Close() + + // 启动 goroutine 接收连接 + var serverConn net.Conn + go func() { + serverConn, _ = ln.Accept() + }() + + // 客户端连接 + clientConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer clientConn.Close() + + // 等待连接建立 + time.Sleep(100 * time.Millisecond) + + // 测试获取 socket fd + fd, err := getSocketFd(clientConn) + if err != nil { + t.Errorf("getSocketFd failed for TCPConn: %v", err) + } + if fd == 0 { + t.Error("Expected non-zero fd for TCPConn") + } + + if serverConn != nil { + serverConn.Close() + } +} + +// TestLinuxSendfile_WithTCPConn 测试 linuxSendfile 使用真实 TCP 连接 +func TestLinuxSendfile_WithTCPConn(t *testing.T) { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "test.txt") + + // 创建大于 MinSendfileSize 的文件 + content := make([]byte, MinSendfileSize+1024) + for i := range content { + content[i] = byte(i % 256) + } + if err := os.WriteFile(tmpFile, content, 0o644); err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + file, err := os.Open(tmpFile) + if err != nil { + t.Fatalf("Failed to open file: %v", err) + } + defer func() { _ = file.Close() }() + + // 创建 TCP 连接 + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + defer ln.Close() + + var serverConn net.Conn + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + serverConn, _ = ln.Accept() + // 读取所有数据 + buf := make([]byte, len(content)) + _, _ = io.ReadFull(serverConn, buf) + serverConn.Close() + }() + + clientConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer clientConn.Close() + + // 设置写超时 + if err := clientConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil { + t.Fatalf("Failed to set deadline: %v", err) + } + + // 调用 linuxSendfile + err = linuxSendfile(clientConn, file.Fd(), 0, int64(len(content))) + if err != nil && err != syscall.EPIPE && err != syscall.ECONNRESET { + t.Logf("linuxSendfile returned: %v", err) + } + + clientConn.Close() + wg.Wait() +} + +// TestLinuxSendfile_SendfileError 测试 sendfile 错误处理 +func TestLinuxSendfile_SendfileError(t *testing.T) { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "test.txt") + content := []byte("test content for sendfile") + if err := os.WriteFile(tmpFile, content, 0o644); err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + file, err := os.Open(tmpFile) + if err != nil { + t.Fatalf("Failed to open file: %v", err) + } + defer func() { _ = file.Close() }() + + // 使用 mockConn 测试不支持的连接类型 + conn := &mockConn{} + err = linuxSendfile(conn, file.Fd(), 0, int64(len(content))) + if err == nil { + t.Error("Expected error for unsupported connection type") + } +} + +// TestSendFile_NegativeLength 测试负长度参数 +func TestSendFile_NegativeLength(t *testing.T) { + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "test.txt") + content := []byte("test content") + if err := os.WriteFile(tmpFile, content, 0o644); err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + file, err := os.Open(tmpFile) + if err != nil { + t.Fatalf("Failed to open file: %v", err) + } + defer func() { _ = file.Close() }() + + ctx := &fasthttp.RequestCtx{} + ctx.Init(&fasthttp.Request{}, nil, nil) + + // 负长度应该使用 fallback + err = SendFile(ctx, file, 0, -1) + if err != nil { + t.Errorf("SendFile with negative length failed: %v", err) + } + + // 应该传输整个文件 + if !bytes.Equal(ctx.Response.Body(), content) { + t.Errorf("Expected body %s, got %s", content, ctx.Response.Body()) + } +} diff --git a/internal/handler/static_test.go b/internal/handler/static_test.go index 220d0ff..7608811 100644 --- a/internal/handler/static_test.go +++ b/internal/handler/static_test.go @@ -21,8 +21,10 @@ import ( "os" "path/filepath" "testing" + "time" "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/cache" "rua.plus/lolly/internal/testutil" ) @@ -1597,3 +1599,352 @@ func TestStaticHandler_TryFilesRootPathFallback(t *testing.T) { t.Errorf("内容 = %q, want %q", got, "root fallback") } } + +// TestStaticHandler_SetSymlinkCheck 测试 SetSymlinkCheck 方法 +func TestStaticHandler_SetSymlinkCheck(t *testing.T) { + handler := NewStaticHandler("/var/www", "/", nil, false) + + if handler.symlinkCheck { + t.Error("初始 symlinkCheck 应为 false") + } + + handler.SetSymlinkCheck(true) + if !handler.symlinkCheck { + t.Error("SetSymlinkCheck(true) 后 symlinkCheck 应为 true") + } + + handler.SetSymlinkCheck(false) + if handler.symlinkCheck { + t.Error("SetSymlinkCheck(false) 后 symlinkCheck 应为 false") + } +} + +// TestStaticHandler_SetInternal 测试 SetInternal 方法 +func TestStaticHandler_SetInternal(t *testing.T) { + handler := NewStaticHandler("/var/www", "/", nil, false) + + if handler.internal { + t.Error("初始 internal 应为 false") + } + + handler.SetInternal(true) + if !handler.internal { + t.Error("SetInternal(true) 后 internal 应为 true") + } + + handler.SetInternal(false) + if handler.internal { + t.Error("SetInternal(false) 后 internal 应为 false") + } +} + +// TestStaticHandler_SetCacheTTL 测试 SetCacheTTL 方法 +func TestStaticHandler_SetCacheTTL(t *testing.T) { + handler := NewStaticHandler("/var/www", "/", nil, false) + + if handler.cacheTTL != 0 { + t.Error("初始 cacheTTL 应为 0") + } + + handler.SetCacheTTL(5 * time.Second) + if handler.cacheTTL != 5*time.Second { + t.Errorf("SetCacheTTL 后 cacheTTL = %v, want %v", handler.cacheTTL, 5*time.Second) + } + + handler.SetCacheTTL(0) + if handler.cacheTTL != 0 { + t.Error("SetCacheTTL(0) 后 cacheTTL 应为 0") + } +} + +// TestStaticHandler_InternalRestriction 测试 internal 访问限制 +func TestStaticHandler_InternalRestriction(t *testing.T) { + tmpDir := t.TempDir() + content := "internal content" + if err := os.WriteFile(filepath.Join(tmpDir, "test.txt"), []byte(content), 0o644); err != nil { + t.Fatalf("创建测试文件失败: %v", err) + } + + handler := newTestHandler(t, tmpDir) + handler.SetInternal(true) + + t.Run("外部请求返回 404", func(t *testing.T) { + ctx := newTestContext(t, "/test.txt") + handler.Handle(ctx) + + if got := ctx.Response.StatusCode(); got != fasthttp.StatusNotFound { + t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusNotFound) + } + }) + + t.Run("内部重定向允许访问", func(t *testing.T) { + ctx := newTestContext(t, "/test.txt") + // 标记为内部重定向 + ctx.SetUserValue("__internal_redirect__", "/test.txt") + handler.Handle(ctx) + + if got := ctx.Response.StatusCode(); got != fasthttp.StatusOK { + t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusOK) + } + }) +} + +// TestStaticHandler_ValidateSymlink 测试符号链接验证 +func TestStaticHandler_ValidateSymlink(t *testing.T) { + tmpDir := t.TempDir() + + // 创建目标文件和目录 + targetDir := filepath.Join(tmpDir, "target") + if err := os.MkdirAll(targetDir, 0o755); err != nil { + t.Fatalf("创建目标目录失败: %v", err) + } + targetFile := filepath.Join(targetDir, "secret.txt") + if err := os.WriteFile(targetFile, []byte("secret content"), 0o644); err != nil { + t.Fatalf("创建目标文件失败: %v", err) + } + + // 创建允许的根目录 + allowedDir := filepath.Join(tmpDir, "allowed") + if err := os.MkdirAll(allowedDir, 0o755); err != nil { + t.Fatalf("创建允许目录失败: %v", err) + } + + t.Run("安全符号链接 - 在允许范围内", func(t *testing.T) { + // 在允许目录内创建符号链接 + linkFile := filepath.Join(allowedDir, "link.txt") + allowedTarget := filepath.Join(allowedDir, "actual.txt") + if err := os.WriteFile(allowedTarget, []byte("allowed content"), 0o644); err != nil { + t.Fatalf("创建实际文件失败: %v", err) + } + if err := os.Symlink(allowedTarget, linkFile); err != nil { + t.Fatalf("创建符号链接失败: %v", err) + } + + handler := NewStaticHandler(allowedDir, "/", nil, false) + handler.SetSymlinkCheck(true) + + ctx := newTestContext(t, "/link.txt") + handler.Handle(ctx) + + if got := ctx.Response.StatusCode(); got != fasthttp.StatusOK { + t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusOK) + } + }) + + t.Run("不安全符号链接 - 指向允许范围外", func(t *testing.T) { + // 创建指向允许目录外的符号链接 + unsafeLink := filepath.Join(allowedDir, "unsafe.txt") + if err := os.Symlink(targetFile, unsafeLink); err != nil { + t.Fatalf("创建不安全符号链接失败: %v", err) + } + + handler := NewStaticHandler(allowedDir, "/", nil, false) + handler.SetSymlinkCheck(true) + + ctx := newTestContext(t, "/unsafe.txt") + handler.Handle(ctx) + + if got := ctx.Response.StatusCode(); got != fasthttp.StatusForbidden { + t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusForbidden) + } + }) + + t.Run("普通文件 - 非符号链接", func(t *testing.T) { + normalFile := filepath.Join(allowedDir, "normal.txt") + if err := os.WriteFile(normalFile, []byte("normal content"), 0o644); err != nil { + t.Fatalf("创建普通文件失败: %v", err) + } + + handler := NewStaticHandler(allowedDir, "/", nil, false) + handler.SetSymlinkCheck(true) + + ctx := newTestContext(t, "/normal.txt") + handler.Handle(ctx) + + if got := ctx.Response.StatusCode(); got != fasthttp.StatusOK { + t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusOK) + } + }) + + t.Run("未启用符号链接检查", func(t *testing.T) { + // 创建指向允许目录外的符号链接 + externalLink := filepath.Join(allowedDir, "external.txt") + if err := os.Symlink(targetFile, externalLink); err != nil { + t.Fatalf("创建外部符号链接失败: %v", err) + } + + handler := NewStaticHandler(allowedDir, "/", nil, false) + // 不启用符号链接检查 + + ctx := newTestContext(t, "/external.txt") + handler.Handle(ctx) + + // 未启用检查时,可以访问符号链接 + if got := ctx.Response.StatusCode(); got != fasthttp.StatusOK { + t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusOK) + } + }) +} + +// TestStaticHandler_ValidateSymlink_WithAlias 测试 alias 模式下的符号链接验证 +func TestStaticHandler_ValidateSymlink_WithAlias(t *testing.T) { + tmpDir := t.TempDir() + + // 创建 alias 目录 + aliasDir := filepath.Join(tmpDir, "alias") + if err := os.MkdirAll(aliasDir, 0o755); err != nil { + t.Fatalf("创建 alias 目录失败: %v", err) + } + + // 在 alias 目录内创建文件和符号链接 + actualFile := filepath.Join(aliasDir, "actual.txt") + if err := os.WriteFile(actualFile, []byte("actual content"), 0o644); err != nil { + t.Fatalf("创建实际文件失败: %v", err) + } + + linkFile := filepath.Join(aliasDir, "link.txt") + if err := os.Symlink(actualFile, linkFile); err != nil { + t.Fatalf("创建符号链接失败: %v", err) + } + + handler := NewStaticHandlerWithAlias(aliasDir, "/static/", nil, false) + handler.SetSymlinkCheck(true) + + ctx := newTestContext(t, "/static/link.txt") + handler.Handle(ctx) + + if got := ctx.Response.StatusCode(); got != fasthttp.StatusOK { + t.Errorf("状态码 = %d, want %d", got, fasthttp.StatusOK) + } +} + +// TestStaticHandler_ValidateSymlink_NoRootOrAlias 测试无 root/alias 时符号链接验证 +func TestStaticHandler_ValidateSymlink_NoRootOrAlias(t *testing.T) { + tmpDir := t.TempDir() + + // 创建文件和符号链接 + targetFile := filepath.Join(tmpDir, "target.txt") + if err := os.WriteFile(targetFile, []byte("target"), 0o644); err != nil { + t.Fatalf("创建目标文件失败: %v", err) + } + + linkFile := filepath.Join(tmpDir, "link.txt") + if err := os.Symlink(targetFile, linkFile); err != nil { + t.Fatalf("创建符号链接失败: %v", err) + } + + // 创建无 root/alias 的处理器 + handler := NewStaticHandler("", "/", nil, false) + handler.SetSymlinkCheck(true) + + // 直接调用 validateSymlink + err := handler.validateSymlink(linkFile) + if err == nil { + t.Error("无 root/alias 时验证符号链接应返回错误") + } +} + +// TestStaticHandler_Handle_WithCacheTTL 测试带 TTL 的缓存处理 +func TestStaticHandler_Handle_WithCacheTTL(t *testing.T) { + tmpDir := t.TempDir() + content := "cached with ttl" + testFile := filepath.Join(tmpDir, "test.txt") + if err := os.WriteFile(testFile, []byte(content), 0o644); err != nil { + t.Fatalf("创建测试文件失败: %v", err) + } + + // 创建带缓存的处理器 + handler := newTestHandler(t, tmpDir) + handler.SetFileCache(cache.NewFileCache(100, 1024*1024, time.Hour)) + handler.SetCacheTTL(5 * time.Second) + + // 第一次请求,填充缓存 + ctx1 := newTestContext(t, "/test.txt") + handler.Handle(ctx1) + + if got := ctx1.Response.StatusCode(); got != fasthttp.StatusOK { + t.Errorf("第一次请求状态码 = %d, want %d", got, fasthttp.StatusOK) + } + + // 第二次请求,应该命中缓存 + ctx2 := newTestContext(t, "/test.txt") + handler.Handle(ctx2) + + if got := ctx2.Response.StatusCode(); got != fasthttp.StatusOK { + t.Errorf("第二次请求状态码 = %d, want %d", got, fasthttp.StatusOK) + } + + // 内容应该一致 + if string(ctx1.Response.Body()) != string(ctx2.Response.Body()) { + t.Error("缓存内容不一致") + } +} + +// TestStaticHandler_Handle_CacheTTLExpired 测试 TTL 过期后重新验证 +func TestStaticHandler_Handle_CacheTTLExpired(t *testing.T) { + tmpDir := t.TempDir() + content := "initial content" + testFile := filepath.Join(tmpDir, "test.txt") + if err := os.WriteFile(testFile, []byte(content), 0o644); err != nil { + t.Fatalf("创建测试文件失败: %v", err) + } + + // 创建带缓存的处理器,TTL 设置很短 + handler := newTestHandler(t, tmpDir) + handler.SetFileCache(cache.NewFileCache(100, 1024*1024, time.Hour)) + handler.SetCacheTTL(100 * time.Millisecond) + + // 第一次请求,填充缓存 + ctx1 := newTestContext(t, "/test.txt") + handler.Handle(ctx1) + + // 等待 TTL 过期 + time.Sleep(150 * time.Millisecond) + + // 修改文件 + newContent := "updated content" + if err := os.WriteFile(testFile, []byte(newContent), 0o644); err != nil { + t.Fatalf("更新文件失败: %v", err) + } + + // 第二次请求,TTL 过期后应该重新读取 + ctx2 := newTestContext(t, "/test.txt") + handler.Handle(ctx2) + + if got := string(ctx2.Response.Body()); got != newContent { + t.Errorf("TTL 过期后内容 = %q, want %q", got, newContent) + } +} + +// TestStaticHandler_Handle_CacheModTimeChanged 测试文件修改后缓存更新 +func TestStaticHandler_Handle_CacheModTimeChanged(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + initialContent := "initial" + if err := os.WriteFile(testFile, []byte(initialContent), 0o644); err != nil { + t.Fatalf("创建测试文件失败: %v", err) + } + + // 创建带缓存的处理器 + handler := newTestHandler(t, tmpDir) + handler.SetFileCache(cache.NewFileCache(100, 1024*1024, time.Hour)) + + // 第一次请求,填充缓存 + ctx1 := newTestContext(t, "/test.txt") + handler.Handle(ctx1) + + // 修改文件 + time.Sleep(10 * time.Millisecond) // 确保 ModTime 变化 + newContent := "modified" + if err := os.WriteFile(testFile, []byte(newContent), 0o644); err != nil { + t.Fatalf("修改文件失败: %v", err) + } + + // 第二次请求,应该检测到文件变化并更新缓存 + ctx2 := newTestContext(t, "/test.txt") + handler.Handle(ctx2) + + if got := string(ctx2.Response.Body()); got != newContent { + t.Errorf("修改后内容 = %q, want %q", got, newContent) + } +} diff --git a/internal/middleware/security/access_coverage_test.go b/internal/middleware/security/access_coverage_test.go index 1595bdb..7793290 100644 --- a/internal/middleware/security/access_coverage_test.go +++ b/internal/middleware/security/access_coverage_test.go @@ -338,3 +338,131 @@ func TestUpdateDenyListError(t *testing.T) { t.Error("UpdateDenyList() should return error for invalid CIDR") } } + +// TestNewAccessControl_TrustedProxiesError 测试可信代理 CIDR 解析错误 +func TestNewAccessControl_TrustedProxiesError(t *testing.T) { + _, err := NewAccessControl(&config.AccessConfig{ + TrustedProxies: []string{"invalid-cidr"}, + }) + if err == nil { + t.Error("NewAccessControl() should return error for invalid trusted proxy CIDR") + } +} + +// TestNewAccessControl_DenyListError 测试拒绝列表 CIDR 解析错误 +func TestNewAccessControl_DenyListError(t *testing.T) { + _, err := NewAccessControl(&config.AccessConfig{ + Deny: []string{"not-a-cidr"}, + }) + if err == nil { + t.Error("NewAccessControl() should return error for invalid deny CIDR") + } +} + +// TestCheck_AllowListHit 测试允许列表命中 +func TestCheck_AllowListHit(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + Allow: []string{"192.168.1.0/24"}, + Default: "deny", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + // 允许列表中的 IP + if !ac.Check(net.ParseIP("192.168.1.50")) { + t.Error("Check() should return true for IP in allow list") + } + + // 不在允许列表中的 IP + if ac.Check(net.ParseIP("10.0.0.1")) { + t.Error("Check() should return false for IP not in allow list") + } +} + +// TestCheck_DenyListHit 测试拒绝列表命中 +func TestCheck_DenyListHit(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + Deny: []string{"10.0.0.100"}, + Allow: []string{"10.0.0.0/8"}, + Default: "allow", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + // 拒绝列表优先 + if ac.Check(net.ParseIP("10.0.0.100")) { + t.Error("Check() should return false for IP in deny list even if in allow list") + } + + // 在允许列表但不在拒绝列表 + if !ac.Check(net.ParseIP("10.0.0.50")) { + t.Error("Check() should return true for IP in allow list but not in deny list") + } +} + +// TestCheck_DefaultAction 测试默认操作 +func TestCheck_DefaultAction(t *testing.T) { + // 默认允许 + acAllow, err := NewAccessControl(&config.AccessConfig{ + Default: "allow", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + if !acAllow.Check(net.ParseIP("1.2.3.4")) { + t.Error("Check() should return true with default allow") + } + + // 默认拒绝 + acDeny, err := NewAccessControl(&config.AccessConfig{ + Default: "deny", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + if acDeny.Check(net.ParseIP("1.2.3.4")) { + t.Error("Check() should return false with default deny") + } +} + +// TestCheck_EmptyDefault 测试空默认值(应默认为 allow) +func TestCheck_EmptyDefault(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + Default: "", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + if !ac.Check(net.ParseIP("1.2.3.4")) { + t.Error("Check() should return true with empty default (should be allow)") + } +} + +// TestAccessControl_ProcessDenied 测试 Process 拒绝请求 +func TestAccessControl_ProcessDenied(t *testing.T) { + ac, err := NewAccessControl(&config.AccessConfig{ + Deny: []string{"192.168.1.0/24"}, + Default: "allow", + }) + if err != nil { + t.Fatalf("NewAccessControl() error: %v", err) + } + + nextCalled := false + handler := ac.Process(func(ctx *fasthttp.RequestCtx) { + nextCalled = true + }) + + ctx := &fasthttp.RequestCtx{} + ctx.SetRemoteAddr(&net.TCPAddr{IP: net.ParseIP("192.168.1.50"), Port: 12345}) + handler(ctx) + + if nextCalled { + t.Error("Process() should not call next handler for denied IP") + } + if ctx.Response.StatusCode() != fasthttp.StatusForbidden { + t.Errorf("Process() status = %d, want 403", ctx.Response.StatusCode()) + } +} diff --git a/internal/middleware/security/auth_request_test.go b/internal/middleware/security/auth_request_test.go index f8e7cc6..93cf1be 100644 --- a/internal/middleware/security/auth_request_test.go +++ b/internal/middleware/security/auth_request_test.go @@ -411,3 +411,162 @@ func BenchmarkAuthRequestExpandVars(b *testing.B) { func contains(s, substr string) bool { return strings.Contains(s, substr) } + +// TestAuthRequest_Middleware 测试 Middleware 方法 +func TestAuthRequest_Middleware(t *testing.T) { + ar := &AuthRequest{} + + mw := ar.Middleware() + if mw == nil { + t.Error("Middleware() should not return nil") + } +} + +// TestAuthRequestExpandVars_Empty 测试空模板展开 +func TestAuthRequestExpandVars_Empty(t *testing.T) { + ar := &AuthRequest{config: config.AuthRequestConfig{}} + ctx := &fasthttp.RequestCtx{} + + result := ar.expandVars(ctx, "") + if result != "" { + t.Errorf("expandVars(empty) = %q, want empty", result) + } +} + +// TestAuthRequestExpandVars_NoVars 测试无变量模板 +func TestAuthRequestExpandVars_NoVars(t *testing.T) { + ar := &AuthRequest{config: config.AuthRequestConfig{}} + ctx := &fasthttp.RequestCtx{} + + result := ar.expandVars(ctx, "http://auth-service/verify") + if result != "http://auth-service/verify" { + t.Errorf("expandVars(no vars) = %q, want original", result) + } +} + +// TestUpdateConfig_RelativePath 测试更新为相对路径配置 +func TestUpdateConfig_RelativePath(t *testing.T) { + cfg := config.AuthRequestConfig{ + Enabled: true, + URI: "http://auth-service:8080/auth", + } + + ar, err := NewAuthRequest(cfg) + if err != nil { + t.Fatalf("NewAuthRequest() failed: %v", err) + } + + // 更新为相对路径 + newCfg := config.AuthRequestConfig{ + Enabled: true, + URI: "/auth/verify", + } + err = ar.UpdateConfig(newCfg) + if err != nil { + t.Errorf("UpdateConfig() failed: %v", err) + } + + // 验证 client 被清空(相对路径不需要独立客户端) + if ar.client != nil { + t.Error("client should be nil for relative path") + } +} + +// TestAuthRequest_ProcessEnabled 测试 Process 处理启用状态 +func TestAuthRequest_ProcessEnabled(t *testing.T) { + cfg := config.AuthRequestConfig{ + Enabled: true, + URI: "/auth/verify", + Method: "GET", + Timeout: 5 * time.Second, + ForwardHeaders: []string{"X-Custom-Header"}, + Headers: map[string]string{"X-Auth-Source": "lolly"}, + } + + ar, err := NewAuthRequest(cfg) + if err != nil { + t.Fatalf("NewAuthRequest() failed: %v", err) + } + + next := func(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(200) + } + + handler := ar.Process(next) + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod("GET") + ctx.Request.SetRequestURI("/test") + ctx.Request.Header.Set("X-Custom-Header", "custom-value") + + // 执行处理器(由于没有实际的认证服务,请求会失败) + handler(ctx) + + // 验证处理器行为 - 由于认证服务不可达,应该返回 500 + if ctx.Response.StatusCode() != 500 { + t.Logf("Status = %d (expected 500 due to unreachable auth service)", ctx.Response.StatusCode()) + } +} + +// TestParseAuthURL_Invalid 测试解析无效 URL +func TestParseAuthURL_Invalid(t *testing.T) { + tests := []struct { + name string + url string + }{ + {"empty string", ""}, + {"only protocol", "http://"}, + {"only https protocol", "https://"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, err := parseAuthURL(tt.url) + if err == nil { + t.Errorf("parseAuthURL(%q) should return error", tt.url) + } + }) + } +} + +// TestExpandVars_EdgeCases 测试变量展开边缘情况 +func TestExpandVars_EdgeCases(t *testing.T) { + ar := &AuthRequest{config: config.AuthRequestConfig{}} + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/path?query=value") + ctx.Request.Header.SetHost("example.com") + + // 测试带多个变量的模板 + result := ar.expandVars(ctx, "http://auth?uri=$request_uri&host=$host&method=$request_method") + if result == "" { + t.Error("expandVars should return non-empty result") + } + + // 测试只有 $ 符号的模板 + result = ar.expandVars(ctx, "http://auth?param=$") + if result != "http://auth?param=$" { + t.Errorf("expandVars with lone $ should return original, got %q", result) + } +} + +// TestInitClient_HTTPS 测试 HTTPS 客户端初始化 +func TestInitClient_HTTPS(t *testing.T) { + cfg := config.AuthRequestConfig{ + Enabled: true, + URI: "https://secure-auth:8443/verify", + Method: "GET", + Timeout: 10 * time.Second, + } + + ar, err := NewAuthRequest(cfg) + if err != nil { + t.Fatalf("NewAuthRequest() failed: %v", err) + } + + if ar.client == nil { + t.Error("client should be initialized for HTTPS URL") + } + if !ar.client.IsTLS { + t.Error("client.IsTLS should be true for HTTPS URL") + } +} diff --git a/internal/middleware/security/auth_test.go b/internal/middleware/security/auth_test.go index d7d9b77..d508a01 100644 --- a/internal/middleware/security/auth_test.go +++ b/internal/middleware/security/auth_test.go @@ -748,3 +748,54 @@ func TestName(t *testing.T) { t.Errorf("Expected name 'basic_auth', got %s", auth.Name()) } } + +// TestAuthenticate_UnknownAlgorithm 测试未知算法 +func TestAuthenticate_UnknownAlgorithm(t *testing.T) { + auth := &BasicAuth{ + users: map[string]string{"admin": "$2b$12$hash"}, + algorithm: HashAlgorithm(99), // 未知算法 + } + + result := auth.Authenticate("admin", "password") + if result { + t.Error("Authenticate() should return false for unknown algorithm") + } +} + +// TestAuthenticateBcrypt_Error 测试 bcrypt 验证错误路径 +func TestAuthenticateBcrypt_Error(t *testing.T) { + // 测试无效的 bcrypt 哈希 + result := authenticateBcrypt("password", "invalid_hash") + if result { + t.Error("authenticateBcrypt() should return false for invalid hash") + } +} + +// TestParseArgon2idHash_InvalidParts 测试无效的 argon2id 哈希格式 +func TestParseArgon2idHash_InvalidParts(t *testing.T) { + tests := []struct { + name string + hash string + }{ + {"too few parts", "$argon2id$v=19$m=32,t=2,p=2"}, + {"wrong algorithm", "$bcrypt$v=19$m=32,t=2,p=2$salt$hash"}, + {"wrong version", "$argon2id$v=18$m=32,t=2,p=2$salt$hash"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, _, err := parseArgon2idHash(tt.hash) + if err == nil { + t.Errorf("parseArgon2idHash(%q) should return error", tt.hash) + } + }) + } +} + +// TestValidatePasswordHash_UnknownAlgorithm 测试未知算法的密码哈希验证 +func TestValidatePasswordHash_UnknownAlgorithm(t *testing.T) { + err := validatePasswordHash("hash", HashAlgorithm(99)) + if err == nil { + t.Error("validatePasswordHash() should return error for unknown algorithm") + } +} diff --git a/internal/middleware/security/headers_test.go b/internal/middleware/security/headers_test.go index 9021c88..6b4f77d 100644 --- a/internal/middleware/security/headers_test.go +++ b/internal/middleware/security/headers_test.go @@ -254,3 +254,137 @@ func TestFormatHSTSValue(t *testing.T) { }) } } + +// TestNewHeadersWithHSTS 测试带 HSTS 配置的创建 +func TestNewHeadersWithHSTS(t *testing.T) { + cfg := &config.SecurityHeaders{ + XFrameOptions: "DENY", + } + + hstsCfg := &config.HSTSConfig{ + MaxAge: 86400, + IncludeSubDomains: true, + Preload: false, + } + + sh := NewHeadersWithHSTS(cfg, hstsCfg) + if sh == nil { + t.Error("NewHeadersWithHSTS() should not return nil") + } + if sh.hsts != "max-age=86400; includeSubDomains" { + t.Errorf("HSTS = %q, want 'max-age=86400; includeSubDomains'", sh.hsts) + } +} + +// TestNewHeadersWithHSTS_NilConfig 测试 nil HSTS 配置 +func TestNewHeadersWithHSTS_NilConfig(t *testing.T) { + sh := NewHeadersWithHSTS(nil, nil) + if sh == nil { + t.Error("NewHeadersWithHSTS() should not return nil") + } + // 应该使用默认配置 + if sh.config == nil { + t.Error("Should use default config when nil is passed") + } +} + +// TestNewHeadersWithHSTS_ZeroMaxAge 测试 MaxAge 为 0 时使用默认值 +func TestNewHeadersWithHSTS_ZeroMaxAge(t *testing.T) { + cfg := &config.SecurityHeaders{ + XFrameOptions: "DENY", + } + + hstsCfg := &config.HSTSConfig{ + MaxAge: 0, // 应该使用默认值 31536000 + IncludeSubDomains: true, + Preload: false, + } + + sh := NewHeadersWithHSTS(cfg, hstsCfg) + if sh.hsts != "max-age=31536000; includeSubDomains" { + t.Errorf("HSTS with zero maxAge should use default: %q", sh.hsts) + } +} + +// TestHeadersProcess_NilConfig 测试 Process 处理 nil config +func TestHeadersProcess_NilConfig(t *testing.T) { + sh := NewHeadersWithHSTS(nil, nil) + + nextCalled := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + nextCalled = true + } + + handler := sh.Process(nextHandler) + ctx := &fasthttp.RequestCtx{} + handler(ctx) + + if !nextCalled { + t.Error("Next handler should be called") + } + + // 验证默认安全头被设置 + if string(ctx.Response.Header.Peek("X-Content-Type-Options")) != "nosniff" { + t.Error("Default X-Content-Type-Options should be set") + } +} + +// TestHeadersProcess_TLS 测试 TLS 情况下 HSTS 头设置 +func TestHeadersProcess_TLS(t *testing.T) { + cfg := &config.SecurityHeaders{ + XFrameOptions: "DENY", + } + + hstsCfg := &config.HSTSConfig{ + MaxAge: 31536000, + IncludeSubDomains: true, + Preload: true, + } + + sh := NewHeadersWithHSTS(cfg, hstsCfg) + + nextHandler := func(ctx *fasthttp.RequestCtx) {} + + handler := sh.Process(nextHandler) + ctx := &fasthttp.RequestCtx{} + // 注意:在测试环境中无法真正模拟 TLS 连接 + // 这个测试验证 handler 不会 panic + handler(ctx) +} + +// TestAddHeaders_AllHeaders 测试所有安全头设置 +func TestAddHeaders_AllHeaders(t *testing.T) { + cfg := &config.SecurityHeaders{ + XFrameOptions: "SAMEORIGIN", + XContentTypeOptions: "nosniff", + ContentSecurityPolicy: "default-src 'self'", + ReferrerPolicy: "strict-origin", + PermissionsPolicy: "camera=(), microphone=()", + } + + sh := NewHeaders(cfg) + + nextHandler := func(ctx *fasthttp.RequestCtx) {} + handler := sh.Process(nextHandler) + + ctx := &fasthttp.RequestCtx{} + handler(ctx) + + headers := &ctx.Response.Header + + if string(headers.Peek("X-Frame-Options")) != "SAMEORIGIN" { + t.Error("X-Frame-Options not set correctly") + } + if string(headers.Peek("X-Content-Type-Options")) != "nosniff" { + t.Error("X-Content-Type-Options not set correctly") + } + if string(headers.Peek("Content-Security-Policy")) != "default-src 'self'" { + t.Error("Content-Security-Policy not set correctly") + } + if string(headers.Peek("Referrer-Policy")) != "strict-origin" { + t.Error("Referrer-Policy not set correctly") + } + if string(headers.Peek("Permissions-Policy")) != "camera=(), microphone=()" { + t.Error("Permissions-Policy not set correctly") + } +} diff --git a/internal/middleware/security/ratelimit_test.go b/internal/middleware/security/ratelimit_test.go index 77f117f..abebde1 100644 --- a/internal/middleware/security/ratelimit_test.go +++ b/internal/middleware/security/ratelimit_test.go @@ -459,3 +459,392 @@ func TestConnLimiterMiddleware(t *testing.T) { t.Errorf("Expected name 'conn_limiter', got %s", middleware.Name()) } } + +// TestNewRateLimiter_SlidingWindow 测试滑动窗口算法 +func TestNewRateLimiter_SlidingWindow(t *testing.T) { + mw, err := NewRateLimiter(&config.RateLimitConfig{ + RequestRate: 100, + Burst: 100, + Algorithm: "sliding_window", + }) + if err != nil { + t.Fatalf("NewRateLimiter() error: %v", err) + } + if mw == nil { + t.Error("Expected non-nil middleware for sliding_window") + } + if mw.Name() != "sliding_window_limiter" { + t.Errorf("Expected name 'sliding_window_limiter', got %s", mw.Name()) + } +} + +// TestNewRateLimiter_SlidingWindowDefault 测试滑动窗口默认配置 +func TestNewRateLimiter_SlidingWindowDefault(t *testing.T) { + mw, err := NewRateLimiter(&config.RateLimitConfig{ + RequestRate: 100, + Burst: 100, + Algorithm: "sliding_window", + SlidingWindow: 0, // 使用默认值 + SlidingWindowMode: "", // 使用默认值 + }) + if err != nil { + t.Fatalf("NewRateLimiter() error: %v", err) + } + if mw == nil { + t.Error("Expected non-nil middleware") + } +} + +// TestNewRateLimiter_SlidingWindowPrecise 测试滑动窗口精确模式 +func TestNewRateLimiter_SlidingWindowPrecise(t *testing.T) { + mw, err := NewRateLimiter(&config.RateLimitConfig{ + RequestRate: 100, + Burst: 100, + Algorithm: "sliding_window", + SlidingWindow: 1, + SlidingWindowMode: "precise", + }) + if err != nil { + t.Fatalf("NewRateLimiter() error: %v", err) + } + if mw == nil { + t.Error("Expected non-nil middleware") + } +} + +// TestNewRateLimiter_UnknownAlgorithm 测试未知算法 +func TestNewRateLimiter_UnknownAlgorithm(t *testing.T) { + _, err := NewRateLimiter(&config.RateLimitConfig{ + RequestRate: 100, + Burst: 100, + Algorithm: "unknown_algo", + }) + if err == nil { + t.Error("NewRateLimiter() should return error for unknown algorithm") + } +} + +// TestSlidingWindowLimiterWrapper_Process 测试滑动窗口包装器的 Process 方法 +func TestSlidingWindowLimiterWrapper_Process(t *testing.T) { + mw, err := NewRateLimiter(&config.RateLimitConfig{ + RequestRate: 100, + Burst: 100, + Algorithm: "sliding_window", + }) + if err != nil { + t.Fatalf("NewRateLimiter() error: %v", err) + } + + called := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + called = true + } + + handler := mw.Process(nextHandler) + if handler == nil { + t.Fatal("Process() returned nil handler") + } + + ctx := testutil.NewRequestCtx("GET", "/test") + handler(ctx) + + if !called { + t.Error("Next handler should be called when rate limit allows") + } +} + +// TestSlidingWindowLimiterWrapper_ProcessDenied 测试滑动窗口拒绝请求 +func TestSlidingWindowLimiterWrapper_ProcessDenied(t *testing.T) { + mw, err := NewRateLimiter(&config.RateLimitConfig{ + RequestRate: 1, + Burst: 1, + Algorithm: "sliding_window", + SlidingWindowMode: "precise", + }) + if err != nil { + t.Fatalf("NewRateLimiter() error: %v", err) + } + + callCount := 0 + nextHandler := func(ctx *fasthttp.RequestCtx) { + callCount++ + } + + handler := mw.Process(nextHandler) + ctx := testutil.NewRequestCtx("GET", "/test") + + // 第一个请求应该被允许 + handler(ctx) + // 第二个请求应该被拒绝 + handler(ctx) + + if callCount != 1 { + t.Errorf("Expected next handler to be called once, got %d", callCount) + } +} + +// TestRateLimiter_GetRetryAfter 测试 getRetryAfter 方法 +func TestRateLimiter_GetRetryAfter(t *testing.T) { + mw, err := NewRateLimiter(&config.RateLimitConfig{ + RequestRate: 10, + Burst: 10, + }) + if err != nil { + t.Fatalf("NewRateLimiter() error: %v", err) + } + + rl, ok := mw.(*RateLimiter) + if !ok { + t.Fatalf("Expected *RateLimiter, got %T", mw) + } + defer rl.StopCleanup() + + // 测试不存在的键 + retryAfter := rl.getRetryAfter("nonexistent") + if retryAfter != 1 { + t.Errorf("getRetryAfter(nonexistent) = %d, want 1", retryAfter) + } + + // 创建一个桶并消耗令牌 + key := "test-key" + for i := 0; i < 10; i++ { + rl.Allow(key) + } + + // 获取重试时间 + retryAfter = rl.getRetryAfter(key) + if retryAfter < 1 { + t.Errorf("getRetryAfter() = %d, want at least 1", retryAfter) + } +} + +// TestKeyByIP 测试 keyByIP 函数 +func TestKeyByIP(t *testing.T) { + ctx := testutil.NewRequestCtx("GET", "/test") + + key := keyByIP(ctx) + if key == "" { + t.Error("keyByIP() should return non-empty string") + } + if key == "unknown" { + t.Error("keyByIP() should return IP address, not 'unknown'") + } +} + +// TestKeyByHeader 测试 keyByHeader 函数 +func TestKeyByHeader(t *testing.T) { + // 测试有头部的情况 + ctx := testutil.NewRequestCtx("GET", "/test") + ctx.Request.Header.Set("X-RateLimit-Key", "custom-key") + + key := keyByHeader(ctx) + if key != "custom-key" { + t.Errorf("keyByHeader() = %q, want 'custom-key'", key) + } + + // 测试没有头部的情况(应该回退到 IP) + ctx2 := testutil.NewRequestCtx("GET", "/test") + key2 := keyByHeader(ctx2) + if key2 == "" { + t.Error("keyByHeader() should fallback to IP when header not set") + } +} + +// TestConnLimiter_PerKey 测试按键限制 +func TestConnLimiter_PerKey(t *testing.T) { + cl, err := NewConnLimiter(2, true, "ip") + if err != nil { + t.Fatalf("NewConnLimiter() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/") + + // 同一个键的前两个应该成功 + if !cl.Acquire(ctx) { + t.Error("Expected first acquire to succeed") + } + if !cl.Acquire(ctx) { + t.Error("Expected second acquire to succeed") + } + + // 第三个应该失败 + if cl.Acquire(ctx) { + t.Error("Expected third acquire to fail") + } + + // 释放一个 + cl.Release(ctx) + + // 现在应该成功 + if !cl.Acquire(ctx) { + t.Error("Expected acquire after release to succeed") + } +} + +// TestConnLimiter_ReleaseUnderflow 测试 Release 下溢保护 +func TestConnLimiter_ReleaseUnderflow(t *testing.T) { + cl, err := NewConnLimiter(2, true, "ip") + if err != nil { + t.Fatalf("NewConnLimiter() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/") + + // 在没有 Acquire 的情况下 Release(测试下溢保护) + cl.Release(ctx) // 不应该 panic + + // 验证计数不会变成负数 + cl.Acquire(ctx) + cl.Acquire(ctx) + if cl.Acquire(ctx) { + t.Error("Expected third acquire to fail") + } +} + +// TestConnLimiterMiddleware_Process 测试连接限制中间件 Process +func TestConnLimiterMiddleware_Process(t *testing.T) { + cl, err := NewConnLimiter(1, false, "") + if err != nil { + t.Fatalf("NewConnLimiter() error: %v", err) + } + + mw := cl.Middleware() + + called := false + nextHandler := func(ctx *fasthttp.RequestCtx) { + called = true + ctx.SetStatusCode(200) + } + + handler := mw.Process(nextHandler) + ctx := testutil.NewRequestCtx("GET", "/test") + + // 第一个请求应该成功 + handler(ctx) + if !called { + t.Error("Next handler should be called") + } + if ctx.Response.StatusCode() != 200 { + t.Errorf("Status = %d, want 200", ctx.Response.StatusCode()) + } +} + +// TestConnLimiterMiddleware_ProcessLimitExceeded 测试连接限制超出 +func TestConnLimiterMiddleware_ProcessLimitExceeded(t *testing.T) { + cl, err := NewConnLimiter(1, false, "") + if err != nil { + t.Fatalf("NewConnLimiter() error: %v", err) + } + + mw := cl.Middleware() + + callCount := 0 + nextHandler := func(ctx *fasthttp.RequestCtx) { + callCount++ + ctx.SetStatusCode(200) + } + + handler := mw.Process(nextHandler) + + // 用尽连接限制 + ctx1 := testutil.NewRequestCtx("GET", "/test1") + cl.Acquire(ctx1) // 手动占用一个槽位 + + // 现在应该无法获取新的连接 + ctx2 := testutil.NewRequestCtx("GET", "/test2") + handler(ctx2) + + if callCount != 0 { + t.Error("Next handler should NOT be called when limit exceeded") + } + if ctx2.Response.StatusCode() != 503 { + t.Errorf("Status = %d, want 503", ctx2.Response.StatusCode()) + } +} + +// TestNewSlidingWindowLimiterWrapper_Error 测试滑动窗口包装器错误情况 +func TestNewSlidingWindowLimiterWrapper_Error(t *testing.T) { + _, err := NewSlidingWindowLimiterWrapper(&config.RateLimitConfig{ + RequestRate: 100, + Key: "invalid_key_type", + }, time.Second, false) + if err == nil { + t.Error("NewSlidingWindowLimiterWrapper should return error for invalid key type") + } +} + +// TestRateLimiter_Name 测试 RateLimiter Name 方法 +func TestRateLimiter_Name(t *testing.T) { + mw, err := NewRateLimiter(&config.RateLimitConfig{ + RequestRate: 10, + Burst: 10, + }) + if err != nil { + t.Fatalf("NewRateLimiter() error: %v", err) + } + + rl, ok := mw.(*RateLimiter) + if !ok { + t.Fatalf("Expected *RateLimiter, got %T", mw) + } + defer rl.StopCleanup() + + if rl.Name() != "rate_limiter" { + t.Errorf("Name() = %q, want 'rate_limiter'", rl.Name()) + } +} + +// TestRateLimiter_ProcessDenied 测试限流拒绝 +func TestRateLimiter_ProcessDenied(t *testing.T) { + mw, err := NewRateLimiter(&config.RateLimitConfig{ + RequestRate: 1, + Burst: 1, + }) + if err != nil { + t.Fatalf("NewRateLimiter() error: %v", err) + } + + rl, ok := mw.(*RateLimiter) + if !ok { + t.Fatalf("Expected *RateLimiter, got %T", mw) + } + defer rl.StopCleanup() + + callCount := 0 + nextHandler := func(ctx *fasthttp.RequestCtx) { + callCount++ + } + + handler := rl.Process(nextHandler) + + // 第一个请求应该成功 + ctx1 := testutil.NewRequestCtx("GET", "/test") + handler(ctx1) + + // 第二个请求应该被限流(使用不同的 context) + ctx2 := testutil.NewRequestCtx("GET", "/test") + handler(ctx2) + + if callCount != 1 { + t.Errorf("Expected next handler to be called once, got %d", callCount) + } + + // 检查第二个请求的状态码 + if ctx2.Response.StatusCode() != 429 { + t.Errorf("Status = %d, want 429", ctx2.Response.StatusCode()) + } +} + +// TestKeyByIP_Unknown 测试无法获取 IP 的情况 +func TestKeyByIP_Unknown(t *testing.T) { + // 创建一个没有设置远程地址的上下文 + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test") + + key := keyByIP(ctx) + // netutil.ExtractClientIPNet 会返回默认值 0.0.0.0 而不是 nil + // 所以这里验证返回值不是空的即可 + if key == "" { + t.Error("keyByIP() should return non-empty string") + } +} diff --git a/internal/proxy/proxy_coverage_extra_test.go b/internal/proxy/proxy_coverage_extra_test.go new file mode 100644 index 0000000..2397e5a --- /dev/null +++ b/internal/proxy/proxy_coverage_extra_test.go @@ -0,0 +1,1636 @@ +// Package proxy 提供额外的覆盖测试,补充低覆盖率函数的测试。 +// +// 该文件测试以下功能: +// - HealthChecker.MarkHealthy 和 run 方法 +// - selectByLua 和 selectByFallback 方法 +// - rewriteCookies 和 rewriteCookieAttr 函数 +// - modifyResponseHeaders 边缘情况 +// - createHostClient 完整选项 +// - TempFileManager 和 TempFileCleaner getter 方法 +// - NewRedirectRewriter 正则规则和 RewriteRefreshOnly +// - rewriteCustom 正则模式 +// - selectTarget 边缘情况 +// +// 作者:xfy +package proxy + +import ( + "net" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttputil" + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/loadbalance" + "rua.plus/lolly/internal/lua" + "rua.plus/lolly/internal/testutil" +) + +// TestHealthChecker_MarkHealthy 测试 MarkHealthy 方法。 +func TestHealthChecker_MarkHealthy(t *testing.T) { + t.Run("标记健康状态", func(t *testing.T) { + target := &loadbalance.Target{URL: "http://backend:8080"} + target.Healthy.Store(false) + + checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ + Interval: 10 * time.Second, + Timeout: 5 * time.Second, + }) + + checker.MarkHealthy(target) + + if !target.Healthy.Load() { + t.Error("MarkHealthy() 后 target 应标记为 healthy") + } + }) + + t.Run("重置失败计数", func(t *testing.T) { + target := &loadbalance.Target{URL: "http://backend:8080"} + target.Healthy.Store(false) + target.RecordFailure() + target.RecordFailure() + target.RecordFailure() + + checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{}) + checker.MarkHealthy(target) + + if !target.Healthy.Load() { + t.Error("MarkHealthy() 后 target 应标记为 healthy") + } + }) + + t.Run("多目标场景", func(t *testing.T) { + target1 := &loadbalance.Target{URL: "http://backend1:8080"} + target1.Healthy.Store(false) + target2 := &loadbalance.Target{URL: "http://backend2:8080"} + target2.Healthy.Store(false) + + checker := NewHealthChecker([]*loadbalance.Target{target1, target2}, &config.HealthCheckConfig{}) + checker.MarkHealthy(target1) + + if !target1.Healthy.Load() { + t.Error("target1 应标记为 healthy") + } + if target2.Healthy.Load() { + t.Error("target2 应保持 unhealthy") + } + }) +} + +// TestHealthChecker_Run 测试 run 方法。 +func TestHealthChecker_Run(t *testing.T) { + t.Run("初始检查执行", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + target := &loadbalance.Target{URL: server.URL} + target.Healthy.Store(false) + + checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ + Interval: 1 * time.Hour, + Timeout: 5 * time.Second, + Path: "/health", + }) + + // 启动检查器 + checker.Start() + + // 等待初始检查完成 + time.Sleep(50 * time.Millisecond) + + // 验证初始检查已执行 + if !target.Healthy.Load() { + t.Error("初始检查后 target 应标记为 healthy") + } + + checker.Stop() + }) + + t.Run("定时检查执行", func(t *testing.T) { + requestCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + target := &loadbalance.Target{URL: server.URL} + target.Healthy.Store(true) + + checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ + Interval: 50 * time.Millisecond, + Timeout: 5 * time.Second, + Path: "/health", + }) + + checker.Start() + time.Sleep(120 * time.Millisecond) + checker.Stop() + + // 应该至少执行初始检查 + 2 次定时检查 + if requestCount < 2 { + t.Errorf("期望至少 2 次检查,实际 %d 次", requestCount) + } + }) + + t.Run("停止后不再检查", func(t *testing.T) { + requestCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + target := &loadbalance.Target{URL: server.URL} + target.Healthy.Store(true) + + checker := NewHealthChecker([]*loadbalance.Target{target}, &config.HealthCheckConfig{ + Interval: 50 * time.Millisecond, + Timeout: 5 * time.Second, + }) + + checker.Start() + time.Sleep(60 * time.Millisecond) + checker.Stop() + countAfterStop := requestCount + + // 等待一段时间,确认不再有检查 + time.Sleep(100 * time.Millisecond) + + if requestCount != countAfterStop { + t.Error("停止后不应再执行检查") + } + }) +} + +// TestSelectByFallback 测试 selectByFallback 方法。 +func TestSelectByFallback(t *testing.T) { + t.Run("round_robin fallback", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + BalancerByLua: config.BalancerByLuaConfig{ + Fallback: "round_robin", + }, + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + } + for _, t := range targets { + t.Healthy.Store(true) + } + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + selected := p.selectByFallback(ctx, targets) + + if selected == nil { + t.Error("selectByFallback() should return a target") + } + }) + + t.Run("ip_hash fallback", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + BalancerByLua: config.BalancerByLuaConfig{ + Fallback: "ip_hash", + }, + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + } + for _, t := range targets { + t.Healthy.Store(true) + } + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{ + "X-Forwarded-For": "192.168.1.1", + }) + + selected := p.selectByFallback(ctx, targets) + if selected == nil { + t.Error("selectByFallback() should return a target for ip_hash") + } + + // 相同 IP 应返回相同目标 + selected2 := p.selectByFallback(ctx, targets) + if selected2 == nil || selected.URL != selected2.URL { + t.Error("ip_hash should consistently return same target for same IP") + } + }) +} + +// TestSelectByLua 测试 selectByLua 方法。 +func TestSelectByLua(t *testing.T) { + t.Run("有 Lua 引擎但脚本不存在", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + BalancerByLua: config.BalancerByLuaConfig{ + Enabled: true, + Script: "/nonexistent/script.lua", + }, + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + } + targets[0].Healthy.Store(true) + + luaEngine, err := lua.NewEngine(nil) + if err != nil { + t.Fatalf("NewEngine() error: %v", err) + } + p, err := NewProxy(cfg, targets, nil, luaEngine) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + + _, err = p.selectByLua(ctx, targets) + if err == nil { + t.Error("selectByLua() should return error for nonexistent script") + } + }) + + t.Run("Lua 引擎正常工作但脚本返回错误", func(t *testing.T) { + // 创建临时 Lua 脚本 + tmpFile, err := os.CreateTemp("", "test_*.lua") + if err != nil { + t.Fatalf("创建临时文件失败: %v", err) + } + defer os.Remove(tmpFile.Name()) + + // 写入一个会报错的脚本 + _, _ = tmpFile.WriteString("error('test error')") + + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + BalancerByLua: config.BalancerByLuaConfig{ + Enabled: true, + Script: tmpFile.Name(), + }, + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + } + targets[0].Healthy.Store(true) + + luaEngine, err := lua.NewEngine(nil) + if err != nil { + t.Fatalf("NewEngine() error: %v", err) + } + p, err := NewProxy(cfg, targets, nil, luaEngine) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + + _, err = p.selectByLua(ctx, targets) + // 脚本执行错误应该返回错误 + if err == nil { + t.Error("selectByLua() should return error for script error") + } + }) +} + +// TestRewriteCookies 测试 rewriteCookies 方法。 +func TestRewriteCookies(t *testing.T) { + tests := []struct { + name string + cookies []string + cookieDomain string + cookiePath string + wantContains []string + wantNotContains []string + }{ + { + name: "改写 Domain", + cookies: []string{"session=abc123; Domain=old.example.com; Path=/"}, + cookieDomain: "new.example.com", + wantContains: []string{"Domain=new.example.com"}, + }, + { + name: "改写 Path", + cookies: []string{"session=abc123; Domain=example.com; Path=/old/"}, + cookiePath: "/new/", + wantContains: []string{"Path=/new/"}, + }, + { + name: "同时改写 Domain 和 Path", + cookies: []string{"session=abc123; Domain=old.example.com; Path=/old/"}, + cookieDomain: "new.example.com", + cookiePath: "/new/", + wantContains: []string{"Domain=new.example.com", "Path=/new/"}, + }, + { + name: "无 Domain 属性时不改写", + cookies: []string{"session=abc123"}, + cookiePath: "/new/", + wantContains: []string{"session=abc123"}, + }, + { + name: "空配置不改写", + cookies: []string{"session=abc123; Domain=example.com"}, + cookieDomain: "", + cookiePath: "", + wantContains: []string{"Domain=example.com"}, + }, + { + name: "大小写不敏感匹配", + cookies: []string{"session=abc123; domain=old.example.com; path=/old/"}, + cookieDomain: "new.example.com", + cookiePath: "/new/", + wantContains: []string{"domain=new.example.com", "path=/new/"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Headers: config.ProxyHeaders{ + CookieDomain: tt.cookieDomain, + CookiePath: tt.cookiePath, + }, + } + + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + for _, cookie := range tt.cookies { + ctx.Response.Header.Set("Set-Cookie", cookie) + } + + p.modifyResponseHeaders(ctx) + + cookies := strings.Split(string(ctx.Response.Header.Peek("Set-Cookie")), ";") + cookieStr := string(ctx.Response.Header.Peek("Set-Cookie")) + + for _, want := range tt.wantContains { + found := false + for _, c := range cookies { + if strings.Contains(strings.TrimSpace(c), want) || strings.Contains(cookieStr, want) { + found = true + break + } + } + if !found && !strings.Contains(cookieStr, want) { + t.Errorf("cookie 应包含 %q, 实际: %q", want, cookieStr) + } + } + + for _, notWant := range tt.wantNotContains { + if strings.Contains(cookieStr, notWant) { + t.Errorf("cookie 不应包含 %q, 实际: %q", notWant, cookieStr) + } + } + }) + } +} + +// TestRewriteCookieAttr 测试 rewriteCookieAttr 函数。 +func TestRewriteCookieAttr(t *testing.T) { + tests := []struct { + name string + cookie string + attr string + newValue string + want string + }{ + { + name: "改写 Domain", + cookie: "session=abc; Domain=old.com; Path=/", + attr: "Domain", + newValue: "new.com", + want: "session=abc; Domain=new.com; Path=/", + }, + { + name: "改写 Path", + cookie: "session=abc; Domain=example.com; Path=/old", + attr: "Path", + newValue: "/new", + want: "session=abc; Domain=example.com; Path=/new", + }, + { + name: "属性不存在则不改写", + cookie: "session=abc", + attr: "Domain", + newValue: "new.com", + want: "session=abc", + }, + { + name: "大小写不敏感", + cookie: "session=abc; domain=old.com", + attr: "Domain", + newValue: "new.com", + want: "session=abc; domain=new.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := rewriteCookieAttr(tt.cookie, tt.attr, tt.newValue) + if got != tt.want { + t.Errorf("rewriteCookieAttr() = %q, want %q", got, tt.want) + } + }) + } +} + +// TestModifyResponseHeaders_PassResponse 测试 PassResponse 白名单模式。 +func TestModifyResponseHeaders_PassResponse(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Headers: config.ProxyHeaders{ + PassResponse: []string{"Content-Type", "X-Allowed"}, + }, + } + + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + ctx.Response.Header.Set("Content-Type", "application/json") + ctx.Response.Header.Set("X-Allowed", "allowed-value") + ctx.Response.Header.Set("X-Blocked", "blocked-value") + + p.modifyResponseHeaders(ctx) + + // 白名单中的头应保留 + if string(ctx.Response.Header.Peek("Content-Type")) != "application/json" { + t.Error("Content-Type 应被保留") + } + if string(ctx.Response.Header.Peek("X-Allowed")) != "allowed-value" { + t.Error("X-Allowed 应被保留") + } + + // 不在白名单中的头应被删除 + if len(ctx.Response.Header.Peek("X-Blocked")) > 0 { + t.Error("X-Blocked 应被删除") + } +} + +// TestModifyResponseHeaders_HideResponse 测试 HideResponse 功能。 +func TestModifyResponseHeaders_HideResponse(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Headers: config.ProxyHeaders{ + HideResponse: []string{"X-Hidden-Header"}, + }, + } + + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + ctx.Response.Header.Set("X-Hidden-Header", "should-be-hidden") + ctx.Response.Header.Set("X-Visible-Header", "should-be-visible") + + p.modifyResponseHeaders(ctx) + + if len(ctx.Response.Header.Peek("X-Hidden-Header")) > 0 { + t.Error("X-Hidden-Header 应被删除") + } + if string(ctx.Response.Header.Peek("X-Visible-Header")) != "should-be-visible" { + t.Error("X-Visible-Header 应被保留") + } +} + +// TestModifyResponseHeaders_IgnoreHeaders 测试 IgnoreHeaders 功能。 +func TestModifyResponseHeaders_IgnoreHeaders(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Headers: config.ProxyHeaders{ + IgnoreHeaders: []string{"X-Ignored"}, + }, + } + + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + ctx.Request.Header.Set("X-Ignored", "ignored-value") + ctx.Response.Header.Set("X-Ignored", "ignored-response-value") + ctx.Response.Header.Set("X-Not-Ignored", "not-ignored") + + p.modifyResponseHeaders(ctx) + + if len(ctx.Request.Header.Peek("X-Ignored")) > 0 { + t.Error("请求中的 X-Ignored 应被删除") + } + if len(ctx.Response.Header.Peek("X-Ignored")) > 0 { + t.Error("响应中的 X-Ignored 应被删除") + } + if string(ctx.Response.Header.Peek("X-Not-Ignored")) != "not-ignored" { + t.Error("X-Not-Ignored 应被保留") + } +} + +// TestCreateHostClient_TransportConfig 测试 Transport 配置。 +func TestCreateHostClient_TransportConfig(t *testing.T) { + transportCfg := &config.TransportConfig{ + IdleConnTimeout: 60 * time.Second, + MaxConnsPerHost: 50, + } + + client := createHostClient("http://localhost:8080", config.ProxyTimeout{ + Connect: 5 * time.Second, + Read: 30 * time.Second, + Write: 30 * time.Second, + }, transportCfg, nil, "", nil) + + if client == nil { + t.Fatal("createHostClient() returned nil") + } + + if client.MaxIdleConnDuration != 60*time.Second { + t.Errorf("MaxIdleConnDuration = %v, want 60s", client.MaxIdleConnDuration) + } + if client.MaxConns != 50 { + t.Errorf("MaxConns = %d, want 50", client.MaxConns) + } +} + +// TestCreateHostClient_Buffering 测试 Buffering 配置。 +func TestCreateHostClient_Buffering(t *testing.T) { + t.Run("streaming mode", func(t *testing.T) { + buffering := &config.ProxyBufferingConfig{ + Mode: "off", + } + client := createHostClient("http://localhost:8080", config.ProxyTimeout{}, nil, nil, "", buffering) + + if !client.StreamResponseBody { + t.Error("StreamResponseBody should be true when buffering is off") + } + }) + + t.Run("custom buffer size", func(t *testing.T) { + buffering := &config.ProxyBufferingConfig{ + BufferSize: 64 * 1024, + } + client := createHostClient("http://localhost:8080", config.ProxyTimeout{}, nil, nil, "", buffering) + + if client.ReadBufferSize != 64*1024 { + t.Errorf("ReadBufferSize = %d, want 64KB", client.ReadBufferSize) + } + if client.WriteBufferSize != 64*1024 { + t.Errorf("WriteBufferSize = %d, want 64KB", client.WriteBufferSize) + } + }) +} + +// TestCreateHostClient_ProxyBind 测试 ProxyBind 配置。 +func TestCreateHostClient_ProxyBind(t *testing.T) { + // 这个测试只验证 ProxyBind 参数不会导致 panic + client := createHostClient("http://localhost:8080", config.ProxyTimeout{ + Connect: 5 * time.Second, + }, nil, nil, "127.0.0.1", nil) + + if client == nil { + t.Error("createHostClient() returned nil") + } + if client.Dial == nil { + t.Error("Dial should be set when ProxyBind is specified") + } +} + +// TestTempFileManager_GetActiveCount 测试 GetActiveCount 方法。 +func TestTempFileManager_GetActiveCount(t *testing.T) { + manager, err := NewTempFileManager(t.TempDir(), "1mb", "10mb") + if err != nil { + t.Fatalf("NewTempFileManager() error: %v", err) + } + + if manager.GetActiveCount() != 0 { + t.Error("初始活动文件数应为 0") + } + + // 创建临时文件 + tf1, err := manager.CreateTempFile() + if err != nil { + t.Fatalf("CreateTempFile() error: %v", err) + } + + if manager.GetActiveCount() != 1 { + t.Errorf("GetActiveCount() = %d, want 1", manager.GetActiveCount()) + } + + tf2, err := manager.CreateTempFile() + if err != nil { + t.Fatalf("CreateTempFile() error: %v", err) + } + + if manager.GetActiveCount() != 2 { + t.Errorf("GetActiveCount() = %d, want 2", manager.GetActiveCount()) + } + + // 清理 + _ = tf1.Close() + _ = tf2.Close() +} + +// TestDynamicTempFileWriter_GetTotalSize 测试 GetTotalSize 方法。 +func TestDynamicTempFileWriter_GetTotalSize(t *testing.T) { + manager, err := NewTempFileManager(t.TempDir(), "1mb", "10mb") + if err != nil { + t.Fatalf("NewTempFileManager() error: %v", err) + } + + writer := NewDynamicTempFileWriter(manager) + defer writer.Cleanup() + + if writer.GetTotalSize() != 0 { + t.Error("初始总大小应为 0") + } + + data := []byte("test data") + _ = writer.Write(data) + + if writer.GetTotalSize() != int64(len(data)) { + t.Errorf("GetTotalSize() = %d, want %d", writer.GetTotalSize(), len(data)) + } +} + +// TestTempFileCleaner_GetInterval_GetMaxAge 测试 getter 方法。 +func TestTempFileCleaner_GetInterval_GetMaxAge(t *testing.T) { + t.Run("默认值", func(t *testing.T) { + cleaner := NewTempFileCleaner(t.TempDir(), 0, 0) + + if cleaner.GetInterval() != DefaultCleanupInterval { + t.Errorf("GetInterval() = %v, want %v", cleaner.GetInterval(), DefaultCleanupInterval) + } + if cleaner.GetMaxAge() != DefaultMaxFileAge { + t.Errorf("GetMaxAge() = %v, want %v", cleaner.GetMaxAge(), DefaultMaxFileAge) + } + }) + + t.Run("自定义值", func(t *testing.T) { + cleaner := NewTempFileCleaner(t.TempDir(), 10*time.Second, 30*time.Minute) + + if cleaner.GetInterval() != 10*time.Second { + t.Errorf("GetInterval() = %v, want 10s", cleaner.GetInterval()) + } + if cleaner.GetMaxAge() != 30*time.Minute { + t.Errorf("GetMaxAge() = %v, want 30m", cleaner.GetMaxAge()) + } + }) +} + +// TestNewRedirectRewriter_RegexRules 测试正则规则。 +func TestNewRedirectRewriter_RegexRules(t *testing.T) { + t.Run("正则模式", func(t *testing.T) { + cfg := &config.RedirectRewriteConfig{ + Mode: "custom", + Rules: []config.RedirectRewriteRule{ + {Pattern: "~http://backend:\\d+", Replacement: "http://frontend"}, + }, + } + + rw, err := NewRedirectRewriter(cfg, "/") + if err != nil { + t.Fatalf("NewRedirectRewriter() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/") + resp := &fasthttp.Response{} + resp.Header.Set("Location", "http://backend:8080/api") + resp.SetStatusCode(301) + + rw.RewriteResponse(resp, ctx, "", "frontend") + + got := string(resp.Header.Peek("Location")) + want := "http://frontend/api" + if got != want { + t.Errorf("Location = %q, want %q", got, want) + } + }) + + t.Run("大小写不敏感正则", func(t *testing.T) { + cfg := &config.RedirectRewriteConfig{ + Mode: "custom", + Rules: []config.RedirectRewriteRule{ + // 注意:大小写不敏感模式下,pattern 应该是小写,因为代码会将输入转为小写匹配 + {Pattern: "~*backend", Replacement: "frontend"}, + }, + } + + rw, err := NewRedirectRewriter(cfg, "/") + if err != nil { + t.Fatalf("NewRedirectRewriter() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/") + resp := &fasthttp.Response{} + // 使用大写的 URL 来测试大小写不敏感匹配 + resp.Header.Set("Location", "http://BACKEND/api") + resp.SetStatusCode(301) + + rw.RewriteResponse(resp, ctx, "", "frontend") + + got := string(resp.Header.Peek("Location")) + want := "http://frontend/api" + if got != want { + t.Errorf("Location = %q, want %q", got, want) + } + }) + + t.Run("无效正则返回错误", func(t *testing.T) { + cfg := &config.RedirectRewriteConfig{ + Mode: "custom", + Rules: []config.RedirectRewriteRule{ + {Pattern: "~[invalid", Replacement: "/"}, + }, + } + + _, err := NewRedirectRewriter(cfg, "/") + if err == nil { + t.Error("NewRedirectRewriter() should return error for invalid regex") + } + }) +} + +// TestRedirectRewriter_RewriteRefreshOnly 测试 RewriteRefreshOnly 方法。 +func TestRedirectRewriter_RewriteRefreshOnly(t *testing.T) { + cfg := &config.RedirectRewriteConfig{ + Mode: "custom", + Rules: []config.RedirectRewriteRule{ + {Pattern: "http://backend:8080/", Replacement: "http://frontend/"}, + }, + } + + rw, err := NewRedirectRewriter(cfg, "/") + if err != nil { + t.Fatalf("NewRedirectRewriter() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/") + resp := &fasthttp.Response{} + resp.Header.Set("Refresh", "5; url=http://backend:8080/page") + resp.SetStatusCode(200) // 非 3xx + + rw.RewriteRefreshOnly(resp, ctx, "", "frontend") + + got := string(resp.Header.Peek("Refresh")) + want := "5; url=http://frontend/page" + if got != want { + t.Errorf("Refresh = %q, want %q", got, want) + } +} + +// TestRewriteCustom 测试 rewriteCustom 方法。 +func TestRewriteCustom(t *testing.T) { + t.Run("正则替换", func(t *testing.T) { + cfg := &config.RedirectRewriteConfig{ + Mode: "custom", + Rules: []config.RedirectRewriteRule{ + // 注意:rewriteCustom 不支持捕获组,只是简单替换匹配的部分 + {Pattern: "~http://[a-z]+:\\d+", Replacement: "https://new.example.com"}, + }, + } + + rw, err := NewRedirectRewriter(cfg, "/") + if err != nil { + t.Fatalf("NewRedirectRewriter() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/") + result := rw.rewriteURL("http://backend:8080/api/users", ctx, "", "frontend") + + want := "https://new.example.com/api/users" + if result != want { + t.Errorf("rewriteURL() = %q, want %q", result, want) + } + }) + + t.Run("精确前缀匹配", func(t *testing.T) { + cfg := &config.RedirectRewriteConfig{ + Mode: "custom", + Rules: []config.RedirectRewriteRule{ + {Pattern: "http://old.example.com/", Replacement: "http://new.example.com/"}, + }, + } + + rw, err := NewRedirectRewriter(cfg, "/") + if err != nil { + t.Fatalf("NewRedirectRewriter() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/") + result := rw.rewriteURL("http://old.example.com/page", ctx, "", "frontend") + + want := "http://new.example.com/page" + if result != want { + t.Errorf("rewriteURL() = %q, want %q", result, want) + } + }) + + t.Run("无匹配则原样返回", func(t *testing.T) { + cfg := &config.RedirectRewriteConfig{ + Mode: "custom", + Rules: []config.RedirectRewriteRule{ + {Pattern: "http://other.com/", Replacement: "/"}, + }, + } + + rw, err := NewRedirectRewriter(cfg, "/") + if err != nil { + t.Fatalf("NewRedirectRewriter() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/") + result := rw.rewriteURL("http://example.com/page", ctx, "", "frontend") + + want := "http://example.com/page" + if result != want { + t.Errorf("rewriteURL() = %q, want %q", result, want) + } + }) +} + +// TestSelectTarget_EmptyTargets 测试空目标列表。 +func TestSelectTarget_EmptyTargets(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 清空目标 + p.mu.Lock() + p.targets = nil + p.mu.Unlock() + + ctx := testutil.NewRequestCtx("GET", "/api/test") + selected := p.selectTarget(ctx) + + if selected != nil { + t.Error("selectTarget() should return nil for empty targets") + } +} + +// TestDialTarget 测试 dialTarget 函数。 +func TestDialTarget(t *testing.T) { + t.Run("连接超时", func(t *testing.T) { + // 使用不可达地址测试超时 + _, err := dialTarget("http://10.255.255.1:9999", 100*time.Millisecond) + if err == nil { + t.Error("dialTarget() should return error for unreachable address") + } + }) + + t.Run("HTTPS 连接失败", func(t *testing.T) { + _, err := dialTarget("https://10.255.255.1:9999", 100*time.Millisecond) + if err == nil { + t.Error("dialTarget() should return error for unreachable HTTPS address") + } + }) + + t.Run("成功连接", func(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + defer ln.Close() + + go func() { + conn, _ := ln.Accept() + if conn != nil { + _ = conn.Close() + } + }() + + addr := ln.Addr().String() + conn, err := dialTarget("http://"+addr, 1*time.Second) + if err != nil { + t.Errorf("dialTarget() error: %v", err) + } + if conn != nil { + _ = conn.Close() + } + }) +} + +// TestBackgroundRefresh_Extra 测试 backgroundRefresh 方法的额外场景。 +func TestBackgroundRefresh_Extra(t *testing.T) { + t.Run("客户端不存在时直接返回", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + }, + } + + targets := []*loadbalance.Target{{URL: "http://nonexistent:9999"}} + targets[0].Healthy.Store(true) + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 删除客户端 + p.mu.Lock() + delete(p.clients, targets[0].URL) + p.mu.Unlock() + + ctx := testutil.NewRequestCtx("GET", "/api/test") + hashKey := uint64(12345) + + // 应该不会 panic + p.backgroundRefresh(ctx, targets[0], hashKey, "GET:/api/test") + }) + + t.Run("缓存锁释放", func(t *testing.T) { + ln := fasthttputil.NewInmemoryListener() + defer ln.Close() + + go func() { + s := &fasthttp.Server{ + Handler: func(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(200) + ctx.SetBodyString("refreshed") + }, + } + _ = s.Serve(ln) + }() + + time.Sleep(10 * time.Millisecond) + + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + }, + } + + targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}} + targets[0].Healthy.Store(true) + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + hashKey := uint64(12345) + p.cache.AcquireLock(hashKey) + + p.backgroundRefresh(ctx, targets[0], hashKey, "GET:/api/test") + }) +} + +// TestWebSocket_ErrorCases 测试 WebSocket 错误情况。 +func TestWebSocket_ErrorCases(t *testing.T) { + t.Run("连接无效后端", func(t *testing.T) { + ctx := testutil.NewRequestCtxWithHeader("GET", "/ws", map[string]string{ + "Upgrade": "websocket", + "Connection": "Upgrade", + }) + + target := &loadbalance.Target{URL: "http://127.0.0.1:1"} + target.Healthy.Store(true) + + // 使用很短的超时 + err := WebSocket(ctx, target, 10*time.Millisecond) + if err == nil { + t.Error("WebSocket() should return error for invalid backend") + } + }) +} + +// TestDialTarget_TLS_Extra 测试 TLS 连接。 +func TestDialTarget_TLS_Extra(t *testing.T) { + t.Run("TLS 握手失败", func(t *testing.T) { + // 使用不可达的 HTTPS 地址 + _, err := dialTarget("https://10.255.255.1:9999", 100*time.Millisecond) + if err == nil { + t.Error("dialTarget() should return error for unreachable HTTPS address") + } + }) +} + +// TestCreateHostClient_SSL 测试 SSL 配置。 +func TestCreateHostClient_SSL(t *testing.T) { + t.Run("启用 SSL 验证", func(t *testing.T) { + sslCfg := &config.ProxySSLConfig{ + Enabled: true, + InsecureSkipVerify: false, + } + + client := createHostClient("https://example.com:443", config.ProxyTimeout{ + Connect: 5 * time.Second, + }, nil, sslCfg, "", nil) + + if client == nil { + t.Error("createHostClient() returned nil") + } + if client.TLSConfig == nil { + t.Error("TLSConfig should be set for HTTPS target") + } + }) +} + +// TestBackgroundRefresh_Revalidate 测试缓存后台刷新的 Revalidate 功能。 +func TestBackgroundRefresh_Revalidate(t *testing.T) { + t.Run("Revalidate 启用但无缓存条目", func(t *testing.T) { + ln := fasthttputil.NewInmemoryListener() + defer ln.Close() + + go func() { + s := &fasthttp.Server{ + Handler: func(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(200) + ctx.SetBodyString("refreshed") + }, + } + _ = s.Serve(ln) + }() + + time.Sleep(10 * time.Millisecond) + + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + Revalidate: true, + }, + } + + targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}} + targets[0].Healthy.Store(true) + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + hashKey := uint64(12345) + + // 无缓存条目时调用 backgroundRefresh + p.backgroundRefresh(ctx, targets[0], hashKey, "GET:/api/test") + }) +} + +// TestSelectByBalancer 测试 selectByBalancer 方法。 +func TestSelectByBalancer(t *testing.T) { + t.Run("IPHash 负载均衡", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "ip_hash", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + } + for _, t := range targets { + t.Healthy.Store(true) + } + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 使用不同 IP 的请求应选择不同目标 + ctx1 := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{ + "X-Forwarded-For": "192.168.1.1", + }) + ctx2 := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{ + "X-Forwarded-For": "192.168.1.2", + }) + + selected1 := p.selectByBalancer(ctx1, targets) + selected2 := p.selectByBalancer(ctx2, targets) + + if selected1 == nil || selected2 == nil { + t.Error("selectByBalancer() should return a target") + } + }) + + t.Run("ConsistentHash 负载均衡", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "consistent_hash", + VirtualNodes: 100, + HashKey: "uri", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + } + for _, t := range targets { + t.Healthy.Store(true) + } + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/users/123") + selected := p.selectByBalancer(ctx, targets) + + if selected == nil { + t.Error("selectByBalancer() should return a target for consistent_hash") + } + + // 相同 URI 应选择相同目标 + selected2 := p.selectByBalancer(ctx, targets) + if selected2 == nil || selected.URL != selected2.URL { + t.Error("consistent_hash should return same target for same URI") + } + }) + + t.Run("ConsistentHash with header key", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "consistent_hash", + VirtualNodes: 100, + HashKey: "header:X-User-Id", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + } + for _, t := range targets { + t.Healthy.Store(true) + } + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{ + "X-User-Id": "user-123", + }) + selected := p.selectByBalancer(ctx, targets) + + if selected == nil { + t.Error("selectByBalancer() should return a target for header-based hash") + } + }) + + t.Run("RoundRobin 负载均衡", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + } + for _, t := range targets { + t.Healthy.Store(true) + } + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + selected := p.selectByBalancer(ctx, targets) + + if selected == nil { + t.Error("selectByBalancer() should return a target for round_robin") + } + }) +} + +// TestSelectTargetExcluding_Extra 测试 selectTargetExcluding 方法。 +func TestSelectTargetExcluding_Extra(t *testing.T) { + t.Run("排除已失败目标", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + {URL: "http://backend3:8080"}, + } + for _, t := range targets { + t.Healthy.Store(true) + } + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + + // 排除第一个目标 + excluded := []*loadbalance.Target{targets[0]} + selected := p.selectTargetExcluding(ctx, excluded) + + if selected == nil { + t.Error("selectTargetExcluding() should return a target") + } + if selected != nil && selected.URL == targets[0].URL { + t.Error("selectTargetExcluding() should not return excluded target") + } + }) + + t.Run("排除所有目标返回 nil", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + } + targets[0].Healthy.Store(true) + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + + // 排除所有目标 + excluded := []*loadbalance.Target{targets[0]} + selected := p.selectTargetExcluding(ctx, excluded) + + if selected != nil { + t.Error("selectTargetExcluding() should return nil when all targets excluded") + } + }) +} + +// TestExtractHashKey_Extra 测试 extractHashKey 方法。 +func TestExtractHashKey_Extra(t *testing.T) { + t.Run("使用 IP 作为 hash key", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "consistent_hash", + HashKey: "ip", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{ + "X-Forwarded-For": "10.0.0.1", + }) + + key := p.extractHashKey(ctx, "ip") + if key != "10.0.0.1" { + t.Errorf("extractHashKey() = %q, want %q", key, "10.0.0.1") + } + }) + + t.Run("使用 URI 作为 hash key", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "consistent_hash", + HashKey: "uri", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/users/123") + + key := p.extractHashKey(ctx, "uri") + if key != "/api/users/123" { + t.Errorf("extractHashKey() = %q, want %q", key, "/api/users/123") + } + }) + + t.Run("使用 header 作为 hash key", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "consistent_hash", + HashKey: "header:X-Session-Id", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{ + "X-Session-Id": "session-abc-123", + }) + + key := p.extractHashKey(ctx, "header:X-Session-Id") + if key != "session-abc-123" { + t.Errorf("extractHashKey() = %q, want %q", key, "session-abc-123") + } + }) + + t.Run("header 不存在时回退到 IP", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "consistent_hash", + HashKey: "header:X-Nonexistent", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtxWithHeader("GET", "/api/test", map[string]string{ + "X-Forwarded-For": "10.0.0.5", + }) + + key := p.extractHashKey(ctx, "header:X-Nonexistent") + if key != "10.0.0.5" { + t.Errorf("extractHashKey() should fallback to IP, got %q", key) + } + }) +} + +// TestSelectTarget_LuaEnabled 测试 selectTarget 在 Lua 启用时的行为。 +func TestSelectTarget_LuaEnabled(t *testing.T) { + t.Run("Lua 引擎为 nil 时使用传统算法", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + BalancerByLua: config.BalancerByLuaConfig{ + Enabled: true, + Script: "/nonexistent.lua", + }, + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + {URL: "http://backend2:8080"}, + } + for _, t := range targets { + t.Healthy.Store(true) + } + + // luaEngine 为 nil + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + selected := p.selectTarget(ctx) + + if selected == nil { + t.Error("selectTarget() should return a target using fallback") + } + }) + + t.Run("Lua 脚本为空时使用传统算法", func(t *testing.T) { + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + BalancerByLua: config.BalancerByLuaConfig{ + Enabled: true, + Script: "", + }, + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + + targets := []*loadbalance.Target{ + {URL: "http://backend1:8080"}, + } + targets[0].Healthy.Store(true) + + luaEngine, err := lua.NewEngine(nil) + if err != nil { + t.Fatalf("NewEngine() error: %v", err) + } + p, err := NewProxy(cfg, targets, nil, luaEngine) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + ctx := testutil.NewRequestCtx("GET", "/api/test") + selected := p.selectTarget(ctx) + + if selected == nil { + t.Error("selectTarget() should return a target using traditional balancer") + } + }) +} + +// TestExtractHostFromURL 测试 extractHostFromURL 函数。 +func TestExtractHostFromURL(t *testing.T) { + tests := []struct { + name string + url string + want string + }{ + { + name: "HTTP URL with port", + url: "http://example.com:8080", + want: "example.com:8080", + }, + { + name: "HTTPS URL with port", + url: "https://example.com:8443", + want: "example.com:8443", + }, + { + name: "HTTP URL without port", + url: "http://example.com", + want: "example.com", + }, + { + name: "HTTPS URL without port", + url: "https://example.com", + want: "example.com", + }, + { + name: "URL with path", + url: "http://example.com:8080/api/users", + want: "example.com:8080", + }, + { + name: "No protocol prefix", + url: "example.com:8080", + want: "example.com:8080", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractHostFromURL(tt.url) + if got != tt.want { + t.Errorf("extractHostFromURL() = %q, want %q", got, tt.want) + } + }) + } +} + +// TestTempFileManager_Threshold 测试 ShouldUseTempFile 的阈值逻辑。 +func TestTempFileManager_Threshold(t *testing.T) { + t.Run("响应大于阈值", func(t *testing.T) { + manager, err := NewTempFileManager(t.TempDir(), "1kb", "10kb") + if err != nil { + t.Fatalf("NewTempFileManager() error: %v", err) + } + + if !manager.ShouldUseTempFile(2048) { + t.Error("ShouldUseTempFile() should return true for 2KB when threshold is 1KB") + } + }) + + t.Run("响应小于阈值", func(t *testing.T) { + manager, err := NewTempFileManager(t.TempDir(), "1kb", "10kb") + if err != nil { + t.Fatalf("NewTempFileManager() error: %v", err) + } + + if manager.ShouldUseTempFile(512) { + t.Error("ShouldUseTempFile() should return false for 512B when threshold is 1KB") + } + }) +} + +// TestBackgroundRefresh_304 测试后台刷新收到 304 响应。 +func TestBackgroundRefresh_304(t *testing.T) { + ln := fasthttputil.NewInmemoryListener() + defer ln.Close() + + go func() { + s := &fasthttp.Server{ + Handler: func(ctx *fasthttp.RequestCtx) { + // 检查条件请求头 + if ctx.Request.Header.Peek("If-Modified-Since") != nil || + ctx.Request.Header.Peek("If-None-Match") != nil { + ctx.SetStatusCode(304) + ctx.Response.Header.Set("Last-Modified", "Wed, 21 Oct 2015 07:28:00 GMT") + ctx.Response.Header.Set("ETag", "\"abc123\"") + return + } + ctx.SetStatusCode(200) + ctx.SetBodyString("fresh content") + }, + } + _ = s.Serve(ln) + }() + + time.Sleep(10 * time.Millisecond) + + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + Revalidate: true, + }, + } + + targets := []*loadbalance.Target{{URL: "http://" + ln.Addr().String()}} + targets[0].Healthy.Store(true) + + p, err := NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 预先设置缓存条目 + ctx := testutil.NewRequestCtx("GET", "/api/test") + hashKey, origKey := p.buildCacheKeyHash(ctx) + p.cache.Set(hashKey, origKey, []byte("cached"), map[string]string{ + "Last-Modified": "Tue, 20 Oct 2015 07:28:00 GMT", + "ETag": "\"old\"", + }, 200, 10*time.Second) + + // 调用后台刷新 + p.backgroundRefresh(ctx, targets[0], hashKey, origKey) +} diff --git a/internal/server/pool_test.go b/internal/server/pool_test.go index 3cf052b..ae93307 100644 --- a/internal/server/pool_test.go +++ b/internal/server/pool_test.go @@ -236,3 +236,220 @@ func TestPoolWrapHandler_WhenStopped(t *testing.T) { t.Error("Expected handler to be executed directly when pool is stopped") } } + +func TestPoolSubmit_QueueFull_StartNewWorker(t *testing.T) { + // 测试队列满时启动新 worker + p := NewGoroutinePool(PoolConfig{ + MaxWorkers: 10, + MinWorkers: 1, + QueueSize: 1, // 小队列,容易满 + IdleTimeout: 5 * time.Second, + }) + + p.Start() + defer p.Stop() + + // 填满队列,让后续提交触发 default 分支 + var executedCount atomic.Int32 + blockTask := func(*fasthttp.RequestCtx) { + time.Sleep(200 * time.Millisecond) // 阻塞任务 + executedCount.Add(1) + } + + // 先提交一个阻塞任务,让 worker 忙碌 + _ = p.Submit(nil, blockTask) + time.Sleep(50 * time.Millisecond) // 等待 worker 开始执行 + + // 填满队列 + _ = p.Submit(nil, func(*fasthttp.RequestCtx) { executedCount.Add(1) }) + + // 此时队列满,应该启动新 worker + _ = p.Submit(nil, func(*fasthttp.RequestCtx) { executedCount.Add(1) }) + + // 等待所有任务完成 + time.Sleep(500 * time.Millisecond) + + if executedCount.Load() < 2 { + t.Errorf("Expected at least 2 executions, got %d", executedCount.Load()) + } +} + +func TestPoolSubmit_QueueFull_MaxWorkers_Fallback(t *testing.T) { + // 测试队列满且达到最大 worker 时直接执行(fallback) + p := NewGoroutinePool(PoolConfig{ + MaxWorkers: 1, // 只有 1 个 worker + MinWorkers: 1, + QueueSize: 1, // 队列大小 1 + IdleTimeout: 5 * time.Second, + }) + + p.Start() + defer p.Stop() + + // 使用 channel 阻塞唯一的 worker + blockCh := make(chan struct{}) + started := make(chan struct{}) + _ = p.Submit(nil, func(*fasthttp.RequestCtx) { + close(started) // 通知 worker 已开始 + <-blockCh // 阻塞直到测试结束 + }) + + // 等待 worker 开始执行阻塞任务 + <-started + + // 填满队列 + _ = p.Submit(nil, func(*fasthttp.RequestCtx) {}) + + // 现在唯一的 worker 在阻塞,队列已满 + // 提交新任务应该触发 fallback 直接执行 + var fallbackExecuted atomic.Bool + _ = p.Submit(nil, func(*fasthttp.RequestCtx) { + fallbackExecuted.Store(true) + }) + + // fallback 执行是同步的直接执行 + if !fallbackExecuted.Load() { + t.Error("Expected task to be executed directly (fallback) when max workers reached") + } + + // 释放阻塞的 worker + close(blockCh) +} + +func TestPoolSubmit_WithIdleWorkers(t *testing.T) { + // 测试有空闲 worker 时不启动新 worker + p := NewGoroutinePool(PoolConfig{ + MaxWorkers: 10, + MinWorkers: 5, // 预热 5 个 worker + QueueSize: 10, + IdleTimeout: 5 * time.Second, + }) + + p.Start() + defer p.Stop() + + // 等待预热完成,worker 应该是空闲的 + time.Sleep(50 * time.Millisecond) + + initialWorkers := atomic.LoadInt32(&p.workers) + + // 提交任务,应该复用空闲 worker + var executed atomic.Bool + _ = p.Submit(nil, func(*fasthttp.RequestCtx) { + executed.Store(true) + }) + + time.Sleep(100 * time.Millisecond) + + if !executed.Load() { + t.Error("Expected task to be executed") + } + + // worker 数量应该保持稳定或更少(不应该增加) + finalWorkers := atomic.LoadInt32(&p.workers) + if finalWorkers > initialWorkers { + t.Errorf("Worker count should not increase when idle workers available: %d -> %d", initialWorkers, finalWorkers) + } +} + +func TestPoolSubmit_NilTask(t *testing.T) { + // 测试提交 nil 任务不会 panic + p := NewGoroutinePool(PoolConfig{ + MaxWorkers: 10, + }) + + p.Start() + defer p.Stop() + + // 提交 nil 任务 - 这会导致 panic,所以不测试 + // 但可以测试空任务函数 + var executed atomic.Bool + emptyTask := func(*fasthttp.RequestCtx) { + executed.Store(true) + } + + err := p.Submit(nil, emptyTask) + if err != nil { + t.Errorf("Submit failed: %v", err) + } + + time.Sleep(50 * time.Millisecond) + + if !executed.Load() { + t.Error("Expected empty task to be executed") + } +} + +func TestPoolSubmit_MultipleQueuedTasks(t *testing.T) { + // 测试多个任务入队 + p := NewGoroutinePool(PoolConfig{ + MaxWorkers: 5, + MinWorkers: 2, + QueueSize: 10, + IdleTimeout: 5 * time.Second, + }) + + p.Start() + defer p.Stop() + + var counter atomic.Int32 + + // 快速提交多个任务 + for i := 0; i < 5; i++ { + _ = p.Submit(nil, func(*fasthttp.RequestCtx) { + counter.Add(1) + }) + } + + time.Sleep(200 * time.Millisecond) + + if counter.Load() != 5 { + t.Errorf("Expected 5 executions, got %d", counter.Load()) + } +} + +func TestPoolSubmit_StartWorkerWhenNoIdle(t *testing.T) { + // 测试当没有空闲 worker 时启动新 worker + // 使用 MinWorkers=1 让池只预热 1 个 worker + p := NewGoroutinePool(PoolConfig{ + MaxWorkers: 5, + MinWorkers: 1, // 只预热 1 个 worker + QueueSize: 10, + IdleTimeout: 5 * time.Second, + }) + + p.Start() + defer p.Stop() + + // 等待预热完成 + time.Sleep(50 * time.Millisecond) + + // 用阻塞任务让唯一的 worker 忙碌 + blockCh := make(chan struct{}) + started := make(chan struct{}) + _ = p.Submit(nil, func(*fasthttp.RequestCtx) { + close(started) + <-blockCh + }) + <-started // 等待 worker 开始执行 + + // 现在唯一的 worker 在忙碌,idleWorkers == 0 + // 提交新任务应该启动新 worker + var executed atomic.Bool + _ = p.Submit(nil, func(*fasthttp.RequestCtx) { + executed.Store(true) + }) + + time.Sleep(100 * time.Millisecond) + + if !executed.Load() { + t.Error("Expected task to be executed by new worker") + } + + // 检查是否启动了新 worker + if atomic.LoadInt32(&p.workers) < 2 { + t.Errorf("Expected at least 2 workers, got %d", atomic.LoadInt32(&p.workers)) + } + + close(blockCh) +} diff --git a/internal/server/pprof_impl_test.go b/internal/server/pprof_impl_test.go new file mode 100644 index 0000000..b226270 --- /dev/null +++ b/internal/server/pprof_impl_test.go @@ -0,0 +1,262 @@ +// Package server 提供 pprof 实现的测试。 +package server + +import ( + "bufio" + "bytes" + "sync" + "testing" +) + +// TestStartCPUProfile 测试 startCPUProfile 函数 +func TestStartCPUProfile(t *testing.T) { + t.Run("start and stop CPU profile", func(t *testing.T) { + var buf bytes.Buffer + + err := startCPUProfile(&buf) + if err != nil { + t.Errorf("unexpected error starting CPU profile: %v", err) + } + + // 停止 CPU profile + stopCPUProfile() + }) + + t.Run("start twice should not error", func(t *testing.T) { + var buf1, buf2 bytes.Buffer + + err1 := startCPUProfile(&buf1) + if err1 != nil { + t.Errorf("unexpected error on first start: %v", err1) + } + + // 第二次启动应该被忽略(已在采集) + err2 := startCPUProfile(&buf2) + if err2 != nil { + t.Errorf("unexpected error on second start: %v", err2) + } + + stopCPUProfile() + }) + + t.Run("stop without start should not panic", func(t *testing.T) { + // 确保停止状态 + stopCPUProfile() + // 再次停止应该安全 + stopCPUProfile() + }) +} + +// TestStopCPUProfile 测试 stopCPUProfile 函数 +func TestStopCPUProfile(t *testing.T) { + t.Run("stop when not active", func(t *testing.T) { + // 确保停止状态 + stopCPUProfile() + // 再次停止应该安全 + stopCPUProfile() + }) + + t.Run("stop when active", func(t *testing.T) { + var buf bytes.Buffer + err := startCPUProfile(&buf) + if err != nil { + t.Fatalf("failed to start CPU profile: %v", err) + } + + stopCPUProfile() + + // 验证可以再次启动 + err = startCPUProfile(&buf) + if err != nil { + t.Errorf("failed to restart CPU profile after stop: %v", err) + } + stopCPUProfile() + }) +} + +// TestWriteHeapProfile 测试 writeHeapProfile 函数 +func TestWriteHeapProfile(t *testing.T) { + t.Run("write heap profile", func(t *testing.T) { + var buf bytes.Buffer + + // 写入 heap profile + writeHeapProfile(&buf) + + // 验证有输出 + if buf.Len() == 0 { + t.Error("expected heap profile output, got empty buffer") + } + }) + + t.Run("write heap profile multiple times", func(t *testing.T) { + var buf1, buf2 bytes.Buffer + + writeHeapProfile(&buf1) + writeHeapProfile(&buf2) + + // 两次都应该有输出 + if buf1.Len() == 0 { + t.Error("expected heap profile output on first call") + } + if buf2.Len() == 0 { + t.Error("expected heap profile output on second call") + } + }) +} + +// TestWriteGoroutineProfile 测试 writeGoroutineProfile 函数 +func TestWriteGoroutineProfile(t *testing.T) { + t.Run("write goroutine profile", func(t *testing.T) { + var buf bytes.Buffer + + writeGoroutineProfile(&buf) + + // 验证有输出 + if buf.Len() == 0 { + t.Error("expected goroutine profile output, got empty buffer") + } + }) + + t.Run("write goroutine profile with spawned goroutines", func(t *testing.T) { + var buf bytes.Buffer + + // 启动一些 goroutine + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + select {} + }() + } + + writeGoroutineProfile(&buf) + + // 应该有输出 + if buf.Len() == 0 { + t.Error("expected goroutine profile output") + } + }) +} + +// TestWriteBlockProfile 测试 writeBlockProfile 函数 +func TestWriteBlockProfile(t *testing.T) { + t.Run("write block profile", func(t *testing.T) { + var buf bytes.Buffer + + writeBlockProfile(&buf) + + // block profile 可能为空(如果没有阻塞操作) + // 所以我们只验证函数不会 panic + }) +} + +// TestWriteMutexProfile 测试 writeMutexProfile 函数 +func TestWriteMutexProfile(t *testing.T) { + t.Run("write mutex profile", func(t *testing.T) { + var buf bytes.Buffer + + writeMutexProfile(&buf) + + // mutex profile 可能为空(如果没有锁竞争) + // 所以我们只验证函数不会 panic + }) +} + +// TestBufioWriterAdapter 测试 bufioWriterAdapter 结构体 +func TestBufioWriterAdapter(t *testing.T) { + t.Run("write and flush", func(t *testing.T) { + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + writer := wrapBufioWriter(bw) + + data := []byte("test data") + n, err := writer.Write(data) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if n != len(data) { + t.Errorf("expected %d bytes written, got %d", len(data), n) + } + }) + + t.Run("write multiple times", func(t *testing.T) { + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + writer := wrapBufioWriter(bw) + + data1 := []byte("first") + data2 := []byte("second") + + n1, err1 := writer.Write(data1) + if err1 != nil { + t.Errorf("unexpected error on first write: %v", err1) + } + if n1 != len(data1) { + t.Errorf("expected %d bytes on first write, got %d", len(data1), n1) + } + + n2, err2 := writer.Write(data2) + if err2 != nil { + t.Errorf("unexpected error on second write: %v", err2) + } + if n2 != len(data2) { + t.Errorf("expected %d bytes on second write, got %d", len(data2), n2) + } + }) +} + +// TestWrapBufioWriter 测试 wrapBufioWriter 函数 +func TestWrapBufioWriter(t *testing.T) { + t.Run("wrap returns non-nil", func(t *testing.T) { + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + writer := wrapBufioWriter(bw) + + if writer == nil { + t.Error("expected non-nil writer") + } + }) + + t.Run("wrapped writer implements io.Writer", func(t *testing.T) { + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + writer := wrapBufioWriter(bw) + + // 测试写入 + data := []byte("hello world") + n, err := writer.Write(data) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if n != len(data) { + t.Errorf("expected %d bytes, got %d", len(data), n) + } + }) +} + +// TestCPUProfileMutex 测试 CPU profile 的并发安全性 +func TestCPUProfileMutex(t *testing.T) { + t.Run("concurrent start/stop", func(t *testing.T) { + var wg sync.WaitGroup + var buf bytes.Buffer + + // 启动多个 goroutine 同时操作 + for i := 0; i < 10; i++ { + wg.Add(2) + go func() { + defer wg.Done() + _ = startCPUProfile(&buf) + }() + go func() { + defer wg.Done() + stopCPUProfile() + }() + } + + wg.Wait() + + // 确保最终状态一致 + stopCPUProfile() + }) +} diff --git a/internal/server/pprof_test.go b/internal/server/pprof_test.go index e5b41a5..6ddd022 100644 --- a/internal/server/pprof_test.go +++ b/internal/server/pprof_test.go @@ -633,6 +633,138 @@ func TestPprofHandler_handleCPU_Params(t *testing.T) { } } +// TestPprofHandler_handleCPU 测试 handleCPU 函数的参数解析。 +// 使用 1 秒的最小采集时间来确保测试快速完成。 +func TestPprofHandler_handleCPU(t *testing.T) { + t.Run("default seconds (30s) - verify headers only", func(t *testing.T) { + // 注意:默认 30 秒太长,这里只验证请求解析逻辑 + // 实际的 profile 采集在 TestPprofHandler_handleCPU_WithShortDuration 中测试 + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/debug/pprof/profile") + + // 验证请求路径解析正确 + path := string(ctx.Path()) + if path != "/debug/pprof/profile" { + t.Errorf("unexpected path: %s", path) + } + + // 验证没有 seconds 参数时 QueryArgs 为空 + secArg := ctx.QueryArgs().Peek("seconds") + if secArg != nil { + t.Errorf("expected no seconds arg, got: %s", secArg) + } + }) + + t.Run("custom seconds parameter", func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/debug/pprof/profile?seconds=5") + + // 验证 seconds 参数解析正确 + secArg := ctx.QueryArgs().Peek("seconds") + if string(secArg) != "5" { + t.Errorf("expected seconds=5, got: %s", secArg) + } + }) + + t.Run("invalid seconds parameter", func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/debug/pprof/profile?seconds=invalid") + + // 验证参数存在 + secArg := ctx.QueryArgs().Peek("seconds") + if string(secArg) != "invalid" { + t.Errorf("expected seconds=invalid, got: %s", secArg) + } + + // strconv.Atoi("invalid") 会返回错误,函数会使用默认值 30 + }) +} + +// TestPprofHandler_handleCPU_Execute 执行 handleCPU 并验证响应。 +// 使用 1 秒采集时间来快速完成测试。 +func TestPprofHandler_handleCPU_Execute(t *testing.T) { + // 确保之前的 CPU profile 已停止 + stopCPUProfile() + + h := &PprofHandler{ + path: "/debug/pprof", + allowedIPs: []net.IP{}, + allowedNets: []*net.IPNet{}, + } + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/debug/pprof/profile?seconds=1") + + // 执行 handleCPU + h.handleCPU(ctx) + + // 验证 Content-Type + contentType := string(ctx.Response.Header.Peek("Content-Type")) + if contentType != "application/octet-stream" { + t.Errorf("expected Content-Type application/octet-stream, got: %s", contentType) + } + + // 验证响应体(CPU profile 数据) + body := ctx.Response.Body() + if len(body) == 0 { + t.Error("expected CPU profile output, got empty body") + } + + // 验证响应体包含 pprof header 标识 + // pprof 文件以特定的 magic number 开头 + if len(body) > 0 { + // pprof 格式的文件应该有内容 + t.Logf("CPU profile size: %d bytes", len(body)) + } +} + +// TestPprofHandler_handleCPU_NegativeSeconds 测试负数秒数参数解析。 +// 负数秒会被 sec > 0 检查过滤,使用默认值 30 秒。 +func TestPprofHandler_handleCPU_NegativeSeconds(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/debug/pprof/profile?seconds=-1") + + // 验证参数解析 + secArg := ctx.QueryArgs().Peek("seconds") + if string(secArg) != "-1" { + t.Errorf("expected seconds=-1, got: %s", secArg) + } + + // 注意:负数秒在 handleCPU 中会被 sec > 0 检查过滤,使用默认值 30 秒 + // 为了测试效率,这里只验证参数解析,不实际执行 handleCPU +} + +// TestPprofHandler_handleCPU_ZeroSeconds 测试零秒参数解析。 +// 零秒会被 sec > 0 检查过滤,使用默认值 30 秒。 +func TestPprofHandler_handleCPU_ZeroSeconds(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/debug/pprof/profile?seconds=0") + + // 验证参数解析 + secArg := ctx.QueryArgs().Peek("seconds") + if string(secArg) != "0" { + t.Errorf("expected seconds=0, got: %s", secArg) + } + + // 注意:零秒在 handleCPU 中会被 sec > 0 检查过滤,使用默认值 30 秒 + // 为了测试效率,这里只验证参数解析,不实际执行 handleCPU +} + +// TestPprofHandler_handleCPU_LargeSeconds 测试大数值秒数参数解析。 +func TestPprofHandler_handleCPU_LargeSeconds(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/debug/pprof/profile?seconds=999999") + + // 验证参数解析 + secArg := ctx.QueryArgs().Peek("seconds") + if string(secArg) != "999999" { + t.Errorf("expected seconds=999999, got: %s", secArg) + } + + // 注意:这里只验证参数解析不会溢出 + // 为了测试效率,不实际执行 handleCPU(会等待 999999 秒) +} + func TestPprofHandler_ConfigWithCIDRAndIP(t *testing.T) { // 测试混合配置 cfg := &config.PprofConfig{ diff --git a/internal/server/purge_test.go b/internal/server/purge_test.go index baa2071..a34c1fc 100644 --- a/internal/server/purge_test.go +++ b/internal/server/purge_test.go @@ -17,10 +17,13 @@ import ( "net" "strings" "testing" + "time" "github.com/valyala/fasthttp" "rua.plus/lolly/internal/cache" "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/loadbalance" + "rua.plus/lolly/internal/proxy" ) func TestPurgeHandler_Path(t *testing.T) { @@ -781,3 +784,607 @@ func TestPurgeHandler_checkAccess_WithAllowedIP(t *testing.T) { } }) } + +// mockProxyWithCache 是一个用于测试的 mock Proxy,可以返回指定的缓存。 +type mockProxyWithCache struct { + cache *cache.ProxyCache +} + +func (m *mockProxyWithCache) GetCache() *cache.ProxyCache { + return m.cache +} + +// TestPurgeHandler_PurgeByPath_WithRealCache 测试 purgeByPath 在有真实缓存时的行为。 +func TestPurgeHandler_PurgeByPath_WithRealCache(t *testing.T) { + // 创建启用缓存的代理 + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + }, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := proxy.NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 获取缓存并添加测试数据 + pcache := p.GetCache() + if pcache == nil { + t.Fatal("GetCache() should return non-nil when cache enabled") + } + + // 添加测试缓存条目 + hashKey1 := cache.HashPathWithMethod("/api/users", "GET") + pcache.Set(hashKey1, "GET:/api/users", []byte("test data 1"), nil, 200, time.Minute) + + hashKey2 := cache.HashPathWithMethod("/api/posts", "GET") + pcache.Set(hashKey2, "GET:/api/posts", []byte("test data 2"), nil, 200, time.Minute) + + hashKey3 := cache.HashPathWithMethod("/api/users", "POST") + pcache.Set(hashKey3, "POST:/api/users", []byte("test data 3"), nil, 200, time.Minute) + + // 创建带有代理的 handler + h, err := NewPurgeHandler(&Server{ + proxies: []*proxy.Proxy{p}, + }, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + t.Run("delete existing entry", func(t *testing.T) { + deleted := h.PurgeByPathForTest("/api/users", "GET") + if deleted != 1 { + t.Errorf("expected 1 deletion, got %d", deleted) + } + }) + + t.Run("delete different method", func(t *testing.T) { + deleted := h.PurgeByPathForTest("/api/users", "POST") + if deleted != 1 { + t.Errorf("expected 1 deletion, got %d", deleted) + } + }) + + t.Run("delete non-existing path", func(t *testing.T) { + deleted := h.PurgeByPathForTest("/api/nonexistent", "GET") + if deleted != 1 { + t.Errorf("expected 1 (proxy count), got %d", deleted) + } + }) + + t.Run("multiple proxies", func(t *testing.T) { + // 创建第二个代理 + p2, err := proxy.NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + pcache2 := p2.GetCache() + hashKey := cache.HashPathWithMethod("/test", "GET") + pcache2.Set(hashKey, "GET:/test", []byte("test"), nil, 200, time.Minute) + + h2, err := NewPurgeHandler(&Server{ + proxies: []*proxy.Proxy{p, p2}, + }, &config.CacheAPIConfig{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + deleted := h2.PurgeByPathForTest("/test", "GET") + if deleted != 2 { + t.Errorf("expected 2 deletions (2 proxies), got %d", deleted) + } + }) +} + +// TestPurgeHandler_PurgeByPattern_WithRealCache 测试 purgeByPattern 在有真实缓存时的行为。 +func TestPurgeHandler_PurgeByPattern_WithRealCache(t *testing.T) { + // 创建启用缓存的代理 + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + }, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := proxy.NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 获取缓存并添加测试数据 + pcache := p.GetCache() + if pcache == nil { + t.Fatal("GetCache() should return non-nil when cache enabled") + } + + // 添加多个测试缓存条目 + pcache.Set(cache.HashPathWithMethod("/api/users", "GET"), "GET:/api/users", []byte("data"), nil, 200, time.Minute) + pcache.Set(cache.HashPathWithMethod("/api/users/1", "GET"), "GET:/api/users/1", []byte("data"), nil, 200, time.Minute) + pcache.Set(cache.HashPathWithMethod("/api/posts", "GET"), "GET:/api/posts", []byte("data"), nil, 200, time.Minute) + pcache.Set(cache.HashPathWithMethod("/api/posts/1", "GET"), "GET:/api/posts/1", []byte("data"), nil, 200, time.Minute) + pcache.Set(cache.HashPathWithMethod("/api/users", "POST"), "POST:/api/users", []byte("data"), nil, 200, time.Minute) + + // 创建带有代理的 handler + h, err := NewPurgeHandler(&Server{ + proxies: []*proxy.Proxy{p}, + }, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + t.Run("wildcard pattern matches multiple", func(t *testing.T) { + // 重新添加数据 + pcache.Set(cache.HashPathWithMethod("/api/users", "GET"), "GET:/api/users", []byte("data"), nil, 200, time.Minute) + pcache.Set(cache.HashPathWithMethod("/api/users/1", "GET"), "GET:/api/users/1", []byte("data"), nil, 200, time.Minute) + pcache.Set(cache.HashPathWithMethod("/api/posts", "GET"), "GET:/api/posts", []byte("data"), nil, 200, time.Minute) + + // 注意:OrigKey 格式为 "METHOD:/path",所以模式需要匹配完整路径 + deleted := h.PurgeByPatternForTest("GET:/api/*", "GET") + if deleted < 1 { + t.Errorf("expected at least 1 deletion, got %d", deleted) + } + }) + + t.Run("empty method matches all methods", func(t *testing.T) { + // 重新添加数据 + pcache.Set(cache.HashPathWithMethod("/api/users", "GET"), "GET:/api/users", []byte("data"), nil, 200, time.Minute) + pcache.Set(cache.HashPathWithMethod("/api/users", "POST"), "POST:/api/users", []byte("data"), nil, 200, time.Minute) + + // 使用 * 通配符匹配所有方法 + deleted := h.PurgeByPatternForTest("*:/api/users", "") + if deleted < 1 { + t.Errorf("expected at least 1 deletion (all methods), got %d", deleted) + } + }) + + t.Run("specific method only", func(t *testing.T) { + // 重新添加数据 + pcache.Set(cache.HashPathWithMethod("/api/users", "GET"), "GET:/api/users", []byte("data"), nil, 200, time.Minute) + pcache.Set(cache.HashPathWithMethod("/api/users", "POST"), "POST:/api/users", []byte("data"), nil, 200, time.Minute) + + // 模式匹配 POST 方法的路径 + deleted := h.PurgeByPatternForTest("POST:/api/users", "POST") + if deleted < 1 { + t.Errorf("expected at least 1 deletion (POST only), got %d", deleted) + } + }) +} + +// TestPurgeHandler_PurgeByPath_WithProxyNoCache 测试代理没有缓存时的情况。 +func TestPurgeHandler_PurgeByPath_WithProxyNoCache(t *testing.T) { + // 创建禁用缓存的代理 + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + // Cache 未启用 + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := proxy.NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 确认缓存为 nil + if p.GetCache() != nil { + t.Fatal("GetCache() should return nil when cache disabled") + } + + // 创建带有代理的 handler + h, err := NewPurgeHandler(&Server{ + proxies: []*proxy.Proxy{p}, + }, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 没有缓存的代理应该返回 0 + deleted := h.PurgeByPathForTest("/api/users", "GET") + if deleted != 0 { + t.Errorf("expected 0 deletions for proxy without cache, got %d", deleted) + } +} + +// TestPurgeHandler_PurgeByPattern_WithProxyNoCache 测试代理没有缓存时的情况。 +func TestPurgeHandler_PurgeByPattern_WithProxyNoCache(t *testing.T) { + // 创建禁用缓存的代理 + cfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := proxy.NewProxy(cfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 创建带有代理的 handler + h, err := NewPurgeHandler(&Server{ + proxies: []*proxy.Proxy{p}, + }, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 没有缓存的代理应该返回 0 + deleted := h.PurgeByPatternForTest("/api/*", "GET") + if deleted != 0 { + t.Errorf("expected 0 deletions for proxy without cache, got %d", deleted) + } +} + +// TestPurgeHandler_PurgeByPath_WithCache 测试 purgeByPath 在有缓存时的行为。 +func TestPurgeHandler_PurgeByPath_WithCache(t *testing.T) { + t.Run("server with empty proxies", func(t *testing.T) { + // 创建带有空 proxies 列表的 handler + h, err := NewPurgeHandler(&Server{ + proxies: []*proxy.Proxy{}, + }, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 空 proxies 列表应该返回 0 + deleted := h.PurgeByPathForTest("/api/users", "GET") + if deleted != 0 { + t.Errorf("expected 0 deletions for empty proxies, got %d", deleted) + } + }) + + t.Run("empty path", func(t *testing.T) { + h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 空路径仍然会执行删除逻辑,只是哈希值为默认 GET 的哈希 + deleted := h.PurgeByPathForTest("", "") + if deleted != 0 { + t.Errorf("expected 0 deletions for nil server, got %d", deleted) + } + }) + + t.Run("path with special characters", func(t *testing.T) { + h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 特殊字符路径 + deleted := h.PurgeByPathForTest("/api/users?id=1&name=test", "GET") + if deleted != 0 { + t.Errorf("expected 0 deletions for nil server, got %d", deleted) + } + }) + + t.Run("path with unicode", func(t *testing.T) { + h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Unicode 路径 + deleted := h.PurgeByPathForTest("/api/用户/列表", "GET") + if deleted != 0 { + t.Errorf("expected 0 deletions for nil server, got %d", deleted) + } + }) + + t.Run("different methods", func(t *testing.T) { + h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"} + for _, method := range methods { + deleted := h.PurgeByPathForTest("/api/users", method) + if deleted != 0 { + t.Errorf("expected 0 deletions for nil server with method %s, got %d", method, deleted) + } + } + }) +} + +// TestPurgeHandler_PurgeByPattern_WithCache 测试 purgeByPattern 在有缓存时的行为。 +func TestPurgeHandler_PurgeByPattern_WithCache(t *testing.T) { + t.Run("empty pattern", func(t *testing.T) { + h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 空模式 + deleted := h.PurgeByPatternForTest("", "GET") + if deleted != 0 { + t.Errorf("expected 0 deletions for nil server, got %d", deleted) + } + }) + + t.Run("wildcard pattern", func(t *testing.T) { + h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 通配符模式 + deleted := h.PurgeByPatternForTest("/api/*", "GET") + if deleted != 0 { + t.Errorf("expected 0 deletions for nil server, got %d", deleted) + } + }) + + t.Run("double wildcard pattern", func(t *testing.T) { + h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 双通配符模式 + deleted := h.PurgeByPatternForTest("/api/**", "GET") + if deleted != 0 { + t.Errorf("expected 0 deletions for nil server, got %d", deleted) + } + }) + + t.Run("pattern with special characters", func(t *testing.T) { + h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 特殊字符模式 + deleted := h.PurgeByPatternForTest("/api/users?id=*", "GET") + if deleted != 0 { + t.Errorf("expected 0 deletions for nil server, got %d", deleted) + } + }) + + t.Run("exact pattern (no wildcard)", func(t *testing.T) { + h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 精确模式(无通配符) + deleted := h.PurgeByPatternForTest("/api/users", "GET") + if deleted != 0 { + t.Errorf("expected 0 deletions for nil server, got %d", deleted) + } + }) + + t.Run("different methods", func(t *testing.T) { + h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH"} + for _, method := range methods { + deleted := h.PurgeByPatternForTest("/api/*", method) + if deleted != 0 { + t.Errorf("expected 0 deletions for nil server with method %s, got %d", method, deleted) + } + } + }) + + t.Run("empty method matches all", func(t *testing.T) { + h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 空方法应该匹配所有条目 + deleted := h.PurgeByPatternForTest("/api/*", "") + if deleted != 0 { + t.Errorf("expected 0 deletions for nil server, got %d", deleted) + } + }) +} + +// TestPurgeHandler_PurgeByPath_HashConsistency 测试哈希一致性。 +func TestPurgeHandler_PurgeByPath_HashConsistency(t *testing.T) { + // 验证相同路径和方法产生相同哈希 + path := "/api/users" + method := "GET" + + hash1 := cache.HashPathWithMethod(path, method) + hash2 := cache.HashPathWithMethod(path, method) + + if hash1 != hash2 { + t.Errorf("hash not consistent: %d != %d", hash1, hash2) + } + + // 验证不同路径产生不同哈希 + hash3 := cache.HashPathWithMethod("/api/posts", method) + if hash1 == hash3 { + t.Error("expected different hashes for different paths") + } + + // 验证不同方法产生不同哈希 + hash4 := cache.HashPathWithMethod(path, "POST") + if hash1 == hash4 { + t.Error("expected different hashes for different methods") + } +} + +// TestPurgeHandler_PurgeByPattern_PatternMatching 测试模式匹配逻辑。 +func TestPurgeHandler_PurgeByPattern_PatternMatching(t *testing.T) { + tests := []struct { + pattern string + path string + want bool + }{ + // 通配符结尾 - 前缀匹配 + {"/api/*", "/api/users", true}, + {"/api/*", "/api/posts", true}, + {"/api/*", "/api/users/123", true}, // * 匹配剩余所有内容 + {"/api/*", "/other/path", false}, + + // 单个 * 匹配所有 + {"*", "/api/users", true}, + {"*", "/any/path", true}, + + // 中间通配符 + {"/api/*/users", "/api/v1/users", true}, + {"/api/*/users", "/api/v2/users", true}, + {"/api/*/users", "/api/users", true}, // 前缀和后缀都匹配 + {"/api/*/users", "/api/v1/posts", false}, + + // 精确匹配 + {"/api/users", "/api/users", true}, + {"/api/users", "/api/posts", false}, + + // 空模式 + {"", "", true}, + {"", "/api", false}, + + // 目录前缀匹配(以 / 结尾) + {"/api/", "/api/users", true}, + {"/api/", "/api/users/123", true}, + {"/api/", "/other/path", false}, + } + + for _, tt := range tests { + t.Run(tt.pattern+"_"+tt.path, func(t *testing.T) { + got := cache.MatchPattern(tt.pattern, tt.path) + if got != tt.want { + t.Errorf("MatchPattern(%q, %q) = %v, want %v", tt.pattern, tt.path, got, tt.want) + } + }) + } +} + +// TestPurgeHandler_PurgeByPath_VariousInputs 测试各种输入。 +func TestPurgeHandler_PurgeByPath_VariousInputs(t *testing.T) { + h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + tests := []struct { + name string + path string + method string + }{ + {"empty path and method", "", ""}, + {"empty path with method", "", "GET"}, + {"path with empty method", "/test", ""}, + {"root path", "/", "GET"}, + {"nested path", "/a/b/c/d/e", "GET"}, + {"path with trailing slash", "/api/users/", "GET"}, + {"path with query", "/api?key=value", "GET"}, + {"path with fragment", "/api#section", "GET"}, + {"path with encoded chars", "/api%2Fusers", "GET"}, + {"long path", strings.Repeat("/a", 100), "GET"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 应该不会 panic + deleted := h.PurgeByPathForTest(tt.path, tt.method) + if deleted != 0 { + t.Errorf("expected 0 deletions for nil server, got %d", deleted) + } + }) + } +} + +// TestPurgeHandler_PurgeByPattern_VariousInputs 测试各种模式输入。 +func TestPurgeHandler_PurgeByPattern_VariousInputs(t *testing.T) { + h, err := NewPurgeHandler(nil, &config.CacheAPIConfig{ + Path: "/_cache/purge", + Allow: []string{}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + tests := []struct { + name string + pattern string + method string + }{ + {"empty pattern and method", "", ""}, + {"empty pattern with method", "", "GET"}, + {"pattern with empty method", "/api/*", ""}, + {"single wildcard only", "*", "GET"}, + {"double wildcard only", "**", "GET"}, + {"multiple single wildcards", "/api/*/users/*", "GET"}, + {"mixed wildcards", "/api/**/users/*", "GET"}, + {"wildcard at start", "*/users", "GET"}, + {"wildcard at end", "/api/*", "GET"}, + {"consecutive wildcards", "/api/**/*", "GET"}, + {"long pattern", strings.Repeat("/a*", 20), "GET"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 应该不会 panic + deleted := h.PurgeByPatternForTest(tt.pattern, tt.method) + if deleted != 0 { + t.Errorf("expected 0 deletions for nil server, got %d", deleted) + } + }) + } +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 884d9ed..1090f02 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -14,13 +14,22 @@ package server import ( + "context" + "fmt" "net" "os" + "sync" "testing" "time" "github.com/valyala/fasthttp" "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/loadbalance" + "rua.plus/lolly/internal/lua" + "rua.plus/lolly/internal/middleware/accesslog" + "rua.plus/lolly/internal/middleware/security" + "rua.plus/lolly/internal/proxy" + "rua.plus/lolly/internal/ssl" "rua.plus/lolly/internal/version" ) @@ -1343,6 +1352,253 @@ func TestServer_GetProxyCacheStats_WithProxies(t *testing.T) { } } +// TestServer_GetProxyCacheStats_SingleProxyWithCache 测试单个代理带缓存的统计。 +func TestServer_GetProxyCacheStats_SingleProxyWithCache(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + + // 创建带缓存的代理 + proxyCfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + }, + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := proxy.NewProxy(proxyCfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + s.proxies = []*proxy.Proxy{p} + + // 获取统计 + stats := s.getProxyCacheStats() + // 新创建的缓存应该有 0 条目 + if stats.Entries < 0 { + t.Errorf("Expected non-negative entries, got %d", stats.Entries) + } + if stats.Pending < 0 { + t.Errorf("Expected non-negative pending, got %d", stats.Pending) + } +} + +// TestServer_GetProxyCacheStats_SingleProxyNoCache 测试单个代理无缓存的统计。 +func TestServer_GetProxyCacheStats_SingleProxyNoCache(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + + // 创建不带缓存的代理 + proxyCfg := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + // Cache 未启用 + } + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + p, err := proxy.NewProxy(proxyCfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + s.proxies = []*proxy.Proxy{p} + + // 获取统计 + stats := s.getProxyCacheStats() + // 无缓存时应返回 0 + if stats.Entries != 0 { + t.Errorf("Expected 0 entries for proxy without cache, got %d", stats.Entries) + } + if stats.Pending != 0 { + t.Errorf("Expected 0 pending for proxy without cache, got %d", stats.Pending) + } +} + +// TestServer_GetProxyCacheStats_MultipleProxies 测试多个代理的缓存统计聚合。 +func TestServer_GetProxyCacheStats_MultipleProxies(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + + // 创建多个代理:部分带缓存,部分不带 + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + + // 代理1:带缓存 + proxyCfg1 := &config.ProxyConfig{ + Path: "/api", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + }, + } + p1, err := proxy.NewProxy(proxyCfg1, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 代理2:不带缓存 + proxyCfg2 := &config.ProxyConfig{ + Path: "/static", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + p2, err := proxy.NewProxy(proxyCfg2, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + // 代理3:带缓存 + proxyCfg3 := &config.ProxyConfig{ + Path: "/data", + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 20 * time.Second, + }, + } + p3, err := proxy.NewProxy(proxyCfg3, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + + s.proxies = []*proxy.Proxy{p1, p2, p3} + + // 获取聚合统计 + stats := s.getProxyCacheStats() + // 统计应该非负 + if stats.Entries < 0 { + t.Errorf("Expected non-negative entries, got %d", stats.Entries) + } + if stats.Pending < 0 { + t.Errorf("Expected non-negative pending, got %d", stats.Pending) + } +} + +// TestServer_GetProxyCacheStats_AllProxiesWithCache 测试所有代理都有缓存的统计。 +func TestServer_GetProxyCacheStats_AllProxiesWithCache(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + + // 创建多个带缓存的代理 + proxies := make([]*proxy.Proxy, 3) + for i := 0; i < 3; i++ { + proxyCfg := &config.ProxyConfig{ + Path: fmt.Sprintf("/api%d", i), + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + Cache: config.ProxyCacheConfig{ + Enabled: true, + MaxAge: 10 * time.Second, + }, + } + p, err := proxy.NewProxy(proxyCfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + proxies[i] = p + } + + s.proxies = proxies + + // 获取统计 + stats := s.getProxyCacheStats() + // 应该聚合所有代理的统计 + if stats.Entries < 0 { + t.Errorf("Expected non-negative entries, got %d", stats.Entries) + } + if stats.Pending < 0 { + t.Errorf("Expected non-negative pending, got %d", stats.Pending) + } +} + +// TestServer_GetProxyCacheStats_AllProxiesNoCache 测试所有代理都没有缓存的统计。 +func TestServer_GetProxyCacheStats_AllProxiesNoCache(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + + targets := []*loadbalance.Target{{URL: "http://localhost:8080"}} + + // 创建多个不带缓存的代理 + proxies := make([]*proxy.Proxy, 3) + for i := 0; i < 3; i++ { + proxyCfg := &config.ProxyConfig{ + Path: fmt.Sprintf("/api%d", i), + LoadBalance: "round_robin", + Timeout: config.ProxyTimeout{Connect: 5 * time.Second}, + } + p, err := proxy.NewProxy(proxyCfg, targets, nil, nil) + if err != nil { + t.Fatalf("NewProxy() error: %v", err) + } + proxies[i] = p + } + + s.proxies = proxies + + // 获取统计 + stats := s.getProxyCacheStats() + // 所有代理都没有缓存,应该返回 0 + if stats.Entries != 0 { + t.Errorf("Expected 0 entries, got %d", stats.Entries) + } + if stats.Pending != 0 { + t.Errorf("Expected 0 pending, got %d", stats.Pending) + } +} + +// TestServer_GetProxyCacheStats_EmptyProxiesSlice 测试空代理切片的统计。 +func TestServer_GetProxyCacheStats_EmptyProxiesSlice(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + s.proxies = []*proxy.Proxy{} // 空切片 + + // 获取统计 + stats := s.getProxyCacheStats() + if stats.Entries != 0 { + t.Errorf("Expected 0 entries, got %d", stats.Entries) + } + if stats.Pending != 0 { + t.Errorf("Expected 0 pending, got %d", stats.Pending) + } +} + // TestServer_MultipleListeners 测试多个监听器。 func TestServer_MultipleListeners(t *testing.T) { cfg := &config.Config{ @@ -1375,3 +1631,1612 @@ func TestServer_MultipleListeners(t *testing.T) { // 清理 _ = s.StopWithTimeout(1 * time.Second) } + +// TestGracefulStop_RunningState 测试 GracefulStop 设置 running 为 false。 +func TestGracefulStop_RunningState(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + s.running = true + + if !s.running { + t.Fatal("running should be true before GracefulStop") + } + + err := s.GracefulStop(1 * time.Second) + if err != nil { + t.Errorf("GracefulStop failed: %v", err) + } + + if s.running { + t.Error("running should be false after GracefulStop") + } +} + +// TestGracefulStop_WithPool 测试 GracefulStop 停止 GoroutinePool。 +func TestGracefulStop_WithPool(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + Performance: config.PerformanceConfig{ + GoroutinePool: config.GoroutinePoolConfig{ + Enabled: true, + MaxWorkers: 10, + MinWorkers: 2, + IdleTimeout: 5 * time.Second, + }, + }, + } + + s := New(cfg) + s.running = true + + // 初始化并启动 pool + s.pool = initGoroutinePool(&cfg.Performance) + if s.pool != nil { + s.pool.Start() + } + + err := s.GracefulStop(1 * time.Second) + if err != nil { + t.Errorf("GracefulStop failed: %v", err) + } +} + +// TestGracefulStop_WithHealthCheckers 测试 GracefulStop 停止健康检查器。 +func TestGracefulStop_WithHealthCheckers(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + s.running = true + + // 创建 mock healthChecker (使用 nil,因为我们只测试循环不会 panic) + s.healthCheckers = nil + + err := s.GracefulStop(1 * time.Second) + if err != nil { + t.Errorf("GracefulStop failed: %v", err) + } +} + +// TestGracefulStop_WithAccessLog 测试 GracefulStop 关闭访问日志。 +func TestGracefulStop_WithAccessLog(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + s.running = true + + // 创建 accessLogMiddleware + s.accessLogMiddleware = accesslog.New(&config.LoggingConfig{}) + + err := s.GracefulStop(1 * time.Second) + if err != nil { + t.Errorf("GracefulStop failed: %v", err) + } +} + +// TestGracefulStop_WithTLSManager 测试 GracefulStop 关闭 TLS 管理器。 +func TestGracefulStop_WithTLSManager(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + s.running = true + + // 创建临时证书文件 + tempDir := t.TempDir() + certFile := tempDir + "/cert.pem" + keyFile := tempDir + "/key.pem" + + // 生成自签名证书用于测试 + if err := generateTestCert(certFile, keyFile); err != nil { + t.Skipf("failed to generate test cert: %v", err) + } + + tlsMgr, err := ssl.NewTLSManager(&config.SSLConfig{ + Cert: certFile, + Key: keyFile, + }) + if err != nil { + t.Skipf("failed to create TLS manager: %v", err) + } + s.tlsManager = tlsMgr + + err = s.GracefulStop(1 * time.Second) + if err != nil { + t.Errorf("GracefulStop failed: %v", err) + } +} + +// TestGracefulStop_WithLuaEngine 测试 GracefulStop 关闭 Lua 引擎。 +func TestGracefulStop_WithLuaEngine(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + s.running = true + + // 创建 Lua 引擎 + luaEngine, err := lua.NewEngine(lua.DefaultConfig()) + if err != nil { + t.Skipf("failed to create Lua engine: %v", err) + } + s.luaEngine = luaEngine + + err = s.GracefulStop(1 * time.Second) + if err != nil { + t.Errorf("GracefulStop failed: %v", err) + } +} + +// TestGracefulStop_Timeout 测试 GracefulStop 超时场景。 +func TestGracefulStop_Timeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + s.running = true + + // 创建一个真实的 fastServer,但通过模拟长时间关闭来测试超时 + s.fastServer = &fasthttp.Server{ + Handler: func(ctx *fasthttp.RequestCtx) { + ctx.SetBodyString("test") + }, + } + + // 使用非常短的超时 + err := s.GracefulStop(1 * time.Nanosecond) + // 超时可能返回 context.DeadlineExceeded 或 nil(取决于关闭速度) + if err != nil && err != context.DeadlineExceeded { + t.Errorf("unexpected error: %v", err) + } +} + +// TestGracefulStop_AllComponents 测试 GracefulStop 关闭所有组件。 +func TestGracefulStop_AllComponents(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + Performance: config.PerformanceConfig{ + GoroutinePool: config.GoroutinePoolConfig{ + Enabled: true, + MaxWorkers: 10, + IdleTimeout: 5 * time.Second, + }, + }, + } + + s := New(cfg) + s.running = true + + // 初始化所有组件 + s.pool = initGoroutinePool(&cfg.Performance) + if s.pool != nil { + s.pool.Start() + } + s.accessLogMiddleware = accesslog.New(&config.LoggingConfig{}) + + // 创建监听器 + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to create listener: %v", err) + } + s.listeners = []net.Listener{ln} + + err = s.GracefulStop(2 * time.Second) + if err != nil { + t.Errorf("GracefulStop failed: %v", err) + } + + // 验证 running 状态 + if s.running { + t.Error("running should be false after GracefulStop") + } +} + +// generateTestCert 生成测试用的自签名证书。 +func generateTestCert(certFile, keyFile string) error { + // 简化实现:跳过证书生成 + return fmt.Errorf("test cert generation not implemented") +} + +// TestGracefulStop_WithAccessControl 测试 GracefulStop 关闭访问控制。 +func TestGracefulStop_WithAccessControl(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + Security: config.SecurityConfig{ + Access: config.AccessConfig{ + Allow: []string{"127.0.0.1"}, + }, + }, + }}, + } + + s := New(cfg) + s.running = true + + // 创建 AccessControl + ac, err := security.NewAccessControl(&cfg.Servers[0].Security.Access) + if err != nil { + t.Skipf("failed to create AccessControl: %v", err) + } + s.accessControl = ac + + err = s.GracefulStop(1 * time.Second) + if err != nil { + t.Errorf("GracefulStop failed: %v", err) + } +} + +// TestGracefulStop_ContextCancelled 测试 GracefulStop 上下文取消场景。 +func TestGracefulStop_ContextCancelled(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + s.running = true + + // 创建一个监听中的服务器 + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to create listener: %v", err) + } + s.listeners = []net.Listener{ln} + + // 创建 fastServer 并开始服务 + s.fastServer = &fasthttp.Server{ + Handler: func(ctx *fasthttp.RequestCtx) { + time.Sleep(100 * time.Millisecond) // 模拟慢请求 + ctx.SetBodyString("ok") + }, + } + + // 启动服务器 + go func() { + _ = s.fastServer.Serve(ln) + }() + + // 等待服务器启动 + time.Sleep(10 * time.Millisecond) + + // 使用非常短的超时测试超时场景 + err = s.GracefulStop(1 * time.Nanosecond) + // 超时可能返回 context.DeadlineExceeded 或 nil + if err != nil && err != context.DeadlineExceeded { + t.Errorf("unexpected error: %v", err) + } +} + +// TestGracefulStop_MultipleHealthCheckers 测试 GracefulStop 停止多个健康检查器。 +func TestGracefulStop_MultipleHealthCheckers(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + s.running = true + + // 创建多个 mock healthChecker + // 注意:这里使用 nil slice 测试空循环不会 panic + s.healthCheckers = make([]*proxy.HealthChecker, 0) + + err := s.GracefulStop(1 * time.Second) + if err != nil { + t.Errorf("GracefulStop failed: %v", err) + } +} + +// TestGracefulStop_NilComponents 测试 GracefulStop 所有组件为 nil。 +func TestGracefulStop_NilComponents(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + s.running = true + + // 确保所有组件为 nil + s.pool = nil + s.healthCheckers = nil + s.accessLogMiddleware = nil + s.tlsManager = nil + s.accessControl = nil + s.luaEngine = nil + s.fastServer = nil + s.fastServers = nil + + err := s.GracefulStop(1 * time.Second) + if err != nil { + t.Errorf("GracefulStop failed: %v", err) + } + + if s.running { + t.Error("running should be false after GracefulStop") + } +} + +// TestGracefulStop_FastServersWithNil 测试 GracefulStop 处理 fastServers 中的 nil。 +func TestGracefulStop_FastServersWithNil(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + s.running = true + + // 创建包含 nil 的 fastServers + s.fastServers = []*fasthttp.Server{nil, {}, nil} + + err := s.GracefulStop(1 * time.Second) + if err != nil { + t.Errorf("GracefulStop failed: %v", err) + } +} + +// TestGracefulStop_ZeroTimeout 测试 GracefulStop 零超时。 +func TestGracefulStop_ZeroTimeout(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + s.running = true + s.fastServer = &fasthttp.Server{} + + err := s.GracefulStop(0) + // 零超时应该立即返回(可能导致超时错误或成功关闭) + if err != nil && err != context.DeadlineExceeded { + t.Errorf("unexpected error: %v", err) + } +} + +// TestGracefulStop_NegativeTimeout 测试 GracefulStop 负超时。 +func TestGracefulStop_NegativeTimeout(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":0", + }}, + } + + s := New(cfg) + s.running = true + s.fastServer = &fasthttp.Server{} + + err := s.GracefulStop(-1 * time.Second) + // 负超时应该立即返回 + if err != nil && err != context.DeadlineExceeded { + t.Errorf("unexpected error: %v", err) + } +} + +// TestStartSingleMode_StaticFiles 测试 startSingleMode 静态文件配置。 +func TestStartSingleMode_StaticFiles(t *testing.T) { + tempDir := t.TempDir() + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Static: []config.StaticConfig{ + { + Path: "/static", + Root: tempDir, + Index: []string{"index.html"}, + }, + { + Path: "/assets", + Root: tempDir, + LocationType: "exact", + SymlinkCheck: true, + Internal: true, + TryFiles: []string{"$uri", "/fallback.html"}, + TryFilesPass: true, + }, + }, + }}, + } + + s := New(cfg) + // 验证静态文件配置已正确设置 + if len(s.config.Servers[0].Static) != 2 { + t.Errorf("expected 2 static configs, got %d", len(s.config.Servers[0].Static)) + } + + // 验证第一个静态配置 + static1 := s.config.Servers[0].Static[0] + if static1.Path != "/static" { + t.Errorf("expected path /static, got %s", static1.Path) + } + if static1.Root != tempDir { + t.Errorf("expected root %s, got %s", tempDir, static1.Root) + } +} + +// TestStartSingleMode_StaticFilesWithGzipStatic 测试静态文件 gzip 预压缩配置。 +func TestStartSingleMode_StaticFilesWithGzipStatic(t *testing.T) { + tempDir := t.TempDir() + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Static: []config.StaticConfig{ + { + Path: "/", + Root: tempDir, + Index: []string{"index.html"}, + }, + }, + Compression: config.CompressionConfig{ + Type: "gzip", + Level: 6, + GzipStatic: true, + GzipStaticExtensions: []string{".html", ".css", ".js"}, + }, + }}, + } + + s := New(cfg) + // 验证 gzip 静态配置 + if !s.config.Servers[0].Compression.GzipStatic { + t.Error("expected GzipStatic to be true") + } + if len(s.config.Servers[0].Compression.GzipStaticExtensions) != 3 { + t.Errorf("expected 3 extensions, got %d", len(s.config.Servers[0].Compression.GzipStaticExtensions)) + } +} + +// TestStartSingleMode_ProxyWithLocationTypes 测试代理配置的不同位置类型。 +func TestStartSingleMode_ProxyWithLocationTypes(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api/exact", + LocationType: "exact", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:8081", Weight: 1}, + }, + }, + { + Path: "/api/priority", + LocationType: "prefix_priority", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:8082", Weight: 1}, + }, + }, + { + Path: "^/api/regex/(.*)$", + LocationType: "regex", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:8083", Weight: 1}, + }, + }, + { + Path: "^/api/caseless/(.*)$", + LocationType: "regex_caseless", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:8084", Weight: 1}, + }, + }, + { + Path: "/api/named", + LocationType: "named", + LocationName: "@api_named", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:8085", Weight: 1}, + }, + }, + { + Path: "/api/default", + // 默认 prefix 类型 + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:8086", Weight: 1}, + }, + Internal: true, + }, + }, + }}, + } + + s := New(cfg) + // 验证代理配置数量 + if len(s.config.Servers[0].Proxy) != 6 { + t.Errorf("expected 6 proxy configs, got %d", len(s.config.Servers[0].Proxy)) + } + + // 验证不同位置类型 + proxyTypes := []string{"exact", "prefix_priority", "regex", "regex_caseless", "named", ""} + for i, pt := range proxyTypes { + if s.config.Servers[0].Proxy[i].LocationType != pt { + t.Errorf("proxy[%d]: expected location type %s, got %s", i, pt, s.config.Servers[0].Proxy[i].LocationType) + } + } +} + +// TestStartSingleMode_ProxyWithHealthCheck 测试代理健康检查配置。 +func TestStartSingleMode_ProxyWithHealthCheck(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + { + URL: "http://127.0.0.1:8081", + Weight: 3, + MaxFails: 3, + FailTimeout: 10 * time.Second, + MaxConns: 100, + Backup: false, + Down: false, + }, + { + URL: "http://127.0.0.1:8082", + Weight: 1, + Backup: true, + }, + }, + LoadBalance: "weighted_round_robin", + HealthCheck: config.HealthCheckConfig{ + Interval: 10 * time.Second, + Timeout: 5 * time.Second, + Path: "/health", + }, + }, + }, + }}, + } + + s := New(cfg) + // 验证健康检查配置 + hc := s.config.Servers[0].Proxy[0].HealthCheck + if hc.Interval != 10*time.Second { + t.Errorf("expected interval 10s, got %v", hc.Interval) + } + if hc.Path != "/health" { + t.Errorf("expected path /health, got %s", hc.Path) + } +} + +// TestStartSingleMode_MonitoringEndpoints 测试监控端点配置。 +func TestStartSingleMode_MonitoringEndpoints(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + }}, + Monitoring: config.MonitoringConfig{ + Status: config.StatusConfig{ + Enabled: true, + Path: "/_status", + Format: "json", + Allow: []string{"127.0.0.1", "192.168.0.0/16"}, + }, + Pprof: config.PprofConfig{ + Enabled: true, + Path: "/debug/pprof", + Allow: []string{"127.0.0.1"}, + }, + }, + } + + s := New(cfg) + // 验证状态端点配置 + if !s.config.Monitoring.Status.Enabled { + t.Error("expected status enabled") + } + if s.config.Monitoring.Status.Path != "/_status" { + t.Errorf("expected status path /_status, got %s", s.config.Monitoring.Status.Path) + } + if len(s.config.Monitoring.Status.Allow) != 2 { + t.Errorf("expected 2 allowed IPs, got %d", len(s.config.Monitoring.Status.Allow)) + } + + // 验证 pprof 配置 + if !s.config.Monitoring.Pprof.Enabled { + t.Error("expected pprof enabled") + } +} + +// TestStartSingleMode_CacheAPI 测试缓存 API 配置。 +func TestStartSingleMode_CacheAPI(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + CacheAPI: &config.CacheAPIConfig{ + Enabled: true, + Path: "/_cache/purge", + Allow: []string{"127.0.0.1"}, + Auth: config.CacheAPIAuthConfig{Type: "token", Token: "secret-token"}, + }, + }}, + } + + s := New(cfg) + // 验证缓存 API 配置 + if s.config.Servers[0].CacheAPI == nil || !s.config.Servers[0].CacheAPI.Enabled { + t.Error("expected cache API enabled") + } + if s.config.Servers[0].CacheAPI.Path != "/_cache/purge" { + t.Errorf("expected path /_cache/purge, got %s", s.config.Servers[0].CacheAPI.Path) + } +} + +// TestStartSingleMode_TLSConfig 测试 TLS 配置。 +func TestStartSingleMode_TLSConfig(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + SSL: config.SSLConfig{ + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + Protocols: []string{"TLSv1.2", "TLSv1.3"}, + Ciphers: []string{"TLS_AES_128_GCM_SHA256"}, + HSTS: config.HSTSConfig{ + MaxAge: 31536000, + IncludeSubDomains: true, + Preload: true, + }, + }, + }}, + } + + s := New(cfg) + // 验证 SSL 配置 + if s.config.Servers[0].SSL.Cert != "/path/to/cert.pem" { + t.Errorf("expected cert path, got %s", s.config.Servers[0].SSL.Cert) + } + if s.config.Servers[0].SSL.HSTS.MaxAge != 31536000 { + t.Errorf("expected HSTS MaxAge 31536000, got %d", s.config.Servers[0].SSL.HSTS.MaxAge) + } +} + +// TestStartSingleMode_MIMETypes 测试 MIME 类型配置。 +func TestStartSingleMode_MIMETypes(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Types: config.TypesConfig{ + Map: map[string]string{ + ".wasm": "application/wasm", + ".custom": "application/x-custom", + }, + DefaultType: "application/octet-stream", + }, + }}, + } + + s := New(cfg) + // 验证 MIME 类型配置 + if len(s.config.Servers[0].Types.Map) != 2 { + t.Errorf("expected 2 MIME types, got %d", len(s.config.Servers[0].Types.Map)) + } + if s.config.Servers[0].Types.DefaultType != "application/octet-stream" { + t.Errorf("expected default type, got %s", s.config.Servers[0].Types.DefaultType) + } +} + +// TestStartSingleMode_ServerOptions 测试服务器选项配置。 +func TestStartSingleMode_ServerOptions(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + MaxConnsPerIP: 100, + MaxRequestsPerConn: 1000, + Concurrency: 256 * 1024, + ReadBufferSize: 16 * 1024, + WriteBufferSize: 16 * 1024, + ReduceMemoryUsage: true, + ServerTokens: false, + }}, + } + + s := New(cfg) + // 验证服务器选项 + sc := s.config.Servers[0] + if sc.ReadTimeout != 30*time.Second { + t.Errorf("expected ReadTimeout 30s, got %v", sc.ReadTimeout) + } + if sc.MaxConnsPerIP != 100 { + t.Errorf("expected MaxConnsPerIP 100, got %d", sc.MaxConnsPerIP) + } + if !sc.ReduceMemoryUsage { + t.Error("expected ReduceMemoryUsage true") + } + if sc.ServerTokens { + t.Error("expected ServerTokens false") + } +} + +// TestStartSingleMode_WithMiddlewareChain 测试中间件链配置。 +func TestStartSingleMode_WithMiddlewareChain(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + Access: config.AccessConfig{ + Allow: []string{"127.0.0.1"}, + Deny: []string{"10.0.0.0/8"}, + }, + RateLimit: config.RateLimitConfig{ + RequestRate: 100, + Burst: 200, + Key: "remote_addr", + }, + Auth: config.AuthConfig{ + Users: []config.User{ + {Name: "admin", Password: "secret"}, + }, + }, + Headers: config.SecurityHeaders{ + XFrameOptions: "DENY", + XContentTypeOptions: "nosniff", + ContentSecurityPolicy: "default-src 'self'", + ReferrerPolicy: "strict-origin-when-cross-origin", + }, + }, + Compression: config.CompressionConfig{ + Type: "gzip", + Level: 6, + }, + Rewrite: []config.RewriteRule{ + {Pattern: "^/old/(.*)$", Replacement: "/new/$1"}, + }, + }}, + } + + s := New(cfg) + // 验证中间件配置 + security := s.config.Servers[0].Security + if len(security.Access.Allow) != 1 { + t.Errorf("expected 1 allow rule, got %d", len(security.Access.Allow)) + } + if security.RateLimit.RequestRate != 100 { + t.Errorf("expected request rate 100, got %d", security.RateLimit.RequestRate) + } + if len(security.Auth.Users) != 1 { + t.Errorf("expected 1 auth user, got %d", len(security.Auth.Users)) + } +} + +// TestStartSingleMode_PerformanceConfig 测试性能配置。 +func TestStartSingleMode_PerformanceConfig(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + }}, + Performance: config.PerformanceConfig{ + GoroutinePool: config.GoroutinePoolConfig{ + Enabled: true, + MaxWorkers: 100, + MinWorkers: 10, + IdleTimeout: 30 * time.Second, + }, + FileCache: config.FileCacheConfig{ + MaxEntries: 10000, + MaxSize: 100 * 1024 * 1024, + }, + }, + } + + s := New(cfg) + // 验证性能配置 + if !s.config.Performance.GoroutinePool.Enabled { + t.Error("expected goroutine pool enabled") + } + if s.config.Performance.FileCache.MaxEntries != 10000 { + t.Errorf("expected 10000 max entries, got %d", s.config.Performance.FileCache.MaxEntries) + } +} + +// TestStartSingleMode_WithLuaMiddleware 测试 Lua 中间件配置。 +func TestStartSingleMode_WithLuaMiddleware(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Lua: &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + { + Path: "/scripts/access.lua", + Phase: "access", + Timeout: 30 * time.Second, + }, + { + Path: "/scripts/header.lua", + Phase: "header_filter", + Timeout: 10 * time.Second, + }, + }, + }, + }}, + } + + s := New(cfg) + // 验证 Lua 配置 + if s.config.Servers[0].Lua == nil || !s.config.Servers[0].Lua.Enabled { + t.Error("expected Lua enabled") + } + if len(s.config.Servers[0].Lua.Scripts) != 2 { + t.Errorf("expected 2 scripts, got %d", len(s.config.Servers[0].Lua.Scripts)) + } +} + +// TestStartSingleMode_WithErrorPage 测试错误页面配置。 +func TestStartSingleMode_WithErrorPage(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + ErrorPage: config.ErrorPageConfig{ + Pages: map[int]string{ + 404: "/errors/404.html", + 500: "/errors/500.html", + 502: "/errors/502.html", + }, + Default: "/errors/default.html", + }, + }, + }}, + } + + s := New(cfg) + // 验证错误页面配置 + ep := s.config.Servers[0].Security.ErrorPage + if len(ep.Pages) != 3 { + t.Errorf("expected 3 error pages, got %d", len(ep.Pages)) + } + if ep.Default != "/errors/default.html" { + t.Errorf("expected default error page, got %s", ep.Default) + } +} + +// TestStartSingleMode_WithConnLimiter 测试连接限制配置。 +func TestStartSingleMode_WithConnLimiter(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + RateLimit: config.RateLimitConfig{ + ConnLimit: 100, + Key: "remote_addr", + }, + }, + }}, + } + + s := New(cfg) + // 验证连接限制配置 + if s.config.Servers[0].Security.RateLimit.ConnLimit != 100 { + t.Errorf("expected ConnLimit 100, got %d", s.config.Servers[0].Security.RateLimit.ConnLimit) + } +} + +// TestStartSingleMode_WithAuthRequest 测试外部认证配置。 +func TestStartSingleMode_WithAuthRequest(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + AuthRequest: config.AuthRequestConfig{ + Enabled: true, + URI: "/auth/validate", + Timeout: 5 * time.Second, + }, + }, + }}, + } + + s := New(cfg) + // 验证外部认证配置 + ar := s.config.Servers[0].Security.AuthRequest + if !ar.Enabled { + t.Error("expected AuthRequest enabled") + } + if ar.URI != "/auth/validate" { + t.Errorf("expected URI /auth/validate, got %s", ar.URI) + } +} + +// TestShutdownServers_EmptySlice 测试空服务器列表。 +func TestShutdownServers_EmptySlice(t *testing.T) { + ctx := context.Background() + err := shutdownServers(ctx, []*fasthttp.Server{}) + if err != nil { + t.Errorf("shutdownServers with empty slice should return nil, got: %v", err) + } +} + +// TestShutdownServers_NilSlice 测试 nil 服务器列表。 +func TestShutdownServers_NilSlice(t *testing.T) { + ctx := context.Background() + err := shutdownServers(ctx, nil) + if err != nil { + t.Errorf("shutdownServers with nil slice should return nil, got: %v", err) + } +} + +// TestShutdownServers_NilContext 测试 nil 上下文。 +func TestShutdownServers_NilContext(t *testing.T) { + // nil ctx 应该使用 context.Background() + err := shutdownServers(nil, []*fasthttp.Server{}) + if err != nil { + t.Errorf("shutdownServers with nil ctx should return nil, got: %v", err) + } +} + +// TestShutdownServers_SingleServer 测试单个服务器关闭。 +func TestShutdownServers_SingleServer(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + servers := []*fasthttp.Server{ + {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, + } + + err := shutdownServers(ctx, servers) + if err != nil { + t.Errorf("shutdownServers failed: %v", err) + } +} + +// TestShutdownServers_MultipleServers 测试多个服务器关闭。 +func TestShutdownServers_MultipleServers(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + servers := []*fasthttp.Server{ + {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test1") }}, + {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test2") }}, + {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test3") }}, + } + + err := shutdownServers(ctx, servers) + if err != nil { + t.Errorf("shutdownServers failed: %v", err) + } +} + +// TestShutdownServers_WithNilServers 测试服务器列表中包含 nil。 +func TestShutdownServers_WithNilServers(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + servers := []*fasthttp.Server{ + nil, + {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, + nil, + } + + err := shutdownServers(ctx, servers) + if err != nil { + t.Errorf("shutdownServers failed: %v", err) + } +} + +// TestShutdownServers_AllNilServers 测试所有服务器都是 nil。 +func TestShutdownServers_AllNilServers(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + servers := []*fasthttp.Server{nil, nil, nil} + + err := shutdownServers(ctx, servers) + if err != nil { + t.Errorf("shutdownServers with all nil servers should return nil, got: %v", err) + } +} + +// TestShutdownServers_ContextCancelled 测试上下文取消。 +func TestShutdownServers_ContextCancelled(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + // 创建一个已取消的上下文 + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + servers := []*fasthttp.Server{ + {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, + } + + err := shutdownServers(ctx, servers) + // 已取消的上下文可能返回 context.Canceled 或 nil(取决于服务器关闭速度) + if err != nil && err != context.Canceled { + t.Errorf("unexpected error: %v", err) + } +} + +// TestShutdownServers_ContextTimeout 测试上下文超时。 +func TestShutdownServers_ContextTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + // 创建一个极短超时的上下文 + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + + // 等待超时 + time.Sleep(1 * time.Millisecond) + + servers := []*fasthttp.Server{ + {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, + } + + err := shutdownServers(ctx, servers) + // 超时的上下文可能返回 context.DeadlineExceeded 或 nil + if err != nil && err != context.DeadlineExceeded { + t.Errorf("unexpected error: %v", err) + } +} + +// TestShutdownServers_RunningServers 测试关闭运行中的服务器。 +func TestShutdownServers_RunningServers(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 创建服务器并启动 + servers := make([]*fasthttp.Server, 2) + listeners := make([]net.Listener, 2) + + for i := 0; i < 2; i++ { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to create listener: %v", err) + } + listeners[i] = ln + + srv := &fasthttp.Server{ + Handler: func(ctx *fasthttp.RequestCtx) { + ctx.SetBodyString("test") + }, + } + servers[i] = srv + + go func(s *fasthttp.Server, l net.Listener) { + _ = s.Serve(l) + }(srv, ln) + } + + // 等待服务器启动 + time.Sleep(10 * time.Millisecond) + + // 关闭服务器 + err := shutdownServers(ctx, servers) + if err != nil { + t.Errorf("shutdownServers failed: %v", err) + } + + // 关闭监听器(如果服务器没有关闭它们) + for _, ln := range listeners { + _ = ln.Close() + } +} + +// TestShutdownServers_ManyServers 测试关闭大量服务器。 +func TestShutdownServers_ManyServers(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // 创建大量服务器 + count := 50 + servers := make([]*fasthttp.Server, count) + for i := 0; i < count; i++ { + servers[i] = &fasthttp.Server{ + Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }, + } + } + + err := shutdownServers(ctx, servers) + if err != nil { + t.Errorf("shutdownServers with many servers failed: %v", err) + } +} + +// TestShutdownServers_MixedNilAndRealServers 测试混合 nil 和真实服务器。 +func TestShutdownServers_MixedNilAndRealServers(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + count := 20 + servers := make([]*fasthttp.Server, count) + for i := 0; i < count; i++ { + if i%2 == 0 { + servers[i] = nil + } else { + servers[i] = &fasthttp.Server{ + Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }, + } + } + } + + err := shutdownServers(ctx, servers) + if err != nil { + t.Errorf("shutdownServers failed: %v", err) + } +} + +// TestShutdownServers_ConcurrentSafety 测试并发安全性。 +func TestShutdownServers_ConcurrentSafety(t *testing.T) { + ctx := context.Background() + + // 并发调用 shutdownServers + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + servers := []*fasthttp.Server{ + {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, + {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, + } + _ = shutdownServers(ctx, servers) + }() + } + wg.Wait() +} + +// TestShutdownServers_WithDeadline 测试带截止时间的上下文。 +func TestShutdownServers_WithDeadline(t *testing.T) { + deadline := time.Now().Add(5 * time.Second) + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + servers := []*fasthttp.Server{ + {Handler: func(ctx *fasthttp.RequestCtx) { ctx.SetBodyString("test") }}, + } + + err := shutdownServers(ctx, servers) + if err != nil { + t.Errorf("shutdownServers failed: %v", err) + } +} + +// TestBuildLuaMiddlewares_SingleScript 测试单个脚本配置。 +func TestBuildLuaMiddlewares_SingleScript(t *testing.T) { + // 创建临时 Lua 脚本 + tempDir := t.TempDir() + scriptPath := tempDir + "/test.lua" + if err := os.WriteFile(scriptPath, []byte("ngx.say('hello')"), 0o644); err != nil { + t.Fatalf("failed to create script: %v", err) + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + luaEngine, err := lua.NewEngine(lua.DefaultConfig()) + if err != nil { + t.Skipf("failed to create Lua engine: %v", err) + } + s.luaEngine = luaEngine + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + {Path: scriptPath, Phase: "access", Timeout: 10 * time.Second, Enabled: true}, + }, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + if len(middlewares) != 1 { + t.Errorf("expected 1 middleware, got: %d", len(middlewares)) + } +} + +// TestBuildLuaMiddlewares_SingleScriptDefaultTimeout 测试单脚本默认超时。 +func TestBuildLuaMiddlewares_SingleScriptDefaultTimeout(t *testing.T) { + tempDir := t.TempDir() + scriptPath := tempDir + "/test.lua" + if err := os.WriteFile(scriptPath, []byte("ngx.say('hello')"), 0o644); err != nil { + t.Fatalf("failed to create script: %v", err) + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + luaEngine, err := lua.NewEngine(lua.DefaultConfig()) + if err != nil { + t.Skipf("failed to create Lua engine: %v", err) + } + s.luaEngine = luaEngine + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + {Path: scriptPath, Phase: "content", Timeout: 0}, // 使用默认超时 + }, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + if len(middlewares) != 1 { + t.Errorf("expected 1 middleware, got: %d", len(middlewares)) + } +} + +// TestBuildLuaMiddlewares_MultipleScriptsSamePhase 测试多脚本同阶段。 +func TestBuildLuaMiddlewares_MultipleScriptsSamePhase(t *testing.T) { + tempDir := t.TempDir() + script1 := tempDir + "/test1.lua" + script2 := tempDir + "/test2.lua" + if err := os.WriteFile(script1, []byte("ngx.say('1')"), 0o644); err != nil { + t.Fatalf("failed to create script: %v", err) + } + if err := os.WriteFile(script2, []byte("ngx.say('2')"), 0o644); err != nil { + t.Fatalf("failed to create script: %v", err) + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + luaEngine, err := lua.NewEngine(lua.DefaultConfig()) + if err != nil { + t.Skipf("failed to create Lua engine: %v", err) + } + s.luaEngine = luaEngine + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + {Path: script1, Phase: "access", Timeout: 10 * time.Second, Enabled: true}, + {Path: script2, Phase: "access", Timeout: 20 * time.Second, Enabled: true}, + }, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + if len(middlewares) != 1 { + t.Errorf("expected 1 middleware (multi-phase), got: %d", len(middlewares)) + } +} + +// TestBuildLuaMiddlewares_MultipleScriptsDifferentPhases 测试多脚本不同阶段。 +func TestBuildLuaMiddlewares_MultipleScriptsDifferentPhases(t *testing.T) { + tempDir := t.TempDir() + script1 := tempDir + "/rewrite.lua" + script2 := tempDir + "/access.lua" + script3 := tempDir + "/log.lua" + for _, p := range []string{script1, script2, script3} { + if err := os.WriteFile(p, []byte("ngx.say('hello')"), 0o644); err != nil { + t.Fatalf("failed to create script: %v", err) + } + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + luaEngine, err := lua.NewEngine(lua.DefaultConfig()) + if err != nil { + t.Skipf("failed to create Lua engine: %v", err) + } + s.luaEngine = luaEngine + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + {Path: script1, Phase: "rewrite", Timeout: 10 * time.Second, Enabled: true}, + {Path: script2, Phase: "access", Timeout: 15 * time.Second, Enabled: true}, + {Path: script3, Phase: "log", Timeout: 20 * time.Second, Enabled: true}, + }, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + if len(middlewares) != 3 { + t.Errorf("expected 3 middlewares, got: %d", len(middlewares)) + } +} + +// TestBuildLuaMiddlewares_DefaultEnabled 测试默认启用逻辑。 +func TestBuildLuaMiddlewares_DefaultEnabled(t *testing.T) { + tempDir := t.TempDir() + scriptPath := tempDir + "/test.lua" + if err := os.WriteFile(scriptPath, []byte("ngx.say('hello')"), 0o644); err != nil { + t.Fatalf("failed to create script: %v", err) + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + luaEngine, err := lua.NewEngine(lua.DefaultConfig()) + if err != nil { + t.Skipf("failed to create Lua engine: %v", err) + } + s.luaEngine = luaEngine + + // Enabled 为 false,但 Timeout=0 且 Path 不为空,应该默认启用 + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + {Path: scriptPath, Phase: "access", Timeout: 0, Enabled: false}, + }, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + // 默认启用逻辑:Enabled=false && Timeout=0 && Path!="" -> enabled=true + if len(middlewares) != 1 { + t.Errorf("expected 1 middleware (default enabled), got: %d", len(middlewares)) + } +} + +// TestBuildLuaMiddlewares_InvalidPhaseInMultiScript 测试多脚本中的无效阶段。 +func TestBuildLuaMiddlewares_InvalidPhaseInMultiScript(t *testing.T) { + tempDir := t.TempDir() + script1 := tempDir + "/test1.lua" + script2 := tempDir + "/test2.lua" + if err := os.WriteFile(script1, []byte("ngx.say('1')"), 0o644); err != nil { + t.Fatalf("failed to create script: %v", err) + } + if err := os.WriteFile(script2, []byte("ngx.say('2')"), 0o644); err != nil { + t.Fatalf("failed to create script: %v", err) + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + luaEngine, err := lua.NewEngine(lua.DefaultConfig()) + if err != nil { + t.Skipf("failed to create Lua engine: %v", err) + } + s.luaEngine = luaEngine + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + {Path: script1, Phase: "access", Timeout: 10 * time.Second, Enabled: true}, + {Path: script2, Phase: "invalid_phase", Timeout: 10 * time.Second, Enabled: true}, + }, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + if err == nil { + t.Error("expected error for invalid phase in multi-script") + } + if middlewares != nil { + t.Errorf("expected nil middlewares on error, got: %v", middlewares) + } +} + +// TestBuildLuaMiddlewares_AllPhases 测试所有阶段。 +func TestBuildLuaMiddlewares_AllPhases(t *testing.T) { + tempDir := t.TempDir() + phases := []string{"rewrite", "access", "content", "log", "header_filter", "body_filter"} + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + luaEngine, err := lua.NewEngine(lua.DefaultConfig()) + if err != nil { + t.Skipf("failed to create Lua engine: %v", err) + } + s.luaEngine = luaEngine + + scripts := make([]config.LuaScriptConfig, len(phases)) + for i, phase := range phases { + scriptPath := tempDir + "/" + phase + ".lua" + if err := os.WriteFile(scriptPath, []byte("ngx.say('"+phase+"')"), 0o644); err != nil { + t.Fatalf("failed to create script: %v", err) + } + scripts[i] = config.LuaScriptConfig{Path: scriptPath, Phase: phase, Timeout: 10 * time.Second, Enabled: true} + } + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: scripts, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + if len(middlewares) != len(phases) { + t.Errorf("expected %d middlewares, got: %d", len(phases), len(middlewares)) + } +} + +// TestBuildLuaMiddlewares_NonExistentScript 测试不存在的脚本文件。 +func TestBuildLuaMiddlewares_NonExistentScript(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + luaEngine, err := lua.NewEngine(lua.DefaultConfig()) + if err != nil { + t.Skipf("failed to create Lua engine: %v", err) + } + s.luaEngine = luaEngine + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + {Path: "/non/existent/script.lua", Phase: "access", Timeout: 10 * time.Second}, + }, + } + + // NewLuaMiddleware 会在创建时验证脚本文件 + middlewares, err := s.buildLuaMiddlewares(luaCfg) + // 由于脚本不存在,可能会返回错误或创建失败 + // 这取决于 lua.NewLuaMiddleware 的实现 + _ = middlewares + _ = err +} + +// TestBuildLuaMiddlewares_MixedEnabledDisabled 测试混合启用禁用脚本。 +func TestBuildLuaMiddlewares_MixedEnabledDisabled(t *testing.T) { + tempDir := t.TempDir() + for _, name := range []string{"enabled1", "enabled2", "disabled1", "disabled2"} { + scriptPath := tempDir + "/" + name + ".lua" + if err := os.WriteFile(scriptPath, []byte("ngx.say('"+name+"')"), 0o644); err != nil { + t.Fatalf("failed to create script: %v", err) + } + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + luaEngine, err := lua.NewEngine(lua.DefaultConfig()) + if err != nil { + t.Skipf("failed to create Lua engine: %v", err) + } + s.luaEngine = luaEngine + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + {Path: tempDir + "/enabled1.lua", Phase: "rewrite", Timeout: 10 * time.Second, Enabled: true}, + {Path: tempDir + "/disabled1.lua", Phase: "rewrite", Timeout: 10 * time.Second, Enabled: false}, + {Path: tempDir + "/enabled2.lua", Phase: "access", Timeout: 10 * time.Second, Enabled: true}, + {Path: tempDir + "/disabled2.lua", Phase: "access", Timeout: 10 * time.Second, Enabled: false}, + }, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + // 只有启用的脚本应该被处理:rewrite(1) + access(1) = 2 + if len(middlewares) != 2 { + t.Errorf("expected 2 middlewares, got: %d", len(middlewares)) + } +} + +// TestBuildLuaMiddlewares_MultiPhaseDefaultTimeout 测试多脚本阶段默认超时。 +func TestBuildLuaMiddlewares_MultiPhaseDefaultTimeout(t *testing.T) { + tempDir := t.TempDir() + script1 := tempDir + "/test1.lua" + script2 := tempDir + "/test2.lua" + if err := os.WriteFile(script1, []byte("ngx.say('1')"), 0o644); err != nil { + t.Fatalf("failed to create script: %v", err) + } + if err := os.WriteFile(script2, []byte("ngx.say('2')"), 0o644); err != nil { + t.Fatalf("failed to create script: %v", err) + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: ":8080", + }}, + } + + s := New(cfg) + luaEngine, err := lua.NewEngine(lua.DefaultConfig()) + if err != nil { + t.Skipf("failed to create Lua engine: %v", err) + } + s.luaEngine = luaEngine + + luaCfg := &config.LuaMiddlewareConfig{ + Enabled: true, + Scripts: []config.LuaScriptConfig{ + {Path: script1, Phase: "access", Timeout: 0}, // 默认超时 + {Path: script2, Phase: "access", Timeout: 0}, // 默认超时 + }, + } + + middlewares, err := s.buildLuaMiddlewares(luaCfg) + if err != nil { + t.Errorf("expected nil error, got: %v", err) + } + if len(middlewares) != 1 { + t.Errorf("expected 1 middleware (multi-phase), got: %d", len(middlewares)) + } +} diff --git a/internal/server/startmultiservermode_test.go b/internal/server/startmultiservermode_test.go new file mode 100644 index 0000000..98db156 --- /dev/null +++ b/internal/server/startmultiservermode_test.go @@ -0,0 +1,1216 @@ +// Package server 提供 startMultiServerMode 集成测试。 +// +// 该文件测试 startMultiServerMode 函数的各种配置场景, +// 包括多服务器配置、监听器创建、服务器启动等场景。 +// +// 作者:xfy +package server + +import ( + "os" + "strings" + "testing" + "time" + + "rua.plus/lolly/internal/config" +) + +// TestStartMultiServerMode_BasicConfig 测试基本的多服务器配置。 +func TestStartMultiServerMode_BasicConfig(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{ + {Listen: "127.0.0.1:0"}, + {Listen: "127.0.0.1:0"}, + }, + } + + s := New(cfg) + + // 验证多服务器配置 + if len(s.config.Servers) != 2 { + t.Errorf("expected 2 servers, got %d", len(s.config.Servers)) + } +} + +// TestStartMultiServerMode_ThreeServers 测试三个服务器配置。 +func TestStartMultiServerMode_ThreeServers(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{ + {Listen: "127.0.0.1:0"}, + {Listen: "127.0.0.1:0"}, + {Listen: "127.0.0.1:0"}, + }, + } + + s := New(cfg) + + if len(s.config.Servers) != 3 { + t.Errorf("expected 3 servers, got %d", len(s.config.Servers)) + } +} + +// TestStartMultiServerMode_WithProxy 测试带代理的多服务器配置。 +func TestStartMultiServerMode_WithProxy(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:8081", Weight: 1}, + }, + }, + }, + }, + { + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:8082", Weight: 1}, + }, + }, + }, + }, + }, + } + + s := New(cfg) + + if len(s.config.Servers[0].Proxy) != 1 { + t.Errorf("expected 1 proxy config for server 0, got %d", len(s.config.Servers[0].Proxy)) + } + if len(s.config.Servers[1].Proxy) != 1 { + t.Errorf("expected 1 proxy config for server 1, got %d", len(s.config.Servers[1].Proxy)) + } +} + +// TestStartMultiServerMode_WithStaticFiles 测试带静态文件的多服务器配置。 +func TestStartMultiServerMode_WithStaticFiles(t *testing.T) { + tempDir := t.TempDir() + + cfg := &config.Config{ + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Static: []config.StaticConfig{ + { + Path: "/static", + Root: tempDir, + Index: []string{"index.html"}, + }, + }, + }, + { + Listen: "127.0.0.1:0", + Static: []config.StaticConfig{ + { + Path: "/assets", + Root: tempDir, + Index: []string{"index.html"}, + }, + }, + }, + }, + } + + s := New(cfg) + + if len(s.config.Servers[0].Static) != 1 { + t.Errorf("expected 1 static config for server 0, got %d", len(s.config.Servers[0].Static)) + } + if len(s.config.Servers[1].Static) != 1 { + t.Errorf("expected 1 static config for server 1, got %d", len(s.config.Servers[1].Static)) + } +} + +// TestStartMultiServerMode_WithCacheAPI 测试带缓存 API 的多服务器配置。 +func TestStartMultiServerMode_WithCacheAPI(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + CacheAPI: &config.CacheAPIConfig{ + Enabled: true, + Path: "/_cache/purge", + Allow: []string{"127.0.0.1"}, + }, + }, + { + Listen: "127.0.0.1:0", + }, + }, + } + + s := New(cfg) + + if s.config.Servers[0].CacheAPI == nil || !s.config.Servers[0].CacheAPI.Enabled { + t.Error("expected cache API enabled on server 0") + } +} + +// TestStartMultiServerMode_WithMiddleware 测试带中间件的多服务器配置。 +func TestStartMultiServerMode_WithMiddleware(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + Access: config.AccessConfig{ + Allow: []string{"127.0.0.1"}, + }, + RateLimit: config.RateLimitConfig{ + RequestRate: 100, + Burst: 200, + }, + }, + }, + { + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + Headers: config.SecurityHeaders{ + XFrameOptions: "DENY", + }, + }, + }, + }, + } + + s := New(cfg) + + if len(s.config.Servers[0].Security.Access.Allow) != 1 { + t.Errorf("expected 1 allow rule for server 0, got %d", len(s.config.Servers[0].Security.Access.Allow)) + } + if s.config.Servers[1].Security.Headers.XFrameOptions != "DENY" { + t.Errorf("expected XFrameOptions DENY for server 1, got %s", s.config.Servers[1].Security.Headers.XFrameOptions) + } +} + +// TestStartMultiServerMode_WithCompression 测试带压缩配置的多服务器配置。 +func TestStartMultiServerMode_WithCompression(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Compression: config.CompressionConfig{ + Type: "gzip", + Level: 6, + }, + }, + { + Listen: "127.0.0.1:0", + Compression: config.CompressionConfig{ + Type: "gzip", + Level: 9, + }, + }, + }, + } + + s := New(cfg) + + if s.config.Servers[0].Compression.Level != 6 { + t.Errorf("expected compression level 6 for server 0, got %d", s.config.Servers[0].Compression.Level) + } + if s.config.Servers[1].Compression.Level != 9 { + t.Errorf("expected compression level 9 for server 1, got %d", s.config.Servers[1].Compression.Level) + } +} + +// TestStartMultiServerMode_ServerOptions 测试服务器选项配置。 +func TestStartMultiServerMode_ServerOptions(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + MaxConnsPerIP: 100, + MaxRequestsPerConn: 1000, + }, + { + Listen: "127.0.0.1:0", + ReadTimeout: 15 * time.Second, + WriteTimeout: 15 * time.Second, + MaxConnsPerIP: 50, + MaxRequestsPerConn: 500, + }, + }, + } + + s := New(cfg) + + if s.config.Servers[0].ReadTimeout != 30*time.Second { + t.Errorf("expected ReadTimeout 30s for server 0, got %v", s.config.Servers[0].ReadTimeout) + } + if s.config.Servers[1].MaxConnsPerIP != 50 { + t.Errorf("expected MaxConnsPerIP 50 for server 1, got %d", s.config.Servers[1].MaxConnsPerIP) + } +} + +// TestStartMultiServerMode_Integration_Basic 测试多服务器模式基本启动。 +func TestStartMultiServerMode_Integration_Basic(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + {Listen: "127.0.0.1:0"}, + {Listen: "127.0.0.1:0"}, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_Integration_WithProxy 测试多服务器模式带代理启动。 +func TestStartMultiServerMode_Integration_WithProxy(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:9999", Weight: 1}, + }, + }, + }, + }, + { + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:9998", Weight: 1}, + }, + }, + }, + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_Integration_WithStaticFiles 测试多服务器模式带静态文件启动。 +func TestStartMultiServerMode_Integration_WithStaticFiles(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + tempDir := t.TempDir() + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Static: []config.StaticConfig{ + { + Path: "/static", + Root: tempDir, + Index: []string{"index.html"}, + }, + }, + }, + { + Listen: "127.0.0.1:0", + Static: []config.StaticConfig{ + { + Path: "/assets", + Root: tempDir, + Index: []string{"index.html"}, + }, + }, + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_Integration_WithCacheAPI 测试多服务器模式带缓存 API 启动。 +func TestStartMultiServerMode_Integration_WithCacheAPI(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + CacheAPI: &config.CacheAPIConfig{ + Enabled: true, + Path: "/_cache/purge", + }, + }, + { + Listen: "127.0.0.1:0", + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_Integration_WithHealthCheck 测试多服务器模式带健康检查启动。 +func TestStartMultiServerMode_Integration_WithHealthCheck(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:9999", Weight: 1}, + }, + HealthCheck: config.HealthCheckConfig{ + Interval: 1 * time.Second, + Timeout: 500 * time.Millisecond, + Path: "/health", + }, + }, + }, + }, + { + Listen: "127.0.0.1:0", + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(100 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_Integration_WithMiddleware 测试多服务器模式带中间件启动。 +func TestStartMultiServerMode_Integration_WithMiddleware(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + Access: config.AccessConfig{ + Allow: []string{"127.0.0.1"}, + }, + RateLimit: config.RateLimitConfig{ + RequestRate: 100, + Burst: 200, + }, + }, + }, + { + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + Headers: config.SecurityHeaders{ + XFrameOptions: "DENY", + }, + }, + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_Integration_WithPerformance 测试多服务器模式带性能配置启动。 +func TestStartMultiServerMode_Integration_WithPerformance(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + {Listen: "127.0.0.1:0"}, + {Listen: "127.0.0.1:0"}, + }, + Performance: config.PerformanceConfig{ + GoroutinePool: config.GoroutinePoolConfig{ + Enabled: true, + MaxWorkers: 50, + MinWorkers: 5, + IdleTimeout: 10 * time.Second, + }, + FileCache: config.FileCacheConfig{ + MaxEntries: 1000, + MaxSize: 10 * 1024 * 1024, + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_Integration_ThreeServers 测试三服务器模式启动。 +func TestStartMultiServerMode_Integration_ThreeServers(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + {Listen: "127.0.0.1:0"}, + {Listen: "127.0.0.1:0"}, + {Listen: "127.0.0.1:0"}, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_Integration_WithCompression 测试多服务器模式带压缩启动。 +func TestStartMultiServerMode_Integration_WithCompression(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Compression: config.CompressionConfig{ + Type: "gzip", + Level: 6, + }, + }, + { + Listen: "127.0.0.1:0", + Compression: config.CompressionConfig{ + Type: "gzip", + Level: 9, + }, + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_Integration_WithRewrite 测试多服务器模式带重写启动。 +func TestStartMultiServerMode_Integration_WithRewrite(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Rewrite: []config.RewriteRule{ + {Pattern: "^/old/(.*)$", Replacement: "/new/$1"}, + }, + }, + { + Listen: "127.0.0.1:0", + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_Integration_WithConnLimiter 测试多服务器模式带连接限制启动。 +func TestStartMultiServerMode_Integration_WithConnLimiter(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + RateLimit: config.RateLimitConfig{ + ConnLimit: 10, + }, + }, + }, + { + Listen: "127.0.0.1:0", + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_Integration_MixedConfigs 测试多服务器模式混合配置启动。 +func TestStartMultiServerMode_Integration_MixedConfigs(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + tempDir := t.TempDir() + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:9999", Weight: 1}, + }, + }, + }, + CacheAPI: &config.CacheAPIConfig{ + Enabled: true, + Path: "/_cache/purge", + }, + }, + { + Listen: "127.0.0.1:0", + Static: []config.StaticConfig{ + { + Path: "/static", + Root: tempDir, + Index: []string{"index.html"}, + }, + }, + Compression: config.CompressionConfig{ + Type: "gzip", + Level: 6, + }, + }, + { + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + Access: config.AccessConfig{ + Allow: []string{"127.0.0.1"}, + }, + }, + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_GracefulUpgradeFallback 测试热升级模式回退。 +func TestStartMultiServerMode_GracefulUpgradeFallback(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + // 设置热升级环境变量 + originalValue := os.Getenv("GRACEFUL_UPGRADE") + defer os.Setenv("GRACEFUL_UPGRADE", originalValue) + + os.Setenv("GRACEFUL_UPGRADE", "1") + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + {Listen: "127.0.0.1:0"}, + {Listen: "127.0.0.1:0"}, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_WithUnixSocket 测试 Unix Socket 配置。 +func TestStartMultiServerMode_WithUnixSocket(t *testing.T) { + tempDir := t.TempDir() + + cfg := &config.Config{ + Servers: []config.ServerConfig{ + { + Listen: "unix:" + tempDir + "/test1.sock", + }, + { + Listen: "unix:" + tempDir + "/test2.sock", + }, + }, + } + + s := New(cfg) + + if !strings.HasPrefix(s.config.Servers[0].Listen, "unix:") { + t.Errorf("expected unix socket for server 0, got %s", s.config.Servers[0].Listen) + } + if !strings.HasPrefix(s.config.Servers[1].Listen, "unix:") { + t.Errorf("expected unix socket for server 1, got %s", s.config.Servers[1].Listen) + } +} + +// TestStartMultiServerMode_WithDifferentListens 测试不同监听地址配置。 +func TestStartMultiServerMode_WithDifferentListens(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{ + {Listen: "127.0.0.1:0"}, + {Listen: "0.0.0.0:0"}, + }, + } + + s := New(cfg) + + if s.config.Servers[0].Listen != "127.0.0.1:0" { + t.Errorf("expected listen 127.0.0.1:0 for server 0, got %s", s.config.Servers[0].Listen) + } + if s.config.Servers[1].Listen != "0.0.0.0:0" { + t.Errorf("expected listen 0.0.0.0:0 for server 1, got %s", s.config.Servers[1].Listen) + } +} + +// TestStartMultiServerMode_Integration_WithErrorPage 测试多服务器模式带错误页面启动。 +func TestStartMultiServerMode_Integration_WithErrorPage(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + tempDir := t.TempDir() + errorPage := tempDir + "/404.html" + if err := os.WriteFile(errorPage, []byte("Not Found"), 0o644); err != nil { + t.Fatalf("failed to create error page: %v", err) + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + ErrorPage: config.ErrorPageConfig{ + Pages: map[int]string{404: errorPage}, + Default: errorPage, + }, + }, + }, + { + Listen: "127.0.0.1:0", + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_Integration_WithMIMETypes 测试多服务器模式带 MIME 类型启动。 +func TestStartMultiServerMode_Integration_WithMIMETypes(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Types: config.TypesConfig{ + Map: map[string]string{ + ".wasm": "application/wasm", + }, + DefaultType: "application/octet-stream", + }, + }, + { + Listen: "127.0.0.1:0", + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_WithServerNames 测试带服务器名称的多服务器配置。 +func TestStartMultiServerMode_WithServerNames(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{ + { + Name: "server1", + Listen: "127.0.0.1:0", + ServerNames: []string{"example.com", "www.example.com"}, + }, + { + Name: "server2", + Listen: "127.0.0.1:0", + ServerNames: []string{"api.example.com"}, + }, + }, + } + + s := New(cfg) + + if s.config.Servers[0].Name != "server1" { + t.Errorf("expected name server1, got %s", s.config.Servers[0].Name) + } + if len(s.config.Servers[0].ServerNames) != 2 { + t.Errorf("expected 2 server names for server 0, got %d", len(s.config.Servers[0].ServerNames)) + } + if s.config.Servers[1].Name != "server2" { + t.Errorf("expected name server2, got %s", s.config.Servers[1].Name) + } +} + +// TestStartMultiServerMode_WithProxyLocationTypes 测试代理不同位置类型配置。 +func TestStartMultiServerMode_WithProxyLocationTypes(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api/exact", + LocationType: "exact", + Targets: []config.ProxyTarget{{URL: "http://127.0.0.1:9999", Weight: 1}}, + }, + { + Path: "/api/priority", + LocationType: "prefix_priority", + Targets: []config.ProxyTarget{{URL: "http://127.0.0.1:9999", Weight: 1}}, + }, + }, + }, + { + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "^/api/regex/(.*)$", + LocationType: "regex", + Targets: []config.ProxyTarget{{URL: "http://127.0.0.1:9999", Weight: 1}}, + }, + }, + }, + }, + } + + s := New(cfg) + + if s.config.Servers[0].Proxy[0].LocationType != "exact" { + t.Errorf("expected exact location type, got %s", s.config.Servers[0].Proxy[0].LocationType) + } + if s.config.Servers[0].Proxy[1].LocationType != "prefix_priority" { + t.Errorf("expected prefix_priority location type, got %s", s.config.Servers[0].Proxy[1].LocationType) + } + if s.config.Servers[1].Proxy[0].LocationType != "regex" { + t.Errorf("expected regex location type, got %s", s.config.Servers[1].Proxy[0].LocationType) + } +} + +// TestStartMultiServerMode_Integration_WithAuthRequest 测试多服务器模式带外部认证启动。 +func TestStartMultiServerMode_Integration_WithAuthRequest(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + AuthRequest: config.AuthRequestConfig{ + Enabled: true, + URI: "/auth/validate", + Timeout: 5 * time.Second, + }, + }, + }, + { + Listen: "127.0.0.1:0", + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartMultiServerMode_ServerTokens 测试 ServerTokens 配置。 +func TestStartMultiServerMode_ServerTokens(t *testing.T) { + cfg := &config.Config{ + Servers: []config.ServerConfig{ + { + Listen: "127.0.0.1:0", + ServerTokens: false, + }, + { + Listen: "127.0.0.1:0", + ServerTokens: true, + }, + }, + } + + s := New(cfg) + + if s.config.Servers[0].ServerTokens { + t.Error("expected ServerTokens false for server 0") + } + if !s.config.Servers[1].ServerTokens { + t.Error("expected ServerTokens true for server 1") + } +} + +// TestStartMultiServerMode_Integration_WithDefaultServer 测试多服务器模式带默认服务器启动。 +func TestStartMultiServerMode_Integration_WithDefaultServer(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Mode: config.ServerModeMultiServer, + Servers: []config.ServerConfig{ + { + Name: "default", + Listen: "127.0.0.1:0", + Default: true, + ServerNames: []string{"_"}, + }, + { + Name: "api", + Listen: "127.0.0.1:0", + ServerNames: []string{"api.example.com"}, + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedMultiServerError(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// isExpectedMultiServerError 检查是否是预期的多服务器关闭错误。 +func isExpectedMultiServerError(err error) bool { + if err == nil { + return true + } + errStr := err.Error() + return strings.Contains(errStr, "closed") || + strings.Contains(errStr, "use of closed") || + strings.Contains(errStr, "listener closed") +} diff --git a/internal/server/startsinglemode_test.go b/internal/server/startsinglemode_test.go new file mode 100644 index 0000000..a435887 --- /dev/null +++ b/internal/server/startsinglemode_test.go @@ -0,0 +1,671 @@ +// Package server 提供 startSingleMode 集成测试。 +// +// 该文件测试 startSingleMode 函数的各种配置场景, +// 包括静态文件、代理、监控端点、TLS 等配置的实际启动。 +// +// 作者:xfy +package server + +import ( + "os" + "strings" + "testing" + "time" + + "rua.plus/lolly/internal/config" +) + +// TestStartSingleMode_Integration_WithStaticFiles 测试 startSingleMode 静态文件实际启动。 +func TestStartSingleMode_Integration_WithStaticFiles(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + tempDir := t.TempDir() + // 创建一个测试文件 + testFile := tempDir + "/index.html" + if err := os.WriteFile(testFile, []byte("test"), 0o644); err != nil { + t.Fatalf("failed to create test file: %v", err) + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Static: []config.StaticConfig{ + { + Path: "/static", + Root: tempDir, + Index: []string{"index.html"}, + }, + }, + }}, + } + + s := New(cfg) + + // 在 goroutine 中启动服务器 + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + // 等待服务器启动 + time.Sleep(50 * time.Millisecond) + + // 停止服务器 + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + // 服务器正常关闭会返回 nil 或 listener closed 错误 + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + // 服务器仍在运行,已通过 GracefulStop 关闭 + } +} + +// TestStartSingleMode_Integration_WithProxy 测试 startSingleMode 代理实际启动。 +func TestStartSingleMode_Integration_WithProxy(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:9999", Weight: 1}, // 不存在的后端 + }, + }, + }, + }}, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartSingleMode_Integration_WithMonitoring 测试 startSingleMode 监控端点实际启动。 +func TestStartSingleMode_Integration_WithMonitoring(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + }}, + Monitoring: config.MonitoringConfig{ + Status: config.StatusConfig{ + Enabled: true, + Path: "/_status", + Allow: []string{"127.0.0.1"}, + }, + Pprof: config.PprofConfig{ + Enabled: true, + Path: "/debug/pprof", + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartSingleMode_Integration_WithCacheAPI 测试 startSingleMode 缓存 API 实际启动。 +func TestStartSingleMode_Integration_WithCacheAPI(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + CacheAPI: &config.CacheAPIConfig{ + Enabled: true, + Path: "/_cache/purge", + }, + }}, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartSingleMode_Integration_WithCompression 测试 startSingleMode 压缩配置实际启动。 +func TestStartSingleMode_Integration_WithCompression(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Compression: config.CompressionConfig{ + Type: "gzip", + Level: 6, + }, + }}, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartSingleMode_Integration_WithSecurity 测试 startSingleMode 安全配置实际启动。 +func TestStartSingleMode_Integration_WithSecurity(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + Access: config.AccessConfig{ + Allow: []string{"127.0.0.1"}, + }, + RateLimit: config.RateLimitConfig{ + RequestRate: 100, + Burst: 200, + }, + Headers: config.SecurityHeaders{ + XFrameOptions: "DENY", + }, + }, + }}, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartSingleMode_Integration_WithRewrite 测试 startSingleMode 重写配置实际启动。 +func TestStartSingleMode_Integration_WithRewrite(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Rewrite: []config.RewriteRule{ + {Pattern: "^/old/(.*)$", Replacement: "/new/$1"}, + }, + }}, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartSingleMode_Integration_WithPerformance 测试 startSingleMode 性能配置实际启动。 +func TestStartSingleMode_Integration_WithPerformance(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + }}, + Performance: config.PerformanceConfig{ + GoroutinePool: config.GoroutinePoolConfig{ + Enabled: true, + MaxWorkers: 50, + MinWorkers: 5, + IdleTimeout: 10 * time.Second, + }, + FileCache: config.FileCacheConfig{ + MaxEntries: 1000, + MaxSize: 10 * 1024 * 1024, + }, + }, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartSingleMode_Integration_WithProxyLocationTypes 测试代理不同位置类型实际启动。 +func TestStartSingleMode_Integration_WithProxyLocationTypes(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api/exact", + LocationType: "exact", + Targets: []config.ProxyTarget{{URL: "http://127.0.0.1:9999", Weight: 1}}, + }, + { + Path: "/api/priority", + LocationType: "prefix_priority", + Targets: []config.ProxyTarget{{URL: "http://127.0.0.1:9999", Weight: 1}}, + }, + { + Path: "^/api/regex/(.*)$", + LocationType: "regex", + Targets: []config.ProxyTarget{{URL: "http://127.0.0.1:9999", Weight: 1}}, + }, + }, + }}, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartSingleMode_Integration_WithStaticLocationTypes 测试静态文件不同位置类型实际启动。 +func TestStartSingleMode_Integration_WithStaticLocationTypes(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + tempDir := t.TempDir() + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Static: []config.StaticConfig{ + { + Path: "/static/exact", + Root: tempDir, + LocationType: "exact", + }, + { + Path: "/static/priority", + Root: tempDir, + LocationType: "prefix_priority", + }, + }, + }}, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartSingleMode_Integration_WithHealthCheck 测试代理健康检查实际启动。 +func TestStartSingleMode_Integration_WithHealthCheck(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Proxy: []config.ProxyConfig{ + { + Path: "/api", + Targets: []config.ProxyTarget{ + {URL: "http://127.0.0.1:9999", Weight: 1}, + }, + HealthCheck: config.HealthCheckConfig{ + Interval: 1 * time.Second, + Timeout: 500 * time.Millisecond, + Path: "/health", + }, + }, + }, + }}, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(100 * time.Millisecond) // 给健康检查一些时间启动 + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartSingleMode_Integration_WithMIMETypes 测试 MIME 类型配置实际启动。 +func TestStartSingleMode_Integration_WithMIMETypes(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Types: config.TypesConfig{ + Map: map[string]string{ + ".wasm": "application/wasm", + }, + DefaultType: "application/octet-stream", + }, + }}, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartSingleMode_Integration_WithErrorPage 测试错误页面配置实际启动。 +func TestStartSingleMode_Integration_WithErrorPage(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + tempDir := t.TempDir() + errorPage := tempDir + "/404.html" + if err := os.WriteFile(errorPage, []byte("Not Found"), 0o644); err != nil { + t.Fatalf("failed to create error page: %v", err) + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + ErrorPage: config.ErrorPageConfig{ + Pages: map[int]string{404: errorPage}, + Default: errorPage, + }, + }, + }}, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartSingleMode_Integration_WithConnLimiter 测试连接限制配置实际启动。 +func TestStartSingleMode_Integration_WithConnLimiter(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + RateLimit: config.RateLimitConfig{ + ConnLimit: 10, + }, + }, + }}, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// TestStartSingleMode_Integration_WithAuthRequest 测试外部认证配置实际启动。 +func TestStartSingleMode_Integration_WithAuthRequest(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + cfg := &config.Config{ + Servers: []config.ServerConfig{{ + Listen: "127.0.0.1:0", + Security: config.SecurityConfig{ + AuthRequest: config.AuthRequestConfig{ + Enabled: true, + URI: "/auth/validate", + Timeout: 5 * time.Second, + }, + }, + }}, + } + + s := New(cfg) + + errCh := make(chan error, 1) + go func() { + if err := s.Start(); err != nil { + errCh <- err + } + }() + + time.Sleep(50 * time.Millisecond) + _ = s.GracefulStop(2 * time.Second) + + select { + case err := <-errCh: + if err != nil && !isExpectedServerErrorForIntegration(err) { + t.Errorf("unexpected server error: %v", err) + } + default: + } +} + +// isExpectedServerErrorForIntegration 检查是否是预期的服务器关闭错误。 +func isExpectedServerErrorForIntegration(err error) bool { + if err == nil { + return true + } + // fasthttp 服务器关闭时的常见错误 + errStr := err.Error() + return strings.Contains(errStr, "closed") || + strings.Contains(errStr, "use of closed") || + strings.Contains(errStr, "listener closed") +} diff --git a/internal/server/status_test.go b/internal/server/status_test.go index 14f82b1..48d6fdc 100644 --- a/internal/server/status_test.go +++ b/internal/server/status_test.go @@ -1211,3 +1211,1072 @@ func TestStatusHandler_ServeHTTP_AccessAllowed_XForwardedFor(t *testing.T) { t.Errorf("expected status 200 for allowed access via X-Forwarded-For, got %d", ctx.Response.StatusCode()) } } + +// --------------------------------------------------------------------------- +// Direct serve* method tests with full Status data +// --------------------------------------------------------------------------- + +func TestServePrometheus_WithUpstreams(t *testing.T) { + cfg := &config.StatusConfig{ + Path: "/_status", + Format: "prometheus", + Allow: []string{}, + } + + srv := New(nil) + srv.startTime = time.Now() + + h, err := NewStatusHandler(srv, cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + status := &Status{ + Version: "1.0.0", + Uptime: 10 * time.Second, + Connections: 5, + Requests: 100, + BytesSent: 1024, + BytesReceived: 512, + Upstreams: []UpstreamStatus{ + { + Name: "backend", + HealthyCount: 3, + UnhealthyCount: 1, + LatencyP50: 10.5, + LatencyP95: 25.3, + LatencyP99: 50.1, + }, + }, + } + + ctx := &fasthttp.RequestCtx{} + h.servePrometheus(ctx, status) + + body := string(ctx.Response.Body()) + + // 验证 upstream 指标 + if !strings.Contains(body, "lolly_upstream_healthy_count") { + t.Error("expected prometheus output to contain lolly_upstream_healthy_count") + } + if !strings.Contains(body, `name="backend"`) { + t.Error("expected prometheus output to contain upstream name label") + } + if !strings.Contains(body, "lolly_upstream_healthy_count{name=\"backend\"} 3") { + t.Error("expected prometheus output to contain upstream healthy count") + } + if !strings.Contains(body, "lolly_upstream_unhealthy_count{name=\"backend\"} 1") { + t.Error("expected prometheus output to contain upstream unhealthy count") + } + if !strings.Contains(body, "lolly_upstream_latency_ms{name=\"backend\",quantile=\"0.5\"}") { + t.Error("expected prometheus output to contain upstream P50 latency") + } + if !strings.Contains(body, "lolly_upstream_latency_ms{name=\"backend\",quantile=\"0.95\"}") { + t.Error("expected prometheus output to contain upstream P95 latency") + } + if !strings.Contains(body, "lolly_upstream_latency_ms{name=\"backend\",quantile=\"0.99\"}") { + t.Error("expected prometheus output to contain upstream P99 latency") + } +} + +func TestServePrometheus_WithSSL(t *testing.T) { + cfg := &config.StatusConfig{ + Path: "/_status", + Format: "prometheus", + Allow: []string{}, + } + + srv := New(nil) + srv.startTime = time.Now() + + h, err := NewStatusHandler(srv, cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + status := &Status{ + Version: "1.0.0", + Uptime: 10 * time.Second, + Connections: 5, + Requests: 100, + BytesSent: 1024, + BytesReceived: 512, + SSL: &SSLStatus{ + Handshakes: 500, + SessionReused: 200, + ReuseRate: 40.0, + }, + } + + ctx := &fasthttp.RequestCtx{} + h.servePrometheus(ctx, status) + + body := string(ctx.Response.Body()) + + if !strings.Contains(body, "lolly_ssl_handshakes_total 500") { + t.Error("expected prometheus output to contain lolly_ssl_handshakes_total") + } + if !strings.Contains(body, "lolly_ssl_session_reused_total 200") { + t.Error("expected prometheus output to contain lolly_ssl_session_reused_total") + } + if !strings.Contains(body, "lolly_ssl_session_reuse_rate 40.00") { + t.Error("expected prometheus output to contain lolly_ssl_session_reuse_rate") + } +} + +func TestServePrometheus_WithRateLimits(t *testing.T) { + cfg := &config.StatusConfig{ + Path: "/_status", + Format: "prometheus", + Allow: []string{}, + } + + srv := New(nil) + srv.startTime = time.Now() + + h, err := NewStatusHandler(srv, cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + status := &Status{ + Version: "1.0.0", + Uptime: 10 * time.Second, + Connections: 5, + Requests: 100, + BytesSent: 1024, + BytesReceived: 512, + RateLimits: []RateLimitStatus{ + { + ZoneName: "api", + Requests: 1000, + Limit: 500, + Rejected: 50, + }, + { + ZoneName: "login", + Requests: 100, + Limit: 10, + Rejected: 5, + }, + }, + } + + ctx := &fasthttp.RequestCtx{} + h.servePrometheus(ctx, status) + + body := string(ctx.Response.Body()) + + if !strings.Contains(body, "lolly_rate_limit_requests{zone=\"api\"} 1000") { + t.Error("expected prometheus output to contain lolly_rate_limit_requests for api zone") + } + if !strings.Contains(body, "lolly_rate_limit_limit{zone=\"api\"} 500") { + t.Error("expected prometheus output to contain lolly_rate_limit_limit for api zone") + } + if !strings.Contains(body, "lolly_rate_limit_rejected_total{zone=\"api\"} 50") { + t.Error("expected prometheus output to contain lolly_rate_limit_rejected_total for api zone") + } + if !strings.Contains(body, "lolly_rate_limit_requests{zone=\"login\"} 100") { + t.Error("expected prometheus output to contain lolly_rate_limit_requests for login zone") + } +} + +func TestServePrometheus_WithProxyCache(t *testing.T) { + cfg := &config.StatusConfig{ + Path: "/_status", + Format: "prometheus", + Allow: []string{}, + } + + srv := New(nil) + srv.startTime = time.Now() + + h, err := NewStatusHandler(srv, cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + status := &Status{ + Version: "1.0.0", + Uptime: 10 * time.Second, + Connections: 5, + Requests: 100, + BytesSent: 1024, + BytesReceived: 512, + Cache: &CacheStats{ + FileCache: FileCacheStats{ + Entries: 50, + MaxEntries: 100, + Size: 10240, + MaxSize: 102400, + }, + ProxyCache: ProxyCacheStats{ + Entries: 25, + Pending: 5, + }, + }, + } + + ctx := &fasthttp.RequestCtx{} + h.servePrometheus(ctx, status) + + body := string(ctx.Response.Body()) + + if !strings.Contains(body, `lolly_cache_entries{type="file"} 50`) { + t.Error("expected prometheus output to contain file cache entries") + } + if !strings.Contains(body, `lolly_cache_entries{type="proxy"} 25`) { + t.Error("expected prometheus output to contain proxy cache entries") + } + if !strings.Contains(body, "lolly_cache_size_bytes 10240") { + t.Error("expected prometheus output to contain cache size bytes") + } + if !strings.Contains(body, "lolly_cache_pending 5") { + t.Error("expected prometheus output to contain cache pending") + } +} + +func TestServeJSON_WithFullStatus(t *testing.T) { + cfg := &config.StatusConfig{ + Path: "/_status", + Format: "json", + Allow: []string{}, + } + + srv := New(nil) + srv.startTime = time.Now() + + h, err := NewStatusHandler(srv, cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + status := &Status{ + Version: "1.0.0", + Uptime: 10 * time.Second, + Connections: 5, + Requests: 100, + BytesSent: 1024, + BytesReceived: 512, + Cache: &CacheStats{ + FileCache: FileCacheStats{ + Entries: 50, + MaxEntries: 100, + Size: 10240, + MaxSize: 102400, + }, + ProxyCache: ProxyCacheStats{ + Entries: 25, + Pending: 5, + }, + }, + Pool: &PoolStats{ + Workers: 10, + IdleWorkers: 5, + MaxWorkers: 20, + MinWorkers: 2, + QueueLen: 3, + QueueCap: 100, + }, + Upstreams: []UpstreamStatus{ + { + Name: "backend", + HealthyCount: 3, + UnhealthyCount: 1, + LatencyP50: 10.5, + LatencyP95: 25.3, + LatencyP99: 50.1, + }, + }, + SSL: &SSLStatus{ + Handshakes: 500, + SessionReused: 200, + ReuseRate: 40.0, + }, + RateLimits: []RateLimitStatus{ + { + ZoneName: "api", + Requests: 1000, + Limit: 500, + Rejected: 50, + }, + }, + } + + ctx := &fasthttp.RequestCtx{} + h.serveJSON(ctx, status) + + if ctx.Response.StatusCode() != 200 { + t.Errorf("expected status 200, got %d", ctx.Response.StatusCode()) + } + + ct := string(ctx.Response.Header.ContentType()) + if !strings.Contains(ct, "application/json") { + t.Errorf("expected content-type application/json, got %s", ct) + } + + // 验证 JSON 可解析并包含所有字段 + var parsed Status + if err := json.Unmarshal(ctx.Response.Body(), &parsed); err != nil { + t.Fatalf("failed to parse JSON response: %v", err) + } + + if parsed.Cache == nil { + t.Error("expected Cache to be populated") + } else if parsed.Cache.ProxyCache.Entries != 25 { + t.Errorf("expected ProxyCache Entries 25, got %d", parsed.Cache.ProxyCache.Entries) + } + + if parsed.Pool == nil { + t.Error("expected Pool to be populated") + } else if parsed.Pool.QueueCap != 100 { + t.Errorf("expected Pool QueueCap 100, got %d", parsed.Pool.QueueCap) + } + + if len(parsed.Upstreams) != 1 { + t.Errorf("expected 1 upstream, got %d", len(parsed.Upstreams)) + } else if parsed.Upstreams[0].Name != "backend" { + t.Errorf("expected upstream name 'backend', got %s", parsed.Upstreams[0].Name) + } + + if parsed.SSL == nil { + t.Error("expected SSL to be populated") + } else if parsed.SSL.Handshakes != 500 { + t.Errorf("expected SSL Handshakes 500, got %d", parsed.SSL.Handshakes) + } + + if len(parsed.RateLimits) != 1 { + t.Errorf("expected 1 rate limit, got %d", len(parsed.RateLimits)) + } else if parsed.RateLimits[0].ZoneName != "api" { + t.Errorf("expected rate limit zone 'api', got %s", parsed.RateLimits[0].ZoneName) + } +} + +func TestServeText_WithAllSections(t *testing.T) { + cfg := &config.StatusConfig{ + Path: "/_status", + Format: "text", + Allow: []string{}, + } + + srv := New(nil) + srv.startTime = time.Now() + + h, err := NewStatusHandler(srv, cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + status := &Status{ + Version: "1.0.0", + Uptime: 10 * time.Second, + Connections: 5, + Requests: 100, + BytesSent: 1024, + BytesReceived: 512, + Cache: &CacheStats{ + FileCache: FileCacheStats{ + Entries: 50, + MaxEntries: 100, + Size: 10240, + MaxSize: 102400, + }, + ProxyCache: ProxyCacheStats{ + Entries: 25, + Pending: 5, + }, + }, + Pool: &PoolStats{ + Workers: 10, + IdleWorkers: 5, + MaxWorkers: 20, + MinWorkers: 2, + QueueLen: 3, + QueueCap: 100, + }, + Upstreams: []UpstreamStatus{ + { + Name: "backend", + HealthyCount: 3, + UnhealthyCount: 1, + LatencyP50: 10.5, + LatencyP95: 25.3, + LatencyP99: 50.1, + }, + }, + SSL: &SSLStatus{ + Handshakes: 500, + SessionReused: 200, + ReuseRate: 40.0, + }, + RateLimits: []RateLimitStatus{ + { + ZoneName: "api", + Requests: 1000, + Limit: 500, + Rejected: 50, + }, + }, + } + + ctx := &fasthttp.RequestCtx{} + h.serveText(ctx, status) + + body := string(ctx.Response.Body()) + + // 验证基础信息 + if !strings.Contains(body, "Lolly Status") { + t.Error("expected text output to contain 'Lolly Status'") + } + if !strings.Contains(body, "Version: 1.0.0") { + t.Error("expected text output to contain Version") + } + if !strings.Contains(body, "Connections: 5") { + t.Error("expected text output to contain Connections") + } + + // 验证 Cache section + if !strings.Contains(body, "Cache:") { + t.Error("expected text output to contain Cache section") + } + if !strings.Contains(body, "File Entries: 50 / 100") { + t.Error("expected text output to contain file cache entries") + } + if !strings.Contains(body, "Proxy Entries: 25") { + t.Error("expected text output to contain proxy cache entries") + } + if !strings.Contains(body, "Proxy Pending: 5") { + t.Error("expected text output to contain proxy cache pending") + } + + // 验证 Pool section + if !strings.Contains(body, "Goroutine Pool:") { + t.Error("expected text output to contain Goroutine Pool section") + } + if !strings.Contains(body, "Workers: 10 (idle: 5)") { + t.Error("expected text output to contain Workers info") + } + if !strings.Contains(body, "Queue: 3 / 100") { + t.Error("expected text output to contain Queue info") + } + + // 验证 Upstreams section + if !strings.Contains(body, "Upstreams:") { + t.Error("expected text output to contain Upstreams section") + } + if !strings.Contains(body, "backend: 3 healthy, 1 unhealthy") { + t.Error("expected text output to contain upstream info") + } + if !strings.Contains(body, "Latency: P50=10.50ms, P95=25.30ms, P99=50.10ms") { + t.Error("expected text output to contain latency info") + } + + // 验证 SSL section + if !strings.Contains(body, "SSL:") { + t.Error("expected text output to contain SSL section") + } + if !strings.Contains(body, "Handshakes: 500") { + t.Error("expected text output to contain Handshakes") + } + if !strings.Contains(body, "Session Reused: 200") { + t.Error("expected text output to contain Session Reused") + } + if !strings.Contains(body, "Reuse Rate: 40.00%") { + t.Error("expected text output to contain Reuse Rate") + } + + // 验证 Rate Limits section + if !strings.Contains(body, "Rate Limits:") { + t.Error("expected text output to contain Rate Limits section") + } + if !strings.Contains(body, "api: 1000 requests, limit=500, rejected=50") { + t.Error("expected text output to contain rate limit info") + } +} + +func TestServeHTML_WithAllSections(t *testing.T) { + cfg := &config.StatusConfig{ + Path: "/_status", + Format: "html", + Allow: []string{}, + } + + srv := New(nil) + srv.startTime = time.Now() + + h, err := NewStatusHandler(srv, cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + status := &Status{ + Version: "1.0.0", + Uptime: 10 * time.Second, + Connections: 5, + Requests: 100, + BytesSent: 1024, + BytesReceived: 512, + Cache: &CacheStats{ + FileCache: FileCacheStats{ + Entries: 50, + MaxEntries: 100, + Size: 10240, + MaxSize: 102400, + }, + ProxyCache: ProxyCacheStats{ + Entries: 25, + Pending: 5, + }, + }, + Pool: &PoolStats{ + Workers: 10, + IdleWorkers: 5, + MaxWorkers: 20, + MinWorkers: 2, + QueueLen: 3, + QueueCap: 100, + }, + Upstreams: []UpstreamStatus{ + { + Name: "backend", + HealthyCount: 3, + UnhealthyCount: 1, + LatencyP50: 10.5, + LatencyP95: 25.3, + LatencyP99: 50.1, + }, + }, + SSL: &SSLStatus{ + Handshakes: 500, + SessionReused: 200, + ReuseRate: 40.0, + }, + RateLimits: []RateLimitStatus{ + { + ZoneName: "api", + Requests: 1000, + Limit: 500, + Rejected: 50, + }, + }, + } + + ctx := &fasthttp.RequestCtx{} + h.serveHTML(ctx, status) + + body := string(ctx.Response.Body()) + + // 验证 HTML 结构 + if !strings.Contains(body, "") { + t.Error("expected HTML output to contain DOCTYPE") + } + + // 验证 Cache section + if !strings.Contains(body, "