feat: apply rate limiting to Register, Login, and upload endpoints

- Remove unused tags.rs and related re-exports
- Convert strict_limit/upload_limit from Layer to manual check functions
- Add IP-based rate limiting checks to Register, Login, and upload_image
- Keep general_limit as global middleware for all other routes
This commit is contained in:
xfy 2026-06-08 17:30:26 +08:00
parent dd8477c7a3
commit e74b9f3c39
6 changed files with 100 additions and 112 deletions

View File

@ -53,6 +53,21 @@ pub async fn register(
email: String,
password: String,
) -> Result<AuthResponse, ServerFnError> {
#[cfg(feature = "server")]
{
if let Some(ctx) = dioxus::fullstack::FullstackContext::current() {
let parts = ctx.parts_mut();
let ip = crate::api::rate_limit::get_client_ip(&parts.headers);
if let Err(msg) = crate::api::rate_limit::check_strict_limit(&ip) {
return Ok(AuthResponse {
success: false,
message: msg,
token: None,
});
}
}
}
if let Err(e) = validate_username(&username) {
return Ok(AuthResponse {
success: false,
@ -126,6 +141,21 @@ pub async fn register(
#[server(Login, "/api")]
pub async fn login(username: String, password: String) -> Result<AuthResponse, ServerFnError> {
#[cfg(feature = "server")]
{
if let Some(ctx) = dioxus::fullstack::FullstackContext::current() {
let parts = ctx.parts_mut();
let ip = crate::api::rate_limit::get_client_ip(&parts.headers);
if let Err(msg) = crate::api::rate_limit::check_strict_limit(&ip) {
return Ok(AuthResponse {
success: false,
message: msg,
token: None,
});
}
}
}
let client = get_conn().await.map_err(db_conn_error)?;
let row = match client

View File

@ -4,6 +4,5 @@ pub mod markdown;
pub mod posts;
pub mod rate_limit;
pub mod slug;
pub mod tags;
pub mod upload;
pub mod utils;

View File

@ -17,8 +17,7 @@ use crate::utils::text::{auto_summary, count_words};
pub use crate::api::markdown::render_markdown_enhanced;
#[cfg(feature = "server")]
pub use crate::api::slug::{ensure_unique_slug, is_valid_slug, slugify};
#[cfg(feature = "server")]
pub use crate::api::tags::{get_post_tags, set_post_tags};
// ============================================================================
// Server-side helpers (only compiled when server feature is enabled)

View File

@ -3,6 +3,10 @@
#[cfg(feature = "server")]
use std::sync::Arc;
#[cfg(feature = "server")]
use std::sync::LazyLock;
#[cfg(feature = "server")]
use std::num::NonZeroU32;
#[cfg(feature = "server")]
use tower_governor::governor::GovernorConfigBuilder;
#[cfg(feature = "server")]
use tower_governor::GovernorLayer;
@ -10,8 +14,10 @@ use tower_governor::GovernorLayer;
use tower_governor::key_extractor::SmartIpKeyExtractor;
#[cfg(feature = "server")]
use governor::middleware::NoOpMiddleware;
#[cfg(feature = "server")]
use governor::{DefaultKeyedRateLimiter, Quota, RateLimiter};
/// 通用限流配置:每分钟 60 请求
/// 通用限流配置:每秒 1 请求,突发 30
#[cfg(feature = "server")]
pub fn general_limit() -> GovernorLayer<SmartIpKeyExtractor, NoOpMiddleware> {
let config = GovernorConfigBuilder::default()
@ -25,30 +31,59 @@ pub fn general_limit() -> GovernorLayer<SmartIpKeyExtractor, NoOpMiddleware> {
}
}
/// 严格限流配置:每分钟 10 请求(用于登录、注册等敏感操作)
// 严格限流:每秒 1突发 5(用于登录、注册等敏感操作)
#[cfg(feature = "server")]
pub fn strict_limit() -> GovernorLayer<SmartIpKeyExtractor, NoOpMiddleware> {
let config = GovernorConfigBuilder::default()
.per_second(1)
.burst_size(5)
.key_extractor(SmartIpKeyExtractor)
.finish()
.unwrap();
GovernorLayer {
config: Arc::new(config),
static STRICT_LIMITER: LazyLock<DefaultKeyedRateLimiter<String>> = LazyLock::new(|| {
RateLimiter::keyed(
Quota::per_second(NonZeroU32::new(1).unwrap())
.allow_burst(NonZeroU32::new(5).unwrap())
)
});
// 上传限流:每秒 1突发 10
#[cfg(feature = "server")]
static UPLOAD_LIMITER: LazyLock<DefaultKeyedRateLimiter<String>> = LazyLock::new(|| {
RateLimiter::keyed(
Quota::per_second(NonZeroU32::new(1).unwrap())
.allow_burst(NonZeroU32::new(10).unwrap())
)
});
/// 从请求 headers 中提取客户端 IP
#[cfg(feature = "server")]
pub fn get_client_ip(headers: &http::HeaderMap) -> String {
// 1. X-Forwarded-For
if let Some(ip) = headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next())
{
return ip.trim().to_string();
}
// 2. X-Real-Ip
if let Some(ip) = headers
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
{
return ip.trim().to_string();
}
"unknown".to_string()
}
/// 上传限流配置:每分钟 20 请求
/// 检查严格限流(用于登录、注册等敏感操作)
#[cfg(feature = "server")]
pub fn upload_limit() -> GovernorLayer<SmartIpKeyExtractor, NoOpMiddleware> {
let config = GovernorConfigBuilder::default()
.per_second(1)
.burst_size(10)
.key_extractor(SmartIpKeyExtractor)
.finish()
.unwrap();
GovernorLayer {
config: Arc::new(config),
}
pub fn check_strict_limit(ip: &str) -> Result<(), String> {
STRICT_LIMITER
.check_key(&ip.to_string())
.map(|_| ())
.map_err(|_| "请求过于频繁,请稍后再试".to_string())
}
/// 检查上传限流
#[cfg(feature = "server")]
pub fn check_upload_limit(ip: &str) -> Result<(), String> {
UPLOAD_LIMITER
.check_key(&ip.to_string())
.map(|_| ())
.map_err(|_| "上传过于频繁,请稍后再试".to_string())
}

View File

@ -1,87 +0,0 @@
#![allow(clippy::unused_unit, deprecated, unused_imports)]
use dioxus::prelude::*;
#[cfg(feature = "server")]
use crate::api::utils::query_error;
#[cfg(feature = "server")]
pub async fn set_post_tags(
client: &tokio_postgres::Client,
post_id: i32,
tags: &[String],
) -> Result<(), ServerFnError> {
// Remove existing tags
client
.execute("DELETE FROM post_tags WHERE post_id = $1", &[&post_id])
.await
.map_err(|e| {
tracing::error!("delete tag links failed: {:?}", e);
ServerFnError::new(format!("删除标签关联失败: {}", e))
})?;
for tag_name in tags {
let tag_name = tag_name.trim();
if tag_name.is_empty() {
continue;
}
// Insert or get tag
let tag_id: i32 = {
let row = client
.query_opt(
"INSERT INTO tags (name) VALUES ($1) ON CONFLICT (name) DO NOTHING RETURNING id",
&[&tag_name],
)
.await
.map_err(|e| {
tracing::error!("create tag failed: {:?}", e);
ServerFnError::new(format!("创建标签失败: {}", e))
})?;
match row {
Some(r) => r.get(0),
None => {
// Tag already exists, fetch its id
let row = client
.query_opt("SELECT id FROM tags WHERE name = $1", &[&tag_name])
.await
.map_err(|e| {
tracing::error!("query tag failed: {:?}", e);
ServerFnError::new(format!("查询标签失败: {}", e))
})?;
row.map(|r| r.get(0))
.ok_or_else(|| ServerFnError::new(format!("标签不存在: {}", tag_name)))?
}
}
};
client
.execute(
"INSERT INTO post_tags (post_id, tag_id) VALUES ($1, $2)",
&[&post_id, &tag_id],
)
.await
.map_err(|e| {
tracing::error!("link tag failed: {:?}", e);
ServerFnError::new(format!("关联标签失败: {}", e))
})?;
}
Ok(())
}
#[cfg(feature = "server")]
pub async fn get_post_tags(client: &tokio_postgres::Client, post_id: i32) -> Vec<String> {
let rows = client
.query(
"SELECT t.name FROM tags t JOIN post_tags pt ON t.id = pt.tag_id WHERE pt.post_id = $1 ORDER BY t.name",
&[&post_id],
)
.await;
match rows {
Ok(rows) => rows.iter().map(|r| r.get(0)).collect(),
Err(_) => vec![],
}
}

View File

@ -22,6 +22,18 @@ pub async fn upload_image(
headers: HeaderMap,
mut multipart: Multipart,
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
// 0. Rate limit check
let ip = crate::api::rate_limit::get_client_ip(&headers);
if let Err(msg) = crate::api::rate_limit::check_upload_limit(&ip) {
return Err((
StatusCode::TOO_MANY_REQUESTS,
Json(json!({
"success": false,
"error": msg
})),
));
}
// 1. Extract session from cookie
let cookie_header = headers
.get("cookie")