From 14196c937cb9443971f394ae199043bb9fce54c5 Mon Sep 17 00:00:00 2001 From: xfy Date: Wed, 2 Jul 2025 00:55:52 +0800 Subject: [PATCH] feat: impl lua context --- src/http/lua.rs | 74 +++++++++++++++++++++++++++++++++++++++++++------ src/http/mod.rs | 14 +++++----- 2 files changed, 73 insertions(+), 15 deletions(-) diff --git a/src/http/lua.rs b/src/http/lua.rs index a96b97a..c35991b 100644 --- a/src/http/lua.rs +++ b/src/http/lua.rs @@ -2,11 +2,11 @@ use anyhow::Context; use axum::{ body::Body, extract::{Path, Request}, - response::IntoResponse, + response::{IntoResponse, Response}, }; use axum_extra::extract::Host; use http::Uri; -use mlua::{UserData, Value}; +use mlua::{UserData, UserDataMethods, UserDataRef}; use tokio::fs::{self}; use tracing::error; @@ -18,9 +18,41 @@ use crate::{ use super::error::RouteResult; #[derive(Clone, Debug)] -struct Candy {} +struct CandyRequest { + uri: Uri, +} +#[derive(Clone, Debug)] +struct CandyResponse { + status: u16, + body: String, +} +// HTTP 请求上下文,可在 Lua 中使用 +#[derive(Clone, Debug)] +struct RequestContext { + req: CandyRequest, + res: CandyResponse, +} -impl UserData for Candy {} +impl UserData for RequestContext { + fn add_methods>(methods: &mut M) { + // 获取请求路径 + methods.add_method("get_path", |_, this, ()| { + Ok(this.req.uri.path().to_string()) + }); + + // 设置响应状态码 + methods.add_method_mut("set_status", |_, this, status: u16| { + this.res.status = status; + Ok(()) + }); + + // 设置响应内容 + methods.add_method_mut("set_body", |_, this, body: String| { + this.res.body = body; + Ok(()) + }); + } +} pub async fn lua( req_uri: Uri, @@ -28,7 +60,7 @@ pub async fn lua( Host(host): Host, req: Request, ) -> RouteResult { - let req_path = req.uri().path(); + // let req_path = req.uri().path(); let scheme = req.uri().scheme_str().unwrap_or("http"); let port = parse_port_from_host(&host, scheme).ok_or(RouteError::BadRequest())?; @@ -56,11 +88,37 @@ pub async fn lua( let script = fs::read_to_string(lua_script) .await .with_context(|| format!("Failed to read lua script file: {lua_script}",))?; - let data: Value = lua.load(script).eval_async().await.map_err(|err| { + lua.globals() + .set( + "ctx", + RequestContext { + req: CandyRequest { uri: req_uri }, + res: CandyResponse { + status: 200, + body: "".to_string(), + }, + }, + ) + .map_err(|err| { + error!("Lua script {lua_script} exec error: {err}"); + RouteError::InternalError() + })?; + lua.load(script).exec_async().await.map_err(|err| { error!("Lua script {lua_script} exec error: {err}"); RouteError::InternalError() })?; - tracing::debug!("Lua script: {data:?}"); + // 获取修改后的上下文并返回响应 + let ctx: UserDataRef = lua.globals().get("ctx").map_err(|err| { + error!("Lua script {lua_script} exec error: {err}"); + RouteError::InternalError() + })?; + let res = ctx.res.clone(); - Ok(()) + let mut response = Response::builder(); + let body = Body::from(res.body); + response = response.status(res.status); + let response = response + .body(body) + .with_context(|| "Failed to build HTTP response with lua")?; + Ok(response) } diff --git a/src/http/mod.rs b/src/http/mod.rs index bf0fae8..280ec6d 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -4,11 +4,11 @@ use std::{ time::Duration, }; -use anyhow::{Context, anyhow}; +use anyhow::anyhow; use axum::{Router, extract::DefaultBodyLimit, middleware, routing::get}; use axum_server::{Handle, tls_rustls::RustlsConfig}; use dashmap::DashMap; -use mlua::{Lua, Value}; +use mlua::Lua; use tower::ServiceBuilder; use tower_http::{compression::CompressionLayer, timeout::TimeoutLayer}; use tracing::{debug, info, warn}; @@ -42,6 +42,8 @@ pub static HOSTS: LazyLock> = LazyLock::new(DashMap::n pub struct LuaEngine { pub lua: Lua, + /// Lua 共享字典 + #[allow(dead_code)] pub shared_table: Arc>, } impl LuaEngine { @@ -59,10 +61,8 @@ impl LuaEngine { .set( "set", lua.create_function(move |_, (key, value): (String, String)| { - let t = shared_table_get - .insert(key, value.clone()) - .ok_or(anyhow!("key not found"))?; - Ok(t.clone()) + shared_table_get.insert(key, value.clone()); + Ok(()) }) .expect("create set function failed"), ) @@ -97,7 +97,7 @@ impl LuaEngine { // 全局变量 candy lua.globals() .set("candy", module) - .expect("set candy failed"); + .expect("set candy table to lua engine failed"); Self { lua, shared_table } }