//! Packet encoding/decoding for TFTP (RFC 1350). //! //! RFC 1350, Section 5 ("TFTP Packets") defines the packet formats: //! - RRQ/WRQ: Figure 5-1 //! - DATA: Figure 5-2 //! - ACK: Figure 5-3 //! - ERROR: Figure 5-4 //! //! This module implements strict parsing/serialization of those formats. #![forbid(unsafe_code)] use core::fmt; /// RFC 1350, Section 2: data blocks are 512 bytes. pub const BLOCK_SIZE: usize = 512; /// Maximum packet size in RFC 1350 (DATA with 512 bytes). pub const MAX_PACKET_SIZE: usize = 4 + BLOCK_SIZE; /// TFTP opcode values (RFC 1350, Section 5). #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u16)] pub enum Opcode { /// Read request (RRQ), RFC 1350 Figure 5-1. Rrq = 1, /// Write request (WRQ), RFC 1350 Figure 5-1. Wrq = 2, /// Data (DATA), RFC 1350 Figure 5-2. Data = 3, /// Acknowledgment (ACK), RFC 1350 Figure 5-3. Ack = 4, /// Error (ERROR), RFC 1350 Figure 5-4. Error = 5, } impl Opcode { #[must_use] pub fn from_u16(value: u16) -> Option { match value { 1 => Some(Self::Rrq), 2 => Some(Self::Wrq), 3 => Some(Self::Data), 4 => Some(Self::Ack), 5 => Some(Self::Error), _ => None, } } } /// Transfer mode for RRQ/WRQ (RFC 1350, Section 5, Figure 5-1). /// /// RFC 1350 lists: "netascii", "octet", "mail" (mail is obsolete; Section 1). #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Mode { NetAscii, Octet, Mail, } impl Mode { #[must_use] pub fn as_str(self) -> &'static str { match self { Self::NetAscii => "netascii", Self::Octet => "octet", Self::Mail => "mail", } } #[must_use] pub fn parse_case_insensitive(s: &str) -> Option { if s.eq_ignore_ascii_case("netascii") { Some(Self::NetAscii) } else if s.eq_ignore_ascii_case("octet") { Some(Self::Octet) } else if s.eq_ignore_ascii_case("mail") { Some(Self::Mail) } else { None } } } /// Error codes (RFC 1350 Appendix, "Error Codes"). #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u16)] pub enum ErrorCode { NotDefined = 0, FileNotFound = 1, AccessViolation = 2, DiskFull = 3, IllegalOperation = 4, UnknownTransferId = 5, FileAlreadyExists = 6, NoSuchUser = 7, } impl ErrorCode { #[must_use] pub fn from_u16(value: u16) -> Option { match value { 0 => Some(Self::NotDefined), 1 => Some(Self::FileNotFound), 2 => Some(Self::AccessViolation), 3 => Some(Self::DiskFull), 4 => Some(Self::IllegalOperation), 5 => Some(Self::UnknownTransferId), 6 => Some(Self::FileAlreadyExists), 7 => Some(Self::NoSuchUser), _ => None, } } } /// A RRQ/WRQ request (RFC 1350 Figure 5-1). #[derive(Debug, Clone, PartialEq, Eq)] pub struct Request { pub filename: String, pub mode: Mode, } /// A decoded TFTP packet (RFC 1350 Section 5). #[derive(Debug, Clone, PartialEq, Eq)] pub enum Packet { Rrq(Request), Wrq(Request), Data { block: u16, data: Vec }, Ack { block: u16 }, Error { code: ErrorCode, message: String }, } impl Packet { /// Encode a packet into a newly allocated buffer. #[must_use] pub fn encode(&self) -> Vec { let mut out = Vec::with_capacity(MAX_PACKET_SIZE); self.encode_into(&mut out); out } /// Encode a packet by appending to `out`. pub fn encode_into(&self, out: &mut Vec) { match self { Self::Rrq(req) => encode_request(out, Opcode::Rrq, req), Self::Wrq(req) => encode_request(out, Opcode::Wrq, req), Self::Data { block, data } => { out.extend_from_slice(&(Opcode::Data as u16).to_be_bytes()); out.extend_from_slice(&block.to_be_bytes()); out.extend_from_slice(data); } Self::Ack { block } => { out.extend_from_slice(&(Opcode::Ack as u16).to_be_bytes()); out.extend_from_slice(&block.to_be_bytes()); } Self::Error { code, message } => { out.extend_from_slice(&(Opcode::Error as u16).to_be_bytes()); out.extend_from_slice(&(*code as u16).to_be_bytes()); out.extend_from_slice(message.as_bytes()); out.push(0); } } } /// Decode a packet from UDP payload bytes. /// /// # Errors /// Returns a [`DecodeError`] if the input does not match an RFC 1350 packet format. pub fn decode(input: &[u8]) -> Result { let (opcode, rest) = parse_opcode(input)?; match opcode { Opcode::Rrq | Opcode::Wrq => { let (filename, rest) = parse_netascii_z(rest, "filename")?; let (mode_str, rest) = parse_netascii_z(rest, "mode")?; if !rest.is_empty() { return Err(DecodeError::TrailingBytes); } let Some(mode) = Mode::parse_case_insensitive(&mode_str) else { return Err(DecodeError::UnknownMode(mode_str)); }; let req = Request { filename, mode }; Ok(if opcode == Opcode::Rrq { Self::Rrq(req) } else { Self::Wrq(req) }) } Opcode::Data => { if rest.len() < 2 { return Err(DecodeError::Truncated("block")); } if rest.len() > 2 + BLOCK_SIZE { return Err(DecodeError::OversizeData(rest.len() - 2)); } let block = u16::from_be_bytes([rest[0], rest[1]]); let data = rest[2..].to_vec(); Ok(Self::Data { block, data }) } Opcode::Ack => { if rest.len() != 2 { return Err(DecodeError::InvalidLength { kind: "ACK", expected: 4, got: input.len(), }); } let block = u16::from_be_bytes([rest[0], rest[1]]); Ok(Self::Ack { block }) } Opcode::Error => { if rest.len() < 2 { return Err(DecodeError::Truncated("error code")); } let code_u16 = u16::from_be_bytes([rest[0], rest[1]]); let Some(code) = ErrorCode::from_u16(code_u16) else { return Err(DecodeError::UnknownErrorCode(code_u16)); }; let (message, trailing) = parse_netascii_z(&rest[2..], "error message")?; if !trailing.is_empty() { return Err(DecodeError::TrailingBytes); } Ok(Self::Error { code, message }) } } } } fn encode_request(out: &mut Vec, opcode: Opcode, req: &Request) { out.extend_from_slice(&(opcode as u16).to_be_bytes()); out.extend_from_slice(req.filename.as_bytes()); out.push(0); out.extend_from_slice(req.mode.as_str().as_bytes()); out.push(0); } fn parse_opcode(input: &[u8]) -> Result<(Opcode, &[u8]), DecodeError> { if input.len() < 2 { return Err(DecodeError::Truncated("opcode")); } let opcode_u16 = u16::from_be_bytes([input[0], input[1]]); let Some(opcode) = Opcode::from_u16(opcode_u16) else { return Err(DecodeError::UnknownOpcode(opcode_u16)); }; Ok((opcode, &input[2..])) } fn parse_netascii_z<'a>( input: &'a [u8], field: &'static str, ) -> Result<(String, &'a [u8]), DecodeError> { let Some(pos) = input.iter().position(|&b| b == 0) else { return Err(DecodeError::MissingTerminator(field)); }; if pos == 0 { return Err(DecodeError::EmptyField(field)); } let bytes = &input[..pos]; if bytes.iter().any(|&b| b > 0x7F) { return Err(DecodeError::NonNetascii(field)); } let s = core::str::from_utf8(bytes).map_err(|_| DecodeError::NonUtf8(field))?; Ok((s.to_owned(), &input[pos + 1..])) } /// A parse error for RFC 1350 packet formats. #[derive(Debug, Clone, PartialEq, Eq)] pub enum DecodeError { Truncated(&'static str), UnknownOpcode(u16), MissingTerminator(&'static str), EmptyField(&'static str), NonNetascii(&'static str), NonUtf8(&'static str), UnknownMode(String), UnknownErrorCode(u16), OversizeData(usize), InvalidLength { kind: &'static str, expected: usize, got: usize, }, TrailingBytes, } impl fmt::Display for DecodeError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Truncated(what) => write!(f, "truncated packet ({what})"), Self::UnknownOpcode(op) => write!(f, "unknown opcode {op}"), Self::MissingTerminator(field) => write!(f, "missing NUL terminator for {field}"), Self::EmptyField(field) => write!(f, "empty {field}"), Self::NonNetascii(field) => write!(f, "{field} contains non-netascii bytes"), Self::NonUtf8(field) => write!(f, "{field} is not valid UTF-8"), Self::UnknownMode(mode) => write!(f, "unknown transfer mode {mode:?}"), Self::UnknownErrorCode(code) => write!(f, "unknown error code {code}"), Self::OversizeData(n) => write!(f, "DATA payload too large ({n} bytes)"), Self::InvalidLength { kind, expected, got, } => { write!( f, "{kind} packet has invalid length (expected {expected}, got {got})" ) } Self::TrailingBytes => write!(f, "trailing bytes after packet"), } } } impl std::error::Error for DecodeError {} #[cfg(test)] mod tests { use super::*; #[allow(clippy::unwrap_used)] #[test] fn rrq_roundtrip() { let pkt = Packet::Rrq(Request { filename: "hello.txt".to_string(), mode: Mode::Octet, }); let encoded = pkt.encode(); let decoded = Packet::decode(&encoded).unwrap(); assert_eq!(decoded, pkt); } #[allow(clippy::unwrap_used)] #[test] fn data_roundtrip_small() { let pkt = Packet::Data { block: 42, data: b"abc".to_vec(), }; let encoded = pkt.encode(); let decoded = Packet::decode(&encoded).unwrap(); assert_eq!(decoded, pkt); } #[allow(clippy::unwrap_used)] #[test] fn ack_requires_exact_length() { let mut bytes = Vec::new(); Packet::Ack { block: 1 }.encode_into(&mut bytes); bytes.push(0); let err = Packet::decode(&bytes).unwrap_err(); assert!(matches!( err, DecodeError::InvalidLength { kind: "ACK", .. } )); } #[allow(clippy::unwrap_used)] #[test] fn rejects_unknown_opcode() { let err = Packet::decode(&[0, 9]).unwrap_err(); assert!(matches!(err, DecodeError::UnknownOpcode(9))); } #[allow(clippy::unwrap_used)] #[test] fn rejects_data_larger_than_512() { let mut bytes = vec![0, 3, 0, 1]; bytes.extend(std::iter::repeat(0x61).take(BLOCK_SIZE + 1)); let err = Packet::decode(&bytes).unwrap_err(); assert!(matches!(err, DecodeError::OversizeData(513))); } }