mirror of
https://github.com/DefectingCat/candy
synced 2025-07-15 16:51:34 +00:00
feat(reverse_proxy): add basic proxy
This commit is contained in:
767
Cargo.lock
generated
767
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -23,6 +23,7 @@ axum-extra = { version = "0.10.1", features = ["typed-header"] }
|
||||
axum-server = { version = "0.7.2", features = ["tls-rustls"] }
|
||||
tower = { version = "0.5.2", features = ["full"] }
|
||||
tower-http = { version = "0.6.4", features = ["full"] }
|
||||
reqwest = { version = "0.12.18", features = ["stream"] }
|
||||
# tools
|
||||
thiserror = "2.0.12"
|
||||
anyhow = "1.0.98"
|
||||
|
@ -16,7 +16,10 @@ use crate::{
|
||||
};
|
||||
|
||||
pub mod error;
|
||||
// handle static file
|
||||
pub mod serve;
|
||||
// handle reverse proxy
|
||||
pub mod reverse_proxy;
|
||||
|
||||
/// Host configuration
|
||||
/// use virtual host port as key
|
||||
@ -40,10 +43,14 @@ pub async fn make_server(host: SettingHost) -> anyhow::Result<()> {
|
||||
for host_route in &host.route {
|
||||
// reverse proxy
|
||||
if host_route.proxy_pass.is_some() {
|
||||
continue;
|
||||
// router = router.route(host_route.location.as_ref(), get(hello));
|
||||
router = router.route(host_route.location.as_ref(), get(reverse_proxy::serve));
|
||||
// save route path to map
|
||||
{
|
||||
host_to_save
|
||||
.route_map
|
||||
.insert(host_route.location.clone(), host_route.clone());
|
||||
}
|
||||
|
||||
} else {
|
||||
// static file
|
||||
if host_route.root.is_none() {
|
||||
warn!("root field not found for route: {:?}", host_route.location);
|
||||
@ -63,7 +70,8 @@ pub async fn make_server(host: SettingHost) -> anyhow::Result<()> {
|
||||
router = router.route(&host_route.location, get(serve::serve));
|
||||
debug!("registed route {}", host_route.location);
|
||||
let len = host_route.location.len();
|
||||
let path_without_slash = host_route.location.chars().collect::<Vec<_>>()[0..len - 1]
|
||||
let path_without_slash = host_route.location.chars().collect::<Vec<_>>()
|
||||
[0..len - 1]
|
||||
.iter()
|
||||
.collect::<String>();
|
||||
// then register path without slash /doc/
|
||||
@ -96,6 +104,7 @@ pub async fn make_server(host: SettingHost) -> anyhow::Result<()> {
|
||||
router = router.route(route_path.as_ref(), get(serve::serve));
|
||||
debug!("registed route: {}", route_path);
|
||||
}
|
||||
}
|
||||
|
||||
// save host to map
|
||||
HOSTS.insert(host.port, host_to_save);
|
||||
@ -126,6 +135,7 @@ pub async fn make_server(host: SettingHost) -> anyhow::Result<()> {
|
||||
.as_ref()
|
||||
.ok_or(anyhow!("certificate_key not found"))?;
|
||||
debug!("certificate {} certificate_key {}", cert, key);
|
||||
|
||||
let rustls_config = RustlsConfig::from_pem_file(cert, key).await?;
|
||||
let addr: SocketAddr = addr.parse()?;
|
||||
info!("listening on https://{}", addr);
|
||||
|
92
src/http/reverse_proxy.rs
Normal file
92
src/http/reverse_proxy.rs
Normal file
@ -0,0 +1,92 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Path, Request},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use axum_extra::extract::Host;
|
||||
use http::Uri;
|
||||
use reqwest::Client;
|
||||
|
||||
use crate::utils::parse_port_from_host;
|
||||
|
||||
use super::{
|
||||
HOSTS,
|
||||
error::{RouteError, RouteResult},
|
||||
};
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn serve(
|
||||
req_uri: Uri,
|
||||
path: Option<Path<String>>,
|
||||
Host(host): Host,
|
||||
mut req: Request,
|
||||
) -> RouteResult<impl IntoResponse> {
|
||||
let req_path = req.uri().path();
|
||||
let path_query = req
|
||||
.uri()
|
||||
.path_and_query()
|
||||
.map(|v| v.as_str())
|
||||
.unwrap_or(req_path);
|
||||
|
||||
// Resolve the parent path:
|
||||
// - If `path` is provided, extract the parent segment from the URI.
|
||||
// - If `path` is None, use the URI path directly (ensuring it ends with '/').
|
||||
/// Resolves the parent path from the URI and optional path segment.
|
||||
fn resolve_parent_path(uri: &Uri, path: Option<&Path<String>>) -> String {
|
||||
match path {
|
||||
Some(path) => {
|
||||
let uri_path = uri.path();
|
||||
// use path sub to this uri path
|
||||
// to find parent path that store in ROUTE_MAP
|
||||
// uri: /assets/css/styles.07713cb6.css, path: Some(Path("assets/css/styles.07713cb6.css")
|
||||
let parent_path = uri_path.get(0..uri_path.len() - path.len());
|
||||
parent_path.unwrap_or("/").to_string()
|
||||
}
|
||||
None => {
|
||||
// uri need end with /
|
||||
// because global ROUTE_MAP key is end with /
|
||||
// so we need add / to uri path to get correct Route
|
||||
let uri_path = uri.path().to_string();
|
||||
if uri_path.ends_with('/') {
|
||||
uri_path
|
||||
} else {
|
||||
format!("{uri_path}/")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let scheme = req.uri().scheme_str().unwrap_or("http");
|
||||
let port = parse_port_from_host(&host, scheme).ok_or(RouteError::BadRequest())?;
|
||||
let route_map = &HOSTS.get(&port).ok_or(RouteError::BadRequest())?.route_map;
|
||||
tracing::debug!("Route map entries: {:?}", route_map);
|
||||
|
||||
let parent_path = resolve_parent_path(&req_uri, path.as_ref());
|
||||
tracing::debug!("parent path: {:?}", parent_path);
|
||||
let proxy_pass = route_map
|
||||
.get(&parent_path)
|
||||
.ok_or(RouteError::RouteNotFound())?;
|
||||
tracing::debug!("proxy pass: {:?}", proxy_pass);
|
||||
let Some(ref proxy_pass) = proxy_pass.proxy_pass else {
|
||||
// return custom_not_found!(host_route, request).await;
|
||||
return Err(RouteError::RouteNotFound());
|
||||
};
|
||||
let uri = format!("{proxy_pass}{path_query}");
|
||||
tracing::debug!("reverse proxy uri: {:?}", &uri);
|
||||
*req.uri_mut() = Uri::try_from(uri.clone()).map_err(|_| RouteError::InternalError())?;
|
||||
|
||||
let client = Client::new();
|
||||
let reqwest_response = client.get(uri).send().await.map_err(|e| {
|
||||
tracing::error!("Failed to proxy request: {}", e);
|
||||
RouteError::BadRequest()
|
||||
})?;
|
||||
|
||||
let mut response_builder = Response::builder().status(reqwest_response.status());
|
||||
*response_builder.headers_mut().unwrap() = reqwest_response.headers().clone();
|
||||
let res = response_builder
|
||||
.body(Body::from_stream(reqwest_response.bytes_stream()))
|
||||
// This unwrap is fine because the body is empty here
|
||||
.unwrap();
|
||||
|
||||
Ok(res)
|
||||
}
|
@ -1,3 +1,5 @@
|
||||
use std::process::exit;
|
||||
|
||||
use tokio::signal;
|
||||
use tracing::{debug, info};
|
||||
|
||||
@ -58,18 +60,18 @@ where
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {
|
||||
shutdown_cb()
|
||||
},
|
||||
_ = terminate => {
|
||||
shutdown_cb()
|
||||
},
|
||||
}
|
||||
|
||||
tracing::info!("Received termination signal shutting down");
|
||||
shutdown_cb()
|
||||
}
|
||||
|
||||
pub fn shutdown() {
|
||||
info!("Server shuting down")
|
||||
info!("Server shuting down");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
/// Parse port from host
|
||||
|
Reference in New Issue
Block a user