mirror of
https://github.com/DefectingCat/candy
synced 2025-07-15 08:41:35 +00:00
Compare commits
7 Commits
f0da74a06e
...
25ff344474
Author | SHA1 | Date | |
---|---|---|---|
25ff344474 | |||
13c4570e20 | |||
e92b191360 | |||
14196c937c | |||
091f0c8eb5 | |||
f96e440264 | |||
36cd153fab |
@ -5,7 +5,7 @@ use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use const_format::{concatcp, formatcp};
|
||||
use const_format::formatcp;
|
||||
use serde_repr::*;
|
||||
use tracing::error;
|
||||
|
||||
@ -36,7 +36,7 @@ pub enum ErrorCode {
|
||||
}
|
||||
|
||||
/// Normal error message
|
||||
const SERVER_ERROR_STR: &str = concatcp!(
|
||||
const SERVER_ERROR_STR: &str = formatcp!(
|
||||
r#"Internal Server Error
|
||||
{NAME} v{VERSION}
|
||||
Powered by RUA
|
||||
|
144
src/http/lua.rs
144
src/http/lua.rs
@ -1,27 +1,149 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::{Context, anyhow};
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Path, Request},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use axum_extra::extract::Host;
|
||||
use http::{HeaderMap, HeaderName, HeaderValue, Uri};
|
||||
use mlua::{UserData, UserDataMethods, UserDataRef};
|
||||
use tokio::fs::{self};
|
||||
use tracing::error;
|
||||
|
||||
use crate::{
|
||||
http::{HOSTS, LUA_ENGINE, error::RouteError, serve::resolve_parent_path},
|
||||
utils::parse_port_from_host,
|
||||
};
|
||||
|
||||
use super::error::RouteResult;
|
||||
|
||||
/// 为 Lua 脚本提供 HTTP 请求上下文
|
||||
#[derive(Clone, Debug)]
|
||||
struct CandyRequest {
|
||||
#[allow(dead_code)]
|
||||
method: String,
|
||||
/// Uri 在路由中被添加到上下文中
|
||||
uri: Uri,
|
||||
}
|
||||
/// 为 Lua 脚本提供 HTTP 响应上下文
|
||||
#[derive(Clone, Debug)]
|
||||
struct CandyResponse {
|
||||
status: u16,
|
||||
headers: HeaderMap,
|
||||
body: String,
|
||||
}
|
||||
// HTTP 请求上下文,可在 Lua 中使用
|
||||
#[derive(Clone, Debug)]
|
||||
struct RequestContext {
|
||||
req: CandyRequest,
|
||||
res: CandyResponse,
|
||||
}
|
||||
|
||||
impl UserData for RequestContext {
|
||||
fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
|
||||
// 获取请求路径
|
||||
methods.add_method("get_path", |_, this, ()| {
|
||||
Ok(this.req.uri.path().to_string())
|
||||
});
|
||||
|
||||
// 获取请求方法
|
||||
methods.add_method("get_method", |_, this, ()| Ok(this.req.method.to_string()));
|
||||
|
||||
// 设置响应状态码
|
||||
methods.add_method_mut("set_status", |_, this, status: u16| {
|
||||
this.res.status = status;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
// 设置响应内容
|
||||
methods.add_method_mut("set_body", |_, this, body: String| {
|
||||
this.res.body = format!("{}{}", this.res.body, body);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
// 设置响应头
|
||||
methods.add_method_mut("set_header", |_, this, (key, value): (String, String)| {
|
||||
this.res.headers.insert(
|
||||
HeaderName::from_str(&key).map_err(|err| anyhow!("header name error: {err}"))?,
|
||||
HeaderValue::from_str(&value)
|
||||
.map_err(|err| anyhow!("header value error: {err}"))?,
|
||||
);
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn lua(
|
||||
req_uri: Uri,
|
||||
path: Option<Path<String>>,
|
||||
Host(host): Host,
|
||||
mut req: Request<Body>,
|
||||
req: Request<Body>,
|
||||
) -> RouteResult<impl IntoResponse> {
|
||||
let req_path = req.uri().path();
|
||||
let path_query = req
|
||||
.uri()
|
||||
.path_and_query()
|
||||
.map(|v| v.as_str())
|
||||
.unwrap_or(req_path);
|
||||
|
||||
let scheme = req.uri().scheme_str().unwrap_or("http");
|
||||
let port = parse_port_from_host(&host, scheme).ok_or(RouteError::BadRequest())?;
|
||||
let route_map = &HOSTS.get(&port).ok_or(RouteError::BadRequest())?.route_map;
|
||||
let route_map = &HOSTS
|
||||
.get(&port)
|
||||
.ok_or(RouteError::BadRequest())
|
||||
.with_context(|| {
|
||||
format!("Hosts not found for port: {port}, host: {host}, scheme: {scheme}")
|
||||
})?
|
||||
.route_map;
|
||||
tracing::debug!("Route map entries: {:?}", route_map);
|
||||
|
||||
let parent_path = resolve_parent_path(&req_uri, path.as_ref());
|
||||
let route_config = route_map
|
||||
.get(&parent_path)
|
||||
.ok_or(RouteError::RouteNotFound())?;
|
||||
.ok_or(RouteError::RouteNotFound())
|
||||
.with_context(|| format!("route not found: {parent_path}"))?;
|
||||
let lua_script = route_config
|
||||
.lua_script
|
||||
.as_ref()
|
||||
.ok_or(RouteError::InternalError())?;
|
||||
.ok_or(RouteError::InternalError())
|
||||
.with_context(|| "lua script not found")?;
|
||||
|
||||
let method = req.method().to_string();
|
||||
|
||||
let lua = &LUA_ENGINE.lua;
|
||||
let script = fs::read_to_string(lua_script)
|
||||
.await
|
||||
.with_context(|| format!("Failed to read lua script file: {lua_script}",))?;
|
||||
lua.globals()
|
||||
.set(
|
||||
"ctx",
|
||||
RequestContext {
|
||||
req: CandyRequest {
|
||||
method,
|
||||
uri: req_uri,
|
||||
},
|
||||
res: CandyResponse {
|
||||
status: 200,
|
||||
headers: HeaderMap::new(),
|
||||
body: "".to_string(),
|
||||
},
|
||||
},
|
||||
)
|
||||
.map_err(|err| {
|
||||
error!("Lua script {lua_script} exec error: {err}");
|
||||
RouteError::InternalError()
|
||||
})?;
|
||||
lua.load(script).exec_async().await.map_err(|err| {
|
||||
error!("Lua script {lua_script} exec error: {err}");
|
||||
RouteError::InternalError()
|
||||
})?;
|
||||
// 获取修改后的上下文并返回响应
|
||||
let ctx: UserDataRef<RequestContext> = lua.globals().get("ctx").map_err(|err| {
|
||||
error!("Lua script {lua_script} exec error: {err}");
|
||||
RouteError::InternalError()
|
||||
})?;
|
||||
let res = ctx.res.clone();
|
||||
|
||||
let mut response = Response::builder();
|
||||
let body = Body::from(res.body);
|
||||
response = response.status(res.status);
|
||||
let response = response
|
||||
.body(body)
|
||||
.with_context(|| "Failed to build HTTP response with lua")?;
|
||||
Ok(response)
|
||||
}
|
||||
|
@ -1,4 +1,8 @@
|
||||
use std::{net::SocketAddr, sync::LazyLock, time::Duration};
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
sync::{Arc, LazyLock},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use axum::{Router, extract::DefaultBodyLimit, middleware, routing::get};
|
||||
@ -11,6 +15,7 @@ use tracing::{debug, info, warn};
|
||||
|
||||
use crate::{
|
||||
config::SettingHost,
|
||||
consts::{ARCH, COMMIT, COMPILER, NAME, OS, VERSION},
|
||||
middlewares::{add_headers, add_version, logging_route},
|
||||
utils::graceful_shutdown,
|
||||
};
|
||||
@ -20,6 +25,8 @@ pub mod error;
|
||||
pub mod serve;
|
||||
// handle reverse proxy
|
||||
pub mod reverse_proxy;
|
||||
// handle lua script
|
||||
pub mod lua;
|
||||
|
||||
/// Host configuration
|
||||
/// use virtual host port as key
|
||||
@ -34,8 +41,79 @@ pub mod reverse_proxy;
|
||||
/// }
|
||||
pub static HOSTS: LazyLock<DashMap<u16, SettingHost>> = LazyLock::new(DashMap::new);
|
||||
|
||||
pub struct LuaEngine {
|
||||
pub lua: Lua,
|
||||
/// Lua 共享字典
|
||||
#[allow(dead_code)]
|
||||
pub shared_table: Arc<DashMap<String, String>>,
|
||||
}
|
||||
impl LuaEngine {
|
||||
pub fn new() -> Self {
|
||||
let lua = Lua::new();
|
||||
let shared_table: DashMap<String, String> = DashMap::new();
|
||||
let shared_table = Arc::new(shared_table);
|
||||
|
||||
let module = lua.create_table().expect("create table failed");
|
||||
let shared_api = lua.create_table().expect("create shared table failed");
|
||||
|
||||
// 创建共享字典到 lua 中
|
||||
let shared_table_get = shared_table.clone();
|
||||
shared_api
|
||||
.set(
|
||||
"set",
|
||||
lua.create_function(move |_, (key, value): (String, String)| {
|
||||
shared_table_get.insert(key, value.clone());
|
||||
Ok(())
|
||||
})
|
||||
.expect("create set function failed"),
|
||||
)
|
||||
.expect("set failed");
|
||||
let shared_table_get = shared_table.clone();
|
||||
shared_api
|
||||
.set(
|
||||
"get",
|
||||
lua.create_function(move |_, key: String| {
|
||||
let t = shared_table_get.get(&key).ok_or(anyhow!("key not found"))?;
|
||||
Ok(t.clone())
|
||||
})
|
||||
.expect("create get function failed"),
|
||||
)
|
||||
.expect("get failed");
|
||||
module
|
||||
.set("shared", shared_api)
|
||||
.expect("set shared_api failed");
|
||||
|
||||
// 日志函数
|
||||
module
|
||||
.set(
|
||||
"log",
|
||||
lua.create_function(move |_, msg: String| {
|
||||
tracing::info!("Lua: {}", msg);
|
||||
Ok(())
|
||||
})
|
||||
.expect("create log function failed"),
|
||||
)
|
||||
.expect("set log failed");
|
||||
|
||||
module.set("version", VERSION).expect("set version failed");
|
||||
module.set("name", NAME).expect("set name failed");
|
||||
module.set("os", OS).expect("set os failed");
|
||||
module.set("arch", ARCH).expect("set arch failed");
|
||||
module
|
||||
.set("compiler", COMPILER)
|
||||
.expect("set compiler failed");
|
||||
module.set("commit", COMMIT).expect("set commit failed");
|
||||
|
||||
// 全局变量 candy
|
||||
lua.globals()
|
||||
.set("candy", module)
|
||||
.expect("set candy table to lua engine failed");
|
||||
|
||||
Self { lua, shared_table }
|
||||
}
|
||||
}
|
||||
/// lua 脚本执行器
|
||||
pub static LUA_EXECUTOR: LazyLock<Lua> = LazyLock::new(Lua::new);
|
||||
pub static LUA_ENGINE: LazyLock<LuaEngine> = LazyLock::new(LuaEngine::new);
|
||||
|
||||
pub async fn make_server(host: SettingHost) -> anyhow::Result<()> {
|
||||
let mut router = Router::new();
|
||||
@ -45,8 +123,17 @@ pub async fn make_server(host: SettingHost) -> anyhow::Result<()> {
|
||||
// register routes
|
||||
for host_route in &host.route {
|
||||
// lua script
|
||||
if let Some(lua_path) = &host_route.lua_script {
|
||||
if host_route.lua_script.is_some() {
|
||||
// papare lua script
|
||||
router = router.route(host_route.location.as_ref(), get(lua::lua));
|
||||
let route_path = format!("{}{{*path}}", host_route.location);
|
||||
router = router.route(route_path.as_ref(), get(lua::lua));
|
||||
// save route path to map
|
||||
{
|
||||
host_to_save
|
||||
.route_map
|
||||
.insert(host_route.location.clone(), host_route.clone());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -157,21 +157,30 @@ pub async fn serve(
|
||||
|
||||
let scheme = req.uri().scheme_str().unwrap_or("http");
|
||||
let port = parse_port_from_host(&host, scheme).ok_or(RouteError::BadRequest())?;
|
||||
let route_map = &HOSTS.get(&port).ok_or(RouteError::BadRequest())?.route_map;
|
||||
let route_map = &HOSTS
|
||||
.get(&port)
|
||||
.ok_or(RouteError::BadRequest())
|
||||
.with_context(|| {
|
||||
format!("Hosts not found for port: {port}, host: {host}, scheme: {scheme}")
|
||||
})?
|
||||
.route_map;
|
||||
tracing::debug!("Route map entries: {:?}", route_map);
|
||||
|
||||
let parent_path = resolve_parent_path(&req_uri, path.as_ref());
|
||||
tracing::debug!("parent path: {:?}", parent_path);
|
||||
let proxy_config = route_map
|
||||
.get(&parent_path)
|
||||
.ok_or(RouteError::RouteNotFound())?;
|
||||
.ok_or(RouteError::RouteNotFound())
|
||||
.with_context(|| format!("route not found: {parent_path}"))?;
|
||||
tracing::debug!("proxy pass: {:?}", proxy_config);
|
||||
let Some(ref proxy_pass) = proxy_config.proxy_pass else {
|
||||
return handle_custom_page(proxy_config, req, true).await;
|
||||
};
|
||||
let uri = format!("{proxy_pass}{path_query}");
|
||||
tracing::debug!("reverse proxy uri: {:?}", &uri);
|
||||
*req.uri_mut() = Uri::try_from(uri.clone()).map_err(|_| RouteError::InternalError())?;
|
||||
*req.uri_mut() = Uri::try_from(uri.clone())
|
||||
.map_err(|_| RouteError::InternalError())
|
||||
.with_context(|| format!("uri not found: {uri}"))?;
|
||||
|
||||
let timeout = proxy_config.proxy_timeout;
|
||||
|
||||
@ -210,7 +219,8 @@ pub async fn serve(
|
||||
reqwest_response.headers(),
|
||||
response_builder
|
||||
.headers_mut()
|
||||
.ok_or(RouteError::InternalError())?,
|
||||
.ok_or(RouteError::InternalError())
|
||||
.with_context(|| "headers not found")?,
|
||||
);
|
||||
let res = response_builder
|
||||
.body(Body::from_stream(reqwest_response.bytes_stream()))
|
||||
|
@ -52,23 +52,27 @@ async fn custom_page(
|
||||
host_route
|
||||
.error_page
|
||||
.as_ref()
|
||||
.ok_or(RouteError::RouteNotFound())?
|
||||
.ok_or(RouteError::RouteNotFound())
|
||||
.with_context(|| "error page not found")?
|
||||
} else {
|
||||
host_route
|
||||
.not_found_page
|
||||
.as_ref()
|
||||
.ok_or(RouteError::RouteNotFound())?
|
||||
.ok_or(RouteError::RouteNotFound())
|
||||
.with_context(|| "not found page not found")?
|
||||
};
|
||||
|
||||
let root = host_route
|
||||
.root
|
||||
.as_ref()
|
||||
.ok_or(RouteError::InternalError())?;
|
||||
.ok_or(RouteError::InternalError())
|
||||
.with_context(|| "root not found")?;
|
||||
|
||||
let path = format!("{}/{}", root, page.page);
|
||||
|
||||
let status = StatusCode::from_str(page.status.to_string().as_ref())
|
||||
.map_err(|_| RouteError::BadRequest())?;
|
||||
.map_err(|_| RouteError::BadRequest())
|
||||
.with_context(|| format!("status code not found: {}", page.status))?;
|
||||
|
||||
tracing::debug!("custom not found path: {:?}", path);
|
||||
|
||||
@ -125,11 +129,18 @@ pub async fn serve(
|
||||
// which is `host_route.location`
|
||||
let scheme = request.uri().scheme_str().unwrap_or("http");
|
||||
let port = parse_port_from_host(&host, scheme).ok_or(RouteError::BadRequest())?;
|
||||
let route_map = &HOSTS.get(&port).ok_or(RouteError::BadRequest())?.route_map;
|
||||
let route_map = &HOSTS
|
||||
.get(&port)
|
||||
.ok_or(RouteError::BadRequest())
|
||||
.with_context(|| {
|
||||
format!("Hosts not found for port: {port}, host: {host}, scheme: {scheme}")
|
||||
})?
|
||||
.route_map;
|
||||
debug!("Route map entries: {:?}", route_map);
|
||||
let host_route = route_map
|
||||
.get(&parent_path)
|
||||
.ok_or(RouteError::RouteNotFound())?;
|
||||
.ok_or(RouteError::RouteNotFound())
|
||||
.with_context(|| format!("route not found: {parent_path}"))?;
|
||||
debug!("route: {:?}", host_route);
|
||||
// after route found
|
||||
// check static file root configuration
|
||||
|
@ -71,7 +71,6 @@ pub async fn add_headers(Host(host): Host, req: Request, next: Next) -> impl Int
|
||||
debug!("port {:?}", port);
|
||||
let mut res = next.run(req).await;
|
||||
let req_headers = res.headers_mut();
|
||||
// let host = HOSTS.read().await;
|
||||
let Some(host) = HOSTS.get(&port) else {
|
||||
return res;
|
||||
};
|
||||
|
Reference in New Issue
Block a user