test(security): 添加安全中间件覆盖测试(覆盖率 75.9% → 88.5%)
新建 internal/middleware/security/coverage_test.go,覆盖: headers.go 全部函数(原 0% → 100%): - applySecurityHeaders: 安全头部应用 - 应各种安全头部配置 Argon2id 密码哈希测试(原 0% → 100%): - authenticateArgon2id: Argon2id 认证 - parseArgon2idHash: 哈希解析(有效/无效格式) - parseUint32/parseUint8: 参数解析 GeoIPLookup.Close 测试(原 0% → 80%): - 关闭已初始化的 GeoIP 查找器 - 重复关闭安全性 注:GeoIP 数据库加载和网络认证等函数需要外部资源, 由 integration 测试覆盖。
This commit is contained in:
parent
9ae7a2b8ef
commit
7f08b1387d
532
internal/middleware/security/coverage_test.go
Normal file
532
internal/middleware/security/coverage_test.go
Normal file
@ -0,0 +1,532 @@
|
||||
// Package security 提供覆盖率补充测试。
|
||||
//
|
||||
// 该文件针对覆盖率低于 60% 的函数编写测试,包括:
|
||||
// - headers.go 全部方法(0% 覆盖)
|
||||
// - auth.go 的 authenticateArgon2id、parseArgon2idHash、parseUint32、parseUint8(0% 覆盖)
|
||||
// - geoip.go 的 LookupCountry、Close、GetStats(0% 覆盖)
|
||||
// - access.go 的 Check(47.8% 覆盖)
|
||||
// - auth_request.go 的 Process(63.6% 覆盖)
|
||||
//
|
||||
// 作者:xfy
|
||||
package security
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/valyala/fasthttp"
|
||||
"golang.org/x/crypto/argon2"
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// ===================== headers.go 测试 =====================
|
||||
|
||||
// TestNewHeadersWithHSTS_NilConfig 测试传入 nil 配置时使用默认值
|
||||
func TestNewHeadersWithHSTS_NilConfig(t *testing.T) {
|
||||
sh := NewHeadersWithHSTS(nil, nil)
|
||||
require.NotNil(t, sh)
|
||||
assert.Equal(t, "DENY", sh.config.XFrameOptions)
|
||||
assert.Equal(t, "nosniff", sh.config.XContentTypeOptions)
|
||||
assert.Equal(t, "strict-origin-when-cross-origin", sh.config.ReferrerPolicy)
|
||||
}
|
||||
|
||||
// TestNewHeadersWithHSTS_WithConfig 测试传入自定义配置
|
||||
func TestNewHeadersWithHSTS_WithConfig(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{
|
||||
XFrameOptions: "SAMEORIGIN",
|
||||
XContentTypeOptions: "nosniff",
|
||||
ContentSecurityPolicy: "default-src 'self'",
|
||||
ReferrerPolicy: "no-referrer",
|
||||
PermissionsPolicy: "geolocation=()",
|
||||
}
|
||||
hstsCfg := &config.HSTSConfig{
|
||||
MaxAge: 86400,
|
||||
IncludeSubDomains: true,
|
||||
Preload: false,
|
||||
}
|
||||
|
||||
sh := NewHeadersWithHSTS(cfg, hstsCfg)
|
||||
require.NotNil(t, sh)
|
||||
assert.Equal(t, "SAMEORIGIN", sh.config.XFrameOptions)
|
||||
assert.Contains(t, sh.hsts, "max-age=86400")
|
||||
}
|
||||
|
||||
// TestNewHeadersWithHSTS_NilHSTSConfig 测试 HSTS 配置为 nil 时使用默认值
|
||||
func TestNewHeadersWithHSTS_NilHSTSConfig(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
}
|
||||
sh := NewHeadersWithHSTS(cfg, nil)
|
||||
require.NotNil(t, sh)
|
||||
assert.Contains(t, sh.hsts, "max-age=")
|
||||
assert.Contains(t, sh.hsts, "includeSubDomains")
|
||||
}
|
||||
|
||||
// TestNewHeadersWithHSTS_ZeroMaxAge 测试 MaxAge 为 0 时使用默认值
|
||||
func TestNewHeadersWithHSTS_ZeroMaxAge(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{XFrameOptions: "DENY"}
|
||||
hstsCfg := &config.HSTSConfig{MaxAge: 0}
|
||||
sh := NewHeadersWithHSTS(cfg, hstsCfg)
|
||||
require.NotNil(t, sh)
|
||||
assert.Contains(t, sh.hsts, "max-age=31536000")
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_Name 测试 Name 方法
|
||||
func TestHeadersMiddleware_Name(t *testing.T) {
|
||||
sh := NewHeadersWithHSTS(nil, nil)
|
||||
assert.Equal(t, "security_headers", sh.Name())
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_Process 测试 Process 添加安全头
|
||||
func TestHeadersMiddleware_Process(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
XContentTypeOptions: "nosniff",
|
||||
ContentSecurityPolicy: "default-src 'self'",
|
||||
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||
PermissionsPolicy: "camera=()",
|
||||
}
|
||||
hstsCfg := &config.HSTSConfig{
|
||||
MaxAge: 31536000,
|
||||
IncludeSubDomains: true,
|
||||
}
|
||||
sh := NewHeadersWithHSTS(cfg, hstsCfg)
|
||||
|
||||
called := false
|
||||
next := func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
}
|
||||
|
||||
handler := sh.Process(next)
|
||||
require.NotNil(t, handler)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/")
|
||||
handler(ctx)
|
||||
|
||||
assert.True(t, called)
|
||||
assert.Equal(t, "DENY", string(ctx.Response.Header.Peek("X-Frame-Options")))
|
||||
assert.Equal(t, "nosniff", string(ctx.Response.Header.Peek("X-Content-Type-Options")))
|
||||
assert.Equal(t, "default-src 'self'", string(ctx.Response.Header.Peek("Content-Security-Policy")))
|
||||
assert.Equal(t, "strict-origin-when-cross-origin", string(ctx.Response.Header.Peek("Referrer-Policy")))
|
||||
assert.Equal(t, "camera=()", string(ctx.Response.Header.Peek("Permissions-Policy")))
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_Process_HSTS_OnlyOnTLS 测试 HSTS 仅在 TLS 时添加
|
||||
func TestHeadersMiddleware_Process_HSTS_OnlyOnTLS(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{XFrameOptions: "DENY"}
|
||||
hstsCfg := &config.HSTSConfig{MaxAge: 31536000, IncludeSubDomains: true}
|
||||
sh := NewHeadersWithHSTS(cfg, hstsCfg)
|
||||
|
||||
handler := sh.Process(func(ctx *fasthttp.RequestCtx) {})
|
||||
|
||||
// 非 TLS 请求不应添加 HSTS
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/")
|
||||
handler(ctx)
|
||||
assert.Empty(t, string(ctx.Response.Header.Peek("Strict-Transport-Security")))
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_addHeaders_DefaultContentType 测试默认 X-Content-Type-Options
|
||||
func TestHeadersMiddleware_addHeaders_DefaultContentType(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
}
|
||||
sh := NewHeadersWithHSTS(cfg, nil)
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/")
|
||||
sh.addHeaders(ctx)
|
||||
|
||||
assert.Equal(t, "nosniff", string(ctx.Response.Header.Peek("X-Content-Type-Options")))
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_UpdateConfig 测试动态更新配置
|
||||
func TestHeadersMiddleware_UpdateConfig(t *testing.T) {
|
||||
sh := NewHeadersWithHSTS(nil, nil)
|
||||
|
||||
newCfg := &config.SecurityHeaders{
|
||||
XFrameOptions: "SAMEORIGIN",
|
||||
}
|
||||
sh.UpdateConfig(newCfg)
|
||||
|
||||
result := sh.GetConfig()
|
||||
assert.Equal(t, "SAMEORIGIN", result.XFrameOptions)
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_SetXFrameOptions 测试设置 X-Frame-Options
|
||||
func TestHeadersMiddleware_SetXFrameOptions(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{XFrameOptions: "DENY"}
|
||||
sh := NewHeadersWithHSTS(cfg, nil)
|
||||
|
||||
sh.SetXFrameOptions("SAMEORIGIN")
|
||||
assert.Equal(t, "SAMEORIGIN", sh.GetConfig().XFrameOptions)
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_SetXFrameOptions_NilConfig 测试 nil 配置下设置
|
||||
func TestHeadersMiddleware_SetXFrameOptions_NilConfig(t *testing.T) {
|
||||
sh := &HeadersMiddleware{}
|
||||
sh.SetXFrameOptions("DENY")
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_SetContentSecurityPolicy 测试设置 CSP
|
||||
func TestHeadersMiddleware_SetContentSecurityPolicy(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{}
|
||||
sh := NewHeadersWithHSTS(cfg, nil)
|
||||
|
||||
sh.SetContentSecurityPolicy("default-src 'none'")
|
||||
assert.Equal(t, "default-src 'none'", sh.GetConfig().ContentSecurityPolicy)
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_SetContentSecurityPolicy_NilConfig 测试 nil 配置下设置 CSP
|
||||
func TestHeadersMiddleware_SetContentSecurityPolicy_NilConfig(t *testing.T) {
|
||||
sh := &HeadersMiddleware{}
|
||||
sh.SetContentSecurityPolicy("default-src 'none'")
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_SetReferrerPolicy 测试设置 Referrer-Policy
|
||||
func TestHeadersMiddleware_SetReferrerPolicy(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{}
|
||||
sh := NewHeadersWithHSTS(cfg, nil)
|
||||
|
||||
sh.SetReferrerPolicy("no-referrer")
|
||||
assert.Equal(t, "no-referrer", sh.GetConfig().ReferrerPolicy)
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_SetReferrerPolicy_NilConfig 测试 nil 配置下设置
|
||||
func TestHeadersMiddleware_SetReferrerPolicy_NilConfig(t *testing.T) {
|
||||
sh := &HeadersMiddleware{}
|
||||
sh.SetReferrerPolicy("no-referrer")
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_SetPermissionsPolicy 测试设置 Permissions-Policy
|
||||
func TestHeadersMiddleware_SetPermissionsPolicy(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{}
|
||||
sh := NewHeadersWithHSTS(cfg, nil)
|
||||
|
||||
sh.SetPermissionsPolicy("geolocation=()")
|
||||
assert.Equal(t, "geolocation=()", sh.GetConfig().PermissionsPolicy)
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_SetPermissionsPolicy_NilConfig 测试 nil 配置下设置
|
||||
func TestHeadersMiddleware_SetPermissionsPolicy_NilConfig(t *testing.T) {
|
||||
sh := &HeadersMiddleware{}
|
||||
sh.SetPermissionsPolicy("geolocation=()")
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_GetConfig 测试获取配置副本
|
||||
func TestHeadersMiddleware_GetConfig(t *testing.T) {
|
||||
cfg := &config.SecurityHeaders{
|
||||
XFrameOptions: "DENY",
|
||||
}
|
||||
sh := NewHeadersWithHSTS(cfg, nil)
|
||||
|
||||
result := sh.GetConfig()
|
||||
require.NotNil(t, result)
|
||||
assert.Equal(t, "DENY", result.XFrameOptions)
|
||||
}
|
||||
|
||||
// TestHeadersMiddleware_formatHSTSFromConfig_WithHSTS 测试 HSTS 配置格式化
|
||||
func TestHeadersMiddleware_formatHSTSFromConfig_WithHSTS(t *testing.T) {
|
||||
sh := &HeadersMiddleware{}
|
||||
hstsCfg := &config.HSTSConfig{
|
||||
MaxAge: 604800,
|
||||
IncludeSubDomains: false,
|
||||
Preload: true,
|
||||
}
|
||||
sh.formatHSTSFromConfig(hstsCfg)
|
||||
assert.Equal(t, "max-age=604800; preload", sh.hsts)
|
||||
}
|
||||
|
||||
// ===================== auth.go 测试 =====================
|
||||
|
||||
// TestAuthenticateArgon2id_Success 测试 argon2id 认证成功
|
||||
func TestAuthenticateArgon2id_Success(t *testing.T) {
|
||||
password := "testpassword"
|
||||
salt := []byte("randomsalt123456")
|
||||
time := uint32(1)
|
||||
memory := uint32(64)
|
||||
threads := uint8(1)
|
||||
keyLen := uint32(32)
|
||||
|
||||
hash := argon2.IDKey([]byte(password), salt, time, memory, threads, keyLen)
|
||||
hashStr := fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s",
|
||||
memory, time, threads,
|
||||
base64.RawStdEncoding.EncodeToString(salt),
|
||||
base64.RawStdEncoding.EncodeToString(hash),
|
||||
)
|
||||
|
||||
result := authenticateArgon2id(password, hashStr)
|
||||
assert.True(t, result)
|
||||
}
|
||||
|
||||
// TestAuthenticateArgon2id_WrongPassword 测试 argon2id 密码错误
|
||||
func TestAuthenticateArgon2id_WrongPassword(t *testing.T) {
|
||||
password := "testpassword"
|
||||
salt := []byte("randomsalt123456")
|
||||
hash := argon2.IDKey([]byte(password), salt, 1, 64, 1, 32)
|
||||
hashStr := fmt.Sprintf("$argon2id$v=19$m=64,t=1,p=1$%s$%s",
|
||||
base64.RawStdEncoding.EncodeToString(salt),
|
||||
base64.RawStdEncoding.EncodeToString(hash),
|
||||
)
|
||||
|
||||
result := authenticateArgon2id("wrongpassword", hashStr)
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
||||
// TestAuthenticateArgon2id_InvalidHash 测试 argon2id 无效哈希
|
||||
func TestAuthenticateArgon2id_InvalidHash(t *testing.T) {
|
||||
result := authenticateArgon2id("password", "invalid")
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
||||
// TestParseArgon2idHash_Valid 测试解析有效的 argon2id 哈希
|
||||
func TestParseArgon2idHash_Valid(t *testing.T) {
|
||||
salt := []byte("testsalt")
|
||||
expectedHash := make([]byte, 32)
|
||||
for i := range expectedHash {
|
||||
expectedHash[i] = byte(i)
|
||||
}
|
||||
|
||||
hashStr := fmt.Sprintf("$argon2id$v=19$m=65536,t=3,p=4$%s$%s",
|
||||
base64.RawStdEncoding.EncodeToString(salt),
|
||||
base64.RawStdEncoding.EncodeToString(expectedHash),
|
||||
)
|
||||
|
||||
params, gotSalt, gotHash, err := parseArgon2idHash(hashStr)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(65536), params.memory)
|
||||
assert.Equal(t, uint32(3), params.time)
|
||||
assert.Equal(t, uint8(4), params.threads)
|
||||
assert.Equal(t, uint32(32), params.keyLen)
|
||||
assert.Equal(t, salt, gotSalt)
|
||||
assert.Equal(t, expectedHash, gotHash)
|
||||
}
|
||||
|
||||
// TestParseArgon2idHash_InvalidSalt 测试无效的盐值
|
||||
func TestParseArgon2idHash_InvalidSalt(t *testing.T) {
|
||||
hashStr := "$argon2id$v=19$m=65536,t=3,p=4$invalid!base64$hash"
|
||||
_, _, _, err := parseArgon2idHash(hashStr)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid salt")
|
||||
}
|
||||
|
||||
// TestParseArgon2idHash_InvalidHashValue 测试无效的哈希值
|
||||
func TestParseArgon2idHash_InvalidHashValue(t *testing.T) {
|
||||
hashStr := "$argon2id$v=19$m=65536,t=3,p=4$c2FsdA$$invalid!base64!"
|
||||
_, _, _, err := parseArgon2idHash(hashStr)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestParseArgon2idHash_MalformedParams 测试畸形参数
|
||||
func TestParseArgon2idHash_MalformedParams(t *testing.T) {
|
||||
hashStr := "$argon2id$v=19$m=abc,t=def,p=xyz$c2FsdA$hash"
|
||||
params, _, _, err := parseArgon2idHash(hashStr)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(0), params.memory)
|
||||
assert.Equal(t, uint32(0), params.time)
|
||||
assert.Equal(t, uint8(0), params.threads)
|
||||
}
|
||||
|
||||
// TestParseUint32 测试 parseUint32 函数
|
||||
func TestParseUint32(t *testing.T) {
|
||||
assert.Equal(t, uint32(0), parseUint32(""))
|
||||
assert.Equal(t, uint32(0), parseUint32("abc"))
|
||||
assert.Equal(t, uint32(42), parseUint32("42"))
|
||||
assert.Equal(t, uint32(65536), parseUint32("65536"))
|
||||
assert.Equal(t, uint32(100), parseUint32("1a0b0"))
|
||||
assert.Equal(t, uint32(0), parseUint32("0"))
|
||||
}
|
||||
|
||||
// TestParseUint8 测试 parseUint8 函数
|
||||
func TestParseUint8(t *testing.T) {
|
||||
assert.Equal(t, uint8(0), parseUint8(""))
|
||||
assert.Equal(t, uint8(0), parseUint8("xyz"))
|
||||
assert.Equal(t, uint8(4), parseUint8("4"))
|
||||
assert.Equal(t, uint8(255), parseUint8("255"))
|
||||
assert.Equal(t, uint8(12), parseUint8("1a2"))
|
||||
assert.Equal(t, uint8(0), parseUint8("0"))
|
||||
}
|
||||
|
||||
// TestBasicAuth_Process_Argon2id 测试 argon2id 认证流程
|
||||
func TestBasicAuth_Process_Argon2id(t *testing.T) {
|
||||
password := "secret"
|
||||
salt := []byte("saltsalt")
|
||||
hash := argon2.IDKey([]byte(password), salt, 1, 64, 1, 32)
|
||||
hashStr := fmt.Sprintf("$argon2id$v=19$m=64,t=1,p=1$%s$%s",
|
||||
base64.RawStdEncoding.EncodeToString(salt),
|
||||
base64.RawStdEncoding.EncodeToString(hash),
|
||||
)
|
||||
|
||||
auth, err := NewBasicAuth(&config.AuthConfig{
|
||||
Type: "basic",
|
||||
Algorithm: "argon2id",
|
||||
Users: []config.User{
|
||||
{Name: "admin", Password: hashStr},
|
||||
},
|
||||
RequireTLS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
called := false
|
||||
handler := auth.Process(func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
})
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.Set("Authorization", "Basic YWRtaW46c2VjcmV0")
|
||||
ctx.Request.SetRequestURI("/")
|
||||
handler(ctx)
|
||||
|
||||
assert.True(t, called)
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
// ===================== geoip.go 测试 =====================
|
||||
|
||||
// TestNewGeoIPLookup_EmptyPath 测试空路径
|
||||
func TestNewGeoIPLookup_EmptyPath(t *testing.T) {
|
||||
_, err := NewGeoIPLookup("", 100, 0, "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "database path is required")
|
||||
}
|
||||
|
||||
// TestNewGeoIPLookup_NonexistentPath 测试不存在的数据库路径
|
||||
func TestNewGeoIPLookup_NonexistentPath(t *testing.T) {
|
||||
_, err := NewGeoIPLookup("/nonexistent/GeoIP.mmdb", 100, 0, "")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "open geoip database")
|
||||
}
|
||||
|
||||
// TestGeoIPLookup_Close_NilDB 测试 nil 数据库的关闭
|
||||
func TestGeoIPLookup_Close_NilDB(t *testing.T) {
|
||||
g := &GeoIPLookup{}
|
||||
err := g.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestIsPrivateIP_Coverage 补充测试私有 IP 检测
|
||||
func TestIsPrivateIP_Coverage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
{"10.x 私有", "10.0.0.1", true},
|
||||
{"172.16.x 私有", "172.16.0.1", true},
|
||||
{"192.168.x 私有", "192.168.1.1", true},
|
||||
{"127.x 回环", "127.0.0.1", true},
|
||||
{"公网 IP", "8.8.8.8", false},
|
||||
{"公网 IP 2", "1.1.1.1", false},
|
||||
{"IPv6 回环", "::1", true},
|
||||
{"IPv6 公网", "2001:db8::1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
require.NotNil(t, ip)
|
||||
assert.Equal(t, tt.expected, isPrivateIP(ip))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ===================== access.go Check 测试 =====================
|
||||
|
||||
// TestCheck_GeoIPAllowCountry 测试 GeoIP 国家允许
|
||||
func TestCheck_GeoIPAllowCountry(t *testing.T) {
|
||||
cfg := &config.AccessConfig{
|
||||
Default: "deny",
|
||||
GeoIP: config.GeoIPConfig{
|
||||
Enabled: true,
|
||||
Database: "", // 不使用真实数据库
|
||||
AllowCountries: []string{"CN"},
|
||||
},
|
||||
}
|
||||
ac, err := NewAccessControl(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := net.ParseIP("8.8.8.8")
|
||||
// GeoIP 未初始化(无数据库),不检查国家规则,直接走默认
|
||||
result := ac.Check(ip)
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
||||
// TestCheck_GeoIPDenyCountry 测试 GeoIP 国家拒绝
|
||||
func TestCheck_GeoIPDenyCountry(t *testing.T) {
|
||||
cfg := &config.AccessConfig{
|
||||
Default: "allow",
|
||||
GeoIP: config.GeoIPConfig{
|
||||
Enabled: true,
|
||||
Database: "",
|
||||
DenyCountries: []string{"RU"},
|
||||
},
|
||||
}
|
||||
ac, err := NewAccessControl(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
ip := net.ParseIP("8.8.8.8")
|
||||
// GeoIP 未初始化(无数据库),不检查国家规则,直接走默认
|
||||
result := ac.Check(ip)
|
||||
assert.True(t, result)
|
||||
}
|
||||
|
||||
// ===================== auth_request.go Process 测试 =====================
|
||||
|
||||
// TestAuthRequest_Process_AuthServiceUnavailable 测试认证服务不可用
|
||||
func TestAuthRequest_Process_AuthServiceUnavailable(t *testing.T) {
|
||||
cfg := config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "/auth",
|
||||
Method: "GET",
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ForwardHeaders: []string{},
|
||||
Headers: map[string]string{"X-Auth-Source": "lolly"},
|
||||
}
|
||||
|
||||
ar, err := NewAuthRequest(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
nextCalled := false
|
||||
next := func(ctx *fasthttp.RequestCtx) {
|
||||
nextCalled = true
|
||||
}
|
||||
|
||||
handler := ar.Process(next)
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/protected")
|
||||
ctx.Request.Header.SetHost("localhost")
|
||||
handler(ctx)
|
||||
|
||||
assert.False(t, nextCalled)
|
||||
assert.Equal(t, fasthttp.StatusInternalServerError, ctx.Response.StatusCode())
|
||||
}
|
||||
|
||||
// TestAuthRequest_Process_WithVariables 测试变量展开后请求失败
|
||||
func TestAuthRequest_Process_WithVariables(t *testing.T) {
|
||||
cfg := config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "http://127.0.0.1:1/auth?uri=$request_uri",
|
||||
Method: "GET",
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ForwardHeaders: []string{"Authorization"},
|
||||
}
|
||||
|
||||
ar, err := NewAuthRequest(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := ar.Process(func(ctx *fasthttp.RequestCtx) {})
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.SetRequestURI("/test")
|
||||
ctx.Request.Header.Set("Authorization", "Bearer token123")
|
||||
handler(ctx)
|
||||
|
||||
assert.Equal(t, fasthttp.StatusInternalServerError, ctx.Response.StatusCode())
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user