Files
pfs-tftp-codex/crates/pfs-tftp-proto/src/packet.rs

371 lines
12 KiB
Rust

//! Packet encoding/decoding for TFTP (RFC 1350).
//!
//! RFC 1350, Section 5 ("TFTP Packets") defines the packet formats:
//! - RRQ/WRQ: Figure 5-1
//! - DATA: Figure 5-2
//! - ACK: Figure 5-3
//! - ERROR: Figure 5-4
//!
//! This module implements strict parsing/serialization of those formats.
#![forbid(unsafe_code)]
use core::fmt;
/// RFC 1350, Section 2: data blocks are 512 bytes.
pub const BLOCK_SIZE: usize = 512;
/// Maximum packet size in RFC 1350 (DATA with 512 bytes).
pub const MAX_PACKET_SIZE: usize = 4 + BLOCK_SIZE;
/// TFTP opcode values (RFC 1350, Section 5).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub enum Opcode {
/// Read request (RRQ), RFC 1350 Figure 5-1.
Rrq = 1,
/// Write request (WRQ), RFC 1350 Figure 5-1.
Wrq = 2,
/// Data (DATA), RFC 1350 Figure 5-2.
Data = 3,
/// Acknowledgment (ACK), RFC 1350 Figure 5-3.
Ack = 4,
/// Error (ERROR), RFC 1350 Figure 5-4.
Error = 5,
}
impl Opcode {
#[must_use]
pub fn from_u16(value: u16) -> Option<Self> {
match value {
1 => Some(Self::Rrq),
2 => Some(Self::Wrq),
3 => Some(Self::Data),
4 => Some(Self::Ack),
5 => Some(Self::Error),
_ => None,
}
}
}
/// Transfer mode for RRQ/WRQ (RFC 1350, Section 5, Figure 5-1).
///
/// RFC 1350 lists: "netascii", "octet", "mail" (mail is obsolete; Section 1).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Mode {
NetAscii,
Octet,
Mail,
}
impl Mode {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::NetAscii => "netascii",
Self::Octet => "octet",
Self::Mail => "mail",
}
}
#[must_use]
pub fn parse_case_insensitive(s: &str) -> Option<Self> {
if s.eq_ignore_ascii_case("netascii") {
Some(Self::NetAscii)
} else if s.eq_ignore_ascii_case("octet") {
Some(Self::Octet)
} else if s.eq_ignore_ascii_case("mail") {
Some(Self::Mail)
} else {
None
}
}
}
/// Error codes (RFC 1350 Appendix, "Error Codes").
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub enum ErrorCode {
NotDefined = 0,
FileNotFound = 1,
AccessViolation = 2,
DiskFull = 3,
IllegalOperation = 4,
UnknownTransferId = 5,
FileAlreadyExists = 6,
NoSuchUser = 7,
}
impl ErrorCode {
#[must_use]
pub fn from_u16(value: u16) -> Option<Self> {
match value {
0 => Some(Self::NotDefined),
1 => Some(Self::FileNotFound),
2 => Some(Self::AccessViolation),
3 => Some(Self::DiskFull),
4 => Some(Self::IllegalOperation),
5 => Some(Self::UnknownTransferId),
6 => Some(Self::FileAlreadyExists),
7 => Some(Self::NoSuchUser),
_ => None,
}
}
}
/// A RRQ/WRQ request (RFC 1350 Figure 5-1).
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Request {
pub filename: String,
pub mode: Mode,
}
/// A decoded TFTP packet (RFC 1350 Section 5).
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Packet {
Rrq(Request),
Wrq(Request),
Data { block: u16, data: Vec<u8> },
Ack { block: u16 },
Error { code: ErrorCode, message: String },
}
impl Packet {
/// Encode a packet into a newly allocated buffer.
#[must_use]
pub fn encode(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(MAX_PACKET_SIZE);
self.encode_into(&mut out);
out
}
/// Encode a packet by appending to `out`.
pub fn encode_into(&self, out: &mut Vec<u8>) {
match self {
Self::Rrq(req) => encode_request(out, Opcode::Rrq, req),
Self::Wrq(req) => encode_request(out, Opcode::Wrq, req),
Self::Data { block, data } => {
out.extend_from_slice(&(Opcode::Data as u16).to_be_bytes());
out.extend_from_slice(&block.to_be_bytes());
out.extend_from_slice(data);
}
Self::Ack { block } => {
out.extend_from_slice(&(Opcode::Ack as u16).to_be_bytes());
out.extend_from_slice(&block.to_be_bytes());
}
Self::Error { code, message } => {
out.extend_from_slice(&(Opcode::Error as u16).to_be_bytes());
out.extend_from_slice(&(*code as u16).to_be_bytes());
out.extend_from_slice(message.as_bytes());
out.push(0);
}
}
}
/// Decode a packet from UDP payload bytes.
///
/// # Errors
/// Returns a [`DecodeError`] if the input does not match an RFC 1350 packet format.
pub fn decode(input: &[u8]) -> Result<Self, DecodeError> {
let (opcode, rest) = parse_opcode(input)?;
match opcode {
Opcode::Rrq | Opcode::Wrq => {
let (filename, rest) = parse_netascii_z(rest, "filename")?;
let (mode_str, rest) = parse_netascii_z(rest, "mode")?;
if !rest.is_empty() {
return Err(DecodeError::TrailingBytes);
}
let Some(mode) = Mode::parse_case_insensitive(&mode_str) else {
return Err(DecodeError::UnknownMode(mode_str));
};
let req = Request { filename, mode };
Ok(if opcode == Opcode::Rrq {
Self::Rrq(req)
} else {
Self::Wrq(req)
})
}
Opcode::Data => {
if rest.len() < 2 {
return Err(DecodeError::Truncated("block"));
}
if rest.len() > 2 + BLOCK_SIZE {
return Err(DecodeError::OversizeData(rest.len() - 2));
}
let block = u16::from_be_bytes([rest[0], rest[1]]);
let data = rest[2..].to_vec();
Ok(Self::Data { block, data })
}
Opcode::Ack => {
if rest.len() != 2 {
return Err(DecodeError::InvalidLength {
kind: "ACK",
expected: 4,
got: input.len(),
});
}
let block = u16::from_be_bytes([rest[0], rest[1]]);
Ok(Self::Ack { block })
}
Opcode::Error => {
if rest.len() < 2 {
return Err(DecodeError::Truncated("error code"));
}
let code_u16 = u16::from_be_bytes([rest[0], rest[1]]);
let Some(code) = ErrorCode::from_u16(code_u16) else {
return Err(DecodeError::UnknownErrorCode(code_u16));
};
let (message, trailing) = parse_netascii_z(&rest[2..], "error message")?;
if !trailing.is_empty() {
return Err(DecodeError::TrailingBytes);
}
Ok(Self::Error { code, message })
}
}
}
}
fn encode_request(out: &mut Vec<u8>, opcode: Opcode, req: &Request) {
out.extend_from_slice(&(opcode as u16).to_be_bytes());
out.extend_from_slice(req.filename.as_bytes());
out.push(0);
out.extend_from_slice(req.mode.as_str().as_bytes());
out.push(0);
}
fn parse_opcode(input: &[u8]) -> Result<(Opcode, &[u8]), DecodeError> {
if input.len() < 2 {
return Err(DecodeError::Truncated("opcode"));
}
let opcode_u16 = u16::from_be_bytes([input[0], input[1]]);
let Some(opcode) = Opcode::from_u16(opcode_u16) else {
return Err(DecodeError::UnknownOpcode(opcode_u16));
};
Ok((opcode, &input[2..]))
}
fn parse_netascii_z<'a>(
input: &'a [u8],
field: &'static str,
) -> Result<(String, &'a [u8]), DecodeError> {
let Some(pos) = input.iter().position(|&b| b == 0) else {
return Err(DecodeError::MissingTerminator(field));
};
if pos == 0 {
return Err(DecodeError::EmptyField(field));
}
let bytes = &input[..pos];
if bytes.iter().any(|&b| b > 0x7F) {
return Err(DecodeError::NonNetascii(field));
}
let s = core::str::from_utf8(bytes).map_err(|_| DecodeError::NonUtf8(field))?;
Ok((s.to_owned(), &input[pos + 1..]))
}
/// A parse error for RFC 1350 packet formats.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DecodeError {
Truncated(&'static str),
UnknownOpcode(u16),
MissingTerminator(&'static str),
EmptyField(&'static str),
NonNetascii(&'static str),
NonUtf8(&'static str),
UnknownMode(String),
UnknownErrorCode(u16),
OversizeData(usize),
InvalidLength {
kind: &'static str,
expected: usize,
got: usize,
},
TrailingBytes,
}
impl fmt::Display for DecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Truncated(what) => write!(f, "truncated packet ({what})"),
Self::UnknownOpcode(op) => write!(f, "unknown opcode {op}"),
Self::MissingTerminator(field) => write!(f, "missing NUL terminator for {field}"),
Self::EmptyField(field) => write!(f, "empty {field}"),
Self::NonNetascii(field) => write!(f, "{field} contains non-netascii bytes"),
Self::NonUtf8(field) => write!(f, "{field} is not valid UTF-8"),
Self::UnknownMode(mode) => write!(f, "unknown transfer mode {mode:?}"),
Self::UnknownErrorCode(code) => write!(f, "unknown error code {code}"),
Self::OversizeData(n) => write!(f, "DATA payload too large ({n} bytes)"),
Self::InvalidLength {
kind,
expected,
got,
} => {
write!(
f,
"{kind} packet has invalid length (expected {expected}, got {got})"
)
}
Self::TrailingBytes => write!(f, "trailing bytes after packet"),
}
}
}
impl std::error::Error for DecodeError {}
#[cfg(test)]
mod tests {
use super::*;
#[allow(clippy::unwrap_used)]
#[test]
fn rrq_roundtrip() {
let pkt = Packet::Rrq(Request {
filename: "hello.txt".to_string(),
mode: Mode::Octet,
});
let encoded = pkt.encode();
let decoded = Packet::decode(&encoded).unwrap();
assert_eq!(decoded, pkt);
}
#[allow(clippy::unwrap_used)]
#[test]
fn data_roundtrip_small() {
let pkt = Packet::Data {
block: 42,
data: b"abc".to_vec(),
};
let encoded = pkt.encode();
let decoded = Packet::decode(&encoded).unwrap();
assert_eq!(decoded, pkt);
}
#[allow(clippy::unwrap_used)]
#[test]
fn ack_requires_exact_length() {
let mut bytes = Vec::new();
Packet::Ack { block: 1 }.encode_into(&mut bytes);
bytes.push(0);
let err = Packet::decode(&bytes).unwrap_err();
assert!(matches!(
err,
DecodeError::InvalidLength { kind: "ACK", .. }
));
}
#[allow(clippy::unwrap_used)]
#[test]
fn rejects_unknown_opcode() {
let err = Packet::decode(&[0, 9]).unwrap_err();
assert!(matches!(err, DecodeError::UnknownOpcode(9)));
}
#[allow(clippy::unwrap_used)]
#[test]
fn rejects_data_larger_than_512() {
let mut bytes = vec![0, 3, 0, 1];
bytes.extend(std::iter::repeat(0x61).take(BLOCK_SIZE + 1));
let err = Packet::decode(&bytes).unwrap_err();
assert!(matches!(err, DecodeError::OversizeData(513)));
}
}