test(sslutil,utils,version): 添加工具模块测试
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
a832e48656
commit
31b3d4d0a3
422
internal/sslutil/certpool_test.go
Normal file
422
internal/sslutil/certpool_test.go
Normal file
@ -0,0 +1,422 @@
|
||||
// Package sslutil provides SSL/TLS utility functions tests.
|
||||
package sslutil
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// generateTestCert generates a self-signed certificate for testing.
|
||||
func generateTestCert(t *testing.T) ([]byte, []byte) {
|
||||
t.Helper()
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDER,
|
||||
})
|
||||
|
||||
keyDER, err := x509.MarshalECPrivateKey(priv)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal private key: %v", err)
|
||||
}
|
||||
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: keyDER,
|
||||
})
|
||||
|
||||
return certPEM, keyPEM
|
||||
}
|
||||
|
||||
// generateMultipleCerts generates multiple certificates in a single PEM file.
|
||||
func generateMultipleCerts(t *testing.T, count int) []byte {
|
||||
t.Helper()
|
||||
|
||||
var pemData []byte
|
||||
for i := 0; i < count; i++ {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(int64(i + 1)),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test"},
|
||||
CommonName: "test-cert-" + string(rune('A'+i)),
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certDER,
|
||||
})
|
||||
|
||||
pemData = append(pemData, certPEM...)
|
||||
}
|
||||
|
||||
return pemData
|
||||
}
|
||||
|
||||
func TestLoadCertPool_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
|
||||
cert, _ := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
|
||||
pool, err := LoadCertPool(certPath, "")
|
||||
if err != nil {
|
||||
t.Errorf("LoadCertPool() error = %v, want nil", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Error("LoadCertPool() returned nil pool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCertPool_MultipleCertificates(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "certs.pem")
|
||||
|
||||
multiCerts := generateMultipleCerts(t, 3)
|
||||
if err := os.WriteFile(certPath, multiCerts, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write multi-cert file: %v", err)
|
||||
}
|
||||
|
||||
pool, err := LoadCertPool(certPath, "")
|
||||
if err != nil {
|
||||
t.Errorf("LoadCertPool() error = %v, want nil", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Error("LoadCertPool() returned nil pool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCertPool_FileNotFound(t *testing.T) {
|
||||
_, err := LoadCertPool("/nonexistent/cert.pem", "")
|
||||
if err == nil {
|
||||
t.Error("LoadCertPool() expected error for non-existent file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCertPool_EmptyFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "empty.pem")
|
||||
|
||||
if err := os.WriteFile(certPath, []byte{}, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write empty file: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadCertPool(certPath, "")
|
||||
if err == nil {
|
||||
t.Error("LoadCertPool() expected error for empty file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCertPool_InvalidPEM(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "invalid.pem")
|
||||
|
||||
invalidData := []byte("not a valid PEM file")
|
||||
if err := os.WriteFile(certPath, invalidData, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write invalid file: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadCertPool(certPath, "")
|
||||
if err == nil {
|
||||
t.Error("LoadCertPool() expected error for invalid PEM, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCertPool_InvalidCertificate(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "invalid-cert.pem")
|
||||
|
||||
// Valid PEM block but not a certificate
|
||||
invalidCert := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: []byte("not a valid certificate"),
|
||||
})
|
||||
if err := os.WriteFile(certPath, invalidCert, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write invalid cert file: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadCertPool(certPath, "")
|
||||
if err == nil {
|
||||
t.Error("LoadCertPool() expected error for invalid certificate, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCertPool_WrongPEMType(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
_, key := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, key, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write key file: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadCertPool(certPath, "")
|
||||
if err == nil {
|
||||
t.Error("LoadCertPool() expected error for non-certificate PEM, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCertPool_Directory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
_, err := LoadCertPool(tmpDir, "")
|
||||
if err == nil {
|
||||
t.Error("LoadCertPool() expected error for directory path, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
caPath := filepath.Join(tmpDir, "ca.pem")
|
||||
|
||||
cert, _ := generateTestCert(t)
|
||||
if err := os.WriteFile(caPath, cert, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write CA cert: %v", err)
|
||||
}
|
||||
|
||||
pool, err := LoadCACertPool(caPath)
|
||||
if err != nil {
|
||||
t.Errorf("LoadCACertPool() error = %v, want nil", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Error("LoadCACertPool() returned nil pool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_MultipleCertificates(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
caPath := filepath.Join(tmpDir, "ca-bundle.pem")
|
||||
|
||||
multiCerts := generateMultipleCerts(t, 5)
|
||||
if err := os.WriteFile(caPath, multiCerts, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write CA bundle: %v", err)
|
||||
}
|
||||
|
||||
pool, err := LoadCACertPool(caPath)
|
||||
if err != nil {
|
||||
t.Errorf("LoadCACertPool() error = %v, want nil", err)
|
||||
}
|
||||
if pool == nil {
|
||||
t.Error("LoadCACertPool() returned nil pool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_FileNotFound(t *testing.T) {
|
||||
_, err := LoadCACertPool("/nonexistent/ca.pem")
|
||||
if err == nil {
|
||||
t.Error("LoadCACertPool() expected error for non-existent file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_EmptyFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
caPath := filepath.Join(tmpDir, "empty.pem")
|
||||
|
||||
if err := os.WriteFile(caPath, []byte{}, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write empty file: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadCACertPool(caPath)
|
||||
if err == nil {
|
||||
t.Error("LoadCACertPool() expected error for empty file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_InvalidPEM(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
caPath := filepath.Join(tmpDir, "invalid.pem")
|
||||
|
||||
invalidData := []byte("not a valid PEM file")
|
||||
if err := os.WriteFile(caPath, invalidData, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write invalid file: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadCACertPool(caPath)
|
||||
if err == nil {
|
||||
t.Error("LoadCACertPool() expected error for invalid PEM, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_InvalidCertificate(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
caPath := filepath.Join(tmpDir, "invalid-ca.pem")
|
||||
|
||||
// Valid PEM block but not a certificate
|
||||
invalidCert := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: []byte("not a valid certificate"),
|
||||
})
|
||||
if err := os.WriteFile(caPath, invalidCert, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write invalid cert file: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadCACertPool(caPath)
|
||||
if err == nil {
|
||||
t.Error("LoadCACertPool() expected error for invalid certificate, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_WrongPEMType(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
caPath := filepath.Join(tmpDir, "key.pem")
|
||||
|
||||
_, key := generateTestCert(t)
|
||||
if err := os.WriteFile(caPath, key, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write key file: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadCACertPool(caPath)
|
||||
if err == nil {
|
||||
t.Error("LoadCACertPool() expected error for non-certificate PEM, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_Directory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
_, err := LoadCACertPool(tmpDir)
|
||||
if err == nil {
|
||||
t.Error("LoadCACertPool() expected error for directory path, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCACertPool_PermissionDenied(t *testing.T) {
|
||||
// Skip on Windows as permission handling differs
|
||||
if os.Getenv("GOOS") == "windows" {
|
||||
t.Skip("Skipping on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
caPath := filepath.Join(tmpDir, "no-perm.pem")
|
||||
|
||||
cert, _ := generateTestCert(t)
|
||||
if err := os.WriteFile(caPath, cert, 0o000); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
|
||||
_, err := LoadCACertPool(caPath)
|
||||
if err == nil {
|
||||
t.Error("LoadCACertPool() expected error for unreadable file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// Test that LoadCertPool and LoadCACertPool produce equivalent results
|
||||
func TestLoadCertPool_Equivalence(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
|
||||
cert, _ := generateTestCert(t)
|
||||
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
|
||||
t.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
|
||||
pool1, err1 := LoadCertPool(certPath, "")
|
||||
pool2, err2 := LoadCACertPool(certPath)
|
||||
|
||||
if err1 != err2 {
|
||||
t.Errorf("LoadCertPool() err = %v, LoadCACertPool() err = %v", err1, err2)
|
||||
}
|
||||
|
||||
if (pool1 == nil) != (pool2 == nil) {
|
||||
t.Errorf("LoadCertPool() pool nil = %v, LoadCACertPool() pool nil = %v", pool1 == nil, pool2 == nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark LoadCertPool
|
||||
func BenchmarkLoadCertPool(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
certPath := filepath.Join(tmpDir, "cert.pem")
|
||||
|
||||
cert, _ := generateTestCert(&testing.T{})
|
||||
if err := os.WriteFile(certPath, cert, 0o644); err != nil {
|
||||
b.Fatalf("Failed to write cert: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = LoadCertPool(certPath, "")
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark LoadCACertPool
|
||||
func BenchmarkLoadCACertPool(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
caPath := filepath.Join(tmpDir, "ca.pem")
|
||||
|
||||
cert, _ := generateTestCert(&testing.T{})
|
||||
if err := os.WriteFile(caPath, cert, 0o644); err != nil {
|
||||
b.Fatalf("Failed to write CA cert: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = LoadCACertPool(caPath)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark with multiple certificates
|
||||
func BenchmarkLoadCACertPool_MultiCert(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
caPath := filepath.Join(tmpDir, "ca-bundle.pem")
|
||||
|
||||
multiCerts := generateMultipleCerts(&testing.T{}, 10)
|
||||
if err := os.WriteFile(caPath, multiCerts, 0o644); err != nil {
|
||||
b.Fatalf("Failed to write CA bundle: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = LoadCACertPool(caPath)
|
||||
}
|
||||
}
|
||||
299
internal/utils/utils_test.go
Normal file
299
internal/utils/utils_test.go
Normal file
@ -0,0 +1,299 @@
|
||||
// Package utils 提供工具函数的测试。
|
||||
//
|
||||
// 该文件测试工具模块的各项功能,包括:
|
||||
// - HTTPError 结构体和预定义错误
|
||||
// - SendError 函数
|
||||
// - SendErrorWithDetail 函数
|
||||
// - 内部重定向相关函数
|
||||
//
|
||||
// 作者:xfy
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// TestHTTPErrorPredefined 测试预定义的 HTTP 错误
|
||||
func TestHTTPErrorPredefined(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err HTTPError
|
||||
wantMsg string
|
||||
wantStatus int
|
||||
}{
|
||||
{"NotFound", ErrNotFound, "Not Found", fasthttp.StatusNotFound},
|
||||
{"Forbidden", ErrForbidden, "Forbidden", fasthttp.StatusForbidden},
|
||||
{"Unauthorized", ErrUnauthorized, "Unauthorized", fasthttp.StatusUnauthorized},
|
||||
{"BadGateway", ErrBadGateway, "Bad Gateway", fasthttp.StatusBadGateway},
|
||||
{"GatewayTimeout", ErrGatewayTimeout, "Gateway Timeout", fasthttp.StatusGatewayTimeout},
|
||||
{"InternalError", ErrInternalError, "Internal Server Error", fasthttp.StatusInternalServerError},
|
||||
{"TooManyRequests", ErrTooManyRequests, "Too Many Requests", fasthttp.StatusTooManyRequests},
|
||||
{"ServiceUnavailable", ErrServiceUnavailable, "Service Unavailable", fasthttp.StatusServiceUnavailable},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.err.Message != tt.wantMsg {
|
||||
t.Errorf("message = %q, want %q", tt.err.Message, tt.wantMsg)
|
||||
}
|
||||
if tt.err.StatusCode != tt.wantStatus {
|
||||
t.Errorf("status = %d, want %d", tt.err.StatusCode, tt.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendError 测试 SendError 函数
|
||||
func TestSendError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err HTTPError
|
||||
wantBody string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "not_found",
|
||||
err: ErrNotFound,
|
||||
wantBody: "Not Found",
|
||||
wantStatus: fasthttp.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "internal_error",
|
||||
err: ErrInternalError,
|
||||
wantBody: "Internal Server Error",
|
||||
wantStatus: fasthttp.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "custom_error",
|
||||
err: HTTPError{Message: "Custom Error", StatusCode: 418},
|
||||
wantBody: "Custom Error",
|
||||
wantStatus: 418,
|
||||
},
|
||||
{
|
||||
name: "empty_message",
|
||||
err: HTTPError{Message: "", StatusCode: 200},
|
||||
wantBody: "",
|
||||
wantStatus: 200,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/test")
|
||||
|
||||
SendError(ctx, tt.err)
|
||||
|
||||
if ctx.Response.StatusCode() != tt.wantStatus {
|
||||
t.Errorf("status = %d, want %d", ctx.Response.StatusCode(), tt.wantStatus)
|
||||
}
|
||||
if body := string(ctx.Response.Body()); body != tt.wantBody {
|
||||
t.Errorf("body = %q, want %q", body, tt.wantBody)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSendErrorWithDetail 测试 SendErrorWithDetail 函数
|
||||
func TestSendErrorWithDetail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err HTTPError
|
||||
detail string
|
||||
wantBody string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "with_detail",
|
||||
err: ErrNotFound,
|
||||
detail: "resource missing",
|
||||
wantBody: "Not Found: resource missing",
|
||||
wantStatus: fasthttp.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "empty_detail",
|
||||
err: ErrForbidden,
|
||||
detail: "",
|
||||
wantBody: "Forbidden",
|
||||
wantStatus: fasthttp.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "detail_with_special_chars",
|
||||
err: ErrBadGateway,
|
||||
detail: "upstream: http://backend:8080 (connection refused)",
|
||||
wantBody: "Bad Gateway: upstream: http://backend:8080 (connection refused)",
|
||||
wantStatus: fasthttp.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "detail_with_chinese",
|
||||
err: ErrForbidden,
|
||||
detail: "权限不足",
|
||||
wantBody: "Forbidden: 权限不足",
|
||||
wantStatus: fasthttp.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "custom_error_with_detail",
|
||||
err: HTTPError{Message: "Custom", StatusCode: 418},
|
||||
detail: "I'm a teapot",
|
||||
wantBody: "Custom: I'm a teapot",
|
||||
wantStatus: 418,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/test")
|
||||
|
||||
SendErrorWithDetail(ctx, tt.err, tt.detail)
|
||||
|
||||
if ctx.Response.StatusCode() != tt.wantStatus {
|
||||
t.Errorf("status = %d, want %d", ctx.Response.StatusCode(), tt.wantStatus)
|
||||
}
|
||||
if body := string(ctx.Response.Body()); body != tt.wantBody {
|
||||
t.Errorf("body = %q, want %q", body, tt.wantBody)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInternalRedirectKey 测试内部重定向常量
|
||||
func TestInternalRedirectKey(t *testing.T) {
|
||||
if InternalRedirectKey != "__internal_redirect__" {
|
||||
t.Errorf("InternalRedirectKey = %q, want %q", InternalRedirectKey, "__internal_redirect__")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInternalRedirect 测试内部重定向相关函数
|
||||
func TestInternalRedirect(t *testing.T) {
|
||||
t.Run("SetAndGet", func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
SetInternalRedirect(ctx, "/new/path")
|
||||
|
||||
if !IsInternalRedirect(ctx) {
|
||||
t.Error("expected internal redirect to be set")
|
||||
}
|
||||
if path := GetInternalRedirectPath(ctx); path != "/new/path" {
|
||||
t.Errorf("path = %q, want %q", path, "/new/path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NotSet", func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
if IsInternalRedirect(ctx) {
|
||||
t.Error("expected no internal redirect")
|
||||
}
|
||||
if path := GetInternalRedirectPath(ctx); path != "" {
|
||||
t.Errorf("path = %q, want empty", path)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("WrongType", func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue(InternalRedirectKey, 123)
|
||||
|
||||
if !IsInternalRedirect(ctx) {
|
||||
t.Error("expected internal redirect to be set")
|
||||
}
|
||||
if path := GetInternalRedirectPath(ctx); path != "" {
|
||||
t.Errorf("path = %q, want empty for wrong type", path)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmptyPath", func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
SetInternalRedirect(ctx, "")
|
||||
|
||||
if !IsInternalRedirect(ctx) {
|
||||
t.Error("expected internal redirect to be set even with empty path")
|
||||
}
|
||||
if path := GetInternalRedirectPath(ctx); path != "" {
|
||||
t.Errorf("path = %q, want empty", path)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RootPath", func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
SetInternalRedirect(ctx, "/")
|
||||
|
||||
if !IsInternalRedirect(ctx) {
|
||||
t.Error("expected internal redirect to be set")
|
||||
}
|
||||
if path := GetInternalRedirectPath(ctx); path != "/" {
|
||||
t.Errorf("path = %q, want %q", path, "/")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PathWithQuery", func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
SetInternalRedirect(ctx, "/api/health?check=all")
|
||||
|
||||
if path := GetInternalRedirectPath(ctx); path != "/api/health?check=all" {
|
||||
t.Errorf("path = %q, want %q", path, "/api/health?check=all")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PathWithChinese", func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
SetInternalRedirect(ctx, "/内部/健康检查")
|
||||
|
||||
if path := GetInternalRedirectPath(ctx); path != "/内部/健康检查" {
|
||||
t.Errorf("path = %q, want %q", path, "/内部/健康检查")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NilUserValue", func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue(InternalRedirectKey, nil)
|
||||
|
||||
if IsInternalRedirect(ctx) {
|
||||
t.Error("expected no internal redirect for nil value")
|
||||
}
|
||||
if path := GetInternalRedirectPath(ctx); path != "" {
|
||||
t.Errorf("path = %q, want empty", path)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestHTTPError_CustomErrors 测试自定义 HTTPError
|
||||
func TestHTTPError_CustomErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err HTTPError
|
||||
wantMsg string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "teapot",
|
||||
err: HTTPError{Message: "I'm a teapot", StatusCode: 418},
|
||||
wantMsg: "I'm a teapot",
|
||||
wantStatus: 418,
|
||||
},
|
||||
{
|
||||
name: "rate_limit",
|
||||
err: HTTPError{Message: "Rate limit exceeded", StatusCode: 429},
|
||||
wantMsg: "Rate limit exceeded",
|
||||
wantStatus: 429,
|
||||
},
|
||||
{
|
||||
name: "empty_message",
|
||||
err: HTTPError{Message: "", StatusCode: 500},
|
||||
wantMsg: "",
|
||||
wantStatus: 500,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.err.Message != tt.wantMsg {
|
||||
t.Errorf("Message = %q, want %q", tt.err.Message, tt.wantMsg)
|
||||
}
|
||||
if tt.err.StatusCode != tt.wantStatus {
|
||||
t.Errorf("StatusCode = %d, want %d", tt.err.StatusCode, tt.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
301
internal/version/version_test.go
Normal file
301
internal/version/version_test.go
Normal file
@ -0,0 +1,301 @@
|
||||
// Package version 提供版本信息的测试。
|
||||
package version
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestDefaultValues 测试默认值。
|
||||
func TestDefaultValues(t *testing.T) {
|
||||
// 注意:由于版本变量是包级别的变量,
|
||||
// 这个测试验证的是编译时的默认值。
|
||||
// 在实际构建中,这些值会通过 -ldflags 注入。
|
||||
|
||||
t.Run("Version默认值", func(t *testing.T) {
|
||||
// 默认值应该是 "dev"
|
||||
if Version != "dev" {
|
||||
t.Errorf("Version = %q, want %q", Version, "dev")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GitCommit默认值", func(t *testing.T) {
|
||||
if GitCommit != "unknown" {
|
||||
t.Errorf("GitCommit = %q, want %q", GitCommit, "unknown")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GitBranch默认值", func(t *testing.T) {
|
||||
if GitBranch != "unknown" {
|
||||
t.Errorf("GitBranch = %q, want %q", GitBranch, "unknown")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("BuildTime默认值", func(t *testing.T) {
|
||||
if BuildTime != "unknown" {
|
||||
t.Errorf("BuildTime = %q, want %q", BuildTime, "unknown")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GoVersion默认值", func(t *testing.T) {
|
||||
if GoVersion != "unknown" {
|
||||
t.Errorf("GoVersion = %q, want %q", GoVersion, "unknown")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("BuildPlatform默认值", func(t *testing.T) {
|
||||
if BuildPlatform != "unknown" {
|
||||
t.Errorf("BuildPlatform = %q, want %q", BuildPlatform, "unknown")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestVersionVariableMutation 测试版本变量可以被修改。
|
||||
// 这模拟了 -ldflags 注入的效果。
|
||||
func TestVersionVariableMutation(t *testing.T) {
|
||||
// 保存原始值
|
||||
originalVersion := Version
|
||||
originalGitCommit := GitCommit
|
||||
originalGitBranch := GitBranch
|
||||
originalBuildTime := BuildTime
|
||||
originalGoVersion := GoVersion
|
||||
originalBuildPlatform := BuildPlatform
|
||||
|
||||
// 在测试结束时恢复原始值
|
||||
t.Cleanup(func() {
|
||||
Version = originalVersion
|
||||
GitCommit = originalGitCommit
|
||||
GitBranch = originalGitBranch
|
||||
BuildTime = originalBuildTime
|
||||
GoVersion = originalGoVersion
|
||||
BuildPlatform = originalBuildPlatform
|
||||
})
|
||||
|
||||
t.Run("设置Version", func(t *testing.T) {
|
||||
testVersion := "v1.2.3"
|
||||
Version = testVersion
|
||||
if Version != testVersion {
|
||||
t.Errorf("Version = %q, want %q", Version, testVersion)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("设置GitCommit", func(t *testing.T) {
|
||||
testCommit := "abc123def456"
|
||||
GitCommit = testCommit
|
||||
if GitCommit != testCommit {
|
||||
t.Errorf("GitCommit = %q, want %q", GitCommit, testCommit)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("设置GitBranch", func(t *testing.T) {
|
||||
testBranch := "feature/test"
|
||||
GitBranch = testBranch
|
||||
if GitBranch != testBranch {
|
||||
t.Errorf("GitBranch = %q, want %q", GitBranch, testBranch)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("设置BuildTime", func(t *testing.T) {
|
||||
testBuildTime := "2024-01-15T10:30:00Z"
|
||||
BuildTime = testBuildTime
|
||||
if BuildTime != testBuildTime {
|
||||
t.Errorf("BuildTime = %q, want %q", BuildTime, testBuildTime)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("设置GoVersion", func(t *testing.T) {
|
||||
testGoVersion := "go1.21.5"
|
||||
GoVersion = testGoVersion
|
||||
if GoVersion != testGoVersion {
|
||||
t.Errorf("GoVersion = %q, want %q", GoVersion, testGoVersion)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("设置BuildPlatform", func(t *testing.T) {
|
||||
testPlatform := "linux/amd64"
|
||||
BuildPlatform = testPlatform
|
||||
if BuildPlatform != testPlatform {
|
||||
t.Errorf("BuildPlatform = %q, want %q", BuildPlatform, testPlatform)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestVersionInformationFormat 测试版本信息的格式化。
|
||||
func TestVersionInformationFormat(t *testing.T) {
|
||||
// 保存原始值
|
||||
originalVersion := Version
|
||||
originalGitCommit := GitCommit
|
||||
originalGitBranch := GitBranch
|
||||
originalBuildTime := BuildTime
|
||||
originalGoVersion := GoVersion
|
||||
originalBuildPlatform := BuildPlatform
|
||||
|
||||
t.Cleanup(func() {
|
||||
Version = originalVersion
|
||||
GitCommit = originalGitCommit
|
||||
GitBranch = originalGitBranch
|
||||
BuildTime = originalBuildTime
|
||||
GoVersion = originalGoVersion
|
||||
BuildPlatform = originalBuildPlatform
|
||||
})
|
||||
|
||||
// 设置测试值
|
||||
Version = "v2.0.0"
|
||||
GitCommit = "a1b2c3d4e5f6"
|
||||
GitBranch = "main"
|
||||
BuildTime = "2024-06-01T12:00:00Z"
|
||||
GoVersion = "go1.22.0"
|
||||
BuildPlatform = "darwin/arm64"
|
||||
|
||||
t.Run("版本号格式", func(t *testing.T) {
|
||||
// 版本号应该以 'v' 开头(语义化版本规范)
|
||||
if !strings.HasPrefix(Version, "v") {
|
||||
t.Errorf("Version = %q, should start with 'v'", Version)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Git提交哈希格式", func(t *testing.T) {
|
||||
// Git 提交哈希应该是十六进制字符串
|
||||
for _, c := range GitCommit {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("GitCommit = %q, contains invalid character %q", GitCommit, c)
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("构建平台格式", func(t *testing.T) {
|
||||
// 构建平台应该是 OS/Arch 格式
|
||||
if !strings.Contains(BuildPlatform, "/") {
|
||||
t.Errorf("BuildPlatform = %q, should be in OS/Arch format", BuildPlatform)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Go版本格式", func(t *testing.T) {
|
||||
// Go 版本应该以 'go' 开头
|
||||
if !strings.HasPrefix(GoVersion, "go") {
|
||||
t.Errorf("GoVersion = %q, should start with 'go'", GoVersion)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAllVariablesNonEmpty 测试所有变量设置后非空。
|
||||
func TestAllVariablesNonEmpty(t *testing.T) {
|
||||
// 保存原始值
|
||||
originalVersion := Version
|
||||
originalGitCommit := GitCommit
|
||||
originalGitBranch := GitBranch
|
||||
originalBuildTime := BuildTime
|
||||
originalGoVersion := GoVersion
|
||||
originalBuildPlatform := BuildPlatform
|
||||
|
||||
t.Cleanup(func() {
|
||||
Version = originalVersion
|
||||
GitCommit = originalGitCommit
|
||||
GitBranch = originalGitBranch
|
||||
BuildTime = originalBuildTime
|
||||
GoVersion = originalGoVersion
|
||||
BuildPlatform = originalBuildPlatform
|
||||
})
|
||||
|
||||
// 设置非空值
|
||||
Version = "v1.0.0"
|
||||
GitCommit = "1234567890abcdef"
|
||||
GitBranch = "master"
|
||||
BuildTime = "2024-01-01"
|
||||
GoVersion = "go1.21"
|
||||
BuildPlatform = "linux/amd64"
|
||||
|
||||
vars := map[string]string{
|
||||
"Version": Version,
|
||||
"GitCommit": GitCommit,
|
||||
"GitBranch": GitBranch,
|
||||
"BuildTime": BuildTime,
|
||||
"GoVersion": GoVersion,
|
||||
"BuildPlatform": BuildPlatform,
|
||||
}
|
||||
|
||||
for name, value := range vars {
|
||||
if value == "" {
|
||||
t.Errorf("%s is empty, expected non-empty value", name)
|
||||
}
|
||||
if value == "unknown" {
|
||||
t.Errorf("%s = %q, expected set value", name, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVersionConsistency 测试版本信息的一致性。
|
||||
func TestVersionConsistency(t *testing.T) {
|
||||
// 保存原始值
|
||||
originalVersion := Version
|
||||
originalGitCommit := GitCommit
|
||||
originalGitBranch := GitBranch
|
||||
originalBuildTime := BuildTime
|
||||
originalGoVersion := GoVersion
|
||||
originalBuildPlatform := BuildPlatform
|
||||
|
||||
t.Cleanup(func() {
|
||||
Version = originalVersion
|
||||
GitCommit = originalGitCommit
|
||||
GitBranch = originalGitBranch
|
||||
BuildTime = originalBuildTime
|
||||
GoVersion = originalGoVersion
|
||||
BuildPlatform = originalBuildPlatform
|
||||
})
|
||||
|
||||
t.Run("语义化版本格式", func(t *testing.T) {
|
||||
testVersions := []string{
|
||||
"v1.0.0",
|
||||
"v2.1.3",
|
||||
"v0.0.1",
|
||||
"v10.20.30",
|
||||
}
|
||||
|
||||
for _, v := range testVersions {
|
||||
Version = v
|
||||
// 验证版本号格式
|
||||
if !strings.HasPrefix(Version, "v") {
|
||||
t.Errorf("Version = %q, should start with 'v'", Version)
|
||||
}
|
||||
// 验证版本号包含点号
|
||||
if !strings.Contains(Version[1:], ".") {
|
||||
t.Errorf("Version = %q, should contain '.' after 'v'", Version)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("开发版本标识", func(t *testing.T) {
|
||||
Version = "dev"
|
||||
if Version != "dev" {
|
||||
t.Errorf("Version = %q, want %q", Version, "dev")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestBuildPlatformVariants 测试不同构建平台格式。
|
||||
func TestBuildPlatformVariants(t *testing.T) {
|
||||
// 保存原始值
|
||||
originalBuildPlatform := BuildPlatform
|
||||
t.Cleanup(func() {
|
||||
BuildPlatform = originalBuildPlatform
|
||||
})
|
||||
|
||||
platforms := []string{
|
||||
"linux/amd64",
|
||||
"linux/arm64",
|
||||
"darwin/amd64",
|
||||
"darwin/arm64",
|
||||
"windows/amd64",
|
||||
"freebsd/amd64",
|
||||
}
|
||||
|
||||
for _, platform := range platforms {
|
||||
t.Run(platform, func(t *testing.T) {
|
||||
BuildPlatform = platform
|
||||
if BuildPlatform != platform {
|
||||
t.Errorf("BuildPlatform = %q, want %q", BuildPlatform, platform)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user