diff --git a/crates/lanparty-gateway/src/packet.rs b/crates/lanparty-gateway/src/packet.rs index f80d673..1955fcf 100644 --- a/crates/lanparty-gateway/src/packet.rs +++ b/crates/lanparty-gateway/src/packet.rs @@ -92,21 +92,33 @@ impl PacketSocket { } pub fn recv_frame(&self, buffer: &mut [u8]) -> io::Result { - let received = unsafe { - // SAFETY: buffer.as_mut_ptr() is valid for buffer.len() bytes for the duration of - // recv, and recv initializes at most that many bytes. - libc::recv( - self.fd.as_raw_fd(), - buffer.as_mut_ptr().cast::(), - buffer.len(), - 0, - ) - }; - if received < 0 { - return Err(io::Error::last_os_error()); + loop { + let mut address = unsafe { + // SAFETY: sockaddr_ll is a plain old data kernel ABI struct; zero is a valid + // base before recvfrom initializes the peer address. + std::mem::zeroed::() + }; + let mut address_len = std::mem::size_of::() as libc::socklen_t; + let received = unsafe { + // SAFETY: buffer.as_mut_ptr() is valid for buffer.len() bytes for the duration + // of recvfrom, and recvfrom initializes at most that many bytes. address points + // to a sockaddr_ll-sized output buffer and address_len carries that size. + libc::recvfrom( + self.fd.as_raw_fd(), + buffer.as_mut_ptr().cast::(), + buffer.len(), + 0, + (&mut address as *mut libc::sockaddr_ll).cast::(), + &mut address_len, + ) + }; + if received < 0 { + return Err(io::Error::last_os_error()); + } + if is_inbound_packet_type(address.sll_pkttype) { + return Ok(received as usize); + } } - - Ok(received as usize) } } @@ -142,6 +154,10 @@ pub fn interface_index(interface: &str) -> io::Result { Ok(index) } +fn is_inbound_packet_type(packet_type: u8) -> bool { + packet_type != libc::PACKET_OUTGOING +} + #[cfg(test)] mod tests { use super::*; @@ -164,4 +180,12 @@ mod tests { assert_ne!(error.kind(), io::ErrorKind::InvalidInput); } + + #[test] + fn classifies_inbound_packet_types() { + assert!(!is_inbound_packet_type(libc::PACKET_OUTGOING)); + assert!(is_inbound_packet_type(libc::PACKET_HOST)); + assert!(is_inbound_packet_type(libc::PACKET_BROADCAST)); + assert!(is_inbound_packet_type(libc::PACKET_MULTICAST)); + } }