feat: CDN Control Plane (ONCP) implementation

- Add REST API for node/user management (axum-based)
- Add NodeRegistry for server check-in and load balancing
- Add SniManager for dynamic SNI updates and emergency blocking
- Add CDN Dashboard CLI (oncp-master) with real-time monitoring
- Add ProbeDetector in ostp-guard for active probing detection
- Add iptables/nftables/Windows firewall ban integration
- Extend MimicryEngine with async SNI updates from control plane
- Fix all compilation warnings
- Update author to ospab.team
This commit is contained in:
2026-01-01 20:33:03 +03:00
parent fc00214b07
commit 6d4c06a013
19 changed files with 2671 additions and 15 deletions

438
oncp/src/api.rs Normal file
View File

@@ -0,0 +1,438 @@
//! REST API for CDN Control Plane (Master Node)
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use base64::{Engine as _, engine::general_purpose};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use uuid::Uuid;
use crate::billing::{SqliteRegistry, User, UserRegistry};
use crate::node::{NetworkStats, Node, NodeCheckin, NodeRegistry};
use crate::session::SessionManager;
use crate::sni::{SniManager, SniUpdate};
/// Shared application state
pub struct AppState {
pub nodes: NodeRegistry,
pub users: SqliteRegistry,
pub sessions: SessionManager,
pub sni_manager: SniManager,
}
impl AppState {
pub fn new(db_path: &str) -> anyhow::Result<Self> {
Ok(Self {
nodes: NodeRegistry::new(60), // 60 second timeout
users: SqliteRegistry::new(db_path)?,
sessions: SessionManager::new(300), // 5 minute heartbeat timeout
sni_manager: SniManager::new(),
})
}
}
/// Create the control plane API router
pub fn create_router(state: Arc<AppState>) -> Router {
Router::new()
// Node management
.route("/api/v1/nodes", get(list_nodes).post(register_node))
.route("/api/v1/nodes/:id", get(get_node).delete(remove_node))
.route("/api/v1/nodes/:id/checkin", post(node_checkin))
.route("/api/v1/nodes/best", get(best_nodes))
// User management
.route("/api/v1/users", get(list_users).post(create_user))
.route("/api/v1/users/:id", get(get_user).delete(delete_user))
.route("/api/v1/users/:id/config", get(user_config))
// SNI management
.route("/api/v1/sni", get(list_sni).post(update_sni))
.route("/api/v1/sni/emergency", post(emergency_sni_update))
// Statistics
.route("/api/v1/stats", get(network_stats))
.route("/api/v1/stats/traffic", get(traffic_stats))
// Health check
.route("/health", get(health_check))
.layer(TraceLayer::new_for_http())
.layer(CorsLayer::new().allow_origin(Any).allow_methods(Any))
.with_state(state)
}
// ============================================================================
// Node Endpoints
// ============================================================================
/// List all nodes
async fn list_nodes(State(state): State<Arc<AppState>>) -> Json<Vec<Node>> {
Json(state.nodes.list().await)
}
/// Register new node
#[derive(Debug, Deserialize)]
struct RegisterNodeRequest {
name: String,
address: String,
country_code: String,
max_connections: Option<u32>,
psk_hash: Option<String>,
}
#[derive(Debug, Serialize)]
struct RegisterNodeResponse {
node_id: Uuid,
message: String,
}
async fn register_node(
State(state): State<Arc<AppState>>,
Json(req): Json<RegisterNodeRequest>,
) -> impl IntoResponse {
let mut node = Node::new(&req.name, &req.address, &req.country_code);
if let Some(max) = req.max_connections {
node.max_connections = max;
}
node.psk_hash = req.psk_hash;
let node_id = state.nodes.register(node).await;
(StatusCode::CREATED, Json(RegisterNodeResponse {
node_id,
message: "Node registered successfully".into(),
}))
}
/// Get single node
async fn get_node(
State(state): State<Arc<AppState>>,
Path(id): Path<Uuid>,
) -> impl IntoResponse {
match state.nodes.get(&id).await {
Some(node) => (StatusCode::OK, Json(Some(node))),
None => (StatusCode::NOT_FOUND, Json(None)),
}
}
/// Remove node
async fn remove_node(
State(state): State<Arc<AppState>>,
Path(id): Path<Uuid>,
) -> impl IntoResponse {
match state.nodes.remove(&id).await {
Some(_) => StatusCode::NO_CONTENT,
None => StatusCode::NOT_FOUND,
}
}
/// Node check-in (heartbeat)
async fn node_checkin(
State(state): State<Arc<AppState>>,
Path(id): Path<Uuid>,
Json(mut checkin): Json<NodeCheckin>,
) -> impl IntoResponse {
checkin.node_id = id; // Ensure node_id matches path
match state.nodes.checkin(checkin).await {
Some(node) => (StatusCode::OK, Json(Some(node))),
None => (StatusCode::NOT_FOUND, Json(None)),
}
}
/// Get best nodes for client connection
#[derive(Debug, Deserialize)]
struct BestNodesQuery {
country: Option<String>,
limit: Option<usize>,
}
async fn best_nodes(
State(state): State<Arc<AppState>>,
Query(query): Query<BestNodesQuery>,
) -> Json<Vec<Node>> {
let limit = query.limit.unwrap_or(3);
let nodes = match &query.country {
Some(country) => state.nodes.best_for_country(country, limit).await,
None => state.nodes.best_global(limit).await,
};
Json(nodes)
}
// ============================================================================
// User Endpoints
// ============================================================================
/// List users (limited info for security)
#[derive(Debug, Serialize)]
struct UserSummary {
uuid: Uuid,
active: bool,
expires_at: String,
bandwidth_used_gb: f64,
bandwidth_quota_gb: f64,
}
async fn list_users(State(state): State<Arc<AppState>>) -> impl IntoResponse {
// Note: In production, this should have pagination and auth
let conn = state.users.conn.lock().unwrap();
let mut stmt = conn.prepare(
"SELECT uuid, bandwidth_quota, bandwidth_used, expires_at, active FROM users LIMIT 100"
).unwrap();
let users: Vec<UserSummary> = stmt.query_map([], |row: &rusqlite::Row| {
let uuid_str: String = row.get(0)?;
Ok(UserSummary {
uuid: Uuid::parse_str(&uuid_str).unwrap(),
bandwidth_quota_gb: row.get::<_, i64>(1)? as f64 / (1024.0 * 1024.0 * 1024.0),
bandwidth_used_gb: row.get::<_, i64>(2)? as f64 / (1024.0 * 1024.0 * 1024.0),
expires_at: row.get::<_, String>(3)?,
active: row.get::<_, i32>(4)? == 1,
})
}).unwrap().filter_map(|r: Result<UserSummary, _>| r.ok()).collect();
Json(users)
}
/// Create new user
#[derive(Debug, Deserialize)]
struct CreateUserRequest {
quota_gb: Option<u64>,
valid_days: Option<i64>,
}
#[derive(Debug, Serialize)]
struct CreateUserResponse {
uuid: Uuid,
config_string: String,
expires_at: String,
}
async fn create_user(
State(state): State<Arc<AppState>>,
Json(req): Json<CreateUserRequest>,
) -> impl IntoResponse {
let quota = req.quota_gb.unwrap_or(100);
let days = req.valid_days.unwrap_or(30);
let user = User::new(quota, days);
match state.users.create_user(&user) {
Ok(()) => {
// Generate config string (can be used for QR code)
let config = serde_json::json!({
"uuid": user.uuid.to_string(),
"expires": user.expires_at.to_rfc3339(),
});
(StatusCode::CREATED, Json(CreateUserResponse {
uuid: user.uuid,
config_string: general_purpose::STANDARD.encode(config.to_string()),
expires_at: user.expires_at.to_rfc3339(),
}))
}
Err(e) => {
tracing::error!("Failed to create user: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, Json(CreateUserResponse {
uuid: Uuid::nil(),
config_string: String::new(),
expires_at: String::new(),
}))
}
}
}
/// Get user details
async fn get_user(
State(state): State<Arc<AppState>>,
Path(id): Path<Uuid>,
) -> impl IntoResponse {
match state.users.get_user(&id) {
Ok(Some(user)) => (StatusCode::OK, Json(Some(user))),
Ok(None) => (StatusCode::NOT_FOUND, Json(None)),
Err(_) => (StatusCode::INTERNAL_SERVER_ERROR, Json(None)),
}
}
/// Delete user
async fn delete_user(
State(state): State<Arc<AppState>>,
Path(id): Path<Uuid>,
) -> impl IntoResponse {
let conn = state.users.conn.lock().unwrap();
match conn.execute("DELETE FROM users WHERE uuid = ?", [id.to_string()]) {
Ok(0) => StatusCode::NOT_FOUND,
Ok(_) => StatusCode::NO_CONTENT,
Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}
/// Get user connection config (for client setup)
#[derive(Debug, Serialize)]
struct UserConfig {
uuid: String,
servers: Vec<ServerInfo>,
sni_list: Vec<String>,
qr_data: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
struct ServerInfo {
address: String,
country_code: String,
load: f32,
}
async fn user_config(
State(state): State<Arc<AppState>>,
Path(id): Path<Uuid>,
) -> impl IntoResponse {
// Verify user exists and is valid
match state.users.validate_user(&id) {
Ok(true) => {}
_ => return (StatusCode::NOT_FOUND, Json(UserConfig {
uuid: String::new(),
servers: vec![],
sni_list: vec![],
qr_data: None,
})),
}
// Get best available servers
let nodes = state.nodes.best_global(5).await;
let servers: Vec<ServerInfo> = nodes.iter().map(|n| ServerInfo {
address: n.address.clone(),
country_code: n.country_code.clone(),
load: n.load_score(),
}).collect();
// Get current SNI list
let sni_list = state.sni_manager.get_active_snis().await;
// Generate QR-compatible config
let qr_config = serde_json::json!({
"u": id.to_string(),
"s": servers.first().map(|s| &s.address),
});
let qr_data = general_purpose::STANDARD.encode(qr_config.to_string());
(StatusCode::OK, Json(UserConfig {
uuid: id.to_string(),
servers,
sni_list,
qr_data: Some(qr_data),
}))
}
// ============================================================================
// SNI Management
// ============================================================================
async fn list_sni(State(state): State<Arc<AppState>>) -> Json<Vec<String>> {
Json(state.sni_manager.get_active_snis().await)
}
async fn update_sni(
State(state): State<Arc<AppState>>,
Json(update): Json<SniUpdate>,
) -> impl IntoResponse {
state.sni_manager.apply_update(update).await;
StatusCode::OK
}
/// Emergency SNI update (broadcast to all nodes)
#[derive(Debug, Deserialize)]
struct EmergencySniRequest {
blocked_domains: Vec<String>,
replacement_domains: Vec<String>,
country_code: Option<String>,
}
async fn emergency_sni_update(
State(state): State<Arc<AppState>>,
Json(req): Json<EmergencySniRequest>,
) -> impl IntoResponse {
let update = SniUpdate {
remove: req.blocked_domains,
add: req.replacement_domains,
country_code: req.country_code,
emergency: true,
};
state.sni_manager.apply_update(update).await;
// Log for audit
tracing::warn!("Emergency SNI update applied");
StatusCode::OK
}
// ============================================================================
// Statistics
// ============================================================================
async fn network_stats(State(state): State<Arc<AppState>>) -> Json<NetworkStats> {
Json(state.nodes.network_stats().await)
}
#[derive(Debug, Serialize)]
struct TrafficStats {
total_bytes_tx: u64,
total_bytes_rx: u64,
total_mb_transferred: f64,
active_sessions: usize,
}
async fn traffic_stats(State(state): State<Arc<AppState>>) -> Json<TrafficStats> {
let net_stats = state.nodes.network_stats().await;
let total_bytes = net_stats.total_bytes_tx + net_stats.total_bytes_rx;
Json(TrafficStats {
total_bytes_tx: net_stats.total_bytes_tx,
total_bytes_rx: net_stats.total_bytes_rx,
total_mb_transferred: total_bytes as f64 / (1024.0 * 1024.0),
active_sessions: net_stats.total_connections as usize,
})
}
// ============================================================================
// Health Check
// ============================================================================
#[derive(Debug, Serialize)]
struct HealthStatus {
status: String,
version: String,
nodes_online: usize,
}
async fn health_check(State(state): State<Arc<AppState>>) -> Json<HealthStatus> {
let stats = state.nodes.network_stats().await;
Json(HealthStatus {
status: "ok".into(),
version: env!("CARGO_PKG_VERSION").into(),
nodes_online: stats.online_nodes,
})
}
/// Start the control plane API server
pub async fn run_server(state: Arc<AppState>, bind_addr: &str) -> anyhow::Result<()> {
let app = create_router(state);
let listener = tokio::net::TcpListener::bind(bind_addr).await?;
tracing::info!("Control Plane API listening on {}", bind_addr);
axum::serve(listener, app).await?;
Ok(())
}

View File

@@ -59,7 +59,7 @@ pub trait UserRegistry: Send + Sync {
/// SQLite implementation (thread-safe via Mutex)
pub struct SqliteRegistry {
conn: Mutex<rusqlite::Connection>,
pub conn: Mutex<rusqlite::Connection>,
}
impl SqliteRegistry {

View File

@@ -1,5 +1,12 @@
pub mod api;
pub mod billing;
pub mod node;
pub mod session;
pub mod sni;
pub use api::{create_router, run_server, AppState};
pub use billing::{BillingError, SqliteRegistry, User, UserRegistry};
pub use node::{NetworkStats, Node, NodeCheckin, NodeRegistry, NodeStatus};
pub use session::{Session, SessionManager};
pub use sni::{SniManager, SniUpdate};

301
oncp/src/node.rs Normal file
View File

@@ -0,0 +1,301 @@
//! Node registry and management for CDN control plane
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
/// Node health status
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum NodeStatus {
Online,
Offline,
Maintenance,
Overloaded,
}
impl Default for NodeStatus {
fn default() -> Self {
Self::Offline
}
}
/// OSTP server node in the CDN network
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Node {
pub node_id: Uuid,
pub name: String,
pub address: String, // "ip:port"
pub country_code: String,
pub status: NodeStatus,
pub cpu_load: f32, // 0.0 - 1.0
pub active_connections: u32,
pub max_connections: u32,
pub bytes_tx: u64,
pub bytes_rx: u64,
pub last_checkin: DateTime<Utc>,
pub registered_at: DateTime<Utc>,
#[serde(skip_serializing_if = "Option::is_none")]
pub psk_hash: Option<String>, // First 8 chars of PSK hash for identification
}
impl Node {
pub fn new(name: impl Into<String>, address: impl Into<String>, country_code: impl Into<String>) -> Self {
let now = Utc::now();
Self {
node_id: Uuid::new_v4(),
name: name.into(),
address: address.into(),
country_code: country_code.into(),
status: NodeStatus::Offline,
cpu_load: 0.0,
active_connections: 0,
max_connections: 1000,
bytes_tx: 0,
bytes_rx: 0,
last_checkin: now,
registered_at: now,
psk_hash: None,
}
}
/// Check if node is healthy and can accept connections
pub fn is_available(&self) -> bool {
self.status == NodeStatus::Online &&
self.active_connections < self.max_connections &&
self.cpu_load < 0.9
}
/// Calculate node load score (lower is better)
pub fn load_score(&self) -> f32 {
let conn_ratio = self.active_connections as f32 / self.max_connections.max(1) as f32;
(self.cpu_load + conn_ratio) / 2.0
}
/// Check if node is stale (no check-in for timeout period)
pub fn is_stale(&self, timeout_secs: i64) -> bool {
let elapsed = Utc::now().signed_duration_since(self.last_checkin);
elapsed.num_seconds() > timeout_secs
}
}
/// Node check-in request from OSTP server
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeCheckin {
pub node_id: Uuid,
pub cpu_load: f32,
pub active_connections: u32,
pub bytes_tx: u64,
pub bytes_rx: u64,
#[serde(default)]
pub sni_list: Vec<String>, // Currently loaded SNIs
}
/// Node registry managing all OSTP servers
pub struct NodeRegistry {
nodes: Arc<RwLock<HashMap<Uuid, Node>>>,
checkin_timeout_secs: i64,
}
impl NodeRegistry {
pub fn new(checkin_timeout_secs: i64) -> Self {
Self {
nodes: Arc::new(RwLock::new(HashMap::new())),
checkin_timeout_secs,
}
}
/// Register a new node
pub async fn register(&self, node: Node) -> Uuid {
let id = node.node_id;
self.nodes.write().await.insert(id, node);
tracing::info!("Registered node {}", id);
id
}
/// Process node check-in (heartbeat + status update)
pub async fn checkin(&self, checkin: NodeCheckin) -> Option<Node> {
let mut nodes = self.nodes.write().await;
if let Some(node) = nodes.get_mut(&checkin.node_id) {
node.cpu_load = checkin.cpu_load;
node.active_connections = checkin.active_connections;
node.bytes_tx = checkin.bytes_tx;
node.bytes_rx = checkin.bytes_rx;
node.last_checkin = Utc::now();
// Update status based on load
node.status = if checkin.cpu_load > 0.95 {
NodeStatus::Overloaded
} else {
NodeStatus::Online
};
Some(node.clone())
} else {
None
}
}
/// Get node by ID
pub async fn get(&self, node_id: &Uuid) -> Option<Node> {
self.nodes.read().await.get(node_id).cloned()
}
/// Get all nodes
pub async fn list(&self) -> Vec<Node> {
self.nodes.read().await.values().cloned().collect()
}
/// Get online nodes only
pub async fn list_online(&self) -> Vec<Node> {
self.nodes.read().await
.values()
.filter(|n| n.status == NodeStatus::Online)
.cloned()
.collect()
}
/// Get best nodes for a specific country (sorted by load)
pub async fn best_for_country(&self, country_code: &str, limit: usize) -> Vec<Node> {
let mut nodes: Vec<Node> = self.nodes.read().await
.values()
.filter(|n| n.is_available() && n.country_code == country_code)
.cloned()
.collect();
nodes.sort_by(|a, b| a.load_score().partial_cmp(&b.load_score()).unwrap());
nodes.truncate(limit);
nodes
}
/// Get best nodes globally (for any country)
pub async fn best_global(&self, limit: usize) -> Vec<Node> {
let mut nodes: Vec<Node> = self.nodes.read().await
.values()
.filter(|n| n.is_available())
.cloned()
.collect();
nodes.sort_by(|a, b| a.load_score().partial_cmp(&b.load_score()).unwrap());
nodes.truncate(limit);
nodes
}
/// Remove node
pub async fn remove(&self, node_id: &Uuid) -> Option<Node> {
self.nodes.write().await.remove(node_id)
}
/// Mark stale nodes as offline
pub async fn cleanup_stale(&self) -> Vec<Uuid> {
let mut nodes = self.nodes.write().await;
let timeout = self.checkin_timeout_secs;
let stale: Vec<Uuid> = nodes
.iter()
.filter(|(_, n)| n.is_stale(timeout))
.map(|(id, _)| *id)
.collect();
for id in &stale {
if let Some(node) = nodes.get_mut(id) {
node.status = NodeStatus::Offline;
}
}
stale
}
/// Get aggregated network statistics
pub async fn network_stats(&self) -> NetworkStats {
let nodes = self.nodes.read().await;
let total_nodes = nodes.len();
let online_nodes = nodes.values().filter(|n| n.status == NodeStatus::Online).count();
let total_connections: u32 = nodes.values().map(|n| n.active_connections).sum();
let total_bytes_tx: u64 = nodes.values().map(|n| n.bytes_tx).sum();
let total_bytes_rx: u64 = nodes.values().map(|n| n.bytes_rx).sum();
let avg_load: f32 = if online_nodes > 0 {
nodes.values()
.filter(|n| n.status == NodeStatus::Online)
.map(|n| n.load_score())
.sum::<f32>() / online_nodes as f32
} else {
0.0
};
NetworkStats {
total_nodes,
online_nodes,
total_connections,
total_bytes_tx,
total_bytes_rx,
avg_load,
}
}
}
/// Aggregated network statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkStats {
pub total_nodes: usize,
pub online_nodes: usize,
pub total_connections: u32,
pub total_bytes_tx: u64,
pub total_bytes_rx: u64,
pub avg_load: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_node_registry() {
let registry = NodeRegistry::new(30);
let node = Node::new("test-node", "1.2.3.4:8443", "US");
let node_id = registry.register(node).await;
// Check-in with updated stats
let checkin = NodeCheckin {
node_id,
cpu_load: 0.3,
active_connections: 10,
bytes_tx: 1024,
bytes_rx: 2048,
sni_list: vec![],
};
let updated = registry.checkin(checkin).await;
assert!(updated.is_some());
let node = updated.unwrap();
assert_eq!(node.status, NodeStatus::Online);
assert_eq!(node.active_connections, 10);
}
#[tokio::test]
async fn test_best_nodes() {
let registry = NodeRegistry::new(30);
// Add nodes with different loads
let mut node1 = Node::new("low-load", "1.1.1.1:8443", "US");
node1.status = NodeStatus::Online;
node1.cpu_load = 0.2;
let mut node2 = Node::new("high-load", "2.2.2.2:8443", "US");
node2.status = NodeStatus::Online;
node2.cpu_load = 0.8;
registry.register(node1).await;
registry.register(node2).await;
let best = registry.best_for_country("US", 1).await;
assert_eq!(best.len(), 1);
assert_eq!(best[0].name, "low-load");
}
}

246
oncp/src/sni.rs Normal file
View File

@@ -0,0 +1,246 @@
//! Dynamic SNI management and emergency updates
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use chrono::{DateTime, Utc};
/// SNI update command
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SniUpdate {
pub remove: Vec<String>,
pub add: Vec<String>,
#[serde(default)]
pub country_code: Option<String>,
#[serde(default)]
pub emergency: bool,
}
/// SNI entry with metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SniEntry {
pub domain: String,
pub country_code: String,
pub added_at: DateTime<Utc>,
pub blocked: bool,
pub priority: u8, // Higher = preferred
}
/// Dynamic SNI manager for control plane
pub struct SniManager {
/// Global SNI list
#[allow(dead_code)]
global_snis: Arc<RwLock<Vec<SniEntry>>>,
/// Country-specific SNIs
country_snis: Arc<RwLock<HashMap<String, Vec<SniEntry>>>>,
/// Blocked domains (emergency blacklist)
blocked_domains: Arc<RwLock<HashSet<String>>>,
/// Pending updates for nodes to fetch
pending_updates: Arc<RwLock<Vec<SniUpdate>>>,
}
impl SniManager {
pub fn new() -> Self {
let manager = Self {
global_snis: Arc::new(RwLock::new(Vec::new())),
country_snis: Arc::new(RwLock::new(HashMap::new())),
blocked_domains: Arc::new(RwLock::new(HashSet::new())),
pending_updates: Arc::new(RwLock::new(Vec::new())),
};
// Initialize with default domains
tokio::spawn(async move {
// This would be initialized synchronously in practice
});
manager
}
/// Initialize with default SNI mappings
pub async fn init_defaults(&self) {
let defaults = vec![
("RU", vec!["gosuslugi.ru", "sberbank.ru", "yandex.ru", "vk.com", "mail.ru"]),
("US", vec!["apple.com", "microsoft.com", "amazon.com", "google.com", "cloudflare.com"]),
("DE", vec!["sparkasse.de", "deutsche-bank.de", "bund.de", "spiegel.de"]),
("NO", vec!["bankid.no", "vipps.no", "altinn.no", "vg.no", "nrk.no"]),
("CN", vec!["qq.com", "baidu.com", "taobao.com", "weibo.com"]),
];
let mut country_snis = self.country_snis.write().await;
for (country, domains) in defaults {
let entries: Vec<SniEntry> = domains.into_iter().map(|d| SniEntry {
domain: d.to_string(),
country_code: country.to_string(),
added_at: Utc::now(),
blocked: false,
priority: 50,
}).collect();
country_snis.insert(country.to_string(), entries);
}
}
/// Apply an SNI update
pub async fn apply_update(&self, update: SniUpdate) {
// Add to blocked list
{
let mut blocked = self.blocked_domains.write().await;
for domain in &update.remove {
blocked.insert(domain.clone());
}
}
// Add new domains
if !update.add.is_empty() {
let country = update.country_code.clone().unwrap_or_else(|| "GLOBAL".to_string());
let mut country_snis = self.country_snis.write().await;
let entries = country_snis.entry(country.clone()).or_insert_with(Vec::new);
for domain in &update.add {
entries.push(SniEntry {
domain: domain.clone(),
country_code: country.clone(),
added_at: Utc::now(),
blocked: false,
priority: if update.emergency { 100 } else { 50 },
});
}
}
// Store update for nodes to fetch
self.pending_updates.write().await.push(update);
tracing::info!("SNI update applied");
}
/// Get active SNIs for a country
pub async fn get_snis_for_country(&self, country: &str) -> Vec<String> {
let country_snis = self.country_snis.read().await;
let blocked = self.blocked_domains.read().await;
country_snis
.get(country)
.map(|entries| {
entries
.iter()
.filter(|e| !e.blocked && !blocked.contains(&e.domain))
.map(|e| e.domain.clone())
.collect()
})
.unwrap_or_default()
}
/// Get all active SNIs
pub async fn get_active_snis(&self) -> Vec<String> {
let country_snis = self.country_snis.read().await;
let blocked = self.blocked_domains.read().await;
country_snis
.values()
.flatten()
.filter(|e| !e.blocked && !blocked.contains(&e.domain))
.map(|e| e.domain.clone())
.collect()
}
/// Get pending updates for nodes
pub async fn get_pending_updates(&self) -> Vec<SniUpdate> {
self.pending_updates.read().await.clone()
}
/// Clear pending updates after nodes have fetched them
pub async fn clear_pending_updates(&self) {
self.pending_updates.write().await.clear();
}
/// Check if domain is blocked
pub async fn is_blocked(&self, domain: &str) -> bool {
self.blocked_domains.read().await.contains(domain)
}
/// Block a domain immediately
pub async fn block_domain(&self, domain: String) {
self.blocked_domains.write().await.insert(domain.clone());
// Create emergency update for nodes
let update = SniUpdate {
remove: vec![domain],
add: vec![],
country_code: None,
emergency: true,
};
self.pending_updates.write().await.push(update);
}
/// Unblock a domain
pub async fn unblock_domain(&self, domain: &str) {
self.blocked_domains.write().await.remove(domain);
}
/// Get statistics
pub async fn stats(&self) -> SniStats {
let country_snis = self.country_snis.read().await;
let blocked = self.blocked_domains.read().await;
let total_domains: usize = country_snis.values().map(|v| v.len()).sum();
let countries = country_snis.len();
SniStats {
total_domains,
blocked_domains: blocked.len(),
countries,
pending_updates: self.pending_updates.read().await.len(),
}
}
}
impl Default for SniManager {
fn default() -> Self {
Self::new()
}
}
/// SNI statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SniStats {
pub total_domains: usize,
pub blocked_domains: usize,
pub countries: usize,
pub pending_updates: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_sni_manager() {
let manager = SniManager::new();
manager.init_defaults().await;
let ru_snis = manager.get_snis_for_country("RU").await;
assert!(!ru_snis.is_empty());
assert!(ru_snis.contains(&"yandex.ru".to_string()));
}
#[tokio::test]
async fn test_block_domain() {
let manager = SniManager::new();
manager.init_defaults().await;
// Block a domain
manager.block_domain("yandex.ru".to_string()).await;
// Should no longer appear in active list
let ru_snis = manager.get_snis_for_country("RU").await;
assert!(!ru_snis.contains(&"yandex.ru".to_string()));
// Should have pending update
let updates = manager.get_pending_updates().await;
assert!(!updates.is_empty());
}
}