From 71ac08c373825ef2305a1ff326cf5c2b5e640927 Mon Sep 17 00:00:00 2001 From: xfy Date: Fri, 12 Jun 2026 17:06:14 +0800 Subject: [PATCH] feat(rate_limit): derive real client IP from X-Forwarded-For with TRUSTED_PROXY_COUNT --- src/api/rate_limit.rs | 121 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 98 insertions(+), 23 deletions(-) diff --git a/src/api/rate_limit.rs b/src/api/rate_limit.rs index 7fd4fd0..5afd42a 100644 --- a/src/api/rate_limit.rs +++ b/src/api/rate_limit.rs @@ -1,5 +1,3 @@ -#![allow(clippy::unused_unit)] - #[cfg(feature = "server")] use std::sync::LazyLock; #[cfg(feature = "server")] @@ -67,23 +65,55 @@ pub fn check_image_limit(ip: &str) -> Result<(), StatusCode> { } #[cfg(feature = "server")] -pub fn get_client_ip(headers: &http::HeaderMap) -> String { - if let Some(ip) = headers - .get("x-forwarded-for") - .and_then(|v| v.to_str().ok()) - .and_then(|v| v.split(',').next()) - { - return ip.trim().to_string(); +fn trusted_proxy_count() -> usize { + std::env::var("TRUSTED_PROXY_COUNT") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(0) +} + +#[cfg(feature = "server")] +fn ip_from_x_forwarded_for(value: &str, trusted_proxy_count: usize) -> Option { + let parts: Vec<&str> = value + .split(',') + .map(str::trim) + .filter(|s| !s.is_empty()) + .collect(); + if parts.is_empty() || trusted_proxy_count == 0 { + return None; } - if let Some(ip) = headers - .get("x-real-ip") - .and_then(|v| v.to_str().ok()) - { - return ip.trim().to_string(); + if parts.len() <= trusted_proxy_count { + return None; } + let idx = parts.len() - 1 - trusted_proxy_count; + parts.get(idx).map(|s| s.to_string()) +} + +#[cfg(feature = "server")] +pub fn get_client_ip_with_trusted( + headers: &http::HeaderMap, + trusted_proxy_count: usize, +) -> String { + if let Some(value) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok()) { + if let Some(ip) = ip_from_x_forwarded_for(value, trusted_proxy_count) { + return ip; + } + } + + if trusted_proxy_count > 0 { + if let Some(ip) = headers.get("x-real-ip").and_then(|v| v.to_str().ok()) { + return ip.trim().to_string(); + } + } + "unknown".to_string() } +#[cfg(feature = "server")] +pub fn get_client_ip(headers: &http::HeaderMap) -> String { + get_client_ip_with_trusted(headers, trusted_proxy_count()) +} + #[cfg(feature = "server")] pub fn check_strict_limit(ip: &str) -> Result<(), String> { STRICT_LIMITER @@ -106,30 +136,75 @@ mod tests { use http::HeaderMap; #[test] - fn get_client_ip_from_x_forwarded_for() { + fn get_client_ip_from_x_forwarded_for_with_one_trusted_proxy() { let mut headers = HeaderMap::new(); headers.insert("x-forwarded-for", "1.2.3.4, 5.6.7.8".parse().unwrap()); - assert_eq!(get_client_ip(&headers), "1.2.3.4"); + assert_eq!(get_client_ip_with_trusted(&headers, 1), "1.2.3.4"); } #[test] - fn get_client_ip_from_x_real_ip() { + fn get_client_ip_from_x_forwarded_for_with_two_trusted_proxies() { + let mut headers = HeaderMap::new(); + headers.insert( + "x-forwarded-for", + "1.2.3.4, 5.6.7.8, 9.10.11.12".parse().unwrap(), + ); + assert_eq!(get_client_ip_with_trusted(&headers, 2), "1.2.3.4"); + } + + #[test] + fn get_client_ip_ignores_x_forwarded_for_when_no_trusted_proxies() { + let mut headers = HeaderMap::new(); + headers.insert("x-forwarded-for", "1.2.3.4, 5.6.7.8".parse().unwrap()); + assert_eq!(get_client_ip_with_trusted(&headers, 0), "unknown"); + } + + #[test] + fn get_client_ip_from_x_real_ip_when_trusted() { let mut headers = HeaderMap::new(); headers.insert("x-real-ip", "9.8.7.6".parse().unwrap()); - assert_eq!(get_client_ip(&headers), "9.8.7.6"); + assert_eq!(get_client_ip_with_trusted(&headers, 1), "9.8.7.6"); } #[test] - fn get_client_ip_x_forwarded_for_takes_priority() { + fn get_client_ip_x_real_ip_ignored_when_not_trusted() { let mut headers = HeaderMap::new(); - headers.insert("x-forwarded-for", "1.1.1.1".parse().unwrap()); - headers.insert("x-real-ip", "2.2.2.2".parse().unwrap()); - assert_eq!(get_client_ip(&headers), "1.1.1.1"); + headers.insert("x-real-ip", "9.8.7.6".parse().unwrap()); + assert_eq!(get_client_ip_with_trusted(&headers, 0), "unknown"); + } + + #[test] + fn get_client_ip_x_forwarded_for_takes_priority_over_x_real_ip() { + let mut headers = HeaderMap::new(); + headers.insert("x-forwarded-for", "1.1.1.1, 2.2.2.2".parse().unwrap()); + headers.insert("x-real-ip", "3.3.3.3".parse().unwrap()); + assert_eq!(get_client_ip_with_trusted(&headers, 1), "1.1.1.1"); } #[test] fn get_client_ip_no_headers_returns_unknown() { let headers = HeaderMap::new(); - assert_eq!(get_client_ip(&headers), "unknown"); + assert_eq!(get_client_ip_with_trusted(&headers, 1), "unknown"); + } + + #[test] + fn get_client_ip_ignores_short_x_forwarded_for_list() { + let mut headers = HeaderMap::new(); + headers.insert("x-forwarded-for", "1.2.3.4".parse().unwrap()); + assert_eq!(get_client_ip_with_trusted(&headers, 2), "unknown"); + } + + #[test] + fn get_client_ip_ignores_x_forwarded_for_equal_to_proxy_count() { + let mut headers = HeaderMap::new(); + headers.insert("x-forwarded-for", "1.2.3.4, 5.6.7.8".parse().unwrap()); + assert_eq!(get_client_ip_with_trusted(&headers, 2), "unknown"); + } + + #[test] + fn get_client_ip_ignores_empty_x_forwarded_for_entries() { + let mut headers = HeaderMap::new(); + headers.insert("x-forwarded-for", " , 1.2.3.4 , 5.6.7.8 , ".parse().unwrap()); + assert_eq!(get_client_ip_with_trusted(&headers, 1), "1.2.3.4"); } }