diff --git a/Cargo.lock b/Cargo.lock index d40631c..e465185 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5,3 +5,69 @@ version = 4 [[package]] name = "pfs-tftp" version = "0.1.0" +dependencies = [ + "pfs-tftp-proto", + "thiserror", +] + +[[package]] +name = "pfs-tftp-proto" +version = "0.1.0" +dependencies = [ + "thiserror", +] + +[[package]] +name = "proc-macro2" +version = "1.0.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "syn" +version = "2.0.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" diff --git a/Cargo.toml b/Cargo.toml index da8351b..4552b9d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,14 +1,17 @@ -[package] -name = "pfs-tftp" +[workspace] +resolver = "3" +members = ["crates/*"] + +[workspace.package] version = "0.1.0" edition = "2024" +license = "MIT" +repository = "https://github.com/pfs/pfs-tftp" -[lints.rust] +[workspace.lints.rust] unsafe_code = "forbid" -[lints.clippy] +[workspace.lints.clippy] pedantic = { level = "warn", priority = -1 } todo = "warn" unwrap_used = "warn" - -[dependencies] diff --git a/crates/pfs-tftp-proto/Cargo.toml b/crates/pfs-tftp-proto/Cargo.toml new file mode 100644 index 0000000..f5a81fd --- /dev/null +++ b/crates/pfs-tftp-proto/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "pfs-tftp-proto" +description = "Sans-IO TFTP protocol implementation (RFC 1350)" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lints] +workspace = true + +[dependencies] +thiserror = "2" + +[dev-dependencies] diff --git a/crates/pfs-tftp-proto/src/error.rs b/crates/pfs-tftp-proto/src/error.rs new file mode 100644 index 0000000..90aae70 --- /dev/null +++ b/crates/pfs-tftp-proto/src/error.rs @@ -0,0 +1,193 @@ +//! TFTP Error Types +//! +//! Defines error codes and error handling for TFTP protocol. +//! +//! # RFC 1350 Error Codes (Appendix) +//! +//! ```text +//! Value Meaning +//! 0 Not defined, see error message (if any). +//! 1 File not found. +//! 2 Access violation. +//! 3 Disk full or allocation exceeded. +//! 4 Illegal TFTP operation. +//! 5 Unknown transfer ID. +//! 6 File already exists. +//! 7 No such user. +//! ``` + +use thiserror::Error; + +/// TFTP error codes as defined in RFC 1350 Appendix. +/// +/// These codes are sent in ERROR packets to indicate the nature of the error. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u16)] +pub enum ErrorCode { + /// Not defined, see error message (if any). + /// Used when the error doesn't fit other categories. + NotDefined = 0, + + /// File not found. + /// The requested file does not exist on the server. + FileNotFound = 1, + + /// Access violation. + /// The client does not have permission to access the file. + AccessViolation = 2, + + /// Disk full or allocation exceeded. + /// There is insufficient storage space to complete the transfer. + DiskFull = 3, + + /// Illegal TFTP operation. + /// The requested operation is not valid (e.g., malformed packet). + IllegalOperation = 4, + + /// Unknown transfer ID. + /// RFC 1350 Section 4: "If a source TID does not match, the packet should + /// be discarded as erroneously sent from somewhere else. An error packet + /// should be sent to the source of the incorrect packet." + UnknownTransferId = 5, + + /// File already exists. + /// Returned when attempting to write a file that already exists + /// and the server policy doesn't allow overwriting. + FileAlreadyExists = 6, + + /// No such user. + /// The specified user does not exist (relevant for mail mode, which is obsolete). + NoSuchUser = 7, +} + +impl ErrorCode { + /// Create an `ErrorCode` from a raw u16 value. + /// + /// Returns `NotDefined` for any unknown error code, as per RFC 1350 + /// which states error code 0 is "Not defined". + #[must_use] + pub const fn from_u16(value: u16) -> Self { + match value { + 1 => Self::FileNotFound, + 2 => Self::AccessViolation, + 3 => Self::DiskFull, + 4 => Self::IllegalOperation, + 5 => Self::UnknownTransferId, + 6 => Self::FileAlreadyExists, + 7 => Self::NoSuchUser, + // RFC 1350: Error code 0 is "Not defined", also used as fallback + _ => Self::NotDefined, + } + } + + /// Get the default error message for this error code. + #[must_use] + pub const fn default_message(&self) -> &'static str { + match self { + Self::NotDefined => "Not defined", + Self::FileNotFound => "File not found", + Self::AccessViolation => "Access violation", + Self::DiskFull => "Disk full or allocation exceeded", + Self::IllegalOperation => "Illegal TFTP operation", + Self::UnknownTransferId => "Unknown transfer ID", + Self::FileAlreadyExists => "File already exists", + Self::NoSuchUser => "No such user", + } + } +} + +impl From for u16 { + fn from(code: ErrorCode) -> Self { + code as Self + } +} + +/// Protocol-level errors for TFTP operations. +#[derive(Debug, Error)] +pub enum Error { + /// Packet is too short to be valid. + #[error("packet too short: expected at least {expected} bytes, got {actual}")] + PacketTooShort { expected: usize, actual: usize }, + + /// Invalid opcode in packet. + #[error("invalid opcode: {0}")] + InvalidOpcode(u16), + + /// Invalid transfer mode in RRQ/WRQ packet. + #[error("invalid mode: {0}")] + InvalidMode(String), + + /// Missing null terminator in string field. + #[error("missing null terminator in {field}")] + MissingNullTerminator { field: &'static str }, + + /// String field contains invalid UTF-8. + #[error("invalid UTF-8 in {field}: {source}")] + InvalidUtf8 { + field: &'static str, + #[source] + source: std::str::Utf8Error, + }, + + /// Data exceeds maximum block size. + #[error("data too large: {size} bytes exceeds maximum of 512")] + DataTooLarge { size: usize }, + + /// Buffer too small for serialization. + #[error("buffer too small: need {needed} bytes, have {available}")] + BufferTooSmall { needed: usize, available: usize }, + + /// TFTP error received from remote peer. + #[error("TFTP error {code:?}: {message}")] + TftpError { code: ErrorCode, message: String }, + + /// Protocol state error. + #[error("protocol state error: {0}")] + StateError(String), + + /// Unexpected packet type received. + #[error("unexpected packet type: expected {expected}, got {actual}")] + UnexpectedPacket { + expected: &'static str, + actual: &'static str, + }, + + /// Block number mismatch. + #[error("block number mismatch: expected {expected}, got {actual}")] + BlockMismatch { expected: u16, actual: u16 }, +} + +/// Result type for TFTP operations. +pub type Result = std::result::Result; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_code_from_u16() { + assert_eq!(ErrorCode::from_u16(0), ErrorCode::NotDefined); + assert_eq!(ErrorCode::from_u16(1), ErrorCode::FileNotFound); + assert_eq!(ErrorCode::from_u16(2), ErrorCode::AccessViolation); + assert_eq!(ErrorCode::from_u16(3), ErrorCode::DiskFull); + assert_eq!(ErrorCode::from_u16(4), ErrorCode::IllegalOperation); + assert_eq!(ErrorCode::from_u16(5), ErrorCode::UnknownTransferId); + assert_eq!(ErrorCode::from_u16(6), ErrorCode::FileAlreadyExists); + assert_eq!(ErrorCode::from_u16(7), ErrorCode::NoSuchUser); + // Unknown codes should map to NotDefined + assert_eq!(ErrorCode::from_u16(8), ErrorCode::NotDefined); + assert_eq!(ErrorCode::from_u16(255), ErrorCode::NotDefined); + } + + #[test] + fn test_error_code_to_u16() { + assert_eq!(u16::from(ErrorCode::NotDefined), 0); + assert_eq!(u16::from(ErrorCode::FileNotFound), 1); + assert_eq!(u16::from(ErrorCode::AccessViolation), 2); + assert_eq!(u16::from(ErrorCode::DiskFull), 3); + assert_eq!(u16::from(ErrorCode::IllegalOperation), 4); + assert_eq!(u16::from(ErrorCode::UnknownTransferId), 5); + assert_eq!(u16::from(ErrorCode::FileAlreadyExists), 6); + assert_eq!(u16::from(ErrorCode::NoSuchUser), 7); + } +} diff --git a/crates/pfs-tftp-proto/src/lib.rs b/crates/pfs-tftp-proto/src/lib.rs new file mode 100644 index 0000000..50f0235 --- /dev/null +++ b/crates/pfs-tftp-proto/src/lib.rs @@ -0,0 +1,21 @@ +//! Sans-IO TFTP Protocol Implementation (RFC 1350) +//! +//! This crate provides a pure protocol implementation without any I/O operations. +//! It handles packet parsing, serialization, and protocol state management. +//! +//! # RFC 1350 Overview +//! +//! TFTP (Trivial File Transfer Protocol) is a simple file transfer protocol +//! that operates over UDP. Key characteristics: +//! - Uses fixed 512-byte data blocks +//! - Stop-and-wait protocol (each packet must be acknowledged) +//! - Five packet types: RRQ, WRQ, DATA, ACK, ERROR +//! - Default port 69 for initial server connection + +mod error; +mod packet; +mod state; + +pub use error::{Error, ErrorCode, Result}; +pub use packet::{Mode, Opcode, Packet, MAX_DATA_SIZE, TFTP_PORT}; +pub use state::{ClientState, Event, ServerState, TransferDirection}; diff --git a/crates/pfs-tftp-proto/src/packet.rs b/crates/pfs-tftp-proto/src/packet.rs new file mode 100644 index 0000000..73ad7ed --- /dev/null +++ b/crates/pfs-tftp-proto/src/packet.rs @@ -0,0 +1,632 @@ +//! TFTP Packet Types and Serialization +//! +//! This module implements all five TFTP packet types as defined in RFC 1350: +//! - RRQ (Read Request) +//! - WRQ (Write Request) +//! - DATA +//! - ACK +//! - ERROR +//! +//! # RFC 1350 Section 5 - TFTP Packets +//! +//! ```text +//! TFTP supports five types of packets: +//! +//! opcode operation +//! 1 Read request (RRQ) +//! 2 Write request (WRQ) +//! 3 Data (DATA) +//! 4 Acknowledgment (ACK) +//! 5 Error (ERROR) +//! ``` + +use crate::{Error, ErrorCode, Result}; + +/// Standard TFTP server port as defined in RFC 1350. +/// +/// RFC 1350 Section 4: "A requesting host chooses its source TID as described +/// above, and sends its initial request to the known TID 69 decimal (105 octal) +/// on the serving host." +pub const TFTP_PORT: u16 = 69; + +/// Maximum data size in a DATA packet. +/// +/// RFC 1350 Section 2: "the connection is opened and the file is sent in fixed +/// length blocks of 512 bytes." +/// +/// RFC 1350 Section 5: "The data field is from zero to 512 bytes long." +pub const MAX_DATA_SIZE: usize = 512; + +/// Minimum packet size (just an opcode). +const MIN_PACKET_SIZE: usize = 2; + +/// TFTP packet opcodes. +/// +/// RFC 1350 Section 5: "The TFTP header of a packet contains the opcode +/// associated with that packet." +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u16)] +pub enum Opcode { + /// Read request (RRQ) + Rrq = 1, + /// Write request (WRQ) + Wrq = 2, + /// Data packet + Data = 3, + /// Acknowledgment + Ack = 4, + /// Error + Error = 5, +} + +impl Opcode { + /// Create an `Opcode` from a raw u16 value. + /// + /// # Errors + /// + /// Returns `Error::InvalidOpcode` if the value is not a valid opcode. + pub const fn from_u16(value: u16) -> Result { + match value { + 1 => Ok(Self::Rrq), + 2 => Ok(Self::Wrq), + 3 => Ok(Self::Data), + 4 => Ok(Self::Ack), + 5 => Ok(Self::Error), + _ => Err(Error::InvalidOpcode(value)), + } + } +} + +impl From for u16 { + fn from(opcode: Opcode) -> Self { + opcode as Self + } +} + +/// Transfer mode for TFTP operations. +/// +/// RFC 1350 Section 1: "Three modes of transfer are currently supported: +/// netascii [...]; octet [...]; mail [...]. (The mail mode is obsolete +/// and should not be implemented or used.)" +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum Mode { + /// `NetASCII` mode - text transfer with CR/LF line endings. + /// + /// RFC 1350: "netascii (This is ascii as defined in 'USA Standard Code + /// for Information Interchange' with the modifications specified in + /// 'Telnet Protocol Specification'.) Note that it is 8 bit ascii." + NetAscii, + + /// Octet mode - raw binary transfer. + /// + /// RFC 1350: "octet (This replaces the 'binary' mode of previous versions + /// of this document.) raw 8 bit bytes [...] If a host receives a octet + /// file and then returns it, the returned file must be identical to the + /// original." + #[default] + Octet, +} + +impl Mode { + /// Parse a mode string (case-insensitive). + /// + /// RFC 1350 Section 5: "The mode field contains the string 'netascii', + /// 'octet', or 'mail' (or any combination of upper and lower case, + /// such as 'NETASCII', '`NetAscii`', etc.)" + /// + /// # Errors + /// + /// Returns `Error::InvalidMode` if the string is not a recognized mode. + pub fn parse(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "netascii" => Ok(Self::NetAscii), + "octet" => Ok(Self::Octet), + // Mail mode is obsolete per RFC 1350 + "mail" => Err(Error::InvalidMode( + "mail mode is obsolete and not supported".to_string(), + )), + _ => Err(Error::InvalidMode(s.to_string())), + } + } + + /// Get the string representation of the mode. + #[must_use] + pub const fn as_str(&self) -> &'static str { + match self { + Self::NetAscii => "netascii", + Self::Octet => "octet", + } + } +} + +/// A TFTP packet. +/// +/// All packet formats are defined in RFC 1350 Section 5 and Appendix I. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Packet { + /// Read Request (RRQ) - opcode 1 + /// + /// RFC 1350 Figure 5-1: + /// ```text + /// 2 bytes string 1 byte string 1 byte + /// ------------------------------------------------ + /// | Opcode | Filename | 0 | Mode | 0 | + /// ------------------------------------------------ + /// ``` + ReadRequest { filename: String, mode: Mode }, + + /// Write Request (WRQ) - opcode 2 + /// + /// Same format as RRQ (RFC 1350 Figure 5-1). + WriteRequest { filename: String, mode: Mode }, + + /// Data packet - opcode 3 + /// + /// RFC 1350 Figure 5-2: + /// ```text + /// 2 bytes 2 bytes n bytes + /// ---------------------------------- + /// | Opcode | Block # | Data | + /// ---------------------------------- + /// ``` + /// + /// RFC 1350 Section 5: "The block numbers on data packets begin with one + /// and increase by one for each new block of data. [...] The data field + /// is from zero to 512 bytes long. If it is 512 bytes long, the block is + /// not the last block of data; if it is from zero to 511 bytes long, it + /// signals the end of the transfer." + Data { block: u16, data: Vec }, + + /// Acknowledgment packet - opcode 4 + /// + /// RFC 1350 Figure 5-3: + /// ```text + /// 2 bytes 2 bytes + /// --------------------- + /// | Opcode | Block # | + /// --------------------- + /// ``` + /// + /// RFC 1350 Section 5: "The block number in an ACK echoes the block number + /// of the DATA packet being acknowledged. A WRQ is acknowledged with an + /// ACK packet having a block number of zero." + Ack { block: u16 }, + + /// Error packet - opcode 5 + /// + /// RFC 1350 Figure 5-4: + /// ```text + /// 2 bytes 2 bytes string 1 byte + /// ----------------------------------------- + /// | Opcode | ErrorCode | ErrMsg | 0 | + /// ----------------------------------------- + /// ``` + /// + /// RFC 1350 Section 5: "An ERROR packet can be the acknowledgment of any + /// other type of packet. The error code is an integer indicating the nature + /// of the error. [...] The error message is intended for human consumption, + /// and should be in netascii." + Error { code: ErrorCode, message: String }, +} + +impl Packet { + /// Parse a TFTP packet from raw bytes. + /// + /// # Errors + /// + /// Returns an error if the packet is malformed. + pub fn parse(data: &[u8]) -> Result { + if data.len() < MIN_PACKET_SIZE { + return Err(Error::PacketTooShort { + expected: MIN_PACKET_SIZE, + actual: data.len(), + }); + } + + let opcode = u16::from_be_bytes([data[0], data[1]]); + let opcode = Opcode::from_u16(opcode)?; + let payload = &data[2..]; + + match opcode { + Opcode::Rrq => Self::parse_request(payload, false), + Opcode::Wrq => Self::parse_request(payload, true), + Opcode::Data => Self::parse_data(payload), + Opcode::Ack => Self::parse_ack(payload), + Opcode::Error => Self::parse_error(payload), + } + } + + /// Parse a RRQ or WRQ packet payload. + fn parse_request(payload: &[u8], is_write: bool) -> Result { + // Find the null terminator for filename + let filename_end = payload + .iter() + .position(|&b| b == 0) + .ok_or(Error::MissingNullTerminator { field: "filename" })?; + + let filename = std::str::from_utf8(&payload[..filename_end]).map_err(|e| { + Error::InvalidUtf8 { + field: "filename", + source: e, + } + })?; + + // Find the mode string after the filename null terminator + let mode_start = filename_end + 1; + if mode_start >= payload.len() { + return Err(Error::MissingNullTerminator { field: "mode" }); + } + + let mode_end = payload[mode_start..] + .iter() + .position(|&b| b == 0) + .ok_or(Error::MissingNullTerminator { field: "mode" })? + + mode_start; + + let mode_str = + std::str::from_utf8(&payload[mode_start..mode_end]).map_err(|e| Error::InvalidUtf8 { + field: "mode", + source: e, + })?; + + let mode = Mode::parse(mode_str)?; + + if is_write { + Ok(Self::WriteRequest { + filename: filename.to_string(), + mode, + }) + } else { + Ok(Self::ReadRequest { + filename: filename.to_string(), + mode, + }) + } + } + + /// Parse a DATA packet payload. + fn parse_data(payload: &[u8]) -> Result { + if payload.len() < 2 { + return Err(Error::PacketTooShort { + expected: 4, // 2 opcode + 2 block + actual: payload.len() + 2, + }); + } + + let block = u16::from_be_bytes([payload[0], payload[1]]); + let data = payload[2..].to_vec(); + + if data.len() > MAX_DATA_SIZE { + return Err(Error::DataTooLarge { size: data.len() }); + } + + Ok(Self::Data { block, data }) + } + + /// Parse an ACK packet payload. + fn parse_ack(payload: &[u8]) -> Result { + if payload.len() < 2 { + return Err(Error::PacketTooShort { + expected: 4, // 2 opcode + 2 block + actual: payload.len() + 2, + }); + } + + let block = u16::from_be_bytes([payload[0], payload[1]]); + Ok(Self::Ack { block }) + } + + /// Parse an ERROR packet payload. + fn parse_error(payload: &[u8]) -> Result { + if payload.len() < 2 { + return Err(Error::PacketTooShort { + expected: 4, // 2 opcode + 2 error code + actual: payload.len() + 2, + }); + } + + let code = u16::from_be_bytes([payload[0], payload[1]]); + let code = ErrorCode::from_u16(code); + + // Find the null terminator for error message + let msg_start = 2; + let msg_end = payload[msg_start..] + .iter() + .position(|&b| b == 0) + .map_or(payload.len(), |pos| pos + msg_start); + + let message = std::str::from_utf8(&payload[msg_start..msg_end]) + .map_err(|e| Error::InvalidUtf8 { + field: "error message", + source: e, + })? + .to_string(); + + Ok(Self::Error { code, message }) + } + + /// Serialize the packet into bytes. + /// + /// # Errors + /// + /// Returns an error if the data is too large. + pub fn serialize(&self) -> Result> { + match self { + Self::ReadRequest { filename, mode } => { + Ok(Self::serialize_request(Opcode::Rrq, filename, *mode)) + } + Self::WriteRequest { filename, mode } => { + Ok(Self::serialize_request(Opcode::Wrq, filename, *mode)) + } + Self::Data { block, data } => Self::serialize_data(*block, data), + Self::Ack { block } => Ok(Self::serialize_ack(*block)), + Self::Error { code, message } => Ok(Self::serialize_error(*code, message)), + } + } + + /// Serialize a RRQ or WRQ packet. + fn serialize_request(opcode: Opcode, filename: &str, mode: Mode) -> Vec { + let mode_str = mode.as_str(); + let capacity = 2 + filename.len() + 1 + mode_str.len() + 1; + let mut buf = Vec::with_capacity(capacity); + + buf.extend_from_slice(&u16::from(opcode).to_be_bytes()); + buf.extend_from_slice(filename.as_bytes()); + buf.push(0); + buf.extend_from_slice(mode_str.as_bytes()); + buf.push(0); + + buf + } + + /// Serialize a DATA packet. + fn serialize_data(block: u16, data: &[u8]) -> Result> { + if data.len() > MAX_DATA_SIZE { + return Err(Error::DataTooLarge { size: data.len() }); + } + + let mut buf = Vec::with_capacity(4 + data.len()); + buf.extend_from_slice(&u16::from(Opcode::Data).to_be_bytes()); + buf.extend_from_slice(&block.to_be_bytes()); + buf.extend_from_slice(data); + + Ok(buf) + } + + /// Serialize an ACK packet. + fn serialize_ack(block: u16) -> Vec { + let mut buf = Vec::with_capacity(4); + buf.extend_from_slice(&u16::from(Opcode::Ack).to_be_bytes()); + buf.extend_from_slice(&block.to_be_bytes()); + buf + } + + /// Serialize an ERROR packet. + fn serialize_error(code: ErrorCode, message: &str) -> Vec { + let capacity = 4 + message.len() + 1; + let mut buf = Vec::with_capacity(capacity); + + buf.extend_from_slice(&u16::from(Opcode::Error).to_be_bytes()); + buf.extend_from_slice(&u16::from(code).to_be_bytes()); + buf.extend_from_slice(message.as_bytes()); + buf.push(0); + + buf + } + + /// Get the opcode of this packet. + #[must_use] + pub const fn opcode(&self) -> Opcode { + match self { + Self::ReadRequest { .. } => Opcode::Rrq, + Self::WriteRequest { .. } => Opcode::Wrq, + Self::Data { .. } => Opcode::Data, + Self::Ack { .. } => Opcode::Ack, + Self::Error { .. } => Opcode::Error, + } + } + + /// Get a human-readable name for the packet type. + #[must_use] + pub const fn type_name(&self) -> &'static str { + match self { + Self::ReadRequest { .. } => "RRQ", + Self::WriteRequest { .. } => "WRQ", + Self::Data { .. } => "DATA", + Self::Ack { .. } => "ACK", + Self::Error { .. } => "ERROR", + } + } + + /// Check if this is the final data packet (less than 512 bytes). + /// + /// RFC 1350 Section 5: "If it is 512 bytes long, the block is not the last + /// block of data; if it is from zero to 511 bytes long, it signals the end + /// of the transfer." + #[must_use] + pub const fn is_final_data(&self) -> bool { + matches!(self, Self::Data { data, .. } if data.len() < MAX_DATA_SIZE) + } + + /// Create a new RRQ packet. + #[must_use] + pub fn rrq(filename: impl Into, mode: Mode) -> Self { + Self::ReadRequest { + filename: filename.into(), + mode, + } + } + + /// Create a new WRQ packet. + #[must_use] + pub fn wrq(filename: impl Into, mode: Mode) -> Self { + Self::WriteRequest { + filename: filename.into(), + mode, + } + } + + /// Create a new DATA packet. + #[must_use] + pub fn data(block: u16, data: Vec) -> Self { + Self::Data { block, data } + } + + /// Create a new ACK packet. + #[must_use] + pub const fn ack(block: u16) -> Self { + Self::Ack { block } + } + + /// Create a new ERROR packet. + #[must_use] + pub fn error(code: ErrorCode, message: impl Into) -> Self { + Self::Error { + code, + message: message.into(), + } + } + + /// Create an ERROR packet with default message for the error code. + #[must_use] + pub fn error_default(code: ErrorCode) -> Self { + Self::Error { + code, + message: code.default_message().to_string(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rrq_roundtrip() { + let packet = Packet::rrq("test.txt", Mode::Octet); + let bytes = packet.serialize().expect("serialize failed"); + let parsed = Packet::parse(&bytes).expect("parse failed"); + assert_eq!(packet, parsed); + } + + #[test] + fn test_wrq_roundtrip() { + let packet = Packet::wrq("output.bin", Mode::NetAscii); + let bytes = packet.serialize().expect("serialize failed"); + let parsed = Packet::parse(&bytes).expect("parse failed"); + assert_eq!(packet, parsed); + } + + #[test] + fn test_data_roundtrip() { + let packet = Packet::data(1, vec![1, 2, 3, 4, 5]); + let bytes = packet.serialize().expect("serialize failed"); + let parsed = Packet::parse(&bytes).expect("parse failed"); + assert_eq!(packet, parsed); + } + + #[test] + fn test_data_empty() { + let packet = Packet::data(42, vec![]); + let bytes = packet.serialize().expect("serialize failed"); + let parsed = Packet::parse(&bytes).expect("parse failed"); + assert_eq!(packet, parsed); + assert!(packet.is_final_data()); + } + + #[test] + fn test_data_max_size() { + let packet = Packet::data(100, vec![0xAB; MAX_DATA_SIZE]); + let bytes = packet.serialize().expect("serialize failed"); + let parsed = Packet::parse(&bytes).expect("parse failed"); + assert_eq!(packet, parsed); + assert!(!packet.is_final_data()); + } + + #[test] + fn test_data_too_large() { + let packet = Packet::data(1, vec![0; MAX_DATA_SIZE + 1]); + let result = packet.serialize(); + assert!(matches!(result, Err(Error::DataTooLarge { .. }))); + } + + #[test] + fn test_ack_roundtrip() { + let packet = Packet::ack(0); + let bytes = packet.serialize().expect("serialize failed"); + let parsed = Packet::parse(&bytes).expect("parse failed"); + assert_eq!(packet, parsed); + + let packet = Packet::ack(65535); + let bytes = packet.serialize().expect("serialize failed"); + let parsed = Packet::parse(&bytes).expect("parse failed"); + assert_eq!(packet, parsed); + } + + #[test] + fn test_error_roundtrip() { + let packet = Packet::error(ErrorCode::FileNotFound, "File not found"); + let bytes = packet.serialize().expect("serialize failed"); + let parsed = Packet::parse(&bytes).expect("parse failed"); + assert_eq!(packet, parsed); + } + + #[test] + fn test_error_default_message() { + let packet = Packet::error_default(ErrorCode::AccessViolation); + assert!(matches!( + packet, + Packet::Error { + code: ErrorCode::AccessViolation, + message + } if message == "Access violation" + )); + } + + #[test] + fn test_mode_case_insensitive() { + assert_eq!(Mode::parse("OCTET").expect("parse"), Mode::Octet); + assert_eq!(Mode::parse("Octet").expect("parse"), Mode::Octet); + assert_eq!(Mode::parse("netascii").expect("parse"), Mode::NetAscii); + assert_eq!(Mode::parse("NETASCII").expect("parse"), Mode::NetAscii); + assert_eq!(Mode::parse("NetAscii").expect("parse"), Mode::NetAscii); + } + + #[test] + fn test_invalid_mode() { + assert!(matches!( + Mode::parse("binary"), + Err(Error::InvalidMode(_)) + )); + assert!(matches!(Mode::parse("mail"), Err(Error::InvalidMode(_)))); + } + + #[test] + fn test_invalid_opcode() { + let data = [0x00, 0x06]; // opcode 6 is invalid + let result = Packet::parse(&data); + assert!(matches!(result, Err(Error::InvalidOpcode(6)))); + } + + #[test] + fn test_packet_too_short() { + let data = [0x00]; + let result = Packet::parse(&data); + assert!(matches!(result, Err(Error::PacketTooShort { .. }))); + } + + #[test] + fn test_rrq_parse_raw() { + // RRQ for "test.txt" in octet mode + let data = [ + 0x00, 0x01, // opcode = RRQ + b't', b'e', b's', b't', b'.', b't', b'x', b't', 0x00, // filename + b'o', b'c', b't', b'e', b't', 0x00, // mode + ]; + let packet = Packet::parse(&data).expect("parse failed"); + assert!(matches!( + packet, + Packet::ReadRequest { filename, mode: Mode::Octet } if filename == "test.txt" + )); + } +} diff --git a/crates/pfs-tftp-proto/src/state.rs b/crates/pfs-tftp-proto/src/state.rs new file mode 100644 index 0000000..c357728 --- /dev/null +++ b/crates/pfs-tftp-proto/src/state.rs @@ -0,0 +1,661 @@ +//! TFTP Protocol State Machine +//! +//! This module implements sans-io state machines for both client and server +//! sides of the TFTP protocol. +//! +//! # RFC 1350 Protocol Overview +//! +//! The protocol follows a lock-step acknowledgment pattern: +//! +//! ## Read Transfer (Client reading from Server) +//! ```text +//! Client -> Server: RRQ (filename, mode) +//! Server -> Client: DATA (block 1) +//! Client -> Server: ACK (block 1) +//! Server -> Client: DATA (block 2) +//! ... (continues until final DATA with < 512 bytes) +//! ``` +//! +//! ## Write Transfer (Client writing to Server) +//! ```text +//! Client -> Server: WRQ (filename, mode) +//! Server -> Client: ACK (block 0) +//! Client -> Server: DATA (block 1) +//! Server -> Client: ACK (block 1) +//! ... (continues until final DATA with < 512 bytes) +//! ``` + +use crate::{Error, Mode, Packet, Result, MAX_DATA_SIZE}; + +/// Direction of data transfer. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TransferDirection { + /// Reading data from remote (client downloads, server uploads) + Read, + /// Writing data to remote (client uploads, server downloads) + Write, +} + +/// Events emitted by the state machine. +/// +/// These events tell the I/O layer what action to take next. +/// Errors are returned via `Result` from the state machine methods, +/// not as events. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Event { + /// Send this packet to the remote peer. + Send(Packet), + + /// Request data from the application for the given block number. + /// + /// The application should provide up to 512 bytes of data. + /// A response with less than 512 bytes indicates the final block. + NeedData { block: u16 }, + + /// Received data from remote peer. + /// + /// The application should store this data. + ReceivedData { block: u16, data: Vec }, + + /// Transfer completed successfully. + Complete, +} + +/// Client-side protocol state machine. +/// +/// Handles the client side of TFTP transfers (both read and write). +#[derive(Debug)] +pub struct ClientState { + /// The direction of the transfer + direction: TransferDirection, + /// Transfer mode + mode: Mode, + /// Current block number we're expecting/sending + current_block: u16, + /// Whether the transfer is complete + complete: bool, + /// Whether we've received the initial response + got_initial_response: bool, +} + +impl ClientState { + /// Create a new client state for a read request. + /// + /// Call `start()` to get the initial RRQ packet to send. + #[must_use] + pub const fn new_read(mode: Mode) -> Self { + Self { + direction: TransferDirection::Read, + mode, + current_block: 1, // Expecting DATA block 1 + complete: false, + got_initial_response: false, + } + } + + /// Create a new client state for a write request. + /// + /// Call `start()` to get the initial WRQ packet to send. + #[must_use] + pub const fn new_write(mode: Mode) -> Self { + Self { + direction: TransferDirection::Write, + mode, + current_block: 0, // Expecting ACK block 0 + complete: false, + got_initial_response: false, + } + } + + /// Get the initial request packet to send. + #[must_use] + pub fn start(&self, filename: &str) -> Packet { + match self.direction { + TransferDirection::Read => Packet::rrq(filename, self.mode), + TransferDirection::Write => Packet::wrq(filename, self.mode), + } + } + + /// Check if the transfer is complete. + #[must_use] + pub const fn is_complete(&self) -> bool { + self.complete + } + + /// Get the current expected block number. + #[must_use] + pub const fn current_block(&self) -> u16 { + self.current_block + } + + /// Process a received packet and return events. + /// + /// # Errors + /// + /// Returns an error if the packet is invalid for the current state. + pub fn receive(&mut self, packet: &Packet) -> Result> { + if self.complete { + return Err(Error::StateError("transfer already complete".to_string())); + } + + match (self.direction, packet) { + // Reading: expect DATA packets + (TransferDirection::Read, Packet::Data { block, data }) => { + self.handle_read_data(*block, data) + } + + // Writing: expect ACK packets + (TransferDirection::Write, Packet::Ack { block }) => self.handle_write_ack(*block), + + // Error packet terminates transfer + (_, Packet::Error { code, message }) => { + self.complete = true; + Err(Error::TftpError { + code: *code, + message: message.clone(), + }) + } + + // Unexpected packet type + (TransferDirection::Read, other) => Err(Error::UnexpectedPacket { + expected: "DATA", + actual: other.type_name(), + }), + + (TransferDirection::Write, other) => Err(Error::UnexpectedPacket { + expected: "ACK", + actual: other.type_name(), + }), + } + } + + /// Handle a DATA packet during a read transfer. + fn handle_read_data(&mut self, block: u16, data: &[u8]) -> Result> { + self.got_initial_response = true; + + // RFC 1350: Block numbers start at 1 and increase sequentially + if block != self.current_block { + // Could be a duplicate - if it's the previous block, re-send ACK + if block == self.current_block.wrapping_sub(1) { + return Ok(vec![Event::Send(Packet::ack(block))]); + } + return Err(Error::BlockMismatch { + expected: self.current_block, + actual: block, + }); + } + + let is_final = data.len() < MAX_DATA_SIZE; + let mut events = vec![ + Event::ReceivedData { + block, + data: data.to_vec(), + }, + Event::Send(Packet::ack(block)), + ]; + + if is_final { + self.complete = true; + events.push(Event::Complete); + } else { + self.current_block = self.current_block.wrapping_add(1); + } + + Ok(events) + } + + /// Handle an ACK packet during a write transfer. + fn handle_write_ack(&mut self, block: u16) -> Result> { + // RFC 1350 Section 4: "A WRQ is acknowledged with an ACK packet having + // a block number of zero." + if !self.got_initial_response { + if block != 0 { + return Err(Error::BlockMismatch { + expected: 0, + actual: block, + }); + } + self.got_initial_response = true; + self.current_block = 1; + return Ok(vec![Event::NeedData { block: 1 }]); + } + + if block != self.current_block { + // Duplicate ACK - ignore + if block == self.current_block.wrapping_sub(1) { + return Ok(vec![]); + } + return Err(Error::BlockMismatch { + expected: self.current_block, + actual: block, + }); + } + + // Request next block of data + self.current_block = self.current_block.wrapping_add(1); + Ok(vec![Event::NeedData { + block: self.current_block, + }]) + } + + /// Provide data for a write transfer. + /// + /// Call this in response to a `NeedData` event. + /// + /// # Arguments + /// + /// * `block` - The block number this data is for + /// * `data` - The data to send (must be <= 512 bytes) + /// + /// # Returns + /// + /// Returns the DATA packet to send, or marks the transfer complete if + /// this was the final block. + /// + /// # Errors + /// + /// Returns an error if the data is too large or the block number is wrong. + pub fn provide_data(&mut self, block: u16, data: Vec) -> Result { + if data.len() > MAX_DATA_SIZE { + return Err(Error::DataTooLarge { size: data.len() }); + } + + if block != self.current_block { + return Err(Error::BlockMismatch { + expected: self.current_block, + actual: block, + }); + } + + let is_final = data.len() < MAX_DATA_SIZE; + if is_final { + self.complete = true; + } + + Ok(Event::Send(Packet::data(block, data))) + } + + /// Mark the transfer as complete after the final ACK is received. + pub fn finish(&mut self) -> Event { + self.complete = true; + Event::Complete + } +} + +/// Server-side protocol state machine. +/// +/// Handles the server side of TFTP transfers. +#[derive(Debug)] +pub struct ServerState { + /// The direction of the transfer (from server's perspective) + direction: TransferDirection, + /// Transfer mode + mode: Mode, + /// The filename being transferred + filename: String, + /// Current block number + current_block: u16, + /// Whether the transfer is complete + complete: bool, + /// Whether we've sent the first response + sent_initial_response: bool, + /// Track if we're waiting for final ACK + waiting_for_final_ack: bool, +} + +impl ServerState { + /// Create a new server state from a request packet. + /// + /// # Errors + /// + /// Returns an error if the packet is not an RRQ or WRQ. + pub fn from_request(packet: &Packet) -> Result { + match packet { + Packet::ReadRequest { filename, mode } => Ok(Self { + // Server uploads when client reads + direction: TransferDirection::Write, + mode: *mode, + filename: filename.clone(), + current_block: 1, + complete: false, + sent_initial_response: false, + waiting_for_final_ack: false, + }), + Packet::WriteRequest { filename, mode } => Ok(Self { + // Server downloads when client writes + direction: TransferDirection::Read, + mode: *mode, + filename: filename.clone(), + current_block: 0, + complete: false, + sent_initial_response: false, + waiting_for_final_ack: false, + }), + other => Err(Error::UnexpectedPacket { + expected: "RRQ or WRQ", + actual: other.type_name(), + }), + } + } + + /// Get the filename being transferred. + #[must_use] + pub fn filename(&self) -> &str { + &self.filename + } + + /// Get the transfer mode. + #[must_use] + pub const fn mode(&self) -> Mode { + self.mode + } + + /// Get the transfer direction (from server's perspective). + #[must_use] + pub const fn direction(&self) -> TransferDirection { + self.direction + } + + /// Check if the transfer is complete. + #[must_use] + pub const fn is_complete(&self) -> bool { + self.complete + } + + /// Get the initial response event. + /// + /// For RRQ: Returns `NeedData` to request the first block + /// For WRQ: Returns `Send(ACK 0)` to acknowledge the write request + pub fn start(&mut self) -> Event { + self.sent_initial_response = true; + + match self.direction { + // Server sending data (client RRQ) + TransferDirection::Write => Event::NeedData { block: 1 }, + + // Server receiving data (client WRQ) + // RFC 1350 Section 4: "A WRQ is acknowledged with an ACK packet + // having a block number of zero." + TransferDirection::Read => Event::Send(Packet::ack(0)), + } + } + + /// Process a received packet and return events. + /// + /// # Errors + /// + /// Returns an error if the packet is invalid for the current state. + pub fn receive(&mut self, packet: &Packet) -> Result> { + if self.complete { + return Err(Error::StateError("transfer already complete".to_string())); + } + + match (self.direction, packet) { + // Server sending data, receiving ACKs + (TransferDirection::Write, Packet::Ack { block }) => self.handle_send_ack(*block), + + // Server receiving data + (TransferDirection::Read, Packet::Data { block, data }) => { + self.handle_receive_data(*block, data) + } + + // Error terminates transfer + (_, Packet::Error { code, message }) => { + self.complete = true; + Err(Error::TftpError { + code: *code, + message: message.clone(), + }) + } + + // Unexpected packet + (TransferDirection::Write, other) => Err(Error::UnexpectedPacket { + expected: "ACK", + actual: other.type_name(), + }), + + (TransferDirection::Read, other) => Err(Error::UnexpectedPacket { + expected: "DATA", + actual: other.type_name(), + }), + } + } + + /// Handle an ACK when server is sending data. + fn handle_send_ack(&mut self, block: u16) -> Result> { + if block != self.current_block { + // Duplicate ACK - could retransmit, but we'll just ignore + if block == self.current_block.wrapping_sub(1) { + return Ok(vec![]); + } + return Err(Error::BlockMismatch { + expected: self.current_block, + actual: block, + }); + } + + // If we were waiting for the final ACK, we're done + if self.waiting_for_final_ack { + self.complete = true; + return Ok(vec![Event::Complete]); + } + + // Request next block + self.current_block = self.current_block.wrapping_add(1); + Ok(vec![Event::NeedData { + block: self.current_block, + }]) + } + + /// Handle a DATA packet when server is receiving data. + fn handle_receive_data(&mut self, block: u16, data: &[u8]) -> Result> { + let expected_block = self.current_block.wrapping_add(1); + + if block != expected_block { + // Duplicate - re-ACK + if block == self.current_block { + return Ok(vec![Event::Send(Packet::ack(block))]); + } + return Err(Error::BlockMismatch { + expected: expected_block, + actual: block, + }); + } + + self.current_block = block; + let is_final = data.len() < MAX_DATA_SIZE; + + let mut events = vec![ + Event::ReceivedData { + block, + data: data.to_vec(), + }, + Event::Send(Packet::ack(block)), + ]; + + if is_final { + self.complete = true; + events.push(Event::Complete); + } + + Ok(events) + } + + /// Provide data for a read transfer (server sending to client). + /// + /// # Errors + /// + /// Returns an error if the data is too large. + pub fn provide_data(&mut self, block: u16, data: Vec) -> Result { + if data.len() > MAX_DATA_SIZE { + return Err(Error::DataTooLarge { size: data.len() }); + } + + if block != self.current_block { + return Err(Error::BlockMismatch { + expected: self.current_block, + actual: block, + }); + } + + let is_final = data.len() < MAX_DATA_SIZE; + if is_final { + self.waiting_for_final_ack = true; + } + + Ok(Event::Send(Packet::data(block, data))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ErrorCode; + + #[test] + fn test_client_read_single_block() { + let mut client = ClientState::new_read(Mode::Octet); + let request = client.start("test.txt"); + assert!(matches!( + request, + Packet::ReadRequest { filename, mode: Mode::Octet } if filename == "test.txt" + )); + + // Receive single DATA block (< 512 bytes = final) + let events = client + .receive(&Packet::data(1, vec![1, 2, 3])) + .expect("receive failed"); + + assert_eq!(events.len(), 3); + assert!(matches!(&events[0], Event::ReceivedData { block: 1, data } if data == &[1, 2, 3])); + assert!(matches!(&events[1], Event::Send(Packet::Ack { block: 1 }))); + assert!(matches!(&events[2], Event::Complete)); + assert!(client.is_complete()); + } + + #[test] + fn test_client_read_multiple_blocks() { + let mut client = ClientState::new_read(Mode::Octet); + let _ = client.start("test.txt"); + + // Full block (512 bytes) + let full_data = vec![0u8; MAX_DATA_SIZE]; + let events = client + .receive(&Packet::data(1, full_data.clone())) + .expect("receive"); + assert_eq!(events.len(), 2); + assert!(!client.is_complete()); + assert_eq!(client.current_block(), 2); + + // Another full block + let events = client + .receive(&Packet::data(2, full_data)) + .expect("receive"); + assert_eq!(events.len(), 2); + assert!(!client.is_complete()); + + // Final block (< 512 bytes) + let events = client + .receive(&Packet::data(3, vec![1, 2, 3])) + .expect("receive"); + assert_eq!(events.len(), 3); + assert!(client.is_complete()); + } + + #[test] + fn test_client_write_flow() { + let mut client = ClientState::new_write(Mode::Octet); + let request = client.start("output.txt"); + assert!(matches!(request, Packet::WriteRequest { .. })); + + // Receive ACK 0 (write request acknowledged) + let events = client.receive(&Packet::ack(0)).expect("receive"); + assert_eq!(events.len(), 1); + assert!(matches!(events[0], Event::NeedData { block: 1 })); + + // Provide data for block 1 + let send_event = client.provide_data(1, vec![1, 2, 3]).expect("provide_data"); + assert!(matches!( + send_event, + Event::Send(Packet::Data { block: 1, .. }) + )); + assert!(client.is_complete()); // < 512 bytes = final + } + + #[test] + fn test_server_rrq() { + let request = Packet::rrq("file.txt", Mode::Octet); + let mut server = ServerState::from_request(&request).expect("from_request"); + + assert_eq!(server.filename(), "file.txt"); + assert_eq!(server.direction(), TransferDirection::Write); + + // Start should request data for block 1 + let event = server.start(); + assert!(matches!(event, Event::NeedData { block: 1 })); + + // Provide data + let event = server.provide_data(1, vec![1, 2, 3]).expect("provide_data"); + assert!(matches!(event, Event::Send(Packet::Data { block: 1, .. }))); + + // Receive ACK + let events = server.receive(&Packet::ack(1)).expect("receive"); + assert!(matches!(&events[0], Event::Complete)); + assert!(server.is_complete()); + } + + #[test] + fn test_server_wrq() { + let request = Packet::wrq("output.txt", Mode::Octet); + let mut server = ServerState::from_request(&request).expect("from_request"); + + assert_eq!(server.direction(), TransferDirection::Read); + + // Start should send ACK 0 + let event = server.start(); + assert!(matches!(event, Event::Send(Packet::Ack { block: 0 }))); + + // Receive DATA block 1 + let events = server + .receive(&Packet::data(1, vec![1, 2, 3])) + .expect("receive"); + assert_eq!(events.len(), 3); + assert!(matches!(&events[0], Event::ReceivedData { block: 1, .. })); + assert!(matches!(&events[1], Event::Send(Packet::Ack { block: 1 }))); + assert!(matches!(&events[2], Event::Complete)); + } + + #[test] + fn test_error_terminates_transfer() { + let mut client = ClientState::new_read(Mode::Octet); + let _ = client.start("test.txt"); + + let result = client.receive(&Packet::error(ErrorCode::FileNotFound, "File not found")); + + assert!(result.is_err()); + assert!(client.is_complete()); + } + + #[test] + fn test_duplicate_ack_ignored() { + let request = Packet::rrq("file.txt", Mode::Octet); + let mut server = ServerState::from_request(&request).expect("from_request"); + let _ = server.start(); + + // Provide and send block 1 + let _ = server + .provide_data(1, vec![0u8; MAX_DATA_SIZE]) + .expect("provide_data"); + + // Receive ACK 1 + let events = server.receive(&Packet::ack(1)).expect("receive"); + assert!(!events.is_empty()); + + // Provide block 2 + let _ = server.provide_data(2, vec![1, 2, 3]).expect("provide_data"); + + // Duplicate ACK 1 should be ignored + let events = server.receive(&Packet::ack(1)).expect("receive"); + assert!(events.is_empty()); + } +} diff --git a/crates/pfs-tftp/Cargo.toml b/crates/pfs-tftp/Cargo.toml new file mode 100644 index 0000000..fbdefc0 --- /dev/null +++ b/crates/pfs-tftp/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "pfs-tftp" +description = "TFTP client and server library (RFC 1350)" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true + +[lints] +workspace = true + +[dependencies] +pfs-tftp-proto = { path = "../pfs-tftp-proto" } +thiserror = "2" + +[[bin]] +name = "tftpd" +path = "src/bin/tftpd.rs" + +[[bin]] +name = "tftp" +path = "src/bin/tftp.rs" diff --git a/crates/pfs-tftp/src/bin/tftp.rs b/crates/pfs-tftp/src/bin/tftp.rs new file mode 100644 index 0000000..ff263c9 --- /dev/null +++ b/crates/pfs-tftp/src/bin/tftp.rs @@ -0,0 +1,3 @@ +fn main() { + println!("tftp client placeholder"); +} diff --git a/crates/pfs-tftp/src/bin/tftpd.rs b/crates/pfs-tftp/src/bin/tftpd.rs new file mode 100644 index 0000000..e2a41ea --- /dev/null +++ b/crates/pfs-tftp/src/bin/tftpd.rs @@ -0,0 +1,3 @@ +fn main() { + println!("tftpd server placeholder"); +} diff --git a/crates/pfs-tftp/src/client.rs b/crates/pfs-tftp/src/client.rs new file mode 100644 index 0000000..b21116a --- /dev/null +++ b/crates/pfs-tftp/src/client.rs @@ -0,0 +1,479 @@ +//! TFTP Client implementation. +//! +//! Provides synchronous TFTP client operations for reading and writing files +//! to/from a TFTP server. + +// The expect() calls in this module cannot panic in practice because we always +// set the server_tid before using it, but clippy doesn't know this. +#![allow(clippy::missing_panics_doc)] + +use std::{ + io::{Read, Write}, + net::{SocketAddr, ToSocketAddrs, UdpSocket}, + time::Duration, +}; + +use pfs_tftp_proto::{ClientState, Event, Mode, Packet, MAX_DATA_SIZE, TFTP_PORT}; + +use crate::{Error, Result}; + +/// Default timeout for TFTP operations (in seconds). +const DEFAULT_TIMEOUT_SECS: u64 = 5; + +/// Default number of retries before giving up. +const DEFAULT_RETRIES: u32 = 3; + +/// Maximum packet size (opcode + block + max data). +const MAX_PACKET_SIZE: usize = 4 + MAX_DATA_SIZE; + +/// TFTP client for transferring files to/from a server. +/// +/// # Example +/// +/// ```no_run +/// use pfs_tftp::{Client, Mode}; +/// +/// let client = Client::new("192.168.1.1:69").expect("connect"); +/// let data = client.get("config.txt", Mode::Octet).expect("get"); +/// println!("Received {} bytes", data.len()); +/// ``` +pub struct Client { + /// Server address + server_addr: SocketAddr, + /// Local UDP socket + socket: UdpSocket, + /// Timeout duration + timeout: Duration, + /// Number of retries + retries: u32, +} + +impl Client { + /// Create a new TFTP client connected to the specified server. + /// + /// The address can be in any format accepted by `ToSocketAddrs`, such as + /// "192.168.1.1:69" or "tftp.example.com:69". + /// + /// # Errors + /// + /// Returns an error if the address is invalid or the socket cannot be bound. + pub fn new(server_addr: A) -> Result { + let server_addr = server_addr + .to_socket_addrs()? + .next() + .ok_or_else(|| { + std::io::Error::new(std::io::ErrorKind::InvalidInput, "no addresses found") + })?; + + // Bind to any available local port + let socket = UdpSocket::bind("0.0.0.0:0")?; + socket.set_read_timeout(Some(Duration::from_secs(DEFAULT_TIMEOUT_SECS)))?; + + Ok(Self { + server_addr, + socket, + timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS), + retries: DEFAULT_RETRIES, + }) + } + + /// Set the timeout for operations. + #[must_use] + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + let _ = self.socket.set_read_timeout(Some(timeout)); + self + } + + /// Set the number of retries. + #[must_use] + pub const fn with_retries(mut self, retries: u32) -> Self { + self.retries = retries; + self + } + + /// Download a file from the server. + /// + /// # Arguments + /// + /// * `filename` - The name of the file to download + /// * `mode` - The transfer mode (`Octet` or `NetAscii`) + /// + /// # Returns + /// + /// The file contents as a byte vector. + /// + /// # Errors + /// + /// Returns an error if the transfer fails. + pub fn get(&self, filename: &str, mode: Mode) -> Result> { + let mut state = ClientState::new_read(mode); + let request = state.start(filename); + + // Send initial request to well-known port + let request_bytes = request.serialize()?; + self.socket.send_to(&request_bytes, self.server_addr)?; + + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + let mut data = Vec::new(); + let mut server_tid: Option = None; + let mut retries_left = self.retries; + + loop { + // Receive packet + let (len, from_addr) = match self.socket.recv_from(&mut recv_buf) { + Ok(result) => result, + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + // Timeout - retransmit last packet + if retries_left == 0 { + return Err(Error::Timeout { + retries: self.retries, + }); + } + retries_left -= 1; + self.socket.send_to(&request_bytes, self.server_addr)?; + continue; + } + Err(e) => return Err(e.into()), + }; + + // RFC 1350 Section 4: First response establishes the TID + // Subsequent packets must come from the same TID + match server_tid { + None => server_tid = Some(from_addr), + Some(expected) if from_addr != expected => { + // RFC 1350: "If a source TID does not match, the packet should + // be discarded as erroneously sent from somewhere else. An error + // packet should be sent to the source of the incorrect packet." + let error = + Packet::error(pfs_tftp_proto::ErrorCode::UnknownTransferId, "Unknown TID"); + let _ = self.socket.send_to(&error.serialize()?, from_addr); + continue; + } + _ => {} + } + + let packet = Packet::parse(&recv_buf[..len])?; + + // Handle error packets + if let Packet::Error { code, message } = packet { + return Err(Error::Remote { code, message }); + } + + let events = state.receive(&packet)?; + + for event in events { + match event { + Event::ReceivedData { data: block_data, .. } => { + data.extend_from_slice(&block_data); + } + Event::Send(packet) => { + let bytes = packet.serialize()?; + self.socket.send_to(&bytes, server_tid.expect("TID set"))?; + } + Event::Complete => { + return Ok(data); + } + Event::NeedData { .. } => { + // Not used in read transfers + } + } + } + + retries_left = self.retries; + } + } + + /// Upload a file to the server. + /// + /// # Arguments + /// + /// * `filename` - The name to give the file on the server + /// * `mode` - The transfer mode (`Octet` or `NetAscii`) + /// * `data` - The file contents to upload + /// + /// # Errors + /// + /// Returns an error if the transfer fails. + pub fn put(&self, filename: &str, mode: Mode, data: &[u8]) -> Result<()> { + let mut state = ClientState::new_write(mode); + let request = state.start(filename); + + // Send initial request + let request_bytes = request.serialize()?; + self.socket.send_to(&request_bytes, self.server_addr)?; + + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + let mut server_tid: Option = None; + let mut retries_left = self.retries; + let mut last_sent: Option> = None; + let mut data_offset: usize = 0; + + loop { + // Receive packet + let (len, from_addr) = match self.socket.recv_from(&mut recv_buf) { + Ok(result) => result, + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + // Timeout - retransmit + if retries_left == 0 { + return Err(Error::Timeout { + retries: self.retries, + }); + } + retries_left -= 1; + if let Some(ref last) = last_sent { + self.socket + .send_to(last, server_tid.unwrap_or(self.server_addr))?; + } else { + self.socket.send_to(&request_bytes, self.server_addr)?; + } + continue; + } + Err(e) => return Err(e.into()), + }; + + // TID validation + match server_tid { + None => server_tid = Some(from_addr), + Some(expected) if from_addr != expected => { + let error = + Packet::error(pfs_tftp_proto::ErrorCode::UnknownTransferId, "Unknown TID"); + let _ = self.socket.send_to(&error.serialize()?, from_addr); + continue; + } + _ => {} + } + + let packet = Packet::parse(&recv_buf[..len])?; + + if let Packet::Error { code, message } = packet { + return Err(Error::Remote { code, message }); + } + + let events = state.receive(&packet)?; + + for event in events { + match event { + Event::NeedData { block } => { + // Calculate data for this block + let start = data_offset; + let end = (start + MAX_DATA_SIZE).min(data.len()); + let block_data = data[start..end].to_vec(); + data_offset = end; + + let send_event = state.provide_data(block, block_data)?; + if let Event::Send(packet) = send_event { + let bytes = packet.serialize()?; + self.socket.send_to(&bytes, server_tid.expect("TID set"))?; + last_sent = Some(bytes); + } + } + Event::Send(packet) => { + let bytes = packet.serialize()?; + self.socket.send_to(&bytes, server_tid.expect("TID set"))?; + last_sent = Some(bytes); + } + Event::Complete => { + return Ok(()); + } + Event::ReceivedData { .. } => { + // Not used in write transfers + } + } + } + + // Check if transfer is complete after sending final data + if state.is_complete() { + return Ok(()); + } + + retries_left = self.retries; + } + } + + /// Download a file from the server to a writer. + /// + /// This is more memory-efficient for large files as it streams data + /// directly to the writer. + /// + /// # Errors + /// + /// Returns an error if the transfer fails. + pub fn get_to_writer(&self, filename: &str, mode: Mode, writer: &mut W) -> Result { + let mut state = ClientState::new_read(mode); + let request = state.start(filename); + + let request_bytes = request.serialize()?; + self.socket.send_to(&request_bytes, self.server_addr)?; + + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + let mut total_bytes: u64 = 0; + let mut server_tid: Option = None; + let mut retries_left = self.retries; + + loop { + let (len, from_addr) = match self.socket.recv_from(&mut recv_buf) { + Ok(result) => result, + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + if retries_left == 0 { + return Err(Error::Timeout { + retries: self.retries, + }); + } + retries_left -= 1; + self.socket.send_to(&request_bytes, self.server_addr)?; + continue; + } + Err(e) => return Err(e.into()), + }; + + match server_tid { + None => server_tid = Some(from_addr), + Some(expected) if from_addr != expected => { + let error = + Packet::error(pfs_tftp_proto::ErrorCode::UnknownTransferId, "Unknown TID"); + let _ = self.socket.send_to(&error.serialize()?, from_addr); + continue; + } + _ => {} + } + + let packet = Packet::parse(&recv_buf[..len])?; + + if let Packet::Error { code, message } = packet { + return Err(Error::Remote { code, message }); + } + + let events = state.receive(&packet)?; + + for event in events { + match event { + Event::ReceivedData { data, .. } => { + writer.write_all(&data)?; + total_bytes += data.len() as u64; + } + Event::Send(packet) => { + let bytes = packet.serialize()?; + self.socket.send_to(&bytes, server_tid.expect("TID set"))?; + } + Event::Complete => { + writer.flush()?; + return Ok(total_bytes); + } + Event::NeedData { .. } => {} + } + } + + retries_left = self.retries; + } + } + + /// Upload data from a reader to the server. + /// + /// This is more memory-efficient for large files as it streams data + /// from the reader. + /// + /// # Errors + /// + /// Returns an error if the transfer fails. + pub fn put_from_reader(&self, filename: &str, mode: Mode, reader: &mut R) -> Result { + let mut state = ClientState::new_write(mode); + let request = state.start(filename); + + let request_bytes = request.serialize()?; + self.socket.send_to(&request_bytes, self.server_addr)?; + + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + let mut read_buf = [0u8; MAX_DATA_SIZE]; + let mut total_bytes: u64 = 0; + let mut server_tid: Option = None; + let mut retries_left = self.retries; + let mut last_sent: Option> = None; + + loop { + let (len, from_addr) = match self.socket.recv_from(&mut recv_buf) { + Ok(result) => result, + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + if retries_left == 0 { + return Err(Error::Timeout { + retries: self.retries, + }); + } + retries_left -= 1; + if let Some(ref last) = last_sent { + self.socket + .send_to(last, server_tid.unwrap_or(self.server_addr))?; + } else { + self.socket.send_to(&request_bytes, self.server_addr)?; + } + continue; + } + Err(e) => return Err(e.into()), + }; + + match server_tid { + None => server_tid = Some(from_addr), + Some(expected) if from_addr != expected => { + let error = + Packet::error(pfs_tftp_proto::ErrorCode::UnknownTransferId, "Unknown TID"); + let _ = self.socket.send_to(&error.serialize()?, from_addr); + continue; + } + _ => {} + } + + let packet = Packet::parse(&recv_buf[..len])?; + + if let Packet::Error { code, message } = packet { + return Err(Error::Remote { code, message }); + } + + let events = state.receive(&packet)?; + + for event in events { + match event { + Event::NeedData { block } => { + let n = reader.read(&mut read_buf)?; + total_bytes += n as u64; + + let send_event = state.provide_data(block, read_buf[..n].to_vec())?; + if let Event::Send(packet) = send_event { + let bytes = packet.serialize()?; + self.socket.send_to(&bytes, server_tid.expect("TID set"))?; + last_sent = Some(bytes); + } + } + Event::Send(packet) => { + let bytes = packet.serialize()?; + self.socket.send_to(&bytes, server_tid.expect("TID set"))?; + last_sent = Some(bytes); + } + Event::Complete => { + return Ok(total_bytes); + } + Event::ReceivedData { .. } => {} + } + } + + if state.is_complete() { + return Ok(total_bytes); + } + + retries_left = self.retries; + } + } +} + +/// Create a client connected to "host:port" or "host" (default port 69). +impl std::str::FromStr for Client { + type Err = Error; + + fn from_str(s: &str) -> Result { + let addr = if s.contains(':') { + s.to_string() + } else { + format!("{s}:{TFTP_PORT}") + }; + Self::new(addr) + } +} diff --git a/crates/pfs-tftp/src/error.rs b/crates/pfs-tftp/src/error.rs new file mode 100644 index 0000000..b149fea --- /dev/null +++ b/crates/pfs-tftp/src/error.rs @@ -0,0 +1,41 @@ +//! Error types for TFTP I/O operations. + +use thiserror::Error; + +/// Errors that can occur during TFTP operations. +#[derive(Debug, Error)] +pub enum Error { + /// I/O error during socket or file operations. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + /// Protocol-level error. + #[error("protocol error: {0}")] + Protocol(#[from] pfs_tftp_proto::Error), + + /// Address parsing error. + #[error("invalid address: {0}")] + InvalidAddress(#[from] std::net::AddrParseError), + + /// Remote peer sent an error. + #[error("remote error ({code:?}): {message}")] + Remote { + code: pfs_tftp_proto::ErrorCode, + message: String, + }, + + /// Transfer timed out. + #[error("transfer timed out after {retries} retries")] + Timeout { retries: u32 }, + + /// File access error. + #[error("file access error: {0}")] + FileAccess(String), + + /// Invalid filename (e.g., path traversal attempt). + #[error("invalid filename: {0}")] + InvalidFilename(String), +} + +/// Result type for TFTP operations. +pub type Result = std::result::Result; diff --git a/crates/pfs-tftp/src/lib.rs b/crates/pfs-tftp/src/lib.rs new file mode 100644 index 0000000..6fb9dfd --- /dev/null +++ b/crates/pfs-tftp/src/lib.rs @@ -0,0 +1,32 @@ +//! TFTP Client and Server Library (RFC 1350) +//! +//! This crate provides synchronous I/O implementations for TFTP client and +//! server functionality, built on top of the `pfs-tftp-proto` protocol crate. +//! +//! # Example: Client +//! +//! ```no_run +//! use pfs_tftp::{Client, Mode}; +//! +//! let mut client = Client::new("192.168.1.1:69").expect("connect"); +//! let data = client.get("config.txt", Mode::Octet).expect("get"); +//! ``` +//! +//! # Example: Server +//! +//! ```no_run +//! use pfs_tftp::Server; +//! use std::path::Path; +//! +//! let server = Server::bind("0.0.0.0:69", Path::new("/tftpboot")).expect("bind"); +//! server.run().expect("run"); +//! ``` + +mod client; +mod error; +mod server; + +pub use client::Client; +pub use error::{Error, Result}; +pub use pfs_tftp_proto::Mode; +pub use server::{Server, ServerBuilder, ServerConfig}; diff --git a/crates/pfs-tftp/src/server.rs b/crates/pfs-tftp/src/server.rs new file mode 100644 index 0000000..aff0dba --- /dev/null +++ b/crates/pfs-tftp/src/server.rs @@ -0,0 +1,569 @@ +//! TFTP Server implementation. +//! +//! Provides a synchronous TFTP server that serves files from a root directory. + +use std::{ + fs::{File, OpenOptions}, + io::{Read, Write}, + net::{SocketAddr, ToSocketAddrs, UdpSocket}, + path::{Path, PathBuf}, + time::Duration, +}; + +use pfs_tftp_proto::{ErrorCode, Event, Mode, Packet, ServerState, MAX_DATA_SIZE, TFTP_PORT}; + +use crate::{Error, Result}; + +/// Default timeout for transfer operations (in seconds). +const DEFAULT_TIMEOUT_SECS: u64 = 5; + +/// Default number of retries before giving up on a transfer. +const DEFAULT_RETRIES: u32 = 3; + +/// Maximum packet size. +const MAX_PACKET_SIZE: usize = 4 + MAX_DATA_SIZE; + +/// Configuration for the TFTP server. +#[derive(Debug, Clone)] +pub struct ServerConfig { + /// Root directory to serve files from/to. + pub root_dir: PathBuf, + /// Allow write operations (WRQ). + pub allow_write: bool, + /// Allow overwriting existing files. + pub allow_overwrite: bool, + /// Timeout for operations. + pub timeout: Duration, + /// Number of retries. + pub retries: u32, +} + +impl Default for ServerConfig { + fn default() -> Self { + Self { + root_dir: PathBuf::from("."), + allow_write: false, + allow_overwrite: false, + timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS), + retries: DEFAULT_RETRIES, + } + } +} + +/// TFTP server. +/// +/// Serves files from a configured root directory. +/// +/// # Example +/// +/// ```no_run +/// use pfs_tftp::Server; +/// use std::path::Path; +/// +/// let server = Server::bind("0.0.0.0:69", Path::new("/tftpboot")).expect("bind"); +/// server.run().expect("run"); +/// ``` +pub struct Server { + /// Listening socket on the well-known TFTP port. + socket: UdpSocket, + /// Server configuration. + config: ServerConfig, +} + +impl Server { + /// Create a new TFTP server bound to the specified address. + /// + /// # Arguments + /// + /// * `addr` - The address to bind to (e.g., "0.0.0.0:69") + /// * `root_dir` - The root directory to serve files from + /// + /// # Errors + /// + /// Returns an error if the socket cannot be bound or the root directory + /// is invalid. + pub fn bind(addr: A, root_dir: &Path) -> Result { + let socket = UdpSocket::bind(addr)?; + + let config = ServerConfig { + root_dir: root_dir.to_path_buf(), + ..Default::default() + }; + + Ok(Self { socket, config }) + } + + /// Create a server with custom configuration. + /// + /// # Errors + /// + /// Returns an error if the socket cannot be bound. + pub fn with_config(addr: A, config: ServerConfig) -> Result { + let socket = UdpSocket::bind(addr)?; + Ok(Self { socket, config }) + } + + /// Enable or disable write operations. + #[must_use] + pub fn allow_write(mut self, allow: bool) -> Self { + self.config.allow_write = allow; + self + } + + /// Enable or disable overwriting existing files. + #[must_use] + pub fn allow_overwrite(mut self, allow: bool) -> Self { + self.config.allow_overwrite = allow; + self + } + + /// Run the server, handling requests indefinitely. + /// + /// This method blocks and handles one request at a time. + /// + /// # Errors + /// + /// Returns an error if a fatal socket error occurs. + pub fn run(&self) -> Result<()> { + let mut buf = [0u8; MAX_PACKET_SIZE]; + + loop { + let (len, client_addr) = self.socket.recv_from(&mut buf)?; + + // Try to parse the request + let packet = match Packet::parse(&buf[..len]) { + Ok(p) => p, + Err(e) => { + eprintln!("Invalid packet from {client_addr}: {e}"); + continue; + } + }; + + // Handle the request + if let Err(e) = self.handle_request(client_addr, &packet) { + eprintln!("Error handling request from {client_addr}: {e}"); + } + } + } + + /// Handle a single request, then return. + /// + /// This is useful for testing or single-request operation. + /// + /// # Errors + /// + /// Returns an error if a fatal socket error occurs. + pub fn handle_one(&self) -> Result<()> { + let mut buf = [0u8; MAX_PACKET_SIZE]; + let (len, client_addr) = self.socket.recv_from(&mut buf)?; + let packet = Packet::parse(&buf[..len])?; + self.handle_request(client_addr, &packet) + } + + /// Handle a single request. + fn handle_request(&self, client_addr: SocketAddr, packet: &Packet) -> Result<()> { + // Create a new socket for this transfer with a random TID + // RFC 1350 Section 4: "each end of the connection chooses a TID for + // itself, to be used for the duration of that connection" + let transfer_socket = UdpSocket::bind("0.0.0.0:0")?; + transfer_socket.set_read_timeout(Some(self.config.timeout))?; + transfer_socket.connect(client_addr)?; + + match packet { + Packet::ReadRequest { filename, mode } => { + self.handle_rrq(&transfer_socket, filename, *mode) + } + Packet::WriteRequest { filename, mode } => { + self.handle_wrq(&transfer_socket, filename, *mode) + } + _ => { + // Unexpected packet type on the well-known port + let error = Packet::error(ErrorCode::IllegalOperation, "Expected RRQ or WRQ"); + let _ = transfer_socket.send(&error.serialize()?); + Ok(()) + } + } + } + + /// Validate and resolve a filename to a path within the root directory. + fn resolve_path(&self, filename: &str) -> Result { + // Security: Prevent path traversal attacks + // RFC 1350 Security Considerations: "care must be taken in the rights + // granted to a TFTP server process so as not to violate the security + // of the server hosts file system" + + // Reject absolute paths + if filename.starts_with('/') || filename.starts_with('\\') { + return Err(Error::InvalidFilename( + "absolute paths not allowed".to_string(), + )); + } + + // Reject path traversal + if filename.contains("..") { + return Err(Error::InvalidFilename( + "path traversal not allowed".to_string(), + )); + } + + // Reject paths with null bytes + if filename.contains('\0') { + return Err(Error::InvalidFilename( + "null bytes not allowed".to_string(), + )); + } + + let path = self.config.root_dir.join(filename); + + // Verify the resolved path is still within root_dir + let canonical_root = self + .config + .root_dir + .canonicalize() + .map_err(|e| Error::FileAccess(format!("cannot access root directory: {e}")))?; + + // For non-existent files (WRQ), we check the parent + let check_path = if path.exists() { + path.canonicalize()? + } else { + let parent = path + .parent() + .ok_or_else(|| Error::InvalidFilename("invalid path".to_string()))?; + if !parent.exists() { + return Err(Error::FileAccess(format!( + "parent directory does not exist: {}", + parent.display() + ))); + } + parent.canonicalize()?.join(path.file_name().ok_or_else(|| { + Error::InvalidFilename("missing filename".to_string()) + })?) + }; + + if !check_path.starts_with(&canonical_root) { + return Err(Error::InvalidFilename( + "path outside root directory".to_string(), + )); + } + + Ok(path) + } + + /// Handle a read request (RRQ). + fn handle_rrq(&self, socket: &UdpSocket, filename: &str, mode: Mode) -> Result<()> { + let path = match self.resolve_path(filename) { + Ok(p) => p, + Err(e) => { + let error = Packet::error(ErrorCode::AccessViolation, e.to_string()); + let _ = socket.send(&error.serialize()?); + return Ok(()); + } + }; + + // Open the file + let mut file = match File::open(&path) { + Ok(f) => f, + Err(e) => { + let (code, msg) = match e.kind() { + std::io::ErrorKind::NotFound => { + (ErrorCode::FileNotFound, "File not found".to_string()) + } + std::io::ErrorKind::PermissionDenied => { + (ErrorCode::AccessViolation, "Permission denied".to_string()) + } + _ => (ErrorCode::NotDefined, e.to_string()), + }; + let error = Packet::error(code, msg); + let _ = socket.send(&error.serialize()?); + return Ok(()); + } + }; + + // Create server state for RRQ (server sends data) + let request = Packet::rrq(filename, mode); + let mut state = match ServerState::from_request(&request) { + Ok(s) => s, + Err(e) => { + let error = Packet::error(ErrorCode::IllegalOperation, e.to_string()); + let _ = socket.send(&error.serialize()?); + return Ok(()); + } + }; + + // Start transfer - get NeedData event + let event = state.start(); + let mut read_buf = [0u8; MAX_DATA_SIZE]; + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + let mut last_sent: Option> = None; + let mut retries_left = self.config.retries; + + // Handle initial NeedData event + if let Event::NeedData { block } = event { + let n = file.read(&mut read_buf)?; + let send_event = state.provide_data(block, read_buf[..n].to_vec())?; + if let Event::Send(packet) = send_event { + let bytes = packet.serialize()?; + socket.send(&bytes)?; + last_sent = Some(bytes); + } + } + + // Main transfer loop + loop { + let len = match socket.recv(&mut recv_buf) { + Ok(len) => len, + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + if retries_left == 0 { + return Err(Error::Timeout { + retries: self.config.retries, + }); + } + retries_left -= 1; + if let Some(ref last) = last_sent { + socket.send(last)?; + } + continue; + } + Err(e) => return Err(e.into()), + }; + + let packet = Packet::parse(&recv_buf[..len])?; + + if let Packet::Error { code, message } = packet { + return Err(Error::Remote { code, message }); + } + + let events = state.receive(&packet)?; + + for event in events { + match event { + Event::NeedData { block } => { + let n = file.read(&mut read_buf)?; + let send_event = state.provide_data(block, read_buf[..n].to_vec())?; + if let Event::Send(packet) = send_event { + let bytes = packet.serialize()?; + socket.send(&bytes)?; + last_sent = Some(bytes); + } + } + Event::Complete => { + return Ok(()); + } + // Send and ReceivedData not expected in RRQ handling + Event::Send(_) | Event::ReceivedData { .. } => {} + } + } + + if state.is_complete() { + return Ok(()); + } + + retries_left = self.config.retries; + } + } + + /// Handle a write request (WRQ). + #[allow(clippy::too_many_lines)] + fn handle_wrq(&self, socket: &UdpSocket, filename: &str, mode: Mode) -> Result<()> { + // Check if writes are allowed + if !self.config.allow_write { + let error = Packet::error(ErrorCode::AccessViolation, "Write not allowed"); + let _ = socket.send(&error.serialize()?); + return Ok(()); + } + + let path = match self.resolve_path(filename) { + Ok(p) => p, + Err(e) => { + let error = Packet::error(ErrorCode::AccessViolation, e.to_string()); + let _ = socket.send(&error.serialize()?); + return Ok(()); + } + }; + + // Check if file exists and overwrite is not allowed + if path.exists() && !self.config.allow_overwrite { + let error = Packet::error(ErrorCode::FileAlreadyExists, "File already exists"); + let _ = socket.send(&error.serialize()?); + return Ok(()); + } + + // Open/create the file + let mut file = match OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&path) + { + Ok(f) => f, + Err(e) => { + let (code, msg) = match e.kind() { + std::io::ErrorKind::PermissionDenied => { + (ErrorCode::AccessViolation, "Permission denied".to_string()) + } + _ => (ErrorCode::NotDefined, e.to_string()), + }; + let error = Packet::error(code, msg); + let _ = socket.send(&error.serialize()?); + return Ok(()); + } + }; + + // Create server state for WRQ (server receives data) + let request = Packet::wrq(filename, mode); + let mut state = match ServerState::from_request(&request) { + Ok(s) => s, + Err(e) => { + let error = Packet::error(ErrorCode::IllegalOperation, e.to_string()); + let _ = socket.send(&error.serialize()?); + return Ok(()); + } + }; + + // Start transfer - send ACK 0 + let event = state.start(); + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + let mut last_sent: Option> = None; + let mut retries_left = self.config.retries; + + if let Event::Send(packet) = event { + let bytes = packet.serialize()?; + socket.send(&bytes)?; + last_sent = Some(bytes); + } + + // Main transfer loop + loop { + let len = match socket.recv(&mut recv_buf) { + Ok(len) => len, + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + if retries_left == 0 { + // Clean up partial file on timeout + let _ = std::fs::remove_file(&path); + return Err(Error::Timeout { + retries: self.config.retries, + }); + } + retries_left -= 1; + if let Some(ref last) = last_sent { + socket.send(last)?; + } + continue; + } + Err(e) => { + let _ = std::fs::remove_file(&path); + return Err(e.into()); + } + }; + + let packet = Packet::parse(&recv_buf[..len])?; + + if let Packet::Error { code, message } = packet { + let _ = std::fs::remove_file(&path); + return Err(Error::Remote { code, message }); + } + + let events = state.receive(&packet)?; + + for event in events { + match event { + Event::ReceivedData { data, .. } => { + if let Err(e) = file.write_all(&data) { + let error = match e.kind() { + std::io::ErrorKind::StorageFull => { + Packet::error(ErrorCode::DiskFull, "Disk full") + } + _ => Packet::error(ErrorCode::NotDefined, e.to_string()), + }; + let _ = socket.send(&error.serialize()?); + let _ = std::fs::remove_file(&path); + return Ok(()); + } + } + Event::Send(packet) => { + let bytes = packet.serialize()?; + socket.send(&bytes)?; + last_sent = Some(bytes); + } + Event::Complete => { + file.flush()?; + return Ok(()); + } + // NeedData not expected in WRQ handling + Event::NeedData { .. } => {} + } + } + + if state.is_complete() { + file.flush()?; + return Ok(()); + } + + retries_left = self.config.retries; + } + } +} + +/// Builder for creating a server with custom options. +pub struct ServerBuilder { + config: ServerConfig, +} + +impl ServerBuilder { + /// Create a new server builder with the specified root directory. + #[must_use] + pub fn new(root_dir: impl Into) -> Self { + Self { + config: ServerConfig { + root_dir: root_dir.into(), + ..Default::default() + }, + } + } + + /// Allow write operations. + #[must_use] + pub const fn allow_write(mut self, allow: bool) -> Self { + self.config.allow_write = allow; + self + } + + /// Allow overwriting existing files. + #[must_use] + pub const fn allow_overwrite(mut self, allow: bool) -> Self { + self.config.allow_overwrite = allow; + self + } + + /// Set the timeout for operations. + #[must_use] + pub const fn timeout(mut self, timeout: Duration) -> Self { + self.config.timeout = timeout; + self + } + + /// Set the number of retries. + #[must_use] + pub const fn retries(mut self, retries: u32) -> Self { + self.config.retries = retries; + self + } + + /// Build the server bound to the specified address. + /// + /// # Errors + /// + /// Returns an error if the socket cannot be bound. + pub fn bind(self, addr: A) -> Result { + Server::with_config(addr, self.config) + } + + /// Build the server bound to the default TFTP port. + /// + /// # Errors + /// + /// Returns an error if the socket cannot be bound. + pub fn bind_default(self) -> Result { + self.bind(("0.0.0.0", TFTP_PORT)) + } +} diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index e7a11a9..0000000 --- a/src/main.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("Hello, world!"); -}