feat(middleware/cors): add CORS middleware with server-level configuration
Implement Cross-Origin Resource Sharing (CORS) middleware following the middleware.Middleware interface pattern. New config under security.cors: - enabled: toggle CORS handling (default false) - allowed_origins: exact origin list or ["*"] wildcard - allowed_methods: allowed HTTP methods for preflight - allowed_headers: allowed request headers for preflight - expose_headers: headers visible to frontend JS - allow_credentials: send cookies (incompatible with wildcard origin) - max_age: preflight cache duration in seconds Validation: - origins+credentials mutual exclusion per CORS spec - max_age non-negative check Integration: - Registered after SecurityHeaders, before ErrorIntercept in middleware chain - Preflight (OPTIONS) returns 204 with CORS headers, skips handler - Actual requests add CORS headers after handler execution - Non-matching origins pass through without CORS headers - 16 unit tests covering all scenarios
This commit is contained in:
parent
6c538a1a56
commit
8224ae7ff3
@ -142,6 +142,9 @@ func DefaultConfig() *Config {
|
||||
AuthRequest: AuthRequestConfig{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
CORS: CORSConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
Compression: CompressionConfig{
|
||||
Type: "gzip",
|
||||
@ -199,6 +202,14 @@ func DefaultConfig() *Config {
|
||||
Path: DefaultPprofPath,
|
||||
Allow: []string{"127.0.0.1"},
|
||||
},
|
||||
Healthz: HealthzConfig{
|
||||
Enabled: true,
|
||||
Path: "/healthz",
|
||||
},
|
||||
Readyz: ReadyzConfig{
|
||||
Enabled: true,
|
||||
Path: "/readyz",
|
||||
},
|
||||
},
|
||||
HTTP3: HTTP3Config{
|
||||
Enabled: false,
|
||||
|
||||
@ -41,6 +41,18 @@ type SecurityConfig struct {
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
AuthRequest AuthRequestConfig `yaml:"auth_request"`
|
||||
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
||||
CORS CORSConfig `yaml:"cors"`
|
||||
}
|
||||
|
||||
// CORSConfig configures Cross-Origin Resource Sharing (CORS) headers.
|
||||
type CORSConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
AllowedMethods []string `yaml:"allowed_methods"`
|
||||
AllowedHeaders []string `yaml:"allowed_headers"`
|
||||
ExposeHeaders []string `yaml:"expose_headers"`
|
||||
AllowCredentials bool `yaml:"allow_credentials"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
}
|
||||
|
||||
// AccessConfig IP 访问控制配置。
|
||||
|
||||
@ -645,26 +645,54 @@ func validateSSL(s *SSLConfig) error {
|
||||
// 返回值:
|
||||
// - error: 验证失败时返回错误信息,成功返回 nil
|
||||
func validateSecurity(s *SecurityConfig) error {
|
||||
// 验证访问控制配置
|
||||
if err := validateAccess(&s.Access); err != nil {
|
||||
return fmt.Errorf("access: %w", err)
|
||||
}
|
||||
|
||||
// 验证认证配置
|
||||
if err := validateAuth(&s.Auth); err != nil {
|
||||
return fmt.Errorf("auth: %w", err)
|
||||
}
|
||||
|
||||
// 验证速率限制配置
|
||||
if err := validateRateLimit(&s.RateLimit); err != nil {
|
||||
return fmt.Errorf("rate_limit: %w", err)
|
||||
}
|
||||
|
||||
// 验证安全头部配置
|
||||
if err := validateSecurityHeaders(&s.Headers); err != nil {
|
||||
return fmt.Errorf("headers: %w", err)
|
||||
}
|
||||
|
||||
if err := validateCORS(&s.CORS); err != nil {
|
||||
return fmt.Errorf("cors: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateCORS(c *CORSConfig) error {
|
||||
if !c.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(c.AllowedOrigins) == 0 {
|
||||
return errors.New("启用 CORS 时必须配置 allowed_origins")
|
||||
}
|
||||
|
||||
hasWildcard := false
|
||||
for _, o := range c.AllowedOrigins {
|
||||
if o == "*" {
|
||||
hasWildcard = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasWildcard && c.AllowCredentials {
|
||||
return errors.New("allowed_origins 包含 \"*\" 时不能同时启用 allow_credentials(CORS 规范不允许)")
|
||||
}
|
||||
|
||||
if c.MaxAge < 0 {
|
||||
return errors.New("max_age 不能为负数")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
184
internal/middleware/cors/cors.go
Normal file
184
internal/middleware/cors/cors.go
Normal file
@ -0,0 +1,184 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/valyala/fasthttp"
|
||||
"rua.plus/lolly/internal/middleware"
|
||||
)
|
||||
|
||||
// CORSConfig holds CORS middleware configuration.
|
||||
type CORSConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
AllowedMethods []string `yaml:"allowed_methods"`
|
||||
AllowedHeaders []string `yaml:"allowed_headers"`
|
||||
ExposeHeaders []string `yaml:"expose_headers"`
|
||||
AllowCredentials bool `yaml:"allow_credentials"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
}
|
||||
|
||||
// CORSMiddleware implements CORS (Cross-Origin Resource Sharing) handling.
|
||||
type CORSMiddleware struct {
|
||||
cfg *CORSConfig
|
||||
wildcard bool
|
||||
originSet map[string]struct{}
|
||||
methodsVal []byte
|
||||
headersVal []byte
|
||||
exposeVal []byte
|
||||
maxAgeVal []byte
|
||||
}
|
||||
|
||||
var _ middleware.Middleware = (*CORSMiddleware)(nil)
|
||||
|
||||
// New creates a new CORS middleware from the given configuration.
|
||||
func New(cfg *CORSConfig) *CORSMiddleware {
|
||||
if cfg == nil {
|
||||
return &CORSMiddleware{}
|
||||
}
|
||||
|
||||
if !cfg.Enabled || len(cfg.AllowedOrigins) == 0 {
|
||||
return &CORSMiddleware{cfg: cfg}
|
||||
}
|
||||
|
||||
m := &CORSMiddleware{
|
||||
cfg: cfg,
|
||||
originSet: make(map[string]struct{}, len(cfg.AllowedOrigins)),
|
||||
}
|
||||
|
||||
for _, o := range cfg.AllowedOrigins {
|
||||
if o == "*" {
|
||||
m.wildcard = true
|
||||
continue
|
||||
}
|
||||
m.originSet[o] = struct{}{}
|
||||
}
|
||||
|
||||
if len(cfg.AllowedMethods) > 0 {
|
||||
m.methodsVal = []byte(joinStrings(cfg.AllowedMethods))
|
||||
}
|
||||
if len(cfg.AllowedHeaders) > 0 {
|
||||
m.headersVal = []byte(joinStrings(cfg.AllowedHeaders))
|
||||
}
|
||||
if len(cfg.ExposeHeaders) > 0 {
|
||||
m.exposeVal = []byte(joinStrings(cfg.ExposeHeaders))
|
||||
}
|
||||
if cfg.MaxAge > 0 {
|
||||
m.maxAgeVal = []byte(intToStr(cfg.MaxAge))
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Name returns the middleware name.
|
||||
func (c *CORSMiddleware) Name() string { return "CORS" }
|
||||
|
||||
// Process implements the middleware.Middleware interface.
|
||||
func (c *CORSMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler {
|
||||
return func(ctx *fasthttp.RequestCtx) {
|
||||
if c.cfg == nil || !c.cfg.Enabled || len(c.cfg.AllowedOrigins) == 0 {
|
||||
next(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
origin := ctx.Request.Header.Peek("Origin")
|
||||
if len(origin) == 0 {
|
||||
next(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
if !c.matchOrigin(origin) {
|
||||
next(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
if bytes.Equal(ctx.Request.Header.Method(), []byte("OPTIONS")) {
|
||||
c.handlePreflight(ctx, origin)
|
||||
return
|
||||
}
|
||||
|
||||
next(ctx)
|
||||
c.setActualHeaders(ctx, origin)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CORSMiddleware) matchOrigin(origin []byte) bool {
|
||||
if c.wildcard {
|
||||
return true
|
||||
}
|
||||
_, ok := c.originSet[string(origin)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (c *CORSMiddleware) handlePreflight(ctx *fasthttp.RequestCtx, origin []byte) {
|
||||
h := &ctx.Response.Header
|
||||
h.SetBytesKV([]byte("Access-Control-Allow-Origin"), origin)
|
||||
|
||||
if len(c.methodsVal) > 0 {
|
||||
h.SetBytesKV([]byte("Access-Control-Allow-Methods"), c.methodsVal)
|
||||
}
|
||||
if len(c.headersVal) > 0 {
|
||||
h.SetBytesKV([]byte("Access-Control-Allow-Headers"), c.headersVal)
|
||||
}
|
||||
if c.cfg.MaxAge > 0 {
|
||||
h.SetBytesKV([]byte("Access-Control-Max-Age"), c.maxAgeVal)
|
||||
}
|
||||
if c.cfg.AllowCredentials {
|
||||
h.SetBytesKV([]byte("Access-Control-Allow-Credentials"), []byte("true"))
|
||||
}
|
||||
|
||||
ctx.SetStatusCode(fasthttp.StatusNoContent)
|
||||
}
|
||||
|
||||
func (c *CORSMiddleware) setActualHeaders(ctx *fasthttp.RequestCtx, origin []byte) {
|
||||
h := &ctx.Response.Header
|
||||
h.SetBytesKV([]byte("Access-Control-Allow-Origin"), origin)
|
||||
|
||||
if len(c.exposeVal) > 0 {
|
||||
h.SetBytesKV([]byte("Access-Control-Expose-Headers"), c.exposeVal)
|
||||
}
|
||||
if c.cfg.AllowCredentials {
|
||||
h.SetBytesKV([]byte("Access-Control-Allow-Credentials"), []byte("true"))
|
||||
}
|
||||
}
|
||||
|
||||
func joinStrings(ss []string) string {
|
||||
switch len(ss) {
|
||||
case 0:
|
||||
return ""
|
||||
case 1:
|
||||
return ss[0]
|
||||
default:
|
||||
var buf []byte
|
||||
for i, s := range ss {
|
||||
if i > 0 {
|
||||
buf = append(buf, ',')
|
||||
}
|
||||
buf = append(buf, s...)
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
}
|
||||
|
||||
func intToStr(n int) string {
|
||||
if n == 0 {
|
||||
return "0"
|
||||
}
|
||||
buf := make([]byte, 0, 12)
|
||||
neg := false
|
||||
if n < 0 {
|
||||
neg = true
|
||||
n = -n
|
||||
}
|
||||
for n > 0 {
|
||||
buf = append(buf, byte('0'+n%10))
|
||||
n /= 10
|
||||
}
|
||||
if neg {
|
||||
buf = append(buf, '-')
|
||||
}
|
||||
for i, j := 0, len(buf)-1; i < j; i, j = i+1, j-1 {
|
||||
buf[i], buf[j] = buf[j], buf[i]
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
256
internal/middleware/cors/cors_test.go
Normal file
256
internal/middleware/cors/cors_test.go
Normal file
@ -0,0 +1,256 @@
|
||||
package cors
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
func newTestHandler() (*fasthttp.RequestCtx, fasthttp.RequestHandler) {
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/test")
|
||||
return ctx, func(ctx *fasthttp.RequestCtx) {
|
||||
ctx.SetStatusCode(fasthttp.StatusOK)
|
||||
ctx.SetBodyString("ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisabled_PassesThrough(t *testing.T) {
|
||||
cfg := &CORSConfig{Enabled: false}
|
||||
m := New(cfg)
|
||||
ctx, next := newTestHandler()
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "ok", string(ctx.Response.Body()))
|
||||
assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
|
||||
}
|
||||
|
||||
func TestNilConfig_PassesThrough(t *testing.T) {
|
||||
m := New(nil)
|
||||
ctx, next := newTestHandler()
|
||||
ctx.Request.Header.Set("Origin", "https://example.com")
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
|
||||
}
|
||||
|
||||
func TestNoOrigin_PassesThrough(t *testing.T) {
|
||||
cfg := &CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"https://example.com"},
|
||||
}
|
||||
m := New(cfg)
|
||||
ctx, next := newTestHandler()
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
|
||||
}
|
||||
|
||||
func TestNonMatchingOrigin_NoCORSHeaders(t *testing.T) {
|
||||
cfg := &CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"https://example.com"},
|
||||
}
|
||||
m := New(cfg)
|
||||
ctx, next := newTestHandler()
|
||||
ctx.Request.Header.Set("Origin", "https://evil.com")
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
|
||||
}
|
||||
|
||||
func TestMatchingOrigin_SetsCORSHeaders(t *testing.T) {
|
||||
cfg := &CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"https://example.com"},
|
||||
AllowCredentials: true,
|
||||
ExposeHeaders: []string{"X-Custom", "X-Another"},
|
||||
}
|
||||
m := New(cfg)
|
||||
ctx, next := newTestHandler()
|
||||
ctx.Request.Header.Set("Origin", "https://example.com")
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "https://example.com", string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
|
||||
assert.Equal(t, "true", string(ctx.Response.Header.Peek("Access-Control-Allow-Credentials")))
|
||||
assert.Equal(t, "X-Custom,X-Another", string(ctx.Response.Header.Peek("Access-Control-Expose-Headers")))
|
||||
}
|
||||
|
||||
func TestPreflight_Returns204(t *testing.T) {
|
||||
cfg := &CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"https://example.com"},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT"},
|
||||
AllowedHeaders: []string{"Content-Type", "Authorization"},
|
||||
MaxAge: 3600,
|
||||
}
|
||||
m := New(cfg)
|
||||
ctx, next := newTestHandler()
|
||||
ctx.Request.Header.SetMethod("OPTIONS")
|
||||
ctx.Request.Header.Set("Origin", "https://example.com")
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Equal(t, fasthttp.StatusNoContent, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "https://example.com", string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
|
||||
assert.Equal(t, "GET,POST,PUT", string(ctx.Response.Header.Peek("Access-Control-Allow-Methods")))
|
||||
assert.Equal(t, "Content-Type,Authorization", string(ctx.Response.Header.Peek("Access-Control-Allow-Headers")))
|
||||
assert.Equal(t, "3600", string(ctx.Response.Header.Peek("Access-Control-Max-Age")))
|
||||
}
|
||||
|
||||
func TestWildcardOrigin_MatchesAny(t *testing.T) {
|
||||
cfg := &CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"*"},
|
||||
}
|
||||
m := New(cfg)
|
||||
ctx, next := newTestHandler()
|
||||
ctx.Request.Header.Set("Origin", "https://anything.com")
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Equal(t, "https://anything.com", string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
|
||||
}
|
||||
|
||||
func TestAllowCredentials_SetsHeader(t *testing.T) {
|
||||
cfg := &CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"https://example.com"},
|
||||
AllowCredentials: true,
|
||||
}
|
||||
m := New(cfg)
|
||||
ctx, next := newTestHandler()
|
||||
ctx.Request.Header.Set("Origin", "https://example.com")
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Equal(t, "true", string(ctx.Response.Header.Peek("Access-Control-Allow-Credentials")))
|
||||
}
|
||||
|
||||
func TestMaxAge_SetsHeaderWhenPositive(t *testing.T) {
|
||||
cfg := &CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"https://example.com"},
|
||||
AllowedMethods: []string{"GET"},
|
||||
MaxAge: 7200,
|
||||
}
|
||||
m := New(cfg)
|
||||
ctx, next := newTestHandler()
|
||||
ctx.Request.Header.SetMethod("OPTIONS")
|
||||
ctx.Request.Header.Set("Origin", "https://example.com")
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Equal(t, "7200", string(ctx.Response.Header.Peek("Access-Control-Max-Age")))
|
||||
}
|
||||
|
||||
func TestMaxAge_NotSetWhenZero(t *testing.T) {
|
||||
cfg := &CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"https://example.com"},
|
||||
AllowedMethods: []string{"GET"},
|
||||
MaxAge: 0,
|
||||
}
|
||||
m := New(cfg)
|
||||
ctx, next := newTestHandler()
|
||||
ctx.Request.Header.SetMethod("OPTIONS")
|
||||
ctx.Request.Header.Set("Origin", "https://example.com")
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Max-Age")))
|
||||
}
|
||||
|
||||
func TestExposeHeaders_OnActualRequest(t *testing.T) {
|
||||
cfg := &CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"https://example.com"},
|
||||
ExposeHeaders: []string{"X-Total-Count", "X-Page"},
|
||||
}
|
||||
m := New(cfg)
|
||||
ctx, next := newTestHandler()
|
||||
ctx.Request.Header.Set("Origin", "https://example.com")
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Equal(t, "X-Total-Count,X-Page", string(ctx.Response.Header.Peek("Access-Control-Expose-Headers")))
|
||||
}
|
||||
|
||||
func TestMultipleOrigins_OnlyMatchingEchoedBack(t *testing.T) {
|
||||
cfg := &CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"https://a.com", "https://b.com", "https://c.com"},
|
||||
}
|
||||
m := New(cfg)
|
||||
|
||||
for _, origin := range []string{"https://a.com", "https://b.com", "https://c.com"} {
|
||||
ctx, next := newTestHandler()
|
||||
ctx.Request.Header.Set("Origin", origin)
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Equal(t, origin, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
|
||||
}
|
||||
|
||||
ctx, next := newTestHandler()
|
||||
ctx.Request.Header.Set("Origin", "https://evil.com")
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
|
||||
}
|
||||
|
||||
func TestEmptyAllowedOrigins_PassesThrough(t *testing.T) {
|
||||
cfg := &CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{},
|
||||
}
|
||||
m := New(cfg)
|
||||
ctx, next := newTestHandler()
|
||||
ctx.Request.Header.Set("Origin", "https://example.com")
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode())
|
||||
assert.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
|
||||
}
|
||||
|
||||
func TestPreflight_WithCredentials(t *testing.T) {
|
||||
cfg := &CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"https://example.com"},
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
AllowCredentials: true,
|
||||
}
|
||||
m := New(cfg)
|
||||
ctx, next := newTestHandler()
|
||||
ctx.Request.Header.SetMethod("OPTIONS")
|
||||
ctx.Request.Header.Set("Origin", "https://example.com")
|
||||
handler := m.Process(next)
|
||||
handler(ctx)
|
||||
assert.Equal(t, fasthttp.StatusNoContent, ctx.Response.StatusCode())
|
||||
assert.Equal(t, "true", string(ctx.Response.Header.Peek("Access-Control-Allow-Credentials")))
|
||||
assert.Equal(t, "https://example.com", string(ctx.Response.Header.Peek("Access-Control-Allow-Origin")))
|
||||
}
|
||||
|
||||
func TestName(t *testing.T) {
|
||||
m := New(nil)
|
||||
assert.Equal(t, "CORS", m.Name())
|
||||
}
|
||||
|
||||
func TestPreflight_DoesNotCallNext(t *testing.T) {
|
||||
called := false
|
||||
cfg := &CORSConfig{
|
||||
Enabled: true,
|
||||
AllowedOrigins: []string{"https://example.com"},
|
||||
AllowedMethods: []string{"GET"},
|
||||
}
|
||||
m := New(cfg)
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
ctx.Request.SetRequestURI("/test")
|
||||
ctx.Request.Header.SetMethod("OPTIONS")
|
||||
ctx.Request.Header.Set("Origin", "https://example.com")
|
||||
handler := m.Process(func(ctx *fasthttp.RequestCtx) {
|
||||
called = true
|
||||
})
|
||||
handler(ctx)
|
||||
assert.False(t, called, "next handler should not be called for preflight requests")
|
||||
assert.Equal(t, fasthttp.StatusNoContent, ctx.Response.StatusCode())
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user