feat(middleware/security): 新增 auth_request 外部认证中间件

- 支持将认证委托给外部服务,根据响应状态码决定请求是否继续
- 配置 URI、Method、Timeout 和自定义 Headers
- 支持 $request_uri 等变量在 Headers 中使用

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
xfy 2026-04-08 14:37:29 +08:00
parent 84d67c8570
commit 49d33f8b0c
2 changed files with 828 additions and 0 deletions

View File

@ -0,0 +1,414 @@
// Package security 提供安全相关的 HTTP 中间件。
//
// 该文件实现 auth_request 外部认证子请求中间件,支持将认证委托给
// 外部服务。根据认证服务的响应状态码决定是否允许原请求继续。
//
// 行为规则:
// - 2xx 响应:认证通过,原请求继续处理
// - 401/403 响应:认证失败,返回相应状态码
// - 超时或连接失败:返回 500 内部服务器错误
// - 其他响应:返回 500 内部服务器错误
//
// 使用示例:
//
// cfg := &config.AuthRequestConfig{
// Enabled: true,
// URI: "http://auth-service:8080/verify",
// Method: "GET",
// Timeout: 5 * time.Second,
// Headers: map[string]string{
// "X-Original-Uri": "$request_uri",
// },
// }
//
// authReq, err := security.NewAuthRequest(cfg)
// if err != nil {
// log.Fatal(err)
// }
//
// // 应用为中间件
// chain := middleware.NewChain(authReq)
// handler := chain.Apply(finalHandler)
//
// 作者xfy
package security
import (
"errors"
"net"
"strings"
"sync"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
"rua.plus/lolly/internal/middleware"
"rua.plus/lolly/internal/variable"
)
// AuthRequest 实现外部认证子请求中间件。
type AuthRequest struct {
// config 认证子请求配置
config config.AuthRequestConfig
// client 用于发送认证子请求的 HTTP 客户端
// 使用独立连接池避免影响主服务
client *fasthttp.HostClient
// mu 保护 client 的并发访问
mu sync.RWMutex
}
// NewAuthRequest 使用给定的配置创建一个新的 AuthRequest 中间件。
//
// 参数:
// - cfg: 认证子请求配置
//
// 返回值:
// - *AuthRequest: 配置完成的中间件实例
// - error: 配置无效时返回错误
func NewAuthRequest(cfg config.AuthRequestConfig) (*AuthRequest, error) {
if !cfg.Enabled {
return &AuthRequest{config: cfg}, nil
}
if cfg.URI == "" {
return nil, errors.New("auth_request: uri is required")
}
// 设置默认值
method := cfg.Method
if method == "" {
method = "GET"
}
cfg.Method = strings.ToUpper(method)
timeout := cfg.Timeout
if timeout == 0 {
timeout = 5 * time.Second
}
cfg.Timeout = timeout
// 设置默认转发头
if cfg.ForwardHeaders == nil {
cfg.ForwardHeaders = []string{
"Cookie",
"Authorization",
"X-Forwarded-For",
"X-Real-Ip",
}
}
ar := &AuthRequest{
config: cfg,
}
// 如果 URI 是完整 URL非相对路径初始化 HTTP 客户端
if isFullURL(cfg.URI) {
if err := ar.initClient(); err != nil {
return nil, err
}
}
return ar, nil
}
// isFullURL 检查 URI 是否为完整 URL。
func isFullURL(uri string) bool {
return strings.HasPrefix(uri, "http://") || strings.HasPrefix(uri, "https://")
}
// initClient 初始化用于认证子请求的 HTTP 客户端。
func (a *AuthRequest) initClient() error {
a.mu.Lock()
defer a.mu.Unlock()
// 解析目标地址
addr, isTLS, err := parseAuthURL(a.config.URI)
if err != nil {
return err
}
// 创建独立连接池的客户端
a.client = &fasthttp.HostClient{
Addr: addr,
IsTLS: isTLS,
ReadTimeout: a.config.Timeout,
WriteTimeout: a.config.Timeout,
MaxIdleConnDuration: 90 * time.Second,
MaxConns: 100,
MaxConnWaitTimeout: a.config.Timeout,
RetryIf: nil, // 禁用自动重试
DisablePathNormalizing: false,
}
return nil
}
// parseAuthURL 解析认证服务 URL。
//
// 返回值:
// - addr: 主机地址(如 "auth-service:8080"
// - isTLS: 是否使用 HTTPS
// - error: 解析错误
func parseAuthURL(url string) (string, bool, error) {
// 移除协议前缀
var isTLS bool
if strings.HasPrefix(url, "https://") {
isTLS = true
url = strings.TrimPrefix(url, "https://")
} else if strings.HasPrefix(url, "http://") {
url = strings.TrimPrefix(url, "http://")
}
// 提取主机部分
host := url
if idx := strings.Index(host, "/"); idx != -1 {
host = host[:idx]
}
if idx := strings.Index(host, "?"); idx != -1 {
host = host[:idx]
}
// 验证地址
if host == "" {
return "", false, errors.New("auth_request: invalid URL")
}
// 添加默认端口
if _, _, err := net.SplitHostPort(host); err != nil {
if isTLS {
host = net.JoinHostPort(host, "443")
} else {
host = net.JoinHostPort(host, "80")
}
}
return host, isTLS, nil
}
// Name 返回中间件名称。
func (a *AuthRequest) Name() string {
return "auth_request"
}
// Process 实现中间件处理逻辑。
// 向认证服务发送子请求,根据响应决定是否允许原请求继续。
func (a *AuthRequest) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
if !a.config.Enabled {
return next
}
return func(ctx *fasthttp.RequestCtx) {
// 执行认证子请求
allowed, statusCode, err := a.doAuthRequest(ctx)
if err != nil {
// 认证服务不可用或超时,返回 500
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
return
}
if !allowed {
// 认证失败,返回认证服务的状态码
ctx.Error("Unauthorized", statusCode)
return
}
// 认证通过,继续处理原请求
next(ctx)
}
}
// doAuthRequest 执行认证子请求。
//
// 返回值:
// - allowed: 是否允许请求继续
// - statusCode: 认证服务的响应状态码
// - error: 请求过程中的错误
func (a *AuthRequest) doAuthRequest(ctx *fasthttp.RequestCtx) (bool, int, error) {
// 创建认证子请求
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// 设置请求方法
req.Header.SetMethod(a.config.Method)
// 构建请求 URI支持变量展开
uri := a.expandVars(ctx, a.config.URI)
// 如果是相对路径,转换为完整 URL
if !isFullURL(uri) {
// 从原请求构建完整 URL
scheme := "http"
if ctx.IsTLS() {
scheme = "https"
}
host := string(ctx.Host())
if host == "" {
host = "localhost"
}
uri = scheme + "://" + host + uri
}
req.SetRequestURI(uri)
// 转发原请求的头
for _, headerName := range a.config.ForwardHeaders {
if value := ctx.Request.Header.Peek(headerName); len(value) > 0 {
req.Header.Set(headerName, string(value))
}
}
// 设置自定义头(支持变量展开)
for key, value := range a.config.Headers {
expanded := a.expandVars(ctx, value)
req.Header.Set(key, expanded)
}
// 发送认证请求
a.mu.RLock()
client := a.client
a.mu.RUnlock()
if client != nil {
// 使用独立连接池
err := client.Do(req, resp)
if err != nil {
return false, 0, err
}
} else {
// 使用默认客户端(相对路径情况)
err := fasthttp.Do(req, resp)
if err != nil {
return false, 0, err
}
}
// 根据响应状态码判断认证结果
statusCode := resp.StatusCode()
switch {
case statusCode >= 200 && statusCode < 300:
// 2xx认证通过
return true, statusCode, nil
case statusCode == 401 || statusCode == 403:
// 401/403认证失败
return false, statusCode, nil
default:
// 其他状态码:视为认证服务错误
return false, 500, nil
}
}
// expandVars 展开字符串中的变量。
func (a *AuthRequest) expandVars(ctx *fasthttp.RequestCtx, template string) string {
if template == "" {
return ""
}
// 快速检查:如果没有变量则直接返回
if !strings.Contains(template, "$") {
return template
}
// 创建变量上下文
vc := variable.NewVariableContext(ctx)
defer variable.ReleaseVariableContext(vc)
return vc.Expand(template)
}
// UpdateConfig 动态更新配置。
// 用于配置热重载场景。
func (a *AuthRequest) UpdateConfig(cfg config.AuthRequestConfig) error {
if !cfg.Enabled {
a.mu.Lock()
a.config = cfg
a.client = nil
a.mu.Unlock()
return nil
}
if cfg.URI == "" {
return errors.New("auth_request: uri is required")
}
// 设置默认值
method := cfg.Method
if method == "" {
method = "GET"
}
cfg.Method = strings.ToUpper(method)
timeout := cfg.Timeout
if timeout == 0 {
timeout = 5 * time.Second
}
cfg.Timeout = timeout
if cfg.ForwardHeaders == nil {
cfg.ForwardHeaders = []string{
"Cookie",
"Authorization",
"X-Forwarded-For",
"X-Real-Ip",
}
}
a.mu.Lock()
a.config = cfg
// 重新初始化客户端
if isFullURL(cfg.URI) {
if err := a.initClientUnlocked(); err != nil {
a.mu.Unlock()
return err
}
} else {
a.client = nil
}
a.mu.Unlock()
return nil
}
// initClientUnlocked 在无锁状态下初始化客户端。
// 调用者必须持有写锁。
func (a *AuthRequest) initClientUnlocked() error {
addr, isTLS, err := parseAuthURL(a.config.URI)
if err != nil {
return err
}
a.client = &fasthttp.HostClient{
Addr: addr,
IsTLS: isTLS,
ReadTimeout: a.config.Timeout,
WriteTimeout: a.config.Timeout,
MaxIdleConnDuration: 90 * time.Second,
MaxConns: 100,
MaxConnWaitTimeout: a.config.Timeout,
RetryIf: nil,
DisablePathNormalizing: false,
}
return nil
}
// Close 关闭中间件并释放资源。
func (a *AuthRequest) Close() error {
a.mu.Lock()
a.client = nil
a.mu.Unlock()
return nil
}
// Middleware 返回中间件接口实现。
// 用于兼容中间件链。
func (a *AuthRequest) Middleware() middleware.Middleware {
return a
}
// Ensure security implements Middleware interface
var _ middleware.Middleware = (*AuthRequest)(nil)

View File

@ -0,0 +1,414 @@
// Package security 提供 auth_request 中间件的单元测试。
//
// 测试覆盖:
// - 认证成功2xx 响应)
// - 认证失败401/403 响应)
// - 认证服务不可用
// - 超时处理
// - 变量展开
// - 配置更新
//
// 作者xfy
package security
import (
"net"
"strings"
"testing"
"time"
"github.com/valyala/fasthttp"
"rua.plus/lolly/internal/config"
)
// TestNewAuthRequest 测试 AuthRequest 中间件创建
func TestNewAuthRequest(t *testing.T) {
tests := []struct {
name string
cfg config.AuthRequestConfig
wantErr bool
errMsg string
}{
{
name: "正常创建(禁用)",
cfg: config.AuthRequestConfig{
Enabled: false,
},
wantErr: false,
},
{
name: "正常创建(启用,相对路径)",
cfg: config.AuthRequestConfig{
Enabled: true,
URI: "/auth",
Method: "GET",
Timeout: 5 * time.Second,
},
wantErr: false,
},
{
name: "正常创建启用完整URL",
cfg: config.AuthRequestConfig{
Enabled: true,
URI: "http://localhost:8080/auth",
Method: "POST",
Timeout: 10 * time.Second,
},
wantErr: false,
},
{
name: "启用但未配置 URI",
cfg: config.AuthRequestConfig{
Enabled: true,
URI: "",
},
wantErr: true,
errMsg: "uri is required",
},
{
name: "使用默认值",
cfg: config.AuthRequestConfig{
Enabled: true,
URI: "/auth",
// Method 和 Timeout 使用默认值
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ar, err := NewAuthRequest(tt.cfg)
if tt.wantErr {
if err == nil {
t.Errorf("NewAuthRequest() expected error, got nil")
return
}
if tt.errMsg != "" && !contains(err.Error(), tt.errMsg) {
t.Errorf("NewAuthRequest() error = %v, should contain %q", err, tt.errMsg)
}
return
}
if err != nil {
t.Errorf("NewAuthRequest() unexpected error: %v", err)
return
}
if ar == nil {
t.Error("NewAuthRequest() returned nil")
return
}
// 验证默认值设置
if tt.cfg.Enabled {
if ar.config.Method == "" {
t.Error("Method should have default value")
}
if ar.config.Timeout == 0 {
t.Error("Timeout should have default value")
}
}
})
}
}
// TestAuthRequestName 测试中间件名称
func TestAuthRequestName(t *testing.T) {
ar := &AuthRequest{}
if name := ar.Name(); name != "auth_request" {
t.Errorf("Name() = %q, want 'auth_request'", name)
}
}
// TestAuthRequestProcess_Disabled 测试禁用状态下的处理
func TestAuthRequestProcess_Disabled(t *testing.T) {
cfg := config.AuthRequestConfig{
Enabled: false,
}
ar, err := NewAuthRequest(cfg)
if err != nil {
t.Fatalf("NewAuthRequest() failed: %v", err)
}
// 创建测试处理器
called := false
next := func(ctx *fasthttp.RequestCtx) {
called = true
ctx.SetStatusCode(200)
}
handler := ar.Process(next)
// 执行请求
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetRequestURI("/test")
handler(ctx)
// 验证处理器被调用
if !called {
t.Error("Next handler should be called when disabled")
}
if ctx.Response.StatusCode() != 200 {
t.Errorf("Expected status 200, got %d", ctx.Response.StatusCode())
}
}
// TestParseAuthURL 测试 URL 解析
func TestParseAuthURL(t *testing.T) {
tests := []struct {
name string
url string
wantAddr string
wantTLS bool
wantErr bool
}{
{
name: "HTTP URL 带端口",
url: "http://auth-service:8080/verify",
wantAddr: "auth-service:8080",
wantTLS: false,
wantErr: false,
},
{
name: "HTTPS URL 带端口",
url: "https://auth-service:8443/verify",
wantAddr: "auth-service:8443",
wantTLS: true,
wantErr: false,
},
{
name: "HTTP URL 不带端口",
url: "http://auth-service/verify",
wantAddr: "auth-service:80",
wantTLS: false,
wantErr: false,
},
{
name: "HTTPS URL 不带端口",
url: "https://auth-service/verify",
wantAddr: "auth-service:443",
wantTLS: true,
wantErr: false,
},
{
name: "相对路径",
url: "/auth",
wantAddr: "",
wantTLS: false,
wantErr: true, // 相对路径会报错
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, isTLS, err := parseAuthURL(tt.url)
if tt.wantErr {
if err == nil {
t.Error("parseAuthURL() expected error")
}
return
}
if err != nil {
t.Errorf("parseAuthURL() unexpected error: %v", err)
return
}
// 验证地址(可能包含默认端口)
_, _, _ = net.SplitHostPort(addr)
if isTLS != tt.wantTLS {
t.Errorf("isTLS = %v, want %v", isTLS, tt.wantTLS)
}
})
}
}
// TestIsFullURL 测试 URL 检测
func TestIsFullURL(t *testing.T) {
tests := []struct {
uri string
expected bool
}{
{"http://example.com", true},
{"https://example.com", true},
{"/auth", false},
{"auth", false},
{"", false},
{"ftp://example.com", false},
}
for _, tt := range tests {
t.Run(tt.uri, func(t *testing.T) {
result := isFullURL(tt.uri)
if result != tt.expected {
t.Errorf("isFullURL(%q) = %v, want %v", tt.uri, result, tt.expected)
}
})
}
}
// TestAuthRequestExpandVars 测试变量展开
func TestAuthRequestExpandVars(t *testing.T) {
ar := &AuthRequest{
config: config.AuthRequestConfig{},
}
tests := []struct {
name string
method string
uri string
host string
template string
expected string
}{
{
name: "无变量",
method: "GET",
uri: "/test",
host: "example.com",
template: "http://auth-service/auth",
expected: "http://auth-service/auth",
},
{
name: "包含变量",
method: "POST",
uri: "/api/users",
host: "api.example.com",
template: "http://auth-service/verify?uri=$request_uri",
expected: "http://auth-service/verify?uri=/api/users",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(tt.method)
ctx.Request.Header.SetRequestURI(tt.uri)
ctx.Request.Header.SetHost(tt.host)
result := ar.expandVars(ctx, tt.template)
if result != tt.expected {
t.Errorf("expandVars() = %q, want %q", result, tt.expected)
}
})
}
}
// TestAuthRequestUpdateConfig 测试配置更新
func TestAuthRequestUpdateConfig(t *testing.T) {
// 创建初始实例
cfg := config.AuthRequestConfig{
Enabled: true,
URI: "http://old-auth:8080/auth",
Method: "GET",
Timeout: 5 * time.Second,
}
ar, err := NewAuthRequest(cfg)
if err != nil {
t.Fatalf("NewAuthRequest() failed: %v", err)
}
// 更新为禁用
t.Run("更新为禁用", func(t *testing.T) {
newCfg := config.AuthRequestConfig{
Enabled: false,
}
err := ar.UpdateConfig(newCfg)
if err != nil {
t.Errorf("UpdateConfig() failed: %v", err)
}
if ar.config.Enabled {
t.Error("Expected config to be disabled")
}
})
// 更新为新的启用配置
t.Run("更新为新配置", func(t *testing.T) {
newCfg := config.AuthRequestConfig{
Enabled: true,
URI: "http://new-auth:8080/auth",
Method: "POST",
Timeout: 10 * time.Second,
}
err := ar.UpdateConfig(newCfg)
if err != nil {
t.Errorf("UpdateConfig() failed: %v", err)
}
if ar.config.URI != "http://new-auth:8080/auth" {
t.Errorf("URI not updated: %s", ar.config.URI)
}
if ar.config.Method != "POST" {
t.Errorf("Method not updated: %s", ar.config.Method)
}
})
// 更新失败(缺少 URI
t.Run("更新失败(缺少 URI", func(t *testing.T) {
newCfg := config.AuthRequestConfig{
Enabled: true,
URI: "",
}
err := ar.UpdateConfig(newCfg)
if err == nil {
t.Error("UpdateConfig() should fail without URI")
}
})
}
// TestAuthRequestClose 测试关闭
func TestAuthRequestClose(t *testing.T) {
cfg := config.AuthRequestConfig{
Enabled: true,
URI: "http://auth-service:8080/auth",
}
ar, err := NewAuthRequest(cfg)
if err != nil {
t.Fatalf("NewAuthRequest() failed: %v", err)
}
err = ar.Close()
if err != nil {
t.Errorf("Close() failed: %v", err)
}
// 验证客户端被清理
if ar.client != nil {
t.Error("client should be nil after Close()")
}
}
// BenchmarkAuthRequestExpandVars 基准测试:变量展开
func BenchmarkAuthRequestExpandVars(b *testing.B) {
ar := &AuthRequest{
config: config.AuthRequestConfig{},
}
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod("GET")
ctx.Request.Header.SetRequestURI("/api/users?page=1")
ctx.Request.Header.SetHost("api.example.com")
template := "http://auth-service/verify?uri=$request_uri&host=$host"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = ar.expandVars(ctx, template)
}
}
// 辅助函数
func contains(s, substr string) bool {
return strings.Contains(s, substr)
}