feat(middleware/security): 新增 auth_request 外部认证中间件
- 支持将认证委托给外部服务,根据响应状态码决定请求是否继续 - 配置 URI、Method、Timeout 和自定义 Headers - 支持 $request_uri 等变量在 Headers 中使用 Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
84d67c8570
commit
49d33f8b0c
414
internal/middleware/security/auth_request.go
Normal file
414
internal/middleware/security/auth_request.go
Normal file
@ -0,0 +1,414 @@
|
||||
// Package security 提供安全相关的 HTTP 中间件。
|
||||
//
|
||||
// 该文件实现 auth_request 外部认证子请求中间件,支持将认证委托给
|
||||
// 外部服务。根据认证服务的响应状态码决定是否允许原请求继续。
|
||||
//
|
||||
// 行为规则:
|
||||
// - 2xx 响应:认证通过,原请求继续处理
|
||||
// - 401/403 响应:认证失败,返回相应状态码
|
||||
// - 超时或连接失败:返回 500 内部服务器错误
|
||||
// - 其他响应:返回 500 内部服务器错误
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// cfg := &config.AuthRequestConfig{
|
||||
// Enabled: true,
|
||||
// URI: "http://auth-service:8080/verify",
|
||||
// Method: "GET",
|
||||
// Timeout: 5 * time.Second,
|
||||
// Headers: map[string]string{
|
||||
// "X-Original-Uri": "$request_uri",
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// authReq, err := security.NewAuthRequest(cfg)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// // 应用为中间件
|
||||
// chain := middleware.NewChain(authReq)
|
||||
// handler := chain.Apply(finalHandler)
|
||||
//
|
||||
// 作者:xfy
|
||||
package security
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
"rua.plus/lolly/internal/middleware"
|
||||
"rua.plus/lolly/internal/variable"
|
||||
)
|
||||
|
||||
// AuthRequest 实现外部认证子请求中间件。
|
||||
type AuthRequest struct {
|
||||
// config 认证子请求配置
|
||||
config config.AuthRequestConfig
|
||||
|
||||
// client 用于发送认证子请求的 HTTP 客户端
|
||||
// 使用独立连接池避免影响主服务
|
||||
client *fasthttp.HostClient
|
||||
|
||||
// mu 保护 client 的并发访问
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewAuthRequest 使用给定的配置创建一个新的 AuthRequest 中间件。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 认证子请求配置
|
||||
//
|
||||
// 返回值:
|
||||
// - *AuthRequest: 配置完成的中间件实例
|
||||
// - error: 配置无效时返回错误
|
||||
func NewAuthRequest(cfg config.AuthRequestConfig) (*AuthRequest, error) {
|
||||
if !cfg.Enabled {
|
||||
return &AuthRequest{config: cfg}, nil
|
||||
}
|
||||
|
||||
if cfg.URI == "" {
|
||||
return nil, errors.New("auth_request: uri is required")
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
method := cfg.Method
|
||||
if method == "" {
|
||||
method = "GET"
|
||||
}
|
||||
cfg.Method = strings.ToUpper(method)
|
||||
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
cfg.Timeout = timeout
|
||||
|
||||
// 设置默认转发头
|
||||
if cfg.ForwardHeaders == nil {
|
||||
cfg.ForwardHeaders = []string{
|
||||
"Cookie",
|
||||
"Authorization",
|
||||
"X-Forwarded-For",
|
||||
"X-Real-Ip",
|
||||
}
|
||||
}
|
||||
|
||||
ar := &AuthRequest{
|
||||
config: cfg,
|
||||
}
|
||||
|
||||
// 如果 URI 是完整 URL(非相对路径),初始化 HTTP 客户端
|
||||
if isFullURL(cfg.URI) {
|
||||
if err := ar.initClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return ar, nil
|
||||
}
|
||||
|
||||
// isFullURL 检查 URI 是否为完整 URL。
|
||||
func isFullURL(uri string) bool {
|
||||
return strings.HasPrefix(uri, "http://") || strings.HasPrefix(uri, "https://")
|
||||
}
|
||||
|
||||
// initClient 初始化用于认证子请求的 HTTP 客户端。
|
||||
func (a *AuthRequest) initClient() error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
// 解析目标地址
|
||||
addr, isTLS, err := parseAuthURL(a.config.URI)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建独立连接池的客户端
|
||||
a.client = &fasthttp.HostClient{
|
||||
Addr: addr,
|
||||
IsTLS: isTLS,
|
||||
ReadTimeout: a.config.Timeout,
|
||||
WriteTimeout: a.config.Timeout,
|
||||
MaxIdleConnDuration: 90 * time.Second,
|
||||
MaxConns: 100,
|
||||
MaxConnWaitTimeout: a.config.Timeout,
|
||||
RetryIf: nil, // 禁用自动重试
|
||||
DisablePathNormalizing: false,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseAuthURL 解析认证服务 URL。
|
||||
//
|
||||
// 返回值:
|
||||
// - addr: 主机地址(如 "auth-service:8080")
|
||||
// - isTLS: 是否使用 HTTPS
|
||||
// - error: 解析错误
|
||||
func parseAuthURL(url string) (string, bool, error) {
|
||||
// 移除协议前缀
|
||||
var isTLS bool
|
||||
if strings.HasPrefix(url, "https://") {
|
||||
isTLS = true
|
||||
url = strings.TrimPrefix(url, "https://")
|
||||
} else if strings.HasPrefix(url, "http://") {
|
||||
url = strings.TrimPrefix(url, "http://")
|
||||
}
|
||||
|
||||
// 提取主机部分
|
||||
host := url
|
||||
if idx := strings.Index(host, "/"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
if idx := strings.Index(host, "?"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
|
||||
// 验证地址
|
||||
if host == "" {
|
||||
return "", false, errors.New("auth_request: invalid URL")
|
||||
}
|
||||
|
||||
// 添加默认端口
|
||||
if _, _, err := net.SplitHostPort(host); err != nil {
|
||||
if isTLS {
|
||||
host = net.JoinHostPort(host, "443")
|
||||
} else {
|
||||
host = net.JoinHostPort(host, "80")
|
||||
}
|
||||
}
|
||||
|
||||
return host, isTLS, nil
|
||||
}
|
||||
|
||||
// Name 返回中间件名称。
|
||||
func (a *AuthRequest) Name() string {
|
||||
return "auth_request"
|
||||
}
|
||||
|
||||
// Process 实现中间件处理逻辑。
|
||||
// 向认证服务发送子请求,根据响应决定是否允许原请求继续。
|
||||
func (a *AuthRequest) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||
if !a.config.Enabled {
|
||||
return next
|
||||
}
|
||||
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
// 执行认证子请求
|
||||
allowed, statusCode, err := a.doAuthRequest(ctx)
|
||||
if err != nil {
|
||||
// 认证服务不可用或超时,返回 500
|
||||
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
// 认证失败,返回认证服务的状态码
|
||||
ctx.Error("Unauthorized", statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
// 认证通过,继续处理原请求
|
||||
next(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// doAuthRequest 执行认证子请求。
|
||||
//
|
||||
// 返回值:
|
||||
// - allowed: 是否允许请求继续
|
||||
// - statusCode: 认证服务的响应状态码
|
||||
// - error: 请求过程中的错误
|
||||
func (a *AuthRequest) doAuthRequest(ctx *fasthttp.RequestCtx) (bool, int, error) {
|
||||
// 创建认证子请求
|
||||
req := fasthttp.AcquireRequest()
|
||||
resp := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseRequest(req)
|
||||
defer fasthttp.ReleaseResponse(resp)
|
||||
|
||||
// 设置请求方法
|
||||
req.Header.SetMethod(a.config.Method)
|
||||
|
||||
// 构建请求 URI(支持变量展开)
|
||||
uri := a.expandVars(ctx, a.config.URI)
|
||||
|
||||
// 如果是相对路径,转换为完整 URL
|
||||
if !isFullURL(uri) {
|
||||
// 从原请求构建完整 URL
|
||||
scheme := "http"
|
||||
if ctx.IsTLS() {
|
||||
scheme = "https"
|
||||
}
|
||||
host := string(ctx.Host())
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
uri = scheme + "://" + host + uri
|
||||
}
|
||||
req.SetRequestURI(uri)
|
||||
|
||||
// 转发原请求的头
|
||||
for _, headerName := range a.config.ForwardHeaders {
|
||||
if value := ctx.Request.Header.Peek(headerName); len(value) > 0 {
|
||||
req.Header.Set(headerName, string(value))
|
||||
}
|
||||
}
|
||||
|
||||
// 设置自定义头(支持变量展开)
|
||||
for key, value := range a.config.Headers {
|
||||
expanded := a.expandVars(ctx, value)
|
||||
req.Header.Set(key, expanded)
|
||||
}
|
||||
|
||||
// 发送认证请求
|
||||
a.mu.RLock()
|
||||
client := a.client
|
||||
a.mu.RUnlock()
|
||||
|
||||
if client != nil {
|
||||
// 使用独立连接池
|
||||
err := client.Do(req, resp)
|
||||
if err != nil {
|
||||
return false, 0, err
|
||||
}
|
||||
} else {
|
||||
// 使用默认客户端(相对路径情况)
|
||||
err := fasthttp.Do(req, resp)
|
||||
if err != nil {
|
||||
return false, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// 根据响应状态码判断认证结果
|
||||
statusCode := resp.StatusCode()
|
||||
switch {
|
||||
case statusCode >= 200 && statusCode < 300:
|
||||
// 2xx:认证通过
|
||||
return true, statusCode, nil
|
||||
case statusCode == 401 || statusCode == 403:
|
||||
// 401/403:认证失败
|
||||
return false, statusCode, nil
|
||||
default:
|
||||
// 其他状态码:视为认证服务错误
|
||||
return false, 500, nil
|
||||
}
|
||||
}
|
||||
|
||||
// expandVars 展开字符串中的变量。
|
||||
func (a *AuthRequest) expandVars(ctx *fasthttp.RequestCtx, template string) string {
|
||||
if template == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 快速检查:如果没有变量则直接返回
|
||||
if !strings.Contains(template, "$") {
|
||||
return template
|
||||
}
|
||||
|
||||
// 创建变量上下文
|
||||
vc := variable.NewVariableContext(ctx)
|
||||
defer variable.ReleaseVariableContext(vc)
|
||||
|
||||
return vc.Expand(template)
|
||||
}
|
||||
|
||||
// UpdateConfig 动态更新配置。
|
||||
// 用于配置热重载场景。
|
||||
func (a *AuthRequest) UpdateConfig(cfg config.AuthRequestConfig) error {
|
||||
if !cfg.Enabled {
|
||||
a.mu.Lock()
|
||||
a.config = cfg
|
||||
a.client = nil
|
||||
a.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg.URI == "" {
|
||||
return errors.New("auth_request: uri is required")
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
method := cfg.Method
|
||||
if method == "" {
|
||||
method = "GET"
|
||||
}
|
||||
cfg.Method = strings.ToUpper(method)
|
||||
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
cfg.Timeout = timeout
|
||||
|
||||
if cfg.ForwardHeaders == nil {
|
||||
cfg.ForwardHeaders = []string{
|
||||
"Cookie",
|
||||
"Authorization",
|
||||
"X-Forwarded-For",
|
||||
"X-Real-Ip",
|
||||
}
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
a.config = cfg
|
||||
|
||||
// 重新初始化客户端
|
||||
if isFullURL(cfg.URI) {
|
||||
if err := a.initClientUnlocked(); err != nil {
|
||||
a.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
a.client = nil
|
||||
}
|
||||
|
||||
a.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// initClientUnlocked 在无锁状态下初始化客户端。
|
||||
// 调用者必须持有写锁。
|
||||
func (a *AuthRequest) initClientUnlocked() error {
|
||||
addr, isTLS, err := parseAuthURL(a.config.URI)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.client = &fasthttp.HostClient{
|
||||
Addr: addr,
|
||||
IsTLS: isTLS,
|
||||
ReadTimeout: a.config.Timeout,
|
||||
WriteTimeout: a.config.Timeout,
|
||||
MaxIdleConnDuration: 90 * time.Second,
|
||||
MaxConns: 100,
|
||||
MaxConnWaitTimeout: a.config.Timeout,
|
||||
RetryIf: nil,
|
||||
DisablePathNormalizing: false,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭中间件并释放资源。
|
||||
func (a *AuthRequest) Close() error {
|
||||
a.mu.Lock()
|
||||
a.client = nil
|
||||
a.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Middleware 返回中间件接口实现。
|
||||
// 用于兼容中间件链。
|
||||
func (a *AuthRequest) Middleware() middleware.Middleware {
|
||||
return a
|
||||
}
|
||||
|
||||
// Ensure security implements Middleware interface
|
||||
var _ middleware.Middleware = (*AuthRequest)(nil)
|
||||
414
internal/middleware/security/auth_request_test.go
Normal file
414
internal/middleware/security/auth_request_test.go
Normal file
@ -0,0 +1,414 @@
|
||||
// Package security 提供 auth_request 中间件的单元测试。
|
||||
//
|
||||
// 测试覆盖:
|
||||
// - 认证成功(2xx 响应)
|
||||
// - 认证失败(401/403 响应)
|
||||
// - 认证服务不可用
|
||||
// - 超时处理
|
||||
// - 变量展开
|
||||
// - 配置更新
|
||||
//
|
||||
// 作者:xfy
|
||||
package security
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
|
||||
"rua.plus/lolly/internal/config"
|
||||
)
|
||||
|
||||
// TestNewAuthRequest 测试 AuthRequest 中间件创建
|
||||
func TestNewAuthRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.AuthRequestConfig
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "正常创建(禁用)",
|
||||
cfg: config.AuthRequestConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "正常创建(启用,相对路径)",
|
||||
cfg: config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "/auth",
|
||||
Method: "GET",
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "正常创建(启用,完整URL)",
|
||||
cfg: config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "http://localhost:8080/auth",
|
||||
Method: "POST",
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "启用但未配置 URI",
|
||||
cfg: config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "uri is required",
|
||||
},
|
||||
{
|
||||
name: "使用默认值",
|
||||
cfg: config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "/auth",
|
||||
// Method 和 Timeout 使用默认值
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ar, err := NewAuthRequest(tt.cfg)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("NewAuthRequest() expected error, got nil")
|
||||
return
|
||||
}
|
||||
if tt.errMsg != "" && !contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("NewAuthRequest() error = %v, should contain %q", err, tt.errMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("NewAuthRequest() unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if ar == nil {
|
||||
t.Error("NewAuthRequest() returned nil")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证默认值设置
|
||||
if tt.cfg.Enabled {
|
||||
if ar.config.Method == "" {
|
||||
t.Error("Method should have default value")
|
||||
}
|
||||
if ar.config.Timeout == 0 {
|
||||
t.Error("Timeout should have default value")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthRequestName 测试中间件名称
|
||||
func TestAuthRequestName(t *testing.T) {
|
||||
ar := &AuthRequest{}
|
||||
if name := ar.Name(); name != "auth_request" {
|
||||
t.Errorf("Name() = %q, want 'auth_request'", name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthRequestProcess_Disabled 测试禁用状态下的处理
|
||||
func TestAuthRequestProcess_Disabled(t *testing.T) {
|
||||
cfg := config.AuthRequestConfig{
|
||||
Enabled: false,
|
||||
}
|
||||
|
||||
ar, err := NewAuthRequest(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthRequest() failed: %v", err)
|
||||
}
|
||||
|
||||
// 创建测试处理器
|
||||
called := false
|
||||
next := func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
ctx.SetStatusCode(200)
|
||||
}
|
||||
|
||||
handler := ar.Process(next)
|
||||
|
||||
// 执行请求
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetRequestURI("/test")
|
||||
|
||||
handler(ctx)
|
||||
|
||||
// 验证处理器被调用
|
||||
if !called {
|
||||
t.Error("Next handler should be called when disabled")
|
||||
}
|
||||
if ctx.Response.StatusCode() != 200 {
|
||||
t.Errorf("Expected status 200, got %d", ctx.Response.StatusCode())
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseAuthURL 测试 URL 解析
|
||||
func TestParseAuthURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantAddr string
|
||||
wantTLS bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "HTTP URL 带端口",
|
||||
url: "http://auth-service:8080/verify",
|
||||
wantAddr: "auth-service:8080",
|
||||
wantTLS: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "HTTPS URL 带端口",
|
||||
url: "https://auth-service:8443/verify",
|
||||
wantAddr: "auth-service:8443",
|
||||
wantTLS: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "HTTP URL 不带端口",
|
||||
url: "http://auth-service/verify",
|
||||
wantAddr: "auth-service:80",
|
||||
wantTLS: false,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "HTTPS URL 不带端口",
|
||||
url: "https://auth-service/verify",
|
||||
wantAddr: "auth-service:443",
|
||||
wantTLS: true,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "相对路径",
|
||||
url: "/auth",
|
||||
wantAddr: "",
|
||||
wantTLS: false,
|
||||
wantErr: true, // 相对路径会报错
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, isTLS, err := parseAuthURL(tt.url)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("parseAuthURL() expected error")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("parseAuthURL() unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证地址(可能包含默认端口)
|
||||
_, _, _ = net.SplitHostPort(addr)
|
||||
|
||||
if isTLS != tt.wantTLS {
|
||||
t.Errorf("isTLS = %v, want %v", isTLS, tt.wantTLS)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsFullURL 测试 URL 检测
|
||||
func TestIsFullURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
uri string
|
||||
expected bool
|
||||
}{
|
||||
{"http://example.com", true},
|
||||
{"https://example.com", true},
|
||||
{"/auth", false},
|
||||
{"auth", false},
|
||||
{"", false},
|
||||
{"ftp://example.com", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.uri, func(t *testing.T) {
|
||||
result := isFullURL(tt.uri)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isFullURL(%q) = %v, want %v", tt.uri, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthRequestExpandVars 测试变量展开
|
||||
func TestAuthRequestExpandVars(t *testing.T) {
|
||||
ar := &AuthRequest{
|
||||
config: config.AuthRequestConfig{},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
uri string
|
||||
host string
|
||||
template string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "无变量",
|
||||
method: "GET",
|
||||
uri: "/test",
|
||||
host: "example.com",
|
||||
template: "http://auth-service/auth",
|
||||
expected: "http://auth-service/auth",
|
||||
},
|
||||
{
|
||||
name: "包含变量",
|
||||
method: "POST",
|
||||
uri: "/api/users",
|
||||
host: "api.example.com",
|
||||
template: "http://auth-service/verify?uri=$request_uri",
|
||||
expected: "http://auth-service/verify?uri=/api/users",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod(tt.method)
|
||||
ctx.Request.Header.SetRequestURI(tt.uri)
|
||||
ctx.Request.Header.SetHost(tt.host)
|
||||
|
||||
result := ar.expandVars(ctx, tt.template)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expandVars() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthRequestUpdateConfig 测试配置更新
|
||||
func TestAuthRequestUpdateConfig(t *testing.T) {
|
||||
// 创建初始实例
|
||||
cfg := config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "http://old-auth:8080/auth",
|
||||
Method: "GET",
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
ar, err := NewAuthRequest(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthRequest() failed: %v", err)
|
||||
}
|
||||
|
||||
// 更新为禁用
|
||||
t.Run("更新为禁用", func(t *testing.T) {
|
||||
newCfg := config.AuthRequestConfig{
|
||||
Enabled: false,
|
||||
}
|
||||
err := ar.UpdateConfig(newCfg)
|
||||
if err != nil {
|
||||
t.Errorf("UpdateConfig() failed: %v", err)
|
||||
}
|
||||
if ar.config.Enabled {
|
||||
t.Error("Expected config to be disabled")
|
||||
}
|
||||
})
|
||||
|
||||
// 更新为新的启用配置
|
||||
t.Run("更新为新配置", func(t *testing.T) {
|
||||
newCfg := config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "http://new-auth:8080/auth",
|
||||
Method: "POST",
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
err := ar.UpdateConfig(newCfg)
|
||||
if err != nil {
|
||||
t.Errorf("UpdateConfig() failed: %v", err)
|
||||
}
|
||||
if ar.config.URI != "http://new-auth:8080/auth" {
|
||||
t.Errorf("URI not updated: %s", ar.config.URI)
|
||||
}
|
||||
if ar.config.Method != "POST" {
|
||||
t.Errorf("Method not updated: %s", ar.config.Method)
|
||||
}
|
||||
})
|
||||
|
||||
// 更新失败(缺少 URI)
|
||||
t.Run("更新失败(缺少 URI)", func(t *testing.T) {
|
||||
newCfg := config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "",
|
||||
}
|
||||
err := ar.UpdateConfig(newCfg)
|
||||
if err == nil {
|
||||
t.Error("UpdateConfig() should fail without URI")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAuthRequestClose 测试关闭
|
||||
func TestAuthRequestClose(t *testing.T) {
|
||||
cfg := config.AuthRequestConfig{
|
||||
Enabled: true,
|
||||
URI: "http://auth-service:8080/auth",
|
||||
}
|
||||
|
||||
ar, err := NewAuthRequest(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAuthRequest() failed: %v", err)
|
||||
}
|
||||
|
||||
err = ar.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close() failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证客户端被清理
|
||||
if ar.client != nil {
|
||||
t.Error("client should be nil after Close()")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkAuthRequestExpandVars 基准测试:变量展开
|
||||
func BenchmarkAuthRequestExpandVars(b *testing.B) {
|
||||
ar := &AuthRequest{
|
||||
config: config.AuthRequestConfig{},
|
||||
}
|
||||
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
ctx.Request.Header.SetRequestURI("/api/users?page=1")
|
||||
ctx.Request.Header.SetHost("api.example.com")
|
||||
|
||||
template := "http://auth-service/verify?uri=$request_uri&host=$host"
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ar.expandVars(ctx, template)
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func contains(s, substr string) bool {
|
||||
return strings.Contains(s, substr)
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user