use thiserror::Error; pub const OVERLAY_MAGIC: u32 = 0x534c_414e; // "SLAN" pub const OVERLAY_VERSION: u8 = 1; pub const OVERLAY_HEADER_LEN: usize = 22; pub const OVERLAY_FLAGS_NONE: u16 = 0; #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[repr(u8)] pub enum FrameType { Ethernet = 1, Control = 2, Keepalive = 3, } impl FrameType { #[must_use] pub const fn as_u8(self) -> u8 { self as u8 } pub const fn from_u8(value: u8) -> Result { match value { 1 => Ok(Self::Ethernet), 2 => Ok(Self::Control), 3 => Ok(Self::Keepalive), other => Err(ProtoError::UnknownFrameType(other)), } } } #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct OverlayHeader { frame_type: FrameType, room_id: u64, peer_id: u32, flags: u16, payload_len: u16, } impl OverlayHeader { pub fn new( frame_type: FrameType, room_id: u64, peer_id: u32, flags: u16, payload_len: usize, ) -> Result { validate_overlay_flags(flags)?; let payload_len = u16::try_from(payload_len).map_err(|_| ProtoError::PayloadTooLarge { len: payload_len, max: u16::MAX as usize, })?; Ok(Self { frame_type, room_id, peer_id, flags, payload_len, }) } #[must_use] pub const fn frame_type(self) -> FrameType { self.frame_type } #[must_use] pub const fn room_id(self) -> u64 { self.room_id } #[must_use] pub const fn peer_id(self) -> u32 { self.peer_id } #[must_use] pub const fn flags(self) -> u16 { self.flags } #[must_use] pub const fn payload_len(self) -> u16 { self.payload_len } #[must_use] pub fn encode(self) -> [u8; OVERLAY_HEADER_LEN] { let mut bytes = [0; OVERLAY_HEADER_LEN]; bytes[0..4].copy_from_slice(&OVERLAY_MAGIC.to_be_bytes()); bytes[4] = OVERLAY_VERSION; bytes[5] = self.frame_type.as_u8(); bytes[6..14].copy_from_slice(&self.room_id.to_be_bytes()); bytes[14..18].copy_from_slice(&self.peer_id.to_be_bytes()); bytes[18..20].copy_from_slice(&self.flags.to_be_bytes()); bytes[20..22].copy_from_slice(&self.payload_len.to_be_bytes()); bytes } pub fn decode(bytes: &[u8]) -> Result { if bytes.len() < OVERLAY_HEADER_LEN { return Err(ProtoError::DatagramTooShort { actual: bytes.len(), minimum: OVERLAY_HEADER_LEN, }); } let magic = u32::from_be_bytes(bytes[0..4].try_into().expect("header magic slice length")); if magic != OVERLAY_MAGIC { return Err(ProtoError::BadMagic { actual: magic }); } let version = bytes[4]; if version != OVERLAY_VERSION { return Err(ProtoError::UnsupportedVersion { actual: version }); } let flags = u16::from_be_bytes(bytes[18..20].try_into().expect("flags slice length")); validate_overlay_flags(flags)?; Ok(Self { frame_type: FrameType::from_u8(bytes[5])?, room_id: u64::from_be_bytes(bytes[6..14].try_into().expect("room id slice length")), peer_id: u32::from_be_bytes(bytes[14..18].try_into().expect("peer id slice length")), flags, payload_len: u16::from_be_bytes( bytes[20..22].try_into().expect("payload len slice length"), ), }) } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct OverlayPacket<'a> { header: OverlayHeader, payload: &'a [u8], } impl<'a> OverlayPacket<'a> { pub fn new(header: OverlayHeader, payload: &'a [u8]) -> Result { let declared = usize::from(header.payload_len); if payload.len() != declared { return Err(ProtoError::PayloadLengthMismatch { declared, actual: payload.len(), }); } Ok(Self { header, payload }) } #[must_use] pub const fn header(self) -> OverlayHeader { self.header } #[must_use] pub const fn payload(self) -> &'a [u8] { self.payload } } #[derive(Debug, Clone, PartialEq, Eq, Error)] pub enum ProtoError { #[error("datagram is too short: got {actual} bytes, need at least {minimum}")] DatagramTooShort { actual: usize, minimum: usize }, #[error("bad overlay magic 0x{actual:08x}")] BadMagic { actual: u32 }, #[error("unsupported overlay version {actual}")] UnsupportedVersion { actual: u8 }, #[error("unknown overlay frame type {0}")] UnknownFrameType(u8), #[error("unsupported overlay flags 0x{actual:04x}")] UnsupportedFlags { actual: u16 }, #[error("payload length {len} exceeds wire maximum {max}")] PayloadTooLarge { len: usize, max: usize }, #[error("encoded datagram length {len} exceeds negotiated QUIC datagram budget {max}")] DatagramExceedsBudget { len: usize, max: usize }, #[error("declared payload length {declared} does not match actual length {actual}")] PayloadLengthMismatch { declared: usize, actual: usize }, #[error("Ethernet frame is too short: got {actual} bytes, need at least {minimum}")] EthernetFrameTooShort { actual: usize, minimum: usize }, } fn validate_overlay_flags(flags: u16) -> Result<(), ProtoError> { if flags != OVERLAY_FLAGS_NONE { return Err(ProtoError::UnsupportedFlags { actual: flags }); } Ok(()) } pub fn encode_datagram( frame_type: FrameType, room_id: u64, peer_id: u32, flags: u16, payload: &[u8], ) -> Result, ProtoError> { let header = OverlayHeader::new(frame_type, room_id, peer_id, flags, payload.len())?; let mut datagram = Vec::with_capacity(OVERLAY_HEADER_LEN + payload.len()); datagram.extend_from_slice(&header.encode()); datagram.extend_from_slice(payload); Ok(datagram) } pub fn validate_datagram_budget( datagram_len: usize, max_datagram_size: usize, ) -> Result<(), ProtoError> { if datagram_len > max_datagram_size { return Err(ProtoError::DatagramExceedsBudget { len: datagram_len, max: max_datagram_size, }); } Ok(()) } pub fn decode_datagram(bytes: &[u8]) -> Result, ProtoError> { let header = OverlayHeader::decode(bytes)?; let payload = &bytes[OVERLAY_HEADER_LEN..]; OverlayPacket::new(header, payload) } #[cfg(test)] mod tests { use super::*; #[test] fn encodes_and_decodes_datagrams() { let payload = [1, 2, 3, 4]; let datagram = encode_datagram( FrameType::Ethernet, 0x0102_0304_0506_0708, 0x0a0b_0c0d, OVERLAY_FLAGS_NONE, &payload, ) .unwrap(); assert_eq!(datagram.len(), OVERLAY_HEADER_LEN + payload.len()); assert_eq!(&datagram[0..4], &OVERLAY_MAGIC.to_be_bytes()); assert_eq!(datagram[4], OVERLAY_VERSION); assert_eq!(datagram[5], FrameType::Ethernet.as_u8()); let packet = decode_datagram(&datagram).unwrap(); let header = packet.header(); assert_eq!(header.frame_type(), FrameType::Ethernet); assert_eq!(header.room_id(), 0x0102_0304_0506_0708); assert_eq!(header.peer_id(), 0x0a0b_0c0d); assert_eq!(header.flags(), OVERLAY_FLAGS_NONE); assert_eq!(header.payload_len(), 4); assert_eq!(packet.payload(), payload); } #[test] fn rejects_payload_length_mismatch() { let mut datagram = encode_datagram(FrameType::Keepalive, 1, 2, 0, &[1, 2, 3]).unwrap(); datagram.pop(); let error = decode_datagram(&datagram).unwrap_err(); assert!(matches!( error, ProtoError::PayloadLengthMismatch { declared: 3, actual: 2 } )); } #[test] fn rejects_bad_magic_and_version() { let mut datagram = encode_datagram(FrameType::Control, 1, 2, 0, &[]).unwrap(); datagram[0] = 0; assert!(matches!( decode_datagram(&datagram).unwrap_err(), ProtoError::BadMagic { .. } )); let mut datagram = encode_datagram(FrameType::Control, 1, 2, 0, &[]).unwrap(); datagram[4] = 99; assert!(matches!( decode_datagram(&datagram).unwrap_err(), ProtoError::UnsupportedVersion { actual: 99 } )); } #[test] fn rejects_unknown_frame_type() { let mut datagram = encode_datagram(FrameType::Control, 1, 2, 0, &[]).unwrap(); datagram[5] = 99; assert_eq!( decode_datagram(&datagram).unwrap_err(), ProtoError::UnknownFrameType(99) ); } #[test] fn rejects_reserved_overlay_flags() { assert_eq!( OverlayHeader::new(FrameType::Ethernet, 1, 2, 1, 0).unwrap_err(), ProtoError::UnsupportedFlags { actual: 1 } ); let mut datagram = encode_datagram(FrameType::Ethernet, 1, 2, 0, &[]).unwrap(); datagram[18..20].copy_from_slice(&0x8000_u16.to_be_bytes()); assert_eq!( decode_datagram(&datagram).unwrap_err(), ProtoError::UnsupportedFlags { actual: 0x8000 } ); } #[test] fn rejects_payloads_too_large_for_header() { let payload = vec![0; usize::from(u16::MAX) + 1]; assert!(matches!( encode_datagram(FrameType::Ethernet, 1, 2, 0, &payload).unwrap_err(), ProtoError::PayloadTooLarge { .. } )); } #[test] fn rejects_datagrams_over_negotiated_budget() { let datagram = encode_datagram(FrameType::Ethernet, 1, 2, 0, &[1, 2, 3]).unwrap(); assert!(validate_datagram_budget(datagram.len(), datagram.len()).is_ok()); assert_eq!( validate_datagram_budget(datagram.len(), datagram.len() - 1).unwrap_err(), ProtoError::DatagramExceedsBudget { len: datagram.len(), max: datagram.len() - 1 } ); } }