feat(reverse_proxy): add basic proxy

This commit is contained in:
xfy
2025-06-01 21:48:55 +08:00
parent 136345e872
commit 1b3fcf7443
5 changed files with 915 additions and 69 deletions

767
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -23,6 +23,7 @@ axum-extra = { version = "0.10.1", features = ["typed-header"] }
axum-server = { version = "0.7.2", features = ["tls-rustls"] } axum-server = { version = "0.7.2", features = ["tls-rustls"] }
tower = { version = "0.5.2", features = ["full"] } tower = { version = "0.5.2", features = ["full"] }
tower-http = { version = "0.6.4", features = ["full"] } tower-http = { version = "0.6.4", features = ["full"] }
reqwest = { version = "0.12.18", features = ["stream"] }
# tools # tools
thiserror = "2.0.12" thiserror = "2.0.12"
anyhow = "1.0.98" anyhow = "1.0.98"

View File

@ -16,7 +16,10 @@ use crate::{
}; };
pub mod error; pub mod error;
// handle static file
pub mod serve; pub mod serve;
// handle reverse proxy
pub mod reverse_proxy;
/// Host configuration /// Host configuration
/// use virtual host port as key /// use virtual host port as key
@ -40,61 +43,67 @@ pub async fn make_server(host: SettingHost) -> anyhow::Result<()> {
for host_route in &host.route { for host_route in &host.route {
// reverse proxy // reverse proxy
if host_route.proxy_pass.is_some() { if host_route.proxy_pass.is_some() {
continue; router = router.route(host_route.location.as_ref(), get(reverse_proxy::serve));
// router = router.route(host_route.location.as_ref(), get(hello)); // save route path to map
} {
host_to_save
// static file .route_map
if host_route.root.is_none() { .insert(host_route.location.clone(), host_route.clone());
warn!("root field not found for route: {:?}", host_route.location); }
continue;
}
// resister with location
// location = "/doc"
// route: GET /doc/*
// resister with file path
// index = ["index.html", "index.txt"]
// route: GET /doc/index.html
// route: GET /doc/index.txt
// register parent path /doc
let path_morethan_one = host_route.location.len() > 1;
let route_path = if path_morethan_one && host_route.location.ends_with('/') {
// first register path with slash /doc
router = router.route(&host_route.location, get(serve::serve));
debug!("registed route {}", host_route.location);
let len = host_route.location.len();
let path_without_slash = host_route.location.chars().collect::<Vec<_>>()[0..len - 1]
.iter()
.collect::<String>();
// then register path without slash /doc/
router = router.route(&path_without_slash, get(serve::serve));
debug!("registed route {}", path_without_slash);
host_route.location.clone()
} else if path_morethan_one {
// first register path without slash /doc
router = router.route(&host_route.location, get(serve::serve));
debug!("registed route {}", host_route.location);
// then register path with slash /doc/
let path = format!("{}/", host_route.location);
router = router.route(&path, get(serve::serve));
debug!("registed route {}", path);
path
} else { } else {
// register path /doc/ // static file
router = router.route(&host_route.location, get(serve::serve)); if host_route.root.is_none() {
debug!("registed route {}", host_route.location); warn!("root field not found for route: {:?}", host_route.location);
host_route.location.clone() continue;
}; }
// save route path to map // resister with location
{ // location = "/doc"
host_to_save // route: GET /doc/*
.route_map // resister with file path
.insert(route_path.clone(), host_route.clone()); // index = ["index.html", "index.txt"]
// route: GET /doc/index.html
// route: GET /doc/index.txt
// register parent path /doc
let path_morethan_one = host_route.location.len() > 1;
let route_path = if path_morethan_one && host_route.location.ends_with('/') {
// first register path with slash /doc
router = router.route(&host_route.location, get(serve::serve));
debug!("registed route {}", host_route.location);
let len = host_route.location.len();
let path_without_slash = host_route.location.chars().collect::<Vec<_>>()
[0..len - 1]
.iter()
.collect::<String>();
// then register path without slash /doc/
router = router.route(&path_without_slash, get(serve::serve));
debug!("registed route {}", path_without_slash);
host_route.location.clone()
} else if path_morethan_one {
// first register path without slash /doc
router = router.route(&host_route.location, get(serve::serve));
debug!("registed route {}", host_route.location);
// then register path with slash /doc/
let path = format!("{}/", host_route.location);
router = router.route(&path, get(serve::serve));
debug!("registed route {}", path);
path
} else {
// register path /doc/
router = router.route(&host_route.location, get(serve::serve));
debug!("registed route {}", host_route.location);
host_route.location.clone()
};
// save route path to map
{
host_to_save
.route_map
.insert(route_path.clone(), host_route.clone());
}
let route_path = format!("{route_path}{{*path}}");
// register wildcard path /doc/*
router = router.route(route_path.as_ref(), get(serve::serve));
debug!("registed route: {}", route_path);
} }
let route_path = format!("{route_path}{{*path}}");
// register wildcard path /doc/*
router = router.route(route_path.as_ref(), get(serve::serve));
debug!("registed route: {}", route_path);
} }
// save host to map // save host to map
@ -126,6 +135,7 @@ pub async fn make_server(host: SettingHost) -> anyhow::Result<()> {
.as_ref() .as_ref()
.ok_or(anyhow!("certificate_key not found"))?; .ok_or(anyhow!("certificate_key not found"))?;
debug!("certificate {} certificate_key {}", cert, key); debug!("certificate {} certificate_key {}", cert, key);
let rustls_config = RustlsConfig::from_pem_file(cert, key).await?; let rustls_config = RustlsConfig::from_pem_file(cert, key).await?;
let addr: SocketAddr = addr.parse()?; let addr: SocketAddr = addr.parse()?;
info!("listening on https://{}", addr); info!("listening on https://{}", addr);

92
src/http/reverse_proxy.rs Normal file
View File

@ -0,0 +1,92 @@
use axum::{
body::Body,
extract::{Path, Request},
response::{IntoResponse, Response},
};
use axum_extra::extract::Host;
use http::Uri;
use reqwest::Client;
use crate::utils::parse_port_from_host;
use super::{
HOSTS,
error::{RouteError, RouteResult},
};
#[axum::debug_handler]
pub async fn serve(
req_uri: Uri,
path: Option<Path<String>>,
Host(host): Host,
mut req: Request,
) -> RouteResult<impl IntoResponse> {
let req_path = req.uri().path();
let path_query = req
.uri()
.path_and_query()
.map(|v| v.as_str())
.unwrap_or(req_path);
// Resolve the parent path:
// - If `path` is provided, extract the parent segment from the URI.
// - If `path` is None, use the URI path directly (ensuring it ends with '/').
/// Resolves the parent path from the URI and optional path segment.
fn resolve_parent_path(uri: &Uri, path: Option<&Path<String>>) -> String {
match path {
Some(path) => {
let uri_path = uri.path();
// use path sub to this uri path
// to find parent path that store in ROUTE_MAP
// uri: /assets/css/styles.07713cb6.css, path: Some(Path("assets/css/styles.07713cb6.css")
let parent_path = uri_path.get(0..uri_path.len() - path.len());
parent_path.unwrap_or("/").to_string()
}
None => {
// uri need end with /
// because global ROUTE_MAP key is end with /
// so we need add / to uri path to get correct Route
let uri_path = uri.path().to_string();
if uri_path.ends_with('/') {
uri_path
} else {
format!("{uri_path}/")
}
}
}
}
let scheme = req.uri().scheme_str().unwrap_or("http");
let port = parse_port_from_host(&host, scheme).ok_or(RouteError::BadRequest())?;
let route_map = &HOSTS.get(&port).ok_or(RouteError::BadRequest())?.route_map;
tracing::debug!("Route map entries: {:?}", route_map);
let parent_path = resolve_parent_path(&req_uri, path.as_ref());
tracing::debug!("parent path: {:?}", parent_path);
let proxy_pass = route_map
.get(&parent_path)
.ok_or(RouteError::RouteNotFound())?;
tracing::debug!("proxy pass: {:?}", proxy_pass);
let Some(ref proxy_pass) = proxy_pass.proxy_pass else {
// return custom_not_found!(host_route, request).await;
return Err(RouteError::RouteNotFound());
};
let uri = format!("{proxy_pass}{path_query}");
tracing::debug!("reverse proxy uri: {:?}", &uri);
*req.uri_mut() = Uri::try_from(uri.clone()).map_err(|_| RouteError::InternalError())?;
let client = Client::new();
let reqwest_response = client.get(uri).send().await.map_err(|e| {
tracing::error!("Failed to proxy request: {}", e);
RouteError::BadRequest()
})?;
let mut response_builder = Response::builder().status(reqwest_response.status());
*response_builder.headers_mut().unwrap() = reqwest_response.headers().clone();
let res = response_builder
.body(Body::from_stream(reqwest_response.bytes_stream()))
// This unwrap is fine because the body is empty here
.unwrap();
Ok(res)
}

View File

@ -1,3 +1,5 @@
use std::process::exit;
use tokio::signal; use tokio::signal;
use tracing::{debug, info}; use tracing::{debug, info};
@ -58,18 +60,18 @@ where
tokio::select! { tokio::select! {
_ = ctrl_c => { _ = ctrl_c => {
shutdown_cb()
}, },
_ = terminate => { _ = terminate => {
shutdown_cb()
}, },
} }
tracing::info!("Received termination signal shutting down"); tracing::info!("Received termination signal shutting down");
shutdown_cb()
} }
pub fn shutdown() { pub fn shutdown() {
info!("Server shuting down") info!("Server shuting down");
exit(0);
} }
/// Parse port from host /// Parse port from host