refactor(ssl): use axum_server to handle tls

This commit is contained in:
xfy
2025-06-01 20:11:48 +08:00
parent 8d35c26d90
commit bcf088c347
7 changed files with 73 additions and 120 deletions

52
Cargo.lock generated
View File

@ -97,6 +97,12 @@ version = "1.0.98"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487"
[[package]]
name = "arc-swap"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
[[package]] [[package]]
name = "async-compression" name = "async-compression"
version = "0.4.22" version = "0.4.22"
@ -237,6 +243,28 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "axum-server"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "495c05f60d6df0093e8fb6e74aa5846a0ad06abaf96d76166283720bf740f8ab"
dependencies = [
"arc-swap",
"bytes",
"fs-err",
"http",
"http-body",
"hyper",
"hyper-util",
"pin-project-lite",
"rustls",
"rustls-pemfile",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]] [[package]]
name = "backtrace" name = "backtrace"
version = "0.3.74" version = "0.3.74"
@ -342,6 +370,7 @@ dependencies = [
"anyhow", "anyhow",
"axum", "axum",
"axum-extra", "axum-extra",
"axum-server",
"bytes", "bytes",
"clap", "clap",
"const_format", "const_format",
@ -599,6 +628,16 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "fs-err"
version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f89bda4c2a21204059a977ed3bfe746677dfd137b83c339e702b0ac91d482aa"
dependencies = [
"autocfg",
"tokio",
]
[[package]] [[package]]
name = "fs_extra" name = "fs_extra"
version = "1.3.0" version = "1.3.0"
@ -1321,6 +1360,15 @@ dependencies = [
"zeroize", "zeroize",
] ]
[[package]]
name = "rustls-pemfile"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50"
dependencies = [
"rustls-pki-types",
]
[[package]] [[package]]
name = "rustls-pki-types" name = "rustls-pki-types"
version = "1.11.0" version = "1.11.0"
@ -1565,9 +1613,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.44.2" version = "1.45.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779"
dependencies = [ dependencies = [
"backtrace", "backtrace",
"bytes", "bytes",

View File

@ -9,7 +9,7 @@ edition = "2024"
[dependencies] [dependencies]
# core # core
tokio = { version = "1.44.2", features = ["full"] } tokio = { version = "1.45.1", features = ["full"] }
tokio-util = "0.7.15" tokio-util = "0.7.15"
tokio-rustls = "0.26.2" tokio-rustls = "0.26.2"
hyper = { version = "1.6.0", features = ["full"] } hyper = { version = "1.6.0", features = ["full"] }
@ -20,6 +20,7 @@ futures-util = "0.3.31"
mimalloc = "0.1.46" mimalloc = "0.1.46"
axum = { version = "0.8.4", features = ["macros"] } axum = { version = "0.8.4", features = ["macros"] }
axum-extra = { version = "0.10.1", features = ["typed-header"] } axum-extra = { version = "0.10.1", features = ["typed-header"] }
axum-server = { version = "0.7.2", features = ["tls-rustls"] }
tower = { version = "0.5.2", features = ["full"] } tower = { version = "0.5.2", features = ["full"] }
tower-http = { version = "0.6.4", features = ["full"] } tower-http = { version = "0.6.4", features = ["full"] }
# tools # tools

View File

@ -14,7 +14,7 @@ dev:
CANDY_LOG=debug $(CARGO) watch -x run CANDY_LOG=debug $(CARGO) watch -x run
run: run:
$(CARGO) run CANDY_LOG=debug $(CARGO) run
test: test:
$(CARGO) test $(CARGO) test

View File

@ -1,26 +1,13 @@
use std::{ use std::{net::SocketAddr, sync::LazyLock, time::Duration};
path::Path,
sync::{Arc, LazyLock},
time::Duration,
};
use anyhow::anyhow; use anyhow::anyhow;
use axum::{Router, extract::Request, middleware, routing::get}; use axum::{Router, middleware, routing::get};
use axum_server::tls_rustls::RustlsConfig;
use dashmap::DashMap; use dashmap::DashMap;
use futures_util::pin_mut;
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio_rustls::{ use tower::ServiceBuilder;
TlsAcceptor,
rustls::{
ServerConfig,
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
},
};
use tower::{Service, ServiceBuilder};
use tower_http::{compression::CompressionLayer, timeout::TimeoutLayer}; use tower_http::{compression::CompressionLayer, timeout::TimeoutLayer};
use tracing::{debug, error, info, warn}; use tracing::{debug, info, warn};
use crate::{ use crate::{
config::SettingHost, config::SettingHost,
@ -124,8 +111,6 @@ pub async fn make_server(host: SettingHost) -> anyhow::Result<()> {
router = logging_route(router); router = logging_route(router);
let addr = format!("{}:{}", host.ip, host.port); let addr = format!("{}:{}", host.ip, host.port);
let listener = TcpListener::bind(&addr).await?;
info!("listening on {}", addr);
// check ssl eanbled or not // check ssl eanbled or not
// if ssl enabled // if ssl enabled
@ -141,57 +126,15 @@ pub async fn make_server(host: SettingHost) -> anyhow::Result<()> {
.as_ref() .as_ref()
.ok_or(anyhow!("certificate_key not found"))?; .ok_or(anyhow!("certificate_key not found"))?;
debug!("certificate {} certificate_key {}", cert, key); debug!("certificate {} certificate_key {}", cert, key);
let rustls_config = rustls_server_config(key, cert)?; let rustls_config = RustlsConfig::from_pem_file(cert, key).await?;
let tls_acceptor = TlsAcceptor::from(rustls_config); let addr: SocketAddr = addr.parse()?;
info!("listening on https://{}", addr);
pin_mut!(listener); axum_server::bind_rustls(addr, rustls_config)
loop { .serve(router.into_make_service())
let tower_service = router.clone(); .await?;
let tls_acceptor = tls_acceptor.clone();
// Wait for new tcp connecttion
let (cnx, addr) = match listener.accept().await {
Ok((cnx, addr)) => (cnx, addr),
Err(err) => {
error!("TCP connection accept error: {:?}", err);
continue;
}
};
let tls_handler = async move {
// Wait for tls handshake to happen
let Ok(stream) = tls_acceptor.accept(cnx).await else {
error!("error during tls handshake connection from {}", addr);
return;
};
// Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
// `TokioIo` converts between them.
let stream = TokioIo::new(stream);
// Hyper also has its own `Service` trait and doesn't use tower. We can use
// `hyper::service::service_fn` to create a hyper `Service` that calls our app through
// `tower::Service::call`.
let hyper_service =
hyper::service::service_fn(move |request: Request<Incoming>| {
// We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
// tower's `Service` requires `&mut self`.
//
// We don't need to call `poll_ready` since `Router` is always ready.
tower_service.clone().call(request)
});
let ret = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(stream, hyper_service)
.await;
if let Err(err) = ret {
warn!("error serving connection from {}: {}", addr, err);
}
};
tokio::spawn(tls_handler);
}
} else { } else {
let listener = TcpListener::bind(&addr).await?;
info!("listening on http://{}", addr);
axum::serve(listener, router) axum::serve(listener, router)
.with_graceful_shutdown(shutdown_signal(shutdown)) .with_graceful_shutdown(shutdown_signal(shutdown))
.await?; .await?;
@ -199,43 +142,3 @@ pub async fn make_server(host: SettingHost) -> anyhow::Result<()> {
Ok(()) Ok(())
} }
/// Creates a Rustls `ServerConfig` for TLS-enabled connections.
///
/// # Arguments
/// - `key`: Path to the PEM-encoded private key file.
/// - `cert`: Path to the PEM-encoded certificate chain file.
///
/// # Returns
/// - `Ok(Arc<ServerConfig>)`: A configured `ServerConfig` with:
/// - No client authentication.
/// - ALPN protocols `h2` and `http/1.1` for HTTP/2 and HTTP/1.1 support.
/// - The provided certificate and private key.
/// - `Err(anyhow::Error)`: If the key/cert files are missing, malformed, or invalid.
///
/// # Errors
/// - Fails if:
/// - The private key or certificate files cannot be read or parsed.
/// - The key/cert pair is incompatible (e.g., mismatched algorithms).
/// - The certificate chain is empty or invalid.
///
/// # Example
/// ```rust
/// let config = rustls_server_config("key.pem", "cert.pem")?;
fn rustls_server_config(
key: impl AsRef<Path>,
cert: impl AsRef<Path>,
) -> anyhow::Result<Arc<ServerConfig>> {
let key = PrivateKeyDer::from_pem_file(key)?;
let certs = CertificateDer::pem_file_iter(cert)?.try_collect()?;
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.expect("bad certificate/key");
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Ok(Arc::new(config))
}

View File

@ -144,7 +144,6 @@ pub async fn serve(
// check static file root configuration // check static file root configuration
// if root is None, then return InternalError // if root is None, then return InternalError
let Some(ref root) = host_route.root else { let Some(ref root) = host_route.root else {
// return Err(RouteError::InternalError());
return custom_not_found!(host_route, request).await; return custom_not_found!(host_route, request).await;
}; };
// try find index file first // try find index file first

View File

@ -45,10 +45,11 @@ async fn main() -> Result<()> {
info!("server started"); info!("server started");
while let Some(res) = servers.join_next().await { while let Some(res) = servers.join_next().await {
if let Err(err) = res { error!("server error: {:?}", res);
error!("server error: {}", err); // if let Err(err) = res {
continue; // error!("server error: {}", err);
} // continue;
// }
} }
Ok(()) Ok(())

View File

@ -59,12 +59,13 @@ where
tokio::select! { tokio::select! {
_ = ctrl_c => { _ = ctrl_c => {
shutdown_cb() shutdown_cb()
// let _ = stop_core().map_err(log_err);
}, },
_ = terminate => { _ = terminate => {
shutdown_cb() shutdown_cb()
}, },
} }
tracing::info!("Received termination signal shutting down");
} }
pub fn shutdown() { pub fn shutdown() {