test(proxy,ssl,server,variable): 补全测试覆盖
- websocket: 升级请求构建、响应读写、大消息转发、并发桥接 - ssl: CRL 吊销检查、证书链深度限制、完整验证流程 - server: 初始化配置、静态文件、GoroutinePool、FileCache - variable: mTLS 客户端证书变量和指纹计算 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
01343ce783
commit
eb379d9121
@ -10,17 +10,28 @@
|
||||
// - 数据复制
|
||||
// - 双向数据转发
|
||||
// - 超时错误处理
|
||||
// - 并发连接测试
|
||||
// - 大消息转发测试
|
||||
//
|
||||
// goroutine 泄漏检测说明:
|
||||
// fasthttp 库使用后台 worker goroutine,与 goleak 不兼容。
|
||||
// 如需检测泄漏,可手动运行:go test -race ./internal/proxy/...
|
||||
//
|
||||
// 作者:xfy
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// TestNewWebSocketBridge 测试桥接器创建
|
||||
@ -121,12 +132,6 @@ func TestIsConnectionClosedError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractHost 测试从 URL 提取主机
|
||||
func TestExtractHost(_ *testing.T) {
|
||||
// extractHost 函数可能不存在,检查一下
|
||||
// 如果存在则测试
|
||||
}
|
||||
|
||||
// TestDialTarget_InvalidAddress 测试无效地址的拨号
|
||||
func TestDialTarget_InvalidAddress(t *testing.T) {
|
||||
// 测试连接到无效端口
|
||||
@ -311,3 +316,469 @@ func TestCopyData(t *testing.T) {
|
||||
t.Error("copyData did not complete in time")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildWebSocketUpgradeRequest 测试构建 WebSocket 升级请求
|
||||
func TestBuildWebSocketUpgradeRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
query string
|
||||
host string
|
||||
targetHost string
|
||||
wantContains []string
|
||||
}{
|
||||
{
|
||||
name: "basic request",
|
||||
path: "/ws",
|
||||
query: "",
|
||||
host: "client.example.com",
|
||||
targetHost: "backend.example.com:8080",
|
||||
wantContains: []string{
|
||||
"GET /ws HTTP/1.1",
|
||||
"Host: backend.example.com:8080",
|
||||
"X-Forwarded-For:",
|
||||
"X-Real-IP:",
|
||||
"X-Forwarded-Host: client.example.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "request with query",
|
||||
path: "/ws",
|
||||
query: "token=abc123",
|
||||
host: "client.example.com",
|
||||
targetHost: "backend.example.com",
|
||||
wantContains: []string{
|
||||
"GET /ws?token=abc123 HTTP/1.1",
|
||||
"Host: backend.example.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty path defaults to slash",
|
||||
path: "",
|
||||
query: "",
|
||||
host: "client.example.com",
|
||||
targetHost: "backend.example.com",
|
||||
wantContains: []string{
|
||||
"GET / HTTP/1.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI(tt.path)
|
||||
if tt.query != "" {
|
||||
ctx.QueryArgs().Parse(tt.query)
|
||||
}
|
||||
ctx.Request.Header.SetHost(tt.host)
|
||||
|
||||
result := buildWebSocketUpgradeRequest(ctx, tt.targetHost)
|
||||
|
||||
for _, want := range tt.wantContains {
|
||||
if !strings.Contains(result, want) {
|
||||
t.Errorf("buildWebSocketUpgradeRequest() missing %q in output:\n%s", want, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildWebSocketUpgradeRequest_WithHeaders 测试复制 WebSocket 头
|
||||
func TestBuildWebSocketUpgradeRequest_WithHeaders(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/ws")
|
||||
ctx.Request.Header.Set("Upgrade", "websocket")
|
||||
ctx.Request.Header.Set("Connection", "Upgrade")
|
||||
ctx.Request.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
ctx.Request.Header.Set("Sec-WebSocket-Version", "13")
|
||||
ctx.Request.Header.Set("Sec-WebSocket-Protocol", "chat")
|
||||
|
||||
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com")
|
||||
|
||||
// 验证关键头被复制
|
||||
expectedHeaders := []string{
|
||||
"Upgrade: websocket",
|
||||
"Connection: Upgrade",
|
||||
"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==",
|
||||
"Sec-WebSocket-Version: 13",
|
||||
"Sec-WebSocket-Protocol: chat",
|
||||
}
|
||||
|
||||
for _, expected := range expectedHeaders {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Missing expected header %q in:\n%s", expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildWebSocketUpgradeRequest_TLSProto 测试 TLS 协议标记
|
||||
func TestBuildWebSocketUpgradeRequest_TLSProto(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isTLS bool
|
||||
wantProto string
|
||||
}{
|
||||
{
|
||||
name: "non-TLS connection",
|
||||
isTLS: false,
|
||||
wantProto: "X-Forwarded-Proto: http",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/ws")
|
||||
|
||||
// 注意:fasthttp.RequestCtx 默认 IsTLS() 返回 false
|
||||
// 无法在单元测试中直接模拟 TLS 连接
|
||||
|
||||
result := buildWebSocketUpgradeRequest(ctx, "backend.example.com")
|
||||
|
||||
if !strings.Contains(result, tt.wantProto) {
|
||||
t.Errorf("Missing %q in:\n%s", tt.wantProto, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractHost 测试从 URL 提取主机
|
||||
func TestExtractHost(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "http with port",
|
||||
url: "http://example.com:8080",
|
||||
expected: "example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "https with port",
|
||||
url: "https://example.com:8443",
|
||||
expected: "example.com:8443",
|
||||
},
|
||||
{
|
||||
name: "http without port",
|
||||
url: "http://example.com",
|
||||
expected: "example.com:80",
|
||||
},
|
||||
{
|
||||
name: "https without port",
|
||||
url: "https://example.com",
|
||||
expected: "example.com:443",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractHost(tt.url)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractHost(%q) = %q, want %q", tt.url, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWriteUpgradeResponse 测试写入升级响应
|
||||
func TestWriteUpgradeResponse(t *testing.T) {
|
||||
// 创建管道连接
|
||||
conn1, conn2 := net.Pipe()
|
||||
defer func() { _ = conn2.Close() }()
|
||||
|
||||
// 创建模拟 HTTP 响应
|
||||
resp := &http.Response{
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Status: "101 Switching Protocols",
|
||||
StatusCode: 101,
|
||||
Header: http.Header{
|
||||
"Upgrade": []string{"websocket"},
|
||||
"Connection": []string{"Upgrade"},
|
||||
},
|
||||
}
|
||||
|
||||
// 启动写入
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- writeUpgradeResponse(conn1, resp)
|
||||
_ = conn1.Close()
|
||||
}()
|
||||
|
||||
// 读取响应
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn2.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response: %v", err)
|
||||
}
|
||||
|
||||
response := string(buf[:n])
|
||||
|
||||
// 验证响应格式
|
||||
expectedParts := []string{
|
||||
"HTTP/1.1 101 Switching Protocols",
|
||||
"Upgrade: websocket",
|
||||
"Connection: Upgrade",
|
||||
}
|
||||
|
||||
for _, expected := range expectedParts {
|
||||
if !strings.Contains(response, expected) {
|
||||
t.Errorf("Missing %q in response:\n%s", expected, response)
|
||||
}
|
||||
}
|
||||
|
||||
// 等待写入完成
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Errorf("writeUpgradeResponse returned error: %v", err)
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("writeUpgradeResponse did not complete in time")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadWebSocketUpgradeResponse 测试读取升级响应
|
||||
func TestReadWebSocketUpgradeResponse(t *testing.T) {
|
||||
// 创建管道连接
|
||||
conn1, conn2 := net.Pipe()
|
||||
defer func() { _ = conn1.Close() }()
|
||||
|
||||
// 在另一个 goroutine 中写入响应
|
||||
go func() {
|
||||
response := "HTTP/1.1 101 Switching Protocols\r\n" +
|
||||
"Upgrade: websocket\r\n" +
|
||||
"Connection: Upgrade\r\n" +
|
||||
"\r\n"
|
||||
_, _ = conn2.Write([]byte(response))
|
||||
_ = conn2.Close()
|
||||
}()
|
||||
|
||||
// 读取响应
|
||||
resp, err := readWebSocketUpgradeResponse(conn1, 1*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("readWebSocketUpgradeResponse failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 101 {
|
||||
t.Errorf("Expected status 101, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if resp.Header.Get("Upgrade") != "websocket" {
|
||||
t.Errorf("Expected Upgrade: websocket, got %q", resp.Header.Get("Upgrade"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadWebSocketUpgradeResponse_Timeout 测试读取超时
|
||||
func TestReadWebSocketUpgradeResponse_Timeout(t *testing.T) {
|
||||
// 创建管道连接但不写入数据
|
||||
conn1, conn2 := net.Pipe()
|
||||
defer func() { _ = conn1.Close() }()
|
||||
defer func() { _ = conn2.Close() }()
|
||||
|
||||
// 使用很短的超时
|
||||
_, err := readWebSocketUpgradeResponse(conn1, 10*time.Millisecond)
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDialTarget_TLS 测试 TLS 连接(连接无效端口应失败)
|
||||
func TestDialTarget_TLS(t *testing.T) {
|
||||
// 测试 HTTPS 连接到无效端口
|
||||
_, err := dialTarget("https://127.0.0.1:1", 100*time.Millisecond)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid HTTPS address")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsConnectionClosedError_ClosedConn 测试已关闭连接错误
|
||||
func TestIsConnectionClosedError_ClosedConn(t *testing.T) {
|
||||
// 创建并立即关闭连接
|
||||
ln, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
conn, _ := net.Dial("tcp", ln.Addr().String())
|
||||
_ = conn.Close()
|
||||
_ = ln.Close()
|
||||
|
||||
// 尝试读取应返回错误
|
||||
_, err := conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Error("Expected error reading from closed connection")
|
||||
}
|
||||
|
||||
// 验证错误被识别为连接关闭错误
|
||||
if !isConnectionClosedError(err) {
|
||||
t.Errorf("Expected closed connection error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebSocketBridge_LargeMessage 测试大消息转发
|
||||
func TestWebSocketBridge_LargeMessage(t *testing.T) {
|
||||
// 创建管道连接
|
||||
client1, client2 := net.Pipe()
|
||||
target1, target2 := net.Pipe()
|
||||
defer func() { _ = client2.Close() }()
|
||||
defer func() { _ = target2.Close() }()
|
||||
|
||||
bridge := NewWebSocketBridge(client1, target1)
|
||||
|
||||
// 启动桥接
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- bridge.Bridge()
|
||||
}()
|
||||
|
||||
// 发送超过 64KB 的数据
|
||||
largeData := make([]byte, 100*1024) // 100KB
|
||||
for i := range largeData {
|
||||
largeData[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
// 客户端发送大消息
|
||||
go func() {
|
||||
_, _ = client2.Write(largeData)
|
||||
}()
|
||||
|
||||
// 后端接收数据
|
||||
buf := make([]byte, 150*1024)
|
||||
total := 0
|
||||
for total < len(largeData) {
|
||||
n, err := target2.Read(buf[total:])
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read large message: %v", err)
|
||||
}
|
||||
total += n
|
||||
}
|
||||
|
||||
// 验证数据完整性
|
||||
for i := range largeData {
|
||||
if buf[i] != largeData[i] {
|
||||
t.Errorf("Data mismatch at byte %d: got %d, want %d", i, buf[i], largeData[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭连接
|
||||
_ = client2.Close()
|
||||
_ = target2.Close()
|
||||
|
||||
// 等待桥接完成
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Errorf("Bridge returned error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Bridge did not complete in time")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebSocketBridge_Concurrent 测试并发桥接
|
||||
func TestWebSocketBridge_Concurrent(t *testing.T) {
|
||||
const numBridges = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, numBridges)
|
||||
|
||||
for i := 0; i < numBridges; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// 创建管道连接
|
||||
client1, client2 := net.Pipe()
|
||||
target1, target2 := net.Pipe()
|
||||
defer func() { _ = client2.Close() }()
|
||||
defer func() { _ = target2.Close() }()
|
||||
|
||||
bridge := NewWebSocketBridge(client1, target1)
|
||||
|
||||
// 启动桥接
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- bridge.Bridge()
|
||||
}()
|
||||
|
||||
// 发送测试数据
|
||||
testData := []byte("concurrent test data")
|
||||
go func() {
|
||||
_, _ = client2.Write(testData)
|
||||
}()
|
||||
|
||||
// 接收数据
|
||||
buf := make([]byte, 1024)
|
||||
n, err := target2.Read(buf)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("bridge %d: read error: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
if string(buf[:n]) != string(testData) {
|
||||
errCh <- fmt.Errorf("bridge %d: data mismatch", id)
|
||||
return
|
||||
}
|
||||
|
||||
// 关闭连接
|
||||
_ = client2.Close()
|
||||
_ = target2.Close()
|
||||
|
||||
// 等待桥接完成
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("bridge %d: %v", id, err)
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
errCh <- fmt.Errorf("bridge %d: timeout", id)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
// 检查错误
|
||||
for err := range errCh {
|
||||
if err != nil {
|
||||
t.Errorf("Concurrent bridge error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopyData_WriteError 测试写入错误处理
|
||||
func TestCopyData_WriteError(t *testing.T) {
|
||||
// 创建管道连接
|
||||
src1, src2 := net.Pipe()
|
||||
dst1, dst2 := net.Pipe()
|
||||
|
||||
bridge := &WebSocketBridge{}
|
||||
|
||||
// 先关闭目标连接
|
||||
_ = dst1.Close()
|
||||
_ = dst2.Close()
|
||||
|
||||
// 启动数据复制
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- bridge.copyData(dst1, src1, "test")
|
||||
}()
|
||||
|
||||
// 发送数据(应该触发写入错误)
|
||||
_, _ = src2.Write([]byte("test data"))
|
||||
_ = src2.Close()
|
||||
|
||||
// 等待完成
|
||||
select {
|
||||
case err := <-errCh:
|
||||
// 写入到已关闭连接应该返回 nil(视为连接关闭错误)
|
||||
if err != nil && !strings.Contains(err.Error(), "closed") {
|
||||
t.Errorf("copyData returned unexpected error: %v", err)
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("copyData did not complete in time")
|
||||
}
|
||||
|
||||
_ = src1.Close()
|
||||
}
|
||||
|
||||
@ -638,3 +638,130 @@ func TestServer_TrackStats_EmptyBody(t *testing.T) {
|
||||
t.Errorf("Expected 0 bytes sent, got %d", s.bytesSent.Load())
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_Success 测试服务器配置初始化
|
||||
func TestStart_Success(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":8080",
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
// 验证服务器正确初始化
|
||||
if s == nil {
|
||||
t.Fatal("New() returned nil, expected non-nil Server")
|
||||
}
|
||||
|
||||
if s.config != cfg {
|
||||
t.Error("Server.config not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithStaticFiles 测试静态文件配置
|
||||
func TestStart_WithStaticFiles(t *testing.T) {
|
||||
// 创建临时目录
|
||||
tempDir := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":8080",
|
||||
Static: []config.StaticConfig{{
|
||||
Path: "/static",
|
||||
Root: tempDir,
|
||||
Index: []string{"index.html"},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
if s == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithGoroutinePool 测试 GoroutinePool 配置
|
||||
func TestStart_WithGoroutinePool(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":8080",
|
||||
},
|
||||
Performance: config.PerformanceConfig{
|
||||
GoroutinePool: config.GoroutinePoolConfig{
|
||||
Enabled: true,
|
||||
MaxWorkers: 100,
|
||||
MinWorkers: 10,
|
||||
IdleTimeout: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
if s == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStart_WithFileCache 测试文件缓存配置
|
||||
func TestStart_WithFileCache(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":8080",
|
||||
},
|
||||
Performance: config.PerformanceConfig{
|
||||
FileCache: config.FileCacheConfig{
|
||||
MaxEntries: 1000,
|
||||
MaxSize: 100 * 1024 * 1024,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
if s == nil {
|
||||
t.Fatal("New() returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestStop_Graceful 测试优雅停止(无 race 模式)
|
||||
func TestStop_Graceful(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":0",
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
// 在未启动时调用 GracefulStop,应返回 nil
|
||||
err := s.GracefulStop(1 * time.Second)
|
||||
if err != nil {
|
||||
t.Errorf("GracefulStop() on non-started server returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetTLSConfig_Nil 测试无 TLS 配置
|
||||
func TestGetTLSConfig_Nil(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Listen: ":0",
|
||||
},
|
||||
}
|
||||
|
||||
s := New(cfg)
|
||||
|
||||
tlsCfg, err := s.GetTLSConfig()
|
||||
if err == nil {
|
||||
t.Error("GetTLSConfig() should return error when TLS not configured")
|
||||
}
|
||||
if tlsCfg != nil {
|
||||
t.Error("GetTLSConfig() should return nil when TLS not configured")
|
||||
}
|
||||
}
|
||||
|
||||
@ -61,8 +61,8 @@ func generateTestCA(t *testing.T) (*x509.Certificate, *rsa.PrivateKey, []byte) {
|
||||
return cert, key, certPEM
|
||||
}
|
||||
|
||||
// generateTestClientCert 生成测试客户端证书。
|
||||
func generateTestClientCert(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, []byte) {
|
||||
// generateTestClientCert 生成测试客户端证书,serial 参数指定序列号。
|
||||
func generateTestClientCert(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey, serial int64) (*x509.Certificate, *rsa.PrivateKey, []byte) {
|
||||
t.Helper()
|
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
@ -71,7 +71,7 @@ func generateTestClientCert(t *testing.T, caCert *x509.Certificate, caKey *rsa.P
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2),
|
||||
SerialNumber: big.NewInt(serial),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "Test Client",
|
||||
Organization: []string{"Test Org"},
|
||||
@ -297,7 +297,7 @@ func TestClientVerifier_ConfigureTLS_Disabled(t *testing.T) {
|
||||
func TestGetClientCertInfo(t *testing.T) {
|
||||
// 生成测试证书
|
||||
caCert, caKey, _ := generateTestCA(t)
|
||||
clientCert, _, _ := generateTestClientCert(t, caCert, caKey)
|
||||
clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 2)
|
||||
|
||||
// 创建模拟连接状态
|
||||
cs := &tls.ConnectionState{
|
||||
@ -335,6 +335,350 @@ func TestGetClientCertInfo_Nil(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetMode 测试获取验证模式。
|
||||
func TestGetMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
expected ClientVerifyMode
|
||||
}{
|
||||
{"off", "off", VerifyOff},
|
||||
{"on", "on", VerifyOn},
|
||||
{"optional", "optional", VerifyOptional},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
_, _, caPEM := generateTestCA(t)
|
||||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||||
t.Fatalf("写入 CA 文件失败: %v", err)
|
||||
}
|
||||
|
||||
var cfg config.ClientVerifyConfig
|
||||
if tt.mode != "off" {
|
||||
cfg = config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: tt.mode,
|
||||
ClientCA: caFile,
|
||||
}
|
||||
} else {
|
||||
cfg = config.ClientVerifyConfig{Enabled: false}
|
||||
}
|
||||
|
||||
verifier, err := NewClientVerifier(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientVerifier() failed: %v", err)
|
||||
}
|
||||
|
||||
if verifier.GetMode() != tt.expected {
|
||||
t.Errorf("GetMode() = %v, want %v", verifier.GetMode(), tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// generateTestCRL 生成测试 CRL。
|
||||
func generateTestCRL(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey, revokedSerials []*big.Int) []byte {
|
||||
t.Helper()
|
||||
|
||||
template := &x509.RevocationList{
|
||||
Number: big.NewInt(1),
|
||||
ThisUpdate: time.Now(),
|
||||
NextUpdate: time.Now().Add(24 * time.Hour),
|
||||
RevokedCertificateEntries: func() []x509.RevocationListEntry {
|
||||
entries := make([]x509.RevocationListEntry, len(revokedSerials))
|
||||
for i, serial := range revokedSerials {
|
||||
entries[i] = x509.RevocationListEntry{
|
||||
SerialNumber: serial,
|
||||
RevocationTime: time.Now(),
|
||||
}
|
||||
}
|
||||
return entries
|
||||
}(),
|
||||
}
|
||||
|
||||
crlDER, err := x509.CreateRevocationList(rand.Reader, template, caCert, caKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create CRL: %v", err)
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "X509 CRL", Bytes: crlDER})
|
||||
}
|
||||
|
||||
// TestLoadCRL 测试 CRL 加载。
|
||||
func TestLoadCRL(t *testing.T) {
|
||||
// 生成测试 CA
|
||||
caCert, caKey, _ := generateTestCA(t)
|
||||
|
||||
// 生成包含吊销证书的 CRL
|
||||
revokedSerial := big.NewInt(999)
|
||||
crlPEM := generateTestCRL(t, caCert, caKey, []*big.Int{revokedSerial})
|
||||
|
||||
// 写入临时文件
|
||||
tempDir := t.TempDir()
|
||||
crlFile := filepath.Join(tempDir, "crl.pem")
|
||||
if err := os.WriteFile(crlFile, crlPEM, 0644); err != nil {
|
||||
t.Fatalf("写入 CRL 文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试加载
|
||||
crl, err := LoadCRL(crlFile)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadCRL() failed: %v", err)
|
||||
}
|
||||
if crl == nil {
|
||||
t.Fatal("LoadCRL() returned nil")
|
||||
}
|
||||
if len(crl.RevokedCertificateEntries) != 1 {
|
||||
t.Errorf("CRL should have 1 revoked certificate, got %d", len(crl.RevokedCertificateEntries))
|
||||
}
|
||||
|
||||
// 测试文件不存在
|
||||
_, err = LoadCRL("/nonexistent/crl.pem")
|
||||
if err == nil {
|
||||
t.Error("LoadCRL() should fail for non-existent file")
|
||||
}
|
||||
|
||||
// 测试无效 CRL
|
||||
invalidFile := filepath.Join(tempDir, "invalid.crl")
|
||||
if err := os.WriteFile(invalidFile, []byte("invalid data"), 0644); err != nil {
|
||||
t.Fatalf("写入无效文件失败: %v", err)
|
||||
}
|
||||
_, err = LoadCRL(invalidFile)
|
||||
if err == nil {
|
||||
t.Error("LoadCRL() should fail for invalid CRL")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckCRL 测试 CRL 检查。
|
||||
func TestCheckCRL(t *testing.T) {
|
||||
// 生成测试 CA
|
||||
caCert, caKey, _ := generateTestCA(t)
|
||||
|
||||
// 生成将被吊销的客户端证书(序列号100)
|
||||
revokedCert, _, _ := generateTestClientCert(t, caCert, caKey, 100)
|
||||
|
||||
// 生成有效客户端证书(序列号200,不会被吊销)
|
||||
validCert, _, _ := generateTestClientCert(t, caCert, caKey, 200)
|
||||
|
||||
// 生成包含吊销证书的 CRL
|
||||
crlPEM := generateTestCRL(t, caCert, caKey, []*big.Int{revokedCert.SerialNumber})
|
||||
|
||||
// 写入临时文件
|
||||
tempDir := t.TempDir()
|
||||
crlFile := filepath.Join(tempDir, "crl.pem")
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
_, _, caPEM := generateTestCA(t)
|
||||
if err := os.WriteFile(crlFile, crlPEM, 0644); err != nil {
|
||||
t.Fatalf("写入 CRL 文件失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||||
t.Fatalf("写入 CA 文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建带 CRL 的验证器
|
||||
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: "on",
|
||||
ClientCA: caFile,
|
||||
CRL: crlFile,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientVerifier() failed: %v", err)
|
||||
}
|
||||
|
||||
// 测试检查有效证书
|
||||
err = verifier.ValidateClientCertificate(validCert)
|
||||
if err != nil {
|
||||
t.Errorf("CheckCRL() should pass for valid cert: %v", err)
|
||||
}
|
||||
|
||||
// 测试检查吊销证书
|
||||
err = verifier.ValidateClientCertificate(revokedCert)
|
||||
if err == nil {
|
||||
t.Error("CheckCRL() should fail for revoked cert")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckCRL_EmptyCRL 测试空 CRL。
|
||||
func TestCheckCRL_EmptyCRL(t *testing.T) {
|
||||
caCert, caKey, _ := generateTestCA(t)
|
||||
clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 50)
|
||||
|
||||
// 生成空 CRL(无吊销证书)
|
||||
crlPEM := generateTestCRL(t, caCert, caKey, nil)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
crlFile := filepath.Join(tempDir, "crl.pem")
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
_, _, caPEM := generateTestCA(t)
|
||||
if err := os.WriteFile(crlFile, crlPEM, 0644); err != nil {
|
||||
t.Fatalf("写入 CRL 文件失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||||
t.Fatalf("写入 CA 文件失败: %v", err)
|
||||
}
|
||||
|
||||
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: "on",
|
||||
ClientCA: caFile,
|
||||
CRL: crlFile,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientVerifier() failed: %v", err)
|
||||
}
|
||||
|
||||
// 空列表应通过所有证书
|
||||
err = verifier.ValidateClientCertificate(clientCert)
|
||||
if err != nil {
|
||||
t.Errorf("CheckCRL() should pass with empty CRL: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateClientCertificate 测试手动验证客户端证书。
|
||||
func TestValidateClientCertificate(t *testing.T) {
|
||||
// 测试禁用验证器
|
||||
verifier, _ := NewClientVerifier(config.ClientVerifyConfig{Enabled: false})
|
||||
|
||||
err := verifier.ValidateClientCertificate(nil)
|
||||
if err != nil {
|
||||
t.Errorf("Disabled verifier should accept nil cert: %v", err)
|
||||
}
|
||||
|
||||
// 测试启用验证器(on 模式)
|
||||
tempDir := t.TempDir()
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
_, _, caPEM := generateTestCA(t)
|
||||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||||
t.Fatalf("写入 CA 文件失败: %v", err)
|
||||
}
|
||||
|
||||
verifier, _ = NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: "on",
|
||||
ClientCA: caFile,
|
||||
})
|
||||
|
||||
// nil 证书在 on 模式应失败
|
||||
err = verifier.ValidateClientCertificate(nil)
|
||||
if err == nil {
|
||||
t.Error("ValidateClientCertificate(nil) should fail in 'on' mode")
|
||||
}
|
||||
|
||||
// 有效证书应通过
|
||||
caCert, caKey, _ := generateTestCA(t)
|
||||
clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 30)
|
||||
err = verifier.ValidateClientCertificate(clientCert)
|
||||
if err != nil {
|
||||
t.Errorf("ValidateClientCertificate() should pass for valid cert: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyConnection 测试连接验证。
|
||||
func TestVerifyConnection(t *testing.T) {
|
||||
// 生成测试 CA 和证书
|
||||
caCert, caKey, _ := generateTestCA(t)
|
||||
validCert, _, _ := generateTestClientCert(t, caCert, caKey, 300)
|
||||
revokedCert, _, _ := generateTestClientCert(t, caCert, caKey, 400)
|
||||
|
||||
// 生成包含吊销证书的 CRL
|
||||
crlPEM := generateTestCRL(t, caCert, caKey, []*big.Int{revokedCert.SerialNumber})
|
||||
|
||||
tempDir := t.TempDir()
|
||||
crlFile := filepath.Join(tempDir, "crl.pem")
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
_, _, caPEM := generateTestCA(t)
|
||||
if err := os.WriteFile(crlFile, crlPEM, 0644); err != nil {
|
||||
t.Fatalf("写入 CRL 文件失败: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||||
t.Fatalf("写入 CA 文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试带 CRL 和深度限制的验证器
|
||||
verifier, err := NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: "on",
|
||||
ClientCA: caFile,
|
||||
CRL: crlFile,
|
||||
VerifyDepth: 3,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewClientVerifier() failed: %v", err)
|
||||
}
|
||||
|
||||
// 配置 TLS 以设置 VerifyConnection 回调
|
||||
tlsCfg := &tls.Config{}
|
||||
verifier.ConfigureTLS(tlsCfg)
|
||||
|
||||
if tlsCfg.VerifyConnection == nil {
|
||||
t.Fatal("VerifyConnection should be set when VerifyDepth > 0")
|
||||
}
|
||||
|
||||
// 测试有效证书连接
|
||||
validCS := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{validCert},
|
||||
}
|
||||
err = tlsCfg.VerifyConnection(*validCS)
|
||||
if err != nil {
|
||||
t.Errorf("VerifyConnection() should pass for valid cert: %v", err)
|
||||
}
|
||||
|
||||
// 测试吊销证书连接
|
||||
revokedCS := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{revokedCert},
|
||||
}
|
||||
err = tlsCfg.VerifyConnection(*revokedCS)
|
||||
if err == nil {
|
||||
t.Error("VerifyConnection() should fail for revoked cert")
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyConnection_DepthLimit 测试证书链深度限制。
|
||||
func TestVerifyConnection_DepthLimit(t *testing.T) {
|
||||
caCert, caKey, _ := generateTestCA(t)
|
||||
clientCert, _, _ := generateTestClientCert(t, caCert, caKey, 500)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
caFile := filepath.Join(tempDir, "ca.crt")
|
||||
_, _, caPEM := generateTestCA(t)
|
||||
if err := os.WriteFile(caFile, caPEM, 0644); err != nil {
|
||||
t.Fatalf("写入 CA 文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试深度限制为 1
|
||||
verifier, _ := NewClientVerifier(config.ClientVerifyConfig{
|
||||
Enabled: true,
|
||||
Mode: "on",
|
||||
ClientCA: caFile,
|
||||
VerifyDepth: 1,
|
||||
})
|
||||
|
||||
tlsCfg := &tls.Config{}
|
||||
verifier.ConfigureTLS(tlsCfg)
|
||||
|
||||
// 单个证书应通过
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{clientCert},
|
||||
}
|
||||
err := tlsCfg.VerifyConnection(*cs)
|
||||
if err != nil {
|
||||
t.Errorf("VerifyConnection() should pass for single cert with depth 1: %v", err)
|
||||
}
|
||||
|
||||
// 多个证书应失败(链太长)
|
||||
longChain := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{clientCert, caCert},
|
||||
}
|
||||
err = tlsCfg.VerifyConnection(*longChain)
|
||||
if err == nil {
|
||||
t.Error("VerifyConnection() should fail for chain exceeding depth limit")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLoadCACertPool 基准测试 CA 证书池加载。
|
||||
func BenchmarkLoadCACertPool(b *testing.B) {
|
||||
tempDir := b.TempDir()
|
||||
|
||||
@ -458,3 +458,101 @@ func TestOCSPConfigDefaults(t *testing.T) {
|
||||
t.Errorf("Expected default max retries 3, got %d", cfg.MaxRetries)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCSPManager_RefreshResponse 测试强制刷新 OCSP 响应
|
||||
func TestOCSPManager_RefreshResponse(_ *testing.T) {
|
||||
cfg := &OCSPConfig{
|
||||
Enabled: true,
|
||||
RefreshInterval: 1 * time.Hour,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
MaxRetries: 1,
|
||||
}
|
||||
mgr := NewOCSPManager(cfg)
|
||||
|
||||
// 创建带 OCSP 服务器的测试证书
|
||||
serial := big.NewInt(12345)
|
||||
priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
OCSPServer: []string{"http://ocsp.example.com"},
|
||||
}
|
||||
certDER, _ := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
cert, _ := x509.ParseCertificate(certDER)
|
||||
|
||||
// 刷新响应(会失败因为 URL 无效)
|
||||
err := mgr.RefreshResponse(cert, cert)
|
||||
// 由于 URL 无效,预期会失败
|
||||
if err == nil {
|
||||
// 如果没有错误,检查状态
|
||||
status, hasResp := mgr.GetStatus(serial.String())
|
||||
_ = status
|
||||
_ = hasResp
|
||||
}
|
||||
}
|
||||
|
||||
// TestOCSPManager_refreshAll 测试刷新所有响应
|
||||
func TestOCSPManager_refreshAll(_ *testing.T) {
|
||||
cfg := &OCSPConfig{
|
||||
Enabled: true,
|
||||
RefreshInterval: 1 * time.Hour,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
MaxRetries: 1,
|
||||
}
|
||||
mgr := NewOCSPManager(cfg)
|
||||
|
||||
// 手动添加一些响应到缓存
|
||||
serial1 := "1001"
|
||||
serial2 := "1002"
|
||||
|
||||
mgr.mu.Lock()
|
||||
mgr.responses[serial1] = &ocspResponse{
|
||||
status: statusValid,
|
||||
response: []byte("test-response"),
|
||||
nextUpdate: time.Now().Add(-1 * time.Hour), // 已过期
|
||||
fetchedAt: time.Now().Add(-2 * time.Hour),
|
||||
}
|
||||
mgr.responses[serial2] = &ocspResponse{
|
||||
status: statusValid,
|
||||
response: []byte("test-response-2"),
|
||||
nextUpdate: time.Now().Add(1 * time.Hour), // 未过期
|
||||
fetchedAt: time.Now(),
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
|
||||
// 调用 refreshAll
|
||||
mgr.refreshAll()
|
||||
|
||||
// 验证刷新逻辑被触发(无法验证实际刷新因为 URL 无效)
|
||||
// 主要目的是确保代码路径被覆盖
|
||||
}
|
||||
|
||||
// TestOCSPManager_GetStatus_EdgeCases 测试 GetStatus 边界情况
|
||||
func TestOCSPManager_GetStatus_EdgeCases(t *testing.T) {
|
||||
cfg := DefaultOCSPConfig()
|
||||
mgr := NewOCSPManager(cfg)
|
||||
|
||||
// 测试不存在的序列号
|
||||
status, hasResp := mgr.GetStatus("nonexistent")
|
||||
if hasResp {
|
||||
t.Error("Expected no response for nonexistent serial")
|
||||
}
|
||||
if status != statusFailed {
|
||||
t.Errorf("Expected statusFailed for nonexistent serial, got %v", status)
|
||||
}
|
||||
|
||||
// 测试空响应
|
||||
serial := "empty-response"
|
||||
mgr.mu.Lock()
|
||||
mgr.responses[serial] = &ocspResponse{
|
||||
status: statusValid,
|
||||
response: nil, // 空响应
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
|
||||
_, hasResp = mgr.GetStatus(serial)
|
||||
if hasResp {
|
||||
t.Error("Expected no response for empty response data")
|
||||
}
|
||||
}
|
||||
|
||||
591
internal/variable/ssl_test.go
Normal file
591
internal/variable/ssl_test.go
Normal file
@ -0,0 +1,591 @@
|
||||
// ssl_test.go - SSL/TLS 客户端证书变量测试
|
||||
//
|
||||
// 测试覆盖:
|
||||
// - mTLS 客户端证书变量获取
|
||||
// - SetSSLClientInfoInContext 设置功能
|
||||
// - calculateFingerprint 指纹计算
|
||||
//
|
||||
// 作者:xfy
|
||||
package variable
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// TestGetSSLClientVerify_NilContext 测试 nil 上下文
|
||||
func TestGetSSLClientVerify_NilContext(t *testing.T) {
|
||||
result := GetSSLClientVerify(nil)
|
||||
if result != "NONE" {
|
||||
t.Errorf("GetSSLClientVerify(nil) = %q, want NONE", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetSSLClientVerify_NoTLS 测试非 TLS 连接
|
||||
func TestGetSSLClientVerify_NoTLS(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
// 默认情况下 IsTLS() 返回 false
|
||||
result := GetSSLClientVerify(ctx)
|
||||
if result != "NONE" {
|
||||
t.Errorf("GetSSLClientVerify(non-TLS) = %q, want NONE", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetSSLClientVerify_NonTLSWithUserValue 测试非 TLS 连接即使设置了 UserValue 也返回 NONE
|
||||
// 注意:GetSSLClientVerify 会先检查 ctx.IsTLS(),非 TLS 连接直接返回 NONE
|
||||
// 这是正确的行为,SSL 客户端变量只在 TLS 连接中有效
|
||||
func TestGetSSLClientVerify_NonTLSWithUserValue(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue(VarSSLClientVerify, "SUCCESS")
|
||||
|
||||
// 非 TLS 连接,即使设置了 UserValue 也应该返回 NONE
|
||||
result := GetSSLClientVerify(ctx)
|
||||
if result != "NONE" {
|
||||
t.Errorf("GetSSLClientVerify(non-TLS with value) = %q, want NONE", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetSSLClientVerify_PeerCertPresent_NonTLS 测试非 TLS 下 peer_cert_present 不生效
|
||||
func TestGetSSLClientVerify_PeerCertPresent_NonTLS(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue("tls_peer_cert_present", true)
|
||||
|
||||
// 非 TLS 连接,peer_cert_present 不应该改变结果
|
||||
result := GetSSLClientVerify(ctx)
|
||||
if result != "NONE" {
|
||||
t.Errorf("GetSSLClientVerify(non-TLS with peer_cert) = %q, want NONE", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetSSLClientSerial 测试获取序列号
|
||||
func TestGetSSLClientSerial(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*fasthttp.RequestCtx)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no value",
|
||||
setup: func(_ *fasthttp.RequestCtx) {},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with serial",
|
||||
setup: func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetUserValue(VarSSLClientSerial, "1234567890ABCDEF")
|
||||
},
|
||||
expected: "1234567890ABCDEF",
|
||||
},
|
||||
{
|
||||
name: "invalid type",
|
||||
setup: func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetUserValue(VarSSLClientSerial, 12345)
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
tt.setup(ctx)
|
||||
result := GetSSLClientSerial(ctx)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetSSLClientSerial() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetSSLClientSubject 测试获取主题
|
||||
func TestGetSSLClientSubject(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*fasthttp.RequestCtx)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no value",
|
||||
setup: func(_ *fasthttp.RequestCtx) {},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with subject",
|
||||
setup: func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetUserValue(VarSSLClientSubject, "CN=test.example.com,O=Test Org")
|
||||
},
|
||||
expected: "CN=test.example.com,O=Test Org",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
tt.setup(ctx)
|
||||
result := GetSSLClientSubject(ctx)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetSSLClientSubject() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetSSLClientIssuer 测试获取颁发者
|
||||
func TestGetSSLClientIssuer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*fasthttp.RequestCtx)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no value",
|
||||
setup: func(_ *fasthttp.RequestCtx) {},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with issuer",
|
||||
setup: func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetUserValue(VarSSLClientIssuer, "CN=Test CA,O=Test Org")
|
||||
},
|
||||
expected: "CN=Test CA,O=Test Org",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
tt.setup(ctx)
|
||||
result := GetSSLClientIssuer(ctx)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetSSLClientIssuer() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetSSLClientFingerprint 测试获取指纹
|
||||
func TestGetSSLClientFingerprint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*fasthttp.RequestCtx)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no value",
|
||||
setup: func(_ *fasthttp.RequestCtx) {},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with fingerprint",
|
||||
setup: func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetUserValue(VarSSLClientFingerprint, "A1B2C3D4E5F6")
|
||||
},
|
||||
expected: "A1B2C3D4E5F6",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
tt.setup(ctx)
|
||||
result := GetSSLClientFingerprint(ctx)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetSSLClientFingerprint() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetSSLClientNotBefore 测试获取生效时间
|
||||
func TestGetSSLClientNotBefore(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*fasthttp.RequestCtx)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no value",
|
||||
setup: func(_ *fasthttp.RequestCtx) {},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with notbefore",
|
||||
setup: func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetUserValue(VarSSLClientNotBefore, "2025-01-01T00:00:00Z")
|
||||
},
|
||||
expected: "2025-01-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
tt.setup(ctx)
|
||||
result := GetSSLClientNotBefore(ctx)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetSSLClientNotBefore() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetSSLClientNotAfter 测试获取过期时间
|
||||
func TestGetSSLClientNotAfter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*fasthttp.RequestCtx)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no value",
|
||||
setup: func(_ *fasthttp.RequestCtx) {},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with notafter",
|
||||
setup: func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetUserValue(VarSSLClientNotAfter, "2026-01-01T00:00:00Z")
|
||||
},
|
||||
expected: "2026-01-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
tt.setup(ctx)
|
||||
result := GetSSLClientNotAfter(ctx)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetSSLClientNotAfter() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetSSLClientEmail 测试获取邮箱
|
||||
func TestGetSSLClientEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*fasthttp.RequestCtx)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no value",
|
||||
setup: func(_ *fasthttp.RequestCtx) {},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "with email",
|
||||
setup: func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetUserValue(VarSSLClientEmail, "test@example.com")
|
||||
},
|
||||
expected: "test@example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
tt.setup(ctx)
|
||||
result := GetSSLClientEmail(ctx)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetSSLClientEmail() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetSSLClientInfoInContext_NilCtx 测试 nil 上下文
|
||||
func TestSetSSLClientInfoInContext_NilCtx(_ *testing.T) {
|
||||
// 不应该 panic
|
||||
SetSSLClientInfoInContext(nil, &tls.ConnectionState{}, "SUCCESS")
|
||||
}
|
||||
|
||||
// TestSetSSLClientInfoInContext_NilConnState 测试 nil 连接状态
|
||||
func TestSetSSLClientInfoInContext_NilConnState(_ *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
// 不应该 panic
|
||||
SetSSLClientInfoInContext(ctx, nil, "SUCCESS")
|
||||
}
|
||||
|
||||
// TestSetSSLClientInfoInContext_NoPeerCerts 测试无客户端证书
|
||||
func TestSetSSLClientInfoInContext_NoPeerCerts(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{},
|
||||
}
|
||||
|
||||
SetSSLClientInfoInContext(ctx, cs, "NONE")
|
||||
|
||||
// 验证只设置了 verify 状态
|
||||
if v := ctx.UserValue(VarSSLClientVerify); v != "NONE" {
|
||||
t.Errorf("expected verify=NONE, got %v", v)
|
||||
}
|
||||
if v := ctx.UserValue("tls_peer_cert_present"); v != nil {
|
||||
t.Errorf("expected no peer_cert_present, got %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetSSLClientInfoInContext_WithPeerCert 测试有客户端证书
|
||||
func TestSetSSLClientInfoInContext_WithPeerCert(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// 创建模拟证书
|
||||
now := time.Now()
|
||||
cert := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(12345),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "test.example.com",
|
||||
Organization: []string{"Test Org"},
|
||||
},
|
||||
Issuer: pkix.Name{
|
||||
CommonName: "Test CA",
|
||||
},
|
||||
NotBefore: now.Add(-24 * time.Hour),
|
||||
NotAfter: now.Add(365 * 24 * time.Hour),
|
||||
EmailAddresses: []string{"test@example.com"},
|
||||
Raw: make([]byte, 25), // 模拟原始数据(25字节)
|
||||
}
|
||||
// 填充可预测的原始数据
|
||||
for i := 0; i < 25; i++ {
|
||||
cert.Raw[i] = byte(i + 1)
|
||||
}
|
||||
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{cert},
|
||||
}
|
||||
|
||||
SetSSLClientInfoInContext(ctx, cs, "SUCCESS")
|
||||
|
||||
// 验证所有字段
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
expected interface{}
|
||||
}{
|
||||
{"verify", VarSSLClientVerify, "SUCCESS"},
|
||||
{"peer_cert_present", "tls_peer_cert_present", true},
|
||||
{"serial", VarSSLClientSerial, "12345"},
|
||||
{"subject", VarSSLClientSubject, cert.Subject.String()},
|
||||
{"issuer", VarSSLClientIssuer, cert.Issuer.String()},
|
||||
{"email", VarSSLClientEmail, "test@example.com"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := ctx.UserValue(tt.key)
|
||||
if v != tt.expected {
|
||||
t.Errorf("%s = %v, want %v", tt.name, v, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 验证时间格式
|
||||
notBefore := ctx.UserValue(VarSSLClientNotBefore)
|
||||
if notBefore == nil || notBefore == "" {
|
||||
t.Error("notbefore should be set")
|
||||
}
|
||||
notAfter := ctx.UserValue(VarSSLClientNotAfter)
|
||||
if notAfter == nil || notAfter == "" {
|
||||
t.Error("notafter should be set")
|
||||
}
|
||||
|
||||
// 验证指纹
|
||||
fingerprint := ctx.UserValue(VarSSLClientFingerprint)
|
||||
if fingerprint == nil || fingerprint == "" {
|
||||
t.Error("fingerprint should be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetSSLClientInfoInContext_NoEmail 测试证书无邮箱
|
||||
func TestSetSSLClientInfoInContext_NoEmail(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
cert := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "test"},
|
||||
Issuer: pkix.Name{CommonName: "CA"},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
EmailAddresses: []string{}, // 无邮箱
|
||||
Raw: []byte{1, 2, 3},
|
||||
}
|
||||
|
||||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{cert},
|
||||
}
|
||||
|
||||
SetSSLClientInfoInContext(ctx, cs, "SUCCESS")
|
||||
|
||||
// 验证邮箱未设置
|
||||
if v := ctx.UserValue(VarSSLClientEmail); v != nil {
|
||||
t.Errorf("expected no email, got %v", v)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCalculateFingerprint 测试指纹计算
|
||||
func TestCalculateFingerprint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw []byte
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty data",
|
||||
raw: []byte{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "short data (less than 20 bytes)",
|
||||
raw: []byte{1, 2, 3, 4, 5},
|
||||
expected: "0102030405000000000000000000000000000000", // 5字节+15个零
|
||||
},
|
||||
{
|
||||
name: "exactly 20 bytes",
|
||||
raw: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14},
|
||||
expected: "0102030405060708090A0B0C0D0E0F1011121314",
|
||||
},
|
||||
{
|
||||
name: "more than 20 bytes",
|
||||
raw: []byte{0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8, 0xF7, 0xF6, 0xF5, 0xF4, 0xF3, 0xF2, 0xF1, 0xF0, 0xEF, 0xEE, 0xED, 0xEC, 0xEB, 0xEA},
|
||||
expected: "FFFEFDFCFBFAF9F8F7F6F5F4F3F2F1F0EFEEEDEC", // 只取前20字节
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := calculateFingerprint(tt.raw)
|
||||
if result != tt.expected {
|
||||
t.Errorf("calculateFingerprint() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCalculateFingerprint_Uppercase 测试十六进制输出为大写
|
||||
func TestCalculateFingerprint_Uppercase(t *testing.T) {
|
||||
raw := []byte{0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}
|
||||
result := calculateFingerprint(raw)
|
||||
|
||||
// 验证输出为大写
|
||||
for _, c := range result {
|
||||
if c >= 'a' && c <= 'f' {
|
||||
t.Errorf("fingerprint should be uppercase, got %q", result)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSLVariablesInContext 测试通过 VariableContext 访问 SSL 变量
|
||||
// 注意:ssl_client_verify 在非 TLS 连接下会返回 NONE(因为 GetSSLClientVerify 检查 ctx.IsTLS())
|
||||
func TestSSLVariablesInContext(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// 设置 SSL 客户端信息
|
||||
ctx.SetUserValue(VarSSLClientSerial, "ABC123")
|
||||
ctx.SetUserValue(VarSSLClientSubject, "CN=test")
|
||||
ctx.SetUserValue(VarSSLClientIssuer, "CN=CA")
|
||||
ctx.SetUserValue(VarSSLClientFingerprint, "FINGERPRINT")
|
||||
ctx.SetUserValue(VarSSLClientNotBefore, "2025-01-01T00:00:00Z")
|
||||
ctx.SetUserValue(VarSSLClientNotAfter, "2026-01-01T00:00:00Z")
|
||||
ctx.SetUserValue(VarSSLClientEmail, "test@example.com")
|
||||
|
||||
vc := NewContext(ctx)
|
||||
defer ReleaseContext(vc)
|
||||
|
||||
tests := []struct {
|
||||
varName string
|
||||
expected string
|
||||
}{
|
||||
{VarSSLClientSerial, "ABC123"},
|
||||
{VarSSLClientSubject, "CN=test"},
|
||||
{VarSSLClientIssuer, "CN=CA"},
|
||||
{VarSSLClientFingerprint, "FINGERPRINT"},
|
||||
{VarSSLClientNotBefore, "2025-01-01T00:00:00Z"},
|
||||
{VarSSLClientNotAfter, "2026-01-01T00:00:00Z"},
|
||||
{VarSSLClientEmail, "test@example.com"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.varName, func(t *testing.T) {
|
||||
value, ok := vc.Get(tt.varName)
|
||||
if !ok {
|
||||
t.Errorf("variable %s not found", tt.varName)
|
||||
return
|
||||
}
|
||||
if value != tt.expected {
|
||||
t.Errorf("%s = %q, want %q", tt.varName, value, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSLVariablesInContext_VerifyNonTLS 测试 ssl_client_verify 在非 TLS 下返回 NONE
|
||||
func TestSSLVariablesInContext_VerifyNonTLS(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue(VarSSLClientVerify, "SUCCESS")
|
||||
|
||||
vc := NewContext(ctx)
|
||||
defer ReleaseContext(vc)
|
||||
|
||||
// 非 TLS 连接,ssl_client_verify 应该返回 NONE
|
||||
value, ok := vc.Get(VarSSLClientVerify)
|
||||
if !ok {
|
||||
t.Error("ssl_client_verify not found")
|
||||
return
|
||||
}
|
||||
if value != "NONE" {
|
||||
t.Errorf("ssl_client_verify = %q, want NONE (non-TLS context)", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSLVariablesExpand 测试在模板中展开 SSL 变量
|
||||
// 注意:ssl_client_verify 在非 TLS 连接下会返回 NONE
|
||||
func TestSSLVariablesExpand(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
ctx.SetUserValue(VarSSLClientSerial, "12345")
|
||||
ctx.SetUserValue(VarSSLClientSubject, "CN=test")
|
||||
|
||||
vc := NewContext(ctx)
|
||||
defer ReleaseContext(vc)
|
||||
|
||||
tests := []struct {
|
||||
template string
|
||||
expected string
|
||||
}{
|
||||
{"$ssl_client_serial", "12345"},
|
||||
{"$ssl_client_subject", "CN=test"},
|
||||
{"serial=$ssl_client_serial subject=$ssl_client_subject", "serial=12345 subject=CN=test"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.template, func(t *testing.T) {
|
||||
result := vc.Expand(tt.template)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expand(%q) = %q, want %q", tt.template, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSLVariablesExpand_VerifyNonTLS 测试 ssl_client_verify 在非 TLS 下展开为 NONE
|
||||
func TestSSLVariablesExpand_VerifyNonTLS(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.SetUserValue(VarSSLClientVerify, "SUCCESS")
|
||||
|
||||
vc := NewContext(ctx)
|
||||
defer ReleaseContext(vc)
|
||||
|
||||
// 非 TLS 连接,ssl_client_verify 应该展开为 NONE
|
||||
result := vc.Expand("$ssl_client_verify")
|
||||
if result != "NONE" {
|
||||
t.Errorf("Expand($ssl_client_verify) = %q, want NONE (non-TLS context)", result)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user