feat: apply rate limiting to Register, Login, and upload endpoints
- Remove unused tags.rs and related re-exports - Convert strict_limit/upload_limit from Layer to manual check functions - Add IP-based rate limiting checks to Register, Login, and upload_image - Keep general_limit as global middleware for all other routes
This commit is contained in:
parent
dd8477c7a3
commit
e74b9f3c39
@ -53,6 +53,21 @@ pub async fn register(
|
||||
email: String,
|
||||
password: String,
|
||||
) -> Result<AuthResponse, ServerFnError> {
|
||||
#[cfg(feature = "server")]
|
||||
{
|
||||
if let Some(ctx) = dioxus::fullstack::FullstackContext::current() {
|
||||
let parts = ctx.parts_mut();
|
||||
let ip = crate::api::rate_limit::get_client_ip(&parts.headers);
|
||||
if let Err(msg) = crate::api::rate_limit::check_strict_limit(&ip) {
|
||||
return Ok(AuthResponse {
|
||||
success: false,
|
||||
message: msg,
|
||||
token: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = validate_username(&username) {
|
||||
return Ok(AuthResponse {
|
||||
success: false,
|
||||
@ -126,6 +141,21 @@ pub async fn register(
|
||||
|
||||
#[server(Login, "/api")]
|
||||
pub async fn login(username: String, password: String) -> Result<AuthResponse, ServerFnError> {
|
||||
#[cfg(feature = "server")]
|
||||
{
|
||||
if let Some(ctx) = dioxus::fullstack::FullstackContext::current() {
|
||||
let parts = ctx.parts_mut();
|
||||
let ip = crate::api::rate_limit::get_client_ip(&parts.headers);
|
||||
if let Err(msg) = crate::api::rate_limit::check_strict_limit(&ip) {
|
||||
return Ok(AuthResponse {
|
||||
success: false,
|
||||
message: msg,
|
||||
token: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let client = get_conn().await.map_err(db_conn_error)?;
|
||||
|
||||
let row = match client
|
||||
|
||||
@ -4,6 +4,5 @@ pub mod markdown;
|
||||
pub mod posts;
|
||||
pub mod rate_limit;
|
||||
pub mod slug;
|
||||
pub mod tags;
|
||||
pub mod upload;
|
||||
pub mod utils;
|
||||
|
||||
@ -17,8 +17,7 @@ use crate::utils::text::{auto_summary, count_words};
|
||||
pub use crate::api::markdown::render_markdown_enhanced;
|
||||
#[cfg(feature = "server")]
|
||||
pub use crate::api::slug::{ensure_unique_slug, is_valid_slug, slugify};
|
||||
#[cfg(feature = "server")]
|
||||
pub use crate::api::tags::{get_post_tags, set_post_tags};
|
||||
|
||||
|
||||
// ============================================================================
|
||||
// Server-side helpers (only compiled when server feature is enabled)
|
||||
|
||||
@ -3,6 +3,10 @@
|
||||
#[cfg(feature = "server")]
|
||||
use std::sync::Arc;
|
||||
#[cfg(feature = "server")]
|
||||
use std::sync::LazyLock;
|
||||
#[cfg(feature = "server")]
|
||||
use std::num::NonZeroU32;
|
||||
#[cfg(feature = "server")]
|
||||
use tower_governor::governor::GovernorConfigBuilder;
|
||||
#[cfg(feature = "server")]
|
||||
use tower_governor::GovernorLayer;
|
||||
@ -10,8 +14,10 @@ use tower_governor::GovernorLayer;
|
||||
use tower_governor::key_extractor::SmartIpKeyExtractor;
|
||||
#[cfg(feature = "server")]
|
||||
use governor::middleware::NoOpMiddleware;
|
||||
#[cfg(feature = "server")]
|
||||
use governor::{DefaultKeyedRateLimiter, Quota, RateLimiter};
|
||||
|
||||
/// 通用限流配置:每分钟 60 请求
|
||||
/// 通用限流配置:每秒 1 请求,突发 30
|
||||
#[cfg(feature = "server")]
|
||||
pub fn general_limit() -> GovernorLayer<SmartIpKeyExtractor, NoOpMiddleware> {
|
||||
let config = GovernorConfigBuilder::default()
|
||||
@ -25,30 +31,59 @@ pub fn general_limit() -> GovernorLayer<SmartIpKeyExtractor, NoOpMiddleware> {
|
||||
}
|
||||
}
|
||||
|
||||
/// 严格限流配置:每分钟 10 请求(用于登录、注册等敏感操作)
|
||||
// 严格限流:每秒 1,突发 5(用于登录、注册等敏感操作)
|
||||
#[cfg(feature = "server")]
|
||||
pub fn strict_limit() -> GovernorLayer<SmartIpKeyExtractor, NoOpMiddleware> {
|
||||
let config = GovernorConfigBuilder::default()
|
||||
.per_second(1)
|
||||
.burst_size(5)
|
||||
.key_extractor(SmartIpKeyExtractor)
|
||||
.finish()
|
||||
.unwrap();
|
||||
GovernorLayer {
|
||||
config: Arc::new(config),
|
||||
static STRICT_LIMITER: LazyLock<DefaultKeyedRateLimiter<String>> = LazyLock::new(|| {
|
||||
RateLimiter::keyed(
|
||||
Quota::per_second(NonZeroU32::new(1).unwrap())
|
||||
.allow_burst(NonZeroU32::new(5).unwrap())
|
||||
)
|
||||
});
|
||||
|
||||
// 上传限流:每秒 1,突发 10
|
||||
#[cfg(feature = "server")]
|
||||
static UPLOAD_LIMITER: LazyLock<DefaultKeyedRateLimiter<String>> = LazyLock::new(|| {
|
||||
RateLimiter::keyed(
|
||||
Quota::per_second(NonZeroU32::new(1).unwrap())
|
||||
.allow_burst(NonZeroU32::new(10).unwrap())
|
||||
)
|
||||
});
|
||||
|
||||
/// 从请求 headers 中提取客户端 IP
|
||||
#[cfg(feature = "server")]
|
||||
pub fn get_client_ip(headers: &http::HeaderMap) -> String {
|
||||
// 1. X-Forwarded-For
|
||||
if let Some(ip) = headers
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.split(',').next())
|
||||
{
|
||||
return ip.trim().to_string();
|
||||
}
|
||||
// 2. X-Real-Ip
|
||||
if let Some(ip) = headers
|
||||
.get("x-real-ip")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
return ip.trim().to_string();
|
||||
}
|
||||
"unknown".to_string()
|
||||
}
|
||||
|
||||
/// 上传限流配置:每分钟 20 请求
|
||||
/// 检查严格限流(用于登录、注册等敏感操作)
|
||||
#[cfg(feature = "server")]
|
||||
pub fn upload_limit() -> GovernorLayer<SmartIpKeyExtractor, NoOpMiddleware> {
|
||||
let config = GovernorConfigBuilder::default()
|
||||
.per_second(1)
|
||||
.burst_size(10)
|
||||
.key_extractor(SmartIpKeyExtractor)
|
||||
.finish()
|
||||
.unwrap();
|
||||
GovernorLayer {
|
||||
config: Arc::new(config),
|
||||
}
|
||||
pub fn check_strict_limit(ip: &str) -> Result<(), String> {
|
||||
STRICT_LIMITER
|
||||
.check_key(&ip.to_string())
|
||||
.map(|_| ())
|
||||
.map_err(|_| "请求过于频繁,请稍后再试".to_string())
|
||||
}
|
||||
|
||||
/// 检查上传限流
|
||||
#[cfg(feature = "server")]
|
||||
pub fn check_upload_limit(ip: &str) -> Result<(), String> {
|
||||
UPLOAD_LIMITER
|
||||
.check_key(&ip.to_string())
|
||||
.map(|_| ())
|
||||
.map_err(|_| "上传过于频繁,请稍后再试".to_string())
|
||||
}
|
||||
|
||||
@ -1,87 +0,0 @@
|
||||
#![allow(clippy::unused_unit, deprecated, unused_imports)]
|
||||
|
||||
use dioxus::prelude::*;
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
use crate::api::utils::query_error;
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
pub async fn set_post_tags(
|
||||
client: &tokio_postgres::Client,
|
||||
post_id: i32,
|
||||
tags: &[String],
|
||||
) -> Result<(), ServerFnError> {
|
||||
// Remove existing tags
|
||||
client
|
||||
.execute("DELETE FROM post_tags WHERE post_id = $1", &[&post_id])
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("delete tag links failed: {:?}", e);
|
||||
ServerFnError::new(format!("删除标签关联失败: {}", e))
|
||||
})?;
|
||||
|
||||
for tag_name in tags {
|
||||
let tag_name = tag_name.trim();
|
||||
if tag_name.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Insert or get tag
|
||||
let tag_id: i32 = {
|
||||
let row = client
|
||||
.query_opt(
|
||||
"INSERT INTO tags (name) VALUES ($1) ON CONFLICT (name) DO NOTHING RETURNING id",
|
||||
&[&tag_name],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("create tag failed: {:?}", e);
|
||||
ServerFnError::new(format!("创建标签失败: {}", e))
|
||||
})?;
|
||||
|
||||
match row {
|
||||
Some(r) => r.get(0),
|
||||
None => {
|
||||
// Tag already exists, fetch its id
|
||||
let row = client
|
||||
.query_opt("SELECT id FROM tags WHERE name = $1", &[&tag_name])
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("query tag failed: {:?}", e);
|
||||
ServerFnError::new(format!("查询标签失败: {}", e))
|
||||
})?;
|
||||
row.map(|r| r.get(0))
|
||||
.ok_or_else(|| ServerFnError::new(format!("标签不存在: {}", tag_name)))?
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
client
|
||||
.execute(
|
||||
"INSERT INTO post_tags (post_id, tag_id) VALUES ($1, $2)",
|
||||
&[&post_id, &tag_id],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("link tag failed: {:?}", e);
|
||||
ServerFnError::new(format!("关联标签失败: {}", e))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
pub async fn get_post_tags(client: &tokio_postgres::Client, post_id: i32) -> Vec<String> {
|
||||
let rows = client
|
||||
.query(
|
||||
"SELECT t.name FROM tags t JOIN post_tags pt ON t.id = pt.tag_id WHERE pt.post_id = $1 ORDER BY t.name",
|
||||
&[&post_id],
|
||||
)
|
||||
.await;
|
||||
|
||||
match rows {
|
||||
Ok(rows) => rows.iter().map(|r| r.get(0)).collect(),
|
||||
Err(_) => vec![],
|
||||
}
|
||||
}
|
||||
@ -22,6 +22,18 @@ pub async fn upload_image(
|
||||
headers: HeaderMap,
|
||||
mut multipart: Multipart,
|
||||
) -> Result<Json<Value>, (StatusCode, Json<Value>)> {
|
||||
// 0. Rate limit check
|
||||
let ip = crate::api::rate_limit::get_client_ip(&headers);
|
||||
if let Err(msg) = crate::api::rate_limit::check_upload_limit(&ip) {
|
||||
return Err((
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
Json(json!({
|
||||
"success": false,
|
||||
"error": msg
|
||||
})),
|
||||
));
|
||||
}
|
||||
|
||||
// 1. Extract session from cookie
|
||||
let cookie_header = headers
|
||||
.get("cookie")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user