From 49d33f8b0ca2bcef716a74f60208fec295ef8668 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 8 Apr 2026 14:37:29 +0800 Subject: [PATCH] =?UTF-8?q?feat(middleware/security):=20=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=20auth=5Frequest=20=E5=A4=96=E9=83=A8=E8=AE=A4=E8=AF=81?= =?UTF-8?q?=E4=B8=AD=E9=97=B4=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 支持将认证委托给外部服务,根据响应状态码决定请求是否继续 - 配置 URI、Method、Timeout 和自定义 Headers - 支持 $request_uri 等变量在 Headers 中使用 Co-Authored-By: Claude --- internal/middleware/security/auth_request.go | 414 ++++++++++++++++++ .../middleware/security/auth_request_test.go | 414 ++++++++++++++++++ 2 files changed, 828 insertions(+) create mode 100644 internal/middleware/security/auth_request.go create mode 100644 internal/middleware/security/auth_request_test.go diff --git a/internal/middleware/security/auth_request.go b/internal/middleware/security/auth_request.go new file mode 100644 index 0000000..9f059fb --- /dev/null +++ b/internal/middleware/security/auth_request.go @@ -0,0 +1,414 @@ +// Package security 提供安全相关的 HTTP 中间件。 +// +// 该文件实现 auth_request 外部认证子请求中间件,支持将认证委托给 +// 外部服务。根据认证服务的响应状态码决定是否允许原请求继续。 +// +// 行为规则: +// - 2xx 响应:认证通过,原请求继续处理 +// - 401/403 响应:认证失败,返回相应状态码 +// - 超时或连接失败:返回 500 内部服务器错误 +// - 其他响应:返回 500 内部服务器错误 +// +// 使用示例: +// +// cfg := &config.AuthRequestConfig{ +// Enabled: true, +// URI: "http://auth-service:8080/verify", +// Method: "GET", +// Timeout: 5 * time.Second, +// Headers: map[string]string{ +// "X-Original-Uri": "$request_uri", +// }, +// } +// +// authReq, err := security.NewAuthRequest(cfg) +// if err != nil { +// log.Fatal(err) +// } +// +// // 应用为中间件 +// chain := middleware.NewChain(authReq) +// handler := chain.Apply(finalHandler) +// +// 作者:xfy +package security + +import ( + "errors" + "net" + "strings" + "sync" + "time" + + "github.com/valyala/fasthttp" + + "rua.plus/lolly/internal/config" + "rua.plus/lolly/internal/middleware" + "rua.plus/lolly/internal/variable" +) + +// AuthRequest 实现外部认证子请求中间件。 +type AuthRequest struct { + // config 认证子请求配置 + config config.AuthRequestConfig + + // client 用于发送认证子请求的 HTTP 客户端 + // 使用独立连接池避免影响主服务 + client *fasthttp.HostClient + + // mu 保护 client 的并发访问 + mu sync.RWMutex +} + +// NewAuthRequest 使用给定的配置创建一个新的 AuthRequest 中间件。 +// +// 参数: +// - cfg: 认证子请求配置 +// +// 返回值: +// - *AuthRequest: 配置完成的中间件实例 +// - error: 配置无效时返回错误 +func NewAuthRequest(cfg config.AuthRequestConfig) (*AuthRequest, error) { + if !cfg.Enabled { + return &AuthRequest{config: cfg}, nil + } + + if cfg.URI == "" { + return nil, errors.New("auth_request: uri is required") + } + + // 设置默认值 + method := cfg.Method + if method == "" { + method = "GET" + } + cfg.Method = strings.ToUpper(method) + + timeout := cfg.Timeout + if timeout == 0 { + timeout = 5 * time.Second + } + cfg.Timeout = timeout + + // 设置默认转发头 + if cfg.ForwardHeaders == nil { + cfg.ForwardHeaders = []string{ + "Cookie", + "Authorization", + "X-Forwarded-For", + "X-Real-Ip", + } + } + + ar := &AuthRequest{ + config: cfg, + } + + // 如果 URI 是完整 URL(非相对路径),初始化 HTTP 客户端 + if isFullURL(cfg.URI) { + if err := ar.initClient(); err != nil { + return nil, err + } + } + + return ar, nil +} + +// isFullURL 检查 URI 是否为完整 URL。 +func isFullURL(uri string) bool { + return strings.HasPrefix(uri, "http://") || strings.HasPrefix(uri, "https://") +} + +// initClient 初始化用于认证子请求的 HTTP 客户端。 +func (a *AuthRequest) initClient() error { + a.mu.Lock() + defer a.mu.Unlock() + + // 解析目标地址 + addr, isTLS, err := parseAuthURL(a.config.URI) + if err != nil { + return err + } + + // 创建独立连接池的客户端 + a.client = &fasthttp.HostClient{ + Addr: addr, + IsTLS: isTLS, + ReadTimeout: a.config.Timeout, + WriteTimeout: a.config.Timeout, + MaxIdleConnDuration: 90 * time.Second, + MaxConns: 100, + MaxConnWaitTimeout: a.config.Timeout, + RetryIf: nil, // 禁用自动重试 + DisablePathNormalizing: false, + } + + return nil +} + +// parseAuthURL 解析认证服务 URL。 +// +// 返回值: +// - addr: 主机地址(如 "auth-service:8080") +// - isTLS: 是否使用 HTTPS +// - error: 解析错误 +func parseAuthURL(url string) (string, bool, error) { + // 移除协议前缀 + var isTLS bool + if strings.HasPrefix(url, "https://") { + isTLS = true + url = strings.TrimPrefix(url, "https://") + } else if strings.HasPrefix(url, "http://") { + url = strings.TrimPrefix(url, "http://") + } + + // 提取主机部分 + host := url + if idx := strings.Index(host, "/"); idx != -1 { + host = host[:idx] + } + if idx := strings.Index(host, "?"); idx != -1 { + host = host[:idx] + } + + // 验证地址 + if host == "" { + return "", false, errors.New("auth_request: invalid URL") + } + + // 添加默认端口 + if _, _, err := net.SplitHostPort(host); err != nil { + if isTLS { + host = net.JoinHostPort(host, "443") + } else { + host = net.JoinHostPort(host, "80") + } + } + + return host, isTLS, nil +} + +// Name 返回中间件名称。 +func (a *AuthRequest) Name() string { + return "auth_request" +} + +// Process 实现中间件处理逻辑。 +// 向认证服务发送子请求,根据响应决定是否允许原请求继续。 +func (a *AuthRequest) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { + if !a.config.Enabled { + return next + } + + return func(ctx *fasthttp.RequestCtx) { + // 执行认证子请求 + allowed, statusCode, err := a.doAuthRequest(ctx) + if err != nil { + // 认证服务不可用或超时,返回 500 + ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError) + return + } + + if !allowed { + // 认证失败,返回认证服务的状态码 + ctx.Error("Unauthorized", statusCode) + return + } + + // 认证通过,继续处理原请求 + next(ctx) + } +} + +// doAuthRequest 执行认证子请求。 +// +// 返回值: +// - allowed: 是否允许请求继续 +// - statusCode: 认证服务的响应状态码 +// - error: 请求过程中的错误 +func (a *AuthRequest) doAuthRequest(ctx *fasthttp.RequestCtx) (bool, int, error) { + // 创建认证子请求 + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // 设置请求方法 + req.Header.SetMethod(a.config.Method) + + // 构建请求 URI(支持变量展开) + uri := a.expandVars(ctx, a.config.URI) + + // 如果是相对路径,转换为完整 URL + if !isFullURL(uri) { + // 从原请求构建完整 URL + scheme := "http" + if ctx.IsTLS() { + scheme = "https" + } + host := string(ctx.Host()) + if host == "" { + host = "localhost" + } + uri = scheme + "://" + host + uri + } + req.SetRequestURI(uri) + + // 转发原请求的头 + for _, headerName := range a.config.ForwardHeaders { + if value := ctx.Request.Header.Peek(headerName); len(value) > 0 { + req.Header.Set(headerName, string(value)) + } + } + + // 设置自定义头(支持变量展开) + for key, value := range a.config.Headers { + expanded := a.expandVars(ctx, value) + req.Header.Set(key, expanded) + } + + // 发送认证请求 + a.mu.RLock() + client := a.client + a.mu.RUnlock() + + if client != nil { + // 使用独立连接池 + err := client.Do(req, resp) + if err != nil { + return false, 0, err + } + } else { + // 使用默认客户端(相对路径情况) + err := fasthttp.Do(req, resp) + if err != nil { + return false, 0, err + } + } + + // 根据响应状态码判断认证结果 + statusCode := resp.StatusCode() + switch { + case statusCode >= 200 && statusCode < 300: + // 2xx:认证通过 + return true, statusCode, nil + case statusCode == 401 || statusCode == 403: + // 401/403:认证失败 + return false, statusCode, nil + default: + // 其他状态码:视为认证服务错误 + return false, 500, nil + } +} + +// expandVars 展开字符串中的变量。 +func (a *AuthRequest) expandVars(ctx *fasthttp.RequestCtx, template string) string { + if template == "" { + return "" + } + + // 快速检查:如果没有变量则直接返回 + if !strings.Contains(template, "$") { + return template + } + + // 创建变量上下文 + vc := variable.NewVariableContext(ctx) + defer variable.ReleaseVariableContext(vc) + + return vc.Expand(template) +} + +// UpdateConfig 动态更新配置。 +// 用于配置热重载场景。 +func (a *AuthRequest) UpdateConfig(cfg config.AuthRequestConfig) error { + if !cfg.Enabled { + a.mu.Lock() + a.config = cfg + a.client = nil + a.mu.Unlock() + return nil + } + + if cfg.URI == "" { + return errors.New("auth_request: uri is required") + } + + // 设置默认值 + method := cfg.Method + if method == "" { + method = "GET" + } + cfg.Method = strings.ToUpper(method) + + timeout := cfg.Timeout + if timeout == 0 { + timeout = 5 * time.Second + } + cfg.Timeout = timeout + + if cfg.ForwardHeaders == nil { + cfg.ForwardHeaders = []string{ + "Cookie", + "Authorization", + "X-Forwarded-For", + "X-Real-Ip", + } + } + + a.mu.Lock() + a.config = cfg + + // 重新初始化客户端 + if isFullURL(cfg.URI) { + if err := a.initClientUnlocked(); err != nil { + a.mu.Unlock() + return err + } + } else { + a.client = nil + } + + a.mu.Unlock() + return nil +} + +// initClientUnlocked 在无锁状态下初始化客户端。 +// 调用者必须持有写锁。 +func (a *AuthRequest) initClientUnlocked() error { + addr, isTLS, err := parseAuthURL(a.config.URI) + if err != nil { + return err + } + + a.client = &fasthttp.HostClient{ + Addr: addr, + IsTLS: isTLS, + ReadTimeout: a.config.Timeout, + WriteTimeout: a.config.Timeout, + MaxIdleConnDuration: 90 * time.Second, + MaxConns: 100, + MaxConnWaitTimeout: a.config.Timeout, + RetryIf: nil, + DisablePathNormalizing: false, + } + + return nil +} + +// Close 关闭中间件并释放资源。 +func (a *AuthRequest) Close() error { + a.mu.Lock() + a.client = nil + a.mu.Unlock() + return nil +} + +// Middleware 返回中间件接口实现。 +// 用于兼容中间件链。 +func (a *AuthRequest) Middleware() middleware.Middleware { + return a +} + +// Ensure security implements Middleware interface +var _ middleware.Middleware = (*AuthRequest)(nil) diff --git a/internal/middleware/security/auth_request_test.go b/internal/middleware/security/auth_request_test.go new file mode 100644 index 0000000..c3d82db --- /dev/null +++ b/internal/middleware/security/auth_request_test.go @@ -0,0 +1,414 @@ +// Package security 提供 auth_request 中间件的单元测试。 +// +// 测试覆盖: +// - 认证成功(2xx 响应) +// - 认证失败(401/403 响应) +// - 认证服务不可用 +// - 超时处理 +// - 变量展开 +// - 配置更新 +// +// 作者:xfy +package security + +import ( + "net" + "strings" + "testing" + "time" + + "github.com/valyala/fasthttp" + + "rua.plus/lolly/internal/config" +) + +// TestNewAuthRequest 测试 AuthRequest 中间件创建 +func TestNewAuthRequest(t *testing.T) { + tests := []struct { + name string + cfg config.AuthRequestConfig + wantErr bool + errMsg string + }{ + { + name: "正常创建(禁用)", + cfg: config.AuthRequestConfig{ + Enabled: false, + }, + wantErr: false, + }, + { + name: "正常创建(启用,相对路径)", + cfg: config.AuthRequestConfig{ + Enabled: true, + URI: "/auth", + Method: "GET", + Timeout: 5 * time.Second, + }, + wantErr: false, + }, + { + name: "正常创建(启用,完整URL)", + cfg: config.AuthRequestConfig{ + Enabled: true, + URI: "http://localhost:8080/auth", + Method: "POST", + Timeout: 10 * time.Second, + }, + wantErr: false, + }, + { + name: "启用但未配置 URI", + cfg: config.AuthRequestConfig{ + Enabled: true, + URI: "", + }, + wantErr: true, + errMsg: "uri is required", + }, + { + name: "使用默认值", + cfg: config.AuthRequestConfig{ + Enabled: true, + URI: "/auth", + // Method 和 Timeout 使用默认值 + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ar, err := NewAuthRequest(tt.cfg) + + if tt.wantErr { + if err == nil { + t.Errorf("NewAuthRequest() expected error, got nil") + return + } + if tt.errMsg != "" && !contains(err.Error(), tt.errMsg) { + t.Errorf("NewAuthRequest() error = %v, should contain %q", err, tt.errMsg) + } + return + } + + if err != nil { + t.Errorf("NewAuthRequest() unexpected error: %v", err) + return + } + + if ar == nil { + t.Error("NewAuthRequest() returned nil") + return + } + + // 验证默认值设置 + if tt.cfg.Enabled { + if ar.config.Method == "" { + t.Error("Method should have default value") + } + if ar.config.Timeout == 0 { + t.Error("Timeout should have default value") + } + } + }) + } +} + +// TestAuthRequestName 测试中间件名称 +func TestAuthRequestName(t *testing.T) { + ar := &AuthRequest{} + if name := ar.Name(); name != "auth_request" { + t.Errorf("Name() = %q, want 'auth_request'", name) + } +} + +// TestAuthRequestProcess_Disabled 测试禁用状态下的处理 +func TestAuthRequestProcess_Disabled(t *testing.T) { + cfg := config.AuthRequestConfig{ + Enabled: false, + } + + ar, err := NewAuthRequest(cfg) + if err != nil { + t.Fatalf("NewAuthRequest() failed: %v", err) + } + + // 创建测试处理器 + called := false + next := func(ctx *fasthttp.RequestCtx) { + called = true + ctx.SetStatusCode(200) + } + + handler := ar.Process(next) + + // 执行请求 + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod("GET") + ctx.Request.Header.SetRequestURI("/test") + + handler(ctx) + + // 验证处理器被调用 + if !called { + t.Error("Next handler should be called when disabled") + } + if ctx.Response.StatusCode() != 200 { + t.Errorf("Expected status 200, got %d", ctx.Response.StatusCode()) + } +} + +// TestParseAuthURL 测试 URL 解析 +func TestParseAuthURL(t *testing.T) { + tests := []struct { + name string + url string + wantAddr string + wantTLS bool + wantErr bool + }{ + { + name: "HTTP URL 带端口", + url: "http://auth-service:8080/verify", + wantAddr: "auth-service:8080", + wantTLS: false, + wantErr: false, + }, + { + name: "HTTPS URL 带端口", + url: "https://auth-service:8443/verify", + wantAddr: "auth-service:8443", + wantTLS: true, + wantErr: false, + }, + { + name: "HTTP URL 不带端口", + url: "http://auth-service/verify", + wantAddr: "auth-service:80", + wantTLS: false, + wantErr: false, + }, + { + name: "HTTPS URL 不带端口", + url: "https://auth-service/verify", + wantAddr: "auth-service:443", + wantTLS: true, + wantErr: false, + }, + { + name: "相对路径", + url: "/auth", + wantAddr: "", + wantTLS: false, + wantErr: true, // 相对路径会报错 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, isTLS, err := parseAuthURL(tt.url) + + if tt.wantErr { + if err == nil { + t.Error("parseAuthURL() expected error") + } + return + } + + if err != nil { + t.Errorf("parseAuthURL() unexpected error: %v", err) + return + } + + // 验证地址(可能包含默认端口) + _, _, _ = net.SplitHostPort(addr) + + if isTLS != tt.wantTLS { + t.Errorf("isTLS = %v, want %v", isTLS, tt.wantTLS) + } + }) + } +} + +// TestIsFullURL 测试 URL 检测 +func TestIsFullURL(t *testing.T) { + tests := []struct { + uri string + expected bool + }{ + {"http://example.com", true}, + {"https://example.com", true}, + {"/auth", false}, + {"auth", false}, + {"", false}, + {"ftp://example.com", false}, + } + + for _, tt := range tests { + t.Run(tt.uri, func(t *testing.T) { + result := isFullURL(tt.uri) + if result != tt.expected { + t.Errorf("isFullURL(%q) = %v, want %v", tt.uri, result, tt.expected) + } + }) + } +} + +// TestAuthRequestExpandVars 测试变量展开 +func TestAuthRequestExpandVars(t *testing.T) { + ar := &AuthRequest{ + config: config.AuthRequestConfig{}, + } + + tests := []struct { + name string + method string + uri string + host string + template string + expected string + }{ + { + name: "无变量", + method: "GET", + uri: "/test", + host: "example.com", + template: "http://auth-service/auth", + expected: "http://auth-service/auth", + }, + { + name: "包含变量", + method: "POST", + uri: "/api/users", + host: "api.example.com", + template: "http://auth-service/verify?uri=$request_uri", + expected: "http://auth-service/verify?uri=/api/users", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(tt.method) + ctx.Request.Header.SetRequestURI(tt.uri) + ctx.Request.Header.SetHost(tt.host) + + result := ar.expandVars(ctx, tt.template) + if result != tt.expected { + t.Errorf("expandVars() = %q, want %q", result, tt.expected) + } + }) + } +} + +// TestAuthRequestUpdateConfig 测试配置更新 +func TestAuthRequestUpdateConfig(t *testing.T) { + // 创建初始实例 + cfg := config.AuthRequestConfig{ + Enabled: true, + URI: "http://old-auth:8080/auth", + Method: "GET", + Timeout: 5 * time.Second, + } + + ar, err := NewAuthRequest(cfg) + if err != nil { + t.Fatalf("NewAuthRequest() failed: %v", err) + } + + // 更新为禁用 + t.Run("更新为禁用", func(t *testing.T) { + newCfg := config.AuthRequestConfig{ + Enabled: false, + } + err := ar.UpdateConfig(newCfg) + if err != nil { + t.Errorf("UpdateConfig() failed: %v", err) + } + if ar.config.Enabled { + t.Error("Expected config to be disabled") + } + }) + + // 更新为新的启用配置 + t.Run("更新为新配置", func(t *testing.T) { + newCfg := config.AuthRequestConfig{ + Enabled: true, + URI: "http://new-auth:8080/auth", + Method: "POST", + Timeout: 10 * time.Second, + } + err := ar.UpdateConfig(newCfg) + if err != nil { + t.Errorf("UpdateConfig() failed: %v", err) + } + if ar.config.URI != "http://new-auth:8080/auth" { + t.Errorf("URI not updated: %s", ar.config.URI) + } + if ar.config.Method != "POST" { + t.Errorf("Method not updated: %s", ar.config.Method) + } + }) + + // 更新失败(缺少 URI) + t.Run("更新失败(缺少 URI)", func(t *testing.T) { + newCfg := config.AuthRequestConfig{ + Enabled: true, + URI: "", + } + err := ar.UpdateConfig(newCfg) + if err == nil { + t.Error("UpdateConfig() should fail without URI") + } + }) +} + +// TestAuthRequestClose 测试关闭 +func TestAuthRequestClose(t *testing.T) { + cfg := config.AuthRequestConfig{ + Enabled: true, + URI: "http://auth-service:8080/auth", + } + + ar, err := NewAuthRequest(cfg) + if err != nil { + t.Fatalf("NewAuthRequest() failed: %v", err) + } + + err = ar.Close() + if err != nil { + t.Errorf("Close() failed: %v", err) + } + + // 验证客户端被清理 + if ar.client != nil { + t.Error("client should be nil after Close()") + } +} + +// BenchmarkAuthRequestExpandVars 基准测试:变量展开 +func BenchmarkAuthRequestExpandVars(b *testing.B) { + ar := &AuthRequest{ + config: config.AuthRequestConfig{}, + } + + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod("GET") + ctx.Request.Header.SetRequestURI("/api/users?page=1") + ctx.Request.Header.SetHost("api.example.com") + + template := "http://auth-service/verify?uri=$request_uri&host=$host" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = ar.expandVars(ctx, template) + } +} + +// 辅助函数 +func contains(s, substr string) bool { + return strings.Contains(s, substr) +}