use std::{ ffi::CString, io, os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}, }; use lanparty_proto::MacAddr; const ETH_P_ALL: u16 = libc::ETH_P_ALL as u16; #[derive(Debug)] pub struct PacketSocket { fd: OwnedFd, interface: String, interface_index: u32, interface_mac: MacAddr, } impl PacketSocket { pub fn open(interface: &str) -> io::Result { let interface_index = interface_index(interface)?; let protocol = i32::from(ETH_P_ALL.to_be()); let raw_fd = unsafe { // SAFETY: socket is called with constant domain/type/protocol values and returns // a new file descriptor or -1 without aliasing Rust-owned memory. libc::socket( libc::AF_PACKET, libc::SOCK_RAW | libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK, protocol, ) }; if raw_fd < 0 { return Err(io::Error::last_os_error()); } let fd = unsafe { // SAFETY: raw_fd was just returned by socket and is owned by this function. OwnedFd::from_raw_fd(raw_fd) }; let mut address = unsafe { // SAFETY: sockaddr_ll is a plain old data kernel ABI struct; zero is a valid base // before filling the fields required by bind. std::mem::zeroed::() }; address.sll_family = libc::AF_PACKET as u16; address.sll_protocol = ETH_P_ALL.to_be(); address.sll_ifindex = interface_index as i32; let result = unsafe { // SAFETY: address points to a properly initialized sockaddr_ll and the length // matches that struct. fd remains owned by this function across the call. libc::bind( fd.as_raw_fd(), (&address as *const libc::sockaddr_ll).cast::(), std::mem::size_of::() as libc::socklen_t, ) }; if result < 0 { return Err(io::Error::last_os_error()); } let interface_mac = interface_hardware_addr(fd.as_raw_fd(), interface)?; Ok(Self { fd, interface: interface.to_owned(), interface_index, interface_mac, }) } #[must_use] pub fn interface(&self) -> &str { &self.interface } #[must_use] pub const fn interface_index(&self) -> u32 { self.interface_index } #[must_use] pub const fn interface_mac(&self) -> MacAddr { self.interface_mac } pub fn send_frame(&self, frame: &[u8]) -> io::Result { let sent = unsafe { // SAFETY: frame.as_ptr() is valid for frame.len() bytes for the duration of send, // and send does not retain the pointer after returning. libc::send( self.fd.as_raw_fd(), frame.as_ptr().cast::(), frame.len(), 0, ) }; if sent < 0 { return Err(io::Error::last_os_error()); } Ok(sent as usize) } pub fn recv_frame(&self, buffer: &mut [u8]) -> io::Result { loop { let mut address = unsafe { // SAFETY: sockaddr_ll is a plain old data kernel ABI struct; zero is a valid // base before recvfrom initializes the peer address. std::mem::zeroed::() }; let mut address_len = std::mem::size_of::() as libc::socklen_t; let received = unsafe { // SAFETY: buffer.as_mut_ptr() is valid for buffer.len() bytes for the duration // of recvfrom, and recvfrom initializes at most that many bytes. address points // to a sockaddr_ll-sized output buffer and address_len carries that size. libc::recvfrom( self.fd.as_raw_fd(), buffer.as_mut_ptr().cast::(), buffer.len(), 0, (&mut address as *mut libc::sockaddr_ll).cast::(), &mut address_len, ) }; if received < 0 { return Err(io::Error::last_os_error()); } if is_inbound_packet_type(address.sll_pkttype) { return Ok(received as usize); } } } } impl AsRawFd for PacketSocket { fn as_raw_fd(&self) -> RawFd { self.fd.as_raw_fd() } } pub fn interface_index(interface: &str) -> io::Result { let name = interface_name(interface)?; let index = unsafe { // SAFETY: name is a valid NUL-terminated C string and if_nametoindex does not retain it. libc::if_nametoindex(name.as_ptr()) }; if index == 0 { return Err(io::Error::last_os_error()); } Ok(index) } fn interface_name(interface: &str) -> io::Result { if interface.trim().is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidInput, "interface name cannot be empty", )); } let name = CString::new(interface).map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, "interface name cannot contain NUL bytes", ) })?; if name.as_bytes_with_nul().len() > libc::IFNAMSIZ { return Err(io::Error::new( io::ErrorKind::InvalidInput, "interface name is too long", )); } Ok(name) } fn interface_hardware_addr(fd: RawFd, interface: &str) -> io::Result { let name = interface_name(interface)?; let mut request = unsafe { // SAFETY: ifreq is a kernel ABI struct; zeroing it gives a valid base before filling the // interface name and asking ioctl to populate the hardware-address union field. std::mem::zeroed::() }; for (destination, source) in request .ifr_name .iter_mut() .zip(name.as_bytes_with_nul().iter()) { *destination = *source as libc::c_char; } let result = unsafe { // SAFETY: request points to an initialized ifreq with a NUL-terminated interface name. // ioctl writes the hardware address into the ifreq union and does not retain pointers. libc::ioctl(fd, libc::SIOCGIFHWADDR, &mut request) }; if result < 0 { return Err(io::Error::last_os_error()); } let hardware_addr = unsafe { // SAFETY: SIOCGIFHWADDR succeeded, so reading the hardware-address union field is valid. request.ifr_ifru.ifru_hwaddr }; if hardware_addr.sa_family != libc::ARPHRD_ETHER { return Err(io::Error::new( io::ErrorKind::InvalidData, "interface is not an Ethernet device", )); } let data = hardware_addr.sa_data; let mac = MacAddr::new([ data[0] as u8, data[1] as u8, data[2] as u8, data[3] as u8, data[4] as u8, data[5] as u8, ]); if !mac.is_valid_unicast() { return Err(io::Error::new( io::ErrorKind::InvalidData, format!("interface has invalid Ethernet MAC address {mac}"), )); } Ok(mac) } fn is_inbound_packet_type(packet_type: u8) -> bool { packet_type != libc::PACKET_OUTGOING } #[cfg(test)] mod tests { use super::*; #[test] fn rejects_invalid_interface_names() { assert_eq!( interface_index("").unwrap_err().kind(), io::ErrorKind::InvalidInput ); assert_eq!( interface_index("eth0\0bad").unwrap_err().kind(), io::ErrorKind::InvalidInput ); } #[test] fn reports_missing_interface() { let error = interface_index("lp-missing0").unwrap_err(); assert_ne!(error.kind(), io::ErrorKind::InvalidInput); } #[test] fn classifies_inbound_packet_types() { assert!(!is_inbound_packet_type(libc::PACKET_OUTGOING)); assert!(is_inbound_packet_type(libc::PACKET_HOST)); assert!(is_inbound_packet_type(libc::PACKET_BROADCAST)); assert!(is_inbound_packet_type(libc::PACKET_MULTICAST)); } }