//! Shared context types for the peer system. use std::{ collections::{HashMap, HashSet}, net::SocketAddr, path::PathBuf, sync::Arc, }; use lanspread_db::db::GameDB; use tokio::sync::RwLock; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use crate::{PeerEvent, Unpacker, library::LocalLibraryState, peer_db::PeerGameDB}; /// Mutating filesystem operation currently in flight for a game root. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum OperationKind { /// Downloading or replacing archive files. Downloading, /// Extracting into a previously uninstalled game root. Installing, /// Replacing an existing `local/` install. Updating, /// Removing an existing `local/` install. Uninstalling, } /// Main context for the peer system. #[derive(Clone)] pub struct Ctx { pub game_dir: Arc>, pub local_game_db: Arc>>, pub local_library: Arc>, pub peer_game_db: Arc>, pub local_peer_addr: Arc>>, pub active_operations: Arc>>, pub active_downloads: Arc>>, pub unpacker: Arc, pub catalog: Arc>>, pub peer_id: Arc, pub enable_mdns: bool, pub shutdown: CancellationToken, pub task_tracker: TaskTracker, } /// Context for peer connection handling. #[derive(Clone)] pub struct PeerCtx { pub game_dir: Arc>, pub local_game_db: Arc>>, pub local_library: Arc>, pub local_peer_addr: Arc>>, pub active_operations: Arc>>, pub peer_game_db: Arc>, pub catalog: Arc>>, pub peer_id: Arc, pub enable_mdns: bool, pub tx_notify_ui: tokio::sync::mpsc::UnboundedSender, pub shutdown: CancellationToken, pub task_tracker: TaskTracker, } impl std::fmt::Debug for PeerCtx { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PeerCtx") .field("game_dir", &"...") .field("local_game_db", &"...") .field("local_peer_addr", &"...") .field("active_operations", &"...") .finish() } } impl Ctx { /// Creates a new context with the given peer game database. #[allow(clippy::too_many_arguments)] pub fn new( peer_game_db: Arc>, peer_id: String, game_dir: PathBuf, unpacker: Arc, shutdown: CancellationToken, task_tracker: TaskTracker, catalog: Arc>>, enable_mdns: bool, ) -> Self { Self { game_dir: Arc::new(RwLock::new(game_dir)), local_game_db: Arc::new(RwLock::new(None)), local_library: Arc::new(RwLock::new(LocalLibraryState::empty())), peer_game_db, local_peer_addr: Arc::new(RwLock::new(None)), active_operations: Arc::new(RwLock::new(HashMap::new())), active_downloads: Arc::new(RwLock::new(HashMap::new())), unpacker, catalog, peer_id: Arc::new(peer_id), enable_mdns, shutdown, task_tracker, } } /// Creates a `PeerCtx` from this context. pub fn to_peer_ctx( &self, tx_notify_ui: tokio::sync::mpsc::UnboundedSender, ) -> PeerCtx { PeerCtx { game_dir: self.game_dir.clone(), local_game_db: self.local_game_db.clone(), local_library: self.local_library.clone(), local_peer_addr: self.local_peer_addr.clone(), active_operations: self.active_operations.clone(), peer_game_db: self.peer_game_db.clone(), catalog: self.catalog.clone(), peer_id: self.peer_id.clone(), enable_mdns: self.enable_mdns, tx_notify_ui, shutdown: self.shutdown.clone(), task_tracker: self.task_tracker.clone(), } } } /// Removes operation tracking no matter how a task exits. pub(crate) struct OperationGuard { id: String, active_operations: Arc>>, active_downloads: Arc>>, clears_download: bool, armed: bool, } impl OperationGuard { pub(crate) fn new( id: String, active_operations: Arc>>, ) -> Self { Self { id, active_operations, active_downloads: Arc::new(RwLock::new(HashMap::new())), clears_download: false, armed: true, } } pub(crate) fn download( id: String, active_operations: Arc>>, active_downloads: Arc>>, ) -> Self { Self { id, active_operations, active_downloads, clears_download: true, armed: true, } } pub(crate) fn disarm(mut self) { self.armed = false; } } impl Drop for OperationGuard { fn drop(&mut self) { if !self.armed { return; } let id = self.id.clone(); log::error!( "Operation guard is cleaning up {id}; operation ended without explicit state cleanup" ); if let Ok(mut guard) = self.active_operations.try_write() { guard.remove(&id); } else if let Ok(handle) = tokio::runtime::Handle::try_current() { let active_operations = self.active_operations.clone(); handle.spawn({ let id = id.clone(); async move { active_operations.write().await.remove(&id); } }); } else { log::error!("Failed to clean operation state for {id}: no Tokio runtime"); } if !self.clears_download { return; } if let Ok(mut guard) = self.active_downloads.try_write() { guard.remove(&id); } else if let Ok(handle) = tokio::runtime::Handle::try_current() { let active_downloads = self.active_downloads.clone(); handle.spawn({ let id = id.clone(); async move { active_downloads.write().await.remove(&id); } }); } else { log::error!("Failed to clean active download state for {id}: no Tokio runtime"); } } } #[cfg(test)] mod tests { use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::sync::RwLock; use tokio_util::sync::CancellationToken; use super::{OperationGuard, OperationKind}; type OperationTracking = ( Arc>>, Arc>>, CancellationToken, ); async fn wait_for_tracking_clear( id: &str, active_operations: &Arc>>, active_downloads: &Arc>>, ) { tokio::time::timeout(Duration::from_secs(1), async { loop { let operation_contains = active_operations.read().await.contains_key(id); let active_contains = active_downloads.read().await.contains_key(id); if !operation_contains && !active_contains { break; } tokio::task::yield_now().await; } }) .await .expect("download tracking should be cleared"); } fn tracked_download_state(id: &str) -> OperationTracking { let active_operations = Arc::new(RwLock::new(HashMap::from([( id.to_string(), OperationKind::Downloading, )]))); let cancel = CancellationToken::new(); let active_downloads = Arc::new(RwLock::new(HashMap::from([( id.to_string(), cancel.clone(), )]))); (active_operations, active_downloads, cancel) } #[tokio::test] async fn operation_guard_cleans_tracking_when_not_disarmed() { let id = "game-complete"; let (active_operations, active_downloads, _) = tracked_download_state(id); drop(OperationGuard::download( id.to_string(), active_operations.clone(), active_downloads.clone(), )); wait_for_tracking_clear(id, &active_operations, &active_downloads).await; } #[tokio::test] async fn operation_guard_cleans_tracking_after_cancellation() { let id = "game-cancelled"; let (active_operations, active_downloads, cancel) = tracked_download_state(id); cancel.cancel(); drop(OperationGuard::download( id.to_string(), active_operations.clone(), active_downloads.clone(), )); wait_for_tracking_clear(id, &active_operations, &active_downloads).await; } #[tokio::test] async fn disarmed_operation_guard_does_not_clean_tracking() { let id = "game-finished"; let (active_operations, active_downloads, _) = tracked_download_state(id); OperationGuard::download( id.to_string(), active_operations.clone(), active_downloads.clone(), ) .disarm(); assert!(active_operations.read().await.contains_key(id)); assert!(active_downloads.read().await.contains_key(id)); } #[tokio::test] async fn operation_guard_cleans_tracking_when_task_is_dropped() { let id = "game-aborted"; let (active_operations, active_downloads, _) = tracked_download_state(id); let (ready_tx, ready_rx) = tokio::sync::oneshot::channel(); let handle = tokio::spawn({ let active_operations = active_operations.clone(); let active_downloads = active_downloads.clone(); async move { let _guard = OperationGuard::download(id.to_string(), active_operations, active_downloads); let _ = ready_tx.send(()); std::future::pending::<()>().await; } }); ready_rx.await.expect("download guard should be created"); handle.abort(); let _ = handle.await; wait_for_tracking_clear(id, &active_operations, &active_downloads).await; } }