use std::{fs, net::SocketAddr, path::Path, sync::Arc}; use anyhow::{Context, Result, anyhow}; use bytes::Bytes; use lanparty_ctrl::{ CONTROL_LENGTH_PREFIX_LEN, ControlCodecError, ControlMessage, EndpointHello, MAX_CONTROL_MESSAGE_LEN, PeerInfo, RELAY_ALPN, Reject, RejectReason, Role, RoomCode, ServerWelcome, decode_control_frame, encode_control_message, }; use lanparty_obs::{FrameDirection, FrameLog}; use lanparty_proto::{EthernetFrame, FrameType, decode_datagram, encode_datagram}; use quinn::crypto::rustls::QuicServerConfig; use quinn::{Endpoint, Incoming, SendStream, ServerConfig, TransportConfig}; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; use std::collections::HashMap; use tokio::sync::Mutex; use crate::{ForwardingDecision, RelayConfig, RoomRegistry}; const DATAGRAM_BUFFER_BYTES: usize = 4 * 1024 * 1024; const MAX_CONTROL_FRAME_LEN: usize = CONTROL_LENGTH_PREFIX_LEN + MAX_CONTROL_MESSAGE_LEN; #[derive(Debug)] pub struct RelayServer { endpoint: Endpoint, rooms: Arc>, sessions: Arc>>, } #[derive(Debug, Clone, PartialEq, Eq)] struct AcceptedPeer { room: RoomCode, welcome: ServerWelcome, peer: PeerInfo, remote_addr: SocketAddr, max_datagram_size: usize, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct PeerKey { room: RoomCode, peer_id: u32, } impl PeerKey { fn new(room: RoomCode, peer_id: u32) -> Self { Self { room, peer_id } } fn from_accepted(accepted: &AcceptedPeer) -> Self { Self::new(accepted.room.clone(), accepted.peer.peer_id()) } } #[derive(Debug, Clone)] struct PeerSession { connection: quinn::Connection, max_datagram_size: usize, } impl RelayServer { pub fn bind(config: &RelayConfig) -> Result { let (server_config, certificate) = development_server_config_with_certificate()?; if let Some(path) = config.dev_cert_der_out() { write_development_certificate(path, &certificate)?; } let endpoint = Endpoint::server(server_config, config.listen().socket_addr()) .context("failed to bind QUIC relay endpoint")?; Ok(Self::from_endpoint(endpoint, config.max_clients_per_room())) } fn from_endpoint(endpoint: Endpoint, max_clients_per_room: usize) -> Self { Self { endpoint, rooms: Arc::new(Mutex::new(RoomRegistry::new(max_clients_per_room))), sessions: Arc::new(Mutex::new(HashMap::new())), } } pub fn local_addr(&self) -> Result { self.endpoint .local_addr() .context("failed to read relay local address") } pub async fn run_until_shutdown(self) -> Result<()> { let endpoint = self.endpoint.clone(); let rooms = Arc::clone(&self.rooms); let sessions = Arc::clone(&self.sessions); loop { tokio::select! { shutdown = tokio::signal::ctrl_c() => { shutdown.context("failed to wait for Ctrl-C")?; self.shutdown("relay shutting down").await; return Ok(()); } incoming = endpoint.accept() => { let Some(incoming) = incoming else { return Ok(()); }; let rooms = Arc::clone(&rooms); let sessions = Arc::clone(&sessions); tokio::spawn(async move { if let Err(error) = handle_incoming_connection(rooms, sessions, incoming).await { eprintln!("relay connection failed: {error:#}"); } }); } } } } #[cfg(test)] async fn accept_once(&self) -> Result> { let Some(incoming) = self.endpoint.accept().await else { return Ok(None); }; handle_incoming_connection( Arc::clone(&self.rooms), Arc::clone(&self.sessions), incoming, ) .await } #[cfg(test)] async fn accept_many_for_test( &self, count: usize, ) -> Result>>>> { let mut handles = Vec::with_capacity(count); for _ in 0..count { let Some(incoming) = self.endpoint.accept().await else { return Err(anyhow!("relay endpoint stopped while accepting test peer")); }; let rooms = Arc::clone(&self.rooms); let sessions = Arc::clone(&self.sessions); handles.push(tokio::spawn(async move { handle_incoming_connection(rooms, sessions, incoming).await })); } Ok(handles) } pub async fn shutdown(self, reason: &str) { self.endpoint.close(0_u32.into(), reason.as_bytes()); self.endpoint.wait_idle().await; } } async fn handle_incoming_connection( rooms: Arc>, sessions: Arc>>, incoming: Incoming, ) -> Result> { let remote_addr = incoming.remote_address(); let connection = incoming .await .with_context(|| format!("failed to establish QUIC connection from {remote_addr}"))?; let Some(accepted) = accept_control_handshake(&rooms, &sessions, &connection).await? else { return Ok(None); }; println!( "accepted {:?} peer {} in room {} from {} with peer datagram budget {} and TAP MTU {}", accepted.peer.role(), accepted.peer.peer_id(), accepted.room, accepted.remote_addr, accepted.max_datagram_size, accepted.welcome.effective_tap_mtu() ); let close_reason = run_peer_datagrams(&rooms, &sessions, &accepted, &connection).await; leave_peer(&rooms, &sessions, &accepted.room, accepted.peer.peer_id()).await?; println!( "peer {} left room {}: {}", accepted.peer.peer_id(), accepted.room, close_reason ); Ok(Some(accepted)) } async fn accept_control_handshake( rooms: &Arc>, sessions: &Arc>>, connection: &quinn::Connection, ) -> Result> { let (mut send, mut recv) = connection .accept_bi() .await .context("failed to accept relay control stream")?; let frame = recv .read_to_end(MAX_CONTROL_FRAME_LEN) .await .context("failed to read relay control hello")?; let (accepted, response) = build_handshake_response(rooms, connection, frame.as_slice()).await; if let Some(accepted) = &accepted { register_peer(sessions, accepted, connection.clone()).await; } if let Err(error) = send_control_message(&mut send, &response).await { if let Some(accepted) = &accepted { leave_peer(rooms, sessions, &accepted.room, accepted.peer.peer_id()).await?; } return Err(error); } Ok(accepted) } async fn register_peer( sessions: &Arc>>, accepted: &AcceptedPeer, connection: quinn::Connection, ) { sessions.lock().await.insert( PeerKey::from_accepted(accepted), PeerSession { connection, max_datagram_size: accepted.max_datagram_size, }, ); } async fn run_peer_datagrams( rooms: &Arc>, sessions: &Arc>>, accepted: &AcceptedPeer, connection: &quinn::Connection, ) -> quinn::ConnectionError { loop { match connection.read_datagram().await { Ok(datagram) => { if let Err(error) = forward_peer_datagram(rooms, sessions, accepted, datagram).await { eprintln!( "failed to forward datagram from peer {} in room {}: {error:#}", accepted.peer.peer_id(), accepted.room ); } } Err(error) => return error, } } } async fn forward_peer_datagram( rooms: &Arc>, sessions: &Arc>>, accepted: &AcceptedPeer, datagram: Bytes, ) -> Result<()> { let Ok(packet) = decode_datagram(&datagram) else { return Ok(()); }; let header = packet.header(); if header.frame_type() != FrameType::Ethernet || header.room_id() != accepted.welcome.room_id() || header.peer_id() != accepted.peer.peer_id() { return Ok(()); } let decision = rooms.lock().await.forward_ethernet( &accepted.room, accepted.peer.peer_id(), packet.payload(), )?; println!( "{}", relay_frame_log_line( &accepted.room, accepted.peer.peer_id(), packet.payload(), &decision ) ); let target_peer_ids = decision.targets().to_vec(); if target_peer_ids.is_empty() { return Ok(()); } let outgoing = encode_datagram( FrameType::Ethernet, accepted.welcome.room_id(), accepted.peer.peer_id(), header.flags(), packet.payload(), )?; let target_sessions = collect_target_sessions(sessions, &accepted.room, &target_peer_ids).await; for target in target_sessions { if outgoing.len() > target.max_datagram_size { continue; } if let Err(error) = target .connection .send_datagram(Bytes::from(outgoing.clone())) { eprintln!( "failed to send datagram from peer {} in room {}: {error}", accepted.peer.peer_id(), accepted.room ); } } Ok(()) } fn relay_frame_log_line( room: &RoomCode, ingress_peer_id: u32, frame_bytes: &[u8], decision: &ForwardingDecision, ) -> String { let log = match EthernetFrame::parse(frame_bytes) { Ok(frame) => FrameLog::from_ethernet( FrameDirection::RelayIngress, Some(ingress_peer_id), decision.action(), decision.drop_reason(), frame, ), Err(_) => FrameLog::malformed( FrameDirection::RelayIngress, Some(ingress_peer_id), frame_bytes.len(), ), }; let source_mac = log .source_mac() .map(|mac| mac.to_string()) .unwrap_or_else(|| "-".to_owned()); let destination_mac = log .destination_mac() .map(|mac| mac.to_string()) .unwrap_or_else(|| "-".to_owned()); let ethertype_or_len = log .ethertype_or_len() .map(|value| format!("0x{value:04x}")) .unwrap_or_else(|| "-".to_owned()); let drop_reason = log .drop_reason() .map(|reason| format!("{reason:?}")) .unwrap_or_else(|| "-".to_owned()); format!( "relay frame room={} direction={:?} peer_id={} src={} dst={} ethertype_or_len={} len={} action={:?} drop_reason={} targets={}", room, log.direction(), log.peer_id().unwrap_or(ingress_peer_id), source_mac, destination_mac, ethertype_or_len, log.frame_len(), log.action(), drop_reason, decision.targets().len() ) } async fn collect_target_sessions( sessions: &Arc>>, room: &RoomCode, target_peer_ids: &[u32], ) -> Vec { let sessions = sessions.lock().await; target_peer_ids .iter() .filter_map(|peer_id| sessions.get(&PeerKey::new(room.clone(), *peer_id)).cloned()) .collect() } async fn build_handshake_response( rooms: &Arc>, connection: &quinn::Connection, frame: &[u8], ) -> (Option, ControlMessage) { let Some(connection_max_datagram_size) = connection.max_datagram_size() else { return reject(Reject::new( RejectReason::MtuTooSmall, "QUIC DATAGRAM support was not negotiated", )); }; let message = match decode_control_frame(frame) { Ok(message) => message, Err(error) => return reject(reject_codec_error(error)), }; let ControlMessage::Hello(hello) = message else { return reject(Reject::new( RejectReason::MalformedHello, "first relay control frame must be hello", )); }; let room = hello.room().clone(); let hello = match limit_hello_to_connection(hello, connection_max_datagram_size) { Ok(hello) => hello, Err(reject) => return (None, ControlMessage::Reject(reject)), }; let join = rooms.lock().await.join(hello); match join { Ok(join) => { let accepted = AcceptedPeer { room, welcome: join.welcome().clone(), peer: join.peer().clone(), remote_addr: connection.remote_address(), max_datagram_size: connection_max_datagram_size, }; ( Some(accepted), ControlMessage::Welcome(join.welcome().clone()), ) } Err(reject) => (None, ControlMessage::Reject(reject)), } } fn limit_hello_to_connection( hello: EndpointHello, connection_max_datagram_size: usize, ) -> Result { let max_datagram_size = usize::from(hello.max_datagram_size()) .min(connection_max_datagram_size) .min(usize::from(u16::MAX)) as u16; match hello.role() { Role::Client => EndpointHello::client( hello.room().clone(), hello .announced_mac() .expect("validated client hello has MAC address"), max_datagram_size, ), Role::Gateway => EndpointHello::gateway(hello.room().clone(), max_datagram_size), } .map_err(crate::reject_control_error) } fn reject(reject: Reject) -> (Option, ControlMessage) { (None, ControlMessage::Reject(reject)) } fn reject_codec_error(error: ControlCodecError) -> Reject { match error { ControlCodecError::InvalidMessage(error) => crate::reject_control_error(error), ControlCodecError::FrameTooShort { .. } | ControlCodecError::MessageTooLarge { .. } | ControlCodecError::IncompletePayload { .. } | ControlCodecError::TrailingBytes { .. } | ControlCodecError::Json(_) => { Reject::new(RejectReason::MalformedHello, error.to_string()) } } } async fn send_control_message(send: &mut SendStream, message: &ControlMessage) -> Result<()> { let response = encode_control_message(message).context("failed to encode control response")?; send.write_all(&response) .await .context("failed to write control response")?; send.finish() .map_err(|error| anyhow!("failed to finish control response stream: {error}"))?; Ok(()) } async fn leave_peer( rooms: &Arc>, sessions: &Arc>>, room: &RoomCode, peer_id: u32, ) -> Result<()> { sessions .lock() .await .remove(&PeerKey::new(room.clone(), peer_id)); rooms .lock() .await .leave(room, peer_id) .with_context(|| format!("failed to remove peer {peer_id} from room {room}"))?; Ok(()) } fn write_development_certificate(path: &Path, certificate: &CertificateDer<'_>) -> Result<()> { if let Some(parent) = path .parent() .filter(|parent| !parent.as_os_str().is_empty()) { fs::create_dir_all(parent).with_context(|| { format!( "failed to create certificate directory {}", parent.display() ) })?; } fs::write(path, certificate.as_ref()) .with_context(|| format!("failed to write development certificate {}", path.display()))?; Ok(()) } fn development_server_config_with_certificate() -> Result<(ServerConfig, CertificateDer<'static>)> { let certified_key = rcgen::generate_simple_self_signed(vec!["lanparty-relay.local".into()]) .context("failed to generate development relay certificate")?; let certificate = certified_key.cert.der().clone(); let cert_chain = vec![certificate.clone()]; let private_key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from( certified_key.signing_key.serialize_der(), )); let mut tls_config = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(cert_chain, private_key) .context("failed to build relay TLS config")?; tls_config.alpn_protocols = vec![RELAY_ALPN.to_vec()]; let mut server_config = ServerConfig::with_crypto(Arc::new( QuicServerConfig::try_from(tls_config).context("failed to build QUIC TLS config")?, )); let mut transport = TransportConfig::default(); transport.datagram_receive_buffer_size(Some(DATAGRAM_BUFFER_BYTES)); transport.datagram_send_buffer_size(DATAGRAM_BUFFER_BYTES); server_config.transport_config(Arc::new(transport)); Ok((server_config, certificate)) } #[cfg(test)] mod tests { use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, time::{Duration, SystemTime, UNIX_EPOCH}, }; use bytes::Bytes; use lanparty_ctrl::{RoomCode, decode_control_frame, encode_control_message}; use lanparty_proto::{FrameType, MacAddr, decode_datagram, encode_datagram}; use quinn::{ClientConfig, crypto::rustls::QuicClientConfig}; use crate::{DEFAULT_MAX_CLIENTS_PER_ROOM, ListenEndpoint}; use super::*; #[tokio::test] async fn binds_quic_endpoint_on_configured_address() { let config = RelayConfig::new( ListenEndpoint::new(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)), DEFAULT_MAX_CLIENTS_PER_ROOM, ) .unwrap(); let server = RelayServer::bind(&config).unwrap(); let local_addr = server.local_addr().unwrap(); assert_eq!(local_addr.ip(), IpAddr::V4(Ipv4Addr::LOCALHOST)); assert_ne!(local_addr.port(), 0); server.shutdown("test complete").await; } #[tokio::test] async fn accepts_client_hello_and_replies_welcome() { let (server, certificate) = bind_test_server(DEFAULT_MAX_CLIENTS_PER_ROOM); let rooms = Arc::clone(&server.rooms); let server_addr = server.local_addr().unwrap(); let server_task = tokio::spawn(async move { let accepted = server .accept_once() .await .unwrap() .expect("connection should be accepted"); server.shutdown("test complete").await; accepted }); let client = client_endpoint(certificate).unwrap(); let connection = client .connect(server_addr, "lanparty-relay.local") .unwrap() .await .unwrap(); let hello = EndpointHello::client( RoomCode::new("TESTROOM").unwrap(), MacAddr::new([0x02, 0, 0, 0, 0, 1]), 1400, ) .unwrap(); let response = request_control_message(&connection, ControlMessage::Hello(hello)) .await .unwrap(); let ControlMessage::Welcome(welcome) = response else { panic!("expected welcome response"); }; assert_eq!(welcome.room_id(), 1); assert_eq!(welcome.peer_id(), 1); assert!(welcome.effective_tap_mtu() <= 1400); connection.close(0_u32.into(), b"test complete"); client.wait_idle().await; let accepted = tokio::time::timeout(Duration::from_secs(5), server_task) .await .unwrap() .unwrap(); assert_eq!(accepted.room.as_str(), "TESTROOM"); assert_eq!(accepted.peer.peer_id(), 1); assert_eq!(accepted.welcome, welcome); assert_eq!(rooms.lock().await.room_count(), 0); } #[tokio::test] async fn writes_development_certificate_when_configured() { let cert_path = unique_temp_cert_path(); let config = RelayConfig::with_dev_cert_der_out( ListenEndpoint::new(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)), DEFAULT_MAX_CLIENTS_PER_ROOM, Some(cert_path.clone()), ) .unwrap(); let server = RelayServer::bind(&config).unwrap(); let cert = std::fs::read(&cert_path).unwrap(); assert!(!cert.is_empty()); server.shutdown("test complete").await; std::fs::remove_file(cert_path).unwrap(); } #[test] fn formats_relay_forwarding_log_line() { let decision = ForwardingDecision::forwarded(vec![2, 3]); let line = relay_frame_log_line( &RoomCode::new("TESTROOM").unwrap(), 1, ðernet_frame(client_mac(2), client_mac(1)), &decision, ); assert!(line.contains("room=TESTROOM")); assert!(line.contains("direction=RelayIngress")); assert!(line.contains("peer_id=1")); assert!(line.contains("src=02:00:00:00:00:01")); assert!(line.contains("dst=02:00:00:00:00:02")); assert!(line.contains("ethertype_or_len=0x0800")); assert!(line.contains("action=Forwarded")); assert!(line.contains("drop_reason=-")); assert!(line.contains("targets=2")); } #[tokio::test] async fn forwards_ethernet_datagrams_between_joined_peers() { let (server, certificate) = bind_test_server(DEFAULT_MAX_CLIENTS_PER_ROOM); let rooms = Arc::clone(&server.rooms); let sessions = Arc::clone(&server.sessions); let server_addr = server.local_addr().unwrap(); let server_task = tokio::spawn(async move { let handles = server.accept_many_for_test(2).await.unwrap(); let mut accepted = Vec::with_capacity(handles.len()); for handle in handles { accepted.push(handle.await.unwrap().unwrap().unwrap()); } server.shutdown("test complete").await; accepted }); let first_endpoint = client_endpoint(certificate.clone()).unwrap(); let first_connection = first_endpoint .connect(server_addr, "lanparty-relay.local") .unwrap() .await .unwrap(); let second_endpoint = client_endpoint(certificate).unwrap(); let second_connection = second_endpoint .connect(server_addr, "lanparty-relay.local") .unwrap() .await .unwrap(); let first_mac = client_mac(1); let second_mac = client_mac(2); let first_welcome = welcome_for_client(&first_connection, first_mac).await; let second_welcome = welcome_for_client(&second_connection, second_mac).await; let ethernet = ethernet_frame(second_mac, first_mac); let datagram = encode_datagram( FrameType::Ethernet, first_welcome.room_id(), first_welcome.peer_id(), 0, ðernet, ) .unwrap(); first_connection .send_datagram(Bytes::from(datagram)) .unwrap(); let received = tokio::time::timeout(Duration::from_secs(5), second_connection.read_datagram()) .await .unwrap() .unwrap(); let packet = decode_datagram(&received).unwrap(); let header = packet.header(); assert_eq!(header.frame_type(), FrameType::Ethernet); assert_eq!(header.room_id(), first_welcome.room_id()); assert_eq!(header.peer_id(), first_welcome.peer_id()); assert_eq!(packet.payload(), ethernet.as_slice()); assert_eq!(first_welcome.room_id(), second_welcome.room_id()); first_connection.close(0_u32.into(), b"test complete"); second_connection.close(0_u32.into(), b"test complete"); first_endpoint.wait_idle().await; second_endpoint.wait_idle().await; let accepted = tokio::time::timeout(Duration::from_secs(5), server_task) .await .unwrap() .unwrap(); assert_eq!(accepted.len(), 2); assert_eq!(rooms.lock().await.room_count(), 0); assert!(sessions.lock().await.is_empty()); } fn bind_test_server(max_clients_per_room: usize) -> (RelayServer, CertificateDer<'static>) { let (server_config, certificate) = development_server_config_with_certificate().unwrap(); let endpoint = Endpoint::server( server_config, SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), ) .unwrap(); ( RelayServer::from_endpoint(endpoint, max_clients_per_room), certificate, ) } fn client_endpoint(certificate: CertificateDer<'static>) -> Result { let mut roots = rustls::RootCertStore::empty(); roots .add(certificate) .context("failed to trust relay test certificate")?; let mut client_crypto = rustls::ClientConfig::builder() .with_root_certificates(roots) .with_no_client_auth(); client_crypto.alpn_protocols = vec![RELAY_ALPN.to_vec()]; let client_config = ClientConfig::new(Arc::new( QuicClientConfig::try_from(client_crypto).context("failed to build client config")?, )); let mut endpoint = Endpoint::client(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)) .context("failed to bind client endpoint")?; endpoint.set_default_client_config(client_config); Ok(endpoint) } async fn request_control_message( connection: &quinn::Connection, message: ControlMessage, ) -> Result { let (mut send, mut recv) = connection.open_bi().await?; let request = encode_control_message(&message)?; send.write_all(&request).await?; send.finish() .map_err(|error| anyhow!("failed to finish request stream: {error}"))?; let response = recv.read_to_end(MAX_CONTROL_FRAME_LEN).await?; Ok(decode_control_frame(&response)?) } async fn welcome_for_client(connection: &quinn::Connection, mac: MacAddr) -> ServerWelcome { let hello = EndpointHello::client(RoomCode::new("TESTROOM").unwrap(), mac, 1400).unwrap(); let response = request_control_message(connection, ControlMessage::Hello(hello)) .await .unwrap(); let ControlMessage::Welcome(welcome) = response else { panic!("expected welcome response"); }; welcome } fn client_mac(last: u8) -> MacAddr { MacAddr::new([0x02, 0, 0, 0, 0, last]) } fn ethernet_frame(destination: MacAddr, source: MacAddr) -> Vec { let mut frame = Vec::new(); frame.extend_from_slice(&destination.octets()); frame.extend_from_slice(&source.octets()); frame.extend_from_slice(&0x0800_u16.to_be_bytes()); frame.extend_from_slice(b"payload"); frame } fn unique_temp_cert_path() -> std::path::PathBuf { let nanos = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_nanos(); std::env::temp_dir().join(format!( "lanparty-relay-dev-cert-{}-{nanos}.der", std::process::id() )) } }