diff --git a/src/api/auth.rs b/src/api/auth.rs index 44e9166..6e78fac 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -8,7 +8,7 @@ use crate::auth::{password, session}; #[cfg(feature = "server")] use crate::auth::session::get_session_from_ctx; #[cfg(feature = "server")] -use crate::api::utils::{db_conn_error, query_error}; +use crate::api::error::AppError; use crate::db::pool::get_conn; use crate::models::user::{PublicUser, User, UserRole}; @@ -90,12 +90,12 @@ pub async fn register( }); } - let client = get_conn().await.map_err(db_conn_error)?; + let client = get_conn().await.map_err(AppError::db_conn)?; let admin_count: i64 = client .query_one("SELECT COUNT(*) FROM users WHERE role = 'admin'", &[]) .await - .map_err(query_error)? + .map_err(AppError::query)? .get(0); if admin_count > 0 { @@ -106,10 +106,7 @@ pub async fn register( }); } - let password_hash = password::hash_password(&password).map_err(|e| { - tracing::error!("Register password hash failed: {:?}", e); - ServerFnError::new(format!("密码哈希失败: {}", e)) - })?; + let password_hash = password::hash_password(&password).map_err(|_| AppError::Internal("密码处理失败"))?; let result = client .query_one( @@ -156,7 +153,7 @@ pub async fn login(username: String, password: String) -> Result Result { - return Err(query_error(e)); + return Err(AppError::query(e).into()); } }; let password_hash: String = row.get("password_hash"); - let valid = password::verify_password(&password, &password_hash).map_err(|e| { - tracing::error!("Login password verify failed: {:?}", e); - ServerFnError::new(format!("密码验证失败: {}", e)) - })?; + let valid = password::verify_password(&password, &password_hash).map_err(|_| AppError::Internal("密码处理失败"))?; if !valid { return Ok(AuthResponse { @@ -202,10 +196,7 @@ pub async fn login(username: String, password: String) -> Result Result Result { let token = get_session_from_ctx(); - let client = get_conn().await.map_err(db_conn_error)?; + let client = get_conn().await.map_err(AppError::db_conn)?; - // 清除 cookie if let Some(ctx) = dioxus::fullstack::FullstackContext::current() { ctx.add_response_header( SET_COOKIE, @@ -238,15 +228,11 @@ pub async fn logout() -> Result { ); } - // 删除当前 session if let Some(t) = token { client .execute("DELETE FROM sessions WHERE token = $1", &[&t]) .await - .map_err(|e| { - tracing::error!("Logout session delete failed: {:?}", e); - ServerFnError::new(format!("删除 session 失败: {}", e)) - })?; + .map_err(AppError::query)?; } Ok(AuthResponse { @@ -263,7 +249,7 @@ pub struct CurrentUserResponse { #[cfg(feature = "server")] pub async fn get_user_by_token(token: &str) -> Result, ServerFnError> { - let client = get_conn().await.map_err(db_conn_error)?; + let client = get_conn().await.map_err(AppError::db_conn)?; let row = client .query_opt( @@ -274,7 +260,7 @@ pub async fn get_user_by_token(token: &str) -> Result, ServerFnErro &[&token], ) .await - .map_err(query_error)?; + .map_err(AppError::query)?; let user = match row { Some(row) => { @@ -302,10 +288,7 @@ pub async fn get_current_user() -> Result { None => return Ok(CurrentUserResponse { user: None }), }; - let user = match get_user_by_token(&token).await? { - Some(u) => Some(PublicUser::from(u)), - None => None, - }; + let user = get_user_by_token(&token).await?.map(PublicUser::from); Ok(CurrentUserResponse { user }) } diff --git a/src/api/error.rs b/src/api/error.rs new file mode 100644 index 0000000..1c8ccb2 --- /dev/null +++ b/src/api/error.rs @@ -0,0 +1,88 @@ +use dioxus::prelude::ServerFnError; + +#[derive(Debug)] +pub enum AppError { + Unauthorized(&'static str), + Forbidden(&'static str), + NotFound(&'static str), + DbConn(String), + Query(String), + Transaction(String), + Internal(&'static str), +} + +impl AppError { + pub fn db_conn(e: impl std::fmt::Display) -> Self { + tracing::error!("DB connection failed: {e}"); + AppError::DbConn(e.to_string()) + } + + pub fn query(e: impl std::fmt::Display) -> Self { + tracing::error!("Query failed: {e}"); + AppError::Query(e.to_string()) + } + + pub fn tx(e: impl std::fmt::Display) -> Self { + tracing::error!("Transaction failed: {e}"); + AppError::Transaction(e.to_string()) + } +} + +impl From for ServerFnError { + fn from(err: AppError) -> ServerFnError { + let msg = match &err { + AppError::Unauthorized(m) => m.to_string(), + AppError::Forbidden(m) => m.to_string(), + AppError::NotFound(m) => m.to_string(), + AppError::DbConn(_) => "服务暂时不可用".to_string(), + AppError::Query(_) => "操作失败".to_string(), + AppError::Transaction(_) => "操作失败".to_string(), + AppError::Internal(m) => m.to_string(), + }; + ServerFnError::new(msg) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn unauthorized_message_passthrough() { + let err: ServerFnError = AppError::Unauthorized("未登录").into(); + let msg = err.to_string(); + assert!(msg.contains("未登录"), "expected '未登录' in: {msg}"); + } + + #[test] + fn db_conn_hides_internal_details() { + let err: ServerFnError = AppError::db_conn("connection refused on port 5432").into(); + let msg = err.to_string(); + assert!( + !msg.contains("5432"), + "should not leak internal details: {msg}" + ); + assert!(msg.contains("服务暂时不可用"), "expected generic message: {msg}"); + } + + #[test] + fn query_hides_sql_details() { + let err: ServerFnError = AppError::query("syntax error at SELECT * FROM").into(); + let msg = err.to_string(); + assert!(!msg.contains("SELECT"), "should not leak SQL: {msg}"); + } + + #[test] + fn forbidden_message_passthrough() { + let err: ServerFnError = AppError::Forbidden("权限不足").into(); + let msg = err.to_string(); + assert!(msg.contains("权限不足"), "expected '权限不足': {msg}"); + } + + #[test] + fn not_found_message_passthrough() { + let err: ServerFnError = AppError::NotFound("文章不存在").into(); + let msg = err.to_string(); + assert!(msg.contains("文章不存在"), "expected passthrough: {msg}"); + } +} diff --git a/src/api/mod.rs b/src/api/mod.rs index e918db9..8de4e81 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,8 +1,8 @@ pub mod auth; +pub mod error; pub mod image; pub mod markdown; pub mod posts; pub mod rate_limit; pub mod slug; pub mod upload; -pub mod utils; diff --git a/src/api/posts.rs b/src/api/posts.rs index 4c15856..ffd3e92 100644 --- a/src/api/posts.rs +++ b/src/api/posts.rs @@ -1,9 +1,9 @@ -#![allow(clippy::unused_unit, deprecated, unused_imports)] +#![allow(clippy::unused_unit, deprecated, unused_imports, clippy::too_many_arguments)] use dioxus::prelude::*; #[cfg(feature = "server")] -use crate::api::utils::{db_conn_error, query_error, tx_error}; +use crate::api::error::AppError; #[cfg(feature = "server")] use crate::auth::session::get_session_from_ctx; use crate::db::pool::get_conn; @@ -26,19 +26,16 @@ pub use crate::api::slug::{ensure_unique_slug, is_valid_slug, slugify}; // ============================================================================ #[cfg(feature = "server")] -async fn get_current_admin_user() -> Result { - let token = match get_session_from_ctx() { - Some(t) => t, - None => return Err(ServerFnError::new("未登录")), - }; +async fn get_current_admin_user() -> Result { + let token = get_session_from_ctx().ok_or(AppError::Unauthorized("未登录"))?; - let user = match crate::api::auth::get_user_by_token(&token).await? { - Some(u) => u, - None => return Err(ServerFnError::new("会话已过期")), - }; + let user = crate::api::auth::get_user_by_token(&token) + .await + .map_err(AppError::query)? + .ok_or(AppError::Unauthorized("会话已过期"))?; if user.role != UserRole::Admin { - return Err(ServerFnError::new("权限不足")); + return Err(AppError::Forbidden("权限不足")); } Ok(user) @@ -249,7 +246,7 @@ pub async fn create_post( _ => crate::api::slug::slugify(&title), }; - let mut client = get_conn().await.map_err(db_conn_error)?; + let mut client = get_conn().await.map_err(AppError::db_conn)?; let final_slug = crate::api::slug::ensure_unique_slug(&client, &base_slug, None).await?; let rendered = crate::api::markdown::render_markdown_enhanced(&content_md); @@ -266,7 +263,7 @@ pub async fn create_post( None }; - let tx = client.transaction().await.map_err(tx_error)?; + let tx = client.transaction().await.map_err(AppError::db_conn)?; let row = tx .query_one( @@ -286,10 +283,7 @@ pub async fn create_post( ], ) .await - .map_err(|e| { - tracing::error!("create post failed: {:?}", e); - ServerFnError::new(format!("创建文章失败: {}", e)) - })?; + .map_err(AppError::tx)?; let post_id: i32 = row.get(0); @@ -308,10 +302,7 @@ pub async fn create_post( &[&tag_name.as_str()], ) .await - .map_err(|e| { - tracing::error!("create tag failed: {:?}", e); - ServerFnError::new(format!("创建标签失败: {}", e)) - })?; + .map_err(AppError::tx)?; match row { Some(r) => r.get(0), @@ -319,13 +310,8 @@ pub async fn create_post( let row = tx .query_opt("SELECT id FROM tags WHERE name = $1", &[&tag_name.as_str()]) .await - .map_err(|e| { - tracing::error!("query tag failed: {:?}", e); - ServerFnError::new(format!("查询标签失败: {}", e)) - })?; - row.map(|r| r.get(0)).ok_or_else(|| { - ServerFnError::new(format!("标签不存在: {}", tag_name)) - })? + .map_err(AppError::query)?; + row.map(|r| r.get(0)).ok_or(AppError::NotFound("标签不存在"))? } } }; @@ -335,14 +321,11 @@ pub async fn create_post( &[&post_id, &tag_id], ) .await - .map_err(|e| { - tracing::error!("link tag failed: {:?}", e); - ServerFnError::new(format!("关联标签失败: {}", e)) - })?; + .map_err(AppError::tx)?; } } - tx.commit().await.map_err(tx_error)?; + tx.commit().await.map_err(AppError::tx)?; // Invalidate caches after successful creation #[cfg(feature = "server")] @@ -377,13 +360,12 @@ pub async fn update_post( ) -> Result { let user = get_current_admin_user().await?; - let mut client = get_conn().await.map_err(db_conn_error)?; + let mut client = get_conn().await.map_err(AppError::db_conn)?; - // Get old slug before updating (for cache invalidation) let old_slug: Option = client .query_opt("SELECT slug FROM posts WHERE id = $1", &[&post_id]) .await - .map_err(query_error)? + .map_err(AppError::query)? .map(|r| r.get(0)); let exists: bool = client @@ -392,7 +374,7 @@ pub async fn update_post( &[&post_id, &user.id], ) .await - .map_err(query_error)? + .map_err(AppError::query)? .is_some(); if !exists { @@ -429,9 +411,8 @@ pub async fn update_post( let post_status = PostStatus::from_str(&status).unwrap_or(PostStatus::Draft); let cover_image = cover_image.filter(|s| !s.trim().is_empty()); - let tx = client.transaction().await.map_err(tx_error)?; + let tx = client.transaction().await.map_err(AppError::db_conn)?; - // Get old tags before deleting them (for cache invalidation) let old_tags: Vec = { let rows = tx .query( @@ -439,7 +420,7 @@ pub async fn update_post( &[&post_id], ) .await - .map_err(query_error)?; + .map_err(AppError::query)?; rows.iter().map(|r| r.get(0)).collect() }; @@ -449,7 +430,7 @@ pub async fn update_post( &[&post_id], ) .await - .map_err(query_error)?; + .map_err(AppError::query)?; let published_at = if post_status == PostStatus::Published { let was_published = old_status_row @@ -487,10 +468,7 @@ pub async fn update_post( ], ) .await - .map_err(|e| { - tracing::error!("update post failed: {:?}", e); - ServerFnError::new(format!("更新文章失败: {}", e)) - })?; + .map_err(AppError::tx)?; let tags_cleaned: Vec = tags .into_iter() @@ -502,10 +480,7 @@ pub async fn update_post( tx.execute("DELETE FROM post_tags WHERE post_id = $1", &[&post_id]) .await - .map_err(|e| { - tracing::error!("delete old tags failed: {:?}", e); - ServerFnError::new(format!("删除旧标签失败: {}", e)) - })?; + .map_err(AppError::tx)?; for tag_name in &tags_cleaned { let tag_id: i32 = { @@ -515,10 +490,7 @@ pub async fn update_post( &[&tag_name.as_str()], ) .await - .map_err(|e| { - tracing::error!("create tag failed: {:?}", e); - ServerFnError::new(format!("创建标签失败: {}", e)) - })?; + .map_err(AppError::tx)?; match row { Some(r) => r.get(0), @@ -526,12 +498,9 @@ pub async fn update_post( let row = tx .query_opt("SELECT id FROM tags WHERE name = $1", &[&tag_name.as_str()]) .await - .map_err(|e| { - tracing::error!("query tag failed: {:?}", e); - ServerFnError::new(format!("查询标签失败: {}", e)) - })?; + .map_err(AppError::query)?; row.map(|r| r.get(0)) - .ok_or_else(|| ServerFnError::new(format!("标签不存在: {}", tag_name)))? + .ok_or(AppError::NotFound("标签不存在"))? } } }; @@ -541,13 +510,10 @@ pub async fn update_post( &[&post_id, &tag_id], ) .await - .map_err(|e| { - tracing::error!("link tag failed: {:?}", e); - ServerFnError::new(format!("关联标签失败: {}", e)) - })?; + .map_err(AppError::tx)?; } - tx.commit().await.map_err(tx_error)?; + tx.commit().await.map_err(AppError::tx)?; // Invalidate caches after successful update #[cfg(feature = "server")] @@ -586,7 +552,7 @@ pub async fn update_post( pub async fn get_post_by_id(post_id: i32) -> Result { let _user = get_current_admin_user().await?; - let client = get_conn().await.map_err(db_conn_error)?; + let client = get_conn().await.map_err(AppError::db_conn)?; let row = client .query_opt( @@ -602,7 +568,7 @@ pub async fn get_post_by_id(post_id: i32) -> Result Some(row_to_post_list(&client, &row).await), @@ -618,7 +584,7 @@ pub async fn get_post_by_slug(slug: String) -> Result Result Some(row_to_post_full(&client, &row).await), @@ -675,7 +641,7 @@ pub async fn list_published_posts( return Ok(PostListResponse { posts: cached }); } - let client = get_conn().await.map_err(db_conn_error)?; + let client = get_conn().await.map_err(AppError::db_conn)?; let offset = ((page - 1).max(0) as i64) * (per_page as i64); let limit = per_page as i64; @@ -695,7 +661,7 @@ pub async fn list_published_posts( &[&limit, &offset], ) .await - .map_err(query_error)?; + .map_err(AppError::query)?; let mut posts = Vec::new(); for row in &rows { @@ -710,7 +676,7 @@ pub async fn list_published_posts( pub async fn list_posts() -> Result { let _user = get_current_admin_user().await?; - let client = get_conn().await.map_err(db_conn_error)?; + let client = get_conn().await.map_err(AppError::db_conn)?; let rows = client .query( @@ -727,7 +693,7 @@ pub async fn list_posts() -> Result { &[], ) .await - .map_err(query_error)?; + .map_err(AppError::query)?; let mut posts = Vec::new(); for row in &rows { @@ -741,7 +707,7 @@ pub async fn list_posts() -> Result { pub async fn delete_post(post_id: i32) -> Result { let _user = get_current_admin_user().await?; - let client = get_conn().await.map_err(db_conn_error)?; + let client = get_conn().await.map_err(AppError::db_conn)?; let result = client .execute( @@ -749,10 +715,7 @@ pub async fn delete_post(post_id: i32) -> Result Result { return Ok(TagListResponse { tags: cached }); } - let client = get_conn().await.map_err(db_conn_error)?; + let client = get_conn().await.map_err(AppError::db_conn)?; let rows = client .query( @@ -796,7 +759,7 @@ pub async fn list_tags() -> Result { &[], ) .await - .map_err(query_error)?; + .map_err(AppError::query)?; let tags: Vec = rows .iter() @@ -817,7 +780,7 @@ pub async fn get_posts_by_tag(tag_name: String) -> Result Result Result { let _user = get_current_admin_user().await?; - let client = get_conn().await.map_err(db_conn_error)?; + let client = get_conn().await.map_err(AppError::db_conn)?; let total: i64 = client .query_one("SELECT COUNT(*) FROM posts WHERE deleted_at IS NULL", &[]) .await - .map_err(query_error)? + .map_err(AppError::query)? .get(0); let drafts: i64 = client @@ -869,7 +832,7 @@ pub async fn get_post_stats() -> Result { &[], ) .await - .map_err(query_error)? + .map_err(AppError::query)? .get(0); let published: i64 = client @@ -878,7 +841,7 @@ pub async fn get_post_stats() -> Result { &[], ) .await - .map_err(query_error)? + .map_err(AppError::query)? .get(0); let stats = PostStats { @@ -892,7 +855,7 @@ pub async fn get_post_stats() -> Result { #[server(SearchPosts, "/api")] pub async fn search_posts(query: String) -> Result { - let client = get_conn().await.map_err(db_conn_error)?; + let client = get_conn().await.map_err(AppError::db_conn)?; let q = query.trim(); if q.is_empty() { @@ -917,7 +880,7 @@ pub async fn search_posts(query: String) -> Result, ) -> Result { - use crate::api::utils::query_error; + use crate::api::error::AppError; let mut candidate = base.to_string(); let mut suffix = 2; @@ -54,7 +54,7 @@ pub async fn ensure_unique_slug( &[&candidate, &exclude], ) .await - .map_err(query_error)? + .map_err(AppError::query)? .is_some() } else { client @@ -63,7 +63,7 @@ pub async fn ensure_unique_slug( &[&candidate], ) .await - .map_err(query_error)? + .map_err(AppError::query)? .is_some() }; @@ -75,7 +75,7 @@ pub async fn ensure_unique_slug( suffix += 1; if candidate.len() > 200 { - return Err(ServerFnError::new("无法生成唯一 slug")); + return Err(AppError::Internal("无法生成唯一 slug").into()); } } } diff --git a/src/api/utils.rs b/src/api/utils.rs deleted file mode 100644 index 93c4036..0000000 --- a/src/api/utils.rs +++ /dev/null @@ -1,19 +0,0 @@ -#![allow(clippy::unused_unit)] - -#[cfg(feature = "server")] -pub fn db_conn_error(e: impl std::fmt::Display) -> dioxus::prelude::ServerFnError { - tracing::error!("DB connection failed: {}", e); - dioxus::prelude::ServerFnError::new(format!("数据库连接失败: {}", e)) -} - -#[cfg(feature = "server")] -pub fn query_error(e: impl std::fmt::Display) -> dioxus::prelude::ServerFnError { - tracing::error!("Query failed: {}", e); - dioxus::prelude::ServerFnError::new(format!("查询失败: {}", e)) -} - -#[cfg(feature = "server")] -pub fn tx_error(e: impl std::fmt::Display) -> dioxus::prelude::ServerFnError { - tracing::error!("Transaction failed: {}", e); - dioxus::prelude::ServerFnError::new(format!("事务失败: {}", e)) -} diff --git a/src/pages/admin/write.rs b/src/pages/admin/write.rs index 628023e..ce0fdeb 100644 --- a/src/pages/admin/write.rs +++ b/src/pages/admin/write.rs @@ -211,6 +211,7 @@ fn write_editor(post_id: Option) -> Element { let on_submit = move |_| { if title().trim().is_empty() { error.set(Some("标题不能为空".to_string())); + #[allow(clippy::needless_return)] return; }