diff --git a/web/Rust/axum/Cargo.lock b/web/Rust/axum/Cargo.lock index 02cbfc1..ec2da51 100644 --- a/web/Rust/axum/Cargo.lock +++ b/web/Rust/axum/Cargo.lock @@ -575,6 +575,7 @@ dependencies = [ "dotenvy", "serde", "serde_json", + "serde_repr", "thiserror", "tokio", "tower 0.5.1", @@ -758,6 +759,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_repr" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c64451ba24fc7a6a2d60fc75dd9c83c90903b19028d4eff35e88fc1e86564e9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" diff --git a/web/Rust/axum/Cargo.toml b/web/Rust/axum/Cargo.toml index 559b367..110e575 100644 --- a/web/Rust/axum/Cargo.toml +++ b/web/Rust/axum/Cargo.toml @@ -18,6 +18,7 @@ thiserror = "1.0.63" dotenvy = "0.15.7" serde = { version = "1.0.210", features = ["derive", "serde_derive"] } serde_json = { version = "1.0.128" } +serde_repr = "0.1.19" [profile.release] lto = true diff --git a/web/Rust/axum/src/consts.rs b/web/Rust/axum/src/consts.rs new file mode 100644 index 0000000..608a2a9 --- /dev/null +++ b/web/Rust/axum/src/consts.rs @@ -0,0 +1,3 @@ +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); +pub const NAME: &str = env!("CARGO_PKG_NAME"); +pub const DEFAULT_PORT: u16 = 4000; diff --git a/web/Rust/axum/src/error.rs b/web/Rust/axum/src/error.rs new file mode 100644 index 0000000..3147417 --- /dev/null +++ b/web/Rust/axum/src/error.rs @@ -0,0 +1,100 @@ +use std::{borrow::Cow, fmt::Display}; + +use axum::{ + extract::rejection::{FormRejection, JsonRejection}, + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde_json::json; +use serde_repr::*; +use tracing::error; + +#[derive(thiserror::Error, Debug)] +pub enum AppError { + #[error("{0}")] + Any(#[from] anyhow::Error), + + // axum + #[error(transparent)] + AxumFormRejection(#[from] FormRejection), + #[error(transparent)] + AxumJsonRejection(#[from] JsonRejection), + + // route + // 路由通常错误 错误信息直接返回用户 + #[error("{0}")] + AuthorizeFailed(Cow<'static, str>), + #[error("{0}")] + UserConflict(Cow<'static, str>), +} + +#[derive(Serialize_repr, Deserialize_repr, PartialEq, Debug)] +#[repr(u16)] +pub enum ErrorCode { + Normal = 200, + InternalError = 1000, + //NotAuthorized = 1001, + AuthorizeFailed = 1002, + UserConflict = 1003, + ParameterIncorrect = 1004, +} + +impl Display for ErrorCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use ErrorCode::*; + + let res = match self { + Normal => "", + InternalError => "服务器内部错误", + //NotAuthorized => "未登录", + AuthorizeFailed => "用户名或密码错误", + UserConflict => "该用户已经存在", + ParameterIncorrect => "请求参数错误", + }; + f.write_str(res)?; + Ok(()) + } +} + +/// Log and return INTERNAL_SERVER_ERROR +fn log_internal_error(err: T) -> (StatusCode, ErrorCode, String) { + use ErrorCode::*; + + error!("{err}"); + ( + StatusCode::INTERNAL_SERVER_ERROR, + InternalError, + "internal server error".to_string(), + ) +} + +// Tell axum how to convert `AppError` into a response. +impl IntoResponse for AppError { + fn into_response(self) -> Response { + use ErrorCode::*; + + let (status_code, code, err_message) = match self { + AppError::Any(err) => log_internal_error(err), + AppError::AxumFormRejection(_) | AppError::AxumJsonRejection(_) => ( + StatusCode::BAD_REQUEST, + ParameterIncorrect, + self.to_string(), + ), + + // route + AppError::AuthorizeFailed(err) => { + (StatusCode::UNAUTHORIZED, AuthorizeFailed, err.to_string()) + } + AppError::UserConflict(err) => (StatusCode::CONFLICT, UserConflict, err.to_string()), + }; + let body = Json(json!({ + "code": code, + "message": code.to_string(), + "error": err_message + })); + (status_code, body).into_response() + } +} + +pub type AppResult = Result; diff --git a/web/Rust/axum/src/main.rs b/web/Rust/axum/src/main.rs index 1afe56a..9c6de44 100644 --- a/web/Rust/axum/src/main.rs +++ b/web/Rust/axum/src/main.rs @@ -1,12 +1,47 @@ -use dotenvy::dotenv; -use tracing::info; -use utils::init_logger; +use std::{env, error::Error, net::SocketAddr}; +use axum::Router; +use consts::DEFAULT_PORT; +use dotenvy::dotenv; +use routes::routes; +use tokio::net::TcpListener; +use tracing::info; +use utils::{init_logger, shutdown_signal}; + +mod consts; +mod error; +mod middlewares; +mod routes; mod utils; +type Result = std::result::Result>; + #[tokio::main] -async fn main() { +async fn main() -> Result<()> { dotenv().ok(); init_logger(); info!("Hello, world!"); + + let port = env::var("VENUS_PORT") + .map(|port| port.parse::().unwrap_or(DEFAULT_PORT)) + .unwrap_or(DEFAULT_PORT); + let addr = SocketAddr::from(([0, 0, 0, 0], port)); + let listener = TcpListener::bind(addr).await?; + info!("listening on {}", addr); + + axum::serve(listener, app()) + .with_graceful_shutdown(shutdown_signal(shutdown)) + .await?; + Ok(()) +} + +#[derive(Debug, Clone)] +pub struct AppState {} + +fn app() -> Router { + Router::new().merge(routes()) +} + +fn shutdown() { + info!("Server shuting down") } diff --git a/web/Rust/axum/src/middlewares/mod.rs b/web/Rust/axum/src/middlewares/mod.rs new file mode 100644 index 0000000..e9f49a1 --- /dev/null +++ b/web/Rust/axum/src/middlewares/mod.rs @@ -0,0 +1,23 @@ +use axum::{ + body::Body, extract::Request, http::HeaderValue, middleware::Next, response::IntoResponse, +}; + +use crate::error::AppResult; + +/// Middleware for adding version information to each response's headers. +/// +/// This middleware takes an incoming `Request` and a `Next` handler, which represents the +/// subsequent middleware or route in the chain. It then asynchronously runs the next handler, +/// obtaining the response. After receiving the response, it appends two headers: +/// - "Server": The name of the server extracted from the Cargo package name. +/// - "S-Version": The version of the server extracted from the Cargo package version. +pub async fn add_version(req: Request, next: Next) -> AppResult { + let mut res = next.run(req).await; + let headers = res.headers_mut(); + headers.append("Server", HeaderValue::from_static(env!("CARGO_PKG_NAME"))); + headers.append( + "Phthonus-Version", + HeaderValue::from_static(env!("CARGO_PKG_VERSION")), + ); + Ok(res) +} diff --git a/web/Rust/axum/src/routes/mod.rs b/web/Rust/axum/src/routes/mod.rs new file mode 100644 index 0000000..9818e6c --- /dev/null +++ b/web/Rust/axum/src/routes/mod.rs @@ -0,0 +1,132 @@ +use std::{borrow::Cow, collections::HashMap, time::Duration}; + +use axum::{ + async_trait, + body::Bytes, + extract::{FromRequestParts, Path, Request}, + http::{request::Parts, HeaderMap, HeaderValue, StatusCode, Uri}, + middleware, + response::{IntoResponse, Response}, + routing::get, + Json, RequestPartsExt, Router, +}; +use serde::Serialize; +use tower::ServiceBuilder; +use tower_http::{ + classify::ServerErrorsFailureClass, compression::CompressionLayer, cors::CorsLayer, + timeout::TimeoutLayer, trace::TraceLayer, +}; +use tracing::{error, info, info_span, Span}; + +use crate::{ + error::{AppResult, ErrorCode}, + middlewares::add_version, +}; + +#[derive(Debug, Serialize)] +pub struct RouteResponse +where + T: Serialize, +{ + code: ErrorCode, + #[serde(skip_serializing_if = "Option::is_none")] + message: Option>, + data: T, +} +pub type RouteResult = AppResult>>; + +pub fn routes() -> Router { + Router::new() + .route("/", get(hello).post(hello)) + .layer( + ServiceBuilder::new() + .layer(middleware::from_fn(add_version)) + .layer(CorsLayer::permissive()) + .layer(TimeoutLayer::new(Duration::from_secs(15))) + .layer(CompressionLayer::new()), + ) + .fallback(fallback) + .layer( + TraceLayer::new_for_http() + .make_span_with(|req: &Request<_>| { + let unknown = &HeaderValue::from_static("Unknown"); + let empty = &HeaderValue::from_static(""); + let headers = req.headers(); + let ua = headers + .get("User-Agent") + .unwrap_or(unknown) + .to_str() + .unwrap_or("Unknown"); + let host = headers.get("Host").unwrap_or(empty).to_str().unwrap_or(""); + info_span!("HTTP", method = ?req.method(), host, uri = ?req.uri(), ua) + }) + .on_request(|_req: &Request<_>, _span: &Span| {}) + .on_response(|res: &Response, latency: Duration, _span: &Span| { + info!("{} {}μs", res.status(), latency.as_micros()); + }) + .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {}) + .on_eos( + |_trailers: Option<&HeaderMap>, _stream_duration: Duration, _span: &Span| {}, + ) + .on_failure( + |error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| { + error!("{}", error); + }, + ), + ) +} + +/// hello world +pub async fn hello() -> String { + format!("hello {}", env!("CARGO_PKG_NAME")) +} + +/// Fallback route handler for handling unmatched routes. +/// +/// This asynchronous function takes a `Uri` as an argument, representing the unmatched route. +/// It logs a message indicating that the specified route is not found and returns a standard +/// "Not Found" response with a `StatusCode` of `404`. +/// +/// # Arguments +/// +/// - `uri`: The `Uri` representing the unmatched route. +/// +/// # Returns +/// +/// Returns a tuple `(StatusCode, &str)` where `StatusCode` is set to `NOT_FOUND` (404), +/// indicating that the route was not found, and the string "Not found" as the response body. +pub async fn fallback(uri: Uri) -> impl IntoResponse { + info!("route {} not found", uri); + (StatusCode::NOT_FOUND, "Not found") +} + +#[derive(Debug)] +enum Version { + V1, + V2, + V3, +} + +#[async_trait] +impl FromRequestParts for Version +where + S: Send + Sync, +{ + type Rejection = Response; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let params: Path> = + parts.extract().await.map_err(IntoResponse::into_response)?; + + let version = params + .get("version") + .ok_or_else(|| (StatusCode::NOT_FOUND, "version param missing").into_response())?; + + match version.as_str() { + "v1" => Ok(Version::V1), + "v2" => Ok(Version::V2), + "v3" => Ok(Version::V3), + _ => Err((StatusCode::NOT_FOUND, "unknown version").into_response()), + } + } +} diff --git a/web/Rust/axum/src/utils/mod.rs b/web/Rust/axum/src/utils/mod.rs index 85a3f63..eb7c3a6 100644 --- a/web/Rust/axum/src/utils/mod.rs +++ b/web/Rust/axum/src/utils/mod.rs @@ -1,3 +1,4 @@ +use tokio::signal; use tracing_subscriber::{fmt, prelude::*, registry, EnvFilter}; /// Initializes the logger for tracing. @@ -26,3 +27,69 @@ pub fn init_logger() { registry().with(env_layer).with(formatting_layer).init(); } + +/// Asynchronously waits for a shutdown signal and executes a callback function when a signal is received. +/// +/// This function listens for shutdown signals in the form of `Ctrl+C` and termination signals. When one of +/// these signals is received, it invokes the provided callback function `shutdown_cb`. +/// +/// The behavior of the signal handling depends on the operating system: +/// +/// - On Unix-based systems (e.g., Linux, macOS), it listens for termination signals (such as SIGTERM). +/// - On non-Unix systems (e.g., Windows), it only listens for `Ctrl+C` and ignores termination signals. +/// +/// The `shutdown_cb` callback function is executed when either signal is received. This function should +/// contain the logic needed to gracefully shut down the application or perform any necessary cleanup tasks. +/// # Parameters +/// +/// - `shutdown_cb`: A closure or function to call when a shutdown signal is received. The function should +/// have the signature `Fn()`. This callback is executed without any parameters. +/// +/// # Errors +/// +/// - If setting up the signal handlers fails, the function will panic with an error message. +/// +/// # Panics +/// +/// - Panics if the setup for `Ctrl+C` or termination signal handlers fails. +/// +/// # Platform-specific behavior +/// +/// - On Unix-based systems, termination signals are handled using the `signal` crate for Unix signals. +/// - On non-Unix systems, only `Ctrl+C` signals are handled, and termination signals are not supported. +/// +/// # Future +/// +/// This function returns a future that resolves when either `Ctrl+C` or a termination signal is received +/// and the callback function has been executed. +pub async fn shutdown_signal(shutdown_cb: F) +where + F: Fn(), +{ + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => { + shutdown_cb() + // let _ = stop_core().map_err(log_err); + }, + _ = terminate => { + shutdown_cb() + }, + } +}