feat: implement TFTP protocol crate (pfs-tftp-proto)
This commit introduces a sans-IO TFTP protocol implementation following RFC 1350. The protocol crate provides: - Packet types: RRQ, WRQ, DATA, ACK, ERROR with full serialization/parsing - Error codes as defined in RFC 1350 Appendix - Transfer modes: octet (binary) and netascii - Client and server state machines for managing protocol flow - Comprehensive tests for all packet types and state transitions The sans-IO design separates protocol logic from I/O operations, making the code testable and reusable across different I/O implementations. Key design decisions: - Mail mode is explicitly rejected as obsolete per RFC 1350 - Block numbers wrap around at 65535 for large file support - State machines emit events that tell the I/O layer what to do next - All protocol-specific values are documented with RFC citations 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
15
crates/pfs-tftp-proto/Cargo.toml
Normal file
15
crates/pfs-tftp-proto/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[package]
|
||||
name = "pfs-tftp-proto"
|
||||
description = "Sans-IO TFTP protocol implementation (RFC 1350)"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
thiserror = "2"
|
||||
|
||||
[dev-dependencies]
|
||||
193
crates/pfs-tftp-proto/src/error.rs
Normal file
193
crates/pfs-tftp-proto/src/error.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
//! TFTP Error Types
|
||||
//!
|
||||
//! Defines error codes and error handling for TFTP protocol.
|
||||
//!
|
||||
//! # RFC 1350 Error Codes (Appendix)
|
||||
//!
|
||||
//! ```text
|
||||
//! Value Meaning
|
||||
//! 0 Not defined, see error message (if any).
|
||||
//! 1 File not found.
|
||||
//! 2 Access violation.
|
||||
//! 3 Disk full or allocation exceeded.
|
||||
//! 4 Illegal TFTP operation.
|
||||
//! 5 Unknown transfer ID.
|
||||
//! 6 File already exists.
|
||||
//! 7 No such user.
|
||||
//! ```
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// TFTP error codes as defined in RFC 1350 Appendix.
|
||||
///
|
||||
/// These codes are sent in ERROR packets to indicate the nature of the error.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u16)]
|
||||
pub enum ErrorCode {
|
||||
/// Not defined, see error message (if any).
|
||||
/// Used when the error doesn't fit other categories.
|
||||
NotDefined = 0,
|
||||
|
||||
/// File not found.
|
||||
/// The requested file does not exist on the server.
|
||||
FileNotFound = 1,
|
||||
|
||||
/// Access violation.
|
||||
/// The client does not have permission to access the file.
|
||||
AccessViolation = 2,
|
||||
|
||||
/// Disk full or allocation exceeded.
|
||||
/// There is insufficient storage space to complete the transfer.
|
||||
DiskFull = 3,
|
||||
|
||||
/// Illegal TFTP operation.
|
||||
/// The requested operation is not valid (e.g., malformed packet).
|
||||
IllegalOperation = 4,
|
||||
|
||||
/// Unknown transfer ID.
|
||||
/// RFC 1350 Section 4: "If a source TID does not match, the packet should
|
||||
/// be discarded as erroneously sent from somewhere else. An error packet
|
||||
/// should be sent to the source of the incorrect packet."
|
||||
UnknownTransferId = 5,
|
||||
|
||||
/// File already exists.
|
||||
/// Returned when attempting to write a file that already exists
|
||||
/// and the server policy doesn't allow overwriting.
|
||||
FileAlreadyExists = 6,
|
||||
|
||||
/// No such user.
|
||||
/// The specified user does not exist (relevant for mail mode, which is obsolete).
|
||||
NoSuchUser = 7,
|
||||
}
|
||||
|
||||
impl ErrorCode {
|
||||
/// Create an `ErrorCode` from a raw u16 value.
|
||||
///
|
||||
/// Returns `NotDefined` for any unknown error code, as per RFC 1350
|
||||
/// which states error code 0 is "Not defined".
|
||||
#[must_use]
|
||||
pub const fn from_u16(value: u16) -> Self {
|
||||
match value {
|
||||
1 => Self::FileNotFound,
|
||||
2 => Self::AccessViolation,
|
||||
3 => Self::DiskFull,
|
||||
4 => Self::IllegalOperation,
|
||||
5 => Self::UnknownTransferId,
|
||||
6 => Self::FileAlreadyExists,
|
||||
7 => Self::NoSuchUser,
|
||||
// RFC 1350: Error code 0 is "Not defined", also used as fallback
|
||||
_ => Self::NotDefined,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the default error message for this error code.
|
||||
#[must_use]
|
||||
pub const fn default_message(&self) -> &'static str {
|
||||
match self {
|
||||
Self::NotDefined => "Not defined",
|
||||
Self::FileNotFound => "File not found",
|
||||
Self::AccessViolation => "Access violation",
|
||||
Self::DiskFull => "Disk full or allocation exceeded",
|
||||
Self::IllegalOperation => "Illegal TFTP operation",
|
||||
Self::UnknownTransferId => "Unknown transfer ID",
|
||||
Self::FileAlreadyExists => "File already exists",
|
||||
Self::NoSuchUser => "No such user",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ErrorCode> for u16 {
|
||||
fn from(code: ErrorCode) -> Self {
|
||||
code as Self
|
||||
}
|
||||
}
|
||||
|
||||
/// Protocol-level errors for TFTP operations.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
/// Packet is too short to be valid.
|
||||
#[error("packet too short: expected at least {expected} bytes, got {actual}")]
|
||||
PacketTooShort { expected: usize, actual: usize },
|
||||
|
||||
/// Invalid opcode in packet.
|
||||
#[error("invalid opcode: {0}")]
|
||||
InvalidOpcode(u16),
|
||||
|
||||
/// Invalid transfer mode in RRQ/WRQ packet.
|
||||
#[error("invalid mode: {0}")]
|
||||
InvalidMode(String),
|
||||
|
||||
/// Missing null terminator in string field.
|
||||
#[error("missing null terminator in {field}")]
|
||||
MissingNullTerminator { field: &'static str },
|
||||
|
||||
/// String field contains invalid UTF-8.
|
||||
#[error("invalid UTF-8 in {field}: {source}")]
|
||||
InvalidUtf8 {
|
||||
field: &'static str,
|
||||
#[source]
|
||||
source: std::str::Utf8Error,
|
||||
},
|
||||
|
||||
/// Data exceeds maximum block size.
|
||||
#[error("data too large: {size} bytes exceeds maximum of 512")]
|
||||
DataTooLarge { size: usize },
|
||||
|
||||
/// Buffer too small for serialization.
|
||||
#[error("buffer too small: need {needed} bytes, have {available}")]
|
||||
BufferTooSmall { needed: usize, available: usize },
|
||||
|
||||
/// TFTP error received from remote peer.
|
||||
#[error("TFTP error {code:?}: {message}")]
|
||||
TftpError { code: ErrorCode, message: String },
|
||||
|
||||
/// Protocol state error.
|
||||
#[error("protocol state error: {0}")]
|
||||
StateError(String),
|
||||
|
||||
/// Unexpected packet type received.
|
||||
#[error("unexpected packet type: expected {expected}, got {actual}")]
|
||||
UnexpectedPacket {
|
||||
expected: &'static str,
|
||||
actual: &'static str,
|
||||
},
|
||||
|
||||
/// Block number mismatch.
|
||||
#[error("block number mismatch: expected {expected}, got {actual}")]
|
||||
BlockMismatch { expected: u16, actual: u16 },
|
||||
}
|
||||
|
||||
/// Result type for TFTP operations.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_error_code_from_u16() {
|
||||
assert_eq!(ErrorCode::from_u16(0), ErrorCode::NotDefined);
|
||||
assert_eq!(ErrorCode::from_u16(1), ErrorCode::FileNotFound);
|
||||
assert_eq!(ErrorCode::from_u16(2), ErrorCode::AccessViolation);
|
||||
assert_eq!(ErrorCode::from_u16(3), ErrorCode::DiskFull);
|
||||
assert_eq!(ErrorCode::from_u16(4), ErrorCode::IllegalOperation);
|
||||
assert_eq!(ErrorCode::from_u16(5), ErrorCode::UnknownTransferId);
|
||||
assert_eq!(ErrorCode::from_u16(6), ErrorCode::FileAlreadyExists);
|
||||
assert_eq!(ErrorCode::from_u16(7), ErrorCode::NoSuchUser);
|
||||
// Unknown codes should map to NotDefined
|
||||
assert_eq!(ErrorCode::from_u16(8), ErrorCode::NotDefined);
|
||||
assert_eq!(ErrorCode::from_u16(255), ErrorCode::NotDefined);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_code_to_u16() {
|
||||
assert_eq!(u16::from(ErrorCode::NotDefined), 0);
|
||||
assert_eq!(u16::from(ErrorCode::FileNotFound), 1);
|
||||
assert_eq!(u16::from(ErrorCode::AccessViolation), 2);
|
||||
assert_eq!(u16::from(ErrorCode::DiskFull), 3);
|
||||
assert_eq!(u16::from(ErrorCode::IllegalOperation), 4);
|
||||
assert_eq!(u16::from(ErrorCode::UnknownTransferId), 5);
|
||||
assert_eq!(u16::from(ErrorCode::FileAlreadyExists), 6);
|
||||
assert_eq!(u16::from(ErrorCode::NoSuchUser), 7);
|
||||
}
|
||||
}
|
||||
21
crates/pfs-tftp-proto/src/lib.rs
Normal file
21
crates/pfs-tftp-proto/src/lib.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
//! Sans-IO TFTP Protocol Implementation (RFC 1350)
|
||||
//!
|
||||
//! This crate provides a pure protocol implementation without any I/O operations.
|
||||
//! It handles packet parsing, serialization, and protocol state management.
|
||||
//!
|
||||
//! # RFC 1350 Overview
|
||||
//!
|
||||
//! TFTP (Trivial File Transfer Protocol) is a simple file transfer protocol
|
||||
//! that operates over UDP. Key characteristics:
|
||||
//! - Uses fixed 512-byte data blocks
|
||||
//! - Stop-and-wait protocol (each packet must be acknowledged)
|
||||
//! - Five packet types: RRQ, WRQ, DATA, ACK, ERROR
|
||||
//! - Default port 69 for initial server connection
|
||||
|
||||
mod error;
|
||||
mod packet;
|
||||
mod state;
|
||||
|
||||
pub use error::{Error, ErrorCode, Result};
|
||||
pub use packet::{Mode, Opcode, Packet, MAX_DATA_SIZE, TFTP_PORT};
|
||||
pub use state::{ClientState, Event, ServerState, TransferDirection};
|
||||
632
crates/pfs-tftp-proto/src/packet.rs
Normal file
632
crates/pfs-tftp-proto/src/packet.rs
Normal file
@@ -0,0 +1,632 @@
|
||||
//! TFTP Packet Types and Serialization
|
||||
//!
|
||||
//! This module implements all five TFTP packet types as defined in RFC 1350:
|
||||
//! - RRQ (Read Request)
|
||||
//! - WRQ (Write Request)
|
||||
//! - DATA
|
||||
//! - ACK
|
||||
//! - ERROR
|
||||
//!
|
||||
//! # RFC 1350 Section 5 - TFTP Packets
|
||||
//!
|
||||
//! ```text
|
||||
//! TFTP supports five types of packets:
|
||||
//!
|
||||
//! opcode operation
|
||||
//! 1 Read request (RRQ)
|
||||
//! 2 Write request (WRQ)
|
||||
//! 3 Data (DATA)
|
||||
//! 4 Acknowledgment (ACK)
|
||||
//! 5 Error (ERROR)
|
||||
//! ```
|
||||
|
||||
use crate::{Error, ErrorCode, Result};
|
||||
|
||||
/// Standard TFTP server port as defined in RFC 1350.
|
||||
///
|
||||
/// RFC 1350 Section 4: "A requesting host chooses its source TID as described
|
||||
/// above, and sends its initial request to the known TID 69 decimal (105 octal)
|
||||
/// on the serving host."
|
||||
pub const TFTP_PORT: u16 = 69;
|
||||
|
||||
/// Maximum data size in a DATA packet.
|
||||
///
|
||||
/// RFC 1350 Section 2: "the connection is opened and the file is sent in fixed
|
||||
/// length blocks of 512 bytes."
|
||||
///
|
||||
/// RFC 1350 Section 5: "The data field is from zero to 512 bytes long."
|
||||
pub const MAX_DATA_SIZE: usize = 512;
|
||||
|
||||
/// Minimum packet size (just an opcode).
|
||||
const MIN_PACKET_SIZE: usize = 2;
|
||||
|
||||
/// TFTP packet opcodes.
|
||||
///
|
||||
/// RFC 1350 Section 5: "The TFTP header of a packet contains the opcode
|
||||
/// associated with that packet."
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(u16)]
|
||||
pub enum Opcode {
|
||||
/// Read request (RRQ)
|
||||
Rrq = 1,
|
||||
/// Write request (WRQ)
|
||||
Wrq = 2,
|
||||
/// Data packet
|
||||
Data = 3,
|
||||
/// Acknowledgment
|
||||
Ack = 4,
|
||||
/// Error
|
||||
Error = 5,
|
||||
}
|
||||
|
||||
impl Opcode {
|
||||
/// Create an `Opcode` from a raw u16 value.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Error::InvalidOpcode` if the value is not a valid opcode.
|
||||
pub const fn from_u16(value: u16) -> Result<Self> {
|
||||
match value {
|
||||
1 => Ok(Self::Rrq),
|
||||
2 => Ok(Self::Wrq),
|
||||
3 => Ok(Self::Data),
|
||||
4 => Ok(Self::Ack),
|
||||
5 => Ok(Self::Error),
|
||||
_ => Err(Error::InvalidOpcode(value)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Opcode> for u16 {
|
||||
fn from(opcode: Opcode) -> Self {
|
||||
opcode as Self
|
||||
}
|
||||
}
|
||||
|
||||
/// Transfer mode for TFTP operations.
|
||||
///
|
||||
/// RFC 1350 Section 1: "Three modes of transfer are currently supported:
|
||||
/// netascii [...]; octet [...]; mail [...]. (The mail mode is obsolete
|
||||
/// and should not be implemented or used.)"
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum Mode {
|
||||
/// `NetASCII` mode - text transfer with CR/LF line endings.
|
||||
///
|
||||
/// RFC 1350: "netascii (This is ascii as defined in 'USA Standard Code
|
||||
/// for Information Interchange' with the modifications specified in
|
||||
/// 'Telnet Protocol Specification'.) Note that it is 8 bit ascii."
|
||||
NetAscii,
|
||||
|
||||
/// Octet mode - raw binary transfer.
|
||||
///
|
||||
/// RFC 1350: "octet (This replaces the 'binary' mode of previous versions
|
||||
/// of this document.) raw 8 bit bytes [...] If a host receives a octet
|
||||
/// file and then returns it, the returned file must be identical to the
|
||||
/// original."
|
||||
#[default]
|
||||
Octet,
|
||||
}
|
||||
|
||||
impl Mode {
|
||||
/// Parse a mode string (case-insensitive).
|
||||
///
|
||||
/// RFC 1350 Section 5: "The mode field contains the string 'netascii',
|
||||
/// 'octet', or 'mail' (or any combination of upper and lower case,
|
||||
/// such as 'NETASCII', '`NetAscii`', etc.)"
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Error::InvalidMode` if the string is not a recognized mode.
|
||||
pub fn parse(s: &str) -> Result<Self> {
|
||||
match s.to_ascii_lowercase().as_str() {
|
||||
"netascii" => Ok(Self::NetAscii),
|
||||
"octet" => Ok(Self::Octet),
|
||||
// Mail mode is obsolete per RFC 1350
|
||||
"mail" => Err(Error::InvalidMode(
|
||||
"mail mode is obsolete and not supported".to_string(),
|
||||
)),
|
||||
_ => Err(Error::InvalidMode(s.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the string representation of the mode.
|
||||
#[must_use]
|
||||
pub const fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::NetAscii => "netascii",
|
||||
Self::Octet => "octet",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A TFTP packet.
|
||||
///
|
||||
/// All packet formats are defined in RFC 1350 Section 5 and Appendix I.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Packet {
|
||||
/// Read Request (RRQ) - opcode 1
|
||||
///
|
||||
/// RFC 1350 Figure 5-1:
|
||||
/// ```text
|
||||
/// 2 bytes string 1 byte string 1 byte
|
||||
/// ------------------------------------------------
|
||||
/// | Opcode | Filename | 0 | Mode | 0 |
|
||||
/// ------------------------------------------------
|
||||
/// ```
|
||||
ReadRequest { filename: String, mode: Mode },
|
||||
|
||||
/// Write Request (WRQ) - opcode 2
|
||||
///
|
||||
/// Same format as RRQ (RFC 1350 Figure 5-1).
|
||||
WriteRequest { filename: String, mode: Mode },
|
||||
|
||||
/// Data packet - opcode 3
|
||||
///
|
||||
/// RFC 1350 Figure 5-2:
|
||||
/// ```text
|
||||
/// 2 bytes 2 bytes n bytes
|
||||
/// ----------------------------------
|
||||
/// | Opcode | Block # | Data |
|
||||
/// ----------------------------------
|
||||
/// ```
|
||||
///
|
||||
/// RFC 1350 Section 5: "The block numbers on data packets begin with one
|
||||
/// and increase by one for each new block of data. [...] The data field
|
||||
/// is from zero to 512 bytes long. If it is 512 bytes long, the block is
|
||||
/// not the last block of data; if it is from zero to 511 bytes long, it
|
||||
/// signals the end of the transfer."
|
||||
Data { block: u16, data: Vec<u8> },
|
||||
|
||||
/// Acknowledgment packet - opcode 4
|
||||
///
|
||||
/// RFC 1350 Figure 5-3:
|
||||
/// ```text
|
||||
/// 2 bytes 2 bytes
|
||||
/// ---------------------
|
||||
/// | Opcode | Block # |
|
||||
/// ---------------------
|
||||
/// ```
|
||||
///
|
||||
/// RFC 1350 Section 5: "The block number in an ACK echoes the block number
|
||||
/// of the DATA packet being acknowledged. A WRQ is acknowledged with an
|
||||
/// ACK packet having a block number of zero."
|
||||
Ack { block: u16 },
|
||||
|
||||
/// Error packet - opcode 5
|
||||
///
|
||||
/// RFC 1350 Figure 5-4:
|
||||
/// ```text
|
||||
/// 2 bytes 2 bytes string 1 byte
|
||||
/// -----------------------------------------
|
||||
/// | Opcode | ErrorCode | ErrMsg | 0 |
|
||||
/// -----------------------------------------
|
||||
/// ```
|
||||
///
|
||||
/// RFC 1350 Section 5: "An ERROR packet can be the acknowledgment of any
|
||||
/// other type of packet. The error code is an integer indicating the nature
|
||||
/// of the error. [...] The error message is intended for human consumption,
|
||||
/// and should be in netascii."
|
||||
Error { code: ErrorCode, message: String },
|
||||
}
|
||||
|
||||
impl Packet {
|
||||
/// Parse a TFTP packet from raw bytes.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the packet is malformed.
|
||||
pub fn parse(data: &[u8]) -> Result<Self> {
|
||||
if data.len() < MIN_PACKET_SIZE {
|
||||
return Err(Error::PacketTooShort {
|
||||
expected: MIN_PACKET_SIZE,
|
||||
actual: data.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let opcode = u16::from_be_bytes([data[0], data[1]]);
|
||||
let opcode = Opcode::from_u16(opcode)?;
|
||||
let payload = &data[2..];
|
||||
|
||||
match opcode {
|
||||
Opcode::Rrq => Self::parse_request(payload, false),
|
||||
Opcode::Wrq => Self::parse_request(payload, true),
|
||||
Opcode::Data => Self::parse_data(payload),
|
||||
Opcode::Ack => Self::parse_ack(payload),
|
||||
Opcode::Error => Self::parse_error(payload),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a RRQ or WRQ packet payload.
|
||||
fn parse_request(payload: &[u8], is_write: bool) -> Result<Self> {
|
||||
// Find the null terminator for filename
|
||||
let filename_end = payload
|
||||
.iter()
|
||||
.position(|&b| b == 0)
|
||||
.ok_or(Error::MissingNullTerminator { field: "filename" })?;
|
||||
|
||||
let filename = std::str::from_utf8(&payload[..filename_end]).map_err(|e| {
|
||||
Error::InvalidUtf8 {
|
||||
field: "filename",
|
||||
source: e,
|
||||
}
|
||||
})?;
|
||||
|
||||
// Find the mode string after the filename null terminator
|
||||
let mode_start = filename_end + 1;
|
||||
if mode_start >= payload.len() {
|
||||
return Err(Error::MissingNullTerminator { field: "mode" });
|
||||
}
|
||||
|
||||
let mode_end = payload[mode_start..]
|
||||
.iter()
|
||||
.position(|&b| b == 0)
|
||||
.ok_or(Error::MissingNullTerminator { field: "mode" })?
|
||||
+ mode_start;
|
||||
|
||||
let mode_str =
|
||||
std::str::from_utf8(&payload[mode_start..mode_end]).map_err(|e| Error::InvalidUtf8 {
|
||||
field: "mode",
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
let mode = Mode::parse(mode_str)?;
|
||||
|
||||
if is_write {
|
||||
Ok(Self::WriteRequest {
|
||||
filename: filename.to_string(),
|
||||
mode,
|
||||
})
|
||||
} else {
|
||||
Ok(Self::ReadRequest {
|
||||
filename: filename.to_string(),
|
||||
mode,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a DATA packet payload.
|
||||
fn parse_data(payload: &[u8]) -> Result<Self> {
|
||||
if payload.len() < 2 {
|
||||
return Err(Error::PacketTooShort {
|
||||
expected: 4, // 2 opcode + 2 block
|
||||
actual: payload.len() + 2,
|
||||
});
|
||||
}
|
||||
|
||||
let block = u16::from_be_bytes([payload[0], payload[1]]);
|
||||
let data = payload[2..].to_vec();
|
||||
|
||||
if data.len() > MAX_DATA_SIZE {
|
||||
return Err(Error::DataTooLarge { size: data.len() });
|
||||
}
|
||||
|
||||
Ok(Self::Data { block, data })
|
||||
}
|
||||
|
||||
/// Parse an ACK packet payload.
|
||||
fn parse_ack(payload: &[u8]) -> Result<Self> {
|
||||
if payload.len() < 2 {
|
||||
return Err(Error::PacketTooShort {
|
||||
expected: 4, // 2 opcode + 2 block
|
||||
actual: payload.len() + 2,
|
||||
});
|
||||
}
|
||||
|
||||
let block = u16::from_be_bytes([payload[0], payload[1]]);
|
||||
Ok(Self::Ack { block })
|
||||
}
|
||||
|
||||
/// Parse an ERROR packet payload.
|
||||
fn parse_error(payload: &[u8]) -> Result<Self> {
|
||||
if payload.len() < 2 {
|
||||
return Err(Error::PacketTooShort {
|
||||
expected: 4, // 2 opcode + 2 error code
|
||||
actual: payload.len() + 2,
|
||||
});
|
||||
}
|
||||
|
||||
let code = u16::from_be_bytes([payload[0], payload[1]]);
|
||||
let code = ErrorCode::from_u16(code);
|
||||
|
||||
// Find the null terminator for error message
|
||||
let msg_start = 2;
|
||||
let msg_end = payload[msg_start..]
|
||||
.iter()
|
||||
.position(|&b| b == 0)
|
||||
.map_or(payload.len(), |pos| pos + msg_start);
|
||||
|
||||
let message = std::str::from_utf8(&payload[msg_start..msg_end])
|
||||
.map_err(|e| Error::InvalidUtf8 {
|
||||
field: "error message",
|
||||
source: e,
|
||||
})?
|
||||
.to_string();
|
||||
|
||||
Ok(Self::Error { code, message })
|
||||
}
|
||||
|
||||
/// Serialize the packet into bytes.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the data is too large.
|
||||
pub fn serialize(&self) -> Result<Vec<u8>> {
|
||||
match self {
|
||||
Self::ReadRequest { filename, mode } => {
|
||||
Ok(Self::serialize_request(Opcode::Rrq, filename, *mode))
|
||||
}
|
||||
Self::WriteRequest { filename, mode } => {
|
||||
Ok(Self::serialize_request(Opcode::Wrq, filename, *mode))
|
||||
}
|
||||
Self::Data { block, data } => Self::serialize_data(*block, data),
|
||||
Self::Ack { block } => Ok(Self::serialize_ack(*block)),
|
||||
Self::Error { code, message } => Ok(Self::serialize_error(*code, message)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize a RRQ or WRQ packet.
|
||||
fn serialize_request(opcode: Opcode, filename: &str, mode: Mode) -> Vec<u8> {
|
||||
let mode_str = mode.as_str();
|
||||
let capacity = 2 + filename.len() + 1 + mode_str.len() + 1;
|
||||
let mut buf = Vec::with_capacity(capacity);
|
||||
|
||||
buf.extend_from_slice(&u16::from(opcode).to_be_bytes());
|
||||
buf.extend_from_slice(filename.as_bytes());
|
||||
buf.push(0);
|
||||
buf.extend_from_slice(mode_str.as_bytes());
|
||||
buf.push(0);
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
/// Serialize a DATA packet.
|
||||
fn serialize_data(block: u16, data: &[u8]) -> Result<Vec<u8>> {
|
||||
if data.len() > MAX_DATA_SIZE {
|
||||
return Err(Error::DataTooLarge { size: data.len() });
|
||||
}
|
||||
|
||||
let mut buf = Vec::with_capacity(4 + data.len());
|
||||
buf.extend_from_slice(&u16::from(Opcode::Data).to_be_bytes());
|
||||
buf.extend_from_slice(&block.to_be_bytes());
|
||||
buf.extend_from_slice(data);
|
||||
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
/// Serialize an ACK packet.
|
||||
fn serialize_ack(block: u16) -> Vec<u8> {
|
||||
let mut buf = Vec::with_capacity(4);
|
||||
buf.extend_from_slice(&u16::from(Opcode::Ack).to_be_bytes());
|
||||
buf.extend_from_slice(&block.to_be_bytes());
|
||||
buf
|
||||
}
|
||||
|
||||
/// Serialize an ERROR packet.
|
||||
fn serialize_error(code: ErrorCode, message: &str) -> Vec<u8> {
|
||||
let capacity = 4 + message.len() + 1;
|
||||
let mut buf = Vec::with_capacity(capacity);
|
||||
|
||||
buf.extend_from_slice(&u16::from(Opcode::Error).to_be_bytes());
|
||||
buf.extend_from_slice(&u16::from(code).to_be_bytes());
|
||||
buf.extend_from_slice(message.as_bytes());
|
||||
buf.push(0);
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
/// Get the opcode of this packet.
|
||||
#[must_use]
|
||||
pub const fn opcode(&self) -> Opcode {
|
||||
match self {
|
||||
Self::ReadRequest { .. } => Opcode::Rrq,
|
||||
Self::WriteRequest { .. } => Opcode::Wrq,
|
||||
Self::Data { .. } => Opcode::Data,
|
||||
Self::Ack { .. } => Opcode::Ack,
|
||||
Self::Error { .. } => Opcode::Error,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a human-readable name for the packet type.
|
||||
#[must_use]
|
||||
pub const fn type_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::ReadRequest { .. } => "RRQ",
|
||||
Self::WriteRequest { .. } => "WRQ",
|
||||
Self::Data { .. } => "DATA",
|
||||
Self::Ack { .. } => "ACK",
|
||||
Self::Error { .. } => "ERROR",
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this is the final data packet (less than 512 bytes).
|
||||
///
|
||||
/// RFC 1350 Section 5: "If it is 512 bytes long, the block is not the last
|
||||
/// block of data; if it is from zero to 511 bytes long, it signals the end
|
||||
/// of the transfer."
|
||||
#[must_use]
|
||||
pub const fn is_final_data(&self) -> bool {
|
||||
matches!(self, Self::Data { data, .. } if data.len() < MAX_DATA_SIZE)
|
||||
}
|
||||
|
||||
/// Create a new RRQ packet.
|
||||
#[must_use]
|
||||
pub fn rrq(filename: impl Into<String>, mode: Mode) -> Self {
|
||||
Self::ReadRequest {
|
||||
filename: filename.into(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new WRQ packet.
|
||||
#[must_use]
|
||||
pub fn wrq(filename: impl Into<String>, mode: Mode) -> Self {
|
||||
Self::WriteRequest {
|
||||
filename: filename.into(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new DATA packet.
|
||||
#[must_use]
|
||||
pub fn data(block: u16, data: Vec<u8>) -> Self {
|
||||
Self::Data { block, data }
|
||||
}
|
||||
|
||||
/// Create a new ACK packet.
|
||||
#[must_use]
|
||||
pub const fn ack(block: u16) -> Self {
|
||||
Self::Ack { block }
|
||||
}
|
||||
|
||||
/// Create a new ERROR packet.
|
||||
#[must_use]
|
||||
pub fn error(code: ErrorCode, message: impl Into<String>) -> Self {
|
||||
Self::Error {
|
||||
code,
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an ERROR packet with default message for the error code.
|
||||
#[must_use]
|
||||
pub fn error_default(code: ErrorCode) -> Self {
|
||||
Self::Error {
|
||||
code,
|
||||
message: code.default_message().to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_rrq_roundtrip() {
|
||||
let packet = Packet::rrq("test.txt", Mode::Octet);
|
||||
let bytes = packet.serialize().expect("serialize failed");
|
||||
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||
assert_eq!(packet, parsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wrq_roundtrip() {
|
||||
let packet = Packet::wrq("output.bin", Mode::NetAscii);
|
||||
let bytes = packet.serialize().expect("serialize failed");
|
||||
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||
assert_eq!(packet, parsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_data_roundtrip() {
|
||||
let packet = Packet::data(1, vec![1, 2, 3, 4, 5]);
|
||||
let bytes = packet.serialize().expect("serialize failed");
|
||||
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||
assert_eq!(packet, parsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_data_empty() {
|
||||
let packet = Packet::data(42, vec![]);
|
||||
let bytes = packet.serialize().expect("serialize failed");
|
||||
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||
assert_eq!(packet, parsed);
|
||||
assert!(packet.is_final_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_data_max_size() {
|
||||
let packet = Packet::data(100, vec![0xAB; MAX_DATA_SIZE]);
|
||||
let bytes = packet.serialize().expect("serialize failed");
|
||||
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||
assert_eq!(packet, parsed);
|
||||
assert!(!packet.is_final_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_data_too_large() {
|
||||
let packet = Packet::data(1, vec![0; MAX_DATA_SIZE + 1]);
|
||||
let result = packet.serialize();
|
||||
assert!(matches!(result, Err(Error::DataTooLarge { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ack_roundtrip() {
|
||||
let packet = Packet::ack(0);
|
||||
let bytes = packet.serialize().expect("serialize failed");
|
||||
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||
assert_eq!(packet, parsed);
|
||||
|
||||
let packet = Packet::ack(65535);
|
||||
let bytes = packet.serialize().expect("serialize failed");
|
||||
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||
assert_eq!(packet, parsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_roundtrip() {
|
||||
let packet = Packet::error(ErrorCode::FileNotFound, "File not found");
|
||||
let bytes = packet.serialize().expect("serialize failed");
|
||||
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||
assert_eq!(packet, parsed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_default_message() {
|
||||
let packet = Packet::error_default(ErrorCode::AccessViolation);
|
||||
assert!(matches!(
|
||||
packet,
|
||||
Packet::Error {
|
||||
code: ErrorCode::AccessViolation,
|
||||
message
|
||||
} if message == "Access violation"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mode_case_insensitive() {
|
||||
assert_eq!(Mode::parse("OCTET").expect("parse"), Mode::Octet);
|
||||
assert_eq!(Mode::parse("Octet").expect("parse"), Mode::Octet);
|
||||
assert_eq!(Mode::parse("netascii").expect("parse"), Mode::NetAscii);
|
||||
assert_eq!(Mode::parse("NETASCII").expect("parse"), Mode::NetAscii);
|
||||
assert_eq!(Mode::parse("NetAscii").expect("parse"), Mode::NetAscii);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_mode() {
|
||||
assert!(matches!(
|
||||
Mode::parse("binary"),
|
||||
Err(Error::InvalidMode(_))
|
||||
));
|
||||
assert!(matches!(Mode::parse("mail"), Err(Error::InvalidMode(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_opcode() {
|
||||
let data = [0x00, 0x06]; // opcode 6 is invalid
|
||||
let result = Packet::parse(&data);
|
||||
assert!(matches!(result, Err(Error::InvalidOpcode(6))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_packet_too_short() {
|
||||
let data = [0x00];
|
||||
let result = Packet::parse(&data);
|
||||
assert!(matches!(result, Err(Error::PacketTooShort { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rrq_parse_raw() {
|
||||
// RRQ for "test.txt" in octet mode
|
||||
let data = [
|
||||
0x00, 0x01, // opcode = RRQ
|
||||
b't', b'e', b's', b't', b'.', b't', b'x', b't', 0x00, // filename
|
||||
b'o', b'c', b't', b'e', b't', 0x00, // mode
|
||||
];
|
||||
let packet = Packet::parse(&data).expect("parse failed");
|
||||
assert!(matches!(
|
||||
packet,
|
||||
Packet::ReadRequest { filename, mode: Mode::Octet } if filename == "test.txt"
|
||||
));
|
||||
}
|
||||
}
|
||||
661
crates/pfs-tftp-proto/src/state.rs
Normal file
661
crates/pfs-tftp-proto/src/state.rs
Normal file
@@ -0,0 +1,661 @@
|
||||
//! TFTP Protocol State Machine
|
||||
//!
|
||||
//! This module implements sans-io state machines for both client and server
|
||||
//! sides of the TFTP protocol.
|
||||
//!
|
||||
//! # RFC 1350 Protocol Overview
|
||||
//!
|
||||
//! The protocol follows a lock-step acknowledgment pattern:
|
||||
//!
|
||||
//! ## Read Transfer (Client reading from Server)
|
||||
//! ```text
|
||||
//! Client -> Server: RRQ (filename, mode)
|
||||
//! Server -> Client: DATA (block 1)
|
||||
//! Client -> Server: ACK (block 1)
|
||||
//! Server -> Client: DATA (block 2)
|
||||
//! ... (continues until final DATA with < 512 bytes)
|
||||
//! ```
|
||||
//!
|
||||
//! ## Write Transfer (Client writing to Server)
|
||||
//! ```text
|
||||
//! Client -> Server: WRQ (filename, mode)
|
||||
//! Server -> Client: ACK (block 0)
|
||||
//! Client -> Server: DATA (block 1)
|
||||
//! Server -> Client: ACK (block 1)
|
||||
//! ... (continues until final DATA with < 512 bytes)
|
||||
//! ```
|
||||
|
||||
use crate::{Error, Mode, Packet, Result, MAX_DATA_SIZE};
|
||||
|
||||
/// Direction of data transfer.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TransferDirection {
|
||||
/// Reading data from remote (client downloads, server uploads)
|
||||
Read,
|
||||
/// Writing data to remote (client uploads, server downloads)
|
||||
Write,
|
||||
}
|
||||
|
||||
/// Events emitted by the state machine.
|
||||
///
|
||||
/// These events tell the I/O layer what action to take next.
|
||||
/// Errors are returned via `Result` from the state machine methods,
|
||||
/// not as events.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Event {
|
||||
/// Send this packet to the remote peer.
|
||||
Send(Packet),
|
||||
|
||||
/// Request data from the application for the given block number.
|
||||
///
|
||||
/// The application should provide up to 512 bytes of data.
|
||||
/// A response with less than 512 bytes indicates the final block.
|
||||
NeedData { block: u16 },
|
||||
|
||||
/// Received data from remote peer.
|
||||
///
|
||||
/// The application should store this data.
|
||||
ReceivedData { block: u16, data: Vec<u8> },
|
||||
|
||||
/// Transfer completed successfully.
|
||||
Complete,
|
||||
}
|
||||
|
||||
/// Client-side protocol state machine.
|
||||
///
|
||||
/// Handles the client side of TFTP transfers (both read and write).
|
||||
#[derive(Debug)]
|
||||
pub struct ClientState {
|
||||
/// The direction of the transfer
|
||||
direction: TransferDirection,
|
||||
/// Transfer mode
|
||||
mode: Mode,
|
||||
/// Current block number we're expecting/sending
|
||||
current_block: u16,
|
||||
/// Whether the transfer is complete
|
||||
complete: bool,
|
||||
/// Whether we've received the initial response
|
||||
got_initial_response: bool,
|
||||
}
|
||||
|
||||
impl ClientState {
|
||||
/// Create a new client state for a read request.
|
||||
///
|
||||
/// Call `start()` to get the initial RRQ packet to send.
|
||||
#[must_use]
|
||||
pub const fn new_read(mode: Mode) -> Self {
|
||||
Self {
|
||||
direction: TransferDirection::Read,
|
||||
mode,
|
||||
current_block: 1, // Expecting DATA block 1
|
||||
complete: false,
|
||||
got_initial_response: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client state for a write request.
|
||||
///
|
||||
/// Call `start()` to get the initial WRQ packet to send.
|
||||
#[must_use]
|
||||
pub const fn new_write(mode: Mode) -> Self {
|
||||
Self {
|
||||
direction: TransferDirection::Write,
|
||||
mode,
|
||||
current_block: 0, // Expecting ACK block 0
|
||||
complete: false,
|
||||
got_initial_response: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the initial request packet to send.
|
||||
#[must_use]
|
||||
pub fn start(&self, filename: &str) -> Packet {
|
||||
match self.direction {
|
||||
TransferDirection::Read => Packet::rrq(filename, self.mode),
|
||||
TransferDirection::Write => Packet::wrq(filename, self.mode),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the transfer is complete.
|
||||
#[must_use]
|
||||
pub const fn is_complete(&self) -> bool {
|
||||
self.complete
|
||||
}
|
||||
|
||||
/// Get the current expected block number.
|
||||
#[must_use]
|
||||
pub const fn current_block(&self) -> u16 {
|
||||
self.current_block
|
||||
}
|
||||
|
||||
/// Process a received packet and return events.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the packet is invalid for the current state.
|
||||
pub fn receive(&mut self, packet: &Packet) -> Result<Vec<Event>> {
|
||||
if self.complete {
|
||||
return Err(Error::StateError("transfer already complete".to_string()));
|
||||
}
|
||||
|
||||
match (self.direction, packet) {
|
||||
// Reading: expect DATA packets
|
||||
(TransferDirection::Read, Packet::Data { block, data }) => {
|
||||
self.handle_read_data(*block, data)
|
||||
}
|
||||
|
||||
// Writing: expect ACK packets
|
||||
(TransferDirection::Write, Packet::Ack { block }) => self.handle_write_ack(*block),
|
||||
|
||||
// Error packet terminates transfer
|
||||
(_, Packet::Error { code, message }) => {
|
||||
self.complete = true;
|
||||
Err(Error::TftpError {
|
||||
code: *code,
|
||||
message: message.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
// Unexpected packet type
|
||||
(TransferDirection::Read, other) => Err(Error::UnexpectedPacket {
|
||||
expected: "DATA",
|
||||
actual: other.type_name(),
|
||||
}),
|
||||
|
||||
(TransferDirection::Write, other) => Err(Error::UnexpectedPacket {
|
||||
expected: "ACK",
|
||||
actual: other.type_name(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a DATA packet during a read transfer.
|
||||
fn handle_read_data(&mut self, block: u16, data: &[u8]) -> Result<Vec<Event>> {
|
||||
self.got_initial_response = true;
|
||||
|
||||
// RFC 1350: Block numbers start at 1 and increase sequentially
|
||||
if block != self.current_block {
|
||||
// Could be a duplicate - if it's the previous block, re-send ACK
|
||||
if block == self.current_block.wrapping_sub(1) {
|
||||
return Ok(vec![Event::Send(Packet::ack(block))]);
|
||||
}
|
||||
return Err(Error::BlockMismatch {
|
||||
expected: self.current_block,
|
||||
actual: block,
|
||||
});
|
||||
}
|
||||
|
||||
let is_final = data.len() < MAX_DATA_SIZE;
|
||||
let mut events = vec![
|
||||
Event::ReceivedData {
|
||||
block,
|
||||
data: data.to_vec(),
|
||||
},
|
||||
Event::Send(Packet::ack(block)),
|
||||
];
|
||||
|
||||
if is_final {
|
||||
self.complete = true;
|
||||
events.push(Event::Complete);
|
||||
} else {
|
||||
self.current_block = self.current_block.wrapping_add(1);
|
||||
}
|
||||
|
||||
Ok(events)
|
||||
}
|
||||
|
||||
/// Handle an ACK packet during a write transfer.
|
||||
fn handle_write_ack(&mut self, block: u16) -> Result<Vec<Event>> {
|
||||
// RFC 1350 Section 4: "A WRQ is acknowledged with an ACK packet having
|
||||
// a block number of zero."
|
||||
if !self.got_initial_response {
|
||||
if block != 0 {
|
||||
return Err(Error::BlockMismatch {
|
||||
expected: 0,
|
||||
actual: block,
|
||||
});
|
||||
}
|
||||
self.got_initial_response = true;
|
||||
self.current_block = 1;
|
||||
return Ok(vec![Event::NeedData { block: 1 }]);
|
||||
}
|
||||
|
||||
if block != self.current_block {
|
||||
// Duplicate ACK - ignore
|
||||
if block == self.current_block.wrapping_sub(1) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
return Err(Error::BlockMismatch {
|
||||
expected: self.current_block,
|
||||
actual: block,
|
||||
});
|
||||
}
|
||||
|
||||
// Request next block of data
|
||||
self.current_block = self.current_block.wrapping_add(1);
|
||||
Ok(vec![Event::NeedData {
|
||||
block: self.current_block,
|
||||
}])
|
||||
}
|
||||
|
||||
/// Provide data for a write transfer.
|
||||
///
|
||||
/// Call this in response to a `NeedData` event.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `block` - The block number this data is for
|
||||
/// * `data` - The data to send (must be <= 512 bytes)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns the DATA packet to send, or marks the transfer complete if
|
||||
/// this was the final block.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the data is too large or the block number is wrong.
|
||||
pub fn provide_data(&mut self, block: u16, data: Vec<u8>) -> Result<Event> {
|
||||
if data.len() > MAX_DATA_SIZE {
|
||||
return Err(Error::DataTooLarge { size: data.len() });
|
||||
}
|
||||
|
||||
if block != self.current_block {
|
||||
return Err(Error::BlockMismatch {
|
||||
expected: self.current_block,
|
||||
actual: block,
|
||||
});
|
||||
}
|
||||
|
||||
let is_final = data.len() < MAX_DATA_SIZE;
|
||||
if is_final {
|
||||
self.complete = true;
|
||||
}
|
||||
|
||||
Ok(Event::Send(Packet::data(block, data)))
|
||||
}
|
||||
|
||||
/// Mark the transfer as complete after the final ACK is received.
|
||||
pub fn finish(&mut self) -> Event {
|
||||
self.complete = true;
|
||||
Event::Complete
|
||||
}
|
||||
}
|
||||
|
||||
/// Server-side protocol state machine.
|
||||
///
|
||||
/// Handles the server side of TFTP transfers.
|
||||
#[derive(Debug)]
|
||||
pub struct ServerState {
|
||||
/// The direction of the transfer (from server's perspective)
|
||||
direction: TransferDirection,
|
||||
/// Transfer mode
|
||||
mode: Mode,
|
||||
/// The filename being transferred
|
||||
filename: String,
|
||||
/// Current block number
|
||||
current_block: u16,
|
||||
/// Whether the transfer is complete
|
||||
complete: bool,
|
||||
/// Whether we've sent the first response
|
||||
sent_initial_response: bool,
|
||||
/// Track if we're waiting for final ACK
|
||||
waiting_for_final_ack: bool,
|
||||
}
|
||||
|
||||
impl ServerState {
|
||||
/// Create a new server state from a request packet.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the packet is not an RRQ or WRQ.
|
||||
pub fn from_request(packet: &Packet) -> Result<Self> {
|
||||
match packet {
|
||||
Packet::ReadRequest { filename, mode } => Ok(Self {
|
||||
// Server uploads when client reads
|
||||
direction: TransferDirection::Write,
|
||||
mode: *mode,
|
||||
filename: filename.clone(),
|
||||
current_block: 1,
|
||||
complete: false,
|
||||
sent_initial_response: false,
|
||||
waiting_for_final_ack: false,
|
||||
}),
|
||||
Packet::WriteRequest { filename, mode } => Ok(Self {
|
||||
// Server downloads when client writes
|
||||
direction: TransferDirection::Read,
|
||||
mode: *mode,
|
||||
filename: filename.clone(),
|
||||
current_block: 0,
|
||||
complete: false,
|
||||
sent_initial_response: false,
|
||||
waiting_for_final_ack: false,
|
||||
}),
|
||||
other => Err(Error::UnexpectedPacket {
|
||||
expected: "RRQ or WRQ",
|
||||
actual: other.type_name(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the filename being transferred.
|
||||
#[must_use]
|
||||
pub fn filename(&self) -> &str {
|
||||
&self.filename
|
||||
}
|
||||
|
||||
/// Get the transfer mode.
|
||||
#[must_use]
|
||||
pub const fn mode(&self) -> Mode {
|
||||
self.mode
|
||||
}
|
||||
|
||||
/// Get the transfer direction (from server's perspective).
|
||||
#[must_use]
|
||||
pub const fn direction(&self) -> TransferDirection {
|
||||
self.direction
|
||||
}
|
||||
|
||||
/// Check if the transfer is complete.
|
||||
#[must_use]
|
||||
pub const fn is_complete(&self) -> bool {
|
||||
self.complete
|
||||
}
|
||||
|
||||
/// Get the initial response event.
|
||||
///
|
||||
/// For RRQ: Returns `NeedData` to request the first block
|
||||
/// For WRQ: Returns `Send(ACK 0)` to acknowledge the write request
|
||||
pub fn start(&mut self) -> Event {
|
||||
self.sent_initial_response = true;
|
||||
|
||||
match self.direction {
|
||||
// Server sending data (client RRQ)
|
||||
TransferDirection::Write => Event::NeedData { block: 1 },
|
||||
|
||||
// Server receiving data (client WRQ)
|
||||
// RFC 1350 Section 4: "A WRQ is acknowledged with an ACK packet
|
||||
// having a block number of zero."
|
||||
TransferDirection::Read => Event::Send(Packet::ack(0)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a received packet and return events.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the packet is invalid for the current state.
|
||||
pub fn receive(&mut self, packet: &Packet) -> Result<Vec<Event>> {
|
||||
if self.complete {
|
||||
return Err(Error::StateError("transfer already complete".to_string()));
|
||||
}
|
||||
|
||||
match (self.direction, packet) {
|
||||
// Server sending data, receiving ACKs
|
||||
(TransferDirection::Write, Packet::Ack { block }) => self.handle_send_ack(*block),
|
||||
|
||||
// Server receiving data
|
||||
(TransferDirection::Read, Packet::Data { block, data }) => {
|
||||
self.handle_receive_data(*block, data)
|
||||
}
|
||||
|
||||
// Error terminates transfer
|
||||
(_, Packet::Error { code, message }) => {
|
||||
self.complete = true;
|
||||
Err(Error::TftpError {
|
||||
code: *code,
|
||||
message: message.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
// Unexpected packet
|
||||
(TransferDirection::Write, other) => Err(Error::UnexpectedPacket {
|
||||
expected: "ACK",
|
||||
actual: other.type_name(),
|
||||
}),
|
||||
|
||||
(TransferDirection::Read, other) => Err(Error::UnexpectedPacket {
|
||||
expected: "DATA",
|
||||
actual: other.type_name(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle an ACK when server is sending data.
|
||||
fn handle_send_ack(&mut self, block: u16) -> Result<Vec<Event>> {
|
||||
if block != self.current_block {
|
||||
// Duplicate ACK - could retransmit, but we'll just ignore
|
||||
if block == self.current_block.wrapping_sub(1) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
return Err(Error::BlockMismatch {
|
||||
expected: self.current_block,
|
||||
actual: block,
|
||||
});
|
||||
}
|
||||
|
||||
// If we were waiting for the final ACK, we're done
|
||||
if self.waiting_for_final_ack {
|
||||
self.complete = true;
|
||||
return Ok(vec![Event::Complete]);
|
||||
}
|
||||
|
||||
// Request next block
|
||||
self.current_block = self.current_block.wrapping_add(1);
|
||||
Ok(vec![Event::NeedData {
|
||||
block: self.current_block,
|
||||
}])
|
||||
}
|
||||
|
||||
/// Handle a DATA packet when server is receiving data.
|
||||
fn handle_receive_data(&mut self, block: u16, data: &[u8]) -> Result<Vec<Event>> {
|
||||
let expected_block = self.current_block.wrapping_add(1);
|
||||
|
||||
if block != expected_block {
|
||||
// Duplicate - re-ACK
|
||||
if block == self.current_block {
|
||||
return Ok(vec![Event::Send(Packet::ack(block))]);
|
||||
}
|
||||
return Err(Error::BlockMismatch {
|
||||
expected: expected_block,
|
||||
actual: block,
|
||||
});
|
||||
}
|
||||
|
||||
self.current_block = block;
|
||||
let is_final = data.len() < MAX_DATA_SIZE;
|
||||
|
||||
let mut events = vec![
|
||||
Event::ReceivedData {
|
||||
block,
|
||||
data: data.to_vec(),
|
||||
},
|
||||
Event::Send(Packet::ack(block)),
|
||||
];
|
||||
|
||||
if is_final {
|
||||
self.complete = true;
|
||||
events.push(Event::Complete);
|
||||
}
|
||||
|
||||
Ok(events)
|
||||
}
|
||||
|
||||
/// Provide data for a read transfer (server sending to client).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the data is too large.
|
||||
pub fn provide_data(&mut self, block: u16, data: Vec<u8>) -> Result<Event> {
|
||||
if data.len() > MAX_DATA_SIZE {
|
||||
return Err(Error::DataTooLarge { size: data.len() });
|
||||
}
|
||||
|
||||
if block != self.current_block {
|
||||
return Err(Error::BlockMismatch {
|
||||
expected: self.current_block,
|
||||
actual: block,
|
||||
});
|
||||
}
|
||||
|
||||
let is_final = data.len() < MAX_DATA_SIZE;
|
||||
if is_final {
|
||||
self.waiting_for_final_ack = true;
|
||||
}
|
||||
|
||||
Ok(Event::Send(Packet::data(block, data)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ErrorCode;
|
||||
|
||||
#[test]
|
||||
fn test_client_read_single_block() {
|
||||
let mut client = ClientState::new_read(Mode::Octet);
|
||||
let request = client.start("test.txt");
|
||||
assert!(matches!(
|
||||
request,
|
||||
Packet::ReadRequest { filename, mode: Mode::Octet } if filename == "test.txt"
|
||||
));
|
||||
|
||||
// Receive single DATA block (< 512 bytes = final)
|
||||
let events = client
|
||||
.receive(&Packet::data(1, vec![1, 2, 3]))
|
||||
.expect("receive failed");
|
||||
|
||||
assert_eq!(events.len(), 3);
|
||||
assert!(matches!(&events[0], Event::ReceivedData { block: 1, data } if data == &[1, 2, 3]));
|
||||
assert!(matches!(&events[1], Event::Send(Packet::Ack { block: 1 })));
|
||||
assert!(matches!(&events[2], Event::Complete));
|
||||
assert!(client.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_read_multiple_blocks() {
|
||||
let mut client = ClientState::new_read(Mode::Octet);
|
||||
let _ = client.start("test.txt");
|
||||
|
||||
// Full block (512 bytes)
|
||||
let full_data = vec![0u8; MAX_DATA_SIZE];
|
||||
let events = client
|
||||
.receive(&Packet::data(1, full_data.clone()))
|
||||
.expect("receive");
|
||||
assert_eq!(events.len(), 2);
|
||||
assert!(!client.is_complete());
|
||||
assert_eq!(client.current_block(), 2);
|
||||
|
||||
// Another full block
|
||||
let events = client
|
||||
.receive(&Packet::data(2, full_data))
|
||||
.expect("receive");
|
||||
assert_eq!(events.len(), 2);
|
||||
assert!(!client.is_complete());
|
||||
|
||||
// Final block (< 512 bytes)
|
||||
let events = client
|
||||
.receive(&Packet::data(3, vec![1, 2, 3]))
|
||||
.expect("receive");
|
||||
assert_eq!(events.len(), 3);
|
||||
assert!(client.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_client_write_flow() {
|
||||
let mut client = ClientState::new_write(Mode::Octet);
|
||||
let request = client.start("output.txt");
|
||||
assert!(matches!(request, Packet::WriteRequest { .. }));
|
||||
|
||||
// Receive ACK 0 (write request acknowledged)
|
||||
let events = client.receive(&Packet::ack(0)).expect("receive");
|
||||
assert_eq!(events.len(), 1);
|
||||
assert!(matches!(events[0], Event::NeedData { block: 1 }));
|
||||
|
||||
// Provide data for block 1
|
||||
let send_event = client.provide_data(1, vec![1, 2, 3]).expect("provide_data");
|
||||
assert!(matches!(
|
||||
send_event,
|
||||
Event::Send(Packet::Data { block: 1, .. })
|
||||
));
|
||||
assert!(client.is_complete()); // < 512 bytes = final
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_rrq() {
|
||||
let request = Packet::rrq("file.txt", Mode::Octet);
|
||||
let mut server = ServerState::from_request(&request).expect("from_request");
|
||||
|
||||
assert_eq!(server.filename(), "file.txt");
|
||||
assert_eq!(server.direction(), TransferDirection::Write);
|
||||
|
||||
// Start should request data for block 1
|
||||
let event = server.start();
|
||||
assert!(matches!(event, Event::NeedData { block: 1 }));
|
||||
|
||||
// Provide data
|
||||
let event = server.provide_data(1, vec![1, 2, 3]).expect("provide_data");
|
||||
assert!(matches!(event, Event::Send(Packet::Data { block: 1, .. })));
|
||||
|
||||
// Receive ACK
|
||||
let events = server.receive(&Packet::ack(1)).expect("receive");
|
||||
assert!(matches!(&events[0], Event::Complete));
|
||||
assert!(server.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_wrq() {
|
||||
let request = Packet::wrq("output.txt", Mode::Octet);
|
||||
let mut server = ServerState::from_request(&request).expect("from_request");
|
||||
|
||||
assert_eq!(server.direction(), TransferDirection::Read);
|
||||
|
||||
// Start should send ACK 0
|
||||
let event = server.start();
|
||||
assert!(matches!(event, Event::Send(Packet::Ack { block: 0 })));
|
||||
|
||||
// Receive DATA block 1
|
||||
let events = server
|
||||
.receive(&Packet::data(1, vec![1, 2, 3]))
|
||||
.expect("receive");
|
||||
assert_eq!(events.len(), 3);
|
||||
assert!(matches!(&events[0], Event::ReceivedData { block: 1, .. }));
|
||||
assert!(matches!(&events[1], Event::Send(Packet::Ack { block: 1 })));
|
||||
assert!(matches!(&events[2], Event::Complete));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_terminates_transfer() {
|
||||
let mut client = ClientState::new_read(Mode::Octet);
|
||||
let _ = client.start("test.txt");
|
||||
|
||||
let result = client.receive(&Packet::error(ErrorCode::FileNotFound, "File not found"));
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(client.is_complete());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duplicate_ack_ignored() {
|
||||
let request = Packet::rrq("file.txt", Mode::Octet);
|
||||
let mut server = ServerState::from_request(&request).expect("from_request");
|
||||
let _ = server.start();
|
||||
|
||||
// Provide and send block 1
|
||||
let _ = server
|
||||
.provide_data(1, vec![0u8; MAX_DATA_SIZE])
|
||||
.expect("provide_data");
|
||||
|
||||
// Receive ACK 1
|
||||
let events = server.receive(&Packet::ack(1)).expect("receive");
|
||||
assert!(!events.is_empty());
|
||||
|
||||
// Provide block 2
|
||||
let _ = server.provide_data(2, vec![1, 2, 3]).expect("provide_data");
|
||||
|
||||
// Duplicate ACK 1 should be ignored
|
||||
let events = server.receive(&Packet::ack(1)).expect("receive");
|
||||
assert!(events.is_empty());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user