fix(axum): refactor logging middleware

This commit is contained in:
xfy
2024-09-18 17:26:38 +08:00
parent b01d97242f
commit e79c23ea08
2 changed files with 56 additions and 41 deletions

View File

@ -1,6 +1,16 @@
use std::time::Duration;
use axum::{
body::Body, extract::Request, http::HeaderValue, middleware::Next, response::IntoResponse,
body::Bytes,
extract::Request,
http::{HeaderMap, HeaderValue},
middleware::Next,
response::{IntoResponse, Response},
Router,
};
use tower_http::classify::ServerErrorsFailureClass;
use tower_http::trace::TraceLayer;
use tracing::{error, info, info_span, Span};
use crate::error::AppResult;
@ -11,7 +21,10 @@ use crate::error::AppResult;
/// obtaining the response. After receiving the response, it appends two headers:
/// - "Server": The name of the server extracted from the Cargo package name.
/// - "S-Version": The version of the server extracted from the Cargo package version.
pub async fn add_version(req: Request<Body>, next: Next) -> AppResult<impl IntoResponse> {
pub async fn add_version(
req: Request<axum::body::Body>,
next: Next,
) -> AppResult<impl IntoResponse> {
let mut res = next.run(req).await;
let headers = res.headers_mut();
headers.append("Server", HeaderValue::from_static(env!("CARGO_PKG_NAME")));
@ -21,3 +34,36 @@ pub async fn add_version(req: Request<Body>, next: Next) -> AppResult<impl IntoR
);
Ok(res)
}
/// Middleware for logging each request.
///
/// This middleware will calculate each request latency
/// and add request's information to each info_span.
pub fn logging_route(router: Router) -> Router {
router.layer(
TraceLayer::new_for_http()
.make_span_with(|req: &Request<_>| {
let unknown = &HeaderValue::from_static("Unknown");
let empty = &HeaderValue::from_static("");
let headers = req.headers();
let ua = headers
.get("User-Agent")
.unwrap_or(unknown)
.to_str()
.unwrap_or("Unknown");
let host = headers.get("Host").unwrap_or(empty).to_str().unwrap_or("");
info_span!("HTTP", method = ?req.method(), host, uri = ?req.uri(), ua)
})
.on_request(|_req: &Request<_>, _span: &Span| {})
.on_response(|res: &Response, latency: Duration, _span: &Span| {
info!("{} {}μs", res.status(), latency.as_micros());
})
.on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {})
.on_eos(|_trailers: Option<&HeaderMap>, _stream_duration: Duration, _span: &Span| {})
.on_failure(
|error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
error!("{}", error);
},
),
)
}

View File

@ -2,9 +2,8 @@ use std::{borrow::Cow, collections::HashMap, time::Duration};
use axum::{
async_trait,
body::Bytes,
extract::{FromRequestParts, Path, Request},
http::{request::Parts, HeaderMap, HeaderValue, StatusCode, Uri},
extract::{FromRequestParts, Path},
http::{request::Parts, StatusCode, Uri},
middleware,
response::{IntoResponse, Response},
routing::get,
@ -12,15 +11,12 @@ use axum::{
};
use serde::Serialize;
use tower::ServiceBuilder;
use tower_http::{
classify::ServerErrorsFailureClass, compression::CompressionLayer, cors::CorsLayer,
timeout::TimeoutLayer, trace::TraceLayer,
};
use tracing::{error, info, info_span, Span};
use tower_http::{compression::CompressionLayer, cors::CorsLayer, timeout::TimeoutLayer};
use tracing::info;
use crate::{
error::{AppResult, ErrorCode},
middlewares::add_version,
middlewares::{add_version, logging_route},
};
#[derive(Debug, Serialize)]
@ -36,7 +32,7 @@ where
pub type RouteResult<T> = AppResult<Json<RouteResponse<T>>>;
pub fn routes() -> Router {
Router::new()
let router = Router::new()
.route("/", get(hello).post(hello))
.layer(
ServiceBuilder::new()
@ -45,35 +41,8 @@ pub fn routes() -> Router {
.layer(TimeoutLayer::new(Duration::from_secs(15)))
.layer(CompressionLayer::new()),
)
.fallback(fallback)
.layer(
TraceLayer::new_for_http()
.make_span_with(|req: &Request<_>| {
let unknown = &HeaderValue::from_static("Unknown");
let empty = &HeaderValue::from_static("");
let headers = req.headers();
let ua = headers
.get("User-Agent")
.unwrap_or(unknown)
.to_str()
.unwrap_or("Unknown");
let host = headers.get("Host").unwrap_or(empty).to_str().unwrap_or("");
info_span!("HTTP", method = ?req.method(), host, uri = ?req.uri(), ua)
})
.on_request(|_req: &Request<_>, _span: &Span| {})
.on_response(|res: &Response, latency: Duration, _span: &Span| {
info!("{} {}μs", res.status(), latency.as_micros());
})
.on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {})
.on_eos(
|_trailers: Option<&HeaderMap>, _stream_duration: Duration, _span: &Span| {},
)
.on_failure(
|error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
error!("{}", error);
},
),
)
.fallback(fallback);
logging_route(router)
}
/// hello world