use thiserror::Error; use crate::{ControlError, ControlMessage}; pub const CONTROL_LENGTH_PREFIX_LEN: usize = 4; pub const MAX_CONTROL_MESSAGE_LEN: usize = 64 * 1024; #[derive(Debug, Error)] pub enum ControlCodecError { #[error("control frame is too short: got {actual} bytes, need at least {minimum}")] FrameTooShort { actual: usize, minimum: usize }, #[error("control message length {len} exceeds maximum {max}")] MessageTooLarge { len: usize, max: usize }, #[error( "control frame payload is incomplete: declared {declared} bytes, available {available}" )] IncompletePayload { declared: usize, available: usize }, #[error("control frame has {trailing} trailing bytes after one message")] TrailingBytes { trailing: usize }, #[error("control message JSON is invalid: {0}")] Json(#[from] serde_json::Error), #[error("control message failed validation: {0}")] InvalidMessage(#[from] ControlError), } pub fn encode_control_message(message: &ControlMessage) -> Result, ControlCodecError> { message.validate()?; let payload = serde_json::to_vec(message)?; let payload_len = payload.len(); if payload_len > MAX_CONTROL_MESSAGE_LEN { return Err(ControlCodecError::MessageTooLarge { len: payload_len, max: MAX_CONTROL_MESSAGE_LEN, }); } let mut frame = Vec::with_capacity(CONTROL_LENGTH_PREFIX_LEN + payload_len); frame.extend_from_slice(&(payload_len as u32).to_be_bytes()); frame.extend_from_slice(&payload); Ok(frame) } pub fn decode_control_frame(frame: &[u8]) -> Result { let Some(total_len) = complete_control_frame_len(frame)? else { return Err(incomplete_frame_error(frame)); }; if frame.len() > total_len { return Err(ControlCodecError::TrailingBytes { trailing: frame.len() - total_len, }); } let payload = &frame[CONTROL_LENGTH_PREFIX_LEN..total_len]; let message: ControlMessage = serde_json::from_slice(payload)?; message.validate()?; Ok(message) } pub fn complete_control_frame_len(buffer: &[u8]) -> Result, ControlCodecError> { if buffer.len() < CONTROL_LENGTH_PREFIX_LEN { return Ok(None); } let declared = declared_payload_len(buffer)?; if declared > MAX_CONTROL_MESSAGE_LEN { return Err(ControlCodecError::MessageTooLarge { len: declared, max: MAX_CONTROL_MESSAGE_LEN, }); } let total = CONTROL_LENGTH_PREFIX_LEN + declared; if buffer.len() < total { return Ok(None); } Ok(Some(total)) } fn declared_payload_len(buffer: &[u8]) -> Result { if buffer.len() < CONTROL_LENGTH_PREFIX_LEN { return Err(ControlCodecError::FrameTooShort { actual: buffer.len(), minimum: CONTROL_LENGTH_PREFIX_LEN, }); } Ok(u32::from_be_bytes( buffer[0..CONTROL_LENGTH_PREFIX_LEN] .try_into() .expect("length prefix slice has exact size"), ) as usize) } fn incomplete_frame_error(frame: &[u8]) -> ControlCodecError { if frame.len() < CONTROL_LENGTH_PREFIX_LEN { return ControlCodecError::FrameTooShort { actual: frame.len(), minimum: CONTROL_LENGTH_PREFIX_LEN, }; } let declared = declared_payload_len(frame).expect("frame has length prefix"); ControlCodecError::IncompletePayload { declared, available: frame.len() - CONTROL_LENGTH_PREFIX_LEN, } } #[cfg(test)] mod tests { use lanparty_proto::{MIN_USEFUL_TAP_MTU, MacAddr}; use super::*; use crate::{ControlMessage, DisconnectReason, EndpointHello, PeerInfo, Role, RoomCode}; fn room() -> RoomCode { RoomCode::new("ROOM_1").unwrap() } fn mac() -> MacAddr { MacAddr::new([0x02, 0, 0, 0, 0, 1]) } #[test] fn encodes_and_decodes_control_messages() { let message = ControlMessage::Hello(EndpointHello::client(room(), mac(), 1400).unwrap()); let frame = encode_control_message(&message).unwrap(); let declared = u32::from_be_bytes(frame[0..4].try_into().unwrap()) as usize; assert_eq!(declared, frame.len() - CONTROL_LENGTH_PREFIX_LEN); assert_eq!( complete_control_frame_len(&frame).unwrap(), Some(frame.len()) ); assert_eq!(decode_control_frame(&frame).unwrap(), message); } #[test] fn reports_incomplete_frames_for_stream_buffering() { assert_eq!(complete_control_frame_len(&[0, 0, 0]).unwrap(), None); let frame = encode_control_message(&ControlMessage::PeerLeft { peer_id: 1, reason: DisconnectReason::Normal, }) .unwrap(); assert_eq!(complete_control_frame_len(&frame[..8]).unwrap(), None); assert!(matches!( decode_control_frame(&frame[..8]).unwrap_err(), ControlCodecError::IncompletePayload { .. } )); } #[test] fn rejects_trailing_bytes_after_one_frame() { let mut frame = encode_control_message(&ControlMessage::PeerJoined( PeerInfo::new(1, Role::Client, Some(mac())).unwrap(), )) .unwrap(); frame.push(0); assert!(matches!( decode_control_frame(&frame).unwrap_err(), ControlCodecError::TrailingBytes { trailing: 1 } )); } #[test] fn rejects_oversized_declared_length() { let mut frame = [0; CONTROL_LENGTH_PREFIX_LEN]; frame.copy_from_slice(&((MAX_CONTROL_MESSAGE_LEN as u32) + 1).to_be_bytes()); assert!(matches!( complete_control_frame_len(&frame).unwrap_err(), ControlCodecError::MessageTooLarge { .. } )); } #[test] fn validates_decoded_messages() { let json = format!( r#"{{"type":"welcome","payload":{{"protocol_version":1,"room_id":1,"peer_id":0,"effective_tap_mtu":{}}}}}"#, MIN_USEFUL_TAP_MTU ); let mut frame = Vec::new(); frame.extend_from_slice(&(json.len() as u32).to_be_bytes()); frame.extend_from_slice(json.as_bytes()); assert!(matches!( decode_control_frame(&frame).unwrap_err(), ControlCodecError::InvalidMessage(ControlError::InvalidPeerId) )); } }