diff --git a/README.md b/README.md index 200cc3e..552627a 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ Platform-neutral remote client relay session: - Ethernet frame send/receive helpers over QUIC DATAGRAM - client tunnel statistics for frame/datagram rx/tx and drops - reliable client stats snapshot sends for relay diagnostics +- best-effort graceful disconnect messages before QUIC close ### `lanparty-client-route` @@ -140,7 +141,8 @@ frame logs include direction, peer id when present, MACs, ethertype/length, frame length, action, and drop reason. The gateway also tracks frame/datagram counters and periodically sends stats snapshots to the relay. Relay lifecycle events seed and retire remote-client MACs for CAM refresh even before that -client sends traffic. +client sends traffic. On shutdown, the gateway sends a best-effort disconnect +control message before closing QUIC so the relay can report the intended reason. ## Windows Client diff --git a/crates/lanparty-client-core/src/lib.rs b/crates/lanparty-client-core/src/lib.rs index 57e9c55..da8fd5d 100644 --- a/crates/lanparty-client-core/src/lib.rs +++ b/crates/lanparty-client-core/src/lib.rs @@ -7,6 +7,7 @@ use std::{ fs, + future::poll_fn, io::ErrorKind, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, path::{Path, PathBuf}, @@ -14,13 +15,15 @@ use std::{ Arc, atomic::{AtomicU64, Ordering}, }, + time::Duration, }; use anyhow::{Context, Result, bail}; use bytes::Bytes; use lanparty_ctrl::{ - CONTROL_LENGTH_PREFIX_LEN, ControlMessage, EndpointHello, MAX_CONTROL_MESSAGE_LEN, RELAY_ALPN, - RoomCode, ServerWelcome, decode_control_frame, encode_control_message, + CONTROL_LENGTH_PREFIX_LEN, ControlMessage, DisconnectReason, EndpointHello, + MAX_CONTROL_MESSAGE_LEN, RELAY_ALPN, RoomCode, ServerWelcome, decode_control_frame, + encode_control_message, }; use lanparty_obs::{QuicDiagnostics, TunnelStats}; use lanparty_proto::{EthernetFrame, FrameType, MacAddr, decode_datagram, encode_datagram}; @@ -28,6 +31,7 @@ use quinn::{ClientConfig, Endpoint, crypto::rustls::QuicClientConfig}; use rustls::pki_types::CertificateDer; const MAX_CONTROL_FRAME_LEN: usize = CONTROL_LENGTH_PREFIX_LEN + MAX_CONTROL_MESSAGE_LEN; +const DISCONNECT_DRAIN_TIMEOUT: Duration = Duration::from_millis(250); #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ClientIdentity { @@ -288,6 +292,10 @@ impl ClientSession { } pub async fn shutdown(self, reason: &str) { + if send_disconnect(&self.connection, reason).await.is_ok() { + drain_disconnect().await; + } + self.connection.close(0_u32.into(), reason.as_bytes()); self.endpoint.wait_idle().await; } @@ -537,6 +545,25 @@ async fn send_control_event(connection: &quinn::Connection, message: ControlMess Ok(()) } +async fn send_disconnect(connection: &quinn::Connection, message: &str) -> Result<()> { + send_control_event( + connection, + ControlMessage::Disconnect { + reason: DisconnectReason::Normal, + message: message.to_owned(), + }, + ) + .await +} + +async fn drain_disconnect() { + let Some(runtime) = quinn::default_runtime() else { + return; + }; + let mut timer = runtime.new_timer(runtime.now() + DISCONNECT_DRAIN_TIMEOUT); + poll_fn(|cx| timer.as_mut().poll(cx)).await; +} + #[cfg(test)] mod tests { use std::time::{Duration, SystemTime, UNIX_EPOCH}; @@ -711,6 +738,18 @@ mod tests { assert_eq!(stats, TunnelStats::new(1, 1, 1, 1, 1, 1)); stats_received_tx.send(()).unwrap(); + let mut disconnect_recv = connection.accept_uni().await.unwrap(); + let disconnect_frame = disconnect_recv + .read_to_end(MAX_CONTROL_FRAME_LEN) + .await + .unwrap(); + let disconnect_message = decode_control_frame(&disconnect_frame).unwrap(); + let ControlMessage::Disconnect { reason, message } = disconnect_message else { + panic!("expected client disconnect event"); + }; + assert_eq!(reason, DisconnectReason::Normal); + assert_eq!(message, "test complete"); + connection.closed().await; endpoint.close(0_u32.into(), b"test complete"); endpoint.wait_idle().await; diff --git a/crates/lanparty-gateway/src/lib.rs b/crates/lanparty-gateway/src/lib.rs index fcf4cbd..d074c81 100644 --- a/crates/lanparty-gateway/src/lib.rs +++ b/crates/lanparty-gateway/src/lib.rs @@ -7,7 +7,7 @@ mod packet; #[cfg(target_os = "linux")] -use std::{collections::BTreeMap, time::Duration}; +use std::collections::BTreeMap; use std::{ fs, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, @@ -16,14 +16,16 @@ use std::{ Arc, atomic::{AtomicU64, Ordering}, }, + time::Duration, }; use anyhow::{Context, Result, bail}; use bytes::Bytes; use clap::Parser; use lanparty_ctrl::{ - CONTROL_LENGTH_PREFIX_LEN, ControlMessage, EndpointHello, MAX_CONTROL_MESSAGE_LEN, PeerInfo, - RELAY_ALPN, Role, RoomCode, ServerWelcome, decode_control_frame, encode_control_message, + CONTROL_LENGTH_PREFIX_LEN, ControlMessage, DisconnectReason, EndpointHello, + MAX_CONTROL_MESSAGE_LEN, PeerInfo, RELAY_ALPN, Role, RoomCode, ServerWelcome, + decode_control_frame, encode_control_message, }; use lanparty_obs::TunnelStats; #[cfg(target_os = "linux")] @@ -41,6 +43,7 @@ use tokio::io::unix::AsyncFd; pub use packet::PacketSocket; const MAX_CONTROL_FRAME_LEN: usize = CONTROL_LENGTH_PREFIX_LEN + MAX_CONTROL_MESSAGE_LEN; +const DISCONNECT_DRAIN_TIMEOUT: Duration = Duration::from_millis(250); #[cfg(target_os = "linux")] const CAM_REFRESH_INTERVAL: Duration = Duration::from_secs(60); #[cfg(target_os = "linux")] @@ -267,6 +270,16 @@ impl GatewayConnection { tokio::select! { shutdown = tokio::signal::ctrl_c() => { shutdown.context("failed to wait for Ctrl-C")?; + if let Err(error) = + send_gateway_disconnect(&connection, "gateway shutting down").await + { + eprintln!("failed to send gateway disconnect to relay: {error:#}"); + } + let _ = tokio::time::timeout( + DISCONNECT_DRAIN_TIMEOUT, + connection.closed(), + ) + .await; connection.close(0_u32.into(), b"gateway shutting down"); endpoint.wait_idle().await; return Ok(()); @@ -325,6 +338,17 @@ impl GatewayConnection { } pub async fn shutdown(self, reason: &str) { + if send_gateway_disconnect(&self.connection, reason) + .await + .is_ok() + && tokio::time::timeout(DISCONNECT_DRAIN_TIMEOUT, self.connection.closed()) + .await + .is_ok() + { + self.endpoint.wait_idle().await; + return; + } + self.connection.close(0_u32.into(), reason.as_bytes()); self.endpoint.wait_idle().await; } @@ -391,15 +415,35 @@ async fn recv_gateway_ethernet( } async fn send_gateway_stats(connection: &quinn::Connection, stats: TunnelStats) -> Result<()> { + send_gateway_control_event(connection, ControlMessage::Stats(stats), "gateway stats").await +} + +async fn send_gateway_disconnect(connection: &quinn::Connection, message: &str) -> Result<()> { + send_gateway_control_event( + connection, + ControlMessage::Disconnect { + reason: DisconnectReason::Normal, + message: message.to_owned(), + }, + "gateway disconnect", + ) + .await +} + +async fn send_gateway_control_event( + connection: &quinn::Connection, + message: ControlMessage, + context: &str, +) -> Result<()> { let mut send = connection .open_uni() .await - .context("failed to open gateway stats stream")?; - let frame = encode_control_message(&ControlMessage::Stats(stats)) - .context("failed to encode gateway stats")?; + .with_context(|| format!("failed to open {context} stream"))?; + let frame = + encode_control_message(&message).with_context(|| format!("failed to encode {context}"))?; send.write_all(&frame) .await - .context("failed to write gateway stats")?; + .with_context(|| format!("failed to write {context}"))?; send.finish()?; Ok(()) @@ -732,7 +776,6 @@ mod tests { use std::time::Duration; use bytes::Bytes; - use lanparty_ctrl::DisconnectReason; use quinn::{ServerConfig, TransportConfig, crypto::rustls::QuicServerConfig}; use rustls::pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer}; @@ -850,6 +893,18 @@ mod tests { assert_eq!(stats, TunnelStats::new(1, 1, 1, 1, 1, 1)); stats_received_tx.send(()).unwrap(); + let mut disconnect_recv = connection.accept_uni().await.unwrap(); + let disconnect_frame = disconnect_recv + .read_to_end(MAX_CONTROL_FRAME_LEN) + .await + .unwrap(); + let disconnect_message = decode_control_frame(&disconnect_frame).unwrap(); + let ControlMessage::Disconnect { reason, message } = disconnect_message else { + panic!("expected gateway disconnect event"); + }; + assert_eq!(reason, DisconnectReason::Normal); + assert_eq!(message, "test complete"); + connection.closed().await; endpoint.close(0_u32.into(), b"test complete"); endpoint.wait_idle().await;