//! Active probing protection and IP ban management //! //! Detects suspicious patterns: //! - Failed PSK handshakes from same IP //! - Rapid connection attempts //! - Protocol fingerprinting probes use std::collections::HashMap; use std::net::IpAddr; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::RwLock; /// Probe detection thresholds pub const MAX_FAILED_HANDSHAKES: u32 = 5; // Ban after 5 failed attempts pub const FAILED_WINDOW_SECS: u64 = 60; // Within 60 seconds pub const RAPID_CONNECT_THRESHOLD: u32 = 20; // 20 connections per minute pub const BAN_DURATION_SECS: u64 = 3600; // 1 hour ban /// IP tracking entry #[derive(Debug, Clone)] struct IpEntry { failed_handshakes: u32, first_failure: Instant, connection_count: u32, first_connection: Instant, banned_until: Option, } impl Default for IpEntry { fn default() -> Self { let now = Instant::now(); Self { failed_handshakes: 0, first_failure: now, connection_count: 0, first_connection: now, banned_until: None, } } } /// Active probing detector and IP ban manager pub struct ProbeDetector { entries: Arc>>, ban_callback: Option>, } impl ProbeDetector { pub fn new() -> Self { Self { entries: Arc::new(RwLock::new(HashMap::new())), ban_callback: None, } } /// Set callback to execute when IP is banned (e.g., iptables) pub fn on_ban(mut self, callback: F) -> Self where F: Fn(IpAddr) + Send + Sync + 'static, { self.ban_callback = Some(Box::new(callback)); self } /// Record a failed handshake attempt pub async fn record_failure(&self, ip: IpAddr) -> bool { let mut entries = self.entries.write().await; let entry = entries.entry(ip).or_default(); let now = Instant::now(); // Reset counter if window expired if now.duration_since(entry.first_failure) > Duration::from_secs(FAILED_WINDOW_SECS) { entry.failed_handshakes = 0; entry.first_failure = now; } entry.failed_handshakes += 1; // Check threshold if entry.failed_handshakes >= MAX_FAILED_HANDSHAKES { self.ban_ip_internal(ip, entry); return true; } false } /// Record a connection attempt (even if successful) pub async fn record_connection(&self, ip: IpAddr) -> bool { let mut entries = self.entries.write().await; let entry = entries.entry(ip).or_default(); let now = Instant::now(); // Reset counter if window expired if now.duration_since(entry.first_connection) > Duration::from_secs(60) { entry.connection_count = 0; entry.first_connection = now; } entry.connection_count += 1; // Check rapid connection threshold if entry.connection_count >= RAPID_CONNECT_THRESHOLD { self.ban_ip_internal(ip, entry); return true; } false } /// Check if IP is currently banned pub async fn is_banned(&self, ip: &IpAddr) -> bool { let entries = self.entries.read().await; if let Some(entry) = entries.get(ip) { if let Some(banned_until) = entry.banned_until { return Instant::now() < banned_until; } } false } /// Internal ban logic fn ban_ip_internal(&self, ip: IpAddr, entry: &mut IpEntry) { entry.banned_until = Some(Instant::now() + Duration::from_secs(BAN_DURATION_SECS)); // Execute OS-level ban if callback set if let Some(ref callback) = self.ban_callback { callback(ip); } tracing::warn!("Banned IP {} for active probing", ip); } /// Manually ban an IP pub async fn ban(&self, ip: IpAddr) { let mut entries = self.entries.write().await; let entry = entries.entry(ip).or_default(); entry.banned_until = Some(Instant::now() + Duration::from_secs(BAN_DURATION_SECS)); if let Some(ref callback) = self.ban_callback { callback(ip); } } /// Unban an IP pub async fn unban(&self, ip: &IpAddr) { let mut entries = self.entries.write().await; if let Some(entry) = entries.get_mut(ip) { entry.banned_until = None; entry.failed_handshakes = 0; entry.connection_count = 0; } } /// Get list of banned IPs pub async fn banned_list(&self) -> Vec { let entries = self.entries.read().await; let now = Instant::now(); entries .iter() .filter(|(_, e)| e.banned_until.map(|t| now < t).unwrap_or(false)) .map(|(ip, _)| *ip) .collect() } /// Cleanup expired entries pub async fn cleanup(&self) { let mut entries = self.entries.write().await; let now = Instant::now(); entries.retain(|_, e| { // Keep if banned and ban not expired if let Some(banned_until) = e.banned_until { if now < banned_until { return true; } } // Keep if recent activity now.duration_since(e.first_connection) < Duration::from_secs(300) }); } /// Get statistics pub async fn stats(&self) -> ProbeStats { let entries = self.entries.read().await; let now = Instant::now(); let banned = entries .iter() .filter(|(_, e)| e.banned_until.map(|t| now < t).unwrap_or(false)) .count(); let total_failures: u32 = entries.values().map(|e| e.failed_handshakes).sum(); ProbeStats { tracked_ips: entries.len(), banned_ips: banned, total_failures, } } } impl Default for ProbeDetector { fn default() -> Self { Self::new() } } /// Probe detection statistics #[derive(Debug, Clone)] pub struct ProbeStats { pub tracked_ips: usize, pub banned_ips: usize, pub total_failures: u32, } /// Execute iptables ban command (Linux) #[cfg(target_os = "linux")] pub fn iptables_ban(ip: IpAddr) { use std::process::Command; let ip_str = ip.to_string(); // Add to INPUT chain let result = Command::new("iptables") .args(["-A", "INPUT", "-s", &ip_str, "-j", "DROP"]) .output(); match result { Ok(output) if output.status.success() => { tracing::info!("iptables: Banned {}", ip); } Ok(output) => { tracing::error!("iptables failed: {}", String::from_utf8_lossy(&output.stderr)); } Err(e) => { tracing::error!("Failed to execute iptables: {}", e); } } } /// Execute nftables ban command (Linux) #[cfg(target_os = "linux")] pub fn nftables_ban(ip: IpAddr) { use std::process::Command; let ip_str = ip.to_string(); // Add to blocklist set (assumes set exists) let result = Command::new("nft") .args(["add", "element", "inet", "filter", "blocklist", &format!("{{ {} }}", ip_str)]) .output(); match result { Ok(output) if output.status.success() => { tracing::info!("nftables: Banned {}", ip); } Ok(output) => { tracing::error!("nftables failed: {}", String::from_utf8_lossy(&output.stderr)); } Err(e) => { tracing::error!("Failed to execute nft: {}", e); } } } /// Execute Windows Firewall ban command #[cfg(target_os = "windows")] pub fn firewall_ban(ip: IpAddr) { use std::process::Command; let ip_str = ip.to_string(); let rule_name = format!("OSTP_BAN_{}", ip_str.replace('.', "_").replace(':', "_")); let result = Command::new("netsh") .args([ "advfirewall", "firewall", "add", "rule", &format!("name={}", rule_name), "dir=in", "action=block", &format!("remoteip={}", ip_str), ]) .output(); match result { Ok(output) if output.status.success() => { tracing::info!("Windows Firewall: Banned {}", ip); } Ok(output) => { tracing::error!("netsh failed: {}", String::from_utf8_lossy(&output.stderr)); } Err(e) => { tracing::error!("Failed to execute netsh: {}", e); } } } /// Dummy ban function for non-supported platforms #[cfg(not(any(target_os = "linux", target_os = "windows")))] pub fn firewall_ban(_ip: IpAddr) { tracing::warn!("Firewall banning not implemented for this platform"); } #[cfg(not(any(target_os = "linux", target_os = "windows")))] pub fn iptables_ban(_ip: IpAddr) {} #[cfg(not(any(target_os = "linux", target_os = "windows")))] pub fn nftables_ban(_ip: IpAddr) {} #[cfg(test)] mod tests { use super::*; use std::net::Ipv4Addr; #[tokio::test] async fn test_probe_detector() { let detector = ProbeDetector::new(); let ip: IpAddr = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100)); // Should not be banned initially assert!(!detector.is_banned(&ip).await); // Record failures below threshold for _ in 0..4 { let banned = detector.record_failure(ip).await; assert!(!banned); } // 5th failure should trigger ban let banned = detector.record_failure(ip).await; assert!(banned); assert!(detector.is_banned(&ip).await); } #[tokio::test] async fn test_rapid_connection() { let detector = ProbeDetector::new(); let ip: IpAddr = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 50)); // Record many connections for _ in 0..19 { detector.record_connection(ip).await; } assert!(!detector.is_banned(&ip).await); // 20th connection triggers ban detector.record_connection(ip).await; assert!(detector.is_banned(&ip).await); } }