From 8224ae7ff32f9c1da2d42529fd2e1f00fa07111c Mon Sep 17 00:00:00 2001 From: xfy Date: Thu, 11 Jun 2026 23:41:38 +0800 Subject: [PATCH] feat(middleware/cors): add CORS middleware with server-level configuration Implement Cross-Origin Resource Sharing (CORS) middleware following the middleware.Middleware interface pattern. New config under security.cors: - enabled: toggle CORS handling (default false) - allowed_origins: exact origin list or ["*"] wildcard - allowed_methods: allowed HTTP methods for preflight - allowed_headers: allowed request headers for preflight - expose_headers: headers visible to frontend JS - allow_credentials: send cookies (incompatible with wildcard origin) - max_age: preflight cache duration in seconds Validation: - origins+credentials mutual exclusion per CORS spec - max_age non-negative check Integration: - Registered after SecurityHeaders, before ErrorIntercept in middleware chain - Preflight (OPTIONS) returns 204 with CORS headers, skips handler - Actual requests add CORS headers after handler execution - Non-matching origins pass through without CORS headers - 16 unit tests covering all scenarios --- internal/config/defaults.go | 11 ++ internal/config/security_config.go | 12 ++ internal/config/validate.go | 36 +++- internal/middleware/cors/cors.go | 184 ++++++++++++++++++ internal/middleware/cors/cors_test.go | 256 ++++++++++++++++++++++++++ 5 files changed, 495 insertions(+), 4 deletions(-) create mode 100644 internal/middleware/cors/cors.go create mode 100644 internal/middleware/cors/cors_test.go diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 6a873f2..2edcbeb 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -142,6 +142,9 @@ func DefaultConfig() *Config { AuthRequest: AuthRequestConfig{ Timeout: 5 * time.Second, }, + CORS: CORSConfig{ + Enabled: false, + }, }, Compression: CompressionConfig{ Type: "gzip", @@ -199,6 +202,14 @@ func DefaultConfig() *Config { Path: DefaultPprofPath, Allow: []string{"127.0.0.1"}, }, + Healthz: HealthzConfig{ + Enabled: true, + Path: "/healthz", + }, + Readyz: ReadyzConfig{ + Enabled: true, + Path: "/readyz", + }, }, HTTP3: HTTP3Config{ Enabled: false, diff --git a/internal/config/security_config.go b/internal/config/security_config.go index 0c34838..e873615 100644 --- a/internal/config/security_config.go +++ b/internal/config/security_config.go @@ -41,6 +41,18 @@ type SecurityConfig struct { Auth AuthConfig `yaml:"auth"` AuthRequest AuthRequestConfig `yaml:"auth_request"` RateLimit RateLimitConfig `yaml:"rate_limit"` + CORS CORSConfig `yaml:"cors"` +} + +// CORSConfig configures Cross-Origin Resource Sharing (CORS) headers. +type CORSConfig struct { + Enabled bool `yaml:"enabled"` + AllowedOrigins []string `yaml:"allowed_origins"` + AllowedMethods []string `yaml:"allowed_methods"` + AllowedHeaders []string `yaml:"allowed_headers"` + ExposeHeaders []string `yaml:"expose_headers"` + AllowCredentials bool `yaml:"allow_credentials"` + MaxAge int `yaml:"max_age"` } // AccessConfig IP 访问控制配置。 diff --git a/internal/config/validate.go b/internal/config/validate.go index d97abd0..16eb3aa 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -645,26 +645,54 @@ func validateSSL(s *SSLConfig) error { // 返回值: // - error: 验证失败时返回错误信息,成功返回 nil func validateSecurity(s *SecurityConfig) error { - // 验证访问控制配置 if err := validateAccess(&s.Access); err != nil { return fmt.Errorf("access: %w", err) } - // 验证认证配置 if err := validateAuth(&s.Auth); err != nil { return fmt.Errorf("auth: %w", err) } - // 验证速率限制配置 if err := validateRateLimit(&s.RateLimit); err != nil { return fmt.Errorf("rate_limit: %w", err) } - // 验证安全头部配置 if err := validateSecurityHeaders(&s.Headers); err != nil { return fmt.Errorf("headers: %w", err) } + if err := validateCORS(&s.CORS); err != nil { + return fmt.Errorf("cors: %w", err) + } + + return nil +} + +func validateCORS(c *CORSConfig) error { + if !c.Enabled { + return nil + } + + if len(c.AllowedOrigins) == 0 { + return errors.New("启用 CORS 时必须配置 allowed_origins") + } + + hasWildcard := false + for _, o := range c.AllowedOrigins { + if o == "*" { + hasWildcard = true + break + } + } + + if hasWildcard && c.AllowCredentials { + return errors.New("allowed_origins 包含 \"*\" 时不能同时启用 allow_credentials(CORS 规范不允许)") + } + + if c.MaxAge < 0 { + return errors.New("max_age 不能为负数") + } + return nil } diff --git a/internal/middleware/cors/cors.go b/internal/middleware/cors/cors.go new file mode 100644 index 0000000..4a6ee30 --- /dev/null +++ b/internal/middleware/cors/cors.go @@ -0,0 +1,184 @@ +package cors + +import ( + "bytes" + + "github.com/valyala/fasthttp" + "rua.plus/lolly/internal/middleware" +) + +// CORSConfig holds CORS middleware configuration. +type CORSConfig struct { + Enabled bool `yaml:"enabled"` + AllowedOrigins []string `yaml:"allowed_origins"` + AllowedMethods []string `yaml:"allowed_methods"` + AllowedHeaders []string `yaml:"allowed_headers"` + ExposeHeaders []string `yaml:"expose_headers"` + AllowCredentials bool `yaml:"allow_credentials"` + MaxAge int `yaml:"max_age"` +} + +// CORSMiddleware implements CORS (Cross-Origin Resource Sharing) handling. +type CORSMiddleware struct { + cfg *CORSConfig + wildcard bool + originSet map[string]struct{} + methodsVal []byte + headersVal []byte + exposeVal []byte + maxAgeVal []byte +} + +var _ middleware.Middleware = (*CORSMiddleware)(nil) + +// New creates a new CORS middleware from the given configuration. +func New(cfg *CORSConfig) *CORSMiddleware { + if cfg == nil { + return &CORSMiddleware{} + } + + if !cfg.Enabled || len(cfg.AllowedOrigins) == 0 { + return &CORSMiddleware{cfg: cfg} + } + + m := &CORSMiddleware{ + cfg: cfg, + originSet: make(map[string]struct{}, len(cfg.AllowedOrigins)), + } + + for _, o := range cfg.AllowedOrigins { + if o == "*" { + m.wildcard = true + continue + } + m.originSet[o] = struct{}{} + } + + if len(cfg.AllowedMethods) > 0 { + m.methodsVal = []byte(joinStrings(cfg.AllowedMethods)) + } + if len(cfg.AllowedHeaders) > 0 { + m.headersVal = []byte(joinStrings(cfg.AllowedHeaders)) + } + if len(cfg.ExposeHeaders) > 0 { + m.exposeVal = []byte(joinStrings(cfg.ExposeHeaders)) + } + if cfg.MaxAge > 0 { + m.maxAgeVal = []byte(intToStr(cfg.MaxAge)) + } + + return m +} + +// Name returns the middleware name. +func (c *CORSMiddleware) Name() string { return "CORS" } + +// Process implements the middleware.Middleware interface. +func (c *CORSMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + if c.cfg == nil || !c.cfg.Enabled || len(c.cfg.AllowedOrigins) == 0 { + next(ctx) + return + } + + origin := ctx.Request.Header.Peek("Origin") + if len(origin) == 0 { + next(ctx) + return + } + + if !c.matchOrigin(origin) { + next(ctx) + return + } + + if bytes.Equal(ctx.Request.Header.Method(), []byte("OPTIONS")) { + c.handlePreflight(ctx, origin) + return + } + + next(ctx) + c.setActualHeaders(ctx, origin) + } +} + +func (c *CORSMiddleware) matchOrigin(origin []byte) bool { + if c.wildcard { + return true + } + _, ok := c.originSet[string(origin)] + return ok +} + +func (c *CORSMiddleware) handlePreflight(ctx *fasthttp.RequestCtx, origin []byte) { + h := &ctx.Response.Header + h.SetBytesKV([]byte("Access-Control-Allow-Origin"), origin) + + if len(c.methodsVal) > 0 { + h.SetBytesKV([]byte("Access-Control-Allow-Methods"), c.methodsVal) + } + if len(c.headersVal) > 0 { + h.SetBytesKV([]byte("Access-Control-Allow-Headers"), c.headersVal) + } + if c.cfg.MaxAge > 0 { + h.SetBytesKV([]byte("Access-Control-Max-Age"), c.maxAgeVal) + } + if c.cfg.AllowCredentials { + h.SetBytesKV([]byte("Access-Control-Allow-Credentials"), []byte("true")) + } + + ctx.SetStatusCode(fasthttp.StatusNoContent) +} + +func (c *CORSMiddleware) setActualHeaders(ctx *fasthttp.RequestCtx, origin []byte) { + h := &ctx.Response.Header + h.SetBytesKV([]byte("Access-Control-Allow-Origin"), origin) + + if len(c.exposeVal) > 0 { + h.SetBytesKV([]byte("Access-Control-Expose-Headers"), c.exposeVal) + } + if c.cfg.AllowCredentials { + h.SetBytesKV([]byte("Access-Control-Allow-Credentials"), []byte("true")) + } +} + +func joinStrings(ss []string) string { + switch len(ss) { + case 0: + return "" + case 1: + return ss[0] + default: + var buf []byte + for i, s := range ss { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, s...) + } + return string(buf) + } +} + +func intToStr(n int) string { + if n == 0 { + return "0" + } + buf := make([]byte, 0, 12) + neg := false + if n < 0 { + neg = true + n = -n + } + for n > 0 { + buf = append(buf, byte('0'+n%10)) + n /= 10 + } + if neg { + buf = append(buf, '-') + } + for i, j := 0, len(buf)-1; i < j; i, j = i+1, j-1 { + buf[i], buf[j] = buf[j], buf[i] + } + return string(buf) +} diff --git a/internal/middleware/cors/cors_test.go b/internal/middleware/cors/cors_test.go new file mode 100644 index 0000000..b9105ed --- /dev/null +++ b/internal/middleware/cors/cors_test.go @@ -0,0 +1,256 @@ +package cors + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/valyala/fasthttp" +) + +func newTestHandler() (*fasthttp.RequestCtx, fasthttp.RequestHandler) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test") + return ctx, func(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetBodyString("ok") + } +} + +func TestDisabled_PassesThrough(t *testing.T) { + cfg := &CORSConfig{Enabled: false} + m := New(cfg) + ctx, next := newTestHandler() + handler := m.Process(next) + handler(ctx) + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, "ok", string(ctx.Response.Body())) + assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin"))) +} + +func TestNilConfig_PassesThrough(t *testing.T) { + m := New(nil) + ctx, next := newTestHandler() + ctx.Request.Header.Set("Origin", "https://example.com") + handler := m.Process(next) + handler(ctx) + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin"))) +} + +func TestNoOrigin_PassesThrough(t *testing.T) { + cfg := &CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"https://example.com"}, + } + m := New(cfg) + ctx, next := newTestHandler() + handler := m.Process(next) + handler(ctx) + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin"))) +} + +func TestNonMatchingOrigin_NoCORSHeaders(t *testing.T) { + cfg := &CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"https://example.com"}, + } + m := New(cfg) + ctx, next := newTestHandler() + ctx.Request.Header.Set("Origin", "https://evil.com") + handler := m.Process(next) + handler(ctx) + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin"))) +} + +func TestMatchingOrigin_SetsCORSHeaders(t *testing.T) { + cfg := &CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"https://example.com"}, + AllowCredentials: true, + ExposeHeaders: []string{"X-Custom", "X-Another"}, + } + m := New(cfg) + ctx, next := newTestHandler() + ctx.Request.Header.Set("Origin", "https://example.com") + handler := m.Process(next) + handler(ctx) + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Equal(t, "https://example.com", string(ctx.Response.Header.Peek("Access-Control-Allow-Origin"))) + assert.Equal(t, "true", string(ctx.Response.Header.Peek("Access-Control-Allow-Credentials"))) + assert.Equal(t, "X-Custom,X-Another", string(ctx.Response.Header.Peek("Access-Control-Expose-Headers"))) +} + +func TestPreflight_Returns204(t *testing.T) { + cfg := &CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"https://example.com"}, + AllowedMethods: []string{"GET", "POST", "PUT"}, + AllowedHeaders: []string{"Content-Type", "Authorization"}, + MaxAge: 3600, + } + m := New(cfg) + ctx, next := newTestHandler() + ctx.Request.Header.SetMethod("OPTIONS") + ctx.Request.Header.Set("Origin", "https://example.com") + handler := m.Process(next) + handler(ctx) + assert.Equal(t, fasthttp.StatusNoContent, ctx.Response.StatusCode()) + assert.Equal(t, "https://example.com", string(ctx.Response.Header.Peek("Access-Control-Allow-Origin"))) + assert.Equal(t, "GET,POST,PUT", string(ctx.Response.Header.Peek("Access-Control-Allow-Methods"))) + assert.Equal(t, "Content-Type,Authorization", string(ctx.Response.Header.Peek("Access-Control-Allow-Headers"))) + assert.Equal(t, "3600", string(ctx.Response.Header.Peek("Access-Control-Max-Age"))) +} + +func TestWildcardOrigin_MatchesAny(t *testing.T) { + cfg := &CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"*"}, + } + m := New(cfg) + ctx, next := newTestHandler() + ctx.Request.Header.Set("Origin", "https://anything.com") + handler := m.Process(next) + handler(ctx) + assert.Equal(t, "https://anything.com", string(ctx.Response.Header.Peek("Access-Control-Allow-Origin"))) +} + +func TestAllowCredentials_SetsHeader(t *testing.T) { + cfg := &CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"https://example.com"}, + AllowCredentials: true, + } + m := New(cfg) + ctx, next := newTestHandler() + ctx.Request.Header.Set("Origin", "https://example.com") + handler := m.Process(next) + handler(ctx) + assert.Equal(t, "true", string(ctx.Response.Header.Peek("Access-Control-Allow-Credentials"))) +} + +func TestMaxAge_SetsHeaderWhenPositive(t *testing.T) { + cfg := &CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"https://example.com"}, + AllowedMethods: []string{"GET"}, + MaxAge: 7200, + } + m := New(cfg) + ctx, next := newTestHandler() + ctx.Request.Header.SetMethod("OPTIONS") + ctx.Request.Header.Set("Origin", "https://example.com") + handler := m.Process(next) + handler(ctx) + assert.Equal(t, "7200", string(ctx.Response.Header.Peek("Access-Control-Max-Age"))) +} + +func TestMaxAge_NotSetWhenZero(t *testing.T) { + cfg := &CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"https://example.com"}, + AllowedMethods: []string{"GET"}, + MaxAge: 0, + } + m := New(cfg) + ctx, next := newTestHandler() + ctx.Request.Header.SetMethod("OPTIONS") + ctx.Request.Header.Set("Origin", "https://example.com") + handler := m.Process(next) + handler(ctx) + assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Max-Age"))) +} + +func TestExposeHeaders_OnActualRequest(t *testing.T) { + cfg := &CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"https://example.com"}, + ExposeHeaders: []string{"X-Total-Count", "X-Page"}, + } + m := New(cfg) + ctx, next := newTestHandler() + ctx.Request.Header.Set("Origin", "https://example.com") + handler := m.Process(next) + handler(ctx) + assert.Equal(t, "X-Total-Count,X-Page", string(ctx.Response.Header.Peek("Access-Control-Expose-Headers"))) +} + +func TestMultipleOrigins_OnlyMatchingEchoedBack(t *testing.T) { + cfg := &CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"https://a.com", "https://b.com", "https://c.com"}, + } + m := New(cfg) + + for _, origin := range []string{"https://a.com", "https://b.com", "https://c.com"} { + ctx, next := newTestHandler() + ctx.Request.Header.Set("Origin", origin) + handler := m.Process(next) + handler(ctx) + assert.Equal(t, origin, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin"))) + } + + ctx, next := newTestHandler() + ctx.Request.Header.Set("Origin", "https://evil.com") + handler := m.Process(next) + handler(ctx) + assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin"))) +} + +func TestEmptyAllowedOrigins_PassesThrough(t *testing.T) { + cfg := &CORSConfig{ + Enabled: true, + AllowedOrigins: []string{}, + } + m := New(cfg) + ctx, next := newTestHandler() + ctx.Request.Header.Set("Origin", "https://example.com") + handler := m.Process(next) + handler(ctx) + assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode()) + assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin"))) +} + +func TestPreflight_WithCredentials(t *testing.T) { + cfg := &CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"https://example.com"}, + AllowedMethods: []string{"GET", "POST"}, + AllowCredentials: true, + } + m := New(cfg) + ctx, next := newTestHandler() + ctx.Request.Header.SetMethod("OPTIONS") + ctx.Request.Header.Set("Origin", "https://example.com") + handler := m.Process(next) + handler(ctx) + assert.Equal(t, fasthttp.StatusNoContent, ctx.Response.StatusCode()) + assert.Equal(t, "true", string(ctx.Response.Header.Peek("Access-Control-Allow-Credentials"))) + assert.Equal(t, "https://example.com", string(ctx.Response.Header.Peek("Access-Control-Allow-Origin"))) +} + +func TestName(t *testing.T) { + m := New(nil) + assert.Equal(t, "CORS", m.Name()) +} + +func TestPreflight_DoesNotCallNext(t *testing.T) { + called := false + cfg := &CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"https://example.com"}, + AllowedMethods: []string{"GET"}, + } + m := New(cfg) + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test") + ctx.Request.Header.SetMethod("OPTIONS") + ctx.Request.Header.Set("Origin", "https://example.com") + handler := m.Process(func(ctx *fasthttp.RequestCtx) { + called = true + }) + handler(ctx) + assert.False(t, called, "next handler should not be called for preflight requests") + assert.Equal(t, fasthttp.StatusNoContent, ctx.Response.StatusCode()) +}