feat: implement TFTP protocol crate (pfs-tftp-proto)

This commit introduces a sans-IO TFTP protocol implementation following
RFC 1350. The protocol crate provides:

- Packet types: RRQ, WRQ, DATA, ACK, ERROR with full serialization/parsing
- Error codes as defined in RFC 1350 Appendix
- Transfer modes: octet (binary) and netascii
- Client and server state machines for managing protocol flow
- Comprehensive tests for all packet types and state transitions

The sans-IO design separates protocol logic from I/O operations, making
the code testable and reusable across different I/O implementations.

Key design decisions:
- Mail mode is explicitly rejected as obsolete per RFC 1350
- Block numbers wrap around at 65535 for large file support
- State machines emit events that tell the I/O layer what to do next
- All protocol-specific values are documented with RFC citations

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-21 13:10:42 +01:00
parent 5e9be9d27f
commit fd8dace0dc
15 changed files with 2746 additions and 9 deletions

View File

@@ -0,0 +1,193 @@
//! TFTP Error Types
//!
//! Defines error codes and error handling for TFTP protocol.
//!
//! # RFC 1350 Error Codes (Appendix)
//!
//! ```text
//! Value Meaning
//! 0 Not defined, see error message (if any).
//! 1 File not found.
//! 2 Access violation.
//! 3 Disk full or allocation exceeded.
//! 4 Illegal TFTP operation.
//! 5 Unknown transfer ID.
//! 6 File already exists.
//! 7 No such user.
//! ```
use thiserror::Error;
/// TFTP error codes as defined in RFC 1350 Appendix.
///
/// These codes are sent in ERROR packets to indicate the nature of the error.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub enum ErrorCode {
/// Not defined, see error message (if any).
/// Used when the error doesn't fit other categories.
NotDefined = 0,
/// File not found.
/// The requested file does not exist on the server.
FileNotFound = 1,
/// Access violation.
/// The client does not have permission to access the file.
AccessViolation = 2,
/// Disk full or allocation exceeded.
/// There is insufficient storage space to complete the transfer.
DiskFull = 3,
/// Illegal TFTP operation.
/// The requested operation is not valid (e.g., malformed packet).
IllegalOperation = 4,
/// Unknown transfer ID.
/// RFC 1350 Section 4: "If a source TID does not match, the packet should
/// be discarded as erroneously sent from somewhere else. An error packet
/// should be sent to the source of the incorrect packet."
UnknownTransferId = 5,
/// File already exists.
/// Returned when attempting to write a file that already exists
/// and the server policy doesn't allow overwriting.
FileAlreadyExists = 6,
/// No such user.
/// The specified user does not exist (relevant for mail mode, which is obsolete).
NoSuchUser = 7,
}
impl ErrorCode {
/// Create an `ErrorCode` from a raw u16 value.
///
/// Returns `NotDefined` for any unknown error code, as per RFC 1350
/// which states error code 0 is "Not defined".
#[must_use]
pub const fn from_u16(value: u16) -> Self {
match value {
1 => Self::FileNotFound,
2 => Self::AccessViolation,
3 => Self::DiskFull,
4 => Self::IllegalOperation,
5 => Self::UnknownTransferId,
6 => Self::FileAlreadyExists,
7 => Self::NoSuchUser,
// RFC 1350: Error code 0 is "Not defined", also used as fallback
_ => Self::NotDefined,
}
}
/// Get the default error message for this error code.
#[must_use]
pub const fn default_message(&self) -> &'static str {
match self {
Self::NotDefined => "Not defined",
Self::FileNotFound => "File not found",
Self::AccessViolation => "Access violation",
Self::DiskFull => "Disk full or allocation exceeded",
Self::IllegalOperation => "Illegal TFTP operation",
Self::UnknownTransferId => "Unknown transfer ID",
Self::FileAlreadyExists => "File already exists",
Self::NoSuchUser => "No such user",
}
}
}
impl From<ErrorCode> for u16 {
fn from(code: ErrorCode) -> Self {
code as Self
}
}
/// Protocol-level errors for TFTP operations.
#[derive(Debug, Error)]
pub enum Error {
/// Packet is too short to be valid.
#[error("packet too short: expected at least {expected} bytes, got {actual}")]
PacketTooShort { expected: usize, actual: usize },
/// Invalid opcode in packet.
#[error("invalid opcode: {0}")]
InvalidOpcode(u16),
/// Invalid transfer mode in RRQ/WRQ packet.
#[error("invalid mode: {0}")]
InvalidMode(String),
/// Missing null terminator in string field.
#[error("missing null terminator in {field}")]
MissingNullTerminator { field: &'static str },
/// String field contains invalid UTF-8.
#[error("invalid UTF-8 in {field}: {source}")]
InvalidUtf8 {
field: &'static str,
#[source]
source: std::str::Utf8Error,
},
/// Data exceeds maximum block size.
#[error("data too large: {size} bytes exceeds maximum of 512")]
DataTooLarge { size: usize },
/// Buffer too small for serialization.
#[error("buffer too small: need {needed} bytes, have {available}")]
BufferTooSmall { needed: usize, available: usize },
/// TFTP error received from remote peer.
#[error("TFTP error {code:?}: {message}")]
TftpError { code: ErrorCode, message: String },
/// Protocol state error.
#[error("protocol state error: {0}")]
StateError(String),
/// Unexpected packet type received.
#[error("unexpected packet type: expected {expected}, got {actual}")]
UnexpectedPacket {
expected: &'static str,
actual: &'static str,
},
/// Block number mismatch.
#[error("block number mismatch: expected {expected}, got {actual}")]
BlockMismatch { expected: u16, actual: u16 },
}
/// Result type for TFTP operations.
pub type Result<T> = std::result::Result<T, Error>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_code_from_u16() {
assert_eq!(ErrorCode::from_u16(0), ErrorCode::NotDefined);
assert_eq!(ErrorCode::from_u16(1), ErrorCode::FileNotFound);
assert_eq!(ErrorCode::from_u16(2), ErrorCode::AccessViolation);
assert_eq!(ErrorCode::from_u16(3), ErrorCode::DiskFull);
assert_eq!(ErrorCode::from_u16(4), ErrorCode::IllegalOperation);
assert_eq!(ErrorCode::from_u16(5), ErrorCode::UnknownTransferId);
assert_eq!(ErrorCode::from_u16(6), ErrorCode::FileAlreadyExists);
assert_eq!(ErrorCode::from_u16(7), ErrorCode::NoSuchUser);
// Unknown codes should map to NotDefined
assert_eq!(ErrorCode::from_u16(8), ErrorCode::NotDefined);
assert_eq!(ErrorCode::from_u16(255), ErrorCode::NotDefined);
}
#[test]
fn test_error_code_to_u16() {
assert_eq!(u16::from(ErrorCode::NotDefined), 0);
assert_eq!(u16::from(ErrorCode::FileNotFound), 1);
assert_eq!(u16::from(ErrorCode::AccessViolation), 2);
assert_eq!(u16::from(ErrorCode::DiskFull), 3);
assert_eq!(u16::from(ErrorCode::IllegalOperation), 4);
assert_eq!(u16::from(ErrorCode::UnknownTransferId), 5);
assert_eq!(u16::from(ErrorCode::FileAlreadyExists), 6);
assert_eq!(u16::from(ErrorCode::NoSuchUser), 7);
}
}

View File

@@ -0,0 +1,21 @@
//! Sans-IO TFTP Protocol Implementation (RFC 1350)
//!
//! This crate provides a pure protocol implementation without any I/O operations.
//! It handles packet parsing, serialization, and protocol state management.
//!
//! # RFC 1350 Overview
//!
//! TFTP (Trivial File Transfer Protocol) is a simple file transfer protocol
//! that operates over UDP. Key characteristics:
//! - Uses fixed 512-byte data blocks
//! - Stop-and-wait protocol (each packet must be acknowledged)
//! - Five packet types: RRQ, WRQ, DATA, ACK, ERROR
//! - Default port 69 for initial server connection
mod error;
mod packet;
mod state;
pub use error::{Error, ErrorCode, Result};
pub use packet::{Mode, Opcode, Packet, MAX_DATA_SIZE, TFTP_PORT};
pub use state::{ClientState, Event, ServerState, TransferDirection};

View File

@@ -0,0 +1,632 @@
//! TFTP Packet Types and Serialization
//!
//! This module implements all five TFTP packet types as defined in RFC 1350:
//! - RRQ (Read Request)
//! - WRQ (Write Request)
//! - DATA
//! - ACK
//! - ERROR
//!
//! # RFC 1350 Section 5 - TFTP Packets
//!
//! ```text
//! TFTP supports five types of packets:
//!
//! opcode operation
//! 1 Read request (RRQ)
//! 2 Write request (WRQ)
//! 3 Data (DATA)
//! 4 Acknowledgment (ACK)
//! 5 Error (ERROR)
//! ```
use crate::{Error, ErrorCode, Result};
/// Standard TFTP server port as defined in RFC 1350.
///
/// RFC 1350 Section 4: "A requesting host chooses its source TID as described
/// above, and sends its initial request to the known TID 69 decimal (105 octal)
/// on the serving host."
pub const TFTP_PORT: u16 = 69;
/// Maximum data size in a DATA packet.
///
/// RFC 1350 Section 2: "the connection is opened and the file is sent in fixed
/// length blocks of 512 bytes."
///
/// RFC 1350 Section 5: "The data field is from zero to 512 bytes long."
pub const MAX_DATA_SIZE: usize = 512;
/// Minimum packet size (just an opcode).
const MIN_PACKET_SIZE: usize = 2;
/// TFTP packet opcodes.
///
/// RFC 1350 Section 5: "The TFTP header of a packet contains the opcode
/// associated with that packet."
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub enum Opcode {
/// Read request (RRQ)
Rrq = 1,
/// Write request (WRQ)
Wrq = 2,
/// Data packet
Data = 3,
/// Acknowledgment
Ack = 4,
/// Error
Error = 5,
}
impl Opcode {
/// Create an `Opcode` from a raw u16 value.
///
/// # Errors
///
/// Returns `Error::InvalidOpcode` if the value is not a valid opcode.
pub const fn from_u16(value: u16) -> Result<Self> {
match value {
1 => Ok(Self::Rrq),
2 => Ok(Self::Wrq),
3 => Ok(Self::Data),
4 => Ok(Self::Ack),
5 => Ok(Self::Error),
_ => Err(Error::InvalidOpcode(value)),
}
}
}
impl From<Opcode> for u16 {
fn from(opcode: Opcode) -> Self {
opcode as Self
}
}
/// Transfer mode for TFTP operations.
///
/// RFC 1350 Section 1: "Three modes of transfer are currently supported:
/// netascii [...]; octet [...]; mail [...]. (The mail mode is obsolete
/// and should not be implemented or used.)"
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Mode {
/// `NetASCII` mode - text transfer with CR/LF line endings.
///
/// RFC 1350: "netascii (This is ascii as defined in 'USA Standard Code
/// for Information Interchange' with the modifications specified in
/// 'Telnet Protocol Specification'.) Note that it is 8 bit ascii."
NetAscii,
/// Octet mode - raw binary transfer.
///
/// RFC 1350: "octet (This replaces the 'binary' mode of previous versions
/// of this document.) raw 8 bit bytes [...] If a host receives a octet
/// file and then returns it, the returned file must be identical to the
/// original."
#[default]
Octet,
}
impl Mode {
/// Parse a mode string (case-insensitive).
///
/// RFC 1350 Section 5: "The mode field contains the string 'netascii',
/// 'octet', or 'mail' (or any combination of upper and lower case,
/// such as 'NETASCII', '`NetAscii`', etc.)"
///
/// # Errors
///
/// Returns `Error::InvalidMode` if the string is not a recognized mode.
pub fn parse(s: &str) -> Result<Self> {
match s.to_ascii_lowercase().as_str() {
"netascii" => Ok(Self::NetAscii),
"octet" => Ok(Self::Octet),
// Mail mode is obsolete per RFC 1350
"mail" => Err(Error::InvalidMode(
"mail mode is obsolete and not supported".to_string(),
)),
_ => Err(Error::InvalidMode(s.to_string())),
}
}
/// Get the string representation of the mode.
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::NetAscii => "netascii",
Self::Octet => "octet",
}
}
}
/// A TFTP packet.
///
/// All packet formats are defined in RFC 1350 Section 5 and Appendix I.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Packet {
/// Read Request (RRQ) - opcode 1
///
/// RFC 1350 Figure 5-1:
/// ```text
/// 2 bytes string 1 byte string 1 byte
/// ------------------------------------------------
/// | Opcode | Filename | 0 | Mode | 0 |
/// ------------------------------------------------
/// ```
ReadRequest { filename: String, mode: Mode },
/// Write Request (WRQ) - opcode 2
///
/// Same format as RRQ (RFC 1350 Figure 5-1).
WriteRequest { filename: String, mode: Mode },
/// Data packet - opcode 3
///
/// RFC 1350 Figure 5-2:
/// ```text
/// 2 bytes 2 bytes n bytes
/// ----------------------------------
/// | Opcode | Block # | Data |
/// ----------------------------------
/// ```
///
/// RFC 1350 Section 5: "The block numbers on data packets begin with one
/// and increase by one for each new block of data. [...] The data field
/// is from zero to 512 bytes long. If it is 512 bytes long, the block is
/// not the last block of data; if it is from zero to 511 bytes long, it
/// signals the end of the transfer."
Data { block: u16, data: Vec<u8> },
/// Acknowledgment packet - opcode 4
///
/// RFC 1350 Figure 5-3:
/// ```text
/// 2 bytes 2 bytes
/// ---------------------
/// | Opcode | Block # |
/// ---------------------
/// ```
///
/// RFC 1350 Section 5: "The block number in an ACK echoes the block number
/// of the DATA packet being acknowledged. A WRQ is acknowledged with an
/// ACK packet having a block number of zero."
Ack { block: u16 },
/// Error packet - opcode 5
///
/// RFC 1350 Figure 5-4:
/// ```text
/// 2 bytes 2 bytes string 1 byte
/// -----------------------------------------
/// | Opcode | ErrorCode | ErrMsg | 0 |
/// -----------------------------------------
/// ```
///
/// RFC 1350 Section 5: "An ERROR packet can be the acknowledgment of any
/// other type of packet. The error code is an integer indicating the nature
/// of the error. [...] The error message is intended for human consumption,
/// and should be in netascii."
Error { code: ErrorCode, message: String },
}
impl Packet {
/// Parse a TFTP packet from raw bytes.
///
/// # Errors
///
/// Returns an error if the packet is malformed.
pub fn parse(data: &[u8]) -> Result<Self> {
if data.len() < MIN_PACKET_SIZE {
return Err(Error::PacketTooShort {
expected: MIN_PACKET_SIZE,
actual: data.len(),
});
}
let opcode = u16::from_be_bytes([data[0], data[1]]);
let opcode = Opcode::from_u16(opcode)?;
let payload = &data[2..];
match opcode {
Opcode::Rrq => Self::parse_request(payload, false),
Opcode::Wrq => Self::parse_request(payload, true),
Opcode::Data => Self::parse_data(payload),
Opcode::Ack => Self::parse_ack(payload),
Opcode::Error => Self::parse_error(payload),
}
}
/// Parse a RRQ or WRQ packet payload.
fn parse_request(payload: &[u8], is_write: bool) -> Result<Self> {
// Find the null terminator for filename
let filename_end = payload
.iter()
.position(|&b| b == 0)
.ok_or(Error::MissingNullTerminator { field: "filename" })?;
let filename = std::str::from_utf8(&payload[..filename_end]).map_err(|e| {
Error::InvalidUtf8 {
field: "filename",
source: e,
}
})?;
// Find the mode string after the filename null terminator
let mode_start = filename_end + 1;
if mode_start >= payload.len() {
return Err(Error::MissingNullTerminator { field: "mode" });
}
let mode_end = payload[mode_start..]
.iter()
.position(|&b| b == 0)
.ok_or(Error::MissingNullTerminator { field: "mode" })?
+ mode_start;
let mode_str =
std::str::from_utf8(&payload[mode_start..mode_end]).map_err(|e| Error::InvalidUtf8 {
field: "mode",
source: e,
})?;
let mode = Mode::parse(mode_str)?;
if is_write {
Ok(Self::WriteRequest {
filename: filename.to_string(),
mode,
})
} else {
Ok(Self::ReadRequest {
filename: filename.to_string(),
mode,
})
}
}
/// Parse a DATA packet payload.
fn parse_data(payload: &[u8]) -> Result<Self> {
if payload.len() < 2 {
return Err(Error::PacketTooShort {
expected: 4, // 2 opcode + 2 block
actual: payload.len() + 2,
});
}
let block = u16::from_be_bytes([payload[0], payload[1]]);
let data = payload[2..].to_vec();
if data.len() > MAX_DATA_SIZE {
return Err(Error::DataTooLarge { size: data.len() });
}
Ok(Self::Data { block, data })
}
/// Parse an ACK packet payload.
fn parse_ack(payload: &[u8]) -> Result<Self> {
if payload.len() < 2 {
return Err(Error::PacketTooShort {
expected: 4, // 2 opcode + 2 block
actual: payload.len() + 2,
});
}
let block = u16::from_be_bytes([payload[0], payload[1]]);
Ok(Self::Ack { block })
}
/// Parse an ERROR packet payload.
fn parse_error(payload: &[u8]) -> Result<Self> {
if payload.len() < 2 {
return Err(Error::PacketTooShort {
expected: 4, // 2 opcode + 2 error code
actual: payload.len() + 2,
});
}
let code = u16::from_be_bytes([payload[0], payload[1]]);
let code = ErrorCode::from_u16(code);
// Find the null terminator for error message
let msg_start = 2;
let msg_end = payload[msg_start..]
.iter()
.position(|&b| b == 0)
.map_or(payload.len(), |pos| pos + msg_start);
let message = std::str::from_utf8(&payload[msg_start..msg_end])
.map_err(|e| Error::InvalidUtf8 {
field: "error message",
source: e,
})?
.to_string();
Ok(Self::Error { code, message })
}
/// Serialize the packet into bytes.
///
/// # Errors
///
/// Returns an error if the data is too large.
pub fn serialize(&self) -> Result<Vec<u8>> {
match self {
Self::ReadRequest { filename, mode } => {
Ok(Self::serialize_request(Opcode::Rrq, filename, *mode))
}
Self::WriteRequest { filename, mode } => {
Ok(Self::serialize_request(Opcode::Wrq, filename, *mode))
}
Self::Data { block, data } => Self::serialize_data(*block, data),
Self::Ack { block } => Ok(Self::serialize_ack(*block)),
Self::Error { code, message } => Ok(Self::serialize_error(*code, message)),
}
}
/// Serialize a RRQ or WRQ packet.
fn serialize_request(opcode: Opcode, filename: &str, mode: Mode) -> Vec<u8> {
let mode_str = mode.as_str();
let capacity = 2 + filename.len() + 1 + mode_str.len() + 1;
let mut buf = Vec::with_capacity(capacity);
buf.extend_from_slice(&u16::from(opcode).to_be_bytes());
buf.extend_from_slice(filename.as_bytes());
buf.push(0);
buf.extend_from_slice(mode_str.as_bytes());
buf.push(0);
buf
}
/// Serialize a DATA packet.
fn serialize_data(block: u16, data: &[u8]) -> Result<Vec<u8>> {
if data.len() > MAX_DATA_SIZE {
return Err(Error::DataTooLarge { size: data.len() });
}
let mut buf = Vec::with_capacity(4 + data.len());
buf.extend_from_slice(&u16::from(Opcode::Data).to_be_bytes());
buf.extend_from_slice(&block.to_be_bytes());
buf.extend_from_slice(data);
Ok(buf)
}
/// Serialize an ACK packet.
fn serialize_ack(block: u16) -> Vec<u8> {
let mut buf = Vec::with_capacity(4);
buf.extend_from_slice(&u16::from(Opcode::Ack).to_be_bytes());
buf.extend_from_slice(&block.to_be_bytes());
buf
}
/// Serialize an ERROR packet.
fn serialize_error(code: ErrorCode, message: &str) -> Vec<u8> {
let capacity = 4 + message.len() + 1;
let mut buf = Vec::with_capacity(capacity);
buf.extend_from_slice(&u16::from(Opcode::Error).to_be_bytes());
buf.extend_from_slice(&u16::from(code).to_be_bytes());
buf.extend_from_slice(message.as_bytes());
buf.push(0);
buf
}
/// Get the opcode of this packet.
#[must_use]
pub const fn opcode(&self) -> Opcode {
match self {
Self::ReadRequest { .. } => Opcode::Rrq,
Self::WriteRequest { .. } => Opcode::Wrq,
Self::Data { .. } => Opcode::Data,
Self::Ack { .. } => Opcode::Ack,
Self::Error { .. } => Opcode::Error,
}
}
/// Get a human-readable name for the packet type.
#[must_use]
pub const fn type_name(&self) -> &'static str {
match self {
Self::ReadRequest { .. } => "RRQ",
Self::WriteRequest { .. } => "WRQ",
Self::Data { .. } => "DATA",
Self::Ack { .. } => "ACK",
Self::Error { .. } => "ERROR",
}
}
/// Check if this is the final data packet (less than 512 bytes).
///
/// RFC 1350 Section 5: "If it is 512 bytes long, the block is not the last
/// block of data; if it is from zero to 511 bytes long, it signals the end
/// of the transfer."
#[must_use]
pub const fn is_final_data(&self) -> bool {
matches!(self, Self::Data { data, .. } if data.len() < MAX_DATA_SIZE)
}
/// Create a new RRQ packet.
#[must_use]
pub fn rrq(filename: impl Into<String>, mode: Mode) -> Self {
Self::ReadRequest {
filename: filename.into(),
mode,
}
}
/// Create a new WRQ packet.
#[must_use]
pub fn wrq(filename: impl Into<String>, mode: Mode) -> Self {
Self::WriteRequest {
filename: filename.into(),
mode,
}
}
/// Create a new DATA packet.
#[must_use]
pub fn data(block: u16, data: Vec<u8>) -> Self {
Self::Data { block, data }
}
/// Create a new ACK packet.
#[must_use]
pub const fn ack(block: u16) -> Self {
Self::Ack { block }
}
/// Create a new ERROR packet.
#[must_use]
pub fn error(code: ErrorCode, message: impl Into<String>) -> Self {
Self::Error {
code,
message: message.into(),
}
}
/// Create an ERROR packet with default message for the error code.
#[must_use]
pub fn error_default(code: ErrorCode) -> Self {
Self::Error {
code,
message: code.default_message().to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rrq_roundtrip() {
let packet = Packet::rrq("test.txt", Mode::Octet);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
}
#[test]
fn test_wrq_roundtrip() {
let packet = Packet::wrq("output.bin", Mode::NetAscii);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
}
#[test]
fn test_data_roundtrip() {
let packet = Packet::data(1, vec![1, 2, 3, 4, 5]);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
}
#[test]
fn test_data_empty() {
let packet = Packet::data(42, vec![]);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
assert!(packet.is_final_data());
}
#[test]
fn test_data_max_size() {
let packet = Packet::data(100, vec![0xAB; MAX_DATA_SIZE]);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
assert!(!packet.is_final_data());
}
#[test]
fn test_data_too_large() {
let packet = Packet::data(1, vec![0; MAX_DATA_SIZE + 1]);
let result = packet.serialize();
assert!(matches!(result, Err(Error::DataTooLarge { .. })));
}
#[test]
fn test_ack_roundtrip() {
let packet = Packet::ack(0);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
let packet = Packet::ack(65535);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
}
#[test]
fn test_error_roundtrip() {
let packet = Packet::error(ErrorCode::FileNotFound, "File not found");
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
}
#[test]
fn test_error_default_message() {
let packet = Packet::error_default(ErrorCode::AccessViolation);
assert!(matches!(
packet,
Packet::Error {
code: ErrorCode::AccessViolation,
message
} if message == "Access violation"
));
}
#[test]
fn test_mode_case_insensitive() {
assert_eq!(Mode::parse("OCTET").expect("parse"), Mode::Octet);
assert_eq!(Mode::parse("Octet").expect("parse"), Mode::Octet);
assert_eq!(Mode::parse("netascii").expect("parse"), Mode::NetAscii);
assert_eq!(Mode::parse("NETASCII").expect("parse"), Mode::NetAscii);
assert_eq!(Mode::parse("NetAscii").expect("parse"), Mode::NetAscii);
}
#[test]
fn test_invalid_mode() {
assert!(matches!(
Mode::parse("binary"),
Err(Error::InvalidMode(_))
));
assert!(matches!(Mode::parse("mail"), Err(Error::InvalidMode(_))));
}
#[test]
fn test_invalid_opcode() {
let data = [0x00, 0x06]; // opcode 6 is invalid
let result = Packet::parse(&data);
assert!(matches!(result, Err(Error::InvalidOpcode(6))));
}
#[test]
fn test_packet_too_short() {
let data = [0x00];
let result = Packet::parse(&data);
assert!(matches!(result, Err(Error::PacketTooShort { .. })));
}
#[test]
fn test_rrq_parse_raw() {
// RRQ for "test.txt" in octet mode
let data = [
0x00, 0x01, // opcode = RRQ
b't', b'e', b's', b't', b'.', b't', b'x', b't', 0x00, // filename
b'o', b'c', b't', b'e', b't', 0x00, // mode
];
let packet = Packet::parse(&data).expect("parse failed");
assert!(matches!(
packet,
Packet::ReadRequest { filename, mode: Mode::Octet } if filename == "test.txt"
));
}
}

View File

@@ -0,0 +1,661 @@
//! TFTP Protocol State Machine
//!
//! This module implements sans-io state machines for both client and server
//! sides of the TFTP protocol.
//!
//! # RFC 1350 Protocol Overview
//!
//! The protocol follows a lock-step acknowledgment pattern:
//!
//! ## Read Transfer (Client reading from Server)
//! ```text
//! Client -> Server: RRQ (filename, mode)
//! Server -> Client: DATA (block 1)
//! Client -> Server: ACK (block 1)
//! Server -> Client: DATA (block 2)
//! ... (continues until final DATA with < 512 bytes)
//! ```
//!
//! ## Write Transfer (Client writing to Server)
//! ```text
//! Client -> Server: WRQ (filename, mode)
//! Server -> Client: ACK (block 0)
//! Client -> Server: DATA (block 1)
//! Server -> Client: ACK (block 1)
//! ... (continues until final DATA with < 512 bytes)
//! ```
use crate::{Error, Mode, Packet, Result, MAX_DATA_SIZE};
/// Direction of data transfer.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferDirection {
/// Reading data from remote (client downloads, server uploads)
Read,
/// Writing data to remote (client uploads, server downloads)
Write,
}
/// Events emitted by the state machine.
///
/// These events tell the I/O layer what action to take next.
/// Errors are returned via `Result` from the state machine methods,
/// not as events.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Event {
/// Send this packet to the remote peer.
Send(Packet),
/// Request data from the application for the given block number.
///
/// The application should provide up to 512 bytes of data.
/// A response with less than 512 bytes indicates the final block.
NeedData { block: u16 },
/// Received data from remote peer.
///
/// The application should store this data.
ReceivedData { block: u16, data: Vec<u8> },
/// Transfer completed successfully.
Complete,
}
/// Client-side protocol state machine.
///
/// Handles the client side of TFTP transfers (both read and write).
#[derive(Debug)]
pub struct ClientState {
/// The direction of the transfer
direction: TransferDirection,
/// Transfer mode
mode: Mode,
/// Current block number we're expecting/sending
current_block: u16,
/// Whether the transfer is complete
complete: bool,
/// Whether we've received the initial response
got_initial_response: bool,
}
impl ClientState {
/// Create a new client state for a read request.
///
/// Call `start()` to get the initial RRQ packet to send.
#[must_use]
pub const fn new_read(mode: Mode) -> Self {
Self {
direction: TransferDirection::Read,
mode,
current_block: 1, // Expecting DATA block 1
complete: false,
got_initial_response: false,
}
}
/// Create a new client state for a write request.
///
/// Call `start()` to get the initial WRQ packet to send.
#[must_use]
pub const fn new_write(mode: Mode) -> Self {
Self {
direction: TransferDirection::Write,
mode,
current_block: 0, // Expecting ACK block 0
complete: false,
got_initial_response: false,
}
}
/// Get the initial request packet to send.
#[must_use]
pub fn start(&self, filename: &str) -> Packet {
match self.direction {
TransferDirection::Read => Packet::rrq(filename, self.mode),
TransferDirection::Write => Packet::wrq(filename, self.mode),
}
}
/// Check if the transfer is complete.
#[must_use]
pub const fn is_complete(&self) -> bool {
self.complete
}
/// Get the current expected block number.
#[must_use]
pub const fn current_block(&self) -> u16 {
self.current_block
}
/// Process a received packet and return events.
///
/// # Errors
///
/// Returns an error if the packet is invalid for the current state.
pub fn receive(&mut self, packet: &Packet) -> Result<Vec<Event>> {
if self.complete {
return Err(Error::StateError("transfer already complete".to_string()));
}
match (self.direction, packet) {
// Reading: expect DATA packets
(TransferDirection::Read, Packet::Data { block, data }) => {
self.handle_read_data(*block, data)
}
// Writing: expect ACK packets
(TransferDirection::Write, Packet::Ack { block }) => self.handle_write_ack(*block),
// Error packet terminates transfer
(_, Packet::Error { code, message }) => {
self.complete = true;
Err(Error::TftpError {
code: *code,
message: message.clone(),
})
}
// Unexpected packet type
(TransferDirection::Read, other) => Err(Error::UnexpectedPacket {
expected: "DATA",
actual: other.type_name(),
}),
(TransferDirection::Write, other) => Err(Error::UnexpectedPacket {
expected: "ACK",
actual: other.type_name(),
}),
}
}
/// Handle a DATA packet during a read transfer.
fn handle_read_data(&mut self, block: u16, data: &[u8]) -> Result<Vec<Event>> {
self.got_initial_response = true;
// RFC 1350: Block numbers start at 1 and increase sequentially
if block != self.current_block {
// Could be a duplicate - if it's the previous block, re-send ACK
if block == self.current_block.wrapping_sub(1) {
return Ok(vec![Event::Send(Packet::ack(block))]);
}
return Err(Error::BlockMismatch {
expected: self.current_block,
actual: block,
});
}
let is_final = data.len() < MAX_DATA_SIZE;
let mut events = vec![
Event::ReceivedData {
block,
data: data.to_vec(),
},
Event::Send(Packet::ack(block)),
];
if is_final {
self.complete = true;
events.push(Event::Complete);
} else {
self.current_block = self.current_block.wrapping_add(1);
}
Ok(events)
}
/// Handle an ACK packet during a write transfer.
fn handle_write_ack(&mut self, block: u16) -> Result<Vec<Event>> {
// RFC 1350 Section 4: "A WRQ is acknowledged with an ACK packet having
// a block number of zero."
if !self.got_initial_response {
if block != 0 {
return Err(Error::BlockMismatch {
expected: 0,
actual: block,
});
}
self.got_initial_response = true;
self.current_block = 1;
return Ok(vec![Event::NeedData { block: 1 }]);
}
if block != self.current_block {
// Duplicate ACK - ignore
if block == self.current_block.wrapping_sub(1) {
return Ok(vec![]);
}
return Err(Error::BlockMismatch {
expected: self.current_block,
actual: block,
});
}
// Request next block of data
self.current_block = self.current_block.wrapping_add(1);
Ok(vec![Event::NeedData {
block: self.current_block,
}])
}
/// Provide data for a write transfer.
///
/// Call this in response to a `NeedData` event.
///
/// # Arguments
///
/// * `block` - The block number this data is for
/// * `data` - The data to send (must be <= 512 bytes)
///
/// # Returns
///
/// Returns the DATA packet to send, or marks the transfer complete if
/// this was the final block.
///
/// # Errors
///
/// Returns an error if the data is too large or the block number is wrong.
pub fn provide_data(&mut self, block: u16, data: Vec<u8>) -> Result<Event> {
if data.len() > MAX_DATA_SIZE {
return Err(Error::DataTooLarge { size: data.len() });
}
if block != self.current_block {
return Err(Error::BlockMismatch {
expected: self.current_block,
actual: block,
});
}
let is_final = data.len() < MAX_DATA_SIZE;
if is_final {
self.complete = true;
}
Ok(Event::Send(Packet::data(block, data)))
}
/// Mark the transfer as complete after the final ACK is received.
pub fn finish(&mut self) -> Event {
self.complete = true;
Event::Complete
}
}
/// Server-side protocol state machine.
///
/// Handles the server side of TFTP transfers.
#[derive(Debug)]
pub struct ServerState {
/// The direction of the transfer (from server's perspective)
direction: TransferDirection,
/// Transfer mode
mode: Mode,
/// The filename being transferred
filename: String,
/// Current block number
current_block: u16,
/// Whether the transfer is complete
complete: bool,
/// Whether we've sent the first response
sent_initial_response: bool,
/// Track if we're waiting for final ACK
waiting_for_final_ack: bool,
}
impl ServerState {
/// Create a new server state from a request packet.
///
/// # Errors
///
/// Returns an error if the packet is not an RRQ or WRQ.
pub fn from_request(packet: &Packet) -> Result<Self> {
match packet {
Packet::ReadRequest { filename, mode } => Ok(Self {
// Server uploads when client reads
direction: TransferDirection::Write,
mode: *mode,
filename: filename.clone(),
current_block: 1,
complete: false,
sent_initial_response: false,
waiting_for_final_ack: false,
}),
Packet::WriteRequest { filename, mode } => Ok(Self {
// Server downloads when client writes
direction: TransferDirection::Read,
mode: *mode,
filename: filename.clone(),
current_block: 0,
complete: false,
sent_initial_response: false,
waiting_for_final_ack: false,
}),
other => Err(Error::UnexpectedPacket {
expected: "RRQ or WRQ",
actual: other.type_name(),
}),
}
}
/// Get the filename being transferred.
#[must_use]
pub fn filename(&self) -> &str {
&self.filename
}
/// Get the transfer mode.
#[must_use]
pub const fn mode(&self) -> Mode {
self.mode
}
/// Get the transfer direction (from server's perspective).
#[must_use]
pub const fn direction(&self) -> TransferDirection {
self.direction
}
/// Check if the transfer is complete.
#[must_use]
pub const fn is_complete(&self) -> bool {
self.complete
}
/// Get the initial response event.
///
/// For RRQ: Returns `NeedData` to request the first block
/// For WRQ: Returns `Send(ACK 0)` to acknowledge the write request
pub fn start(&mut self) -> Event {
self.sent_initial_response = true;
match self.direction {
// Server sending data (client RRQ)
TransferDirection::Write => Event::NeedData { block: 1 },
// Server receiving data (client WRQ)
// RFC 1350 Section 4: "A WRQ is acknowledged with an ACK packet
// having a block number of zero."
TransferDirection::Read => Event::Send(Packet::ack(0)),
}
}
/// Process a received packet and return events.
///
/// # Errors
///
/// Returns an error if the packet is invalid for the current state.
pub fn receive(&mut self, packet: &Packet) -> Result<Vec<Event>> {
if self.complete {
return Err(Error::StateError("transfer already complete".to_string()));
}
match (self.direction, packet) {
// Server sending data, receiving ACKs
(TransferDirection::Write, Packet::Ack { block }) => self.handle_send_ack(*block),
// Server receiving data
(TransferDirection::Read, Packet::Data { block, data }) => {
self.handle_receive_data(*block, data)
}
// Error terminates transfer
(_, Packet::Error { code, message }) => {
self.complete = true;
Err(Error::TftpError {
code: *code,
message: message.clone(),
})
}
// Unexpected packet
(TransferDirection::Write, other) => Err(Error::UnexpectedPacket {
expected: "ACK",
actual: other.type_name(),
}),
(TransferDirection::Read, other) => Err(Error::UnexpectedPacket {
expected: "DATA",
actual: other.type_name(),
}),
}
}
/// Handle an ACK when server is sending data.
fn handle_send_ack(&mut self, block: u16) -> Result<Vec<Event>> {
if block != self.current_block {
// Duplicate ACK - could retransmit, but we'll just ignore
if block == self.current_block.wrapping_sub(1) {
return Ok(vec![]);
}
return Err(Error::BlockMismatch {
expected: self.current_block,
actual: block,
});
}
// If we were waiting for the final ACK, we're done
if self.waiting_for_final_ack {
self.complete = true;
return Ok(vec![Event::Complete]);
}
// Request next block
self.current_block = self.current_block.wrapping_add(1);
Ok(vec![Event::NeedData {
block: self.current_block,
}])
}
/// Handle a DATA packet when server is receiving data.
fn handle_receive_data(&mut self, block: u16, data: &[u8]) -> Result<Vec<Event>> {
let expected_block = self.current_block.wrapping_add(1);
if block != expected_block {
// Duplicate - re-ACK
if block == self.current_block {
return Ok(vec![Event::Send(Packet::ack(block))]);
}
return Err(Error::BlockMismatch {
expected: expected_block,
actual: block,
});
}
self.current_block = block;
let is_final = data.len() < MAX_DATA_SIZE;
let mut events = vec![
Event::ReceivedData {
block,
data: data.to_vec(),
},
Event::Send(Packet::ack(block)),
];
if is_final {
self.complete = true;
events.push(Event::Complete);
}
Ok(events)
}
/// Provide data for a read transfer (server sending to client).
///
/// # Errors
///
/// Returns an error if the data is too large.
pub fn provide_data(&mut self, block: u16, data: Vec<u8>) -> Result<Event> {
if data.len() > MAX_DATA_SIZE {
return Err(Error::DataTooLarge { size: data.len() });
}
if block != self.current_block {
return Err(Error::BlockMismatch {
expected: self.current_block,
actual: block,
});
}
let is_final = data.len() < MAX_DATA_SIZE;
if is_final {
self.waiting_for_final_ack = true;
}
Ok(Event::Send(Packet::data(block, data)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ErrorCode;
#[test]
fn test_client_read_single_block() {
let mut client = ClientState::new_read(Mode::Octet);
let request = client.start("test.txt");
assert!(matches!(
request,
Packet::ReadRequest { filename, mode: Mode::Octet } if filename == "test.txt"
));
// Receive single DATA block (< 512 bytes = final)
let events = client
.receive(&Packet::data(1, vec![1, 2, 3]))
.expect("receive failed");
assert_eq!(events.len(), 3);
assert!(matches!(&events[0], Event::ReceivedData { block: 1, data } if data == &[1, 2, 3]));
assert!(matches!(&events[1], Event::Send(Packet::Ack { block: 1 })));
assert!(matches!(&events[2], Event::Complete));
assert!(client.is_complete());
}
#[test]
fn test_client_read_multiple_blocks() {
let mut client = ClientState::new_read(Mode::Octet);
let _ = client.start("test.txt");
// Full block (512 bytes)
let full_data = vec![0u8; MAX_DATA_SIZE];
let events = client
.receive(&Packet::data(1, full_data.clone()))
.expect("receive");
assert_eq!(events.len(), 2);
assert!(!client.is_complete());
assert_eq!(client.current_block(), 2);
// Another full block
let events = client
.receive(&Packet::data(2, full_data))
.expect("receive");
assert_eq!(events.len(), 2);
assert!(!client.is_complete());
// Final block (< 512 bytes)
let events = client
.receive(&Packet::data(3, vec![1, 2, 3]))
.expect("receive");
assert_eq!(events.len(), 3);
assert!(client.is_complete());
}
#[test]
fn test_client_write_flow() {
let mut client = ClientState::new_write(Mode::Octet);
let request = client.start("output.txt");
assert!(matches!(request, Packet::WriteRequest { .. }));
// Receive ACK 0 (write request acknowledged)
let events = client.receive(&Packet::ack(0)).expect("receive");
assert_eq!(events.len(), 1);
assert!(matches!(events[0], Event::NeedData { block: 1 }));
// Provide data for block 1
let send_event = client.provide_data(1, vec![1, 2, 3]).expect("provide_data");
assert!(matches!(
send_event,
Event::Send(Packet::Data { block: 1, .. })
));
assert!(client.is_complete()); // < 512 bytes = final
}
#[test]
fn test_server_rrq() {
let request = Packet::rrq("file.txt", Mode::Octet);
let mut server = ServerState::from_request(&request).expect("from_request");
assert_eq!(server.filename(), "file.txt");
assert_eq!(server.direction(), TransferDirection::Write);
// Start should request data for block 1
let event = server.start();
assert!(matches!(event, Event::NeedData { block: 1 }));
// Provide data
let event = server.provide_data(1, vec![1, 2, 3]).expect("provide_data");
assert!(matches!(event, Event::Send(Packet::Data { block: 1, .. })));
// Receive ACK
let events = server.receive(&Packet::ack(1)).expect("receive");
assert!(matches!(&events[0], Event::Complete));
assert!(server.is_complete());
}
#[test]
fn test_server_wrq() {
let request = Packet::wrq("output.txt", Mode::Octet);
let mut server = ServerState::from_request(&request).expect("from_request");
assert_eq!(server.direction(), TransferDirection::Read);
// Start should send ACK 0
let event = server.start();
assert!(matches!(event, Event::Send(Packet::Ack { block: 0 })));
// Receive DATA block 1
let events = server
.receive(&Packet::data(1, vec![1, 2, 3]))
.expect("receive");
assert_eq!(events.len(), 3);
assert!(matches!(&events[0], Event::ReceivedData { block: 1, .. }));
assert!(matches!(&events[1], Event::Send(Packet::Ack { block: 1 })));
assert!(matches!(&events[2], Event::Complete));
}
#[test]
fn test_error_terminates_transfer() {
let mut client = ClientState::new_read(Mode::Octet);
let _ = client.start("test.txt");
let result = client.receive(&Packet::error(ErrorCode::FileNotFound, "File not found"));
assert!(result.is_err());
assert!(client.is_complete());
}
#[test]
fn test_duplicate_ack_ignored() {
let request = Packet::rrq("file.txt", Mode::Octet);
let mut server = ServerState::from_request(&request).expect("from_request");
let _ = server.start();
// Provide and send block 1
let _ = server
.provide_data(1, vec![0u8; MAX_DATA_SIZE])
.expect("provide_data");
// Receive ACK 1
let events = server.receive(&Packet::ack(1)).expect("receive");
assert!(!events.is_empty());
// Provide block 2
let _ = server.provide_data(2, vec![1, 2, 3]).expect("provide_data");
// Duplicate ACK 1 should be ignored
let events = server.receive(&Packet::ack(1)).expect("receive");
assert!(events.is_empty());
}
}