From e79c23ea087100d00ec1fcb32ec28f6954962c35 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 18 Sep 2024 17:26:38 +0800 Subject: [PATCH] fix(axum): refactor logging middleware --- web/Rust/axum/src/middlewares/mod.rs | 50 ++++++++++++++++++++++++++-- web/Rust/axum/src/routes/mod.rs | 47 +++++--------------------- 2 files changed, 56 insertions(+), 41 deletions(-) diff --git a/web/Rust/axum/src/middlewares/mod.rs b/web/Rust/axum/src/middlewares/mod.rs index e9f49a1..3880d6a 100644 --- a/web/Rust/axum/src/middlewares/mod.rs +++ b/web/Rust/axum/src/middlewares/mod.rs @@ -1,6 +1,16 @@ +use std::time::Duration; + use axum::{ - body::Body, extract::Request, http::HeaderValue, middleware::Next, response::IntoResponse, + body::Bytes, + extract::Request, + http::{HeaderMap, HeaderValue}, + middleware::Next, + response::{IntoResponse, Response}, + Router, }; +use tower_http::classify::ServerErrorsFailureClass; +use tower_http::trace::TraceLayer; +use tracing::{error, info, info_span, Span}; use crate::error::AppResult; @@ -11,7 +21,10 @@ use crate::error::AppResult; /// obtaining the response. After receiving the response, it appends two headers: /// - "Server": The name of the server extracted from the Cargo package name. /// - "S-Version": The version of the server extracted from the Cargo package version. -pub async fn add_version(req: Request, next: Next) -> AppResult { +pub async fn add_version( + req: Request, + next: Next, +) -> AppResult { let mut res = next.run(req).await; let headers = res.headers_mut(); headers.append("Server", HeaderValue::from_static(env!("CARGO_PKG_NAME"))); @@ -21,3 +34,36 @@ pub async fn add_version(req: Request, next: Next) -> AppResult Router { + router.layer( + TraceLayer::new_for_http() + .make_span_with(|req: &Request<_>| { + let unknown = &HeaderValue::from_static("Unknown"); + let empty = &HeaderValue::from_static(""); + let headers = req.headers(); + let ua = headers + .get("User-Agent") + .unwrap_or(unknown) + .to_str() + .unwrap_or("Unknown"); + let host = headers.get("Host").unwrap_or(empty).to_str().unwrap_or(""); + info_span!("HTTP", method = ?req.method(), host, uri = ?req.uri(), ua) + }) + .on_request(|_req: &Request<_>, _span: &Span| {}) + .on_response(|res: &Response, latency: Duration, _span: &Span| { + info!("{} {}μs", res.status(), latency.as_micros()); + }) + .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {}) + .on_eos(|_trailers: Option<&HeaderMap>, _stream_duration: Duration, _span: &Span| {}) + .on_failure( + |error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| { + error!("{}", error); + }, + ), + ) +} diff --git a/web/Rust/axum/src/routes/mod.rs b/web/Rust/axum/src/routes/mod.rs index 9818e6c..3c44c9e 100644 --- a/web/Rust/axum/src/routes/mod.rs +++ b/web/Rust/axum/src/routes/mod.rs @@ -2,9 +2,8 @@ use std::{borrow::Cow, collections::HashMap, time::Duration}; use axum::{ async_trait, - body::Bytes, - extract::{FromRequestParts, Path, Request}, - http::{request::Parts, HeaderMap, HeaderValue, StatusCode, Uri}, + extract::{FromRequestParts, Path}, + http::{request::Parts, StatusCode, Uri}, middleware, response::{IntoResponse, Response}, routing::get, @@ -12,15 +11,12 @@ use axum::{ }; use serde::Serialize; use tower::ServiceBuilder; -use tower_http::{ - classify::ServerErrorsFailureClass, compression::CompressionLayer, cors::CorsLayer, - timeout::TimeoutLayer, trace::TraceLayer, -}; -use tracing::{error, info, info_span, Span}; +use tower_http::{compression::CompressionLayer, cors::CorsLayer, timeout::TimeoutLayer}; +use tracing::info; use crate::{ error::{AppResult, ErrorCode}, - middlewares::add_version, + middlewares::{add_version, logging_route}, }; #[derive(Debug, Serialize)] @@ -36,7 +32,7 @@ where pub type RouteResult = AppResult>>; pub fn routes() -> Router { - Router::new() + let router = Router::new() .route("/", get(hello).post(hello)) .layer( ServiceBuilder::new() @@ -45,35 +41,8 @@ pub fn routes() -> Router { .layer(TimeoutLayer::new(Duration::from_secs(15))) .layer(CompressionLayer::new()), ) - .fallback(fallback) - .layer( - TraceLayer::new_for_http() - .make_span_with(|req: &Request<_>| { - let unknown = &HeaderValue::from_static("Unknown"); - let empty = &HeaderValue::from_static(""); - let headers = req.headers(); - let ua = headers - .get("User-Agent") - .unwrap_or(unknown) - .to_str() - .unwrap_or("Unknown"); - let host = headers.get("Host").unwrap_or(empty).to_str().unwrap_or(""); - info_span!("HTTP", method = ?req.method(), host, uri = ?req.uri(), ua) - }) - .on_request(|_req: &Request<_>, _span: &Span| {}) - .on_response(|res: &Response, latency: Duration, _span: &Span| { - info!("{} {}μs", res.status(), latency.as_micros()); - }) - .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {}) - .on_eos( - |_trailers: Option<&HeaderMap>, _stream_duration: Duration, _span: &Span| {}, - ) - .on_failure( - |error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| { - error!("{}", error); - }, - ), - ) + .fallback(fallback); + logging_route(router) } /// hello world