feat: impl lua context

This commit is contained in:
xfy
2025-07-02 00:55:52 +08:00
parent 091f0c8eb5
commit 14196c937c
2 changed files with 73 additions and 15 deletions

View File

@ -2,11 +2,11 @@ use anyhow::Context;
use axum::{ use axum::{
body::Body, body::Body,
extract::{Path, Request}, extract::{Path, Request},
response::IntoResponse, response::{IntoResponse, Response},
}; };
use axum_extra::extract::Host; use axum_extra::extract::Host;
use http::Uri; use http::Uri;
use mlua::{UserData, Value}; use mlua::{UserData, UserDataMethods, UserDataRef};
use tokio::fs::{self}; use tokio::fs::{self};
use tracing::error; use tracing::error;
@ -18,9 +18,41 @@ use crate::{
use super::error::RouteResult; use super::error::RouteResult;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct Candy {} struct CandyRequest {
uri: Uri,
}
#[derive(Clone, Debug)]
struct CandyResponse {
status: u16,
body: String,
}
// HTTP 请求上下文,可在 Lua 中使用
#[derive(Clone, Debug)]
struct RequestContext {
req: CandyRequest,
res: CandyResponse,
}
impl UserData for Candy {} impl UserData for RequestContext {
fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
// 获取请求路径
methods.add_method("get_path", |_, this, ()| {
Ok(this.req.uri.path().to_string())
});
// 设置响应状态码
methods.add_method_mut("set_status", |_, this, status: u16| {
this.res.status = status;
Ok(())
});
// 设置响应内容
methods.add_method_mut("set_body", |_, this, body: String| {
this.res.body = body;
Ok(())
});
}
}
pub async fn lua( pub async fn lua(
req_uri: Uri, req_uri: Uri,
@ -28,7 +60,7 @@ pub async fn lua(
Host(host): Host, Host(host): Host,
req: Request<Body>, req: Request<Body>,
) -> RouteResult<impl IntoResponse> { ) -> RouteResult<impl IntoResponse> {
let req_path = req.uri().path(); // let req_path = req.uri().path();
let scheme = req.uri().scheme_str().unwrap_or("http"); let scheme = req.uri().scheme_str().unwrap_or("http");
let port = parse_port_from_host(&host, scheme).ok_or(RouteError::BadRequest())?; let port = parse_port_from_host(&host, scheme).ok_or(RouteError::BadRequest())?;
@ -56,11 +88,37 @@ pub async fn lua(
let script = fs::read_to_string(lua_script) let script = fs::read_to_string(lua_script)
.await .await
.with_context(|| format!("Failed to read lua script file: {lua_script}",))?; .with_context(|| format!("Failed to read lua script file: {lua_script}",))?;
let data: Value = lua.load(script).eval_async().await.map_err(|err| { lua.globals()
.set(
"ctx",
RequestContext {
req: CandyRequest { uri: req_uri },
res: CandyResponse {
status: 200,
body: "".to_string(),
},
},
)
.map_err(|err| {
error!("Lua script {lua_script} exec error: {err}");
RouteError::InternalError()
})?;
lua.load(script).exec_async().await.map_err(|err| {
error!("Lua script {lua_script} exec error: {err}"); error!("Lua script {lua_script} exec error: {err}");
RouteError::InternalError() RouteError::InternalError()
})?; })?;
tracing::debug!("Lua script: {data:?}"); // 获取修改后的上下文并返回响应
let ctx: UserDataRef<RequestContext> = lua.globals().get("ctx").map_err(|err| {
error!("Lua script {lua_script} exec error: {err}");
RouteError::InternalError()
})?;
let res = ctx.res.clone();
Ok(()) let mut response = Response::builder();
let body = Body::from(res.body);
response = response.status(res.status);
let response = response
.body(body)
.with_context(|| "Failed to build HTTP response with lua")?;
Ok(response)
} }

View File

@ -4,11 +4,11 @@ use std::{
time::Duration, time::Duration,
}; };
use anyhow::{Context, anyhow}; use anyhow::anyhow;
use axum::{Router, extract::DefaultBodyLimit, middleware, routing::get}; use axum::{Router, extract::DefaultBodyLimit, middleware, routing::get};
use axum_server::{Handle, tls_rustls::RustlsConfig}; use axum_server::{Handle, tls_rustls::RustlsConfig};
use dashmap::DashMap; use dashmap::DashMap;
use mlua::{Lua, Value}; use mlua::Lua;
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tower_http::{compression::CompressionLayer, timeout::TimeoutLayer}; use tower_http::{compression::CompressionLayer, timeout::TimeoutLayer};
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
@ -42,6 +42,8 @@ pub static HOSTS: LazyLock<DashMap<u16, SettingHost>> = LazyLock::new(DashMap::n
pub struct LuaEngine { pub struct LuaEngine {
pub lua: Lua, pub lua: Lua,
/// Lua 共享字典
#[allow(dead_code)]
pub shared_table: Arc<DashMap<String, String>>, pub shared_table: Arc<DashMap<String, String>>,
} }
impl LuaEngine { impl LuaEngine {
@ -59,10 +61,8 @@ impl LuaEngine {
.set( .set(
"set", "set",
lua.create_function(move |_, (key, value): (String, String)| { lua.create_function(move |_, (key, value): (String, String)| {
let t = shared_table_get shared_table_get.insert(key, value.clone());
.insert(key, value.clone()) Ok(())
.ok_or(anyhow!("key not found"))?;
Ok(t.clone())
}) })
.expect("create set function failed"), .expect("create set function failed"),
) )
@ -97,7 +97,7 @@ impl LuaEngine {
// 全局变量 candy // 全局变量 candy
lua.globals() lua.globals()
.set("candy", module) .set("candy", module)
.expect("set candy failed"); .expect("set candy table to lua engine failed");
Self { lua, shared_table } Self { lua, shared_table }
} }