diff --git a/.env.example b/.env.example index d283bad..2e896fe 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,12 @@ DATABASE_URL=postgres://postgres:postgres@localhost:5432/yggdrasil RUST_LOG=info + +# Rate Limit — 严格限流(登录、注册) +RATE_LIMIT_STRICT_PER_SEC=1 +RATE_LIMIT_STRICT_BURST=5 +# Rate Limit — 上传限流(图片上传) +RATE_LIMIT_UPLOAD_PER_SEC=2 +RATE_LIMIT_UPLOAD_BURST=15 +# Rate Limit — 图片访问限流(/uploads/*) +RATE_LIMIT_IMAGE_PER_SEC=10 +RATE_LIMIT_IMAGE_BURST=50 diff --git a/src/api/image.rs b/src/api/image.rs index 63ce695..ed0e5b2 100644 --- a/src/api/image.rs +++ b/src/api/image.rs @@ -224,8 +224,19 @@ fn is_path_safe(path: &str) -> bool { } #[cfg(feature = "server")] -pub async fn serve_image(Path(path): Path, Query(params): Query) -> Response { - // Path traversal protection +use axum::http::HeaderMap; + +#[cfg(feature = "server")] +pub async fn serve_image( + Path(path): Path, + Query(params): Query, + headers: HeaderMap, +) -> Response { + let ip = crate::api::rate_limit::get_client_ip(&headers); + if let Err(status) = crate::api::rate_limit::check_image_limit(&ip) { + return status.into_response(); + } + if !is_path_safe(&path) { return StatusCode::FORBIDDEN.into_response(); } diff --git a/src/api/rate_limit.rs b/src/api/rate_limit.rs index c208294..f579e10 100644 --- a/src/api/rate_limit.rs +++ b/src/api/rate_limit.rs @@ -1,58 +1,57 @@ #![allow(clippy::unused_unit)] -#[cfg(feature = "server")] -use std::sync::Arc; #[cfg(feature = "server")] use std::sync::LazyLock; #[cfg(feature = "server")] use std::num::NonZeroU32; #[cfg(feature = "server")] -use tower_governor::governor::GovernorConfigBuilder; -#[cfg(feature = "server")] -use tower_governor::GovernorLayer; -#[cfg(feature = "server")] -use tower_governor::key_extractor::SmartIpKeyExtractor; -#[cfg(feature = "server")] -use governor::middleware::NoOpMiddleware; -#[cfg(feature = "server")] use governor::{DefaultKeyedRateLimiter, Quota, RateLimiter}; - -/// 通用限流配置:每秒 1 请求,突发 30 #[cfg(feature = "server")] -pub fn general_limit() -> GovernorLayer { - let config = GovernorConfigBuilder::default() - .per_second(1) - .burst_size(30) - .key_extractor(SmartIpKeyExtractor) - .finish() - .unwrap(); - GovernorLayer { - config: Arc::new(config), - } +use axum::http::StatusCode; + +#[cfg(feature = "server")] +fn env_or(key: &str, default: u32) -> NonZeroU32 { + let val = std::env::var(key) + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(default); + NonZeroU32::new(val.max(1)).unwrap() } -// 严格限流:每秒 1,突发 5(用于登录、注册等敏感操作) #[cfg(feature = "server")] static STRICT_LIMITER: LazyLock> = LazyLock::new(|| { RateLimiter::keyed( - Quota::per_second(NonZeroU32::new(1).unwrap()) - .allow_burst(NonZeroU32::new(5).unwrap()) + Quota::per_second(env_or("RATE_LIMIT_STRICT_PER_SEC", 1)) + .allow_burst(env_or("RATE_LIMIT_STRICT_BURST", 5)), ) }); -// 上传限流:每秒 1,突发 10 #[cfg(feature = "server")] static UPLOAD_LIMITER: LazyLock> = LazyLock::new(|| { RateLimiter::keyed( - Quota::per_second(NonZeroU32::new(1).unwrap()) - .allow_burst(NonZeroU32::new(10).unwrap()) + Quota::per_second(env_or("RATE_LIMIT_UPLOAD_PER_SEC", 2)) + .allow_burst(env_or("RATE_LIMIT_UPLOAD_BURST", 15)), ) }); -/// 从请求 headers 中提取客户端 IP +#[cfg(feature = "server")] +static IMAGE_LIMITER: LazyLock> = LazyLock::new(|| { + RateLimiter::keyed( + Quota::per_second(env_or("RATE_LIMIT_IMAGE_PER_SEC", 10)) + .allow_burst(env_or("RATE_LIMIT_IMAGE_BURST", 50)), + ) +}); + +#[cfg(feature = "server")] +pub fn check_image_limit(ip: &str) -> Result<(), StatusCode> { + IMAGE_LIMITER + .check_key(&ip.to_string()) + .map(|_| ()) + .map_err(|_| StatusCode::TOO_MANY_REQUESTS) +} + #[cfg(feature = "server")] pub fn get_client_ip(headers: &http::HeaderMap) -> String { - // 1. X-Forwarded-For if let Some(ip) = headers .get("x-forwarded-for") .and_then(|v| v.to_str().ok()) @@ -60,7 +59,6 @@ pub fn get_client_ip(headers: &http::HeaderMap) -> String { { return ip.trim().to_string(); } - // 2. X-Real-Ip if let Some(ip) = headers .get("x-real-ip") .and_then(|v| v.to_str().ok()) @@ -70,7 +68,6 @@ pub fn get_client_ip(headers: &http::HeaderMap) -> String { "unknown".to_string() } -/// 检查严格限流(用于登录、注册等敏感操作) #[cfg(feature = "server")] pub fn check_strict_limit(ip: &str) -> Result<(), String> { STRICT_LIMITER @@ -79,7 +76,6 @@ pub fn check_strict_limit(ip: &str) -> Result<(), String> { .map_err(|_| "请求过于频繁,请稍后再试".to_string()) } -/// 检查上传限流 #[cfg(feature = "server")] pub fn check_upload_limit(ip: &str) -> Result<(), String> { UPLOAD_LIMITER diff --git a/src/main.rs b/src/main.rs index 7bc4cc3..9c43cb2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -44,15 +44,22 @@ fn main() { dioxus::server::IncrementalRendererConfig::default() .invalidate_after(std::time::Duration::from_secs(300)), ); - let router = axum::Router::new() + let api_routes = axum::Router::new() .route( "/api/upload", axum::routing::post(crate::api::upload::upload_image) .layer(axum::extract::DefaultBodyLimit::disable()), - ) - .route("/uploads/{*path}", axum::routing::get(crate::api::image::serve_image)) - .layer(crate::api::rate_limit::general_limit()) - .serve_dioxus_application(config, router::AppRouter) + ); + + let static_routes = axum::Router::new() + .route("/uploads/{*path}", axum::routing::get(crate::api::image::serve_image)); + + let dioxus_app = axum::Router::new() + .serve_dioxus_application(config, router::AppRouter); + + let router = api_routes + .merge(static_routes) + .merge(dioxus_app) .layer( TraceLayer::new_for_http() .make_span_with(