diff --git a/crates/lanparty-relay/src/server.rs b/crates/lanparty-relay/src/server.rs index f0bc50f..1be164e 100644 --- a/crates/lanparty-relay/src/server.rs +++ b/crates/lanparty-relay/src/server.rs @@ -1013,7 +1013,9 @@ mod tests { use lanparty_client_core::{ClientSessionConfig, connect_client}; use lanparty_ctrl::{RoomCode, decode_control_frame, encode_control_message}; use lanparty_gateway::{GatewayConfig, connect_gateway}; - use lanparty_proto::{FrameType, MacAddr, decode_datagram, encode_datagram}; + use lanparty_proto::{ + ETHERNET_HEADER_LEN, ETHERTYPE_IPV4, FrameType, MacAddr, decode_datagram, encode_datagram, + }; use quinn::{ClientConfig, crypto::rustls::QuicClientConfig}; use crate::{DEFAULT_MAX_CLIENTS_PER_ROOM, ListenEndpoint}; @@ -1023,6 +1025,10 @@ mod tests { const ETHERTYPE_ARP: u16 = 0x0806; const ARP_REQUEST: u16 = 1; const ARP_REPLY: u16 = 2; + const IPV4_HEADER_LEN: usize = 20; + const IP_PROTOCOL_ICMPV4: u8 = 1; + const ICMPV4_ECHO_REPLY: u8 = 0; + const ICMPV4_ECHO_REQUEST: u8 = 8; #[tokio::test] async fn binds_quic_endpoint_on_configured_address() { @@ -1763,6 +1769,140 @@ mod tests { assert!(sessions.lock().await.is_empty()); } + #[tokio::test] + async fn bridges_icmpv4_ping_frames_between_client_and_gateway_sessions() { + let (server, certificate) = bind_test_server(DEFAULT_MAX_CLIENTS_PER_ROOM); + let rooms = Arc::clone(&server.rooms); + let sessions = Arc::clone(&server.sessions); + let server_addr = server.local_addr().unwrap(); + let server_task = tokio::spawn(async move { + let handles = server.accept_many_for_test(2).await.unwrap(); + let mut accepted = Vec::with_capacity(handles.len()); + + for handle in handles { + accepted.push(handle.await.unwrap().unwrap().unwrap()); + } + + server.shutdown("test complete").await; + accepted + }); + let cert_der = certificate.as_ref().to_vec(); + let room = RoomCode::new("TESTROOM").unwrap(); + let client_mac = client_mac(1); + let lan_host_mac = lan_host_mac(); + let client_ip = Ipv4Addr::new(10, 73, 42, 51); + let lan_host_ip = Ipv4Addr::new(10, 73, 42, 10); + let gateway = connect_gateway( + GatewayConfig::new( + server_addr, + "lanparty-relay.local", + cert_der.clone(), + room.clone(), + "eth0", + 1400, + ) + .unwrap(), + ) + .await + .unwrap(); + let client = connect_client( + ClientSessionConfig::new( + server_addr, + "lanparty-relay.local", + cert_der, + room, + client_mac, + 1400, + ) + .unwrap(), + ) + .await + .unwrap(); + + let ControlMessage::PeerJoined(peer) = + tokio::time::timeout(Duration::from_secs(5), gateway.recv_control_event()) + .await + .unwrap() + .unwrap() + else { + panic!("expected gateway to observe client join"); + }; + assert_eq!(peer.peer_id(), client.welcome().peer_id()); + assert_eq!(peer.role(), Role::Client); + assert_eq!(peer.mac(), Some(client_mac)); + + let ControlMessage::PeerJoined(peer) = + tokio::time::timeout(Duration::from_secs(5), client.recv_control_event()) + .await + .unwrap() + .unwrap() + else { + panic!("expected client to receive gateway catch-up event"); + }; + assert_eq!(peer.peer_id(), gateway.welcome().peer_id()); + assert_eq!(peer.role(), Role::Gateway); + + let echo_request = icmpv4_echo_frame( + lan_host_mac, + client_mac, + ICMPV4_ECHO_REQUEST, + client_ip, + lan_host_ip, + ); + let request_header = EthernetFrame::parse(&echo_request).unwrap(); + let request_ipv4 = &echo_request[ETHERNET_HEADER_LEN..]; + assert_eq!(request_header.ethertype_or_len(), ETHERTYPE_IPV4); + assert_eq!(request_header.destination(), lan_host_mac); + assert_eq!(request_ipv4[9], IP_PROTOCOL_ICMPV4); + assert_eq!(request_ipv4[IPV4_HEADER_LEN], ICMPV4_ECHO_REQUEST); + assert_eq!( + client + .relay_io() + .send_ethernet_with_outcome(&echo_request) + .unwrap(), + lanparty_client_core::ClientSendOutcome::Sent + ); + let received = tokio::time::timeout(Duration::from_secs(5), gateway.recv_ethernet()) + .await + .unwrap() + .unwrap(); + assert_eq!(received.source_peer_id(), client.welcome().peer_id()); + assert_eq!(received.payload(), echo_request.as_slice()); + + let echo_reply = icmpv4_echo_frame( + client_mac, + lan_host_mac, + ICMPV4_ECHO_REPLY, + lan_host_ip, + client_ip, + ); + let reply_header = EthernetFrame::parse(&echo_reply).unwrap(); + let reply_ipv4 = &echo_reply[ETHERNET_HEADER_LEN..]; + assert_eq!(reply_header.ethertype_or_len(), ETHERTYPE_IPV4); + assert_eq!(reply_header.destination(), client_mac); + assert_eq!(reply_ipv4[9], IP_PROTOCOL_ICMPV4); + assert_eq!(reply_ipv4[IPV4_HEADER_LEN], ICMPV4_ECHO_REPLY); + gateway.send_ethernet(&echo_reply).unwrap(); + let received = + tokio::time::timeout(Duration::from_secs(5), client.relay_io().recv_ethernet()) + .await + .unwrap() + .unwrap(); + assert_eq!(received.source_peer_id(), gateway.welcome().peer_id()); + assert_eq!(received.payload(), echo_reply.as_slice()); + + client.shutdown("test client done").await; + gateway.shutdown("test gateway done").await; + + let accepted = tokio::time::timeout(Duration::from_secs(5), server_task) + .await + .unwrap() + .unwrap(); + assert_eq!(accepted.len(), 2); + assert_eq!(rooms.lock().await.room_count(), 0); + assert!(sessions.lock().await.is_empty()); + } + #[tokio::test] async fn reconnects_gateway_while_client_stays_joined() { let (server, certificate) = bind_test_server(DEFAULT_MAX_CLIENTS_PER_ROOM); @@ -2158,8 +2298,12 @@ mod tests { MacAddr::new([0x0a, 0, 0, 0, 0, 1]) } + fn lan_host_mac() -> MacAddr { + MacAddr::new([0x0a, 0, 0, 0, 0, 2]) + } + fn ethernet_frame(destination: MacAddr, source: MacAddr) -> Vec { - ethernet_frame_with_payload(destination, source, 0x0800, b"payload") + ethernet_frame_with_payload(destination, source, ETHERTYPE_IPV4, b"payload") } fn arp_frame( @@ -2184,6 +2328,57 @@ mod tests { ethernet_frame_with_payload(destination, source, ETHERTYPE_ARP, &payload) } + fn icmpv4_echo_frame( + destination: MacAddr, + source: MacAddr, + message_type: u8, + source_ip: Ipv4Addr, + destination_ip: Ipv4Addr, + ) -> Vec { + let mut icmp = Vec::with_capacity(8); + icmp.push(message_type); + icmp.push(0); + icmp.extend_from_slice(&0_u16.to_be_bytes()); + icmp.extend_from_slice(&0x4242_u16.to_be_bytes()); + icmp.extend_from_slice(&1_u16.to_be_bytes()); + let checksum = internet_checksum(&icmp); + icmp[2..4].copy_from_slice(&checksum.to_be_bytes()); + + let mut ipv4 = Vec::with_capacity(IPV4_HEADER_LEN + icmp.len()); + ipv4.push(0x45); + ipv4.push(0); + let total_len = u16::try_from(IPV4_HEADER_LEN + icmp.len()).unwrap(); + ipv4.extend_from_slice(&total_len.to_be_bytes()); + ipv4.extend_from_slice(&0x1234_u16.to_be_bytes()); + ipv4.extend_from_slice(&0_u16.to_be_bytes()); + ipv4.push(64); + ipv4.push(IP_PROTOCOL_ICMPV4); + ipv4.extend_from_slice(&0_u16.to_be_bytes()); + ipv4.extend_from_slice(&source_ip.octets()); + ipv4.extend_from_slice(&destination_ip.octets()); + let checksum = internet_checksum(&ipv4); + ipv4[10..12].copy_from_slice(&checksum.to_be_bytes()); + ipv4.extend_from_slice(&icmp); + + ethernet_frame_with_payload(destination, source, ETHERTYPE_IPV4, &ipv4) + } + + fn internet_checksum(bytes: &[u8]) -> u16 { + let mut sum = 0_u32; + let mut chunks = bytes.chunks_exact(2); + for chunk in &mut chunks { + sum += u32::from(u16::from_be_bytes([chunk[0], chunk[1]])); + } + if let Some(byte) = chunks.remainder().first() { + sum += u32::from(*byte) << 8; + } + while sum > 0xffff { + sum = (sum & 0xffff) + (sum >> 16); + } + + !(sum as u16) + } + fn ethernet_frame_with_payload( destination: MacAddr, source: MacAddr,