diff --git a/Cargo.lock b/Cargo.lock index 573bbc6..f7ca9b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2117,6 +2117,7 @@ dependencies = [ "s2n-quic", "serde", "serde_json", + "strum", "tokio", "tokio-util", "uuid", @@ -4191,6 +4192,27 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7695ce3845ea4b33927c055a39dc438a45b059f7c1b3d91d38d10355fb8cbca7" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "subtle" version = "2.6.1" @@ -4812,6 +4834,7 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", "pin-project-lite", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 2bc19f3..af95992 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,13 +28,14 @@ sqlx = { version = "0.8", default-features = false, features = [ "runtime-tokio", "sqlite", ] } +strum = { version = "0.27", features = ["derive"] } tauri = { version = "2", features = [] } tauri-plugin-dialog = "2" tauri-plugin-log = "2" tauri-plugin-shell = "2" tauri-plugin-store = "2" tokio = { version = "1", features = ["full"] } -tokio-util = { version = "0.7", features = ["codec"] } +tokio-util = { version = "0.7", features = ["codec", "rt"] } tracing = "0.1" uuid = { version = "1", features = ["v7"] } walkdir = "2" diff --git a/crates/lanspread-mdns/src/lib.rs b/crates/lanspread-mdns/src/lib.rs index b6fab05..af398d2 100644 --- a/crates/lanspread-mdns/src/lib.rs +++ b/crates/lanspread-mdns/src/lib.rs @@ -1,12 +1,17 @@ #![allow(clippy::missing_errors_doc, clippy::missing_panics_doc)] -use std::{collections::HashMap, net::SocketAddr}; +use std::{ + collections::HashMap, + net::SocketAddr, + time::{Duration, Instant}, +}; use eyre::bail; pub use mdns_sd::DaemonEvent; -use mdns_sd::{Receiver, ServiceDaemon, ServiceEvent, ServiceInfo}; +use mdns_sd::{Receiver, ResolvedService, ServiceDaemon, ServiceEvent, ServiceInfo}; pub const LANSPREAD_SERVICE_TYPE: &str = "_lanspread._udp.local."; +pub type MdnsMonitor = Receiver; pub struct MdnsAdvertiser { daemon: ServiceDaemon, @@ -66,6 +71,13 @@ pub struct MdnsService { pub properties: HashMap, } +#[derive(Debug, Clone)] +pub enum MdnsServicePoll { + Service(MdnsService), + Timeout, + Closed, +} + impl MdnsBrowser { pub fn new(service_type: &str) -> eyre::Result { let daemon = ServiceDaemon::new()?; @@ -83,50 +95,10 @@ impl MdnsBrowser { ) -> eyre::Result> { loop { match self.receiver.recv() { - Ok(ServiceEvent::ServiceResolved(info)) => { - log::trace!("mdns ServiceResolved event: {info:?}"); - - if info.ty_domain != self.service_type { - log::trace!( - "Got mDNS with uninteresting service type: {} (expected: {})", - info.ty_domain, - self.service_type, - ); - continue; + Ok(event) => { + if let Some(service) = self.service_from_event(event, ignore_addr) { + return Ok(Some(service)); } - - let mut ignored_match = false; - for address in info.get_addresses() { - let addr = SocketAddr::new(address.to_ip_addr(), info.get_port()); - - if ignore_addr.is_some_and(|ignore| ignore == addr) { - ignored_match = true; - log::trace!("Ignoring mDNS advertisement for local server at {addr}"); - continue; - } - - log::info!("Found server at {addr}"); - let properties = info.get_properties().clone().into_property_map_str(); - return Ok(Some(MdnsService { - addr, - fullname: info.get_fullname().to_string(), - hostname: info.get_hostname().to_string(), - properties, - })); - } - - if ignored_match { - log::trace!( - "Only saw ignored mDNS advertisements (probably ourselves) for {:?}", - info.get_fullname() - ); - continue; - } - - log::error!("No address found in mDNS response: {info:?}"); - } - Ok(other_event) => { - log::trace!("mdns unrelated event: {other_event:?}"); } Err(err) => { log::error!("mDNS browse channel closed: {err}"); @@ -136,12 +108,105 @@ impl MdnsBrowser { } } + pub fn next_service_timeout( + &self, + ignore_addr: Option, + timeout: Duration, + ) -> eyre::Result { + let deadline = Instant::now() + timeout; + + loop { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + return Ok(MdnsServicePoll::Timeout); + } + + match self.receiver.recv_timeout(remaining) { + Ok(event) => { + if let Some(service) = self.service_from_event(event, ignore_addr) { + return Ok(MdnsServicePoll::Service(service)); + } + } + Err(err) if self.receiver.is_disconnected() => { + log::error!("mDNS browse channel closed: {err}"); + return Ok(MdnsServicePoll::Closed); + } + Err(err) => { + log::trace!("mDNS browse timeout: {err}"); + return Ok(MdnsServicePoll::Timeout); + } + } + } + } + pub fn next_address( &self, ignore_addr: Option, ) -> eyre::Result> { Ok(self.next_service(ignore_addr)?.map(|service| service.addr)) } + + fn service_from_event( + &self, + event: ServiceEvent, + ignore_addr: Option, + ) -> Option { + match event { + ServiceEvent::ServiceResolved(info) => self.service_from_resolved(&info, ignore_addr), + other_event => { + log::trace!("mdns unrelated event: {other_event:?}"); + None + } + } + } + + fn service_from_resolved( + &self, + info: &ResolvedService, + ignore_addr: Option, + ) -> Option { + log::trace!("mdns ServiceResolved event: {info:?}"); + + if info.ty_domain != self.service_type { + log::trace!( + "Got mDNS with uninteresting service type: {} (expected: {})", + info.ty_domain, + self.service_type, + ); + return None; + } + + let mut ignored_match = false; + for address in info.get_addresses() { + let addr = SocketAddr::new(address.to_ip_addr(), info.get_port()); + + if ignore_addr.is_some_and(|ignore| ignore == addr) { + ignored_match = true; + log::trace!("Ignoring mDNS advertisement for local server at {addr}"); + continue; + } + + log::info!("Found server at {addr}"); + let properties = info.get_properties().clone().into_property_map_str(); + return Some(MdnsService { + addr, + fullname: info.get_fullname().to_string(), + hostname: info.get_hostname().to_string(), + properties, + }); + } + + if ignored_match { + log::trace!( + "Only saw ignored mDNS advertisements (probably ourselves) for {:?}", + info.get_fullname() + ); + return None; + } + + log::error!("No address found in mDNS response: {info:?}"); + None + } } impl Drop for MdnsBrowser { diff --git a/crates/lanspread-peer/Cargo.toml b/crates/lanspread-peer/Cargo.toml index bb87534..a4060b4 100644 --- a/crates/lanspread-peer/Cargo.toml +++ b/crates/lanspread-peer/Cargo.toml @@ -24,6 +24,7 @@ log = { workspace = true } s2n-quic = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +strum = { workspace = true } tokio = { workspace = true } tokio-util = { workspace = true } uuid = { workspace = true } diff --git a/crates/lanspread-peer/src/context.rs b/crates/lanspread-peer/src/context.rs index 6f070b9..0eca907 100644 --- a/crates/lanspread-peer/src/context.rs +++ b/crates/lanspread-peer/src/context.rs @@ -8,7 +8,8 @@ use std::{ }; use lanspread_db::db::GameDB; -use tokio::{sync::RwLock, task::JoinHandle}; +use tokio::sync::RwLock; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; use crate::{PeerEvent, library::LocalLibraryState, peer_db::PeerGameDB}; @@ -21,8 +22,10 @@ pub struct Ctx { pub peer_game_db: Arc>, pub local_peer_addr: Arc>>, pub downloading_games: Arc>>, - pub active_downloads: Arc>>>, + pub active_downloads: Arc>>, pub peer_id: Arc, + pub shutdown: CancellationToken, + pub task_tracker: TaskTracker, } /// Context for peer connection handling. @@ -36,6 +39,8 @@ pub struct PeerCtx { pub peer_game_db: Arc>, pub peer_id: Arc, pub tx_notify_ui: tokio::sync::mpsc::UnboundedSender, + pub shutdown: CancellationToken, + pub task_tracker: TaskTracker, } impl std::fmt::Debug for PeerCtx { @@ -51,7 +56,13 @@ impl std::fmt::Debug for PeerCtx { impl Ctx { /// Creates a new context with the given peer game database. - pub fn new(peer_game_db: Arc>, peer_id: String, game_dir: PathBuf) -> Self { + pub fn new( + peer_game_db: Arc>, + peer_id: String, + game_dir: PathBuf, + shutdown: CancellationToken, + task_tracker: TaskTracker, + ) -> Self { Self { game_dir: Arc::new(RwLock::new(game_dir)), local_game_db: Arc::new(RwLock::new(None)), @@ -61,6 +72,8 @@ impl Ctx { downloading_games: Arc::new(RwLock::new(HashSet::new())), active_downloads: Arc::new(RwLock::new(HashMap::new())), peer_id: Arc::new(peer_id), + shutdown, + task_tracker, } } @@ -78,6 +91,164 @@ impl Ctx { peer_game_db: self.peer_game_db.clone(), peer_id: self.peer_id.clone(), tx_notify_ui, + shutdown: self.shutdown.clone(), + task_tracker: self.task_tracker.clone(), } } } + +/// Removes download tracking no matter how a download task exits. +pub(crate) struct DownloadStateGuard { + id: String, + downloading_games: Arc>>, + active_downloads: Arc>>, +} + +impl DownloadStateGuard { + pub(crate) fn new( + id: String, + downloading_games: Arc>>, + active_downloads: Arc>>, + ) -> Self { + Self { + id, + downloading_games, + active_downloads, + } + } +} + +impl Drop for DownloadStateGuard { + fn drop(&mut self) { + let id = self.id.clone(); + if let Ok(mut guard) = self.downloading_games.try_write() { + guard.remove(&id); + } else if let Ok(handle) = tokio::runtime::Handle::try_current() { + let downloading_games = self.downloading_games.clone(); + handle.spawn({ + let id = id.clone(); + async move { + downloading_games.write().await.remove(&id); + } + }); + } else { + log::error!("Failed to clean downloading state for {id}: no Tokio runtime"); + } + + if let Ok(mut guard) = self.active_downloads.try_write() { + guard.remove(&id); + } else if let Ok(handle) = tokio::runtime::Handle::try_current() { + let active_downloads = self.active_downloads.clone(); + handle.spawn({ + let id = id.clone(); + async move { + active_downloads.write().await.remove(&id); + } + }); + } else { + log::error!("Failed to clean active download state for {id}: no Tokio runtime"); + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + time::Duration, + }; + + use tokio::sync::RwLock; + use tokio_util::sync::CancellationToken; + + use super::DownloadStateGuard; + + type DownloadTracking = ( + Arc>>, + Arc>>, + CancellationToken, + ); + + async fn wait_for_tracking_clear( + id: &str, + downloading_games: &Arc>>, + active_downloads: &Arc>>, + ) { + tokio::time::timeout(Duration::from_secs(1), async { + loop { + let downloading_contains = downloading_games.read().await.contains(id); + let active_contains = active_downloads.read().await.contains_key(id); + if !downloading_contains && !active_contains { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("download tracking should be cleared"); + } + + fn tracked_download_state(id: &str) -> DownloadTracking { + let downloading_games = Arc::new(RwLock::new(HashSet::from([id.to_string()]))); + let cancel = CancellationToken::new(); + let active_downloads = Arc::new(RwLock::new(HashMap::from([( + id.to_string(), + cancel.clone(), + )]))); + (downloading_games, active_downloads, cancel) + } + + #[tokio::test] + async fn download_state_guard_clears_tracking_on_completion() { + let id = "game-complete"; + let (downloading_games, active_downloads, _) = tracked_download_state(id); + + drop(DownloadStateGuard::new( + id.to_string(), + downloading_games.clone(), + active_downloads.clone(), + )); + + wait_for_tracking_clear(id, &downloading_games, &active_downloads).await; + } + + #[tokio::test] + async fn download_state_guard_clears_tracking_after_cancellation() { + let id = "game-cancelled"; + let (downloading_games, active_downloads, cancel) = tracked_download_state(id); + cancel.cancel(); + + drop(DownloadStateGuard::new( + id.to_string(), + downloading_games.clone(), + active_downloads.clone(), + )); + + wait_for_tracking_clear(id, &downloading_games, &active_downloads).await; + } + + #[tokio::test] + async fn download_state_guard_clears_tracking_when_task_is_dropped() { + let id = "game-aborted"; + let (downloading_games, active_downloads, _) = tracked_download_state(id); + let (ready_tx, ready_rx) = tokio::sync::oneshot::channel(); + + let handle = tokio::spawn({ + let downloading_games = downloading_games.clone(); + let active_downloads = active_downloads.clone(); + async move { + let _guard = + DownloadStateGuard::new(id.to_string(), downloading_games, active_downloads); + let _ = ready_tx.send(()); + std::future::pending::<()>().await; + } + }); + + ready_rx.await.expect("download guard should be created"); + handle.abort(); + let _ = handle.await; + + wait_for_tracking_clear(id, &downloading_games, &active_downloads).await; + } +} diff --git a/crates/lanspread-peer/src/download.rs b/crates/lanspread-peer/src/download.rs index 8e9aca9..215a4cf 100644 --- a/crates/lanspread-peer/src/download.rs +++ b/crates/lanspread-peer/src/download.rs @@ -12,7 +12,10 @@ use tokio::{ io::{AsyncSeekExt, AsyncWriteExt}, sync::mpsc::UnboundedSender, }; -use tokio_util::codec::{FramedWrite, LengthDelimitedCodec}; +use tokio_util::{ + codec::{FramedWrite, LengthDelimitedCodec}, + sync::CancellationToken, +}; use crate::{ PeerEvent, @@ -50,6 +53,16 @@ pub struct ChunkDownloadResult { pub peer_addr: SocketAddr, } +fn ensure_download_not_cancelled( + cancel_token: &CancellationToken, + game_id: &str, +) -> eyre::Result<()> { + if cancel_token.is_cancelled() { + eyre::bail!("download cancelled for game {game_id}"); + } + Ok(()) +} + // ============================================================================= // Storage preparation // ============================================================================= @@ -315,11 +328,14 @@ pub async fn download_from_peer( game_id: &str, plan: PeerDownloadPlan, games_folder: PathBuf, + cancel_token: &CancellationToken, ) -> eyre::Result> { if plan.chunks.is_empty() && plan.whole_files.is_empty() { return Ok(Vec::new()); } + ensure_download_not_cancelled(cancel_token, game_id)?; + let mut conn = connect_to_peer(peer_addr).await?; conn.keep_alive(true)?; conn.keep_alive(true)?; @@ -329,6 +345,8 @@ pub async fn download_from_peer( // Download chunks with error handling for chunk in &plan.chunks { + ensure_download_not_cancelled(cancel_token, game_id)?; + log::info!( "Downloading chunk {} (offset {}, length {}) from {}", chunk.relative_path, @@ -346,6 +364,8 @@ pub async fn download_from_peer( // Download whole files for desc in &plan.whole_files { + ensure_download_not_cancelled(cancel_token, game_id)?; + let chunk = DownloadChunk { relative_path: desc.relative_path.clone(), offset: 0, @@ -404,11 +424,16 @@ pub async fn retry_failed_chunks( base_dir: &Path, game_id: &str, file_peer_map: &HashMap>, -) -> Vec { + cancel_token: &CancellationToken, +) -> eyre::Result> { let mut exhausted = Vec::new(); let mut queue: VecDeque = failed_chunks.into_iter().collect(); while let Some(mut chunk) = queue.pop_front() { + if cancel_token.is_cancelled() { + return Ok(exhausted); + } + let eligible_peers = resolve_file_peers(&chunk.relative_path, file_peer_map, peers); if chunk.retry_count >= MAX_RETRY_COUNT { @@ -445,8 +470,20 @@ pub async fn retry_failed_chunks( whole_files: Vec::new(), }; - match download_from_peer(peer_addr, game_id, plan, base_dir.to_path_buf()).await { + match download_from_peer( + peer_addr, + game_id, + plan, + base_dir.to_path_buf(), + cancel_token, + ) + .await + { Ok(results) => { + if cancel_token.is_cancelled() { + return Ok(exhausted); + } + for result in results { match result.result { Ok(()) => {} @@ -473,6 +510,10 @@ pub async fn retry_failed_chunks( } } Err(e) => { + if cancel_token.is_cancelled() { + return Ok(exhausted); + } + chunk.retry_count += 1; chunk.last_peer = Some(peer_addr); @@ -492,7 +533,7 @@ pub async fn retry_failed_chunks( } } - exhausted + Ok(exhausted) } // ============================================================================= @@ -500,6 +541,7 @@ pub async fn retry_failed_chunks( // ============================================================================= /// Downloads all game files from available peers. +#[allow(clippy::too_many_lines)] pub async fn download_game_files( game_id: &str, game_file_descs: Vec, @@ -507,12 +549,20 @@ pub async fn download_game_files( peers: Vec, file_peer_map: HashMap>, tx_notify_ui: UnboundedSender, + cancel_token: CancellationToken, ) -> eyre::Result<()> { if peers.is_empty() { eyre::bail!("no peers available for game {game_id}"); } + if cancel_token.is_cancelled() { + return Ok(()); + } + prepare_game_storage(&games_folder, &game_file_descs).await?; + if cancel_token.is_cancelled() { + return Ok(()); + } tx_notify_ui.send(PeerEvent::DownloadGameFilesBegin { id: game_id.to_string(), @@ -524,8 +574,9 @@ pub async fn download_game_files( for (peer_addr, plan) in plans { let base_dir = games_folder.clone(); let game_id = game_id.to_string(); + let cancel_token = cancel_token.clone(); tasks.push(tokio::spawn(async move { - download_from_peer(peer_addr, &game_id, plan, base_dir).await + download_from_peer(peer_addr, &game_id, plan, base_dir, &cancel_token).await })); } @@ -533,8 +584,16 @@ pub async fn download_game_files( let mut last_err: Option = None; for handle in tasks { + if cancel_token.is_cancelled() { + return Ok(()); + } + match handle.await { Ok(Ok(results)) => { + if cancel_token.is_cancelled() { + return Ok(()); + } + for chunk_result in results { if let Err(e) = chunk_result.result { log::warn!( @@ -555,6 +614,7 @@ pub async fn download_game_files( } } } + Ok(Err(_)) | Err(_) if cancel_token.is_cancelled() => return Ok(()), Ok(Err(e)) => last_err = Some(e), Err(e) => last_err = Some(eyre::eyre!("task join error: {e}")), } @@ -562,18 +622,35 @@ pub async fn download_game_files( // Retry failed chunks if any if !failed_chunks.is_empty() && !peers.is_empty() { + if cancel_token.is_cancelled() { + return Ok(()); + } + log::info!("Retrying {} failed chunks", failed_chunks.len()); - let retry_results = retry_failed_chunks( + let retry_results = match retry_failed_chunks( failed_chunks, &peers, &games_folder, game_id, &file_peer_map, + &cancel_token, ) - .await; + .await + { + Ok(results) => results, + Err(_) if cancel_token.is_cancelled() => return Ok(()), + Err(err) => { + last_err = Some(err); + Vec::new() + } + }; for chunk_result in retry_results { + if cancel_token.is_cancelled() { + return Ok(()); + } + if let Err(e) = chunk_result.result { log::error!("Retry failed for chunk: {e}"); last_err = Some(e); @@ -581,6 +658,10 @@ pub async fn download_game_files( } } + if cancel_token.is_cancelled() { + return Ok(()); + } + if let Some(err) = last_err { tx_notify_ui.send(PeerEvent::DownloadGameFilesFailed { id: game_id.to_string(), diff --git a/crates/lanspread-peer/src/events.rs b/crates/lanspread-peer/src/events.rs index e637272..6166cae 100644 --- a/crates/lanspread-peer/src/events.rs +++ b/crates/lanspread-peer/src/events.rs @@ -6,9 +6,10 @@ use tokio::sync::{RwLock, mpsc::UnboundedSender}; use crate::{PeerEvent, peer_db::PeerGameDB}; -pub fn send(tx_notify_ui: &UnboundedSender, event: PeerEvent, label: &str) { +pub fn send(tx_notify_ui: &UnboundedSender, event: PeerEvent) { if let Err(err) = tx_notify_ui.send(event) { - log::error!("Failed to send {label} event: {err}"); + let kind: &'static str = (&err.0).into(); + log::error!("Failed to send {kind} event: channel closed"); } } @@ -17,7 +18,7 @@ pub async fn emit_peer_game_list( tx_notify_ui: &UnboundedSender, ) { let games = { peer_game_db.read().await.get_all_games() }; - send(tx_notify_ui, PeerEvent::ListGames(games), "ListGames"); + send(tx_notify_ui, PeerEvent::ListGames(games)); } pub async fn emit_peer_count( @@ -25,11 +26,7 @@ pub async fn emit_peer_count( tx_notify_ui: &UnboundedSender, ) { let peer_count = { peer_game_db.read().await.get_peer_addresses().len() }; - send( - tx_notify_ui, - PeerEvent::PeerCountUpdated(peer_count), - "PeerCountUpdated", - ); + send(tx_notify_ui, PeerEvent::PeerCountUpdated(peer_count)); } pub async fn emit_peer_discovered( @@ -37,11 +34,7 @@ pub async fn emit_peer_discovered( tx_notify_ui: &UnboundedSender, peer_addr: SocketAddr, ) { - send( - tx_notify_ui, - PeerEvent::PeerDiscovered(peer_addr), - "PeerDiscovered", - ); + send(tx_notify_ui, PeerEvent::PeerDiscovered(peer_addr)); emit_peer_count(peer_game_db, tx_notify_ui).await; } @@ -50,6 +43,6 @@ pub async fn emit_peer_lost( tx_notify_ui: &UnboundedSender, peer_addr: SocketAddr, ) { - send(tx_notify_ui, PeerEvent::PeerLost(peer_addr), "PeerLost"); + send(tx_notify_ui, PeerEvent::PeerLost(peer_addr)); emit_peer_count(peer_game_db, tx_notify_ui).await; } diff --git a/crates/lanspread-peer/src/handlers.rs b/crates/lanspread-peer/src/handlers.rs index f4d0b83..97c1b36 100644 --- a/crates/lanspread-peer/src/handlers.rs +++ b/crates/lanspread-peer/src/handlers.rs @@ -7,7 +7,7 @@ use tokio::sync::{RwLock, mpsc::UnboundedSender}; use crate::{ PeerEvent, - context::Ctx, + context::{Ctx, DownloadStateGuard}, download::download_game_files, events, identity::FEATURE_LIBRARY_DELTA, @@ -86,7 +86,7 @@ pub async fn handle_get_game_command( let peer_game_db = ctx.peer_game_db.clone(); let tx_notify_ui = tx_notify_ui.clone(); - tokio::spawn(async move { + ctx.task_tracker.spawn(async move { let mut fetched_any = false; for peer_addr in peers { match request_game_details_and_update(peer_addr, &id, peer_game_db.clone()).await { @@ -221,8 +221,17 @@ pub async fn handle_download_game_files_command( let active_downloads = ctx.active_downloads.clone(); let tx_notify_ui_clone = tx_notify_ui.clone(); let download_id = id.clone(); + let cancel_token = ctx.shutdown.child_token(); + + ctx.active_downloads + .write() + .await + .insert(id, cancel_token.clone()); + + ctx.task_tracker.spawn(async move { + let _download_state_guard = + DownloadStateGuard::new(download_id.clone(), downloading_games, active_downloads); - let handle = tokio::spawn(async move { let result = download_game_files( &download_id, resolved_descriptions, @@ -230,27 +239,14 @@ pub async fn handle_download_game_files_command( peer_whitelist, file_peer_map, tx_notify_ui_clone.clone(), + cancel_token, ) .await; - { - let mut guard = downloading_games.write().await; - guard.remove(&download_id); - } - if let Err(e) = result { log::error!("Download failed for {download_id}: {e}"); - if let Err(send_err) = tx_notify_ui_clone.send(PeerEvent::DownloadGameFilesFailed { - id: download_id.clone(), - }) { - log::error!("Failed to send DownloadGameFilesFailed event: {send_err}"); - } } - - let _ = active_downloads.write().await.remove(&download_id); }); - - ctx.active_downloads.write().await.insert(id, handle); } /// Handles the `SetGameDir` command. @@ -265,7 +261,7 @@ pub async fn handle_set_game_dir_command( let tx_notify_ui = tx_notify_ui.clone(); let ctx_clone = ctx.clone(); - tokio::spawn(async move { + ctx.task_tracker.spawn(async move { match load_local_library(&ctx_clone, &tx_notify_ui).await { Ok(()) => log::info!("Local game database loaded successfully"), Err(e) => { @@ -345,14 +341,14 @@ pub async fn update_and_announce_games( .any(|feature| feature == FEATURE_LIBRARY_DELTA) { let delta = delta.clone(); - tokio::spawn(async move { + ctx.task_tracker.spawn(async move { if let Err(e) = send_library_delta(peer_addr, delta).await { log::warn!("Failed to send library delta to {peer_addr}: {e}"); } }); } else { let games_clone = all_games.clone(); - tokio::spawn(async move { + ctx.task_tracker.spawn(async move { if let Err(e) = announce_games_to_peer(peer_addr, games_clone).await { log::warn!("Failed to announce games to {peer_addr}: {e}"); } diff --git a/crates/lanspread-peer/src/lib.rs b/crates/lanspread-peer/src/lib.rs index a08e76c..2ac864f 100644 --- a/crates/lanspread-peer/src/lib.rs +++ b/crates/lanspread-peer/src/lib.rs @@ -43,7 +43,9 @@ use tokio::sync::{ RwLock, mpsc::{UnboundedReceiver, UnboundedSender}, }; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; +pub use crate::startup::PeerRuntimeHandle; use crate::{ context::Ctx, handlers::{ @@ -61,7 +63,7 @@ use crate::{ // ============================================================================= /// Events sent from the peer system to the UI. -#[derive(Debug)] +#[derive(Debug, strum::IntoStaticStr)] pub enum PeerEvent { /// List of available games from peers. ListGames(Vec), @@ -92,6 +94,26 @@ pub enum PeerEvent { PeerCountUpdated(usize), /// Local games have been updated. LocalGamesUpdated(Vec), + /// A required peer runtime component failed. + RuntimeFailed { + component: PeerRuntimeComponent, + error: String, + }, +} + +/// Long-running peer runtime components reported in failure events. +#[derive(Clone, Copy, Debug, strum::IntoStaticStr)] +pub enum PeerRuntimeComponent { + /// Command/control message loop. + CommandLoop, + /// Inbound QUIC server and its mDNS advertisement. + QuicServer, + /// mDNS peer discovery. + Discovery, + /// Peer liveness monitoring. + Liveness, + /// Local game directory monitor. + LocalMonitor, } /// Commands sent to the peer system from the UI. @@ -119,23 +141,20 @@ pub enum PeerCommand { /// Initialize and start the peer system. /// /// This is the main entry point for the peer system. It starts all background -/// services (server, discovery, ping, local monitor) and returns a channel -/// for sending commands. +/// services (server, discovery, ping, local monitor) and returns a handle that +/// owns the command sender plus a shutdown signal callers can use for clean +/// teardown. /// /// # Arguments /// /// * `game_dir` - Path to the local game directory /// * `tx_notify_ui` - Channel for sending events to the UI /// * `peer_game_db` - Shared peer game database -/// -/// # Returns -/// -/// A channel sender for sending commands to the peer system. pub fn start_peer( game_dir: impl Into, tx_notify_ui: UnboundedSender, peer_game_db: Arc>, -) -> eyre::Result> { +) -> eyre::Result { let game_dir = game_dir.into(); log::info!( "Starting peer system with game directory: {}", @@ -145,9 +164,14 @@ pub fn start_peer( let (tx_control, rx_control) = tokio::sync::mpsc::unbounded_channel(); - startup::spawn_peer_runtime(rx_control, tx_notify_ui, peer_game_db, peer_id, game_dir); - - Ok(tx_control) + Ok(startup::spawn_peer_runtime( + tx_control, + rx_control, + tx_notify_ui, + peer_game_db, + peer_id, + game_dir, + )) } /// Main peer execution loop that handles peer commands and manages the peer system. @@ -157,13 +181,26 @@ async fn run_peer( peer_game_db: Arc>, peer_id: String, game_dir: PathBuf, + shutdown: CancellationToken, + task_tracker: TaskTracker, ) -> eyre::Result<()> { - let ctx = Ctx::new(peer_game_db, peer_id, game_dir); + let ctx = Ctx::new(peer_game_db, peer_id, game_dir, shutdown, task_tracker); if let Err(err) = load_local_library(&ctx, &tx_notify_ui).await { log::error!("Failed to load initial local game database: {err}"); } startup::spawn_startup_services(&ctx, &tx_notify_ui); - handle_peer_commands(&ctx, &tx_notify_ui, &mut rx_control).await; + if let Err(err) = handle_peer_commands(&ctx, &tx_notify_ui, &mut rx_control).await { + let error = err.to_string(); + log::error!("Peer command loop failed: {error}"); + events::send( + &tx_notify_ui, + PeerEvent::RuntimeFailed { + component: PeerRuntimeComponent::CommandLoop, + error, + }, + ); + ctx.shutdown.cancel(); + } startup::send_goodbye_notifications(&ctx).await; Ok(()) @@ -173,8 +210,20 @@ async fn handle_peer_commands( ctx: &Ctx, tx_notify_ui: &UnboundedSender, rx_control: &mut UnboundedReceiver, -) { - while let Some(cmd) = rx_control.recv().await { +) -> eyre::Result<()> { + loop { + let cmd = tokio::select! { + () = ctx.shutdown.cancelled() => return Ok(()), + cmd = rx_control.recv() => cmd, + }; + + let Some(cmd) = cmd else { + if ctx.shutdown.is_cancelled() { + return Ok(()); + } + eyre::bail!("peer command channel closed unexpectedly"); + }; + match cmd { PeerCommand::ListGames => { handle_list_games_command(ctx, tx_notify_ui).await; diff --git a/crates/lanspread-peer/src/services/advertise.rs b/crates/lanspread-peer/src/services/advertise.rs index 620c909..e504154 100644 --- a/crates/lanspread-peer/src/services/advertise.rs +++ b/crates/lanspread-peer/src/services/advertise.rs @@ -2,15 +2,16 @@ use std::{collections::HashMap, net::SocketAddr, time::Duration}; -use lanspread_mdns::{LANSPREAD_SERVICE_TYPE, MdnsAdvertiser}; +use lanspread_mdns::{DaemonEvent, LANSPREAD_SERVICE_TYPE, MdnsAdvertiser, MdnsMonitor}; use lanspread_proto::PROTOCOL_VERSION; +use tokio_util::sync::CancellationToken; use crate::{context::PeerCtx, network::select_advertise_ip}; pub(super) async fn start_mdns_advertiser( ctx: &PeerCtx, server_addr: SocketAddr, -) -> eyre::Result<()> { +) -> eyre::Result { let advertise_ip = select_advertise_ip()?; let advertise_addr = SocketAddr::new(advertise_ip, server_addr.port()); log::info!("Advertising peer via mDNS from {advertise_addr}"); @@ -36,22 +37,34 @@ pub(super) async fn start_mdns_advertiser( }) .await??; - tokio::spawn(async move { - log::info!("Registered mDNS service with name: {monitor_name}"); - while let Ok(event) = mdns.monitor.recv() { - match event { - lanspread_mdns::DaemonEvent::Error(err) => { - log::error!("mDNS error: {err}"); - tokio::time::sleep(Duration::from_secs(1)).await; - } - _ => { - log::trace!("mDNS event: {event:?}"); + log::info!("Registered mDNS service with name: {monitor_name}"); + Ok(mdns) +} + +pub(super) async fn monitor_mdns_events(monitor: MdnsMonitor, shutdown: CancellationToken) { + loop { + let event = tokio::select! { + () = shutdown.cancelled() => break, + event = monitor.recv_async() => event, + }; + + match event { + Ok(DaemonEvent::Error(err)) => { + log::error!("mDNS error: {err}"); + tokio::select! { + () = shutdown.cancelled() => break, + () = tokio::time::sleep(Duration::from_secs(1)) => {} } } + Ok(other_event) => { + log::trace!("mDNS event: {other_event:?}"); + } + Err(err) => { + log::debug!("mDNS monitor channel closed: {err}"); + break; + } } - }); - - Ok(()) + } } fn advertised_service_name(hostname: &str, peer_id: &str) -> String { diff --git a/crates/lanspread-peer/src/services/discovery.rs b/crates/lanspread-peer/src/services/discovery.rs index 6291f88..407f238 100644 --- a/crates/lanspread-peer/src/services/discovery.rs +++ b/crates/lanspread-peer/src/services/discovery.rs @@ -2,7 +2,7 @@ use std::time::Duration; -use lanspread_mdns::{LANSPREAD_SERVICE_TYPE, MdnsBrowser, MdnsService}; +use lanspread_mdns::{LANSPREAD_SERVICE_TYPE, MdnsBrowser, MdnsService, MdnsServicePoll}; use lanspread_proto::PROTOCOL_VERSION; use tokio::sync::mpsc::UnboundedSender; @@ -23,54 +23,73 @@ struct MdnsPeerInfo { } /// Runs the peer discovery service using mDNS. -pub async fn run_peer_discovery(tx_notify_ui: UnboundedSender, ctx: Ctx) { +pub async fn run_peer_discovery( + tx_notify_ui: UnboundedSender, + ctx: Ctx, +) -> eyre::Result<()> { log::info!("Starting peer discovery task"); let service_type = LANSPREAD_SERVICE_TYPE.to_string(); + let (service_tx, mut service_rx) = tokio::sync::mpsc::unbounded_channel(); + let worker_shutdown = ctx.shutdown.clone(); + let service_type_clone = service_type.clone(); - loop { - let (service_tx, mut service_rx) = tokio::sync::mpsc::unbounded_channel(); - let service_type_clone = service_type.clone(); - - let worker_handle = tokio::task::spawn_blocking(move || -> eyre::Result<()> { + let worker_handle = ctx + .task_tracker + .spawn_blocking(move || -> eyre::Result<()> { let browser = MdnsBrowser::new(&service_type_clone)?; - loop { - if let Some(service) = browser.next_service(None)? { - if service_tx.send(service).is_err() { - log::debug!("Peer discovery consumer dropped; stopping worker"); + while !worker_shutdown.is_cancelled() { + match browser.next_service_timeout(None, Duration::from_millis(250))? { + MdnsServicePoll::Service(service) => { + if service_tx.send(service).is_err() { + log::debug!("Peer discovery consumer dropped; stopping worker"); + break; + } + } + MdnsServicePoll::Timeout => {} + MdnsServicePoll::Closed => { + log::warn!("mDNS browser closed; stopping peer discovery worker"); break; } - } else { - log::warn!("mDNS browser closed; stopping peer discovery worker"); - break; } } Ok(()) }); - while let Some(service) = service_rx.recv().await { - let info = parse_mdns_peer(&service); - if is_self_advertisement(&info, &ctx).await { - log::trace!("Ignoring self advertisement at {}", info.addr); - continue; - } + loop { + tokio::select! { + () = ctx.shutdown.cancelled() => break, + service = service_rx.recv() => { + let Some(service) = service else { + break; + }; - handle_discovered_peer(info, &ctx, &tx_notify_ui).await; - } + let info = parse_mdns_peer(&service); + if is_self_advertisement(&info, &ctx).await { + log::trace!("Ignoring self advertisement at {}", info.addr); + continue; + } - match worker_handle.await { - Ok(Ok(())) => { - log::warn!("Peer discovery worker exited; restarting shortly"); - } - Ok(Err(err)) => { - log::error!("Peer discovery worker failed: {err}"); - } - Err(err) => { - log::error!("Peer discovery worker join error: {err}"); + handle_discovered_peer(info, &ctx, &tx_notify_ui).await; } } + } - tokio::time::sleep(Duration::from_secs(5)).await; + match worker_handle.await { + Ok(Ok(())) if ctx.shutdown.is_cancelled() => Ok(()), + Ok(Ok(())) => { + eyre::bail!("mDNS discovery worker exited unexpectedly"); + } + Ok(Err(err)) if ctx.shutdown.is_cancelled() => { + log::debug!("Peer discovery worker stopped during shutdown: {err}"); + Ok(()) + } + Ok(Err(err)) => Err(err.wrap_err("peer discovery worker failed")), + Err(err) if ctx.shutdown.is_cancelled() => { + log::debug!("Peer discovery worker join ended during shutdown: {err}"); + Ok(()) + } + Err(err) => Err(eyre::eyre!("peer discovery worker join error: {err}")), } } @@ -146,7 +165,7 @@ fn spawn_protocol_negotiation( let local_library = ctx.local_library.clone(); let peer_game_db = ctx.peer_game_db.clone(); - tokio::spawn(async move { + ctx.task_tracker.spawn(async move { let handshake_result = if proto_ver.is_none() || proto_ver == Some(PROTOCOL_VERSION) { perform_handshake_with_peer( peer_id_arc, diff --git a/crates/lanspread-peer/src/services/legacy.rs b/crates/lanspread-peer/src/services/legacy.rs index b3c7504..e1de4a5 100644 --- a/crates/lanspread-peer/src/services/legacy.rs +++ b/crates/lanspread-peer/src/services/legacy.rs @@ -31,11 +31,7 @@ pub(super) async fn request_games_from_peer( } let aggregated_games = update_peer_from_game_list(&peer_game_db, peer_addr, &games).await; - events::send( - &tx_notify_ui, - PeerEvent::ListGames(aggregated_games), - "ListGames", - ); + events::send(&tx_notify_ui, PeerEvent::ListGames(aggregated_games)); return Ok(()); } } diff --git a/crates/lanspread-peer/src/services/liveness.rs b/crates/lanspread-peer/src/services/liveness.rs index a5f4ce3..0b45b10 100644 --- a/crates/lanspread-peer/src/services/liveness.rs +++ b/crates/lanspread-peer/src/services/liveness.rs @@ -6,10 +6,8 @@ use std::{ time::Duration, }; -use tokio::{ - sync::{RwLock, mpsc::UnboundedSender}, - task::JoinHandle, -}; +use tokio::sync::{RwLock, mpsc::UnboundedSender}; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; use crate::{ PeerEvent, @@ -24,8 +22,10 @@ pub async fn run_ping_service( tx_notify_ui: UnboundedSender, peer_game_db: Arc>, downloading_games: Arc>>, - active_downloads: Arc>>>, -) { + active_downloads: Arc>>, + shutdown: CancellationToken, + task_tracker: TaskTracker, +) -> eyre::Result<()> { log::info!( "Starting ping service ({PEER_PING_INTERVAL_SECS}s interval, \ {}s idle threshold, {}s timeout)", @@ -36,12 +36,18 @@ pub async fn run_ping_service( let mut interval = tokio::time::interval(Duration::from_secs(PEER_PING_INTERVAL_SECS)); loop { - interval.tick().await; + tokio::select! { + () = shutdown.cancelled() => return Ok(()), + _ = interval.tick() => {} + } + ping_idle_peers( &peer_game_db, &downloading_games, &active_downloads, &tx_notify_ui, + &shutdown, + &task_tracker, ) .await; @@ -58,8 +64,10 @@ pub async fn run_ping_service( async fn ping_idle_peers( peer_game_db: &Arc>, downloading_games: &Arc>>, - active_downloads: &Arc>>>, + active_downloads: &Arc>>, tx_notify_ui: &UnboundedSender, + shutdown: &CancellationToken, + task_tracker: &TaskTracker, ) { let peer_snapshots = { peer_game_db.read().await.peer_liveness_snapshot() }; @@ -72,9 +80,15 @@ async fn ping_idle_peers( let peer_game_db = peer_game_db.clone(); let downloading_games = downloading_games.clone(); let active_downloads = active_downloads.clone(); + let shutdown = shutdown.clone(); - tokio::spawn(async move { - match ping_peer(peer_addr).await { + task_tracker.spawn(async move { + let ping_result = tokio::select! { + () = shutdown.cancelled() => return, + result = ping_peer(peer_addr) => result, + }; + + match ping_result { Ok(true) => { peer_game_db.write().await.update_last_seen(&peer_id); } @@ -110,7 +124,7 @@ async fn ping_idle_peers( async fn prune_stale_peers( peer_game_db: &Arc>, downloading_games: &Arc>>, - active_downloads: &Arc>>>, + active_downloads: &Arc>>, tx_notify_ui: &UnboundedSender, ) { let stale_peers = { @@ -140,7 +154,7 @@ async fn prune_stale_peers( async fn remove_peer_and_refresh( peer_game_db: &Arc>, downloading_games: &Arc>>, - active_downloads: &Arc>>>, + active_downloads: &Arc>>, tx_notify_ui: &UnboundedSender, peer_id: PeerId, log_label: &str, @@ -176,7 +190,7 @@ async fn remove_peer( async fn handle_active_downloads_without_peers( peer_game_db: &Arc>, downloading_games: &Arc>>, - active_downloads: &Arc>>>, + active_downloads: &Arc>>, tx_notify_ui: &UnboundedSender, ) { let active_ids = { @@ -196,23 +210,14 @@ async fn handle_active_downloads_without_peers( continue; } - let removed_from_tracking = { - let mut guard = downloading_games.write().await; - guard.remove(&id) - }; - - if !removed_from_tracking { + let Some(cancel_token) = active_downloads.write().await.remove(&id) else { continue; - } - - if let Some(handle) = { active_downloads.write().await.remove(&id) } { - handle.abort(); - } + }; + cancel_token.cancel(); events::send( tx_notify_ui, PeerEvent::DownloadGameFilesAllPeersGone { id }, - "DownloadGameFilesAllPeersGone", ); } } @@ -221,3 +226,50 @@ async fn peers_still_have_game(peer_game_db: &Arc>, game_id: let guard = peer_game_db.read().await; !guard.peers_with_game(game_id).is_empty() } + +#[cfg(test)] +mod tests { + use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + }; + + use tokio::sync::RwLock; + use tokio_util::sync::CancellationToken; + + use super::handle_active_downloads_without_peers; + use crate::{PeerEvent, peer_db::PeerGameDB}; + + #[tokio::test] + async fn all_peers_gone_cancels_download_and_emits_only_peers_gone() { + let peer_game_db = Arc::new(RwLock::new(PeerGameDB::new())); + let downloading_games = Arc::new(RwLock::new(HashSet::from(["game".to_string()]))); + let cancel = CancellationToken::new(); + let active_downloads = Arc::new(RwLock::new(HashMap::from([( + "game".to_string(), + cancel.clone(), + )]))); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + + handle_active_downloads_without_peers( + &peer_game_db, + &downloading_games, + &active_downloads, + &tx, + ) + .await; + + assert!(cancel.is_cancelled()); + assert!(!active_downloads.read().await.contains_key("game")); + + let event = rx.recv().await.expect("peers-gone event should be emitted"); + assert!(matches!( + event, + PeerEvent::DownloadGameFilesAllPeersGone { id } if id == "game" + )); + assert!( + rx.try_recv().is_err(), + "peers-gone cancellation must not emit a duplicate failure event" + ); + } +} diff --git a/crates/lanspread-peer/src/services/local_monitor.rs b/crates/lanspread-peer/src/services/local_monitor.rs index 4de8017..fbfe802 100644 --- a/crates/lanspread-peer/src/services/local_monitor.rs +++ b/crates/lanspread-peer/src/services/local_monitor.rs @@ -13,7 +13,10 @@ use crate::{ }; /// Monitors the local game directory for changes. -pub async fn run_local_game_monitor(tx_notify_ui: UnboundedSender, ctx: Ctx) { +pub async fn run_local_game_monitor( + tx_notify_ui: UnboundedSender, + ctx: Ctx, +) -> eyre::Result<()> { log::info!( "Starting local game directory monitor ({LOCAL_GAME_MONITOR_INTERVAL_SECS}s interval)" ); @@ -21,7 +24,10 @@ pub async fn run_local_game_monitor(tx_notify_ui: UnboundedSender, ct let mut interval = tokio::time::interval(Duration::from_secs(LOCAL_GAME_MONITOR_INTERVAL_SECS)); loop { - interval.tick().await; + tokio::select! { + () = ctx.shutdown.cancelled() => return Ok(()), + _ = interval.tick() => {} + } let game_dir = { ctx.game_dir.read().await.clone() }; match scan_local_library(&game_dir).await { diff --git a/crates/lanspread-peer/src/services/server.rs b/crates/lanspread-peer/src/services/server.rs index b0976c5..c4200d5 100644 --- a/crates/lanspread-peer/src/services/server.rs +++ b/crates/lanspread-peer/src/services/server.rs @@ -10,7 +10,10 @@ use crate::{ config::{CERT_PEM, KEY_PEM}, context::PeerCtx, events, - services::{advertise::start_mdns_advertiser, stream::handle_peer_stream}, + services::{ + advertise::{monitor_mdns_events, start_mdns_advertiser}, + stream::handle_peer_stream, + }, }; /// Runs the QUIC server and mDNS advertiser. @@ -32,20 +35,33 @@ pub async fn run_server_component( let server_addr = server.local_addr()?; log::info!("Peer server listening on {server_addr}"); - start_mdns_advertiser(&ctx, server_addr).await?; + let mdns_advertiser = start_mdns_advertiser(&ctx, server_addr).await?; + let mdns_monitor = mdns_advertiser.monitor.clone(); + let mdns_shutdown = ctx.shutdown.clone(); + ctx.task_tracker.spawn(async move { + monitor_mdns_events(mdns_monitor, mdns_shutdown).await; + }); + + loop { + let connection = tokio::select! { + () = ctx.shutdown.cancelled() => return Ok(()), + connection = server.accept() => connection, + }; + + let Some(connection) = connection else { + eyre::bail!("QUIC server accept loop ended unexpectedly"); + }; - while let Some(connection) = server.accept().await { let ctx = ctx.clone(); let tx_notify_ui = tx_notify_ui.clone(); + let task_tracker = ctx.task_tracker.clone(); - tokio::spawn(async move { + task_tracker.spawn(async move { if let Err(err) = handle_peer_connection(connection, ctx, tx_notify_ui).await { log::error!("Peer connection error: {err}"); } }); } - - Ok(()) } async fn handle_peer_connection( @@ -55,25 +71,27 @@ async fn handle_peer_connection( ) -> eyre::Result<()> { let remote_addr = connection.remote_addr()?; log::info!("{remote_addr} peer connected"); - events::send( - &tx_notify_ui, - PeerEvent::PeerConnected(remote_addr), - "PeerConnected", - ); + events::send(&tx_notify_ui, PeerEvent::PeerConnected(remote_addr)); + + loop { + let stream = tokio::select! { + () = ctx.shutdown.cancelled() => break, + stream = connection.accept_bidirectional_stream() => stream, + }; + + let Some(stream) = stream? else { + break; + }; - while let Ok(Some(stream)) = connection.accept_bidirectional_stream().await { let ctx = ctx.clone(); - tokio::spawn(async move { + let task_tracker = ctx.task_tracker.clone(); + task_tracker.spawn(async move { if let Err(err) = handle_peer_stream(stream, ctx, Some(remote_addr)).await { log::error!("{remote_addr:?} peer stream error: {err}"); } }); } - events::send( - &tx_notify_ui, - PeerEvent::PeerDisconnected(remote_addr), - "PeerDisconnected", - ); + events::send(&tx_notify_ui, PeerEvent::PeerDisconnected(remote_addr)); Ok(()) } diff --git a/crates/lanspread-peer/src/services/stream.rs b/crates/lanspread-peer/src/services/stream.rs index b58eee3..7808442 100644 --- a/crates/lanspread-peer/src/services/stream.rs +++ b/crates/lanspread-peer/src/services/stream.rs @@ -38,7 +38,12 @@ pub(super) async fn handle_peer_stream( log::trace!("{remote_addr:?} peer stream opened"); loop { - match framed_rx.next().await { + let next_message = tokio::select! { + () = ctx.shutdown.cancelled() => break, + next_message = framed_rx.next() => next_message, + }; + + match next_message { Some(Ok(data)) => { log::trace!( "{:?} msg: (raw): {}", @@ -191,7 +196,7 @@ async fn handle_library_summary( } if summary.library_digest != previous_digest || previous_count == 0 { - tokio::spawn({ + ctx.task_tracker.spawn({ let peer_id_arc = ctx.peer_id.clone(); let local_library = ctx.local_library.clone(); let peer_game_db = ctx.peer_game_db.clone(); @@ -357,10 +362,6 @@ async fn handle_announce_games(ctx: &PeerCtx, remote_addr: Option, g if let Some(addr) = remote_addr { let aggregated_games = update_peer_from_game_list(&ctx.peer_game_db, addr, &games).await; - events::send( - &ctx.tx_notify_ui, - PeerEvent::ListGames(aggregated_games), - "ListGames", - ); + events::send(&ctx.tx_notify_ui, PeerEvent::ListGames(aggregated_games)); } } diff --git a/crates/lanspread-peer/src/startup.rs b/crates/lanspread-peer/src/startup.rs index 24dfacd..7c92490 100644 --- a/crates/lanspread-peer/src/startup.rs +++ b/crates/lanspread-peer/src/startup.rs @@ -1,16 +1,29 @@ //! Peer runtime task startup and shutdown orchestration. -use std::{net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; +use std::{ + any::Any, + future::Future, + net::SocketAddr, + panic::AssertUnwindSafe, + path::PathBuf, + sync::Arc, + time::Duration, +}; +use futures::FutureExt as _; use tokio::sync::{ RwLock, mpsc::{UnboundedReceiver, UnboundedSender}, + watch, }; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; use crate::{ PeerCommand, PeerEvent, + PeerRuntimeComponent, context::Ctx, + events, network::send_goodbye, peer_db::PeerGameDB, run_peer, @@ -22,19 +35,86 @@ use crate::{ }, }; +/// Handle to a running peer runtime. +/// +/// Holds the command sender plus the runtime's shutdown token and a `stopped` +/// signal so callers can request a clean shutdown and wait for goodbye +/// notifications to flush. +pub struct PeerRuntimeHandle { + tx: UnboundedSender, + shutdown: CancellationToken, + stopped: watch::Receiver, +} + +impl PeerRuntimeHandle { + /// Returns a clone of the command channel sender. + #[must_use] + pub fn sender(&self) -> UnboundedSender { + self.tx.clone() + } + + /// Signals the runtime to shut down. Idempotent. + pub fn shutdown(&self) { + self.shutdown.cancel(); + } + + /// Resolves once the runtime task has fully stopped (services drained, + /// goodbye notifications sent). Returns even if the runtime stopped + /// without an explicit shutdown request. + pub async fn wait_stopped(&mut self) { + let _ = self.stopped.wait_for(|stopped| *stopped).await; + } +} + +#[derive(Clone, Copy, Debug)] +pub(crate) enum SupervisionPolicy { + Required, + Restart { backoff: Duration }, + BestEffort, +} + pub(crate) fn spawn_peer_runtime( + tx_control: UnboundedSender, rx_control: UnboundedReceiver, tx_notify_ui: UnboundedSender, peer_game_db: Arc>, peer_id: String, game_dir: PathBuf, -) { +) -> PeerRuntimeHandle { + let shutdown = CancellationToken::new(); + let task_tracker = TaskTracker::new(); + let (tx_stopped, stopped) = watch::channel(false); + + let runtime_shutdown = shutdown.clone(); + let runtime_tracker = task_tracker.clone(); tokio::spawn(async move { - if let Err(err) = run_peer(rx_control, tx_notify_ui, peer_game_db, peer_id, game_dir).await + if let Err(err) = run_peer( + rx_control, + tx_notify_ui, + peer_game_db, + peer_id, + game_dir, + runtime_shutdown.clone(), + runtime_tracker.clone(), + ) + .await { log::error!("Peer system failed: {err}"); } + + runtime_shutdown.cancel(); + runtime_tracker.close(); + runtime_tracker.wait().await; + if tx_stopped.send(true).is_err() { + log::debug!("Peer runtime stopped after handle was dropped"); + } }); + + PeerRuntimeHandle { + tx: tx_control, + shutdown, + stopped, + } } pub(crate) fn spawn_startup_services(ctx: &Ctx, tx_notify_ui: &UnboundedSender) { @@ -60,21 +140,43 @@ fn spawn_quic_server(ctx: &Ctx, tx_notify_ui: &UnboundedSender) { let server_addr = SocketAddr::from(([0, 0, 0, 0], 0)); let peer_ctx = ctx.to_peer_ctx(tx_notify_ui.clone()); let tx_notify_ui = tx_notify_ui.clone(); + let supervisor_tx = tx_notify_ui.clone(); - tokio::spawn(async move { - if let Err(err) = run_server_component(server_addr, peer_ctx, tx_notify_ui).await { - log::error!("Server component error: {err}"); - } - }); + spawn_supervised_service( + &ctx.task_tracker, + &ctx.shutdown, + &supervisor_tx, + PeerRuntimeComponent::QuicServer, + SupervisionPolicy::Required, + move || { + let peer_ctx = peer_ctx.clone(); + let tx_notify_ui = tx_notify_ui.clone(); + async move { run_server_component(server_addr, peer_ctx, tx_notify_ui).await } + }, + ); } fn spawn_peer_discovery_service(ctx: &Ctx, tx_notify_ui: &UnboundedSender) { let ctx = ctx.clone(); let tx_notify_ui = tx_notify_ui.clone(); + let task_tracker = ctx.task_tracker.clone(); + let shutdown = ctx.shutdown.clone(); + let supervisor_tx = tx_notify_ui.clone(); - tokio::spawn(async move { - run_peer_discovery(tx_notify_ui, ctx).await; - }); + spawn_supervised_service( + &task_tracker, + &shutdown, + &supervisor_tx, + PeerRuntimeComponent::Discovery, + SupervisionPolicy::Restart { + backoff: Duration::from_secs(5), + }, + move || { + let ctx = ctx.clone(); + let tx_notify_ui = tx_notify_ui.clone(); + async move { run_peer_discovery(tx_notify_ui, ctx).await } + }, + ); } fn spawn_peer_liveness_service(ctx: &Ctx, tx_notify_ui: &UnboundedSender) { @@ -82,25 +184,59 @@ fn spawn_peer_liveness_service(ctx: &Ctx, tx_notify_ui: &UnboundedSender) { let ctx = ctx.clone(); let tx_notify_ui = tx_notify_ui.clone(); + let task_tracker = ctx.task_tracker.clone(); + let shutdown = ctx.shutdown.clone(); + let supervisor_tx = tx_notify_ui.clone(); - tokio::spawn(async move { - run_local_game_monitor(tx_notify_ui, ctx).await; - }); + spawn_supervised_service( + &task_tracker, + &shutdown, + &supervisor_tx, + PeerRuntimeComponent::LocalMonitor, + SupervisionPolicy::BestEffort, + move || { + let ctx = ctx.clone(); + let tx_notify_ui = tx_notify_ui.clone(); + async move { run_local_game_monitor(tx_notify_ui, ctx).await } + }, + ); } async fn send_goodbye_notification(peer_addr: SocketAddr, peer_id: String) { @@ -110,3 +246,210 @@ async fn send_goodbye_notification(peer_addr: SocketAddr, peer_id: String) { Err(_) => log::warn!("Timed out sending Goodbye to {peer_addr}"), } } + +fn spawn_supervised_service( + task_tracker: &TaskTracker, + shutdown: &CancellationToken, + tx_notify_ui: &UnboundedSender, + component: PeerRuntimeComponent, + policy: SupervisionPolicy, + mut make_service: F, +) where + F: FnMut() -> Fut + Send + 'static, + Fut: Future> + Send + 'static, +{ + let task_tracker = task_tracker.clone(); + let shutdown = shutdown.clone(); + let tx_notify_ui = tx_notify_ui.clone(); + + task_tracker.spawn(async move { + loop { + if shutdown.is_cancelled() { + break; + } + + let result = match AssertUnwindSafe(make_service()).catch_unwind().await { + Ok(result) => result, + Err(payload) => Err(eyre::eyre!( + "component panicked: {}", + panic_payload_to_string(&payload) + )), + }; + if shutdown.is_cancelled() { + break; + } + + match policy { + SupervisionPolicy::Required => { + let error = match result { + Ok(()) => "component exited unexpectedly".to_string(), + Err(err) => err.to_string(), + }; + report_required_service_failure(&tx_notify_ui, component, error, &shutdown); + break; + } + SupervisionPolicy::Restart { backoff } => { + match result { + Ok(()) => log::warn!("{component:?} exited; restarting in {backoff:?}"), + Err(err) => { + log::error!("{component:?} failed: {err}; restarting in {backoff:?}"); + } + } + + tokio::select! { + () = shutdown.cancelled() => break, + () = tokio::time::sleep(backoff) => {} + } + } + SupervisionPolicy::BestEffort => { + match result { + Ok(()) => log::warn!("{component:?} exited"), + Err(err) => log::error!("{component:?} failed: {err}"), + } + break; + } + } + } + }); +} + +fn report_required_service_failure( + tx_notify_ui: &UnboundedSender, + component: PeerRuntimeComponent, + error: String, + shutdown: &CancellationToken, +) { + log::error!("{component:?} failed: {error}"); + events::send(tx_notify_ui, PeerEvent::RuntimeFailed { component, error }); + shutdown.cancel(); +} + +fn panic_payload_to_string(payload: &(dyn Any + Send)) -> String { + if let Some(message) = payload.downcast_ref::<&'static str>() { + return (*message).to_string(); + } + + if let Some(message) = payload.downcast_ref::() { + return message.clone(); + } + + "unknown panic payload".to_string() +} + +#[cfg(test)] +mod tests { + use std::{ + sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, + }, + time::Duration, + }; + + use tokio_util::{sync::CancellationToken, task::TaskTracker}; + + use super::{SupervisionPolicy, spawn_supervised_service}; + use crate::{PeerRuntimeComponent, startup::PeerRuntimeHandle}; + + #[tokio::test] + async fn required_service_failure_cancels_runtime_and_emits_event() { + let tracker = TaskTracker::new(); + let shutdown = CancellationToken::new(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + + spawn_supervised_service( + &tracker, + &shutdown, + &tx, + PeerRuntimeComponent::QuicServer, + SupervisionPolicy::Required, + || async { Err(eyre::eyre!("bind failed")) }, + ); + + let event = tokio::time::timeout(Duration::from_secs(1), rx.recv()) + .await + .expect("runtime failure event should arrive") + .expect("event channel should stay open"); + + assert!(shutdown.is_cancelled()); + assert!(matches!( + event, + crate::PeerEvent::RuntimeFailed { + component: PeerRuntimeComponent::QuicServer, + .. + } + )); + + tracker.close(); + tokio::time::timeout(Duration::from_secs(1), tracker.wait()) + .await + .expect("supervisor task should stop"); + } + + #[tokio::test] + async fn restart_service_restarts_until_shutdown() { + let tracker = TaskTracker::new(); + let shutdown = CancellationToken::new(); + let (tx, _rx) = tokio::sync::mpsc::unbounded_channel(); + let attempts = Arc::new(AtomicUsize::new(0)); + + spawn_supervised_service( + &tracker, + &shutdown, + &tx, + PeerRuntimeComponent::Discovery, + SupervisionPolicy::Restart { + backoff: Duration::from_millis(10), + }, + { + let attempts = attempts.clone(); + move || { + let attempts = attempts.clone(); + async move { + attempts.fetch_add(1, Ordering::SeqCst); + Err(eyre::eyre!("discovery worker stopped")) + } + } + }, + ); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + if attempts.load(Ordering::SeqCst) >= 2 { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("restartable service should run more than once"); + + shutdown.cancel(); + tracker.close(); + tokio::time::timeout(Duration::from_secs(1), tracker.wait()) + .await + .expect("restart supervisor should stop after shutdown"); + } + + #[tokio::test] + async fn runtime_handle_can_shutdown_and_await_stopped() { + let (tx, _rx) = tokio::sync::mpsc::unbounded_channel(); + let shutdown = CancellationToken::new(); + let (tx_stopped, stopped) = tokio::sync::watch::channel(false); + let mut handle = PeerRuntimeHandle { + tx, + shutdown: shutdown.clone(), + stopped, + }; + + tokio::spawn(async move { + shutdown.cancelled().await; + let _ = tx_stopped.send(true); + }); + + handle.shutdown(); + tokio::time::timeout(Duration::from_secs(1), handle.wait_stopped()) + .await + .expect("runtime handle should observe stopped"); + } +} diff --git a/crates/lanspread-tauri-deno-ts/src-tauri/src/lib.rs b/crates/lanspread-tauri-deno-ts/src-tauri/src/lib.rs index c898ce2..5954c2c 100644 --- a/crates/lanspread-tauri-deno-ts/src-tauri/src/lib.rs +++ b/crates/lanspread-tauri-deno-ts/src-tauri/src/lib.rs @@ -10,7 +10,7 @@ use std::{ use eyre::bail; use lanspread_compat::eti::get_games; use lanspread_db::db::{Game, GameDB, GameFileDescription}; -use lanspread_peer::{PeerCommand, PeerEvent, PeerGameDB, start_peer}; +use lanspread_peer::{PeerCommand, PeerEvent, PeerGameDB, PeerRuntimeHandle, start_peer}; use tauri::{AppHandle, Emitter as _, Manager}; use tauri_plugin_shell::{ShellExt, process::Command}; use tokio::sync::{ @@ -24,6 +24,7 @@ use tokio::sync::{ #[derive(Default)] struct LanSpreadState { peer_ctrl: Arc>>>, + peer_runtime: Arc>>, games: Arc>, games_in_download: Arc>>, games_folder: Arc>, @@ -832,9 +833,11 @@ async fn ensure_peer_started(app_handle: &AppHandle, games_folder: &Path) { tx_peer_event, state.peer_game_db.clone(), ) { - Ok(new_peer_ctrl) => { - *peer_ctrl = Some(new_peer_ctrl.clone()); - if let Err(e) = new_peer_ctrl.send(PeerCommand::ListGames) { + Ok(handle) => { + let sender = handle.sender(); + *peer_ctrl = Some(sender.clone()); + *state.peer_runtime.write().await = Some(handle); + if let Err(e) = sender.send(PeerCommand::ListGames) { log::error!("Failed to send initial PeerCommand::ListGames: {e}"); } log::info!("Peer system initialized successfully with games directory"); @@ -865,6 +868,7 @@ fn spawn_peer_event_loop(app_handle: AppHandle, mut rx_peer_event: UnboundedRece }); } +#[allow(clippy::too_many_lines)] async fn handle_peer_event(app_handle: &AppHandle, event: PeerEvent) { match event { PeerEvent::ListGames(games) => { @@ -956,6 +960,16 @@ async fn handle_peer_event(app_handle: &AppHandle, event: PeerEvent) { log::error!("Failed to emit peer-count-updated event: {e}"); } } + PeerEvent::RuntimeFailed { component, error } => { + let component_name: &'static str = (&component).into(); + log::error!("Peer runtime component {component_name} failed: {error}"); + if let Err(e) = app_handle.emit( + "peer-runtime-failed", + Some((component_name.to_string(), error)), + ) { + log::error!("Failed to emit peer-runtime-failed event: {e}"); + } + } } } @@ -1066,6 +1080,29 @@ pub fn run() { spawn_peer_event_loop(app.handle().clone(), rx_peer_event); Ok(()) }) - .run(tauri::generate_context!()) - .expect("error while running tauri application"); + .build(tauri::generate_context!()) + .expect("error while building tauri application") + .run(|app_handle, event| { + if matches!(event, tauri::RunEvent::Exit) { + shutdown_peer_runtime(app_handle); + } + }); +} + +fn shutdown_peer_runtime(app_handle: &AppHandle) { + let state = app_handle.state::(); + let peer_runtime = state.peer_runtime.clone(); + + tauri::async_runtime::block_on(async move { + let Some(mut handle) = peer_runtime.write().await.take() else { + return; + }; + handle.shutdown(); + if tokio::time::timeout(std::time::Duration::from_secs(2), handle.wait_stopped()) + .await + .is_err() + { + log::warn!("Peer runtime did not stop within 2s of shutdown request"); + } + }); }