From e74b9f3c390029647e8ce11c6b91e7350e10b877 Mon Sep 17 00:00:00 2001 From: xfy Date: Mon, 8 Jun 2026 17:30:26 +0800 Subject: [PATCH] feat: apply rate limiting to Register, Login, and upload endpoints - Remove unused tags.rs and related re-exports - Convert strict_limit/upload_limit from Layer to manual check functions - Add IP-based rate limiting checks to Register, Login, and upload_image - Keep general_limit as global middleware for all other routes --- src/api/auth.rs | 30 +++++++++++++++ src/api/mod.rs | 1 - src/api/posts.rs | 3 +- src/api/rate_limit.rs | 79 ++++++++++++++++++++++++++++----------- src/api/tags.rs | 87 ------------------------------------------- src/api/upload.rs | 12 ++++++ 6 files changed, 100 insertions(+), 112 deletions(-) delete mode 100644 src/api/tags.rs diff --git a/src/api/auth.rs b/src/api/auth.rs index af68c99..4d63eae 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -53,6 +53,21 @@ pub async fn register( email: String, password: String, ) -> Result { + #[cfg(feature = "server")] + { + if let Some(ctx) = dioxus::fullstack::FullstackContext::current() { + let parts = ctx.parts_mut(); + let ip = crate::api::rate_limit::get_client_ip(&parts.headers); + if let Err(msg) = crate::api::rate_limit::check_strict_limit(&ip) { + return Ok(AuthResponse { + success: false, + message: msg, + token: None, + }); + } + } + } + if let Err(e) = validate_username(&username) { return Ok(AuthResponse { success: false, @@ -126,6 +141,21 @@ pub async fn register( #[server(Login, "/api")] pub async fn login(username: String, password: String) -> Result { + #[cfg(feature = "server")] + { + if let Some(ctx) = dioxus::fullstack::FullstackContext::current() { + let parts = ctx.parts_mut(); + let ip = crate::api::rate_limit::get_client_ip(&parts.headers); + if let Err(msg) = crate::api::rate_limit::check_strict_limit(&ip) { + return Ok(AuthResponse { + success: false, + message: msg, + token: None, + }); + } + } + } + let client = get_conn().await.map_err(db_conn_error)?; let row = match client diff --git a/src/api/mod.rs b/src/api/mod.rs index 2e94105..e918db9 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -4,6 +4,5 @@ pub mod markdown; pub mod posts; pub mod rate_limit; pub mod slug; -pub mod tags; pub mod upload; pub mod utils; diff --git a/src/api/posts.rs b/src/api/posts.rs index cb45d4a..f66e82d 100644 --- a/src/api/posts.rs +++ b/src/api/posts.rs @@ -17,8 +17,7 @@ use crate::utils::text::{auto_summary, count_words}; pub use crate::api::markdown::render_markdown_enhanced; #[cfg(feature = "server")] pub use crate::api::slug::{ensure_unique_slug, is_valid_slug, slugify}; -#[cfg(feature = "server")] -pub use crate::api::tags::{get_post_tags, set_post_tags}; + // ============================================================================ // Server-side helpers (only compiled when server feature is enabled) diff --git a/src/api/rate_limit.rs b/src/api/rate_limit.rs index 1e3acbe..02fbcdf 100644 --- a/src/api/rate_limit.rs +++ b/src/api/rate_limit.rs @@ -3,6 +3,10 @@ #[cfg(feature = "server")] use std::sync::Arc; #[cfg(feature = "server")] +use std::sync::LazyLock; +#[cfg(feature = "server")] +use std::num::NonZeroU32; +#[cfg(feature = "server")] use tower_governor::governor::GovernorConfigBuilder; #[cfg(feature = "server")] use tower_governor::GovernorLayer; @@ -10,8 +14,10 @@ use tower_governor::GovernorLayer; use tower_governor::key_extractor::SmartIpKeyExtractor; #[cfg(feature = "server")] use governor::middleware::NoOpMiddleware; +#[cfg(feature = "server")] +use governor::{DefaultKeyedRateLimiter, Quota, RateLimiter}; -/// 通用限流配置:每分钟 60 请求 +/// 通用限流配置:每秒 1 请求,突发 30 #[cfg(feature = "server")] pub fn general_limit() -> GovernorLayer { let config = GovernorConfigBuilder::default() @@ -25,30 +31,59 @@ pub fn general_limit() -> GovernorLayer { } } -/// 严格限流配置:每分钟 10 请求(用于登录、注册等敏感操作) +// 严格限流:每秒 1,突发 5(用于登录、注册等敏感操作) #[cfg(feature = "server")] -pub fn strict_limit() -> GovernorLayer { - let config = GovernorConfigBuilder::default() - .per_second(1) - .burst_size(5) - .key_extractor(SmartIpKeyExtractor) - .finish() - .unwrap(); - GovernorLayer { - config: Arc::new(config), +static STRICT_LIMITER: LazyLock> = LazyLock::new(|| { + RateLimiter::keyed( + Quota::per_second(NonZeroU32::new(1).unwrap()) + .allow_burst(NonZeroU32::new(5).unwrap()) + ) +}); + +// 上传限流:每秒 1,突发 10 +#[cfg(feature = "server")] +static UPLOAD_LIMITER: LazyLock> = LazyLock::new(|| { + RateLimiter::keyed( + Quota::per_second(NonZeroU32::new(1).unwrap()) + .allow_burst(NonZeroU32::new(10).unwrap()) + ) +}); + +/// 从请求 headers 中提取客户端 IP +#[cfg(feature = "server")] +pub fn get_client_ip(headers: &http::HeaderMap) -> String { + // 1. X-Forwarded-For + if let Some(ip) = headers + .get("x-forwarded-for") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.split(',').next()) + { + return ip.trim().to_string(); } + // 2. X-Real-Ip + if let Some(ip) = headers + .get("x-real-ip") + .and_then(|v| v.to_str().ok()) + { + return ip.trim().to_string(); + } + "unknown".to_string() } -/// 上传限流配置:每分钟 20 请求 +/// 检查严格限流(用于登录、注册等敏感操作) #[cfg(feature = "server")] -pub fn upload_limit() -> GovernorLayer { - let config = GovernorConfigBuilder::default() - .per_second(1) - .burst_size(10) - .key_extractor(SmartIpKeyExtractor) - .finish() - .unwrap(); - GovernorLayer { - config: Arc::new(config), - } +pub fn check_strict_limit(ip: &str) -> Result<(), String> { + STRICT_LIMITER + .check_key(&ip.to_string()) + .map(|_| ()) + .map_err(|_| "请求过于频繁,请稍后再试".to_string()) +} + +/// 检查上传限流 +#[cfg(feature = "server")] +pub fn check_upload_limit(ip: &str) -> Result<(), String> { + UPLOAD_LIMITER + .check_key(&ip.to_string()) + .map(|_| ()) + .map_err(|_| "上传过于频繁,请稍后再试".to_string()) } diff --git a/src/api/tags.rs b/src/api/tags.rs deleted file mode 100644 index eb0f567..0000000 --- a/src/api/tags.rs +++ /dev/null @@ -1,87 +0,0 @@ -#![allow(clippy::unused_unit, deprecated, unused_imports)] - -use dioxus::prelude::*; - -#[cfg(feature = "server")] -use crate::api::utils::query_error; - -#[cfg(feature = "server")] -pub async fn set_post_tags( - client: &tokio_postgres::Client, - post_id: i32, - tags: &[String], -) -> Result<(), ServerFnError> { - // Remove existing tags - client - .execute("DELETE FROM post_tags WHERE post_id = $1", &[&post_id]) - .await - .map_err(|e| { - tracing::error!("delete tag links failed: {:?}", e); - ServerFnError::new(format!("删除标签关联失败: {}", e)) - })?; - - for tag_name in tags { - let tag_name = tag_name.trim(); - if tag_name.is_empty() { - continue; - } - - // Insert or get tag - let tag_id: i32 = { - let row = client - .query_opt( - "INSERT INTO tags (name) VALUES ($1) ON CONFLICT (name) DO NOTHING RETURNING id", - &[&tag_name], - ) - .await - .map_err(|e| { - tracing::error!("create tag failed: {:?}", e); - ServerFnError::new(format!("创建标签失败: {}", e)) - })?; - - match row { - Some(r) => r.get(0), - None => { - // Tag already exists, fetch its id - let row = client - .query_opt("SELECT id FROM tags WHERE name = $1", &[&tag_name]) - .await - .map_err(|e| { - tracing::error!("query tag failed: {:?}", e); - ServerFnError::new(format!("查询标签失败: {}", e)) - })?; - row.map(|r| r.get(0)) - .ok_or_else(|| ServerFnError::new(format!("标签不存在: {}", tag_name)))? - } - } - }; - - client - .execute( - "INSERT INTO post_tags (post_id, tag_id) VALUES ($1, $2)", - &[&post_id, &tag_id], - ) - .await - .map_err(|e| { - tracing::error!("link tag failed: {:?}", e); - ServerFnError::new(format!("关联标签失败: {}", e)) - })?; - } - - Ok(()) -} - -#[cfg(feature = "server")] -pub async fn get_post_tags(client: &tokio_postgres::Client, post_id: i32) -> Vec { - let rows = client - .query( - "SELECT t.name FROM tags t JOIN post_tags pt ON t.id = pt.tag_id WHERE pt.post_id = $1 ORDER BY t.name", - &[&post_id], - ) - .await; - - match rows { - Ok(rows) => rows.iter().map(|r| r.get(0)).collect(), - Err(_) => vec![], - } -} diff --git a/src/api/upload.rs b/src/api/upload.rs index ed52b88..56e9d3d 100644 --- a/src/api/upload.rs +++ b/src/api/upload.rs @@ -22,6 +22,18 @@ pub async fn upload_image( headers: HeaderMap, mut multipart: Multipart, ) -> Result, (StatusCode, Json)> { + // 0. Rate limit check + let ip = crate::api::rate_limit::get_client_ip(&headers); + if let Err(msg) = crate::api::rate_limit::check_upload_limit(&ip) { + return Err(( + StatusCode::TOO_MANY_REQUESTS, + Json(json!({ + "success": false, + "error": msg + })), + )); + } + // 1. Extract session from cookie let cookie_header = headers .get("cookie")