371 lines
12 KiB
Rust
371 lines
12 KiB
Rust
//! Packet encoding/decoding for TFTP (RFC 1350).
|
|
//!
|
|
//! RFC 1350, Section 5 ("TFTP Packets") defines the packet formats:
|
|
//! - RRQ/WRQ: Figure 5-1
|
|
//! - DATA: Figure 5-2
|
|
//! - ACK: Figure 5-3
|
|
//! - ERROR: Figure 5-4
|
|
//!
|
|
//! This module implements strict parsing/serialization of those formats.
|
|
|
|
#![forbid(unsafe_code)]
|
|
|
|
use core::fmt;
|
|
|
|
/// RFC 1350, Section 2: data blocks are 512 bytes.
|
|
pub const BLOCK_SIZE: usize = 512;
|
|
|
|
/// Maximum packet size in RFC 1350 (DATA with 512 bytes).
|
|
pub const MAX_PACKET_SIZE: usize = 4 + BLOCK_SIZE;
|
|
|
|
/// TFTP opcode values (RFC 1350, Section 5).
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
#[repr(u16)]
|
|
pub enum Opcode {
|
|
/// Read request (RRQ), RFC 1350 Figure 5-1.
|
|
Rrq = 1,
|
|
/// Write request (WRQ), RFC 1350 Figure 5-1.
|
|
Wrq = 2,
|
|
/// Data (DATA), RFC 1350 Figure 5-2.
|
|
Data = 3,
|
|
/// Acknowledgment (ACK), RFC 1350 Figure 5-3.
|
|
Ack = 4,
|
|
/// Error (ERROR), RFC 1350 Figure 5-4.
|
|
Error = 5,
|
|
}
|
|
|
|
impl Opcode {
|
|
#[must_use]
|
|
pub fn from_u16(value: u16) -> Option<Self> {
|
|
match value {
|
|
1 => Some(Self::Rrq),
|
|
2 => Some(Self::Wrq),
|
|
3 => Some(Self::Data),
|
|
4 => Some(Self::Ack),
|
|
5 => Some(Self::Error),
|
|
_ => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Transfer mode for RRQ/WRQ (RFC 1350, Section 5, Figure 5-1).
|
|
///
|
|
/// RFC 1350 lists: "netascii", "octet", "mail" (mail is obsolete; Section 1).
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum Mode {
|
|
NetAscii,
|
|
Octet,
|
|
Mail,
|
|
}
|
|
|
|
impl Mode {
|
|
#[must_use]
|
|
pub fn as_str(self) -> &'static str {
|
|
match self {
|
|
Self::NetAscii => "netascii",
|
|
Self::Octet => "octet",
|
|
Self::Mail => "mail",
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn parse_case_insensitive(s: &str) -> Option<Self> {
|
|
if s.eq_ignore_ascii_case("netascii") {
|
|
Some(Self::NetAscii)
|
|
} else if s.eq_ignore_ascii_case("octet") {
|
|
Some(Self::Octet)
|
|
} else if s.eq_ignore_ascii_case("mail") {
|
|
Some(Self::Mail)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Error codes (RFC 1350 Appendix, "Error Codes").
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
#[repr(u16)]
|
|
pub enum ErrorCode {
|
|
NotDefined = 0,
|
|
FileNotFound = 1,
|
|
AccessViolation = 2,
|
|
DiskFull = 3,
|
|
IllegalOperation = 4,
|
|
UnknownTransferId = 5,
|
|
FileAlreadyExists = 6,
|
|
NoSuchUser = 7,
|
|
}
|
|
|
|
impl ErrorCode {
|
|
#[must_use]
|
|
pub fn from_u16(value: u16) -> Option<Self> {
|
|
match value {
|
|
0 => Some(Self::NotDefined),
|
|
1 => Some(Self::FileNotFound),
|
|
2 => Some(Self::AccessViolation),
|
|
3 => Some(Self::DiskFull),
|
|
4 => Some(Self::IllegalOperation),
|
|
5 => Some(Self::UnknownTransferId),
|
|
6 => Some(Self::FileAlreadyExists),
|
|
7 => Some(Self::NoSuchUser),
|
|
_ => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A RRQ/WRQ request (RFC 1350 Figure 5-1).
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub struct Request {
|
|
pub filename: String,
|
|
pub mode: Mode,
|
|
}
|
|
|
|
/// A decoded TFTP packet (RFC 1350 Section 5).
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum Packet {
|
|
Rrq(Request),
|
|
Wrq(Request),
|
|
Data { block: u16, data: Vec<u8> },
|
|
Ack { block: u16 },
|
|
Error { code: ErrorCode, message: String },
|
|
}
|
|
|
|
impl Packet {
|
|
/// Encode a packet into a newly allocated buffer.
|
|
#[must_use]
|
|
pub fn encode(&self) -> Vec<u8> {
|
|
let mut out = Vec::with_capacity(MAX_PACKET_SIZE);
|
|
self.encode_into(&mut out);
|
|
out
|
|
}
|
|
|
|
/// Encode a packet by appending to `out`.
|
|
pub fn encode_into(&self, out: &mut Vec<u8>) {
|
|
match self {
|
|
Self::Rrq(req) => encode_request(out, Opcode::Rrq, req),
|
|
Self::Wrq(req) => encode_request(out, Opcode::Wrq, req),
|
|
Self::Data { block, data } => {
|
|
out.extend_from_slice(&(Opcode::Data as u16).to_be_bytes());
|
|
out.extend_from_slice(&block.to_be_bytes());
|
|
out.extend_from_slice(data);
|
|
}
|
|
Self::Ack { block } => {
|
|
out.extend_from_slice(&(Opcode::Ack as u16).to_be_bytes());
|
|
out.extend_from_slice(&block.to_be_bytes());
|
|
}
|
|
Self::Error { code, message } => {
|
|
out.extend_from_slice(&(Opcode::Error as u16).to_be_bytes());
|
|
out.extend_from_slice(&(*code as u16).to_be_bytes());
|
|
out.extend_from_slice(message.as_bytes());
|
|
out.push(0);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Decode a packet from UDP payload bytes.
|
|
///
|
|
/// # Errors
|
|
/// Returns a [`DecodeError`] if the input does not match an RFC 1350 packet format.
|
|
pub fn decode(input: &[u8]) -> Result<Self, DecodeError> {
|
|
let (opcode, rest) = parse_opcode(input)?;
|
|
match opcode {
|
|
Opcode::Rrq | Opcode::Wrq => {
|
|
let (filename, rest) = parse_netascii_z(rest, "filename")?;
|
|
let (mode_str, rest) = parse_netascii_z(rest, "mode")?;
|
|
if !rest.is_empty() {
|
|
return Err(DecodeError::TrailingBytes);
|
|
}
|
|
let Some(mode) = Mode::parse_case_insensitive(&mode_str) else {
|
|
return Err(DecodeError::UnknownMode(mode_str));
|
|
};
|
|
let req = Request { filename, mode };
|
|
Ok(if opcode == Opcode::Rrq {
|
|
Self::Rrq(req)
|
|
} else {
|
|
Self::Wrq(req)
|
|
})
|
|
}
|
|
Opcode::Data => {
|
|
if rest.len() < 2 {
|
|
return Err(DecodeError::Truncated("block"));
|
|
}
|
|
if rest.len() > 2 + BLOCK_SIZE {
|
|
return Err(DecodeError::OversizeData(rest.len() - 2));
|
|
}
|
|
let block = u16::from_be_bytes([rest[0], rest[1]]);
|
|
let data = rest[2..].to_vec();
|
|
Ok(Self::Data { block, data })
|
|
}
|
|
Opcode::Ack => {
|
|
if rest.len() != 2 {
|
|
return Err(DecodeError::InvalidLength {
|
|
kind: "ACK",
|
|
expected: 4,
|
|
got: input.len(),
|
|
});
|
|
}
|
|
let block = u16::from_be_bytes([rest[0], rest[1]]);
|
|
Ok(Self::Ack { block })
|
|
}
|
|
Opcode::Error => {
|
|
if rest.len() < 2 {
|
|
return Err(DecodeError::Truncated("error code"));
|
|
}
|
|
let code_u16 = u16::from_be_bytes([rest[0], rest[1]]);
|
|
let Some(code) = ErrorCode::from_u16(code_u16) else {
|
|
return Err(DecodeError::UnknownErrorCode(code_u16));
|
|
};
|
|
let (message, trailing) = parse_netascii_z(&rest[2..], "error message")?;
|
|
if !trailing.is_empty() {
|
|
return Err(DecodeError::TrailingBytes);
|
|
}
|
|
Ok(Self::Error { code, message })
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn encode_request(out: &mut Vec<u8>, opcode: Opcode, req: &Request) {
|
|
out.extend_from_slice(&(opcode as u16).to_be_bytes());
|
|
out.extend_from_slice(req.filename.as_bytes());
|
|
out.push(0);
|
|
out.extend_from_slice(req.mode.as_str().as_bytes());
|
|
out.push(0);
|
|
}
|
|
|
|
fn parse_opcode(input: &[u8]) -> Result<(Opcode, &[u8]), DecodeError> {
|
|
if input.len() < 2 {
|
|
return Err(DecodeError::Truncated("opcode"));
|
|
}
|
|
let opcode_u16 = u16::from_be_bytes([input[0], input[1]]);
|
|
let Some(opcode) = Opcode::from_u16(opcode_u16) else {
|
|
return Err(DecodeError::UnknownOpcode(opcode_u16));
|
|
};
|
|
Ok((opcode, &input[2..]))
|
|
}
|
|
|
|
fn parse_netascii_z<'a>(
|
|
input: &'a [u8],
|
|
field: &'static str,
|
|
) -> Result<(String, &'a [u8]), DecodeError> {
|
|
let Some(pos) = input.iter().position(|&b| b == 0) else {
|
|
return Err(DecodeError::MissingTerminator(field));
|
|
};
|
|
if pos == 0 {
|
|
return Err(DecodeError::EmptyField(field));
|
|
}
|
|
let bytes = &input[..pos];
|
|
if bytes.iter().any(|&b| b > 0x7F) {
|
|
return Err(DecodeError::NonNetascii(field));
|
|
}
|
|
let s = core::str::from_utf8(bytes).map_err(|_| DecodeError::NonUtf8(field))?;
|
|
Ok((s.to_owned(), &input[pos + 1..]))
|
|
}
|
|
|
|
/// A parse error for RFC 1350 packet formats.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum DecodeError {
|
|
Truncated(&'static str),
|
|
UnknownOpcode(u16),
|
|
MissingTerminator(&'static str),
|
|
EmptyField(&'static str),
|
|
NonNetascii(&'static str),
|
|
NonUtf8(&'static str),
|
|
UnknownMode(String),
|
|
UnknownErrorCode(u16),
|
|
OversizeData(usize),
|
|
InvalidLength {
|
|
kind: &'static str,
|
|
expected: usize,
|
|
got: usize,
|
|
},
|
|
TrailingBytes,
|
|
}
|
|
|
|
impl fmt::Display for DecodeError {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
Self::Truncated(what) => write!(f, "truncated packet ({what})"),
|
|
Self::UnknownOpcode(op) => write!(f, "unknown opcode {op}"),
|
|
Self::MissingTerminator(field) => write!(f, "missing NUL terminator for {field}"),
|
|
Self::EmptyField(field) => write!(f, "empty {field}"),
|
|
Self::NonNetascii(field) => write!(f, "{field} contains non-netascii bytes"),
|
|
Self::NonUtf8(field) => write!(f, "{field} is not valid UTF-8"),
|
|
Self::UnknownMode(mode) => write!(f, "unknown transfer mode {mode:?}"),
|
|
Self::UnknownErrorCode(code) => write!(f, "unknown error code {code}"),
|
|
Self::OversizeData(n) => write!(f, "DATA payload too large ({n} bytes)"),
|
|
Self::InvalidLength {
|
|
kind,
|
|
expected,
|
|
got,
|
|
} => {
|
|
write!(
|
|
f,
|
|
"{kind} packet has invalid length (expected {expected}, got {got})"
|
|
)
|
|
}
|
|
Self::TrailingBytes => write!(f, "trailing bytes after packet"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for DecodeError {}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[allow(clippy::unwrap_used)]
|
|
#[test]
|
|
fn rrq_roundtrip() {
|
|
let pkt = Packet::Rrq(Request {
|
|
filename: "hello.txt".to_string(),
|
|
mode: Mode::Octet,
|
|
});
|
|
let encoded = pkt.encode();
|
|
let decoded = Packet::decode(&encoded).unwrap();
|
|
assert_eq!(decoded, pkt);
|
|
}
|
|
|
|
#[allow(clippy::unwrap_used)]
|
|
#[test]
|
|
fn data_roundtrip_small() {
|
|
let pkt = Packet::Data {
|
|
block: 42,
|
|
data: b"abc".to_vec(),
|
|
};
|
|
let encoded = pkt.encode();
|
|
let decoded = Packet::decode(&encoded).unwrap();
|
|
assert_eq!(decoded, pkt);
|
|
}
|
|
|
|
#[allow(clippy::unwrap_used)]
|
|
#[test]
|
|
fn ack_requires_exact_length() {
|
|
let mut bytes = Vec::new();
|
|
Packet::Ack { block: 1 }.encode_into(&mut bytes);
|
|
bytes.push(0);
|
|
let err = Packet::decode(&bytes).unwrap_err();
|
|
assert!(matches!(
|
|
err,
|
|
DecodeError::InvalidLength { kind: "ACK", .. }
|
|
));
|
|
}
|
|
|
|
#[allow(clippy::unwrap_used)]
|
|
#[test]
|
|
fn rejects_unknown_opcode() {
|
|
let err = Packet::decode(&[0, 9]).unwrap_err();
|
|
assert!(matches!(err, DecodeError::UnknownOpcode(9)));
|
|
}
|
|
|
|
#[allow(clippy::unwrap_used)]
|
|
#[test]
|
|
fn rejects_data_larger_than_512() {
|
|
let mut bytes = vec![0, 3, 0, 1];
|
|
bytes.extend(std::iter::repeat(0x61).take(BLOCK_SIZE + 1));
|
|
let err = Packet::decode(&bytes).unwrap_err();
|
|
assert!(matches!(err, DecodeError::OversizeData(513)));
|
|
}
|
|
}
|