diff --git a/internal/sslutil/certpool_test.go b/internal/sslutil/certpool_test.go new file mode 100644 index 0000000..c11f061 --- /dev/null +++ b/internal/sslutil/certpool_test.go @@ -0,0 +1,422 @@ +// Package sslutil provides SSL/TLS utility functions tests. +package sslutil + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" +) + +// generateTestCert generates a self-signed certificate for testing. +func generateTestCert(t *testing.T) ([]byte, []byte) { + t.Helper() + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("Failed to generate private key: %v", err) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("Failed to create certificate: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + keyDER, err := x509.MarshalECPrivateKey(priv) + if err != nil { + t.Fatalf("Failed to marshal private key: %v", err) + } + + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: keyDER, + }) + + return certPEM, keyPEM +} + +// generateMultipleCerts generates multiple certificates in a single PEM file. +func generateMultipleCerts(t *testing.T, count int) []byte { + t.Helper() + + var pemData []byte + for i := 0; i < count; i++ { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("Failed to generate private key: %v", err) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(int64(i + 1)), + Subject: pkix.Name{ + Organization: []string{"Test"}, + CommonName: "test-cert-" + string(rune('A'+i)), + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("Failed to create certificate: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + pemData = append(pemData, certPEM...) + } + + return pemData +} + +func TestLoadCertPool_Success(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + + cert, _ := generateTestCert(t) + if err := os.WriteFile(certPath, cert, 0o644); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + + pool, err := LoadCertPool(certPath, "") + if err != nil { + t.Errorf("LoadCertPool() error = %v, want nil", err) + } + if pool == nil { + t.Error("LoadCertPool() returned nil pool") + } +} + +func TestLoadCertPool_MultipleCertificates(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "certs.pem") + + multiCerts := generateMultipleCerts(t, 3) + if err := os.WriteFile(certPath, multiCerts, 0o644); err != nil { + t.Fatalf("Failed to write multi-cert file: %v", err) + } + + pool, err := LoadCertPool(certPath, "") + if err != nil { + t.Errorf("LoadCertPool() error = %v, want nil", err) + } + if pool == nil { + t.Error("LoadCertPool() returned nil pool") + } +} + +func TestLoadCertPool_FileNotFound(t *testing.T) { + _, err := LoadCertPool("/nonexistent/cert.pem", "") + if err == nil { + t.Error("LoadCertPool() expected error for non-existent file, got nil") + } +} + +func TestLoadCertPool_EmptyFile(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "empty.pem") + + if err := os.WriteFile(certPath, []byte{}, 0o644); err != nil { + t.Fatalf("Failed to write empty file: %v", err) + } + + _, err := LoadCertPool(certPath, "") + if err == nil { + t.Error("LoadCertPool() expected error for empty file, got nil") + } +} + +func TestLoadCertPool_InvalidPEM(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "invalid.pem") + + invalidData := []byte("not a valid PEM file") + if err := os.WriteFile(certPath, invalidData, 0o644); err != nil { + t.Fatalf("Failed to write invalid file: %v", err) + } + + _, err := LoadCertPool(certPath, "") + if err == nil { + t.Error("LoadCertPool() expected error for invalid PEM, got nil") + } +} + +func TestLoadCertPool_InvalidCertificate(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "invalid-cert.pem") + + // Valid PEM block but not a certificate + invalidCert := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: []byte("not a valid certificate"), + }) + if err := os.WriteFile(certPath, invalidCert, 0o644); err != nil { + t.Fatalf("Failed to write invalid cert file: %v", err) + } + + _, err := LoadCertPool(certPath, "") + if err == nil { + t.Error("LoadCertPool() expected error for invalid certificate, got nil") + } +} + +func TestLoadCertPool_WrongPEMType(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "key.pem") + + _, key := generateTestCert(t) + if err := os.WriteFile(certPath, key, 0o644); err != nil { + t.Fatalf("Failed to write key file: %v", err) + } + + _, err := LoadCertPool(certPath, "") + if err == nil { + t.Error("LoadCertPool() expected error for non-certificate PEM, got nil") + } +} + +func TestLoadCertPool_Directory(t *testing.T) { + tmpDir := t.TempDir() + + _, err := LoadCertPool(tmpDir, "") + if err == nil { + t.Error("LoadCertPool() expected error for directory path, got nil") + } +} + +func TestLoadCACertPool_Success(t *testing.T) { + tmpDir := t.TempDir() + caPath := filepath.Join(tmpDir, "ca.pem") + + cert, _ := generateTestCert(t) + if err := os.WriteFile(caPath, cert, 0o644); err != nil { + t.Fatalf("Failed to write CA cert: %v", err) + } + + pool, err := LoadCACertPool(caPath) + if err != nil { + t.Errorf("LoadCACertPool() error = %v, want nil", err) + } + if pool == nil { + t.Error("LoadCACertPool() returned nil pool") + } +} + +func TestLoadCACertPool_MultipleCertificates(t *testing.T) { + tmpDir := t.TempDir() + caPath := filepath.Join(tmpDir, "ca-bundle.pem") + + multiCerts := generateMultipleCerts(t, 5) + if err := os.WriteFile(caPath, multiCerts, 0o644); err != nil { + t.Fatalf("Failed to write CA bundle: %v", err) + } + + pool, err := LoadCACertPool(caPath) + if err != nil { + t.Errorf("LoadCACertPool() error = %v, want nil", err) + } + if pool == nil { + t.Error("LoadCACertPool() returned nil pool") + } +} + +func TestLoadCACertPool_FileNotFound(t *testing.T) { + _, err := LoadCACertPool("/nonexistent/ca.pem") + if err == nil { + t.Error("LoadCACertPool() expected error for non-existent file, got nil") + } +} + +func TestLoadCACertPool_EmptyFile(t *testing.T) { + tmpDir := t.TempDir() + caPath := filepath.Join(tmpDir, "empty.pem") + + if err := os.WriteFile(caPath, []byte{}, 0o644); err != nil { + t.Fatalf("Failed to write empty file: %v", err) + } + + _, err := LoadCACertPool(caPath) + if err == nil { + t.Error("LoadCACertPool() expected error for empty file, got nil") + } +} + +func TestLoadCACertPool_InvalidPEM(t *testing.T) { + tmpDir := t.TempDir() + caPath := filepath.Join(tmpDir, "invalid.pem") + + invalidData := []byte("not a valid PEM file") + if err := os.WriteFile(caPath, invalidData, 0o644); err != nil { + t.Fatalf("Failed to write invalid file: %v", err) + } + + _, err := LoadCACertPool(caPath) + if err == nil { + t.Error("LoadCACertPool() expected error for invalid PEM, got nil") + } +} + +func TestLoadCACertPool_InvalidCertificate(t *testing.T) { + tmpDir := t.TempDir() + caPath := filepath.Join(tmpDir, "invalid-ca.pem") + + // Valid PEM block but not a certificate + invalidCert := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: []byte("not a valid certificate"), + }) + if err := os.WriteFile(caPath, invalidCert, 0o644); err != nil { + t.Fatalf("Failed to write invalid cert file: %v", err) + } + + _, err := LoadCACertPool(caPath) + if err == nil { + t.Error("LoadCACertPool() expected error for invalid certificate, got nil") + } +} + +func TestLoadCACertPool_WrongPEMType(t *testing.T) { + tmpDir := t.TempDir() + caPath := filepath.Join(tmpDir, "key.pem") + + _, key := generateTestCert(t) + if err := os.WriteFile(caPath, key, 0o644); err != nil { + t.Fatalf("Failed to write key file: %v", err) + } + + _, err := LoadCACertPool(caPath) + if err == nil { + t.Error("LoadCACertPool() expected error for non-certificate PEM, got nil") + } +} + +func TestLoadCACertPool_Directory(t *testing.T) { + tmpDir := t.TempDir() + + _, err := LoadCACertPool(tmpDir) + if err == nil { + t.Error("LoadCACertPool() expected error for directory path, got nil") + } +} + +func TestLoadCACertPool_PermissionDenied(t *testing.T) { + // Skip on Windows as permission handling differs + if os.Getenv("GOOS") == "windows" { + t.Skip("Skipping on Windows") + } + + tmpDir := t.TempDir() + caPath := filepath.Join(tmpDir, "no-perm.pem") + + cert, _ := generateTestCert(t) + if err := os.WriteFile(caPath, cert, 0o000); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + + _, err := LoadCACertPool(caPath) + if err == nil { + t.Error("LoadCACertPool() expected error for unreadable file, got nil") + } +} + +// Test that LoadCertPool and LoadCACertPool produce equivalent results +func TestLoadCertPool_Equivalence(t *testing.T) { + tmpDir := t.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + + cert, _ := generateTestCert(t) + if err := os.WriteFile(certPath, cert, 0o644); err != nil { + t.Fatalf("Failed to write cert: %v", err) + } + + pool1, err1 := LoadCertPool(certPath, "") + pool2, err2 := LoadCACertPool(certPath) + + if err1 != err2 { + t.Errorf("LoadCertPool() err = %v, LoadCACertPool() err = %v", err1, err2) + } + + if (pool1 == nil) != (pool2 == nil) { + t.Errorf("LoadCertPool() pool nil = %v, LoadCACertPool() pool nil = %v", pool1 == nil, pool2 == nil) + } +} + +// Benchmark LoadCertPool +func BenchmarkLoadCertPool(b *testing.B) { + tmpDir := b.TempDir() + certPath := filepath.Join(tmpDir, "cert.pem") + + cert, _ := generateTestCert(&testing.T{}) + if err := os.WriteFile(certPath, cert, 0o644); err != nil { + b.Fatalf("Failed to write cert: %v", err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = LoadCertPool(certPath, "") + } +} + +// Benchmark LoadCACertPool +func BenchmarkLoadCACertPool(b *testing.B) { + tmpDir := b.TempDir() + caPath := filepath.Join(tmpDir, "ca.pem") + + cert, _ := generateTestCert(&testing.T{}) + if err := os.WriteFile(caPath, cert, 0o644); err != nil { + b.Fatalf("Failed to write CA cert: %v", err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = LoadCACertPool(caPath) + } +} + +// Benchmark with multiple certificates +func BenchmarkLoadCACertPool_MultiCert(b *testing.B) { + tmpDir := b.TempDir() + caPath := filepath.Join(tmpDir, "ca-bundle.pem") + + multiCerts := generateMultipleCerts(&testing.T{}, 10) + if err := os.WriteFile(caPath, multiCerts, 0o644); err != nil { + b.Fatalf("Failed to write CA bundle: %v", err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = LoadCACertPool(caPath) + } +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go new file mode 100644 index 0000000..b3be0e7 --- /dev/null +++ b/internal/utils/utils_test.go @@ -0,0 +1,299 @@ +// Package utils 提供工具函数的测试。 +// +// 该文件测试工具模块的各项功能,包括: +// - HTTPError 结构体和预定义错误 +// - SendError 函数 +// - SendErrorWithDetail 函数 +// - 内部重定向相关函数 +// +// 作者:xfy +package utils + +import ( + "testing" + + "github.com/valyala/fasthttp" +) + +// TestHTTPErrorPredefined 测试预定义的 HTTP 错误 +func TestHTTPErrorPredefined(t *testing.T) { + tests := []struct { + name string + err HTTPError + wantMsg string + wantStatus int + }{ + {"NotFound", ErrNotFound, "Not Found", fasthttp.StatusNotFound}, + {"Forbidden", ErrForbidden, "Forbidden", fasthttp.StatusForbidden}, + {"Unauthorized", ErrUnauthorized, "Unauthorized", fasthttp.StatusUnauthorized}, + {"BadGateway", ErrBadGateway, "Bad Gateway", fasthttp.StatusBadGateway}, + {"GatewayTimeout", ErrGatewayTimeout, "Gateway Timeout", fasthttp.StatusGatewayTimeout}, + {"InternalError", ErrInternalError, "Internal Server Error", fasthttp.StatusInternalServerError}, + {"TooManyRequests", ErrTooManyRequests, "Too Many Requests", fasthttp.StatusTooManyRequests}, + {"ServiceUnavailable", ErrServiceUnavailable, "Service Unavailable", fasthttp.StatusServiceUnavailable}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err.Message != tt.wantMsg { + t.Errorf("message = %q, want %q", tt.err.Message, tt.wantMsg) + } + if tt.err.StatusCode != tt.wantStatus { + t.Errorf("status = %d, want %d", tt.err.StatusCode, tt.wantStatus) + } + }) + } +} + +// TestSendError 测试 SendError 函数 +func TestSendError(t *testing.T) { + tests := []struct { + name string + err HTTPError + wantBody string + wantStatus int + }{ + { + name: "not_found", + err: ErrNotFound, + wantBody: "Not Found", + wantStatus: fasthttp.StatusNotFound, + }, + { + name: "internal_error", + err: ErrInternalError, + wantBody: "Internal Server Error", + wantStatus: fasthttp.StatusInternalServerError, + }, + { + name: "custom_error", + err: HTTPError{Message: "Custom Error", StatusCode: 418}, + wantBody: "Custom Error", + wantStatus: 418, + }, + { + name: "empty_message", + err: HTTPError{Message: "", StatusCode: 200}, + wantBody: "", + wantStatus: 200, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test") + + SendError(ctx, tt.err) + + if ctx.Response.StatusCode() != tt.wantStatus { + t.Errorf("status = %d, want %d", ctx.Response.StatusCode(), tt.wantStatus) + } + if body := string(ctx.Response.Body()); body != tt.wantBody { + t.Errorf("body = %q, want %q", body, tt.wantBody) + } + }) + } +} + +// TestSendErrorWithDetail 测试 SendErrorWithDetail 函数 +func TestSendErrorWithDetail(t *testing.T) { + tests := []struct { + name string + err HTTPError + detail string + wantBody string + wantStatus int + }{ + { + name: "with_detail", + err: ErrNotFound, + detail: "resource missing", + wantBody: "Not Found: resource missing", + wantStatus: fasthttp.StatusNotFound, + }, + { + name: "empty_detail", + err: ErrForbidden, + detail: "", + wantBody: "Forbidden", + wantStatus: fasthttp.StatusForbidden, + }, + { + name: "detail_with_special_chars", + err: ErrBadGateway, + detail: "upstream: http://backend:8080 (connection refused)", + wantBody: "Bad Gateway: upstream: http://backend:8080 (connection refused)", + wantStatus: fasthttp.StatusBadGateway, + }, + { + name: "detail_with_chinese", + err: ErrForbidden, + detail: "权限不足", + wantBody: "Forbidden: 权限不足", + wantStatus: fasthttp.StatusForbidden, + }, + { + name: "custom_error_with_detail", + err: HTTPError{Message: "Custom", StatusCode: 418}, + detail: "I'm a teapot", + wantBody: "Custom: I'm a teapot", + wantStatus: 418, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test") + + SendErrorWithDetail(ctx, tt.err, tt.detail) + + if ctx.Response.StatusCode() != tt.wantStatus { + t.Errorf("status = %d, want %d", ctx.Response.StatusCode(), tt.wantStatus) + } + if body := string(ctx.Response.Body()); body != tt.wantBody { + t.Errorf("body = %q, want %q", body, tt.wantBody) + } + }) + } +} + +// TestInternalRedirectKey 测试内部重定向常量 +func TestInternalRedirectKey(t *testing.T) { + if InternalRedirectKey != "__internal_redirect__" { + t.Errorf("InternalRedirectKey = %q, want %q", InternalRedirectKey, "__internal_redirect__") + } +} + +// TestInternalRedirect 测试内部重定向相关函数 +func TestInternalRedirect(t *testing.T) { + t.Run("SetAndGet", func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + SetInternalRedirect(ctx, "/new/path") + + if !IsInternalRedirect(ctx) { + t.Error("expected internal redirect to be set") + } + if path := GetInternalRedirectPath(ctx); path != "/new/path" { + t.Errorf("path = %q, want %q", path, "/new/path") + } + }) + + t.Run("NotSet", func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + + if IsInternalRedirect(ctx) { + t.Error("expected no internal redirect") + } + if path := GetInternalRedirectPath(ctx); path != "" { + t.Errorf("path = %q, want empty", path) + } + }) + + t.Run("WrongType", func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.SetUserValue(InternalRedirectKey, 123) + + if !IsInternalRedirect(ctx) { + t.Error("expected internal redirect to be set") + } + if path := GetInternalRedirectPath(ctx); path != "" { + t.Errorf("path = %q, want empty for wrong type", path) + } + }) + + t.Run("EmptyPath", func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + SetInternalRedirect(ctx, "") + + if !IsInternalRedirect(ctx) { + t.Error("expected internal redirect to be set even with empty path") + } + if path := GetInternalRedirectPath(ctx); path != "" { + t.Errorf("path = %q, want empty", path) + } + }) + + t.Run("RootPath", func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + SetInternalRedirect(ctx, "/") + + if !IsInternalRedirect(ctx) { + t.Error("expected internal redirect to be set") + } + if path := GetInternalRedirectPath(ctx); path != "/" { + t.Errorf("path = %q, want %q", path, "/") + } + }) + + t.Run("PathWithQuery", func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + SetInternalRedirect(ctx, "/api/health?check=all") + + if path := GetInternalRedirectPath(ctx); path != "/api/health?check=all" { + t.Errorf("path = %q, want %q", path, "/api/health?check=all") + } + }) + + t.Run("PathWithChinese", func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + SetInternalRedirect(ctx, "/内部/健康检查") + + if path := GetInternalRedirectPath(ctx); path != "/内部/健康检查" { + t.Errorf("path = %q, want %q", path, "/内部/健康检查") + } + }) + + t.Run("NilUserValue", func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.SetUserValue(InternalRedirectKey, nil) + + if IsInternalRedirect(ctx) { + t.Error("expected no internal redirect for nil value") + } + if path := GetInternalRedirectPath(ctx); path != "" { + t.Errorf("path = %q, want empty", path) + } + }) +} + +// TestHTTPError_CustomErrors 测试自定义 HTTPError +func TestHTTPError_CustomErrors(t *testing.T) { + tests := []struct { + name string + err HTTPError + wantMsg string + wantStatus int + }{ + { + name: "teapot", + err: HTTPError{Message: "I'm a teapot", StatusCode: 418}, + wantMsg: "I'm a teapot", + wantStatus: 418, + }, + { + name: "rate_limit", + err: HTTPError{Message: "Rate limit exceeded", StatusCode: 429}, + wantMsg: "Rate limit exceeded", + wantStatus: 429, + }, + { + name: "empty_message", + err: HTTPError{Message: "", StatusCode: 500}, + wantMsg: "", + wantStatus: 500, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err.Message != tt.wantMsg { + t.Errorf("Message = %q, want %q", tt.err.Message, tt.wantMsg) + } + if tt.err.StatusCode != tt.wantStatus { + t.Errorf("StatusCode = %d, want %d", tt.err.StatusCode, tt.wantStatus) + } + }) + } +} diff --git a/internal/version/version_test.go b/internal/version/version_test.go new file mode 100644 index 0000000..3ca6e8a --- /dev/null +++ b/internal/version/version_test.go @@ -0,0 +1,301 @@ +// Package version 提供版本信息的测试。 +package version + +import ( + "strings" + "testing" +) + +// TestDefaultValues 测试默认值。 +func TestDefaultValues(t *testing.T) { + // 注意:由于版本变量是包级别的变量, + // 这个测试验证的是编译时的默认值。 + // 在实际构建中,这些值会通过 -ldflags 注入。 + + t.Run("Version默认值", func(t *testing.T) { + // 默认值应该是 "dev" + if Version != "dev" { + t.Errorf("Version = %q, want %q", Version, "dev") + } + }) + + t.Run("GitCommit默认值", func(t *testing.T) { + if GitCommit != "unknown" { + t.Errorf("GitCommit = %q, want %q", GitCommit, "unknown") + } + }) + + t.Run("GitBranch默认值", func(t *testing.T) { + if GitBranch != "unknown" { + t.Errorf("GitBranch = %q, want %q", GitBranch, "unknown") + } + }) + + t.Run("BuildTime默认值", func(t *testing.T) { + if BuildTime != "unknown" { + t.Errorf("BuildTime = %q, want %q", BuildTime, "unknown") + } + }) + + t.Run("GoVersion默认值", func(t *testing.T) { + if GoVersion != "unknown" { + t.Errorf("GoVersion = %q, want %q", GoVersion, "unknown") + } + }) + + t.Run("BuildPlatform默认值", func(t *testing.T) { + if BuildPlatform != "unknown" { + t.Errorf("BuildPlatform = %q, want %q", BuildPlatform, "unknown") + } + }) +} + +// TestVersionVariableMutation 测试版本变量可以被修改。 +// 这模拟了 -ldflags 注入的效果。 +func TestVersionVariableMutation(t *testing.T) { + // 保存原始值 + originalVersion := Version + originalGitCommit := GitCommit + originalGitBranch := GitBranch + originalBuildTime := BuildTime + originalGoVersion := GoVersion + originalBuildPlatform := BuildPlatform + + // 在测试结束时恢复原始值 + t.Cleanup(func() { + Version = originalVersion + GitCommit = originalGitCommit + GitBranch = originalGitBranch + BuildTime = originalBuildTime + GoVersion = originalGoVersion + BuildPlatform = originalBuildPlatform + }) + + t.Run("设置Version", func(t *testing.T) { + testVersion := "v1.2.3" + Version = testVersion + if Version != testVersion { + t.Errorf("Version = %q, want %q", Version, testVersion) + } + }) + + t.Run("设置GitCommit", func(t *testing.T) { + testCommit := "abc123def456" + GitCommit = testCommit + if GitCommit != testCommit { + t.Errorf("GitCommit = %q, want %q", GitCommit, testCommit) + } + }) + + t.Run("设置GitBranch", func(t *testing.T) { + testBranch := "feature/test" + GitBranch = testBranch + if GitBranch != testBranch { + t.Errorf("GitBranch = %q, want %q", GitBranch, testBranch) + } + }) + + t.Run("设置BuildTime", func(t *testing.T) { + testBuildTime := "2024-01-15T10:30:00Z" + BuildTime = testBuildTime + if BuildTime != testBuildTime { + t.Errorf("BuildTime = %q, want %q", BuildTime, testBuildTime) + } + }) + + t.Run("设置GoVersion", func(t *testing.T) { + testGoVersion := "go1.21.5" + GoVersion = testGoVersion + if GoVersion != testGoVersion { + t.Errorf("GoVersion = %q, want %q", GoVersion, testGoVersion) + } + }) + + t.Run("设置BuildPlatform", func(t *testing.T) { + testPlatform := "linux/amd64" + BuildPlatform = testPlatform + if BuildPlatform != testPlatform { + t.Errorf("BuildPlatform = %q, want %q", BuildPlatform, testPlatform) + } + }) +} + +// TestVersionInformationFormat 测试版本信息的格式化。 +func TestVersionInformationFormat(t *testing.T) { + // 保存原始值 + originalVersion := Version + originalGitCommit := GitCommit + originalGitBranch := GitBranch + originalBuildTime := BuildTime + originalGoVersion := GoVersion + originalBuildPlatform := BuildPlatform + + t.Cleanup(func() { + Version = originalVersion + GitCommit = originalGitCommit + GitBranch = originalGitBranch + BuildTime = originalBuildTime + GoVersion = originalGoVersion + BuildPlatform = originalBuildPlatform + }) + + // 设置测试值 + Version = "v2.0.0" + GitCommit = "a1b2c3d4e5f6" + GitBranch = "main" + BuildTime = "2024-06-01T12:00:00Z" + GoVersion = "go1.22.0" + BuildPlatform = "darwin/arm64" + + t.Run("版本号格式", func(t *testing.T) { + // 版本号应该以 'v' 开头(语义化版本规范) + if !strings.HasPrefix(Version, "v") { + t.Errorf("Version = %q, should start with 'v'", Version) + } + }) + + t.Run("Git提交哈希格式", func(t *testing.T) { + // Git 提交哈希应该是十六进制字符串 + for _, c := range GitCommit { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + t.Errorf("GitCommit = %q, contains invalid character %q", GitCommit, c) + break + } + } + }) + + t.Run("构建平台格式", func(t *testing.T) { + // 构建平台应该是 OS/Arch 格式 + if !strings.Contains(BuildPlatform, "/") { + t.Errorf("BuildPlatform = %q, should be in OS/Arch format", BuildPlatform) + } + }) + + t.Run("Go版本格式", func(t *testing.T) { + // Go 版本应该以 'go' 开头 + if !strings.HasPrefix(GoVersion, "go") { + t.Errorf("GoVersion = %q, should start with 'go'", GoVersion) + } + }) +} + +// TestAllVariablesNonEmpty 测试所有变量设置后非空。 +func TestAllVariablesNonEmpty(t *testing.T) { + // 保存原始值 + originalVersion := Version + originalGitCommit := GitCommit + originalGitBranch := GitBranch + originalBuildTime := BuildTime + originalGoVersion := GoVersion + originalBuildPlatform := BuildPlatform + + t.Cleanup(func() { + Version = originalVersion + GitCommit = originalGitCommit + GitBranch = originalGitBranch + BuildTime = originalBuildTime + GoVersion = originalGoVersion + BuildPlatform = originalBuildPlatform + }) + + // 设置非空值 + Version = "v1.0.0" + GitCommit = "1234567890abcdef" + GitBranch = "master" + BuildTime = "2024-01-01" + GoVersion = "go1.21" + BuildPlatform = "linux/amd64" + + vars := map[string]string{ + "Version": Version, + "GitCommit": GitCommit, + "GitBranch": GitBranch, + "BuildTime": BuildTime, + "GoVersion": GoVersion, + "BuildPlatform": BuildPlatform, + } + + for name, value := range vars { + if value == "" { + t.Errorf("%s is empty, expected non-empty value", name) + } + if value == "unknown" { + t.Errorf("%s = %q, expected set value", name, value) + } + } +} + +// TestVersionConsistency 测试版本信息的一致性。 +func TestVersionConsistency(t *testing.T) { + // 保存原始值 + originalVersion := Version + originalGitCommit := GitCommit + originalGitBranch := GitBranch + originalBuildTime := BuildTime + originalGoVersion := GoVersion + originalBuildPlatform := BuildPlatform + + t.Cleanup(func() { + Version = originalVersion + GitCommit = originalGitCommit + GitBranch = originalGitBranch + BuildTime = originalBuildTime + GoVersion = originalGoVersion + BuildPlatform = originalBuildPlatform + }) + + t.Run("语义化版本格式", func(t *testing.T) { + testVersions := []string{ + "v1.0.0", + "v2.1.3", + "v0.0.1", + "v10.20.30", + } + + for _, v := range testVersions { + Version = v + // 验证版本号格式 + if !strings.HasPrefix(Version, "v") { + t.Errorf("Version = %q, should start with 'v'", Version) + } + // 验证版本号包含点号 + if !strings.Contains(Version[1:], ".") { + t.Errorf("Version = %q, should contain '.' after 'v'", Version) + } + } + }) + + t.Run("开发版本标识", func(t *testing.T) { + Version = "dev" + if Version != "dev" { + t.Errorf("Version = %q, want %q", Version, "dev") + } + }) +} + +// TestBuildPlatformVariants 测试不同构建平台格式。 +func TestBuildPlatformVariants(t *testing.T) { + // 保存原始值 + originalBuildPlatform := BuildPlatform + t.Cleanup(func() { + BuildPlatform = originalBuildPlatform + }) + + platforms := []string{ + "linux/amd64", + "linux/arm64", + "darwin/amd64", + "darwin/arm64", + "windows/amd64", + "freebsd/amd64", + } + + for _, platform := range platforms { + t.Run(platform, func(t *testing.T) { + BuildPlatform = platform + if BuildPlatform != platform { + t.Errorf("BuildPlatform = %q, want %q", BuildPlatform, platform) + } + }) + } +}