diff --git a/internal/middleware/requestid/requestid.go b/internal/middleware/requestid/requestid.go new file mode 100644 index 0000000..cfa7854 --- /dev/null +++ b/internal/middleware/requestid/requestid.go @@ -0,0 +1,59 @@ +package requestid + +import ( + "bytes" + + "github.com/google/uuid" + "github.com/valyala/fasthttp" + + "rua.plus/lolly/internal/middleware" +) + +var requestIDHeader = []byte("X-Request-ID") + +// RequestIDMiddleware generates or propagates X-Request-ID for request tracing. +type RequestIDMiddleware struct{} + +var _ middleware.Middleware = (*RequestIDMiddleware)(nil) + +// New creates a new Request-ID middleware. +func New() *RequestIDMiddleware { + return &RequestIDMiddleware{} +} + +// Name returns the middleware name. +func (m *RequestIDMiddleware) Name() string { return "request_id" } + +// Process implements the middleware.Middleware interface. +func (m *RequestIDMiddleware) Process(next fasthttp.RequestHandler) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + var id string + + incoming := ctx.Request.Header.PeekBytes(requestIDHeader) + if len(incoming) > 0 { + trimmed := bytes.TrimSpace(incoming) + if len(trimmed) > 0 { + id = string(trimmed) + } + } + + if id == "" { + id = uuid.New().String() + } + + ctx.SetUserValue("request_id", id) + ctx.Response.Header.SetBytesKV(requestIDHeader, []byte(id)) + + next(ctx) + } +} + +// GetRequestID extracts the request ID from the request context. +func GetRequestID(ctx *fasthttp.RequestCtx) string { + if v := ctx.UserValue("request_id"); v != nil { + if s, ok := v.(string); ok { + return s + } + } + return "" +} diff --git a/internal/middleware/requestid/requestid_test.go b/internal/middleware/requestid/requestid_test.go new file mode 100644 index 0000000..ee21b5a --- /dev/null +++ b/internal/middleware/requestid/requestid_test.go @@ -0,0 +1,145 @@ +package requestid + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/valyala/fasthttp" +) + +func TestRequestID_GeneratesUUID(t *testing.T) { + m := New() + var capturedID string + + next := func(ctx *fasthttp.RequestCtx) { + capturedID = GetRequestID(ctx) + ctx.SetStatusCode(fasthttp.StatusOK) + } + + handler := m.Process(next) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test") + + handler(ctx) + + assert.NotEmpty(t, capturedID, "request ID should be generated") + _, err := uuid.Parse(capturedID) + assert.NoError(t, err, "generated ID should be valid UUID") + + assert.Equal(t, capturedID, string(ctx.Response.Header.Peek("X-Request-ID"))) +} + +func TestRequestID_ReusesIncoming(t *testing.T) { + m := New() + incomingID := "existing-id-12345" + var capturedID string + + next := func(ctx *fasthttp.RequestCtx) { + capturedID = GetRequestID(ctx) + ctx.SetStatusCode(fasthttp.StatusOK) + } + + handler := m.Process(next) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test") + ctx.Request.Header.Set("X-Request-ID", incomingID) + + handler(ctx) + + assert.Equal(t, incomingID, capturedID) + assert.Equal(t, incomingID, string(ctx.Response.Header.Peek("X-Request-ID"))) +} + +func TestRequestID_EmptyHeaderGeneratesNew(t *testing.T) { + m := New() + var capturedID string + + next := func(ctx *fasthttp.RequestCtx) { + capturedID = GetRequestID(ctx) + ctx.SetStatusCode(fasthttp.StatusOK) + } + + handler := m.Process(next) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test") + ctx.Request.Header.Set("X-Request-ID", " ") + + handler(ctx) + + assert.NotEmpty(t, capturedID, "empty header should generate new UUID") + _, err := uuid.Parse(capturedID) + assert.NoError(t, err) +} + +func TestRequestID_UserValueAccessible(t *testing.T) { + m := New() + + next := func(ctx *fasthttp.RequestCtx) { + val := ctx.UserValue("request_id") + assert.NotNil(t, val) + s, ok := val.(string) + assert.True(t, ok) + assert.NotEmpty(t, s) + ctx.SetStatusCode(fasthttp.StatusOK) + } + + handler := m.Process(next) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test") + + handler(ctx) +} + +func TestRequestID_ResponseHeaderSet(t *testing.T) { + m := New() + + next := func(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(fasthttp.StatusOK) + } + + handler := m.Process(next) + + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test") + + handler(ctx) + + respHeader := string(ctx.Response.Header.Peek("X-Request-ID")) + assert.NotEmpty(t, respHeader) +} + +func TestRequestID_GeneratedUUIDValid(t *testing.T) { + m := New() + + next := func(ctx *fasthttp.RequestCtx) { + ctx.SetStatusCode(fasthttp.StatusOK) + } + + handler := m.Process(next) + + for range 10 { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI("/test") + + handler(ctx) + + respHeader := string(ctx.Response.Header.Peek("X-Request-ID")) + _, err := uuid.Parse(respHeader) + assert.NoError(t, err, "generated UUID should be valid: %s", respHeader) + } +} + +func TestRequestID_Name(t *testing.T) { + m := New() + assert.Equal(t, "request_id", m.Name()) +} + +func TestGetRequestID_Empty(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + assert.Equal(t, "", GetRequestID(ctx)) +}