feat: implement TFTP protocol crate (pfs-tftp-proto)

This commit introduces a sans-IO TFTP protocol implementation following
RFC 1350. The protocol crate provides:

- Packet types: RRQ, WRQ, DATA, ACK, ERROR with full serialization/parsing
- Error codes as defined in RFC 1350 Appendix
- Transfer modes: octet (binary) and netascii
- Client and server state machines for managing protocol flow
- Comprehensive tests for all packet types and state transitions

The sans-IO design separates protocol logic from I/O operations, making
the code testable and reusable across different I/O implementations.

Key design decisions:
- Mail mode is explicitly rejected as obsolete per RFC 1350
- Block numbers wrap around at 65535 for large file support
- State machines emit events that tell the I/O layer what to do next
- All protocol-specific values are documented with RFC citations

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-21 13:10:42 +01:00
parent 5e9be9d27f
commit fd8dace0dc
15 changed files with 2746 additions and 9 deletions

66
Cargo.lock generated
View File

@@ -5,3 +5,69 @@ version = 4
[[package]]
name = "pfs-tftp"
version = "0.1.0"
dependencies = [
"pfs-tftp-proto",
"thiserror",
]
[[package]]
name = "pfs-tftp-proto"
version = "0.1.0"
dependencies = [
"thiserror",
]
[[package]]
name = "proc-macro2"
version = "1.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f"
dependencies = [
"proc-macro2",
]
[[package]]
name = "syn"
version = "2.0.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "thiserror"
version = "2.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "2.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "unicode-ident"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"

View File

@@ -1,14 +1,17 @@
[package]
name = "pfs-tftp"
[workspace]
resolver = "3"
members = ["crates/*"]
[workspace.package]
version = "0.1.0"
edition = "2024"
license = "MIT"
repository = "https://github.com/pfs/pfs-tftp"
[lints.rust]
[workspace.lints.rust]
unsafe_code = "forbid"
[lints.clippy]
[workspace.lints.clippy]
pedantic = { level = "warn", priority = -1 }
todo = "warn"
unwrap_used = "warn"
[dependencies]

View File

@@ -0,0 +1,15 @@
[package]
name = "pfs-tftp-proto"
description = "Sans-IO TFTP protocol implementation (RFC 1350)"
version.workspace = true
edition.workspace = true
license.workspace = true
repository.workspace = true
[lints]
workspace = true
[dependencies]
thiserror = "2"
[dev-dependencies]

View File

@@ -0,0 +1,193 @@
//! TFTP Error Types
//!
//! Defines error codes and error handling for TFTP protocol.
//!
//! # RFC 1350 Error Codes (Appendix)
//!
//! ```text
//! Value Meaning
//! 0 Not defined, see error message (if any).
//! 1 File not found.
//! 2 Access violation.
//! 3 Disk full or allocation exceeded.
//! 4 Illegal TFTP operation.
//! 5 Unknown transfer ID.
//! 6 File already exists.
//! 7 No such user.
//! ```
use thiserror::Error;
/// TFTP error codes as defined in RFC 1350 Appendix.
///
/// These codes are sent in ERROR packets to indicate the nature of the error.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub enum ErrorCode {
/// Not defined, see error message (if any).
/// Used when the error doesn't fit other categories.
NotDefined = 0,
/// File not found.
/// The requested file does not exist on the server.
FileNotFound = 1,
/// Access violation.
/// The client does not have permission to access the file.
AccessViolation = 2,
/// Disk full or allocation exceeded.
/// There is insufficient storage space to complete the transfer.
DiskFull = 3,
/// Illegal TFTP operation.
/// The requested operation is not valid (e.g., malformed packet).
IllegalOperation = 4,
/// Unknown transfer ID.
/// RFC 1350 Section 4: "If a source TID does not match, the packet should
/// be discarded as erroneously sent from somewhere else. An error packet
/// should be sent to the source of the incorrect packet."
UnknownTransferId = 5,
/// File already exists.
/// Returned when attempting to write a file that already exists
/// and the server policy doesn't allow overwriting.
FileAlreadyExists = 6,
/// No such user.
/// The specified user does not exist (relevant for mail mode, which is obsolete).
NoSuchUser = 7,
}
impl ErrorCode {
/// Create an `ErrorCode` from a raw u16 value.
///
/// Returns `NotDefined` for any unknown error code, as per RFC 1350
/// which states error code 0 is "Not defined".
#[must_use]
pub const fn from_u16(value: u16) -> Self {
match value {
1 => Self::FileNotFound,
2 => Self::AccessViolation,
3 => Self::DiskFull,
4 => Self::IllegalOperation,
5 => Self::UnknownTransferId,
6 => Self::FileAlreadyExists,
7 => Self::NoSuchUser,
// RFC 1350: Error code 0 is "Not defined", also used as fallback
_ => Self::NotDefined,
}
}
/// Get the default error message for this error code.
#[must_use]
pub const fn default_message(&self) -> &'static str {
match self {
Self::NotDefined => "Not defined",
Self::FileNotFound => "File not found",
Self::AccessViolation => "Access violation",
Self::DiskFull => "Disk full or allocation exceeded",
Self::IllegalOperation => "Illegal TFTP operation",
Self::UnknownTransferId => "Unknown transfer ID",
Self::FileAlreadyExists => "File already exists",
Self::NoSuchUser => "No such user",
}
}
}
impl From<ErrorCode> for u16 {
fn from(code: ErrorCode) -> Self {
code as Self
}
}
/// Protocol-level errors for TFTP operations.
#[derive(Debug, Error)]
pub enum Error {
/// Packet is too short to be valid.
#[error("packet too short: expected at least {expected} bytes, got {actual}")]
PacketTooShort { expected: usize, actual: usize },
/// Invalid opcode in packet.
#[error("invalid opcode: {0}")]
InvalidOpcode(u16),
/// Invalid transfer mode in RRQ/WRQ packet.
#[error("invalid mode: {0}")]
InvalidMode(String),
/// Missing null terminator in string field.
#[error("missing null terminator in {field}")]
MissingNullTerminator { field: &'static str },
/// String field contains invalid UTF-8.
#[error("invalid UTF-8 in {field}: {source}")]
InvalidUtf8 {
field: &'static str,
#[source]
source: std::str::Utf8Error,
},
/// Data exceeds maximum block size.
#[error("data too large: {size} bytes exceeds maximum of 512")]
DataTooLarge { size: usize },
/// Buffer too small for serialization.
#[error("buffer too small: need {needed} bytes, have {available}")]
BufferTooSmall { needed: usize, available: usize },
/// TFTP error received from remote peer.
#[error("TFTP error {code:?}: {message}")]
TftpError { code: ErrorCode, message: String },
/// Protocol state error.
#[error("protocol state error: {0}")]
StateError(String),
/// Unexpected packet type received.
#[error("unexpected packet type: expected {expected}, got {actual}")]
UnexpectedPacket {
expected: &'static str,
actual: &'static str,
},
/// Block number mismatch.
#[error("block number mismatch: expected {expected}, got {actual}")]
BlockMismatch { expected: u16, actual: u16 },
}
/// Result type for TFTP operations.
pub type Result<T> = std::result::Result<T, Error>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_code_from_u16() {
assert_eq!(ErrorCode::from_u16(0), ErrorCode::NotDefined);
assert_eq!(ErrorCode::from_u16(1), ErrorCode::FileNotFound);
assert_eq!(ErrorCode::from_u16(2), ErrorCode::AccessViolation);
assert_eq!(ErrorCode::from_u16(3), ErrorCode::DiskFull);
assert_eq!(ErrorCode::from_u16(4), ErrorCode::IllegalOperation);
assert_eq!(ErrorCode::from_u16(5), ErrorCode::UnknownTransferId);
assert_eq!(ErrorCode::from_u16(6), ErrorCode::FileAlreadyExists);
assert_eq!(ErrorCode::from_u16(7), ErrorCode::NoSuchUser);
// Unknown codes should map to NotDefined
assert_eq!(ErrorCode::from_u16(8), ErrorCode::NotDefined);
assert_eq!(ErrorCode::from_u16(255), ErrorCode::NotDefined);
}
#[test]
fn test_error_code_to_u16() {
assert_eq!(u16::from(ErrorCode::NotDefined), 0);
assert_eq!(u16::from(ErrorCode::FileNotFound), 1);
assert_eq!(u16::from(ErrorCode::AccessViolation), 2);
assert_eq!(u16::from(ErrorCode::DiskFull), 3);
assert_eq!(u16::from(ErrorCode::IllegalOperation), 4);
assert_eq!(u16::from(ErrorCode::UnknownTransferId), 5);
assert_eq!(u16::from(ErrorCode::FileAlreadyExists), 6);
assert_eq!(u16::from(ErrorCode::NoSuchUser), 7);
}
}

View File

@@ -0,0 +1,21 @@
//! Sans-IO TFTP Protocol Implementation (RFC 1350)
//!
//! This crate provides a pure protocol implementation without any I/O operations.
//! It handles packet parsing, serialization, and protocol state management.
//!
//! # RFC 1350 Overview
//!
//! TFTP (Trivial File Transfer Protocol) is a simple file transfer protocol
//! that operates over UDP. Key characteristics:
//! - Uses fixed 512-byte data blocks
//! - Stop-and-wait protocol (each packet must be acknowledged)
//! - Five packet types: RRQ, WRQ, DATA, ACK, ERROR
//! - Default port 69 for initial server connection
mod error;
mod packet;
mod state;
pub use error::{Error, ErrorCode, Result};
pub use packet::{Mode, Opcode, Packet, MAX_DATA_SIZE, TFTP_PORT};
pub use state::{ClientState, Event, ServerState, TransferDirection};

View File

@@ -0,0 +1,632 @@
//! TFTP Packet Types and Serialization
//!
//! This module implements all five TFTP packet types as defined in RFC 1350:
//! - RRQ (Read Request)
//! - WRQ (Write Request)
//! - DATA
//! - ACK
//! - ERROR
//!
//! # RFC 1350 Section 5 - TFTP Packets
//!
//! ```text
//! TFTP supports five types of packets:
//!
//! opcode operation
//! 1 Read request (RRQ)
//! 2 Write request (WRQ)
//! 3 Data (DATA)
//! 4 Acknowledgment (ACK)
//! 5 Error (ERROR)
//! ```
use crate::{Error, ErrorCode, Result};
/// Standard TFTP server port as defined in RFC 1350.
///
/// RFC 1350 Section 4: "A requesting host chooses its source TID as described
/// above, and sends its initial request to the known TID 69 decimal (105 octal)
/// on the serving host."
pub const TFTP_PORT: u16 = 69;
/// Maximum data size in a DATA packet.
///
/// RFC 1350 Section 2: "the connection is opened and the file is sent in fixed
/// length blocks of 512 bytes."
///
/// RFC 1350 Section 5: "The data field is from zero to 512 bytes long."
pub const MAX_DATA_SIZE: usize = 512;
/// Minimum packet size (just an opcode).
const MIN_PACKET_SIZE: usize = 2;
/// TFTP packet opcodes.
///
/// RFC 1350 Section 5: "The TFTP header of a packet contains the opcode
/// associated with that packet."
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub enum Opcode {
/// Read request (RRQ)
Rrq = 1,
/// Write request (WRQ)
Wrq = 2,
/// Data packet
Data = 3,
/// Acknowledgment
Ack = 4,
/// Error
Error = 5,
}
impl Opcode {
/// Create an `Opcode` from a raw u16 value.
///
/// # Errors
///
/// Returns `Error::InvalidOpcode` if the value is not a valid opcode.
pub const fn from_u16(value: u16) -> Result<Self> {
match value {
1 => Ok(Self::Rrq),
2 => Ok(Self::Wrq),
3 => Ok(Self::Data),
4 => Ok(Self::Ack),
5 => Ok(Self::Error),
_ => Err(Error::InvalidOpcode(value)),
}
}
}
impl From<Opcode> for u16 {
fn from(opcode: Opcode) -> Self {
opcode as Self
}
}
/// Transfer mode for TFTP operations.
///
/// RFC 1350 Section 1: "Three modes of transfer are currently supported:
/// netascii [...]; octet [...]; mail [...]. (The mail mode is obsolete
/// and should not be implemented or used.)"
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Mode {
/// `NetASCII` mode - text transfer with CR/LF line endings.
///
/// RFC 1350: "netascii (This is ascii as defined in 'USA Standard Code
/// for Information Interchange' with the modifications specified in
/// 'Telnet Protocol Specification'.) Note that it is 8 bit ascii."
NetAscii,
/// Octet mode - raw binary transfer.
///
/// RFC 1350: "octet (This replaces the 'binary' mode of previous versions
/// of this document.) raw 8 bit bytes [...] If a host receives a octet
/// file and then returns it, the returned file must be identical to the
/// original."
#[default]
Octet,
}
impl Mode {
/// Parse a mode string (case-insensitive).
///
/// RFC 1350 Section 5: "The mode field contains the string 'netascii',
/// 'octet', or 'mail' (or any combination of upper and lower case,
/// such as 'NETASCII', '`NetAscii`', etc.)"
///
/// # Errors
///
/// Returns `Error::InvalidMode` if the string is not a recognized mode.
pub fn parse(s: &str) -> Result<Self> {
match s.to_ascii_lowercase().as_str() {
"netascii" => Ok(Self::NetAscii),
"octet" => Ok(Self::Octet),
// Mail mode is obsolete per RFC 1350
"mail" => Err(Error::InvalidMode(
"mail mode is obsolete and not supported".to_string(),
)),
_ => Err(Error::InvalidMode(s.to_string())),
}
}
/// Get the string representation of the mode.
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::NetAscii => "netascii",
Self::Octet => "octet",
}
}
}
/// A TFTP packet.
///
/// All packet formats are defined in RFC 1350 Section 5 and Appendix I.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Packet {
/// Read Request (RRQ) - opcode 1
///
/// RFC 1350 Figure 5-1:
/// ```text
/// 2 bytes string 1 byte string 1 byte
/// ------------------------------------------------
/// | Opcode | Filename | 0 | Mode | 0 |
/// ------------------------------------------------
/// ```
ReadRequest { filename: String, mode: Mode },
/// Write Request (WRQ) - opcode 2
///
/// Same format as RRQ (RFC 1350 Figure 5-1).
WriteRequest { filename: String, mode: Mode },
/// Data packet - opcode 3
///
/// RFC 1350 Figure 5-2:
/// ```text
/// 2 bytes 2 bytes n bytes
/// ----------------------------------
/// | Opcode | Block # | Data |
/// ----------------------------------
/// ```
///
/// RFC 1350 Section 5: "The block numbers on data packets begin with one
/// and increase by one for each new block of data. [...] The data field
/// is from zero to 512 bytes long. If it is 512 bytes long, the block is
/// not the last block of data; if it is from zero to 511 bytes long, it
/// signals the end of the transfer."
Data { block: u16, data: Vec<u8> },
/// Acknowledgment packet - opcode 4
///
/// RFC 1350 Figure 5-3:
/// ```text
/// 2 bytes 2 bytes
/// ---------------------
/// | Opcode | Block # |
/// ---------------------
/// ```
///
/// RFC 1350 Section 5: "The block number in an ACK echoes the block number
/// of the DATA packet being acknowledged. A WRQ is acknowledged with an
/// ACK packet having a block number of zero."
Ack { block: u16 },
/// Error packet - opcode 5
///
/// RFC 1350 Figure 5-4:
/// ```text
/// 2 bytes 2 bytes string 1 byte
/// -----------------------------------------
/// | Opcode | ErrorCode | ErrMsg | 0 |
/// -----------------------------------------
/// ```
///
/// RFC 1350 Section 5: "An ERROR packet can be the acknowledgment of any
/// other type of packet. The error code is an integer indicating the nature
/// of the error. [...] The error message is intended for human consumption,
/// and should be in netascii."
Error { code: ErrorCode, message: String },
}
impl Packet {
/// Parse a TFTP packet from raw bytes.
///
/// # Errors
///
/// Returns an error if the packet is malformed.
pub fn parse(data: &[u8]) -> Result<Self> {
if data.len() < MIN_PACKET_SIZE {
return Err(Error::PacketTooShort {
expected: MIN_PACKET_SIZE,
actual: data.len(),
});
}
let opcode = u16::from_be_bytes([data[0], data[1]]);
let opcode = Opcode::from_u16(opcode)?;
let payload = &data[2..];
match opcode {
Opcode::Rrq => Self::parse_request(payload, false),
Opcode::Wrq => Self::parse_request(payload, true),
Opcode::Data => Self::parse_data(payload),
Opcode::Ack => Self::parse_ack(payload),
Opcode::Error => Self::parse_error(payload),
}
}
/// Parse a RRQ or WRQ packet payload.
fn parse_request(payload: &[u8], is_write: bool) -> Result<Self> {
// Find the null terminator for filename
let filename_end = payload
.iter()
.position(|&b| b == 0)
.ok_or(Error::MissingNullTerminator { field: "filename" })?;
let filename = std::str::from_utf8(&payload[..filename_end]).map_err(|e| {
Error::InvalidUtf8 {
field: "filename",
source: e,
}
})?;
// Find the mode string after the filename null terminator
let mode_start = filename_end + 1;
if mode_start >= payload.len() {
return Err(Error::MissingNullTerminator { field: "mode" });
}
let mode_end = payload[mode_start..]
.iter()
.position(|&b| b == 0)
.ok_or(Error::MissingNullTerminator { field: "mode" })?
+ mode_start;
let mode_str =
std::str::from_utf8(&payload[mode_start..mode_end]).map_err(|e| Error::InvalidUtf8 {
field: "mode",
source: e,
})?;
let mode = Mode::parse(mode_str)?;
if is_write {
Ok(Self::WriteRequest {
filename: filename.to_string(),
mode,
})
} else {
Ok(Self::ReadRequest {
filename: filename.to_string(),
mode,
})
}
}
/// Parse a DATA packet payload.
fn parse_data(payload: &[u8]) -> Result<Self> {
if payload.len() < 2 {
return Err(Error::PacketTooShort {
expected: 4, // 2 opcode + 2 block
actual: payload.len() + 2,
});
}
let block = u16::from_be_bytes([payload[0], payload[1]]);
let data = payload[2..].to_vec();
if data.len() > MAX_DATA_SIZE {
return Err(Error::DataTooLarge { size: data.len() });
}
Ok(Self::Data { block, data })
}
/// Parse an ACK packet payload.
fn parse_ack(payload: &[u8]) -> Result<Self> {
if payload.len() < 2 {
return Err(Error::PacketTooShort {
expected: 4, // 2 opcode + 2 block
actual: payload.len() + 2,
});
}
let block = u16::from_be_bytes([payload[0], payload[1]]);
Ok(Self::Ack { block })
}
/// Parse an ERROR packet payload.
fn parse_error(payload: &[u8]) -> Result<Self> {
if payload.len() < 2 {
return Err(Error::PacketTooShort {
expected: 4, // 2 opcode + 2 error code
actual: payload.len() + 2,
});
}
let code = u16::from_be_bytes([payload[0], payload[1]]);
let code = ErrorCode::from_u16(code);
// Find the null terminator for error message
let msg_start = 2;
let msg_end = payload[msg_start..]
.iter()
.position(|&b| b == 0)
.map_or(payload.len(), |pos| pos + msg_start);
let message = std::str::from_utf8(&payload[msg_start..msg_end])
.map_err(|e| Error::InvalidUtf8 {
field: "error message",
source: e,
})?
.to_string();
Ok(Self::Error { code, message })
}
/// Serialize the packet into bytes.
///
/// # Errors
///
/// Returns an error if the data is too large.
pub fn serialize(&self) -> Result<Vec<u8>> {
match self {
Self::ReadRequest { filename, mode } => {
Ok(Self::serialize_request(Opcode::Rrq, filename, *mode))
}
Self::WriteRequest { filename, mode } => {
Ok(Self::serialize_request(Opcode::Wrq, filename, *mode))
}
Self::Data { block, data } => Self::serialize_data(*block, data),
Self::Ack { block } => Ok(Self::serialize_ack(*block)),
Self::Error { code, message } => Ok(Self::serialize_error(*code, message)),
}
}
/// Serialize a RRQ or WRQ packet.
fn serialize_request(opcode: Opcode, filename: &str, mode: Mode) -> Vec<u8> {
let mode_str = mode.as_str();
let capacity = 2 + filename.len() + 1 + mode_str.len() + 1;
let mut buf = Vec::with_capacity(capacity);
buf.extend_from_slice(&u16::from(opcode).to_be_bytes());
buf.extend_from_slice(filename.as_bytes());
buf.push(0);
buf.extend_from_slice(mode_str.as_bytes());
buf.push(0);
buf
}
/// Serialize a DATA packet.
fn serialize_data(block: u16, data: &[u8]) -> Result<Vec<u8>> {
if data.len() > MAX_DATA_SIZE {
return Err(Error::DataTooLarge { size: data.len() });
}
let mut buf = Vec::with_capacity(4 + data.len());
buf.extend_from_slice(&u16::from(Opcode::Data).to_be_bytes());
buf.extend_from_slice(&block.to_be_bytes());
buf.extend_from_slice(data);
Ok(buf)
}
/// Serialize an ACK packet.
fn serialize_ack(block: u16) -> Vec<u8> {
let mut buf = Vec::with_capacity(4);
buf.extend_from_slice(&u16::from(Opcode::Ack).to_be_bytes());
buf.extend_from_slice(&block.to_be_bytes());
buf
}
/// Serialize an ERROR packet.
fn serialize_error(code: ErrorCode, message: &str) -> Vec<u8> {
let capacity = 4 + message.len() + 1;
let mut buf = Vec::with_capacity(capacity);
buf.extend_from_slice(&u16::from(Opcode::Error).to_be_bytes());
buf.extend_from_slice(&u16::from(code).to_be_bytes());
buf.extend_from_slice(message.as_bytes());
buf.push(0);
buf
}
/// Get the opcode of this packet.
#[must_use]
pub const fn opcode(&self) -> Opcode {
match self {
Self::ReadRequest { .. } => Opcode::Rrq,
Self::WriteRequest { .. } => Opcode::Wrq,
Self::Data { .. } => Opcode::Data,
Self::Ack { .. } => Opcode::Ack,
Self::Error { .. } => Opcode::Error,
}
}
/// Get a human-readable name for the packet type.
#[must_use]
pub const fn type_name(&self) -> &'static str {
match self {
Self::ReadRequest { .. } => "RRQ",
Self::WriteRequest { .. } => "WRQ",
Self::Data { .. } => "DATA",
Self::Ack { .. } => "ACK",
Self::Error { .. } => "ERROR",
}
}
/// Check if this is the final data packet (less than 512 bytes).
///
/// RFC 1350 Section 5: "If it is 512 bytes long, the block is not the last
/// block of data; if it is from zero to 511 bytes long, it signals the end
/// of the transfer."
#[must_use]
pub const fn is_final_data(&self) -> bool {
matches!(self, Self::Data { data, .. } if data.len() < MAX_DATA_SIZE)
}
/// Create a new RRQ packet.
#[must_use]
pub fn rrq(filename: impl Into<String>, mode: Mode) -> Self {
Self::ReadRequest {
filename: filename.into(),
mode,
}
}
/// Create a new WRQ packet.
#[must_use]
pub fn wrq(filename: impl Into<String>, mode: Mode) -> Self {
Self::WriteRequest {
filename: filename.into(),
mode,
}
}
/// Create a new DATA packet.
#[must_use]
pub fn data(block: u16, data: Vec<u8>) -> Self {
Self::Data { block, data }
}
/// Create a new ACK packet.
#[must_use]
pub const fn ack(block: u16) -> Self {
Self::Ack { block }
}
/// Create a new ERROR packet.
#[must_use]
pub fn error(code: ErrorCode, message: impl Into<String>) -> Self {
Self::Error {
code,
message: message.into(),
}
}
/// Create an ERROR packet with default message for the error code.
#[must_use]
pub fn error_default(code: ErrorCode) -> Self {
Self::Error {
code,
message: code.default_message().to_string(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rrq_roundtrip() {
let packet = Packet::rrq("test.txt", Mode::Octet);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
}
#[test]
fn test_wrq_roundtrip() {
let packet = Packet::wrq("output.bin", Mode::NetAscii);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
}
#[test]
fn test_data_roundtrip() {
let packet = Packet::data(1, vec![1, 2, 3, 4, 5]);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
}
#[test]
fn test_data_empty() {
let packet = Packet::data(42, vec![]);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
assert!(packet.is_final_data());
}
#[test]
fn test_data_max_size() {
let packet = Packet::data(100, vec![0xAB; MAX_DATA_SIZE]);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
assert!(!packet.is_final_data());
}
#[test]
fn test_data_too_large() {
let packet = Packet::data(1, vec![0; MAX_DATA_SIZE + 1]);
let result = packet.serialize();
assert!(matches!(result, Err(Error::DataTooLarge { .. })));
}
#[test]
fn test_ack_roundtrip() {
let packet = Packet::ack(0);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
let packet = Packet::ack(65535);
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
}
#[test]
fn test_error_roundtrip() {
let packet = Packet::error(ErrorCode::FileNotFound, "File not found");
let bytes = packet.serialize().expect("serialize failed");
let parsed = Packet::parse(&bytes).expect("parse failed");
assert_eq!(packet, parsed);
}
#[test]
fn test_error_default_message() {
let packet = Packet::error_default(ErrorCode::AccessViolation);
assert!(matches!(
packet,
Packet::Error {
code: ErrorCode::AccessViolation,
message
} if message == "Access violation"
));
}
#[test]
fn test_mode_case_insensitive() {
assert_eq!(Mode::parse("OCTET").expect("parse"), Mode::Octet);
assert_eq!(Mode::parse("Octet").expect("parse"), Mode::Octet);
assert_eq!(Mode::parse("netascii").expect("parse"), Mode::NetAscii);
assert_eq!(Mode::parse("NETASCII").expect("parse"), Mode::NetAscii);
assert_eq!(Mode::parse("NetAscii").expect("parse"), Mode::NetAscii);
}
#[test]
fn test_invalid_mode() {
assert!(matches!(
Mode::parse("binary"),
Err(Error::InvalidMode(_))
));
assert!(matches!(Mode::parse("mail"), Err(Error::InvalidMode(_))));
}
#[test]
fn test_invalid_opcode() {
let data = [0x00, 0x06]; // opcode 6 is invalid
let result = Packet::parse(&data);
assert!(matches!(result, Err(Error::InvalidOpcode(6))));
}
#[test]
fn test_packet_too_short() {
let data = [0x00];
let result = Packet::parse(&data);
assert!(matches!(result, Err(Error::PacketTooShort { .. })));
}
#[test]
fn test_rrq_parse_raw() {
// RRQ for "test.txt" in octet mode
let data = [
0x00, 0x01, // opcode = RRQ
b't', b'e', b's', b't', b'.', b't', b'x', b't', 0x00, // filename
b'o', b'c', b't', b'e', b't', 0x00, // mode
];
let packet = Packet::parse(&data).expect("parse failed");
assert!(matches!(
packet,
Packet::ReadRequest { filename, mode: Mode::Octet } if filename == "test.txt"
));
}
}

View File

@@ -0,0 +1,661 @@
//! TFTP Protocol State Machine
//!
//! This module implements sans-io state machines for both client and server
//! sides of the TFTP protocol.
//!
//! # RFC 1350 Protocol Overview
//!
//! The protocol follows a lock-step acknowledgment pattern:
//!
//! ## Read Transfer (Client reading from Server)
//! ```text
//! Client -> Server: RRQ (filename, mode)
//! Server -> Client: DATA (block 1)
//! Client -> Server: ACK (block 1)
//! Server -> Client: DATA (block 2)
//! ... (continues until final DATA with < 512 bytes)
//! ```
//!
//! ## Write Transfer (Client writing to Server)
//! ```text
//! Client -> Server: WRQ (filename, mode)
//! Server -> Client: ACK (block 0)
//! Client -> Server: DATA (block 1)
//! Server -> Client: ACK (block 1)
//! ... (continues until final DATA with < 512 bytes)
//! ```
use crate::{Error, Mode, Packet, Result, MAX_DATA_SIZE};
/// Direction of data transfer.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferDirection {
/// Reading data from remote (client downloads, server uploads)
Read,
/// Writing data to remote (client uploads, server downloads)
Write,
}
/// Events emitted by the state machine.
///
/// These events tell the I/O layer what action to take next.
/// Errors are returned via `Result` from the state machine methods,
/// not as events.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Event {
/// Send this packet to the remote peer.
Send(Packet),
/// Request data from the application for the given block number.
///
/// The application should provide up to 512 bytes of data.
/// A response with less than 512 bytes indicates the final block.
NeedData { block: u16 },
/// Received data from remote peer.
///
/// The application should store this data.
ReceivedData { block: u16, data: Vec<u8> },
/// Transfer completed successfully.
Complete,
}
/// Client-side protocol state machine.
///
/// Handles the client side of TFTP transfers (both read and write).
#[derive(Debug)]
pub struct ClientState {
/// The direction of the transfer
direction: TransferDirection,
/// Transfer mode
mode: Mode,
/// Current block number we're expecting/sending
current_block: u16,
/// Whether the transfer is complete
complete: bool,
/// Whether we've received the initial response
got_initial_response: bool,
}
impl ClientState {
/// Create a new client state for a read request.
///
/// Call `start()` to get the initial RRQ packet to send.
#[must_use]
pub const fn new_read(mode: Mode) -> Self {
Self {
direction: TransferDirection::Read,
mode,
current_block: 1, // Expecting DATA block 1
complete: false,
got_initial_response: false,
}
}
/// Create a new client state for a write request.
///
/// Call `start()` to get the initial WRQ packet to send.
#[must_use]
pub const fn new_write(mode: Mode) -> Self {
Self {
direction: TransferDirection::Write,
mode,
current_block: 0, // Expecting ACK block 0
complete: false,
got_initial_response: false,
}
}
/// Get the initial request packet to send.
#[must_use]
pub fn start(&self, filename: &str) -> Packet {
match self.direction {
TransferDirection::Read => Packet::rrq(filename, self.mode),
TransferDirection::Write => Packet::wrq(filename, self.mode),
}
}
/// Check if the transfer is complete.
#[must_use]
pub const fn is_complete(&self) -> bool {
self.complete
}
/// Get the current expected block number.
#[must_use]
pub const fn current_block(&self) -> u16 {
self.current_block
}
/// Process a received packet and return events.
///
/// # Errors
///
/// Returns an error if the packet is invalid for the current state.
pub fn receive(&mut self, packet: &Packet) -> Result<Vec<Event>> {
if self.complete {
return Err(Error::StateError("transfer already complete".to_string()));
}
match (self.direction, packet) {
// Reading: expect DATA packets
(TransferDirection::Read, Packet::Data { block, data }) => {
self.handle_read_data(*block, data)
}
// Writing: expect ACK packets
(TransferDirection::Write, Packet::Ack { block }) => self.handle_write_ack(*block),
// Error packet terminates transfer
(_, Packet::Error { code, message }) => {
self.complete = true;
Err(Error::TftpError {
code: *code,
message: message.clone(),
})
}
// Unexpected packet type
(TransferDirection::Read, other) => Err(Error::UnexpectedPacket {
expected: "DATA",
actual: other.type_name(),
}),
(TransferDirection::Write, other) => Err(Error::UnexpectedPacket {
expected: "ACK",
actual: other.type_name(),
}),
}
}
/// Handle a DATA packet during a read transfer.
fn handle_read_data(&mut self, block: u16, data: &[u8]) -> Result<Vec<Event>> {
self.got_initial_response = true;
// RFC 1350: Block numbers start at 1 and increase sequentially
if block != self.current_block {
// Could be a duplicate - if it's the previous block, re-send ACK
if block == self.current_block.wrapping_sub(1) {
return Ok(vec![Event::Send(Packet::ack(block))]);
}
return Err(Error::BlockMismatch {
expected: self.current_block,
actual: block,
});
}
let is_final = data.len() < MAX_DATA_SIZE;
let mut events = vec![
Event::ReceivedData {
block,
data: data.to_vec(),
},
Event::Send(Packet::ack(block)),
];
if is_final {
self.complete = true;
events.push(Event::Complete);
} else {
self.current_block = self.current_block.wrapping_add(1);
}
Ok(events)
}
/// Handle an ACK packet during a write transfer.
fn handle_write_ack(&mut self, block: u16) -> Result<Vec<Event>> {
// RFC 1350 Section 4: "A WRQ is acknowledged with an ACK packet having
// a block number of zero."
if !self.got_initial_response {
if block != 0 {
return Err(Error::BlockMismatch {
expected: 0,
actual: block,
});
}
self.got_initial_response = true;
self.current_block = 1;
return Ok(vec![Event::NeedData { block: 1 }]);
}
if block != self.current_block {
// Duplicate ACK - ignore
if block == self.current_block.wrapping_sub(1) {
return Ok(vec![]);
}
return Err(Error::BlockMismatch {
expected: self.current_block,
actual: block,
});
}
// Request next block of data
self.current_block = self.current_block.wrapping_add(1);
Ok(vec![Event::NeedData {
block: self.current_block,
}])
}
/// Provide data for a write transfer.
///
/// Call this in response to a `NeedData` event.
///
/// # Arguments
///
/// * `block` - The block number this data is for
/// * `data` - The data to send (must be <= 512 bytes)
///
/// # Returns
///
/// Returns the DATA packet to send, or marks the transfer complete if
/// this was the final block.
///
/// # Errors
///
/// Returns an error if the data is too large or the block number is wrong.
pub fn provide_data(&mut self, block: u16, data: Vec<u8>) -> Result<Event> {
if data.len() > MAX_DATA_SIZE {
return Err(Error::DataTooLarge { size: data.len() });
}
if block != self.current_block {
return Err(Error::BlockMismatch {
expected: self.current_block,
actual: block,
});
}
let is_final = data.len() < MAX_DATA_SIZE;
if is_final {
self.complete = true;
}
Ok(Event::Send(Packet::data(block, data)))
}
/// Mark the transfer as complete after the final ACK is received.
pub fn finish(&mut self) -> Event {
self.complete = true;
Event::Complete
}
}
/// Server-side protocol state machine.
///
/// Handles the server side of TFTP transfers.
#[derive(Debug)]
pub struct ServerState {
/// The direction of the transfer (from server's perspective)
direction: TransferDirection,
/// Transfer mode
mode: Mode,
/// The filename being transferred
filename: String,
/// Current block number
current_block: u16,
/// Whether the transfer is complete
complete: bool,
/// Whether we've sent the first response
sent_initial_response: bool,
/// Track if we're waiting for final ACK
waiting_for_final_ack: bool,
}
impl ServerState {
/// Create a new server state from a request packet.
///
/// # Errors
///
/// Returns an error if the packet is not an RRQ or WRQ.
pub fn from_request(packet: &Packet) -> Result<Self> {
match packet {
Packet::ReadRequest { filename, mode } => Ok(Self {
// Server uploads when client reads
direction: TransferDirection::Write,
mode: *mode,
filename: filename.clone(),
current_block: 1,
complete: false,
sent_initial_response: false,
waiting_for_final_ack: false,
}),
Packet::WriteRequest { filename, mode } => Ok(Self {
// Server downloads when client writes
direction: TransferDirection::Read,
mode: *mode,
filename: filename.clone(),
current_block: 0,
complete: false,
sent_initial_response: false,
waiting_for_final_ack: false,
}),
other => Err(Error::UnexpectedPacket {
expected: "RRQ or WRQ",
actual: other.type_name(),
}),
}
}
/// Get the filename being transferred.
#[must_use]
pub fn filename(&self) -> &str {
&self.filename
}
/// Get the transfer mode.
#[must_use]
pub const fn mode(&self) -> Mode {
self.mode
}
/// Get the transfer direction (from server's perspective).
#[must_use]
pub const fn direction(&self) -> TransferDirection {
self.direction
}
/// Check if the transfer is complete.
#[must_use]
pub const fn is_complete(&self) -> bool {
self.complete
}
/// Get the initial response event.
///
/// For RRQ: Returns `NeedData` to request the first block
/// For WRQ: Returns `Send(ACK 0)` to acknowledge the write request
pub fn start(&mut self) -> Event {
self.sent_initial_response = true;
match self.direction {
// Server sending data (client RRQ)
TransferDirection::Write => Event::NeedData { block: 1 },
// Server receiving data (client WRQ)
// RFC 1350 Section 4: "A WRQ is acknowledged with an ACK packet
// having a block number of zero."
TransferDirection::Read => Event::Send(Packet::ack(0)),
}
}
/// Process a received packet and return events.
///
/// # Errors
///
/// Returns an error if the packet is invalid for the current state.
pub fn receive(&mut self, packet: &Packet) -> Result<Vec<Event>> {
if self.complete {
return Err(Error::StateError("transfer already complete".to_string()));
}
match (self.direction, packet) {
// Server sending data, receiving ACKs
(TransferDirection::Write, Packet::Ack { block }) => self.handle_send_ack(*block),
// Server receiving data
(TransferDirection::Read, Packet::Data { block, data }) => {
self.handle_receive_data(*block, data)
}
// Error terminates transfer
(_, Packet::Error { code, message }) => {
self.complete = true;
Err(Error::TftpError {
code: *code,
message: message.clone(),
})
}
// Unexpected packet
(TransferDirection::Write, other) => Err(Error::UnexpectedPacket {
expected: "ACK",
actual: other.type_name(),
}),
(TransferDirection::Read, other) => Err(Error::UnexpectedPacket {
expected: "DATA",
actual: other.type_name(),
}),
}
}
/// Handle an ACK when server is sending data.
fn handle_send_ack(&mut self, block: u16) -> Result<Vec<Event>> {
if block != self.current_block {
// Duplicate ACK - could retransmit, but we'll just ignore
if block == self.current_block.wrapping_sub(1) {
return Ok(vec![]);
}
return Err(Error::BlockMismatch {
expected: self.current_block,
actual: block,
});
}
// If we were waiting for the final ACK, we're done
if self.waiting_for_final_ack {
self.complete = true;
return Ok(vec![Event::Complete]);
}
// Request next block
self.current_block = self.current_block.wrapping_add(1);
Ok(vec![Event::NeedData {
block: self.current_block,
}])
}
/// Handle a DATA packet when server is receiving data.
fn handle_receive_data(&mut self, block: u16, data: &[u8]) -> Result<Vec<Event>> {
let expected_block = self.current_block.wrapping_add(1);
if block != expected_block {
// Duplicate - re-ACK
if block == self.current_block {
return Ok(vec![Event::Send(Packet::ack(block))]);
}
return Err(Error::BlockMismatch {
expected: expected_block,
actual: block,
});
}
self.current_block = block;
let is_final = data.len() < MAX_DATA_SIZE;
let mut events = vec![
Event::ReceivedData {
block,
data: data.to_vec(),
},
Event::Send(Packet::ack(block)),
];
if is_final {
self.complete = true;
events.push(Event::Complete);
}
Ok(events)
}
/// Provide data for a read transfer (server sending to client).
///
/// # Errors
///
/// Returns an error if the data is too large.
pub fn provide_data(&mut self, block: u16, data: Vec<u8>) -> Result<Event> {
if data.len() > MAX_DATA_SIZE {
return Err(Error::DataTooLarge { size: data.len() });
}
if block != self.current_block {
return Err(Error::BlockMismatch {
expected: self.current_block,
actual: block,
});
}
let is_final = data.len() < MAX_DATA_SIZE;
if is_final {
self.waiting_for_final_ack = true;
}
Ok(Event::Send(Packet::data(block, data)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ErrorCode;
#[test]
fn test_client_read_single_block() {
let mut client = ClientState::new_read(Mode::Octet);
let request = client.start("test.txt");
assert!(matches!(
request,
Packet::ReadRequest { filename, mode: Mode::Octet } if filename == "test.txt"
));
// Receive single DATA block (< 512 bytes = final)
let events = client
.receive(&Packet::data(1, vec![1, 2, 3]))
.expect("receive failed");
assert_eq!(events.len(), 3);
assert!(matches!(&events[0], Event::ReceivedData { block: 1, data } if data == &[1, 2, 3]));
assert!(matches!(&events[1], Event::Send(Packet::Ack { block: 1 })));
assert!(matches!(&events[2], Event::Complete));
assert!(client.is_complete());
}
#[test]
fn test_client_read_multiple_blocks() {
let mut client = ClientState::new_read(Mode::Octet);
let _ = client.start("test.txt");
// Full block (512 bytes)
let full_data = vec![0u8; MAX_DATA_SIZE];
let events = client
.receive(&Packet::data(1, full_data.clone()))
.expect("receive");
assert_eq!(events.len(), 2);
assert!(!client.is_complete());
assert_eq!(client.current_block(), 2);
// Another full block
let events = client
.receive(&Packet::data(2, full_data))
.expect("receive");
assert_eq!(events.len(), 2);
assert!(!client.is_complete());
// Final block (< 512 bytes)
let events = client
.receive(&Packet::data(3, vec![1, 2, 3]))
.expect("receive");
assert_eq!(events.len(), 3);
assert!(client.is_complete());
}
#[test]
fn test_client_write_flow() {
let mut client = ClientState::new_write(Mode::Octet);
let request = client.start("output.txt");
assert!(matches!(request, Packet::WriteRequest { .. }));
// Receive ACK 0 (write request acknowledged)
let events = client.receive(&Packet::ack(0)).expect("receive");
assert_eq!(events.len(), 1);
assert!(matches!(events[0], Event::NeedData { block: 1 }));
// Provide data for block 1
let send_event = client.provide_data(1, vec![1, 2, 3]).expect("provide_data");
assert!(matches!(
send_event,
Event::Send(Packet::Data { block: 1, .. })
));
assert!(client.is_complete()); // < 512 bytes = final
}
#[test]
fn test_server_rrq() {
let request = Packet::rrq("file.txt", Mode::Octet);
let mut server = ServerState::from_request(&request).expect("from_request");
assert_eq!(server.filename(), "file.txt");
assert_eq!(server.direction(), TransferDirection::Write);
// Start should request data for block 1
let event = server.start();
assert!(matches!(event, Event::NeedData { block: 1 }));
// Provide data
let event = server.provide_data(1, vec![1, 2, 3]).expect("provide_data");
assert!(matches!(event, Event::Send(Packet::Data { block: 1, .. })));
// Receive ACK
let events = server.receive(&Packet::ack(1)).expect("receive");
assert!(matches!(&events[0], Event::Complete));
assert!(server.is_complete());
}
#[test]
fn test_server_wrq() {
let request = Packet::wrq("output.txt", Mode::Octet);
let mut server = ServerState::from_request(&request).expect("from_request");
assert_eq!(server.direction(), TransferDirection::Read);
// Start should send ACK 0
let event = server.start();
assert!(matches!(event, Event::Send(Packet::Ack { block: 0 })));
// Receive DATA block 1
let events = server
.receive(&Packet::data(1, vec![1, 2, 3]))
.expect("receive");
assert_eq!(events.len(), 3);
assert!(matches!(&events[0], Event::ReceivedData { block: 1, .. }));
assert!(matches!(&events[1], Event::Send(Packet::Ack { block: 1 })));
assert!(matches!(&events[2], Event::Complete));
}
#[test]
fn test_error_terminates_transfer() {
let mut client = ClientState::new_read(Mode::Octet);
let _ = client.start("test.txt");
let result = client.receive(&Packet::error(ErrorCode::FileNotFound, "File not found"));
assert!(result.is_err());
assert!(client.is_complete());
}
#[test]
fn test_duplicate_ack_ignored() {
let request = Packet::rrq("file.txt", Mode::Octet);
let mut server = ServerState::from_request(&request).expect("from_request");
let _ = server.start();
// Provide and send block 1
let _ = server
.provide_data(1, vec![0u8; MAX_DATA_SIZE])
.expect("provide_data");
// Receive ACK 1
let events = server.receive(&Packet::ack(1)).expect("receive");
assert!(!events.is_empty());
// Provide block 2
let _ = server.provide_data(2, vec![1, 2, 3]).expect("provide_data");
// Duplicate ACK 1 should be ignored
let events = server.receive(&Packet::ack(1)).expect("receive");
assert!(events.is_empty());
}
}

View File

@@ -0,0 +1,22 @@
[package]
name = "pfs-tftp"
description = "TFTP client and server library (RFC 1350)"
version.workspace = true
edition.workspace = true
license.workspace = true
repository.workspace = true
[lints]
workspace = true
[dependencies]
pfs-tftp-proto = { path = "../pfs-tftp-proto" }
thiserror = "2"
[[bin]]
name = "tftpd"
path = "src/bin/tftpd.rs"
[[bin]]
name = "tftp"
path = "src/bin/tftp.rs"

View File

@@ -0,0 +1,3 @@
fn main() {
println!("tftp client placeholder");
}

View File

@@ -0,0 +1,3 @@
fn main() {
println!("tftpd server placeholder");
}

View File

@@ -0,0 +1,479 @@
//! TFTP Client implementation.
//!
//! Provides synchronous TFTP client operations for reading and writing files
//! to/from a TFTP server.
// The expect() calls in this module cannot panic in practice because we always
// set the server_tid before using it, but clippy doesn't know this.
#![allow(clippy::missing_panics_doc)]
use std::{
io::{Read, Write},
net::{SocketAddr, ToSocketAddrs, UdpSocket},
time::Duration,
};
use pfs_tftp_proto::{ClientState, Event, Mode, Packet, MAX_DATA_SIZE, TFTP_PORT};
use crate::{Error, Result};
/// Default timeout for TFTP operations (in seconds).
const DEFAULT_TIMEOUT_SECS: u64 = 5;
/// Default number of retries before giving up.
const DEFAULT_RETRIES: u32 = 3;
/// Maximum packet size (opcode + block + max data).
const MAX_PACKET_SIZE: usize = 4 + MAX_DATA_SIZE;
/// TFTP client for transferring files to/from a server.
///
/// # Example
///
/// ```no_run
/// use pfs_tftp::{Client, Mode};
///
/// let client = Client::new("192.168.1.1:69").expect("connect");
/// let data = client.get("config.txt", Mode::Octet).expect("get");
/// println!("Received {} bytes", data.len());
/// ```
pub struct Client {
/// Server address
server_addr: SocketAddr,
/// Local UDP socket
socket: UdpSocket,
/// Timeout duration
timeout: Duration,
/// Number of retries
retries: u32,
}
impl Client {
/// Create a new TFTP client connected to the specified server.
///
/// The address can be in any format accepted by `ToSocketAddrs`, such as
/// "192.168.1.1:69" or "tftp.example.com:69".
///
/// # Errors
///
/// Returns an error if the address is invalid or the socket cannot be bound.
pub fn new<A: ToSocketAddrs>(server_addr: A) -> Result<Self> {
let server_addr = server_addr
.to_socket_addrs()?
.next()
.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::InvalidInput, "no addresses found")
})?;
// Bind to any available local port
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.set_read_timeout(Some(Duration::from_secs(DEFAULT_TIMEOUT_SECS)))?;
Ok(Self {
server_addr,
socket,
timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
retries: DEFAULT_RETRIES,
})
}
/// Set the timeout for operations.
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
let _ = self.socket.set_read_timeout(Some(timeout));
self
}
/// Set the number of retries.
#[must_use]
pub const fn with_retries(mut self, retries: u32) -> Self {
self.retries = retries;
self
}
/// Download a file from the server.
///
/// # Arguments
///
/// * `filename` - The name of the file to download
/// * `mode` - The transfer mode (`Octet` or `NetAscii`)
///
/// # Returns
///
/// The file contents as a byte vector.
///
/// # Errors
///
/// Returns an error if the transfer fails.
pub fn get(&self, filename: &str, mode: Mode) -> Result<Vec<u8>> {
let mut state = ClientState::new_read(mode);
let request = state.start(filename);
// Send initial request to well-known port
let request_bytes = request.serialize()?;
self.socket.send_to(&request_bytes, self.server_addr)?;
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
let mut data = Vec::new();
let mut server_tid: Option<SocketAddr> = None;
let mut retries_left = self.retries;
loop {
// Receive packet
let (len, from_addr) = match self.socket.recv_from(&mut recv_buf) {
Ok(result) => result,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
// Timeout - retransmit last packet
if retries_left == 0 {
return Err(Error::Timeout {
retries: self.retries,
});
}
retries_left -= 1;
self.socket.send_to(&request_bytes, self.server_addr)?;
continue;
}
Err(e) => return Err(e.into()),
};
// RFC 1350 Section 4: First response establishes the TID
// Subsequent packets must come from the same TID
match server_tid {
None => server_tid = Some(from_addr),
Some(expected) if from_addr != expected => {
// RFC 1350: "If a source TID does not match, the packet should
// be discarded as erroneously sent from somewhere else. An error
// packet should be sent to the source of the incorrect packet."
let error =
Packet::error(pfs_tftp_proto::ErrorCode::UnknownTransferId, "Unknown TID");
let _ = self.socket.send_to(&error.serialize()?, from_addr);
continue;
}
_ => {}
}
let packet = Packet::parse(&recv_buf[..len])?;
// Handle error packets
if let Packet::Error { code, message } = packet {
return Err(Error::Remote { code, message });
}
let events = state.receive(&packet)?;
for event in events {
match event {
Event::ReceivedData { data: block_data, .. } => {
data.extend_from_slice(&block_data);
}
Event::Send(packet) => {
let bytes = packet.serialize()?;
self.socket.send_to(&bytes, server_tid.expect("TID set"))?;
}
Event::Complete => {
return Ok(data);
}
Event::NeedData { .. } => {
// Not used in read transfers
}
}
}
retries_left = self.retries;
}
}
/// Upload a file to the server.
///
/// # Arguments
///
/// * `filename` - The name to give the file on the server
/// * `mode` - The transfer mode (`Octet` or `NetAscii`)
/// * `data` - The file contents to upload
///
/// # Errors
///
/// Returns an error if the transfer fails.
pub fn put(&self, filename: &str, mode: Mode, data: &[u8]) -> Result<()> {
let mut state = ClientState::new_write(mode);
let request = state.start(filename);
// Send initial request
let request_bytes = request.serialize()?;
self.socket.send_to(&request_bytes, self.server_addr)?;
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
let mut server_tid: Option<SocketAddr> = None;
let mut retries_left = self.retries;
let mut last_sent: Option<Vec<u8>> = None;
let mut data_offset: usize = 0;
loop {
// Receive packet
let (len, from_addr) = match self.socket.recv_from(&mut recv_buf) {
Ok(result) => result,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
// Timeout - retransmit
if retries_left == 0 {
return Err(Error::Timeout {
retries: self.retries,
});
}
retries_left -= 1;
if let Some(ref last) = last_sent {
self.socket
.send_to(last, server_tid.unwrap_or(self.server_addr))?;
} else {
self.socket.send_to(&request_bytes, self.server_addr)?;
}
continue;
}
Err(e) => return Err(e.into()),
};
// TID validation
match server_tid {
None => server_tid = Some(from_addr),
Some(expected) if from_addr != expected => {
let error =
Packet::error(pfs_tftp_proto::ErrorCode::UnknownTransferId, "Unknown TID");
let _ = self.socket.send_to(&error.serialize()?, from_addr);
continue;
}
_ => {}
}
let packet = Packet::parse(&recv_buf[..len])?;
if let Packet::Error { code, message } = packet {
return Err(Error::Remote { code, message });
}
let events = state.receive(&packet)?;
for event in events {
match event {
Event::NeedData { block } => {
// Calculate data for this block
let start = data_offset;
let end = (start + MAX_DATA_SIZE).min(data.len());
let block_data = data[start..end].to_vec();
data_offset = end;
let send_event = state.provide_data(block, block_data)?;
if let Event::Send(packet) = send_event {
let bytes = packet.serialize()?;
self.socket.send_to(&bytes, server_tid.expect("TID set"))?;
last_sent = Some(bytes);
}
}
Event::Send(packet) => {
let bytes = packet.serialize()?;
self.socket.send_to(&bytes, server_tid.expect("TID set"))?;
last_sent = Some(bytes);
}
Event::Complete => {
return Ok(());
}
Event::ReceivedData { .. } => {
// Not used in write transfers
}
}
}
// Check if transfer is complete after sending final data
if state.is_complete() {
return Ok(());
}
retries_left = self.retries;
}
}
/// Download a file from the server to a writer.
///
/// This is more memory-efficient for large files as it streams data
/// directly to the writer.
///
/// # Errors
///
/// Returns an error if the transfer fails.
pub fn get_to_writer<W: Write>(&self, filename: &str, mode: Mode, writer: &mut W) -> Result<u64> {
let mut state = ClientState::new_read(mode);
let request = state.start(filename);
let request_bytes = request.serialize()?;
self.socket.send_to(&request_bytes, self.server_addr)?;
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
let mut total_bytes: u64 = 0;
let mut server_tid: Option<SocketAddr> = None;
let mut retries_left = self.retries;
loop {
let (len, from_addr) = match self.socket.recv_from(&mut recv_buf) {
Ok(result) => result,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
if retries_left == 0 {
return Err(Error::Timeout {
retries: self.retries,
});
}
retries_left -= 1;
self.socket.send_to(&request_bytes, self.server_addr)?;
continue;
}
Err(e) => return Err(e.into()),
};
match server_tid {
None => server_tid = Some(from_addr),
Some(expected) if from_addr != expected => {
let error =
Packet::error(pfs_tftp_proto::ErrorCode::UnknownTransferId, "Unknown TID");
let _ = self.socket.send_to(&error.serialize()?, from_addr);
continue;
}
_ => {}
}
let packet = Packet::parse(&recv_buf[..len])?;
if let Packet::Error { code, message } = packet {
return Err(Error::Remote { code, message });
}
let events = state.receive(&packet)?;
for event in events {
match event {
Event::ReceivedData { data, .. } => {
writer.write_all(&data)?;
total_bytes += data.len() as u64;
}
Event::Send(packet) => {
let bytes = packet.serialize()?;
self.socket.send_to(&bytes, server_tid.expect("TID set"))?;
}
Event::Complete => {
writer.flush()?;
return Ok(total_bytes);
}
Event::NeedData { .. } => {}
}
}
retries_left = self.retries;
}
}
/// Upload data from a reader to the server.
///
/// This is more memory-efficient for large files as it streams data
/// from the reader.
///
/// # Errors
///
/// Returns an error if the transfer fails.
pub fn put_from_reader<R: Read>(&self, filename: &str, mode: Mode, reader: &mut R) -> Result<u64> {
let mut state = ClientState::new_write(mode);
let request = state.start(filename);
let request_bytes = request.serialize()?;
self.socket.send_to(&request_bytes, self.server_addr)?;
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
let mut read_buf = [0u8; MAX_DATA_SIZE];
let mut total_bytes: u64 = 0;
let mut server_tid: Option<SocketAddr> = None;
let mut retries_left = self.retries;
let mut last_sent: Option<Vec<u8>> = None;
loop {
let (len, from_addr) = match self.socket.recv_from(&mut recv_buf) {
Ok(result) => result,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
if retries_left == 0 {
return Err(Error::Timeout {
retries: self.retries,
});
}
retries_left -= 1;
if let Some(ref last) = last_sent {
self.socket
.send_to(last, server_tid.unwrap_or(self.server_addr))?;
} else {
self.socket.send_to(&request_bytes, self.server_addr)?;
}
continue;
}
Err(e) => return Err(e.into()),
};
match server_tid {
None => server_tid = Some(from_addr),
Some(expected) if from_addr != expected => {
let error =
Packet::error(pfs_tftp_proto::ErrorCode::UnknownTransferId, "Unknown TID");
let _ = self.socket.send_to(&error.serialize()?, from_addr);
continue;
}
_ => {}
}
let packet = Packet::parse(&recv_buf[..len])?;
if let Packet::Error { code, message } = packet {
return Err(Error::Remote { code, message });
}
let events = state.receive(&packet)?;
for event in events {
match event {
Event::NeedData { block } => {
let n = reader.read(&mut read_buf)?;
total_bytes += n as u64;
let send_event = state.provide_data(block, read_buf[..n].to_vec())?;
if let Event::Send(packet) = send_event {
let bytes = packet.serialize()?;
self.socket.send_to(&bytes, server_tid.expect("TID set"))?;
last_sent = Some(bytes);
}
}
Event::Send(packet) => {
let bytes = packet.serialize()?;
self.socket.send_to(&bytes, server_tid.expect("TID set"))?;
last_sent = Some(bytes);
}
Event::Complete => {
return Ok(total_bytes);
}
Event::ReceivedData { .. } => {}
}
}
if state.is_complete() {
return Ok(total_bytes);
}
retries_left = self.retries;
}
}
}
/// Create a client connected to "host:port" or "host" (default port 69).
impl std::str::FromStr for Client {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
let addr = if s.contains(':') {
s.to_string()
} else {
format!("{s}:{TFTP_PORT}")
};
Self::new(addr)
}
}

View File

@@ -0,0 +1,41 @@
//! Error types for TFTP I/O operations.
use thiserror::Error;
/// Errors that can occur during TFTP operations.
#[derive(Debug, Error)]
pub enum Error {
/// I/O error during socket or file operations.
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
/// Protocol-level error.
#[error("protocol error: {0}")]
Protocol(#[from] pfs_tftp_proto::Error),
/// Address parsing error.
#[error("invalid address: {0}")]
InvalidAddress(#[from] std::net::AddrParseError),
/// Remote peer sent an error.
#[error("remote error ({code:?}): {message}")]
Remote {
code: pfs_tftp_proto::ErrorCode,
message: String,
},
/// Transfer timed out.
#[error("transfer timed out after {retries} retries")]
Timeout { retries: u32 },
/// File access error.
#[error("file access error: {0}")]
FileAccess(String),
/// Invalid filename (e.g., path traversal attempt).
#[error("invalid filename: {0}")]
InvalidFilename(String),
}
/// Result type for TFTP operations.
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -0,0 +1,32 @@
//! TFTP Client and Server Library (RFC 1350)
//!
//! This crate provides synchronous I/O implementations for TFTP client and
//! server functionality, built on top of the `pfs-tftp-proto` protocol crate.
//!
//! # Example: Client
//!
//! ```no_run
//! use pfs_tftp::{Client, Mode};
//!
//! let mut client = Client::new("192.168.1.1:69").expect("connect");
//! let data = client.get("config.txt", Mode::Octet).expect("get");
//! ```
//!
//! # Example: Server
//!
//! ```no_run
//! use pfs_tftp::Server;
//! use std::path::Path;
//!
//! let server = Server::bind("0.0.0.0:69", Path::new("/tftpboot")).expect("bind");
//! server.run().expect("run");
//! ```
mod client;
mod error;
mod server;
pub use client::Client;
pub use error::{Error, Result};
pub use pfs_tftp_proto::Mode;
pub use server::{Server, ServerBuilder, ServerConfig};

View File

@@ -0,0 +1,569 @@
//! TFTP Server implementation.
//!
//! Provides a synchronous TFTP server that serves files from a root directory.
use std::{
fs::{File, OpenOptions},
io::{Read, Write},
net::{SocketAddr, ToSocketAddrs, UdpSocket},
path::{Path, PathBuf},
time::Duration,
};
use pfs_tftp_proto::{ErrorCode, Event, Mode, Packet, ServerState, MAX_DATA_SIZE, TFTP_PORT};
use crate::{Error, Result};
/// Default timeout for transfer operations (in seconds).
const DEFAULT_TIMEOUT_SECS: u64 = 5;
/// Default number of retries before giving up on a transfer.
const DEFAULT_RETRIES: u32 = 3;
/// Maximum packet size.
const MAX_PACKET_SIZE: usize = 4 + MAX_DATA_SIZE;
/// Configuration for the TFTP server.
#[derive(Debug, Clone)]
pub struct ServerConfig {
/// Root directory to serve files from/to.
pub root_dir: PathBuf,
/// Allow write operations (WRQ).
pub allow_write: bool,
/// Allow overwriting existing files.
pub allow_overwrite: bool,
/// Timeout for operations.
pub timeout: Duration,
/// Number of retries.
pub retries: u32,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
root_dir: PathBuf::from("."),
allow_write: false,
allow_overwrite: false,
timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
retries: DEFAULT_RETRIES,
}
}
}
/// TFTP server.
///
/// Serves files from a configured root directory.
///
/// # Example
///
/// ```no_run
/// use pfs_tftp::Server;
/// use std::path::Path;
///
/// let server = Server::bind("0.0.0.0:69", Path::new("/tftpboot")).expect("bind");
/// server.run().expect("run");
/// ```
pub struct Server {
/// Listening socket on the well-known TFTP port.
socket: UdpSocket,
/// Server configuration.
config: ServerConfig,
}
impl Server {
/// Create a new TFTP server bound to the specified address.
///
/// # Arguments
///
/// * `addr` - The address to bind to (e.g., "0.0.0.0:69")
/// * `root_dir` - The root directory to serve files from
///
/// # Errors
///
/// Returns an error if the socket cannot be bound or the root directory
/// is invalid.
pub fn bind<A: ToSocketAddrs>(addr: A, root_dir: &Path) -> Result<Self> {
let socket = UdpSocket::bind(addr)?;
let config = ServerConfig {
root_dir: root_dir.to_path_buf(),
..Default::default()
};
Ok(Self { socket, config })
}
/// Create a server with custom configuration.
///
/// # Errors
///
/// Returns an error if the socket cannot be bound.
pub fn with_config<A: ToSocketAddrs>(addr: A, config: ServerConfig) -> Result<Self> {
let socket = UdpSocket::bind(addr)?;
Ok(Self { socket, config })
}
/// Enable or disable write operations.
#[must_use]
pub fn allow_write(mut self, allow: bool) -> Self {
self.config.allow_write = allow;
self
}
/// Enable or disable overwriting existing files.
#[must_use]
pub fn allow_overwrite(mut self, allow: bool) -> Self {
self.config.allow_overwrite = allow;
self
}
/// Run the server, handling requests indefinitely.
///
/// This method blocks and handles one request at a time.
///
/// # Errors
///
/// Returns an error if a fatal socket error occurs.
pub fn run(&self) -> Result<()> {
let mut buf = [0u8; MAX_PACKET_SIZE];
loop {
let (len, client_addr) = self.socket.recv_from(&mut buf)?;
// Try to parse the request
let packet = match Packet::parse(&buf[..len]) {
Ok(p) => p,
Err(e) => {
eprintln!("Invalid packet from {client_addr}: {e}");
continue;
}
};
// Handle the request
if let Err(e) = self.handle_request(client_addr, &packet) {
eprintln!("Error handling request from {client_addr}: {e}");
}
}
}
/// Handle a single request, then return.
///
/// This is useful for testing or single-request operation.
///
/// # Errors
///
/// Returns an error if a fatal socket error occurs.
pub fn handle_one(&self) -> Result<()> {
let mut buf = [0u8; MAX_PACKET_SIZE];
let (len, client_addr) = self.socket.recv_from(&mut buf)?;
let packet = Packet::parse(&buf[..len])?;
self.handle_request(client_addr, &packet)
}
/// Handle a single request.
fn handle_request(&self, client_addr: SocketAddr, packet: &Packet) -> Result<()> {
// Create a new socket for this transfer with a random TID
// RFC 1350 Section 4: "each end of the connection chooses a TID for
// itself, to be used for the duration of that connection"
let transfer_socket = UdpSocket::bind("0.0.0.0:0")?;
transfer_socket.set_read_timeout(Some(self.config.timeout))?;
transfer_socket.connect(client_addr)?;
match packet {
Packet::ReadRequest { filename, mode } => {
self.handle_rrq(&transfer_socket, filename, *mode)
}
Packet::WriteRequest { filename, mode } => {
self.handle_wrq(&transfer_socket, filename, *mode)
}
_ => {
// Unexpected packet type on the well-known port
let error = Packet::error(ErrorCode::IllegalOperation, "Expected RRQ or WRQ");
let _ = transfer_socket.send(&error.serialize()?);
Ok(())
}
}
}
/// Validate and resolve a filename to a path within the root directory.
fn resolve_path(&self, filename: &str) -> Result<PathBuf> {
// Security: Prevent path traversal attacks
// RFC 1350 Security Considerations: "care must be taken in the rights
// granted to a TFTP server process so as not to violate the security
// of the server hosts file system"
// Reject absolute paths
if filename.starts_with('/') || filename.starts_with('\\') {
return Err(Error::InvalidFilename(
"absolute paths not allowed".to_string(),
));
}
// Reject path traversal
if filename.contains("..") {
return Err(Error::InvalidFilename(
"path traversal not allowed".to_string(),
));
}
// Reject paths with null bytes
if filename.contains('\0') {
return Err(Error::InvalidFilename(
"null bytes not allowed".to_string(),
));
}
let path = self.config.root_dir.join(filename);
// Verify the resolved path is still within root_dir
let canonical_root = self
.config
.root_dir
.canonicalize()
.map_err(|e| Error::FileAccess(format!("cannot access root directory: {e}")))?;
// For non-existent files (WRQ), we check the parent
let check_path = if path.exists() {
path.canonicalize()?
} else {
let parent = path
.parent()
.ok_or_else(|| Error::InvalidFilename("invalid path".to_string()))?;
if !parent.exists() {
return Err(Error::FileAccess(format!(
"parent directory does not exist: {}",
parent.display()
)));
}
parent.canonicalize()?.join(path.file_name().ok_or_else(|| {
Error::InvalidFilename("missing filename".to_string())
})?)
};
if !check_path.starts_with(&canonical_root) {
return Err(Error::InvalidFilename(
"path outside root directory".to_string(),
));
}
Ok(path)
}
/// Handle a read request (RRQ).
fn handle_rrq(&self, socket: &UdpSocket, filename: &str, mode: Mode) -> Result<()> {
let path = match self.resolve_path(filename) {
Ok(p) => p,
Err(e) => {
let error = Packet::error(ErrorCode::AccessViolation, e.to_string());
let _ = socket.send(&error.serialize()?);
return Ok(());
}
};
// Open the file
let mut file = match File::open(&path) {
Ok(f) => f,
Err(e) => {
let (code, msg) = match e.kind() {
std::io::ErrorKind::NotFound => {
(ErrorCode::FileNotFound, "File not found".to_string())
}
std::io::ErrorKind::PermissionDenied => {
(ErrorCode::AccessViolation, "Permission denied".to_string())
}
_ => (ErrorCode::NotDefined, e.to_string()),
};
let error = Packet::error(code, msg);
let _ = socket.send(&error.serialize()?);
return Ok(());
}
};
// Create server state for RRQ (server sends data)
let request = Packet::rrq(filename, mode);
let mut state = match ServerState::from_request(&request) {
Ok(s) => s,
Err(e) => {
let error = Packet::error(ErrorCode::IllegalOperation, e.to_string());
let _ = socket.send(&error.serialize()?);
return Ok(());
}
};
// Start transfer - get NeedData event
let event = state.start();
let mut read_buf = [0u8; MAX_DATA_SIZE];
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
let mut last_sent: Option<Vec<u8>> = None;
let mut retries_left = self.config.retries;
// Handle initial NeedData event
if let Event::NeedData { block } = event {
let n = file.read(&mut read_buf)?;
let send_event = state.provide_data(block, read_buf[..n].to_vec())?;
if let Event::Send(packet) = send_event {
let bytes = packet.serialize()?;
socket.send(&bytes)?;
last_sent = Some(bytes);
}
}
// Main transfer loop
loop {
let len = match socket.recv(&mut recv_buf) {
Ok(len) => len,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
if retries_left == 0 {
return Err(Error::Timeout {
retries: self.config.retries,
});
}
retries_left -= 1;
if let Some(ref last) = last_sent {
socket.send(last)?;
}
continue;
}
Err(e) => return Err(e.into()),
};
let packet = Packet::parse(&recv_buf[..len])?;
if let Packet::Error { code, message } = packet {
return Err(Error::Remote { code, message });
}
let events = state.receive(&packet)?;
for event in events {
match event {
Event::NeedData { block } => {
let n = file.read(&mut read_buf)?;
let send_event = state.provide_data(block, read_buf[..n].to_vec())?;
if let Event::Send(packet) = send_event {
let bytes = packet.serialize()?;
socket.send(&bytes)?;
last_sent = Some(bytes);
}
}
Event::Complete => {
return Ok(());
}
// Send and ReceivedData not expected in RRQ handling
Event::Send(_) | Event::ReceivedData { .. } => {}
}
}
if state.is_complete() {
return Ok(());
}
retries_left = self.config.retries;
}
}
/// Handle a write request (WRQ).
#[allow(clippy::too_many_lines)]
fn handle_wrq(&self, socket: &UdpSocket, filename: &str, mode: Mode) -> Result<()> {
// Check if writes are allowed
if !self.config.allow_write {
let error = Packet::error(ErrorCode::AccessViolation, "Write not allowed");
let _ = socket.send(&error.serialize()?);
return Ok(());
}
let path = match self.resolve_path(filename) {
Ok(p) => p,
Err(e) => {
let error = Packet::error(ErrorCode::AccessViolation, e.to_string());
let _ = socket.send(&error.serialize()?);
return Ok(());
}
};
// Check if file exists and overwrite is not allowed
if path.exists() && !self.config.allow_overwrite {
let error = Packet::error(ErrorCode::FileAlreadyExists, "File already exists");
let _ = socket.send(&error.serialize()?);
return Ok(());
}
// Open/create the file
let mut file = match OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&path)
{
Ok(f) => f,
Err(e) => {
let (code, msg) = match e.kind() {
std::io::ErrorKind::PermissionDenied => {
(ErrorCode::AccessViolation, "Permission denied".to_string())
}
_ => (ErrorCode::NotDefined, e.to_string()),
};
let error = Packet::error(code, msg);
let _ = socket.send(&error.serialize()?);
return Ok(());
}
};
// Create server state for WRQ (server receives data)
let request = Packet::wrq(filename, mode);
let mut state = match ServerState::from_request(&request) {
Ok(s) => s,
Err(e) => {
let error = Packet::error(ErrorCode::IllegalOperation, e.to_string());
let _ = socket.send(&error.serialize()?);
return Ok(());
}
};
// Start transfer - send ACK 0
let event = state.start();
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
let mut last_sent: Option<Vec<u8>> = None;
let mut retries_left = self.config.retries;
if let Event::Send(packet) = event {
let bytes = packet.serialize()?;
socket.send(&bytes)?;
last_sent = Some(bytes);
}
// Main transfer loop
loop {
let len = match socket.recv(&mut recv_buf) {
Ok(len) => len,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
if retries_left == 0 {
// Clean up partial file on timeout
let _ = std::fs::remove_file(&path);
return Err(Error::Timeout {
retries: self.config.retries,
});
}
retries_left -= 1;
if let Some(ref last) = last_sent {
socket.send(last)?;
}
continue;
}
Err(e) => {
let _ = std::fs::remove_file(&path);
return Err(e.into());
}
};
let packet = Packet::parse(&recv_buf[..len])?;
if let Packet::Error { code, message } = packet {
let _ = std::fs::remove_file(&path);
return Err(Error::Remote { code, message });
}
let events = state.receive(&packet)?;
for event in events {
match event {
Event::ReceivedData { data, .. } => {
if let Err(e) = file.write_all(&data) {
let error = match e.kind() {
std::io::ErrorKind::StorageFull => {
Packet::error(ErrorCode::DiskFull, "Disk full")
}
_ => Packet::error(ErrorCode::NotDefined, e.to_string()),
};
let _ = socket.send(&error.serialize()?);
let _ = std::fs::remove_file(&path);
return Ok(());
}
}
Event::Send(packet) => {
let bytes = packet.serialize()?;
socket.send(&bytes)?;
last_sent = Some(bytes);
}
Event::Complete => {
file.flush()?;
return Ok(());
}
// NeedData not expected in WRQ handling
Event::NeedData { .. } => {}
}
}
if state.is_complete() {
file.flush()?;
return Ok(());
}
retries_left = self.config.retries;
}
}
}
/// Builder for creating a server with custom options.
pub struct ServerBuilder {
config: ServerConfig,
}
impl ServerBuilder {
/// Create a new server builder with the specified root directory.
#[must_use]
pub fn new(root_dir: impl Into<PathBuf>) -> Self {
Self {
config: ServerConfig {
root_dir: root_dir.into(),
..Default::default()
},
}
}
/// Allow write operations.
#[must_use]
pub const fn allow_write(mut self, allow: bool) -> Self {
self.config.allow_write = allow;
self
}
/// Allow overwriting existing files.
#[must_use]
pub const fn allow_overwrite(mut self, allow: bool) -> Self {
self.config.allow_overwrite = allow;
self
}
/// Set the timeout for operations.
#[must_use]
pub const fn timeout(mut self, timeout: Duration) -> Self {
self.config.timeout = timeout;
self
}
/// Set the number of retries.
#[must_use]
pub const fn retries(mut self, retries: u32) -> Self {
self.config.retries = retries;
self
}
/// Build the server bound to the specified address.
///
/// # Errors
///
/// Returns an error if the socket cannot be bound.
pub fn bind<A: ToSocketAddrs>(self, addr: A) -> Result<Server> {
Server::with_config(addr, self.config)
}
/// Build the server bound to the default TFTP port.
///
/// # Errors
///
/// Returns an error if the socket cannot be bound.
pub fn bind_default(self) -> Result<Server> {
self.bind(("0.0.0.0", TFTP_PORT))
}
}

View File

@@ -1,3 +0,0 @@
fn main() {
println!("Hello, world!");
}