460 lines
14 KiB
Rust
460 lines
14 KiB
Rust
//! 认证相关的 Dioxus server function 与辅助函数。
|
||
//!
|
||
//! 提供注册、登录、登出、获取当前用户等接口,
|
||
//! 通过 HttpOnly Cookie 维护会话,首个注册用户自动成为 admin。
|
||
//! 所有 server function 均在 `#[server(Name, "/api")]` 下注册,供客户端与服务端调用。
|
||
//! 仅在 `feature = "server"` 启用的服务端构建中执行数据库操作与 Cookie 写入。
|
||
|
||
#![allow(clippy::unused_unit, deprecated)]
|
||
|
||
use dioxus::prelude::*;
|
||
#[cfg(feature = "server")]
|
||
use http::header::{HeaderValue, SET_COOKIE};
|
||
|
||
#[cfg(feature = "server")]
|
||
use crate::api::error::AppError;
|
||
#[cfg(feature = "server")]
|
||
use crate::auth::session::get_session_from_ctx;
|
||
use crate::auth::{password, session};
|
||
use crate::db::pool::get_conn;
|
||
use crate::models::user::{PublicUser, User, UserRole};
|
||
|
||
#[allow(dead_code)]
|
||
fn validate_username(username: &str) -> Result<(), String> {
|
||
if username.len() < 3 || username.len() > 50 {
|
||
return Err("用户名长度必须在 3-50 字符之间".to_string());
|
||
}
|
||
if !username.chars().all(|c| c.is_alphanumeric() || c == '_') {
|
||
return Err("用户名只能包含字母、数字和下划线".to_string());
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
#[allow(dead_code)]
|
||
fn validate_email(email: &str) -> Result<(), String> {
|
||
let re = regex::Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap();
|
||
if !re.is_match(email) {
|
||
return Err("邮箱格式不正确".to_string());
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
#[allow(dead_code)]
|
||
fn validate_password(password: &str) -> Result<(), String> {
|
||
if password.len() < 8 {
|
||
return Err("密码长度至少 8 位".to_string());
|
||
}
|
||
Ok(())
|
||
}
|
||
|
||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||
/// 认证接口统一响应结构。
|
||
pub struct AuthResponse {
|
||
/// 操作是否成功。
|
||
pub success: bool,
|
||
/// 提示信息。
|
||
pub message: String,
|
||
/// 登录成功后的会话 token(已废弃,实际通过 Cookie 传递)。
|
||
pub token: Option<String>,
|
||
}
|
||
|
||
/// 用户注册。
|
||
///
|
||
/// 校验用户名、邮箱、密码,首个注册用户自动设为 admin;
|
||
/// 已有 admin 时返回 "Registration is closed"。
|
||
/// Dioxus server function,注册在 `/api` 路径下。
|
||
#[server(Register, "/api")]
|
||
pub async fn register(
|
||
username: String,
|
||
email: String,
|
||
password: String,
|
||
) -> Result<AuthResponse, ServerFnError> {
|
||
// 服务端构建时先进行严格限流检查。
|
||
#[cfg(feature = "server")]
|
||
{
|
||
if let Some(ctx) = dioxus::fullstack::FullstackContext::current() {
|
||
let parts = ctx.parts_mut();
|
||
let ip = crate::api::rate_limit::get_client_ip(&parts.headers);
|
||
if let Err(msg) = crate::api::rate_limit::check_strict_limit(&ip) {
|
||
return Ok(AuthResponse {
|
||
success: false,
|
||
message: msg,
|
||
token: None,
|
||
});
|
||
}
|
||
}
|
||
}
|
||
|
||
if let Err(e) = validate_username(&username) {
|
||
return Ok(AuthResponse {
|
||
success: false,
|
||
message: e,
|
||
token: None,
|
||
});
|
||
}
|
||
if let Err(e) = validate_email(&email) {
|
||
return Ok(AuthResponse {
|
||
success: false,
|
||
message: e,
|
||
token: None,
|
||
});
|
||
}
|
||
if let Err(e) = validate_password(&password) {
|
||
return Ok(AuthResponse {
|
||
success: false,
|
||
message: e,
|
||
token: None,
|
||
});
|
||
}
|
||
|
||
let client = get_conn().await.map_err(AppError::db_conn)?;
|
||
|
||
// 仅允许第一个注册用户注册为 admin,其余拒绝。
|
||
let admin_count: i64 = client
|
||
.query_one("SELECT COUNT(*) FROM users WHERE role = 'admin'", &[])
|
||
.await
|
||
.map_err(AppError::query)?
|
||
.get(0);
|
||
|
||
if admin_count > 0 {
|
||
return Ok(AuthResponse {
|
||
success: false,
|
||
message: "Registration is closed".to_string(),
|
||
token: None,
|
||
});
|
||
}
|
||
|
||
let password_hash =
|
||
password::hash_password(&password).map_err(|_| AppError::Internal("密码处理失败"))?;
|
||
|
||
let result = client
|
||
.query_one(
|
||
"INSERT INTO users (username, email, password_hash, role) VALUES ($1, $2, $3, 'admin') RETURNING id",
|
||
&[&username, &email, &password_hash],
|
||
)
|
||
.await;
|
||
|
||
match result {
|
||
Ok(_) => Ok(AuthResponse {
|
||
success: true,
|
||
message: "注册成功".to_string(),
|
||
token: None,
|
||
}),
|
||
Err(e) => {
|
||
let msg = if e.to_string().contains("unique constraint") {
|
||
"用户名或邮箱已存在".to_string()
|
||
} else {
|
||
format!("注册失败: {}", e)
|
||
};
|
||
Ok(AuthResponse {
|
||
success: false,
|
||
message: msg,
|
||
token: None,
|
||
})
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 用户登录。
|
||
///
|
||
/// 验证用户名/邮箱与密码,生成会话并写入 HttpOnly Cookie;
|
||
/// 同一用户活跃会话数超过 `MAX_SESSIONS_PER_USER` 时删除最早会话。
|
||
/// Dioxus server function,注册在 `/api` 路径下。
|
||
#[server(Login, "/api")]
|
||
pub async fn login(username: String, password: String) -> Result<AuthResponse, ServerFnError> {
|
||
// 服务端构建时先进行严格限流检查。
|
||
#[cfg(feature = "server")]
|
||
{
|
||
if let Some(ctx) = dioxus::fullstack::FullstackContext::current() {
|
||
let parts = ctx.parts_mut();
|
||
let ip = crate::api::rate_limit::get_client_ip(&parts.headers);
|
||
if let Err(msg) = crate::api::rate_limit::check_strict_limit(&ip) {
|
||
return Ok(AuthResponse {
|
||
success: false,
|
||
message: msg,
|
||
token: None,
|
||
});
|
||
}
|
||
}
|
||
}
|
||
|
||
let client = get_conn().await.map_err(AppError::db_conn)?;
|
||
|
||
let row = match client
|
||
.query_opt(
|
||
"SELECT id, username, email, password_hash, role, created_at FROM users WHERE username = $1 OR email = $1",
|
||
&[&username],
|
||
)
|
||
.await
|
||
{
|
||
Ok(Some(row)) => row,
|
||
Ok(None) => {
|
||
return Ok(AuthResponse {
|
||
success: false,
|
||
message: "Invalid credentials".to_string(),
|
||
token: None,
|
||
});
|
||
}
|
||
Err(e) => {
|
||
return Err(AppError::query(e).into());
|
||
}
|
||
};
|
||
|
||
let password_hash: String = row.get("password_hash");
|
||
let valid = password::verify_password(&password, &password_hash)
|
||
.map_err(|_| AppError::Internal("密码处理失败"))?;
|
||
|
||
if !valid {
|
||
return Ok(AuthResponse {
|
||
success: false,
|
||
message: "Invalid credentials".to_string(),
|
||
token: None,
|
||
});
|
||
}
|
||
|
||
let user_id: i32 = row.get("id");
|
||
let token = session::generate_token();
|
||
let token_hash = session::hash_token(&token);
|
||
let expires_at = session::default_expiry();
|
||
|
||
let max_sessions = std::env::var("MAX_SESSIONS_PER_USER")
|
||
.ok()
|
||
.and_then(|s| s.parse::<i64>().ok())
|
||
.unwrap_or(5)
|
||
.max(1);
|
||
|
||
// 查询当前活跃会话数,超出限制时删除最早的一条。
|
||
let session_count: i64 = client
|
||
.query_one(
|
||
"SELECT COUNT(*) FROM sessions WHERE user_id = $1 AND expires_at > NOW()",
|
||
&[&user_id],
|
||
)
|
||
.await
|
||
.map_err(AppError::query)?
|
||
.get(0);
|
||
|
||
if session_count >= max_sessions {
|
||
client
|
||
.execute(
|
||
"DELETE FROM sessions WHERE id IN (
|
||
SELECT id FROM sessions
|
||
WHERE user_id = $1 AND expires_at > NOW()
|
||
ORDER BY created_at ASC
|
||
LIMIT 1
|
||
)",
|
||
&[&user_id],
|
||
)
|
||
.await
|
||
.map_err(AppError::query)?;
|
||
}
|
||
|
||
client
|
||
.execute(
|
||
"INSERT INTO sessions (user_id, token_hash, user_agent, expires_at) VALUES ($1, $2, $3, $4)",
|
||
&[&user_id, &token_hash, &None::<String>, &expires_at],
|
||
)
|
||
.await
|
||
.map_err(AppError::query)?;
|
||
|
||
let cookie = session::session_cookie(&token, 30 * 24 * 60 * 60, session::cookie_secure());
|
||
// 通过 Dioxus FullstackContext 设置 HttpOnly Cookie 响应头。
|
||
if let Some(ctx) = dioxus::fullstack::FullstackContext::current() {
|
||
if let Ok(value) = HeaderValue::try_from(cookie.as_str()) {
|
||
ctx.add_response_header(SET_COOKIE, value);
|
||
}
|
||
}
|
||
|
||
Ok(AuthResponse {
|
||
success: true,
|
||
message: "登录成功".to_string(),
|
||
token: Some(token),
|
||
})
|
||
}
|
||
|
||
/// 用户登出。
|
||
///
|
||
/// 清空客户端 session Cookie,并删除数据库中对应会话记录。
|
||
/// Dioxus server function,注册在 `/api` 路径下。
|
||
#[server(Logout, "/api")]
|
||
pub async fn logout() -> Result<AuthResponse, ServerFnError> {
|
||
let token = get_session_from_ctx();
|
||
|
||
let client = get_conn().await.map_err(AppError::db_conn)?;
|
||
|
||
// 设置过期时间为 0 的 Cookie,通知浏览器清除会话。
|
||
let cookie = session::session_cookie("", 0, session::cookie_secure());
|
||
if let Some(ctx) = dioxus::fullstack::FullstackContext::current() {
|
||
if let Ok(value) = HeaderValue::try_from(cookie.as_str()) {
|
||
ctx.add_response_header(SET_COOKIE, value);
|
||
}
|
||
}
|
||
|
||
if let Some(t) = token {
|
||
let token_hash = session::hash_token(&t);
|
||
client
|
||
.execute("DELETE FROM sessions WHERE token_hash = $1", &[&token_hash])
|
||
.await
|
||
.map_err(AppError::query)?;
|
||
}
|
||
|
||
Ok(AuthResponse {
|
||
success: true,
|
||
message: "登出成功".to_string(),
|
||
token: None,
|
||
})
|
||
}
|
||
|
||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||
/// 当前用户查询响应。
|
||
pub struct CurrentUserResponse {
|
||
/// 当前已登录用户的公开信息;未登录时为 `None`。
|
||
pub user: Option<PublicUser>,
|
||
}
|
||
|
||
#[cfg(feature = "server")]
|
||
/// 根据会话 token 查询对应用户(含密码哈希等完整信息)。
|
||
///
|
||
/// 仅服务端内部使用,不会暴露给前端。
|
||
pub async fn get_user_by_token(token: &str) -> Result<Option<User>, ServerFnError> {
|
||
let client = get_conn().await.map_err(AppError::db_conn)?;
|
||
|
||
let token_hash = session::hash_token(token);
|
||
let row = client
|
||
.query_opt(
|
||
"SELECT u.id, u.username, u.email, u.password_hash, u.role, u.created_at
|
||
FROM sessions s
|
||
JOIN users u ON s.user_id = u.id
|
||
WHERE s.token_hash = $1 AND s.expires_at > NOW()",
|
||
&[&token_hash],
|
||
)
|
||
.await
|
||
.map_err(AppError::query)?;
|
||
|
||
let user = match row {
|
||
Some(row) => {
|
||
let role_str: String = row.get("role");
|
||
let role = UserRole::from_str(&role_str).unwrap_or(UserRole::Blocked);
|
||
Some(User {
|
||
id: row.get("id"),
|
||
username: row.get("username"),
|
||
email: row.get("email"),
|
||
password_hash: row.get("password_hash"),
|
||
role,
|
||
created_at: row.get("created_at"),
|
||
})
|
||
}
|
||
None => None,
|
||
};
|
||
|
||
Ok(user)
|
||
}
|
||
|
||
/// 获取当前登录用户的公开信息。
|
||
///
|
||
/// Dioxus server function,注册在 `/api` 路径下。
|
||
#[server(GetCurrentUser, "/api")]
|
||
pub async fn get_current_user() -> Result<CurrentUserResponse, ServerFnError> {
|
||
let token = match get_session_from_ctx() {
|
||
Some(t) => t,
|
||
None => return Ok(CurrentUserResponse { user: None }),
|
||
};
|
||
|
||
let user = get_user_by_token(&token).await?.map(PublicUser::from);
|
||
|
||
Ok(CurrentUserResponse { user })
|
||
}
|
||
|
||
#[cfg(feature = "server")]
|
||
/// 获取当前登录用户并要求其为 admin,否则返回 401/403。
|
||
///
|
||
/// 供其它服务端接口内部调用。
|
||
pub async fn get_current_admin_user() -> Result<User, AppError> {
|
||
let token = get_session_from_ctx().ok_or(AppError::Unauthorized("未登录"))?;
|
||
|
||
let user = get_user_by_token(&token)
|
||
.await
|
||
.map_err(AppError::query)?
|
||
.ok_or(AppError::Unauthorized("会话已过期"))?;
|
||
|
||
if user.role != UserRole::Admin {
|
||
return Err(AppError::Forbidden("权限不足"));
|
||
}
|
||
|
||
Ok(user)
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn validate_username_valid() {
|
||
assert!(validate_username("admin").is_ok());
|
||
assert!(validate_username("user_123").is_ok());
|
||
assert!(validate_username("abc").is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn validate_username_too_short() {
|
||
assert!(validate_username("ab").is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn validate_username_too_long() {
|
||
assert!(validate_username(&"a".repeat(51)).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn validate_username_max_length() {
|
||
assert!(validate_username(&"a".repeat(50)).is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn validate_username_special_chars() {
|
||
assert!(validate_username("user name").is_err());
|
||
assert!(validate_username("user@name").is_err());
|
||
assert!(validate_username("user-name").is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn validate_username_unicode() {
|
||
assert!(validate_username("用户名").is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn validate_email_valid() {
|
||
assert!(validate_email("user@example.com").is_ok());
|
||
assert!(validate_email("a.b+c@domain.co").is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn validate_email_invalid() {
|
||
assert!(validate_email("notanemail").is_err());
|
||
assert!(validate_email("@domain.com").is_err());
|
||
assert!(validate_email("user@").is_err());
|
||
assert!(validate_email("user@.com").is_err());
|
||
assert!(validate_email("").is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn validate_password_valid() {
|
||
assert!(validate_password("12345678").is_ok());
|
||
assert!(validate_password("a very long password with spaces").is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn validate_password_too_short() {
|
||
assert!(validate_password("1234567").is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn validate_password_exactly_8() {
|
||
assert!(validate_password("12345678").is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn validate_password_empty() {
|
||
assert!(validate_password("").is_err());
|
||
}
|
||
}
|