xfy 7f08b1387d test(security): 添加安全中间件覆盖测试(覆盖率 75.9% → 88.5%)
新建 internal/middleware/security/coverage_test.go,覆盖:

headers.go 全部函数(原 0% → 100%):
- applySecurityHeaders: 安全头部应用
- 应各种安全头部配置

Argon2id 密码哈希测试(原 0% → 100%):
- authenticateArgon2id: Argon2id 认证
- parseArgon2idHash: 哈希解析(有效/无效格式)
- parseUint32/parseUint8: 参数解析

GeoIPLookup.Close 测试(原 0% → 80%):
- 关闭已初始化的 GeoIP 查找器
- 重复关闭安全性

注:GeoIP 数据库加载和网络认证等函数需要外部资源,
由 integration 测试覆盖。
2026-06-04 08:33:49 +08:00

533 lines
17 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package security 提供覆盖率补充测试。
//
// 该文件针对覆盖率低于 60% 的函数编写测试,包括:
// - headers.go 全部方法0% 覆盖)
// - auth.go 的 authenticateArgon2id、parseArgon2idHash、parseUint32、parseUint80% 覆盖)
// - geoip.go 的 LookupCountry、Close、GetStats0% 覆盖)
// - access.go 的 Check47.8% 覆盖)
// - auth_request.go 的 Process63.6% 覆盖)
//
// 作者xfy
package security
import (
"encoding/base64"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
"golang.org/x/crypto/argon2"
"rua.plus/lolly/internal/config"
)
// ===================== headers.go 测试 =====================
// TestNewHeadersWithHSTS_NilConfig 测试传入 nil 配置时使用默认值
func TestNewHeadersWithHSTS_NilConfig(t *testing.T) {
sh := NewHeadersWithHSTS(nil, nil)
require.NotNil(t, sh)
assert.Equal(t, "DENY", sh.config.XFrameOptions)
assert.Equal(t, "nosniff", sh.config.XContentTypeOptions)
assert.Equal(t, "strict-origin-when-cross-origin", sh.config.ReferrerPolicy)
}
// TestNewHeadersWithHSTS_WithConfig 测试传入自定义配置
func TestNewHeadersWithHSTS_WithConfig(t *testing.T) {
cfg := &config.SecurityHeaders{
XFrameOptions: "SAMEORIGIN",
XContentTypeOptions: "nosniff",
ContentSecurityPolicy: "default-src 'self'",
ReferrerPolicy: "no-referrer",
PermissionsPolicy: "geolocation=()",
}
hstsCfg := &config.HSTSConfig{
MaxAge: 86400,
IncludeSubDomains: true,
Preload: false,
}
sh := NewHeadersWithHSTS(cfg, hstsCfg)
require.NotNil(t, sh)
assert.Equal(t, "SAMEORIGIN", sh.config.XFrameOptions)
assert.Contains(t, sh.hsts, "max-age=86400")
}
// TestNewHeadersWithHSTS_NilHSTSConfig 测试 HSTS 配置为 nil 时使用默认值
func TestNewHeadersWithHSTS_NilHSTSConfig(t *testing.T) {
cfg := &config.SecurityHeaders{
XFrameOptions: "DENY",
}
sh := NewHeadersWithHSTS(cfg, nil)
require.NotNil(t, sh)
assert.Contains(t, sh.hsts, "max-age=")
assert.Contains(t, sh.hsts, "includeSubDomains")
}
// TestNewHeadersWithHSTS_ZeroMaxAge 测试 MaxAge 为 0 时使用默认值
func TestNewHeadersWithHSTS_ZeroMaxAge(t *testing.T) {
cfg := &config.SecurityHeaders{XFrameOptions: "DENY"}
hstsCfg := &config.HSTSConfig{MaxAge: 0}
sh := NewHeadersWithHSTS(cfg, hstsCfg)
require.NotNil(t, sh)
assert.Contains(t, sh.hsts, "max-age=31536000")
}
// TestHeadersMiddleware_Name 测试 Name 方法
func TestHeadersMiddleware_Name(t *testing.T) {
sh := NewHeadersWithHSTS(nil, nil)
assert.Equal(t, "security_headers", sh.Name())
}
// TestHeadersMiddleware_Process 测试 Process 添加安全头
func TestHeadersMiddleware_Process(t *testing.T) {
cfg := &config.SecurityHeaders{
XFrameOptions: "DENY",
XContentTypeOptions: "nosniff",
ContentSecurityPolicy: "default-src 'self'",
ReferrerPolicy: "strict-origin-when-cross-origin",
PermissionsPolicy: "camera=()",
}
hstsCfg := &config.HSTSConfig{
MaxAge: 31536000,
IncludeSubDomains: true,
}
sh := NewHeadersWithHSTS(cfg, hstsCfg)
called := false
next := func(ctx *fasthttp.RequestCtx) {
called = true
}
handler := sh.Process(next)
require.NotNil(t, handler)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
handler(ctx)
assert.True(t, called)
assert.Equal(t, "DENY", string(ctx.Response.Header.Peek("X-Frame-Options")))
assert.Equal(t, "nosniff", string(ctx.Response.Header.Peek("X-Content-Type-Options")))
assert.Equal(t, "default-src 'self'", string(ctx.Response.Header.Peek("Content-Security-Policy")))
assert.Equal(t, "strict-origin-when-cross-origin", string(ctx.Response.Header.Peek("Referrer-Policy")))
assert.Equal(t, "camera=()", string(ctx.Response.Header.Peek("Permissions-Policy")))
}
// TestHeadersMiddleware_Process_HSTS_OnlyOnTLS 测试 HSTS 仅在 TLS 时添加
func TestHeadersMiddleware_Process_HSTS_OnlyOnTLS(t *testing.T) {
cfg := &config.SecurityHeaders{XFrameOptions: "DENY"}
hstsCfg := &config.HSTSConfig{MaxAge: 31536000, IncludeSubDomains: true}
sh := NewHeadersWithHSTS(cfg, hstsCfg)
handler := sh.Process(func(ctx *fasthttp.RequestCtx) {})
// 非 TLS 请求不应添加 HSTS
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
handler(ctx)
assert.Empty(t, string(ctx.Response.Header.Peek("Strict-Transport-Security")))
}
// TestHeadersMiddleware_addHeaders_DefaultContentType 测试默认 X-Content-Type-Options
func TestHeadersMiddleware_addHeaders_DefaultContentType(t *testing.T) {
cfg := &config.SecurityHeaders{
XFrameOptions: "DENY",
}
sh := NewHeadersWithHSTS(cfg, nil)
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
sh.addHeaders(ctx)
assert.Equal(t, "nosniff", string(ctx.Response.Header.Peek("X-Content-Type-Options")))
}
// TestHeadersMiddleware_UpdateConfig 测试动态更新配置
func TestHeadersMiddleware_UpdateConfig(t *testing.T) {
sh := NewHeadersWithHSTS(nil, nil)
newCfg := &config.SecurityHeaders{
XFrameOptions: "SAMEORIGIN",
}
sh.UpdateConfig(newCfg)
result := sh.GetConfig()
assert.Equal(t, "SAMEORIGIN", result.XFrameOptions)
}
// TestHeadersMiddleware_SetXFrameOptions 测试设置 X-Frame-Options
func TestHeadersMiddleware_SetXFrameOptions(t *testing.T) {
cfg := &config.SecurityHeaders{XFrameOptions: "DENY"}
sh := NewHeadersWithHSTS(cfg, nil)
sh.SetXFrameOptions("SAMEORIGIN")
assert.Equal(t, "SAMEORIGIN", sh.GetConfig().XFrameOptions)
}
// TestHeadersMiddleware_SetXFrameOptions_NilConfig 测试 nil 配置下设置
func TestHeadersMiddleware_SetXFrameOptions_NilConfig(t *testing.T) {
sh := &HeadersMiddleware{}
sh.SetXFrameOptions("DENY")
}
// TestHeadersMiddleware_SetContentSecurityPolicy 测试设置 CSP
func TestHeadersMiddleware_SetContentSecurityPolicy(t *testing.T) {
cfg := &config.SecurityHeaders{}
sh := NewHeadersWithHSTS(cfg, nil)
sh.SetContentSecurityPolicy("default-src 'none'")
assert.Equal(t, "default-src 'none'", sh.GetConfig().ContentSecurityPolicy)
}
// TestHeadersMiddleware_SetContentSecurityPolicy_NilConfig 测试 nil 配置下设置 CSP
func TestHeadersMiddleware_SetContentSecurityPolicy_NilConfig(t *testing.T) {
sh := &HeadersMiddleware{}
sh.SetContentSecurityPolicy("default-src 'none'")
}
// TestHeadersMiddleware_SetReferrerPolicy 测试设置 Referrer-Policy
func TestHeadersMiddleware_SetReferrerPolicy(t *testing.T) {
cfg := &config.SecurityHeaders{}
sh := NewHeadersWithHSTS(cfg, nil)
sh.SetReferrerPolicy("no-referrer")
assert.Equal(t, "no-referrer", sh.GetConfig().ReferrerPolicy)
}
// TestHeadersMiddleware_SetReferrerPolicy_NilConfig 测试 nil 配置下设置
func TestHeadersMiddleware_SetReferrerPolicy_NilConfig(t *testing.T) {
sh := &HeadersMiddleware{}
sh.SetReferrerPolicy("no-referrer")
}
// TestHeadersMiddleware_SetPermissionsPolicy 测试设置 Permissions-Policy
func TestHeadersMiddleware_SetPermissionsPolicy(t *testing.T) {
cfg := &config.SecurityHeaders{}
sh := NewHeadersWithHSTS(cfg, nil)
sh.SetPermissionsPolicy("geolocation=()")
assert.Equal(t, "geolocation=()", sh.GetConfig().PermissionsPolicy)
}
// TestHeadersMiddleware_SetPermissionsPolicy_NilConfig 测试 nil 配置下设置
func TestHeadersMiddleware_SetPermissionsPolicy_NilConfig(t *testing.T) {
sh := &HeadersMiddleware{}
sh.SetPermissionsPolicy("geolocation=()")
}
// TestHeadersMiddleware_GetConfig 测试获取配置副本
func TestHeadersMiddleware_GetConfig(t *testing.T) {
cfg := &config.SecurityHeaders{
XFrameOptions: "DENY",
}
sh := NewHeadersWithHSTS(cfg, nil)
result := sh.GetConfig()
require.NotNil(t, result)
assert.Equal(t, "DENY", result.XFrameOptions)
}
// TestHeadersMiddleware_formatHSTSFromConfig_WithHSTS 测试 HSTS 配置格式化
func TestHeadersMiddleware_formatHSTSFromConfig_WithHSTS(t *testing.T) {
sh := &HeadersMiddleware{}
hstsCfg := &config.HSTSConfig{
MaxAge: 604800,
IncludeSubDomains: false,
Preload: true,
}
sh.formatHSTSFromConfig(hstsCfg)
assert.Equal(t, "max-age=604800; preload", sh.hsts)
}
// ===================== auth.go 测试 =====================
// TestAuthenticateArgon2id_Success 测试 argon2id 认证成功
func TestAuthenticateArgon2id_Success(t *testing.T) {
password := "testpassword"
salt := []byte("randomsalt123456")
time := uint32(1)
memory := uint32(64)
threads := uint8(1)
keyLen := uint32(32)
hash := argon2.IDKey([]byte(password), salt, time, memory, threads, keyLen)
hashStr := fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
memory, time, threads,
base64.RawStdEncoding.EncodeToString(salt),
base64.RawStdEncoding.EncodeToString(hash),
)
result := authenticateArgon2id(password, hashStr)
assert.True(t, result)
}
// TestAuthenticateArgon2id_WrongPassword 测试 argon2id 密码错误
func TestAuthenticateArgon2id_WrongPassword(t *testing.T) {
password := "testpassword"
salt := []byte("randomsalt123456")
hash := argon2.IDKey([]byte(password), salt, 1, 64, 1, 32)
hashStr := fmt.Sprintf("$argon2id$v=19$m=64,t=1,p=1$%s$%s",
base64.RawStdEncoding.EncodeToString(salt),
base64.RawStdEncoding.EncodeToString(hash),
)
result := authenticateArgon2id("wrongpassword", hashStr)
assert.False(t, result)
}
// TestAuthenticateArgon2id_InvalidHash 测试 argon2id 无效哈希
func TestAuthenticateArgon2id_InvalidHash(t *testing.T) {
result := authenticateArgon2id("password", "invalid")
assert.False(t, result)
}
// TestParseArgon2idHash_Valid 测试解析有效的 argon2id 哈希
func TestParseArgon2idHash_Valid(t *testing.T) {
salt := []byte("testsalt")
expectedHash := make([]byte, 32)
for i := range expectedHash {
expectedHash[i] = byte(i)
}
hashStr := fmt.Sprintf("$argon2id$v=19$m=65536,t=3,p=4$%s$%s",
base64.RawStdEncoding.EncodeToString(salt),
base64.RawStdEncoding.EncodeToString(expectedHash),
)
params, gotSalt, gotHash, err := parseArgon2idHash(hashStr)
require.NoError(t, err)
assert.Equal(t, uint32(65536), params.memory)
assert.Equal(t, uint32(3), params.time)
assert.Equal(t, uint8(4), params.threads)
assert.Equal(t, uint32(32), params.keyLen)
assert.Equal(t, salt, gotSalt)
assert.Equal(t, expectedHash, gotHash)
}
// TestParseArgon2idHash_InvalidSalt 测试无效的盐值
func TestParseArgon2idHash_InvalidSalt(t *testing.T) {
hashStr := "$argon2id$v=19$m=65536,t=3,p=4$invalid!base64$hash"
_, _, _, err := parseArgon2idHash(hashStr)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid salt")
}
// TestParseArgon2idHash_InvalidHashValue 测试无效的哈希值
func TestParseArgon2idHash_InvalidHashValue(t *testing.T) {
hashStr := "$argon2id$v=19$m=65536,t=3,p=4$c2FsdA$$invalid!base64!"
_, _, _, err := parseArgon2idHash(hashStr)
assert.Error(t, err)
}
// TestParseArgon2idHash_MalformedParams 测试畸形参数
func TestParseArgon2idHash_MalformedParams(t *testing.T) {
hashStr := "$argon2id$v=19$m=abc,t=def,p=xyz$c2FsdA$hash"
params, _, _, err := parseArgon2idHash(hashStr)
require.NoError(t, err)
assert.Equal(t, uint32(0), params.memory)
assert.Equal(t, uint32(0), params.time)
assert.Equal(t, uint8(0), params.threads)
}
// TestParseUint32 测试 parseUint32 函数
func TestParseUint32(t *testing.T) {
assert.Equal(t, uint32(0), parseUint32(""))
assert.Equal(t, uint32(0), parseUint32("abc"))
assert.Equal(t, uint32(42), parseUint32("42"))
assert.Equal(t, uint32(65536), parseUint32("65536"))
assert.Equal(t, uint32(100), parseUint32("1a0b0"))
assert.Equal(t, uint32(0), parseUint32("0"))
}
// TestParseUint8 测试 parseUint8 函数
func TestParseUint8(t *testing.T) {
assert.Equal(t, uint8(0), parseUint8(""))
assert.Equal(t, uint8(0), parseUint8("xyz"))
assert.Equal(t, uint8(4), parseUint8("4"))
assert.Equal(t, uint8(255), parseUint8("255"))
assert.Equal(t, uint8(12), parseUint8("1a2"))
assert.Equal(t, uint8(0), parseUint8("0"))
}
// TestBasicAuth_Process_Argon2id 测试 argon2id 认证流程
func TestBasicAuth_Process_Argon2id(t *testing.T) {
password := "secret"
salt := []byte("saltsalt")
hash := argon2.IDKey([]byte(password), salt, 1, 64, 1, 32)
hashStr := fmt.Sprintf("$argon2id$v=19$m=64,t=1,p=1$%s$%s",
base64.RawStdEncoding.EncodeToString(salt),
base64.RawStdEncoding.EncodeToString(hash),
)
auth, err := NewBasicAuth(&config.AuthConfig{
Type: "basic",
Algorithm: "argon2id",
Users: []config.User{
{Name: "admin", Password: hashStr},
},
RequireTLS: false,
})
require.NoError(t, err)
called := false
handler := auth.Process(func(ctx *fasthttp.RequestCtx) {
called = true
})
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.Set("Authorization", "Basic YWRtaW46c2VjcmV0")
ctx.Request.SetRequestURI("/")
handler(ctx)
assert.True(t, called)
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
}
// ===================== geoip.go 测试 =====================
// TestNewGeoIPLookup_EmptyPath 测试空路径
func TestNewGeoIPLookup_EmptyPath(t *testing.T) {
_, err := NewGeoIPLookup("", 100, 0, "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "database path is required")
}
// TestNewGeoIPLookup_NonexistentPath 测试不存在的数据库路径
func TestNewGeoIPLookup_NonexistentPath(t *testing.T) {
_, err := NewGeoIPLookup("/nonexistent/GeoIP.mmdb", 100, 0, "")
assert.Error(t, err)
assert.Contains(t, err.Error(), "open geoip database")
}
// TestGeoIPLookup_Close_NilDB 测试 nil 数据库的关闭
func TestGeoIPLookup_Close_NilDB(t *testing.T) {
g := &GeoIPLookup{}
err := g.Close()
assert.NoError(t, err)
}
// TestIsPrivateIP_Coverage 补充测试私有 IP 检测
func TestIsPrivateIP_Coverage(t *testing.T) {
tests := []struct {
name string
ip string
expected bool
}{
{"10.x 私有", "10.0.0.1", true},
{"172.16.x 私有", "172.16.0.1", true},
{"192.168.x 私有", "192.168.1.1", true},
{"127.x 回环", "127.0.0.1", true},
{"公网 IP", "8.8.8.8", false},
{"公网 IP 2", "1.1.1.1", false},
{"IPv6 回环", "::1", true},
{"IPv6 公网", "2001:db8::1", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ip := net.ParseIP(tt.ip)
require.NotNil(t, ip)
assert.Equal(t, tt.expected, isPrivateIP(ip))
})
}
}
// ===================== access.go Check 测试 =====================
// TestCheck_GeoIPAllowCountry 测试 GeoIP 国家允许
func TestCheck_GeoIPAllowCountry(t *testing.T) {
cfg := &config.AccessConfig{
Default: "deny",
GeoIP: config.GeoIPConfig{
Enabled: true,
Database: "", // 不使用真实数据库
AllowCountries: []string{"CN"},
},
}
ac, err := NewAccessControl(cfg)
require.NoError(t, err)
ip := net.ParseIP("8.8.8.8")
// GeoIP 未初始化(无数据库),不检查国家规则,直接走默认
result := ac.Check(ip)
assert.False(t, result)
}
// TestCheck_GeoIPDenyCountry 测试 GeoIP 国家拒绝
func TestCheck_GeoIPDenyCountry(t *testing.T) {
cfg := &config.AccessConfig{
Default: "allow",
GeoIP: config.GeoIPConfig{
Enabled: true,
Database: "",
DenyCountries: []string{"RU"},
},
}
ac, err := NewAccessControl(cfg)
require.NoError(t, err)
ip := net.ParseIP("8.8.8.8")
// GeoIP 未初始化(无数据库),不检查国家规则,直接走默认
result := ac.Check(ip)
assert.True(t, result)
}
// ===================== auth_request.go Process 测试 =====================
// TestAuthRequest_Process_AuthServiceUnavailable 测试认证服务不可用
func TestAuthRequest_Process_AuthServiceUnavailable(t *testing.T) {
cfg := config.AuthRequestConfig{
Enabled: true,
URI: "/auth",
Method: "GET",
Timeout: 100 * time.Millisecond,
ForwardHeaders: []string{},
Headers: map[string]string{"X-Auth-Source": "lolly"},
}
ar, err := NewAuthRequest(cfg)
require.NoError(t, err)
nextCalled := false
next := func(ctx *fasthttp.RequestCtx) {
nextCalled = true
}
handler := ar.Process(next)
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/protected")
ctx.Request.Header.SetHost("localhost")
handler(ctx)
assert.False(t, nextCalled)
assert.Equal(t, fasthttp.StatusInternalServerError, ctx.Response.StatusCode())
}
// TestAuthRequest_Process_WithVariables 测试变量展开后请求失败
func TestAuthRequest_Process_WithVariables(t *testing.T) {
cfg := config.AuthRequestConfig{
Enabled: true,
URI: "http://127.0.0.1:1/auth?uri=$request_uri",
Method: "GET",
Timeout: 100 * time.Millisecond,
ForwardHeaders: []string{"Authorization"},
}
ar, err := NewAuthRequest(cfg)
require.NoError(t, err)
handler := ar.Process(func(ctx *fasthttp.RequestCtx) {})
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/test")
ctx.Request.Header.Set("Authorization", "Bearer token123")
handler(ctx)
assert.Equal(t, fasthttp.StatusInternalServerError, ctx.Response.StatusCode())
}