feat: implement TFTP protocol crate (pfs-tftp-proto)
This commit introduces a sans-IO TFTP protocol implementation following RFC 1350. The protocol crate provides: - Packet types: RRQ, WRQ, DATA, ACK, ERROR with full serialization/parsing - Error codes as defined in RFC 1350 Appendix - Transfer modes: octet (binary) and netascii - Client and server state machines for managing protocol flow - Comprehensive tests for all packet types and state transitions The sans-IO design separates protocol logic from I/O operations, making the code testable and reusable across different I/O implementations. Key design decisions: - Mail mode is explicitly rejected as obsolete per RFC 1350 - Block numbers wrap around at 65535 for large file support - State machines emit events that tell the I/O layer what to do next - All protocol-specific values are documented with RFC citations 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
66
Cargo.lock
generated
66
Cargo.lock
generated
@@ -5,3 +5,69 @@ version = 4
|
|||||||
[[package]]
|
[[package]]
|
||||||
name = "pfs-tftp"
|
name = "pfs-tftp"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"pfs-tftp-proto",
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pfs-tftp-proto"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"thiserror",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro2"
|
||||||
|
version = "1.0.103"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quote"
|
||||||
|
version = "1.0.42"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "2.0.111"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror"
|
||||||
|
version = "2.0.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8"
|
||||||
|
dependencies = [
|
||||||
|
"thiserror-impl",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "thiserror-impl"
|
||||||
|
version = "2.0.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-ident"
|
||||||
|
version = "1.0.22"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
||||||
|
|||||||
15
Cargo.toml
15
Cargo.toml
@@ -1,14 +1,17 @@
|
|||||||
[package]
|
[workspace]
|
||||||
name = "pfs-tftp"
|
resolver = "3"
|
||||||
|
members = ["crates/*"]
|
||||||
|
|
||||||
|
[workspace.package]
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
license = "MIT"
|
||||||
|
repository = "https://github.com/pfs/pfs-tftp"
|
||||||
|
|
||||||
[lints.rust]
|
[workspace.lints.rust]
|
||||||
unsafe_code = "forbid"
|
unsafe_code = "forbid"
|
||||||
|
|
||||||
[lints.clippy]
|
[workspace.lints.clippy]
|
||||||
pedantic = { level = "warn", priority = -1 }
|
pedantic = { level = "warn", priority = -1 }
|
||||||
todo = "warn"
|
todo = "warn"
|
||||||
unwrap_used = "warn"
|
unwrap_used = "warn"
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
|
|||||||
15
crates/pfs-tftp-proto/Cargo.toml
Normal file
15
crates/pfs-tftp-proto/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
[package]
|
||||||
|
name = "pfs-tftp-proto"
|
||||||
|
description = "Sans-IO TFTP protocol implementation (RFC 1350)"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
repository.workspace = true
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
thiserror = "2"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
193
crates/pfs-tftp-proto/src/error.rs
Normal file
193
crates/pfs-tftp-proto/src/error.rs
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
//! TFTP Error Types
|
||||||
|
//!
|
||||||
|
//! Defines error codes and error handling for TFTP protocol.
|
||||||
|
//!
|
||||||
|
//! # RFC 1350 Error Codes (Appendix)
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! Value Meaning
|
||||||
|
//! 0 Not defined, see error message (if any).
|
||||||
|
//! 1 File not found.
|
||||||
|
//! 2 Access violation.
|
||||||
|
//! 3 Disk full or allocation exceeded.
|
||||||
|
//! 4 Illegal TFTP operation.
|
||||||
|
//! 5 Unknown transfer ID.
|
||||||
|
//! 6 File already exists.
|
||||||
|
//! 7 No such user.
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// TFTP error codes as defined in RFC 1350 Appendix.
|
||||||
|
///
|
||||||
|
/// These codes are sent in ERROR packets to indicate the nature of the error.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
#[repr(u16)]
|
||||||
|
pub enum ErrorCode {
|
||||||
|
/// Not defined, see error message (if any).
|
||||||
|
/// Used when the error doesn't fit other categories.
|
||||||
|
NotDefined = 0,
|
||||||
|
|
||||||
|
/// File not found.
|
||||||
|
/// The requested file does not exist on the server.
|
||||||
|
FileNotFound = 1,
|
||||||
|
|
||||||
|
/// Access violation.
|
||||||
|
/// The client does not have permission to access the file.
|
||||||
|
AccessViolation = 2,
|
||||||
|
|
||||||
|
/// Disk full or allocation exceeded.
|
||||||
|
/// There is insufficient storage space to complete the transfer.
|
||||||
|
DiskFull = 3,
|
||||||
|
|
||||||
|
/// Illegal TFTP operation.
|
||||||
|
/// The requested operation is not valid (e.g., malformed packet).
|
||||||
|
IllegalOperation = 4,
|
||||||
|
|
||||||
|
/// Unknown transfer ID.
|
||||||
|
/// RFC 1350 Section 4: "If a source TID does not match, the packet should
|
||||||
|
/// be discarded as erroneously sent from somewhere else. An error packet
|
||||||
|
/// should be sent to the source of the incorrect packet."
|
||||||
|
UnknownTransferId = 5,
|
||||||
|
|
||||||
|
/// File already exists.
|
||||||
|
/// Returned when attempting to write a file that already exists
|
||||||
|
/// and the server policy doesn't allow overwriting.
|
||||||
|
FileAlreadyExists = 6,
|
||||||
|
|
||||||
|
/// No such user.
|
||||||
|
/// The specified user does not exist (relevant for mail mode, which is obsolete).
|
||||||
|
NoSuchUser = 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ErrorCode {
|
||||||
|
/// Create an `ErrorCode` from a raw u16 value.
|
||||||
|
///
|
||||||
|
/// Returns `NotDefined` for any unknown error code, as per RFC 1350
|
||||||
|
/// which states error code 0 is "Not defined".
|
||||||
|
#[must_use]
|
||||||
|
pub const fn from_u16(value: u16) -> Self {
|
||||||
|
match value {
|
||||||
|
1 => Self::FileNotFound,
|
||||||
|
2 => Self::AccessViolation,
|
||||||
|
3 => Self::DiskFull,
|
||||||
|
4 => Self::IllegalOperation,
|
||||||
|
5 => Self::UnknownTransferId,
|
||||||
|
6 => Self::FileAlreadyExists,
|
||||||
|
7 => Self::NoSuchUser,
|
||||||
|
// RFC 1350: Error code 0 is "Not defined", also used as fallback
|
||||||
|
_ => Self::NotDefined,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the default error message for this error code.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn default_message(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::NotDefined => "Not defined",
|
||||||
|
Self::FileNotFound => "File not found",
|
||||||
|
Self::AccessViolation => "Access violation",
|
||||||
|
Self::DiskFull => "Disk full or allocation exceeded",
|
||||||
|
Self::IllegalOperation => "Illegal TFTP operation",
|
||||||
|
Self::UnknownTransferId => "Unknown transfer ID",
|
||||||
|
Self::FileAlreadyExists => "File already exists",
|
||||||
|
Self::NoSuchUser => "No such user",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ErrorCode> for u16 {
|
||||||
|
fn from(code: ErrorCode) -> Self {
|
||||||
|
code as Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Protocol-level errors for TFTP operations.
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum Error {
|
||||||
|
/// Packet is too short to be valid.
|
||||||
|
#[error("packet too short: expected at least {expected} bytes, got {actual}")]
|
||||||
|
PacketTooShort { expected: usize, actual: usize },
|
||||||
|
|
||||||
|
/// Invalid opcode in packet.
|
||||||
|
#[error("invalid opcode: {0}")]
|
||||||
|
InvalidOpcode(u16),
|
||||||
|
|
||||||
|
/// Invalid transfer mode in RRQ/WRQ packet.
|
||||||
|
#[error("invalid mode: {0}")]
|
||||||
|
InvalidMode(String),
|
||||||
|
|
||||||
|
/// Missing null terminator in string field.
|
||||||
|
#[error("missing null terminator in {field}")]
|
||||||
|
MissingNullTerminator { field: &'static str },
|
||||||
|
|
||||||
|
/// String field contains invalid UTF-8.
|
||||||
|
#[error("invalid UTF-8 in {field}: {source}")]
|
||||||
|
InvalidUtf8 {
|
||||||
|
field: &'static str,
|
||||||
|
#[source]
|
||||||
|
source: std::str::Utf8Error,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Data exceeds maximum block size.
|
||||||
|
#[error("data too large: {size} bytes exceeds maximum of 512")]
|
||||||
|
DataTooLarge { size: usize },
|
||||||
|
|
||||||
|
/// Buffer too small for serialization.
|
||||||
|
#[error("buffer too small: need {needed} bytes, have {available}")]
|
||||||
|
BufferTooSmall { needed: usize, available: usize },
|
||||||
|
|
||||||
|
/// TFTP error received from remote peer.
|
||||||
|
#[error("TFTP error {code:?}: {message}")]
|
||||||
|
TftpError { code: ErrorCode, message: String },
|
||||||
|
|
||||||
|
/// Protocol state error.
|
||||||
|
#[error("protocol state error: {0}")]
|
||||||
|
StateError(String),
|
||||||
|
|
||||||
|
/// Unexpected packet type received.
|
||||||
|
#[error("unexpected packet type: expected {expected}, got {actual}")]
|
||||||
|
UnexpectedPacket {
|
||||||
|
expected: &'static str,
|
||||||
|
actual: &'static str,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Block number mismatch.
|
||||||
|
#[error("block number mismatch: expected {expected}, got {actual}")]
|
||||||
|
BlockMismatch { expected: u16, actual: u16 },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result type for TFTP operations.
|
||||||
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_error_code_from_u16() {
|
||||||
|
assert_eq!(ErrorCode::from_u16(0), ErrorCode::NotDefined);
|
||||||
|
assert_eq!(ErrorCode::from_u16(1), ErrorCode::FileNotFound);
|
||||||
|
assert_eq!(ErrorCode::from_u16(2), ErrorCode::AccessViolation);
|
||||||
|
assert_eq!(ErrorCode::from_u16(3), ErrorCode::DiskFull);
|
||||||
|
assert_eq!(ErrorCode::from_u16(4), ErrorCode::IllegalOperation);
|
||||||
|
assert_eq!(ErrorCode::from_u16(5), ErrorCode::UnknownTransferId);
|
||||||
|
assert_eq!(ErrorCode::from_u16(6), ErrorCode::FileAlreadyExists);
|
||||||
|
assert_eq!(ErrorCode::from_u16(7), ErrorCode::NoSuchUser);
|
||||||
|
// Unknown codes should map to NotDefined
|
||||||
|
assert_eq!(ErrorCode::from_u16(8), ErrorCode::NotDefined);
|
||||||
|
assert_eq!(ErrorCode::from_u16(255), ErrorCode::NotDefined);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_error_code_to_u16() {
|
||||||
|
assert_eq!(u16::from(ErrorCode::NotDefined), 0);
|
||||||
|
assert_eq!(u16::from(ErrorCode::FileNotFound), 1);
|
||||||
|
assert_eq!(u16::from(ErrorCode::AccessViolation), 2);
|
||||||
|
assert_eq!(u16::from(ErrorCode::DiskFull), 3);
|
||||||
|
assert_eq!(u16::from(ErrorCode::IllegalOperation), 4);
|
||||||
|
assert_eq!(u16::from(ErrorCode::UnknownTransferId), 5);
|
||||||
|
assert_eq!(u16::from(ErrorCode::FileAlreadyExists), 6);
|
||||||
|
assert_eq!(u16::from(ErrorCode::NoSuchUser), 7);
|
||||||
|
}
|
||||||
|
}
|
||||||
21
crates/pfs-tftp-proto/src/lib.rs
Normal file
21
crates/pfs-tftp-proto/src/lib.rs
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
//! Sans-IO TFTP Protocol Implementation (RFC 1350)
|
||||||
|
//!
|
||||||
|
//! This crate provides a pure protocol implementation without any I/O operations.
|
||||||
|
//! It handles packet parsing, serialization, and protocol state management.
|
||||||
|
//!
|
||||||
|
//! # RFC 1350 Overview
|
||||||
|
//!
|
||||||
|
//! TFTP (Trivial File Transfer Protocol) is a simple file transfer protocol
|
||||||
|
//! that operates over UDP. Key characteristics:
|
||||||
|
//! - Uses fixed 512-byte data blocks
|
||||||
|
//! - Stop-and-wait protocol (each packet must be acknowledged)
|
||||||
|
//! - Five packet types: RRQ, WRQ, DATA, ACK, ERROR
|
||||||
|
//! - Default port 69 for initial server connection
|
||||||
|
|
||||||
|
mod error;
|
||||||
|
mod packet;
|
||||||
|
mod state;
|
||||||
|
|
||||||
|
pub use error::{Error, ErrorCode, Result};
|
||||||
|
pub use packet::{Mode, Opcode, Packet, MAX_DATA_SIZE, TFTP_PORT};
|
||||||
|
pub use state::{ClientState, Event, ServerState, TransferDirection};
|
||||||
632
crates/pfs-tftp-proto/src/packet.rs
Normal file
632
crates/pfs-tftp-proto/src/packet.rs
Normal file
@@ -0,0 +1,632 @@
|
|||||||
|
//! TFTP Packet Types and Serialization
|
||||||
|
//!
|
||||||
|
//! This module implements all five TFTP packet types as defined in RFC 1350:
|
||||||
|
//! - RRQ (Read Request)
|
||||||
|
//! - WRQ (Write Request)
|
||||||
|
//! - DATA
|
||||||
|
//! - ACK
|
||||||
|
//! - ERROR
|
||||||
|
//!
|
||||||
|
//! # RFC 1350 Section 5 - TFTP Packets
|
||||||
|
//!
|
||||||
|
//! ```text
|
||||||
|
//! TFTP supports five types of packets:
|
||||||
|
//!
|
||||||
|
//! opcode operation
|
||||||
|
//! 1 Read request (RRQ)
|
||||||
|
//! 2 Write request (WRQ)
|
||||||
|
//! 3 Data (DATA)
|
||||||
|
//! 4 Acknowledgment (ACK)
|
||||||
|
//! 5 Error (ERROR)
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use crate::{Error, ErrorCode, Result};
|
||||||
|
|
||||||
|
/// Standard TFTP server port as defined in RFC 1350.
|
||||||
|
///
|
||||||
|
/// RFC 1350 Section 4: "A requesting host chooses its source TID as described
|
||||||
|
/// above, and sends its initial request to the known TID 69 decimal (105 octal)
|
||||||
|
/// on the serving host."
|
||||||
|
pub const TFTP_PORT: u16 = 69;
|
||||||
|
|
||||||
|
/// Maximum data size in a DATA packet.
|
||||||
|
///
|
||||||
|
/// RFC 1350 Section 2: "the connection is opened and the file is sent in fixed
|
||||||
|
/// length blocks of 512 bytes."
|
||||||
|
///
|
||||||
|
/// RFC 1350 Section 5: "The data field is from zero to 512 bytes long."
|
||||||
|
pub const MAX_DATA_SIZE: usize = 512;
|
||||||
|
|
||||||
|
/// Minimum packet size (just an opcode).
|
||||||
|
const MIN_PACKET_SIZE: usize = 2;
|
||||||
|
|
||||||
|
/// TFTP packet opcodes.
|
||||||
|
///
|
||||||
|
/// RFC 1350 Section 5: "The TFTP header of a packet contains the opcode
|
||||||
|
/// associated with that packet."
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
#[repr(u16)]
|
||||||
|
pub enum Opcode {
|
||||||
|
/// Read request (RRQ)
|
||||||
|
Rrq = 1,
|
||||||
|
/// Write request (WRQ)
|
||||||
|
Wrq = 2,
|
||||||
|
/// Data packet
|
||||||
|
Data = 3,
|
||||||
|
/// Acknowledgment
|
||||||
|
Ack = 4,
|
||||||
|
/// Error
|
||||||
|
Error = 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Opcode {
|
||||||
|
/// Create an `Opcode` from a raw u16 value.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Error::InvalidOpcode` if the value is not a valid opcode.
|
||||||
|
pub const fn from_u16(value: u16) -> Result<Self> {
|
||||||
|
match value {
|
||||||
|
1 => Ok(Self::Rrq),
|
||||||
|
2 => Ok(Self::Wrq),
|
||||||
|
3 => Ok(Self::Data),
|
||||||
|
4 => Ok(Self::Ack),
|
||||||
|
5 => Ok(Self::Error),
|
||||||
|
_ => Err(Error::InvalidOpcode(value)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Opcode> for u16 {
|
||||||
|
fn from(opcode: Opcode) -> Self {
|
||||||
|
opcode as Self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Transfer mode for TFTP operations.
|
||||||
|
///
|
||||||
|
/// RFC 1350 Section 1: "Three modes of transfer are currently supported:
|
||||||
|
/// netascii [...]; octet [...]; mail [...]. (The mail mode is obsolete
|
||||||
|
/// and should not be implemented or used.)"
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||||
|
pub enum Mode {
|
||||||
|
/// `NetASCII` mode - text transfer with CR/LF line endings.
|
||||||
|
///
|
||||||
|
/// RFC 1350: "netascii (This is ascii as defined in 'USA Standard Code
|
||||||
|
/// for Information Interchange' with the modifications specified in
|
||||||
|
/// 'Telnet Protocol Specification'.) Note that it is 8 bit ascii."
|
||||||
|
NetAscii,
|
||||||
|
|
||||||
|
/// Octet mode - raw binary transfer.
|
||||||
|
///
|
||||||
|
/// RFC 1350: "octet (This replaces the 'binary' mode of previous versions
|
||||||
|
/// of this document.) raw 8 bit bytes [...] If a host receives a octet
|
||||||
|
/// file and then returns it, the returned file must be identical to the
|
||||||
|
/// original."
|
||||||
|
#[default]
|
||||||
|
Octet,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Mode {
|
||||||
|
/// Parse a mode string (case-insensitive).
|
||||||
|
///
|
||||||
|
/// RFC 1350 Section 5: "The mode field contains the string 'netascii',
|
||||||
|
/// 'octet', or 'mail' (or any combination of upper and lower case,
|
||||||
|
/// such as 'NETASCII', '`NetAscii`', etc.)"
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns `Error::InvalidMode` if the string is not a recognized mode.
|
||||||
|
pub fn parse(s: &str) -> Result<Self> {
|
||||||
|
match s.to_ascii_lowercase().as_str() {
|
||||||
|
"netascii" => Ok(Self::NetAscii),
|
||||||
|
"octet" => Ok(Self::Octet),
|
||||||
|
// Mail mode is obsolete per RFC 1350
|
||||||
|
"mail" => Err(Error::InvalidMode(
|
||||||
|
"mail mode is obsolete and not supported".to_string(),
|
||||||
|
)),
|
||||||
|
_ => Err(Error::InvalidMode(s.to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the string representation of the mode.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn as_str(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::NetAscii => "netascii",
|
||||||
|
Self::Octet => "octet",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A TFTP packet.
|
||||||
|
///
|
||||||
|
/// All packet formats are defined in RFC 1350 Section 5 and Appendix I.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum Packet {
|
||||||
|
/// Read Request (RRQ) - opcode 1
|
||||||
|
///
|
||||||
|
/// RFC 1350 Figure 5-1:
|
||||||
|
/// ```text
|
||||||
|
/// 2 bytes string 1 byte string 1 byte
|
||||||
|
/// ------------------------------------------------
|
||||||
|
/// | Opcode | Filename | 0 | Mode | 0 |
|
||||||
|
/// ------------------------------------------------
|
||||||
|
/// ```
|
||||||
|
ReadRequest { filename: String, mode: Mode },
|
||||||
|
|
||||||
|
/// Write Request (WRQ) - opcode 2
|
||||||
|
///
|
||||||
|
/// Same format as RRQ (RFC 1350 Figure 5-1).
|
||||||
|
WriteRequest { filename: String, mode: Mode },
|
||||||
|
|
||||||
|
/// Data packet - opcode 3
|
||||||
|
///
|
||||||
|
/// RFC 1350 Figure 5-2:
|
||||||
|
/// ```text
|
||||||
|
/// 2 bytes 2 bytes n bytes
|
||||||
|
/// ----------------------------------
|
||||||
|
/// | Opcode | Block # | Data |
|
||||||
|
/// ----------------------------------
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// RFC 1350 Section 5: "The block numbers on data packets begin with one
|
||||||
|
/// and increase by one for each new block of data. [...] The data field
|
||||||
|
/// is from zero to 512 bytes long. If it is 512 bytes long, the block is
|
||||||
|
/// not the last block of data; if it is from zero to 511 bytes long, it
|
||||||
|
/// signals the end of the transfer."
|
||||||
|
Data { block: u16, data: Vec<u8> },
|
||||||
|
|
||||||
|
/// Acknowledgment packet - opcode 4
|
||||||
|
///
|
||||||
|
/// RFC 1350 Figure 5-3:
|
||||||
|
/// ```text
|
||||||
|
/// 2 bytes 2 bytes
|
||||||
|
/// ---------------------
|
||||||
|
/// | Opcode | Block # |
|
||||||
|
/// ---------------------
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// RFC 1350 Section 5: "The block number in an ACK echoes the block number
|
||||||
|
/// of the DATA packet being acknowledged. A WRQ is acknowledged with an
|
||||||
|
/// ACK packet having a block number of zero."
|
||||||
|
Ack { block: u16 },
|
||||||
|
|
||||||
|
/// Error packet - opcode 5
|
||||||
|
///
|
||||||
|
/// RFC 1350 Figure 5-4:
|
||||||
|
/// ```text
|
||||||
|
/// 2 bytes 2 bytes string 1 byte
|
||||||
|
/// -----------------------------------------
|
||||||
|
/// | Opcode | ErrorCode | ErrMsg | 0 |
|
||||||
|
/// -----------------------------------------
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// RFC 1350 Section 5: "An ERROR packet can be the acknowledgment of any
|
||||||
|
/// other type of packet. The error code is an integer indicating the nature
|
||||||
|
/// of the error. [...] The error message is intended for human consumption,
|
||||||
|
/// and should be in netascii."
|
||||||
|
Error { code: ErrorCode, message: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Packet {
|
||||||
|
/// Parse a TFTP packet from raw bytes.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the packet is malformed.
|
||||||
|
pub fn parse(data: &[u8]) -> Result<Self> {
|
||||||
|
if data.len() < MIN_PACKET_SIZE {
|
||||||
|
return Err(Error::PacketTooShort {
|
||||||
|
expected: MIN_PACKET_SIZE,
|
||||||
|
actual: data.len(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let opcode = u16::from_be_bytes([data[0], data[1]]);
|
||||||
|
let opcode = Opcode::from_u16(opcode)?;
|
||||||
|
let payload = &data[2..];
|
||||||
|
|
||||||
|
match opcode {
|
||||||
|
Opcode::Rrq => Self::parse_request(payload, false),
|
||||||
|
Opcode::Wrq => Self::parse_request(payload, true),
|
||||||
|
Opcode::Data => Self::parse_data(payload),
|
||||||
|
Opcode::Ack => Self::parse_ack(payload),
|
||||||
|
Opcode::Error => Self::parse_error(payload),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a RRQ or WRQ packet payload.
|
||||||
|
fn parse_request(payload: &[u8], is_write: bool) -> Result<Self> {
|
||||||
|
// Find the null terminator for filename
|
||||||
|
let filename_end = payload
|
||||||
|
.iter()
|
||||||
|
.position(|&b| b == 0)
|
||||||
|
.ok_or(Error::MissingNullTerminator { field: "filename" })?;
|
||||||
|
|
||||||
|
let filename = std::str::from_utf8(&payload[..filename_end]).map_err(|e| {
|
||||||
|
Error::InvalidUtf8 {
|
||||||
|
field: "filename",
|
||||||
|
source: e,
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Find the mode string after the filename null terminator
|
||||||
|
let mode_start = filename_end + 1;
|
||||||
|
if mode_start >= payload.len() {
|
||||||
|
return Err(Error::MissingNullTerminator { field: "mode" });
|
||||||
|
}
|
||||||
|
|
||||||
|
let mode_end = payload[mode_start..]
|
||||||
|
.iter()
|
||||||
|
.position(|&b| b == 0)
|
||||||
|
.ok_or(Error::MissingNullTerminator { field: "mode" })?
|
||||||
|
+ mode_start;
|
||||||
|
|
||||||
|
let mode_str =
|
||||||
|
std::str::from_utf8(&payload[mode_start..mode_end]).map_err(|e| Error::InvalidUtf8 {
|
||||||
|
field: "mode",
|
||||||
|
source: e,
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mode = Mode::parse(mode_str)?;
|
||||||
|
|
||||||
|
if is_write {
|
||||||
|
Ok(Self::WriteRequest {
|
||||||
|
filename: filename.to_string(),
|
||||||
|
mode,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(Self::ReadRequest {
|
||||||
|
filename: filename.to_string(),
|
||||||
|
mode,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a DATA packet payload.
|
||||||
|
fn parse_data(payload: &[u8]) -> Result<Self> {
|
||||||
|
if payload.len() < 2 {
|
||||||
|
return Err(Error::PacketTooShort {
|
||||||
|
expected: 4, // 2 opcode + 2 block
|
||||||
|
actual: payload.len() + 2,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let block = u16::from_be_bytes([payload[0], payload[1]]);
|
||||||
|
let data = payload[2..].to_vec();
|
||||||
|
|
||||||
|
if data.len() > MAX_DATA_SIZE {
|
||||||
|
return Err(Error::DataTooLarge { size: data.len() });
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self::Data { block, data })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse an ACK packet payload.
|
||||||
|
fn parse_ack(payload: &[u8]) -> Result<Self> {
|
||||||
|
if payload.len() < 2 {
|
||||||
|
return Err(Error::PacketTooShort {
|
||||||
|
expected: 4, // 2 opcode + 2 block
|
||||||
|
actual: payload.len() + 2,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let block = u16::from_be_bytes([payload[0], payload[1]]);
|
||||||
|
Ok(Self::Ack { block })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse an ERROR packet payload.
|
||||||
|
fn parse_error(payload: &[u8]) -> Result<Self> {
|
||||||
|
if payload.len() < 2 {
|
||||||
|
return Err(Error::PacketTooShort {
|
||||||
|
expected: 4, // 2 opcode + 2 error code
|
||||||
|
actual: payload.len() + 2,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let code = u16::from_be_bytes([payload[0], payload[1]]);
|
||||||
|
let code = ErrorCode::from_u16(code);
|
||||||
|
|
||||||
|
// Find the null terminator for error message
|
||||||
|
let msg_start = 2;
|
||||||
|
let msg_end = payload[msg_start..]
|
||||||
|
.iter()
|
||||||
|
.position(|&b| b == 0)
|
||||||
|
.map_or(payload.len(), |pos| pos + msg_start);
|
||||||
|
|
||||||
|
let message = std::str::from_utf8(&payload[msg_start..msg_end])
|
||||||
|
.map_err(|e| Error::InvalidUtf8 {
|
||||||
|
field: "error message",
|
||||||
|
source: e,
|
||||||
|
})?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
Ok(Self::Error { code, message })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Serialize the packet into bytes.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the data is too large.
|
||||||
|
pub fn serialize(&self) -> Result<Vec<u8>> {
|
||||||
|
match self {
|
||||||
|
Self::ReadRequest { filename, mode } => {
|
||||||
|
Ok(Self::serialize_request(Opcode::Rrq, filename, *mode))
|
||||||
|
}
|
||||||
|
Self::WriteRequest { filename, mode } => {
|
||||||
|
Ok(Self::serialize_request(Opcode::Wrq, filename, *mode))
|
||||||
|
}
|
||||||
|
Self::Data { block, data } => Self::serialize_data(*block, data),
|
||||||
|
Self::Ack { block } => Ok(Self::serialize_ack(*block)),
|
||||||
|
Self::Error { code, message } => Ok(Self::serialize_error(*code, message)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Serialize a RRQ or WRQ packet.
|
||||||
|
fn serialize_request(opcode: Opcode, filename: &str, mode: Mode) -> Vec<u8> {
|
||||||
|
let mode_str = mode.as_str();
|
||||||
|
let capacity = 2 + filename.len() + 1 + mode_str.len() + 1;
|
||||||
|
let mut buf = Vec::with_capacity(capacity);
|
||||||
|
|
||||||
|
buf.extend_from_slice(&u16::from(opcode).to_be_bytes());
|
||||||
|
buf.extend_from_slice(filename.as_bytes());
|
||||||
|
buf.push(0);
|
||||||
|
buf.extend_from_slice(mode_str.as_bytes());
|
||||||
|
buf.push(0);
|
||||||
|
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Serialize a DATA packet.
|
||||||
|
fn serialize_data(block: u16, data: &[u8]) -> Result<Vec<u8>> {
|
||||||
|
if data.len() > MAX_DATA_SIZE {
|
||||||
|
return Err(Error::DataTooLarge { size: data.len() });
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut buf = Vec::with_capacity(4 + data.len());
|
||||||
|
buf.extend_from_slice(&u16::from(Opcode::Data).to_be_bytes());
|
||||||
|
buf.extend_from_slice(&block.to_be_bytes());
|
||||||
|
buf.extend_from_slice(data);
|
||||||
|
|
||||||
|
Ok(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Serialize an ACK packet.
|
||||||
|
fn serialize_ack(block: u16) -> Vec<u8> {
|
||||||
|
let mut buf = Vec::with_capacity(4);
|
||||||
|
buf.extend_from_slice(&u16::from(Opcode::Ack).to_be_bytes());
|
||||||
|
buf.extend_from_slice(&block.to_be_bytes());
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Serialize an ERROR packet.
|
||||||
|
fn serialize_error(code: ErrorCode, message: &str) -> Vec<u8> {
|
||||||
|
let capacity = 4 + message.len() + 1;
|
||||||
|
let mut buf = Vec::with_capacity(capacity);
|
||||||
|
|
||||||
|
buf.extend_from_slice(&u16::from(Opcode::Error).to_be_bytes());
|
||||||
|
buf.extend_from_slice(&u16::from(code).to_be_bytes());
|
||||||
|
buf.extend_from_slice(message.as_bytes());
|
||||||
|
buf.push(0);
|
||||||
|
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the opcode of this packet.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn opcode(&self) -> Opcode {
|
||||||
|
match self {
|
||||||
|
Self::ReadRequest { .. } => Opcode::Rrq,
|
||||||
|
Self::WriteRequest { .. } => Opcode::Wrq,
|
||||||
|
Self::Data { .. } => Opcode::Data,
|
||||||
|
Self::Ack { .. } => Opcode::Ack,
|
||||||
|
Self::Error { .. } => Opcode::Error,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a human-readable name for the packet type.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn type_name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::ReadRequest { .. } => "RRQ",
|
||||||
|
Self::WriteRequest { .. } => "WRQ",
|
||||||
|
Self::Data { .. } => "DATA",
|
||||||
|
Self::Ack { .. } => "ACK",
|
||||||
|
Self::Error { .. } => "ERROR",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if this is the final data packet (less than 512 bytes).
|
||||||
|
///
|
||||||
|
/// RFC 1350 Section 5: "If it is 512 bytes long, the block is not the last
|
||||||
|
/// block of data; if it is from zero to 511 bytes long, it signals the end
|
||||||
|
/// of the transfer."
|
||||||
|
#[must_use]
|
||||||
|
pub const fn is_final_data(&self) -> bool {
|
||||||
|
matches!(self, Self::Data { data, .. } if data.len() < MAX_DATA_SIZE)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new RRQ packet.
|
||||||
|
#[must_use]
|
||||||
|
pub fn rrq(filename: impl Into<String>, mode: Mode) -> Self {
|
||||||
|
Self::ReadRequest {
|
||||||
|
filename: filename.into(),
|
||||||
|
mode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new WRQ packet.
|
||||||
|
#[must_use]
|
||||||
|
pub fn wrq(filename: impl Into<String>, mode: Mode) -> Self {
|
||||||
|
Self::WriteRequest {
|
||||||
|
filename: filename.into(),
|
||||||
|
mode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new DATA packet.
|
||||||
|
#[must_use]
|
||||||
|
pub fn data(block: u16, data: Vec<u8>) -> Self {
|
||||||
|
Self::Data { block, data }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new ACK packet.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn ack(block: u16) -> Self {
|
||||||
|
Self::Ack { block }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new ERROR packet.
|
||||||
|
#[must_use]
|
||||||
|
pub fn error(code: ErrorCode, message: impl Into<String>) -> Self {
|
||||||
|
Self::Error {
|
||||||
|
code,
|
||||||
|
message: message.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an ERROR packet with default message for the error code.
|
||||||
|
#[must_use]
|
||||||
|
pub fn error_default(code: ErrorCode) -> Self {
|
||||||
|
Self::Error {
|
||||||
|
code,
|
||||||
|
message: code.default_message().to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rrq_roundtrip() {
|
||||||
|
let packet = Packet::rrq("test.txt", Mode::Octet);
|
||||||
|
let bytes = packet.serialize().expect("serialize failed");
|
||||||
|
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||||
|
assert_eq!(packet, parsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wrq_roundtrip() {
|
||||||
|
let packet = Packet::wrq("output.bin", Mode::NetAscii);
|
||||||
|
let bytes = packet.serialize().expect("serialize failed");
|
||||||
|
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||||
|
assert_eq!(packet, parsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_data_roundtrip() {
|
||||||
|
let packet = Packet::data(1, vec![1, 2, 3, 4, 5]);
|
||||||
|
let bytes = packet.serialize().expect("serialize failed");
|
||||||
|
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||||
|
assert_eq!(packet, parsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_data_empty() {
|
||||||
|
let packet = Packet::data(42, vec![]);
|
||||||
|
let bytes = packet.serialize().expect("serialize failed");
|
||||||
|
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||||
|
assert_eq!(packet, parsed);
|
||||||
|
assert!(packet.is_final_data());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_data_max_size() {
|
||||||
|
let packet = Packet::data(100, vec![0xAB; MAX_DATA_SIZE]);
|
||||||
|
let bytes = packet.serialize().expect("serialize failed");
|
||||||
|
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||||
|
assert_eq!(packet, parsed);
|
||||||
|
assert!(!packet.is_final_data());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_data_too_large() {
|
||||||
|
let packet = Packet::data(1, vec![0; MAX_DATA_SIZE + 1]);
|
||||||
|
let result = packet.serialize();
|
||||||
|
assert!(matches!(result, Err(Error::DataTooLarge { .. })));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ack_roundtrip() {
|
||||||
|
let packet = Packet::ack(0);
|
||||||
|
let bytes = packet.serialize().expect("serialize failed");
|
||||||
|
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||||
|
assert_eq!(packet, parsed);
|
||||||
|
|
||||||
|
let packet = Packet::ack(65535);
|
||||||
|
let bytes = packet.serialize().expect("serialize failed");
|
||||||
|
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||||
|
assert_eq!(packet, parsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_error_roundtrip() {
|
||||||
|
let packet = Packet::error(ErrorCode::FileNotFound, "File not found");
|
||||||
|
let bytes = packet.serialize().expect("serialize failed");
|
||||||
|
let parsed = Packet::parse(&bytes).expect("parse failed");
|
||||||
|
assert_eq!(packet, parsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_error_default_message() {
|
||||||
|
let packet = Packet::error_default(ErrorCode::AccessViolation);
|
||||||
|
assert!(matches!(
|
||||||
|
packet,
|
||||||
|
Packet::Error {
|
||||||
|
code: ErrorCode::AccessViolation,
|
||||||
|
message
|
||||||
|
} if message == "Access violation"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mode_case_insensitive() {
|
||||||
|
assert_eq!(Mode::parse("OCTET").expect("parse"), Mode::Octet);
|
||||||
|
assert_eq!(Mode::parse("Octet").expect("parse"), Mode::Octet);
|
||||||
|
assert_eq!(Mode::parse("netascii").expect("parse"), Mode::NetAscii);
|
||||||
|
assert_eq!(Mode::parse("NETASCII").expect("parse"), Mode::NetAscii);
|
||||||
|
assert_eq!(Mode::parse("NetAscii").expect("parse"), Mode::NetAscii);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_mode() {
|
||||||
|
assert!(matches!(
|
||||||
|
Mode::parse("binary"),
|
||||||
|
Err(Error::InvalidMode(_))
|
||||||
|
));
|
||||||
|
assert!(matches!(Mode::parse("mail"), Err(Error::InvalidMode(_))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_opcode() {
|
||||||
|
let data = [0x00, 0x06]; // opcode 6 is invalid
|
||||||
|
let result = Packet::parse(&data);
|
||||||
|
assert!(matches!(result, Err(Error::InvalidOpcode(6))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_packet_too_short() {
|
||||||
|
let data = [0x00];
|
||||||
|
let result = Packet::parse(&data);
|
||||||
|
assert!(matches!(result, Err(Error::PacketTooShort { .. })));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rrq_parse_raw() {
|
||||||
|
// RRQ for "test.txt" in octet mode
|
||||||
|
let data = [
|
||||||
|
0x00, 0x01, // opcode = RRQ
|
||||||
|
b't', b'e', b's', b't', b'.', b't', b'x', b't', 0x00, // filename
|
||||||
|
b'o', b'c', b't', b'e', b't', 0x00, // mode
|
||||||
|
];
|
||||||
|
let packet = Packet::parse(&data).expect("parse failed");
|
||||||
|
assert!(matches!(
|
||||||
|
packet,
|
||||||
|
Packet::ReadRequest { filename, mode: Mode::Octet } if filename == "test.txt"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
661
crates/pfs-tftp-proto/src/state.rs
Normal file
661
crates/pfs-tftp-proto/src/state.rs
Normal file
@@ -0,0 +1,661 @@
|
|||||||
|
//! TFTP Protocol State Machine
|
||||||
|
//!
|
||||||
|
//! This module implements sans-io state machines for both client and server
|
||||||
|
//! sides of the TFTP protocol.
|
||||||
|
//!
|
||||||
|
//! # RFC 1350 Protocol Overview
|
||||||
|
//!
|
||||||
|
//! The protocol follows a lock-step acknowledgment pattern:
|
||||||
|
//!
|
||||||
|
//! ## Read Transfer (Client reading from Server)
|
||||||
|
//! ```text
|
||||||
|
//! Client -> Server: RRQ (filename, mode)
|
||||||
|
//! Server -> Client: DATA (block 1)
|
||||||
|
//! Client -> Server: ACK (block 1)
|
||||||
|
//! Server -> Client: DATA (block 2)
|
||||||
|
//! ... (continues until final DATA with < 512 bytes)
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! ## Write Transfer (Client writing to Server)
|
||||||
|
//! ```text
|
||||||
|
//! Client -> Server: WRQ (filename, mode)
|
||||||
|
//! Server -> Client: ACK (block 0)
|
||||||
|
//! Client -> Server: DATA (block 1)
|
||||||
|
//! Server -> Client: ACK (block 1)
|
||||||
|
//! ... (continues until final DATA with < 512 bytes)
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use crate::{Error, Mode, Packet, Result, MAX_DATA_SIZE};
|
||||||
|
|
||||||
|
/// Direction of data transfer.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum TransferDirection {
|
||||||
|
/// Reading data from remote (client downloads, server uploads)
|
||||||
|
Read,
|
||||||
|
/// Writing data to remote (client uploads, server downloads)
|
||||||
|
Write,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Events emitted by the state machine.
|
||||||
|
///
|
||||||
|
/// These events tell the I/O layer what action to take next.
|
||||||
|
/// Errors are returned via `Result` from the state machine methods,
|
||||||
|
/// not as events.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub enum Event {
|
||||||
|
/// Send this packet to the remote peer.
|
||||||
|
Send(Packet),
|
||||||
|
|
||||||
|
/// Request data from the application for the given block number.
|
||||||
|
///
|
||||||
|
/// The application should provide up to 512 bytes of data.
|
||||||
|
/// A response with less than 512 bytes indicates the final block.
|
||||||
|
NeedData { block: u16 },
|
||||||
|
|
||||||
|
/// Received data from remote peer.
|
||||||
|
///
|
||||||
|
/// The application should store this data.
|
||||||
|
ReceivedData { block: u16, data: Vec<u8> },
|
||||||
|
|
||||||
|
/// Transfer completed successfully.
|
||||||
|
Complete,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Client-side protocol state machine.
|
||||||
|
///
|
||||||
|
/// Handles the client side of TFTP transfers (both read and write).
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ClientState {
|
||||||
|
/// The direction of the transfer
|
||||||
|
direction: TransferDirection,
|
||||||
|
/// Transfer mode
|
||||||
|
mode: Mode,
|
||||||
|
/// Current block number we're expecting/sending
|
||||||
|
current_block: u16,
|
||||||
|
/// Whether the transfer is complete
|
||||||
|
complete: bool,
|
||||||
|
/// Whether we've received the initial response
|
||||||
|
got_initial_response: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientState {
|
||||||
|
/// Create a new client state for a read request.
|
||||||
|
///
|
||||||
|
/// Call `start()` to get the initial RRQ packet to send.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn new_read(mode: Mode) -> Self {
|
||||||
|
Self {
|
||||||
|
direction: TransferDirection::Read,
|
||||||
|
mode,
|
||||||
|
current_block: 1, // Expecting DATA block 1
|
||||||
|
complete: false,
|
||||||
|
got_initial_response: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new client state for a write request.
|
||||||
|
///
|
||||||
|
/// Call `start()` to get the initial WRQ packet to send.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn new_write(mode: Mode) -> Self {
|
||||||
|
Self {
|
||||||
|
direction: TransferDirection::Write,
|
||||||
|
mode,
|
||||||
|
current_block: 0, // Expecting ACK block 0
|
||||||
|
complete: false,
|
||||||
|
got_initial_response: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the initial request packet to send.
|
||||||
|
#[must_use]
|
||||||
|
pub fn start(&self, filename: &str) -> Packet {
|
||||||
|
match self.direction {
|
||||||
|
TransferDirection::Read => Packet::rrq(filename, self.mode),
|
||||||
|
TransferDirection::Write => Packet::wrq(filename, self.mode),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the transfer is complete.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn is_complete(&self) -> bool {
|
||||||
|
self.complete
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the current expected block number.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn current_block(&self) -> u16 {
|
||||||
|
self.current_block
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a received packet and return events.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the packet is invalid for the current state.
|
||||||
|
pub fn receive(&mut self, packet: &Packet) -> Result<Vec<Event>> {
|
||||||
|
if self.complete {
|
||||||
|
return Err(Error::StateError("transfer already complete".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
match (self.direction, packet) {
|
||||||
|
// Reading: expect DATA packets
|
||||||
|
(TransferDirection::Read, Packet::Data { block, data }) => {
|
||||||
|
self.handle_read_data(*block, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Writing: expect ACK packets
|
||||||
|
(TransferDirection::Write, Packet::Ack { block }) => self.handle_write_ack(*block),
|
||||||
|
|
||||||
|
// Error packet terminates transfer
|
||||||
|
(_, Packet::Error { code, message }) => {
|
||||||
|
self.complete = true;
|
||||||
|
Err(Error::TftpError {
|
||||||
|
code: *code,
|
||||||
|
message: message.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unexpected packet type
|
||||||
|
(TransferDirection::Read, other) => Err(Error::UnexpectedPacket {
|
||||||
|
expected: "DATA",
|
||||||
|
actual: other.type_name(),
|
||||||
|
}),
|
||||||
|
|
||||||
|
(TransferDirection::Write, other) => Err(Error::UnexpectedPacket {
|
||||||
|
expected: "ACK",
|
||||||
|
actual: other.type_name(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle a DATA packet during a read transfer.
|
||||||
|
fn handle_read_data(&mut self, block: u16, data: &[u8]) -> Result<Vec<Event>> {
|
||||||
|
self.got_initial_response = true;
|
||||||
|
|
||||||
|
// RFC 1350: Block numbers start at 1 and increase sequentially
|
||||||
|
if block != self.current_block {
|
||||||
|
// Could be a duplicate - if it's the previous block, re-send ACK
|
||||||
|
if block == self.current_block.wrapping_sub(1) {
|
||||||
|
return Ok(vec![Event::Send(Packet::ack(block))]);
|
||||||
|
}
|
||||||
|
return Err(Error::BlockMismatch {
|
||||||
|
expected: self.current_block,
|
||||||
|
actual: block,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let is_final = data.len() < MAX_DATA_SIZE;
|
||||||
|
let mut events = vec![
|
||||||
|
Event::ReceivedData {
|
||||||
|
block,
|
||||||
|
data: data.to_vec(),
|
||||||
|
},
|
||||||
|
Event::Send(Packet::ack(block)),
|
||||||
|
];
|
||||||
|
|
||||||
|
if is_final {
|
||||||
|
self.complete = true;
|
||||||
|
events.push(Event::Complete);
|
||||||
|
} else {
|
||||||
|
self.current_block = self.current_block.wrapping_add(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(events)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle an ACK packet during a write transfer.
|
||||||
|
fn handle_write_ack(&mut self, block: u16) -> Result<Vec<Event>> {
|
||||||
|
// RFC 1350 Section 4: "A WRQ is acknowledged with an ACK packet having
|
||||||
|
// a block number of zero."
|
||||||
|
if !self.got_initial_response {
|
||||||
|
if block != 0 {
|
||||||
|
return Err(Error::BlockMismatch {
|
||||||
|
expected: 0,
|
||||||
|
actual: block,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
self.got_initial_response = true;
|
||||||
|
self.current_block = 1;
|
||||||
|
return Ok(vec![Event::NeedData { block: 1 }]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if block != self.current_block {
|
||||||
|
// Duplicate ACK - ignore
|
||||||
|
if block == self.current_block.wrapping_sub(1) {
|
||||||
|
return Ok(vec![]);
|
||||||
|
}
|
||||||
|
return Err(Error::BlockMismatch {
|
||||||
|
expected: self.current_block,
|
||||||
|
actual: block,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request next block of data
|
||||||
|
self.current_block = self.current_block.wrapping_add(1);
|
||||||
|
Ok(vec![Event::NeedData {
|
||||||
|
block: self.current_block,
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provide data for a write transfer.
|
||||||
|
///
|
||||||
|
/// Call this in response to a `NeedData` event.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `block` - The block number this data is for
|
||||||
|
/// * `data` - The data to send (must be <= 512 bytes)
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// Returns the DATA packet to send, or marks the transfer complete if
|
||||||
|
/// this was the final block.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the data is too large or the block number is wrong.
|
||||||
|
pub fn provide_data(&mut self, block: u16, data: Vec<u8>) -> Result<Event> {
|
||||||
|
if data.len() > MAX_DATA_SIZE {
|
||||||
|
return Err(Error::DataTooLarge { size: data.len() });
|
||||||
|
}
|
||||||
|
|
||||||
|
if block != self.current_block {
|
||||||
|
return Err(Error::BlockMismatch {
|
||||||
|
expected: self.current_block,
|
||||||
|
actual: block,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let is_final = data.len() < MAX_DATA_SIZE;
|
||||||
|
if is_final {
|
||||||
|
self.complete = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Event::Send(Packet::data(block, data)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mark the transfer as complete after the final ACK is received.
|
||||||
|
pub fn finish(&mut self) -> Event {
|
||||||
|
self.complete = true;
|
||||||
|
Event::Complete
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Server-side protocol state machine.
|
||||||
|
///
|
||||||
|
/// Handles the server side of TFTP transfers.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ServerState {
|
||||||
|
/// The direction of the transfer (from server's perspective)
|
||||||
|
direction: TransferDirection,
|
||||||
|
/// Transfer mode
|
||||||
|
mode: Mode,
|
||||||
|
/// The filename being transferred
|
||||||
|
filename: String,
|
||||||
|
/// Current block number
|
||||||
|
current_block: u16,
|
||||||
|
/// Whether the transfer is complete
|
||||||
|
complete: bool,
|
||||||
|
/// Whether we've sent the first response
|
||||||
|
sent_initial_response: bool,
|
||||||
|
/// Track if we're waiting for final ACK
|
||||||
|
waiting_for_final_ack: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerState {
|
||||||
|
/// Create a new server state from a request packet.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the packet is not an RRQ or WRQ.
|
||||||
|
pub fn from_request(packet: &Packet) -> Result<Self> {
|
||||||
|
match packet {
|
||||||
|
Packet::ReadRequest { filename, mode } => Ok(Self {
|
||||||
|
// Server uploads when client reads
|
||||||
|
direction: TransferDirection::Write,
|
||||||
|
mode: *mode,
|
||||||
|
filename: filename.clone(),
|
||||||
|
current_block: 1,
|
||||||
|
complete: false,
|
||||||
|
sent_initial_response: false,
|
||||||
|
waiting_for_final_ack: false,
|
||||||
|
}),
|
||||||
|
Packet::WriteRequest { filename, mode } => Ok(Self {
|
||||||
|
// Server downloads when client writes
|
||||||
|
direction: TransferDirection::Read,
|
||||||
|
mode: *mode,
|
||||||
|
filename: filename.clone(),
|
||||||
|
current_block: 0,
|
||||||
|
complete: false,
|
||||||
|
sent_initial_response: false,
|
||||||
|
waiting_for_final_ack: false,
|
||||||
|
}),
|
||||||
|
other => Err(Error::UnexpectedPacket {
|
||||||
|
expected: "RRQ or WRQ",
|
||||||
|
actual: other.type_name(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the filename being transferred.
|
||||||
|
#[must_use]
|
||||||
|
pub fn filename(&self) -> &str {
|
||||||
|
&self.filename
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the transfer mode.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn mode(&self) -> Mode {
|
||||||
|
self.mode
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the transfer direction (from server's perspective).
|
||||||
|
#[must_use]
|
||||||
|
pub const fn direction(&self) -> TransferDirection {
|
||||||
|
self.direction
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the transfer is complete.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn is_complete(&self) -> bool {
|
||||||
|
self.complete
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the initial response event.
|
||||||
|
///
|
||||||
|
/// For RRQ: Returns `NeedData` to request the first block
|
||||||
|
/// For WRQ: Returns `Send(ACK 0)` to acknowledge the write request
|
||||||
|
pub fn start(&mut self) -> Event {
|
||||||
|
self.sent_initial_response = true;
|
||||||
|
|
||||||
|
match self.direction {
|
||||||
|
// Server sending data (client RRQ)
|
||||||
|
TransferDirection::Write => Event::NeedData { block: 1 },
|
||||||
|
|
||||||
|
// Server receiving data (client WRQ)
|
||||||
|
// RFC 1350 Section 4: "A WRQ is acknowledged with an ACK packet
|
||||||
|
// having a block number of zero."
|
||||||
|
TransferDirection::Read => Event::Send(Packet::ack(0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a received packet and return events.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the packet is invalid for the current state.
|
||||||
|
pub fn receive(&mut self, packet: &Packet) -> Result<Vec<Event>> {
|
||||||
|
if self.complete {
|
||||||
|
return Err(Error::StateError("transfer already complete".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
match (self.direction, packet) {
|
||||||
|
// Server sending data, receiving ACKs
|
||||||
|
(TransferDirection::Write, Packet::Ack { block }) => self.handle_send_ack(*block),
|
||||||
|
|
||||||
|
// Server receiving data
|
||||||
|
(TransferDirection::Read, Packet::Data { block, data }) => {
|
||||||
|
self.handle_receive_data(*block, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error terminates transfer
|
||||||
|
(_, Packet::Error { code, message }) => {
|
||||||
|
self.complete = true;
|
||||||
|
Err(Error::TftpError {
|
||||||
|
code: *code,
|
||||||
|
message: message.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unexpected packet
|
||||||
|
(TransferDirection::Write, other) => Err(Error::UnexpectedPacket {
|
||||||
|
expected: "ACK",
|
||||||
|
actual: other.type_name(),
|
||||||
|
}),
|
||||||
|
|
||||||
|
(TransferDirection::Read, other) => Err(Error::UnexpectedPacket {
|
||||||
|
expected: "DATA",
|
||||||
|
actual: other.type_name(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle an ACK when server is sending data.
|
||||||
|
fn handle_send_ack(&mut self, block: u16) -> Result<Vec<Event>> {
|
||||||
|
if block != self.current_block {
|
||||||
|
// Duplicate ACK - could retransmit, but we'll just ignore
|
||||||
|
if block == self.current_block.wrapping_sub(1) {
|
||||||
|
return Ok(vec![]);
|
||||||
|
}
|
||||||
|
return Err(Error::BlockMismatch {
|
||||||
|
expected: self.current_block,
|
||||||
|
actual: block,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we were waiting for the final ACK, we're done
|
||||||
|
if self.waiting_for_final_ack {
|
||||||
|
self.complete = true;
|
||||||
|
return Ok(vec![Event::Complete]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request next block
|
||||||
|
self.current_block = self.current_block.wrapping_add(1);
|
||||||
|
Ok(vec![Event::NeedData {
|
||||||
|
block: self.current_block,
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle a DATA packet when server is receiving data.
|
||||||
|
fn handle_receive_data(&mut self, block: u16, data: &[u8]) -> Result<Vec<Event>> {
|
||||||
|
let expected_block = self.current_block.wrapping_add(1);
|
||||||
|
|
||||||
|
if block != expected_block {
|
||||||
|
// Duplicate - re-ACK
|
||||||
|
if block == self.current_block {
|
||||||
|
return Ok(vec![Event::Send(Packet::ack(block))]);
|
||||||
|
}
|
||||||
|
return Err(Error::BlockMismatch {
|
||||||
|
expected: expected_block,
|
||||||
|
actual: block,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
self.current_block = block;
|
||||||
|
let is_final = data.len() < MAX_DATA_SIZE;
|
||||||
|
|
||||||
|
let mut events = vec![
|
||||||
|
Event::ReceivedData {
|
||||||
|
block,
|
||||||
|
data: data.to_vec(),
|
||||||
|
},
|
||||||
|
Event::Send(Packet::ack(block)),
|
||||||
|
];
|
||||||
|
|
||||||
|
if is_final {
|
||||||
|
self.complete = true;
|
||||||
|
events.push(Event::Complete);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(events)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provide data for a read transfer (server sending to client).
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the data is too large.
|
||||||
|
pub fn provide_data(&mut self, block: u16, data: Vec<u8>) -> Result<Event> {
|
||||||
|
if data.len() > MAX_DATA_SIZE {
|
||||||
|
return Err(Error::DataTooLarge { size: data.len() });
|
||||||
|
}
|
||||||
|
|
||||||
|
if block != self.current_block {
|
||||||
|
return Err(Error::BlockMismatch {
|
||||||
|
expected: self.current_block,
|
||||||
|
actual: block,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let is_final = data.len() < MAX_DATA_SIZE;
|
||||||
|
if is_final {
|
||||||
|
self.waiting_for_final_ack = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Event::Send(Packet::data(block, data)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::ErrorCode;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_client_read_single_block() {
|
||||||
|
let mut client = ClientState::new_read(Mode::Octet);
|
||||||
|
let request = client.start("test.txt");
|
||||||
|
assert!(matches!(
|
||||||
|
request,
|
||||||
|
Packet::ReadRequest { filename, mode: Mode::Octet } if filename == "test.txt"
|
||||||
|
));
|
||||||
|
|
||||||
|
// Receive single DATA block (< 512 bytes = final)
|
||||||
|
let events = client
|
||||||
|
.receive(&Packet::data(1, vec![1, 2, 3]))
|
||||||
|
.expect("receive failed");
|
||||||
|
|
||||||
|
assert_eq!(events.len(), 3);
|
||||||
|
assert!(matches!(&events[0], Event::ReceivedData { block: 1, data } if data == &[1, 2, 3]));
|
||||||
|
assert!(matches!(&events[1], Event::Send(Packet::Ack { block: 1 })));
|
||||||
|
assert!(matches!(&events[2], Event::Complete));
|
||||||
|
assert!(client.is_complete());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_client_read_multiple_blocks() {
|
||||||
|
let mut client = ClientState::new_read(Mode::Octet);
|
||||||
|
let _ = client.start("test.txt");
|
||||||
|
|
||||||
|
// Full block (512 bytes)
|
||||||
|
let full_data = vec![0u8; MAX_DATA_SIZE];
|
||||||
|
let events = client
|
||||||
|
.receive(&Packet::data(1, full_data.clone()))
|
||||||
|
.expect("receive");
|
||||||
|
assert_eq!(events.len(), 2);
|
||||||
|
assert!(!client.is_complete());
|
||||||
|
assert_eq!(client.current_block(), 2);
|
||||||
|
|
||||||
|
// Another full block
|
||||||
|
let events = client
|
||||||
|
.receive(&Packet::data(2, full_data))
|
||||||
|
.expect("receive");
|
||||||
|
assert_eq!(events.len(), 2);
|
||||||
|
assert!(!client.is_complete());
|
||||||
|
|
||||||
|
// Final block (< 512 bytes)
|
||||||
|
let events = client
|
||||||
|
.receive(&Packet::data(3, vec![1, 2, 3]))
|
||||||
|
.expect("receive");
|
||||||
|
assert_eq!(events.len(), 3);
|
||||||
|
assert!(client.is_complete());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_client_write_flow() {
|
||||||
|
let mut client = ClientState::new_write(Mode::Octet);
|
||||||
|
let request = client.start("output.txt");
|
||||||
|
assert!(matches!(request, Packet::WriteRequest { .. }));
|
||||||
|
|
||||||
|
// Receive ACK 0 (write request acknowledged)
|
||||||
|
let events = client.receive(&Packet::ack(0)).expect("receive");
|
||||||
|
assert_eq!(events.len(), 1);
|
||||||
|
assert!(matches!(events[0], Event::NeedData { block: 1 }));
|
||||||
|
|
||||||
|
// Provide data for block 1
|
||||||
|
let send_event = client.provide_data(1, vec![1, 2, 3]).expect("provide_data");
|
||||||
|
assert!(matches!(
|
||||||
|
send_event,
|
||||||
|
Event::Send(Packet::Data { block: 1, .. })
|
||||||
|
));
|
||||||
|
assert!(client.is_complete()); // < 512 bytes = final
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_server_rrq() {
|
||||||
|
let request = Packet::rrq("file.txt", Mode::Octet);
|
||||||
|
let mut server = ServerState::from_request(&request).expect("from_request");
|
||||||
|
|
||||||
|
assert_eq!(server.filename(), "file.txt");
|
||||||
|
assert_eq!(server.direction(), TransferDirection::Write);
|
||||||
|
|
||||||
|
// Start should request data for block 1
|
||||||
|
let event = server.start();
|
||||||
|
assert!(matches!(event, Event::NeedData { block: 1 }));
|
||||||
|
|
||||||
|
// Provide data
|
||||||
|
let event = server.provide_data(1, vec![1, 2, 3]).expect("provide_data");
|
||||||
|
assert!(matches!(event, Event::Send(Packet::Data { block: 1, .. })));
|
||||||
|
|
||||||
|
// Receive ACK
|
||||||
|
let events = server.receive(&Packet::ack(1)).expect("receive");
|
||||||
|
assert!(matches!(&events[0], Event::Complete));
|
||||||
|
assert!(server.is_complete());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_server_wrq() {
|
||||||
|
let request = Packet::wrq("output.txt", Mode::Octet);
|
||||||
|
let mut server = ServerState::from_request(&request).expect("from_request");
|
||||||
|
|
||||||
|
assert_eq!(server.direction(), TransferDirection::Read);
|
||||||
|
|
||||||
|
// Start should send ACK 0
|
||||||
|
let event = server.start();
|
||||||
|
assert!(matches!(event, Event::Send(Packet::Ack { block: 0 })));
|
||||||
|
|
||||||
|
// Receive DATA block 1
|
||||||
|
let events = server
|
||||||
|
.receive(&Packet::data(1, vec![1, 2, 3]))
|
||||||
|
.expect("receive");
|
||||||
|
assert_eq!(events.len(), 3);
|
||||||
|
assert!(matches!(&events[0], Event::ReceivedData { block: 1, .. }));
|
||||||
|
assert!(matches!(&events[1], Event::Send(Packet::Ack { block: 1 })));
|
||||||
|
assert!(matches!(&events[2], Event::Complete));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_error_terminates_transfer() {
|
||||||
|
let mut client = ClientState::new_read(Mode::Octet);
|
||||||
|
let _ = client.start("test.txt");
|
||||||
|
|
||||||
|
let result = client.receive(&Packet::error(ErrorCode::FileNotFound, "File not found"));
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(client.is_complete());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_duplicate_ack_ignored() {
|
||||||
|
let request = Packet::rrq("file.txt", Mode::Octet);
|
||||||
|
let mut server = ServerState::from_request(&request).expect("from_request");
|
||||||
|
let _ = server.start();
|
||||||
|
|
||||||
|
// Provide and send block 1
|
||||||
|
let _ = server
|
||||||
|
.provide_data(1, vec![0u8; MAX_DATA_SIZE])
|
||||||
|
.expect("provide_data");
|
||||||
|
|
||||||
|
// Receive ACK 1
|
||||||
|
let events = server.receive(&Packet::ack(1)).expect("receive");
|
||||||
|
assert!(!events.is_empty());
|
||||||
|
|
||||||
|
// Provide block 2
|
||||||
|
let _ = server.provide_data(2, vec![1, 2, 3]).expect("provide_data");
|
||||||
|
|
||||||
|
// Duplicate ACK 1 should be ignored
|
||||||
|
let events = server.receive(&Packet::ack(1)).expect("receive");
|
||||||
|
assert!(events.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
22
crates/pfs-tftp/Cargo.toml
Normal file
22
crates/pfs-tftp/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
[package]
|
||||||
|
name = "pfs-tftp"
|
||||||
|
description = "TFTP client and server library (RFC 1350)"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
repository.workspace = true
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
pfs-tftp-proto = { path = "../pfs-tftp-proto" }
|
||||||
|
thiserror = "2"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "tftpd"
|
||||||
|
path = "src/bin/tftpd.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "tftp"
|
||||||
|
path = "src/bin/tftp.rs"
|
||||||
3
crates/pfs-tftp/src/bin/tftp.rs
Normal file
3
crates/pfs-tftp/src/bin/tftp.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
fn main() {
|
||||||
|
println!("tftp client placeholder");
|
||||||
|
}
|
||||||
3
crates/pfs-tftp/src/bin/tftpd.rs
Normal file
3
crates/pfs-tftp/src/bin/tftpd.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
fn main() {
|
||||||
|
println!("tftpd server placeholder");
|
||||||
|
}
|
||||||
479
crates/pfs-tftp/src/client.rs
Normal file
479
crates/pfs-tftp/src/client.rs
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
//! TFTP Client implementation.
|
||||||
|
//!
|
||||||
|
//! Provides synchronous TFTP client operations for reading and writing files
|
||||||
|
//! to/from a TFTP server.
|
||||||
|
|
||||||
|
// The expect() calls in this module cannot panic in practice because we always
|
||||||
|
// set the server_tid before using it, but clippy doesn't know this.
|
||||||
|
#![allow(clippy::missing_panics_doc)]
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
io::{Read, Write},
|
||||||
|
net::{SocketAddr, ToSocketAddrs, UdpSocket},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use pfs_tftp_proto::{ClientState, Event, Mode, Packet, MAX_DATA_SIZE, TFTP_PORT};
|
||||||
|
|
||||||
|
use crate::{Error, Result};
|
||||||
|
|
||||||
|
/// Default timeout for TFTP operations (in seconds).
|
||||||
|
const DEFAULT_TIMEOUT_SECS: u64 = 5;
|
||||||
|
|
||||||
|
/// Default number of retries before giving up.
|
||||||
|
const DEFAULT_RETRIES: u32 = 3;
|
||||||
|
|
||||||
|
/// Maximum packet size (opcode + block + max data).
|
||||||
|
const MAX_PACKET_SIZE: usize = 4 + MAX_DATA_SIZE;
|
||||||
|
|
||||||
|
/// TFTP client for transferring files to/from a server.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// use pfs_tftp::{Client, Mode};
|
||||||
|
///
|
||||||
|
/// let client = Client::new("192.168.1.1:69").expect("connect");
|
||||||
|
/// let data = client.get("config.txt", Mode::Octet).expect("get");
|
||||||
|
/// println!("Received {} bytes", data.len());
|
||||||
|
/// ```
|
||||||
|
pub struct Client {
|
||||||
|
/// Server address
|
||||||
|
server_addr: SocketAddr,
|
||||||
|
/// Local UDP socket
|
||||||
|
socket: UdpSocket,
|
||||||
|
/// Timeout duration
|
||||||
|
timeout: Duration,
|
||||||
|
/// Number of retries
|
||||||
|
retries: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Client {
|
||||||
|
/// Create a new TFTP client connected to the specified server.
|
||||||
|
///
|
||||||
|
/// The address can be in any format accepted by `ToSocketAddrs`, such as
|
||||||
|
/// "192.168.1.1:69" or "tftp.example.com:69".
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the address is invalid or the socket cannot be bound.
|
||||||
|
pub fn new<A: ToSocketAddrs>(server_addr: A) -> Result<Self> {
|
||||||
|
let server_addr = server_addr
|
||||||
|
.to_socket_addrs()?
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::InvalidInput, "no addresses found")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Bind to any available local port
|
||||||
|
let socket = UdpSocket::bind("0.0.0.0:0")?;
|
||||||
|
socket.set_read_timeout(Some(Duration::from_secs(DEFAULT_TIMEOUT_SECS)))?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
server_addr,
|
||||||
|
socket,
|
||||||
|
timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
|
||||||
|
retries: DEFAULT_RETRIES,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the timeout for operations.
|
||||||
|
#[must_use]
|
||||||
|
pub fn with_timeout(mut self, timeout: Duration) -> Self {
|
||||||
|
self.timeout = timeout;
|
||||||
|
let _ = self.socket.set_read_timeout(Some(timeout));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the number of retries.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn with_retries(mut self, retries: u32) -> Self {
|
||||||
|
self.retries = retries;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Download a file from the server.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `filename` - The name of the file to download
|
||||||
|
/// * `mode` - The transfer mode (`Octet` or `NetAscii`)
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// The file contents as a byte vector.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the transfer fails.
|
||||||
|
pub fn get(&self, filename: &str, mode: Mode) -> Result<Vec<u8>> {
|
||||||
|
let mut state = ClientState::new_read(mode);
|
||||||
|
let request = state.start(filename);
|
||||||
|
|
||||||
|
// Send initial request to well-known port
|
||||||
|
let request_bytes = request.serialize()?;
|
||||||
|
self.socket.send_to(&request_bytes, self.server_addr)?;
|
||||||
|
|
||||||
|
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
|
||||||
|
let mut data = Vec::new();
|
||||||
|
let mut server_tid: Option<SocketAddr> = None;
|
||||||
|
let mut retries_left = self.retries;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// Receive packet
|
||||||
|
let (len, from_addr) = match self.socket.recv_from(&mut recv_buf) {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
||||||
|
// Timeout - retransmit last packet
|
||||||
|
if retries_left == 0 {
|
||||||
|
return Err(Error::Timeout {
|
||||||
|
retries: self.retries,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
retries_left -= 1;
|
||||||
|
self.socket.send_to(&request_bytes, self.server_addr)?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e.into()),
|
||||||
|
};
|
||||||
|
|
||||||
|
// RFC 1350 Section 4: First response establishes the TID
|
||||||
|
// Subsequent packets must come from the same TID
|
||||||
|
match server_tid {
|
||||||
|
None => server_tid = Some(from_addr),
|
||||||
|
Some(expected) if from_addr != expected => {
|
||||||
|
// RFC 1350: "If a source TID does not match, the packet should
|
||||||
|
// be discarded as erroneously sent from somewhere else. An error
|
||||||
|
// packet should be sent to the source of the incorrect packet."
|
||||||
|
let error =
|
||||||
|
Packet::error(pfs_tftp_proto::ErrorCode::UnknownTransferId, "Unknown TID");
|
||||||
|
let _ = self.socket.send_to(&error.serialize()?, from_addr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
let packet = Packet::parse(&recv_buf[..len])?;
|
||||||
|
|
||||||
|
// Handle error packets
|
||||||
|
if let Packet::Error { code, message } = packet {
|
||||||
|
return Err(Error::Remote { code, message });
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = state.receive(&packet)?;
|
||||||
|
|
||||||
|
for event in events {
|
||||||
|
match event {
|
||||||
|
Event::ReceivedData { data: block_data, .. } => {
|
||||||
|
data.extend_from_slice(&block_data);
|
||||||
|
}
|
||||||
|
Event::Send(packet) => {
|
||||||
|
let bytes = packet.serialize()?;
|
||||||
|
self.socket.send_to(&bytes, server_tid.expect("TID set"))?;
|
||||||
|
}
|
||||||
|
Event::Complete => {
|
||||||
|
return Ok(data);
|
||||||
|
}
|
||||||
|
Event::NeedData { .. } => {
|
||||||
|
// Not used in read transfers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
retries_left = self.retries;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Upload a file to the server.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `filename` - The name to give the file on the server
|
||||||
|
/// * `mode` - The transfer mode (`Octet` or `NetAscii`)
|
||||||
|
/// * `data` - The file contents to upload
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the transfer fails.
|
||||||
|
pub fn put(&self, filename: &str, mode: Mode, data: &[u8]) -> Result<()> {
|
||||||
|
let mut state = ClientState::new_write(mode);
|
||||||
|
let request = state.start(filename);
|
||||||
|
|
||||||
|
// Send initial request
|
||||||
|
let request_bytes = request.serialize()?;
|
||||||
|
self.socket.send_to(&request_bytes, self.server_addr)?;
|
||||||
|
|
||||||
|
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
|
||||||
|
let mut server_tid: Option<SocketAddr> = None;
|
||||||
|
let mut retries_left = self.retries;
|
||||||
|
let mut last_sent: Option<Vec<u8>> = None;
|
||||||
|
let mut data_offset: usize = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// Receive packet
|
||||||
|
let (len, from_addr) = match self.socket.recv_from(&mut recv_buf) {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
||||||
|
// Timeout - retransmit
|
||||||
|
if retries_left == 0 {
|
||||||
|
return Err(Error::Timeout {
|
||||||
|
retries: self.retries,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
retries_left -= 1;
|
||||||
|
if let Some(ref last) = last_sent {
|
||||||
|
self.socket
|
||||||
|
.send_to(last, server_tid.unwrap_or(self.server_addr))?;
|
||||||
|
} else {
|
||||||
|
self.socket.send_to(&request_bytes, self.server_addr)?;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e.into()),
|
||||||
|
};
|
||||||
|
|
||||||
|
// TID validation
|
||||||
|
match server_tid {
|
||||||
|
None => server_tid = Some(from_addr),
|
||||||
|
Some(expected) if from_addr != expected => {
|
||||||
|
let error =
|
||||||
|
Packet::error(pfs_tftp_proto::ErrorCode::UnknownTransferId, "Unknown TID");
|
||||||
|
let _ = self.socket.send_to(&error.serialize()?, from_addr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
let packet = Packet::parse(&recv_buf[..len])?;
|
||||||
|
|
||||||
|
if let Packet::Error { code, message } = packet {
|
||||||
|
return Err(Error::Remote { code, message });
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = state.receive(&packet)?;
|
||||||
|
|
||||||
|
for event in events {
|
||||||
|
match event {
|
||||||
|
Event::NeedData { block } => {
|
||||||
|
// Calculate data for this block
|
||||||
|
let start = data_offset;
|
||||||
|
let end = (start + MAX_DATA_SIZE).min(data.len());
|
||||||
|
let block_data = data[start..end].to_vec();
|
||||||
|
data_offset = end;
|
||||||
|
|
||||||
|
let send_event = state.provide_data(block, block_data)?;
|
||||||
|
if let Event::Send(packet) = send_event {
|
||||||
|
let bytes = packet.serialize()?;
|
||||||
|
self.socket.send_to(&bytes, server_tid.expect("TID set"))?;
|
||||||
|
last_sent = Some(bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Event::Send(packet) => {
|
||||||
|
let bytes = packet.serialize()?;
|
||||||
|
self.socket.send_to(&bytes, server_tid.expect("TID set"))?;
|
||||||
|
last_sent = Some(bytes);
|
||||||
|
}
|
||||||
|
Event::Complete => {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
Event::ReceivedData { .. } => {
|
||||||
|
// Not used in write transfers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if transfer is complete after sending final data
|
||||||
|
if state.is_complete() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
retries_left = self.retries;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Download a file from the server to a writer.
|
||||||
|
///
|
||||||
|
/// This is more memory-efficient for large files as it streams data
|
||||||
|
/// directly to the writer.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the transfer fails.
|
||||||
|
pub fn get_to_writer<W: Write>(&self, filename: &str, mode: Mode, writer: &mut W) -> Result<u64> {
|
||||||
|
let mut state = ClientState::new_read(mode);
|
||||||
|
let request = state.start(filename);
|
||||||
|
|
||||||
|
let request_bytes = request.serialize()?;
|
||||||
|
self.socket.send_to(&request_bytes, self.server_addr)?;
|
||||||
|
|
||||||
|
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
|
||||||
|
let mut total_bytes: u64 = 0;
|
||||||
|
let mut server_tid: Option<SocketAddr> = None;
|
||||||
|
let mut retries_left = self.retries;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let (len, from_addr) = match self.socket.recv_from(&mut recv_buf) {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
||||||
|
if retries_left == 0 {
|
||||||
|
return Err(Error::Timeout {
|
||||||
|
retries: self.retries,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
retries_left -= 1;
|
||||||
|
self.socket.send_to(&request_bytes, self.server_addr)?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e.into()),
|
||||||
|
};
|
||||||
|
|
||||||
|
match server_tid {
|
||||||
|
None => server_tid = Some(from_addr),
|
||||||
|
Some(expected) if from_addr != expected => {
|
||||||
|
let error =
|
||||||
|
Packet::error(pfs_tftp_proto::ErrorCode::UnknownTransferId, "Unknown TID");
|
||||||
|
let _ = self.socket.send_to(&error.serialize()?, from_addr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
let packet = Packet::parse(&recv_buf[..len])?;
|
||||||
|
|
||||||
|
if let Packet::Error { code, message } = packet {
|
||||||
|
return Err(Error::Remote { code, message });
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = state.receive(&packet)?;
|
||||||
|
|
||||||
|
for event in events {
|
||||||
|
match event {
|
||||||
|
Event::ReceivedData { data, .. } => {
|
||||||
|
writer.write_all(&data)?;
|
||||||
|
total_bytes += data.len() as u64;
|
||||||
|
}
|
||||||
|
Event::Send(packet) => {
|
||||||
|
let bytes = packet.serialize()?;
|
||||||
|
self.socket.send_to(&bytes, server_tid.expect("TID set"))?;
|
||||||
|
}
|
||||||
|
Event::Complete => {
|
||||||
|
writer.flush()?;
|
||||||
|
return Ok(total_bytes);
|
||||||
|
}
|
||||||
|
Event::NeedData { .. } => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
retries_left = self.retries;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Upload data from a reader to the server.
|
||||||
|
///
|
||||||
|
/// This is more memory-efficient for large files as it streams data
|
||||||
|
/// from the reader.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the transfer fails.
|
||||||
|
pub fn put_from_reader<R: Read>(&self, filename: &str, mode: Mode, reader: &mut R) -> Result<u64> {
|
||||||
|
let mut state = ClientState::new_write(mode);
|
||||||
|
let request = state.start(filename);
|
||||||
|
|
||||||
|
let request_bytes = request.serialize()?;
|
||||||
|
self.socket.send_to(&request_bytes, self.server_addr)?;
|
||||||
|
|
||||||
|
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
|
||||||
|
let mut read_buf = [0u8; MAX_DATA_SIZE];
|
||||||
|
let mut total_bytes: u64 = 0;
|
||||||
|
let mut server_tid: Option<SocketAddr> = None;
|
||||||
|
let mut retries_left = self.retries;
|
||||||
|
let mut last_sent: Option<Vec<u8>> = None;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let (len, from_addr) = match self.socket.recv_from(&mut recv_buf) {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
||||||
|
if retries_left == 0 {
|
||||||
|
return Err(Error::Timeout {
|
||||||
|
retries: self.retries,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
retries_left -= 1;
|
||||||
|
if let Some(ref last) = last_sent {
|
||||||
|
self.socket
|
||||||
|
.send_to(last, server_tid.unwrap_or(self.server_addr))?;
|
||||||
|
} else {
|
||||||
|
self.socket.send_to(&request_bytes, self.server_addr)?;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e.into()),
|
||||||
|
};
|
||||||
|
|
||||||
|
match server_tid {
|
||||||
|
None => server_tid = Some(from_addr),
|
||||||
|
Some(expected) if from_addr != expected => {
|
||||||
|
let error =
|
||||||
|
Packet::error(pfs_tftp_proto::ErrorCode::UnknownTransferId, "Unknown TID");
|
||||||
|
let _ = self.socket.send_to(&error.serialize()?, from_addr);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
let packet = Packet::parse(&recv_buf[..len])?;
|
||||||
|
|
||||||
|
if let Packet::Error { code, message } = packet {
|
||||||
|
return Err(Error::Remote { code, message });
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = state.receive(&packet)?;
|
||||||
|
|
||||||
|
for event in events {
|
||||||
|
match event {
|
||||||
|
Event::NeedData { block } => {
|
||||||
|
let n = reader.read(&mut read_buf)?;
|
||||||
|
total_bytes += n as u64;
|
||||||
|
|
||||||
|
let send_event = state.provide_data(block, read_buf[..n].to_vec())?;
|
||||||
|
if let Event::Send(packet) = send_event {
|
||||||
|
let bytes = packet.serialize()?;
|
||||||
|
self.socket.send_to(&bytes, server_tid.expect("TID set"))?;
|
||||||
|
last_sent = Some(bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Event::Send(packet) => {
|
||||||
|
let bytes = packet.serialize()?;
|
||||||
|
self.socket.send_to(&bytes, server_tid.expect("TID set"))?;
|
||||||
|
last_sent = Some(bytes);
|
||||||
|
}
|
||||||
|
Event::Complete => {
|
||||||
|
return Ok(total_bytes);
|
||||||
|
}
|
||||||
|
Event::ReceivedData { .. } => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.is_complete() {
|
||||||
|
return Ok(total_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
retries_left = self.retries;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a client connected to "host:port" or "host" (default port 69).
|
||||||
|
impl std::str::FromStr for Client {
|
||||||
|
type Err = Error;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self> {
|
||||||
|
let addr = if s.contains(':') {
|
||||||
|
s.to_string()
|
||||||
|
} else {
|
||||||
|
format!("{s}:{TFTP_PORT}")
|
||||||
|
};
|
||||||
|
Self::new(addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
41
crates/pfs-tftp/src/error.rs
Normal file
41
crates/pfs-tftp/src/error.rs
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
//! Error types for TFTP I/O operations.
|
||||||
|
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// Errors that can occur during TFTP operations.
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum Error {
|
||||||
|
/// I/O error during socket or file operations.
|
||||||
|
#[error("I/O error: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
|
||||||
|
/// Protocol-level error.
|
||||||
|
#[error("protocol error: {0}")]
|
||||||
|
Protocol(#[from] pfs_tftp_proto::Error),
|
||||||
|
|
||||||
|
/// Address parsing error.
|
||||||
|
#[error("invalid address: {0}")]
|
||||||
|
InvalidAddress(#[from] std::net::AddrParseError),
|
||||||
|
|
||||||
|
/// Remote peer sent an error.
|
||||||
|
#[error("remote error ({code:?}): {message}")]
|
||||||
|
Remote {
|
||||||
|
code: pfs_tftp_proto::ErrorCode,
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Transfer timed out.
|
||||||
|
#[error("transfer timed out after {retries} retries")]
|
||||||
|
Timeout { retries: u32 },
|
||||||
|
|
||||||
|
/// File access error.
|
||||||
|
#[error("file access error: {0}")]
|
||||||
|
FileAccess(String),
|
||||||
|
|
||||||
|
/// Invalid filename (e.g., path traversal attempt).
|
||||||
|
#[error("invalid filename: {0}")]
|
||||||
|
InvalidFilename(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result type for TFTP operations.
|
||||||
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
32
crates/pfs-tftp/src/lib.rs
Normal file
32
crates/pfs-tftp/src/lib.rs
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
//! TFTP Client and Server Library (RFC 1350)
|
||||||
|
//!
|
||||||
|
//! This crate provides synchronous I/O implementations for TFTP client and
|
||||||
|
//! server functionality, built on top of the `pfs-tftp-proto` protocol crate.
|
||||||
|
//!
|
||||||
|
//! # Example: Client
|
||||||
|
//!
|
||||||
|
//! ```no_run
|
||||||
|
//! use pfs_tftp::{Client, Mode};
|
||||||
|
//!
|
||||||
|
//! let mut client = Client::new("192.168.1.1:69").expect("connect");
|
||||||
|
//! let data = client.get("config.txt", Mode::Octet).expect("get");
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! # Example: Server
|
||||||
|
//!
|
||||||
|
//! ```no_run
|
||||||
|
//! use pfs_tftp::Server;
|
||||||
|
//! use std::path::Path;
|
||||||
|
//!
|
||||||
|
//! let server = Server::bind("0.0.0.0:69", Path::new("/tftpboot")).expect("bind");
|
||||||
|
//! server.run().expect("run");
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
mod client;
|
||||||
|
mod error;
|
||||||
|
mod server;
|
||||||
|
|
||||||
|
pub use client::Client;
|
||||||
|
pub use error::{Error, Result};
|
||||||
|
pub use pfs_tftp_proto::Mode;
|
||||||
|
pub use server::{Server, ServerBuilder, ServerConfig};
|
||||||
569
crates/pfs-tftp/src/server.rs
Normal file
569
crates/pfs-tftp/src/server.rs
Normal file
@@ -0,0 +1,569 @@
|
|||||||
|
//! TFTP Server implementation.
|
||||||
|
//!
|
||||||
|
//! Provides a synchronous TFTP server that serves files from a root directory.
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
fs::{File, OpenOptions},
|
||||||
|
io::{Read, Write},
|
||||||
|
net::{SocketAddr, ToSocketAddrs, UdpSocket},
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use pfs_tftp_proto::{ErrorCode, Event, Mode, Packet, ServerState, MAX_DATA_SIZE, TFTP_PORT};
|
||||||
|
|
||||||
|
use crate::{Error, Result};
|
||||||
|
|
||||||
|
/// Default timeout for transfer operations (in seconds).
|
||||||
|
const DEFAULT_TIMEOUT_SECS: u64 = 5;
|
||||||
|
|
||||||
|
/// Default number of retries before giving up on a transfer.
|
||||||
|
const DEFAULT_RETRIES: u32 = 3;
|
||||||
|
|
||||||
|
/// Maximum packet size.
|
||||||
|
const MAX_PACKET_SIZE: usize = 4 + MAX_DATA_SIZE;
|
||||||
|
|
||||||
|
/// Configuration for the TFTP server.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ServerConfig {
|
||||||
|
/// Root directory to serve files from/to.
|
||||||
|
pub root_dir: PathBuf,
|
||||||
|
/// Allow write operations (WRQ).
|
||||||
|
pub allow_write: bool,
|
||||||
|
/// Allow overwriting existing files.
|
||||||
|
pub allow_overwrite: bool,
|
||||||
|
/// Timeout for operations.
|
||||||
|
pub timeout: Duration,
|
||||||
|
/// Number of retries.
|
||||||
|
pub retries: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ServerConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
root_dir: PathBuf::from("."),
|
||||||
|
allow_write: false,
|
||||||
|
allow_overwrite: false,
|
||||||
|
timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
|
||||||
|
retries: DEFAULT_RETRIES,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TFTP server.
|
||||||
|
///
|
||||||
|
/// Serves files from a configured root directory.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// use pfs_tftp::Server;
|
||||||
|
/// use std::path::Path;
|
||||||
|
///
|
||||||
|
/// let server = Server::bind("0.0.0.0:69", Path::new("/tftpboot")).expect("bind");
|
||||||
|
/// server.run().expect("run");
|
||||||
|
/// ```
|
||||||
|
pub struct Server {
|
||||||
|
/// Listening socket on the well-known TFTP port.
|
||||||
|
socket: UdpSocket,
|
||||||
|
/// Server configuration.
|
||||||
|
config: ServerConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Server {
|
||||||
|
/// Create a new TFTP server bound to the specified address.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `addr` - The address to bind to (e.g., "0.0.0.0:69")
|
||||||
|
/// * `root_dir` - The root directory to serve files from
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the socket cannot be bound or the root directory
|
||||||
|
/// is invalid.
|
||||||
|
pub fn bind<A: ToSocketAddrs>(addr: A, root_dir: &Path) -> Result<Self> {
|
||||||
|
let socket = UdpSocket::bind(addr)?;
|
||||||
|
|
||||||
|
let config = ServerConfig {
|
||||||
|
root_dir: root_dir.to_path_buf(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self { socket, config })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a server with custom configuration.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the socket cannot be bound.
|
||||||
|
pub fn with_config<A: ToSocketAddrs>(addr: A, config: ServerConfig) -> Result<Self> {
|
||||||
|
let socket = UdpSocket::bind(addr)?;
|
||||||
|
Ok(Self { socket, config })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Enable or disable write operations.
|
||||||
|
#[must_use]
|
||||||
|
pub fn allow_write(mut self, allow: bool) -> Self {
|
||||||
|
self.config.allow_write = allow;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Enable or disable overwriting existing files.
|
||||||
|
#[must_use]
|
||||||
|
pub fn allow_overwrite(mut self, allow: bool) -> Self {
|
||||||
|
self.config.allow_overwrite = allow;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the server, handling requests indefinitely.
|
||||||
|
///
|
||||||
|
/// This method blocks and handles one request at a time.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if a fatal socket error occurs.
|
||||||
|
pub fn run(&self) -> Result<()> {
|
||||||
|
let mut buf = [0u8; MAX_PACKET_SIZE];
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let (len, client_addr) = self.socket.recv_from(&mut buf)?;
|
||||||
|
|
||||||
|
// Try to parse the request
|
||||||
|
let packet = match Packet::parse(&buf[..len]) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Invalid packet from {client_addr}: {e}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle the request
|
||||||
|
if let Err(e) = self.handle_request(client_addr, &packet) {
|
||||||
|
eprintln!("Error handling request from {client_addr}: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle a single request, then return.
|
||||||
|
///
|
||||||
|
/// This is useful for testing or single-request operation.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if a fatal socket error occurs.
|
||||||
|
pub fn handle_one(&self) -> Result<()> {
|
||||||
|
let mut buf = [0u8; MAX_PACKET_SIZE];
|
||||||
|
let (len, client_addr) = self.socket.recv_from(&mut buf)?;
|
||||||
|
let packet = Packet::parse(&buf[..len])?;
|
||||||
|
self.handle_request(client_addr, &packet)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle a single request.
|
||||||
|
fn handle_request(&self, client_addr: SocketAddr, packet: &Packet) -> Result<()> {
|
||||||
|
// Create a new socket for this transfer with a random TID
|
||||||
|
// RFC 1350 Section 4: "each end of the connection chooses a TID for
|
||||||
|
// itself, to be used for the duration of that connection"
|
||||||
|
let transfer_socket = UdpSocket::bind("0.0.0.0:0")?;
|
||||||
|
transfer_socket.set_read_timeout(Some(self.config.timeout))?;
|
||||||
|
transfer_socket.connect(client_addr)?;
|
||||||
|
|
||||||
|
match packet {
|
||||||
|
Packet::ReadRequest { filename, mode } => {
|
||||||
|
self.handle_rrq(&transfer_socket, filename, *mode)
|
||||||
|
}
|
||||||
|
Packet::WriteRequest { filename, mode } => {
|
||||||
|
self.handle_wrq(&transfer_socket, filename, *mode)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Unexpected packet type on the well-known port
|
||||||
|
let error = Packet::error(ErrorCode::IllegalOperation, "Expected RRQ or WRQ");
|
||||||
|
let _ = transfer_socket.send(&error.serialize()?);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate and resolve a filename to a path within the root directory.
|
||||||
|
fn resolve_path(&self, filename: &str) -> Result<PathBuf> {
|
||||||
|
// Security: Prevent path traversal attacks
|
||||||
|
// RFC 1350 Security Considerations: "care must be taken in the rights
|
||||||
|
// granted to a TFTP server process so as not to violate the security
|
||||||
|
// of the server hosts file system"
|
||||||
|
|
||||||
|
// Reject absolute paths
|
||||||
|
if filename.starts_with('/') || filename.starts_with('\\') {
|
||||||
|
return Err(Error::InvalidFilename(
|
||||||
|
"absolute paths not allowed".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject path traversal
|
||||||
|
if filename.contains("..") {
|
||||||
|
return Err(Error::InvalidFilename(
|
||||||
|
"path traversal not allowed".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reject paths with null bytes
|
||||||
|
if filename.contains('\0') {
|
||||||
|
return Err(Error::InvalidFilename(
|
||||||
|
"null bytes not allowed".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let path = self.config.root_dir.join(filename);
|
||||||
|
|
||||||
|
// Verify the resolved path is still within root_dir
|
||||||
|
let canonical_root = self
|
||||||
|
.config
|
||||||
|
.root_dir
|
||||||
|
.canonicalize()
|
||||||
|
.map_err(|e| Error::FileAccess(format!("cannot access root directory: {e}")))?;
|
||||||
|
|
||||||
|
// For non-existent files (WRQ), we check the parent
|
||||||
|
let check_path = if path.exists() {
|
||||||
|
path.canonicalize()?
|
||||||
|
} else {
|
||||||
|
let parent = path
|
||||||
|
.parent()
|
||||||
|
.ok_or_else(|| Error::InvalidFilename("invalid path".to_string()))?;
|
||||||
|
if !parent.exists() {
|
||||||
|
return Err(Error::FileAccess(format!(
|
||||||
|
"parent directory does not exist: {}",
|
||||||
|
parent.display()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
parent.canonicalize()?.join(path.file_name().ok_or_else(|| {
|
||||||
|
Error::InvalidFilename("missing filename".to_string())
|
||||||
|
})?)
|
||||||
|
};
|
||||||
|
|
||||||
|
if !check_path.starts_with(&canonical_root) {
|
||||||
|
return Err(Error::InvalidFilename(
|
||||||
|
"path outside root directory".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle a read request (RRQ).
|
||||||
|
fn handle_rrq(&self, socket: &UdpSocket, filename: &str, mode: Mode) -> Result<()> {
|
||||||
|
let path = match self.resolve_path(filename) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
let error = Packet::error(ErrorCode::AccessViolation, e.to_string());
|
||||||
|
let _ = socket.send(&error.serialize()?);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Open the file
|
||||||
|
let mut file = match File::open(&path) {
|
||||||
|
Ok(f) => f,
|
||||||
|
Err(e) => {
|
||||||
|
let (code, msg) = match e.kind() {
|
||||||
|
std::io::ErrorKind::NotFound => {
|
||||||
|
(ErrorCode::FileNotFound, "File not found".to_string())
|
||||||
|
}
|
||||||
|
std::io::ErrorKind::PermissionDenied => {
|
||||||
|
(ErrorCode::AccessViolation, "Permission denied".to_string())
|
||||||
|
}
|
||||||
|
_ => (ErrorCode::NotDefined, e.to_string()),
|
||||||
|
};
|
||||||
|
let error = Packet::error(code, msg);
|
||||||
|
let _ = socket.send(&error.serialize()?);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create server state for RRQ (server sends data)
|
||||||
|
let request = Packet::rrq(filename, mode);
|
||||||
|
let mut state = match ServerState::from_request(&request) {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => {
|
||||||
|
let error = Packet::error(ErrorCode::IllegalOperation, e.to_string());
|
||||||
|
let _ = socket.send(&error.serialize()?);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start transfer - get NeedData event
|
||||||
|
let event = state.start();
|
||||||
|
let mut read_buf = [0u8; MAX_DATA_SIZE];
|
||||||
|
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
|
||||||
|
let mut last_sent: Option<Vec<u8>> = None;
|
||||||
|
let mut retries_left = self.config.retries;
|
||||||
|
|
||||||
|
// Handle initial NeedData event
|
||||||
|
if let Event::NeedData { block } = event {
|
||||||
|
let n = file.read(&mut read_buf)?;
|
||||||
|
let send_event = state.provide_data(block, read_buf[..n].to_vec())?;
|
||||||
|
if let Event::Send(packet) = send_event {
|
||||||
|
let bytes = packet.serialize()?;
|
||||||
|
socket.send(&bytes)?;
|
||||||
|
last_sent = Some(bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main transfer loop
|
||||||
|
loop {
|
||||||
|
let len = match socket.recv(&mut recv_buf) {
|
||||||
|
Ok(len) => len,
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
||||||
|
if retries_left == 0 {
|
||||||
|
return Err(Error::Timeout {
|
||||||
|
retries: self.config.retries,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
retries_left -= 1;
|
||||||
|
if let Some(ref last) = last_sent {
|
||||||
|
socket.send(last)?;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e.into()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let packet = Packet::parse(&recv_buf[..len])?;
|
||||||
|
|
||||||
|
if let Packet::Error { code, message } = packet {
|
||||||
|
return Err(Error::Remote { code, message });
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = state.receive(&packet)?;
|
||||||
|
|
||||||
|
for event in events {
|
||||||
|
match event {
|
||||||
|
Event::NeedData { block } => {
|
||||||
|
let n = file.read(&mut read_buf)?;
|
||||||
|
let send_event = state.provide_data(block, read_buf[..n].to_vec())?;
|
||||||
|
if let Event::Send(packet) = send_event {
|
||||||
|
let bytes = packet.serialize()?;
|
||||||
|
socket.send(&bytes)?;
|
||||||
|
last_sent = Some(bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Event::Complete => {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
// Send and ReceivedData not expected in RRQ handling
|
||||||
|
Event::Send(_) | Event::ReceivedData { .. } => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.is_complete() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
retries_left = self.config.retries;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle a write request (WRQ).
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
|
fn handle_wrq(&self, socket: &UdpSocket, filename: &str, mode: Mode) -> Result<()> {
|
||||||
|
// Check if writes are allowed
|
||||||
|
if !self.config.allow_write {
|
||||||
|
let error = Packet::error(ErrorCode::AccessViolation, "Write not allowed");
|
||||||
|
let _ = socket.send(&error.serialize()?);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let path = match self.resolve_path(filename) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
let error = Packet::error(ErrorCode::AccessViolation, e.to_string());
|
||||||
|
let _ = socket.send(&error.serialize()?);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if file exists and overwrite is not allowed
|
||||||
|
if path.exists() && !self.config.allow_overwrite {
|
||||||
|
let error = Packet::error(ErrorCode::FileAlreadyExists, "File already exists");
|
||||||
|
let _ = socket.send(&error.serialize()?);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open/create the file
|
||||||
|
let mut file = match OpenOptions::new()
|
||||||
|
.write(true)
|
||||||
|
.create(true)
|
||||||
|
.truncate(true)
|
||||||
|
.open(&path)
|
||||||
|
{
|
||||||
|
Ok(f) => f,
|
||||||
|
Err(e) => {
|
||||||
|
let (code, msg) = match e.kind() {
|
||||||
|
std::io::ErrorKind::PermissionDenied => {
|
||||||
|
(ErrorCode::AccessViolation, "Permission denied".to_string())
|
||||||
|
}
|
||||||
|
_ => (ErrorCode::NotDefined, e.to_string()),
|
||||||
|
};
|
||||||
|
let error = Packet::error(code, msg);
|
||||||
|
let _ = socket.send(&error.serialize()?);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create server state for WRQ (server receives data)
|
||||||
|
let request = Packet::wrq(filename, mode);
|
||||||
|
let mut state = match ServerState::from_request(&request) {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => {
|
||||||
|
let error = Packet::error(ErrorCode::IllegalOperation, e.to_string());
|
||||||
|
let _ = socket.send(&error.serialize()?);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start transfer - send ACK 0
|
||||||
|
let event = state.start();
|
||||||
|
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
|
||||||
|
let mut last_sent: Option<Vec<u8>> = None;
|
||||||
|
let mut retries_left = self.config.retries;
|
||||||
|
|
||||||
|
if let Event::Send(packet) = event {
|
||||||
|
let bytes = packet.serialize()?;
|
||||||
|
socket.send(&bytes)?;
|
||||||
|
last_sent = Some(bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main transfer loop
|
||||||
|
loop {
|
||||||
|
let len = match socket.recv(&mut recv_buf) {
|
||||||
|
Ok(len) => len,
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
|
||||||
|
if retries_left == 0 {
|
||||||
|
// Clean up partial file on timeout
|
||||||
|
let _ = std::fs::remove_file(&path);
|
||||||
|
return Err(Error::Timeout {
|
||||||
|
retries: self.config.retries,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
retries_left -= 1;
|
||||||
|
if let Some(ref last) = last_sent {
|
||||||
|
socket.send(last)?;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = std::fs::remove_file(&path);
|
||||||
|
return Err(e.into());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let packet = Packet::parse(&recv_buf[..len])?;
|
||||||
|
|
||||||
|
if let Packet::Error { code, message } = packet {
|
||||||
|
let _ = std::fs::remove_file(&path);
|
||||||
|
return Err(Error::Remote { code, message });
|
||||||
|
}
|
||||||
|
|
||||||
|
let events = state.receive(&packet)?;
|
||||||
|
|
||||||
|
for event in events {
|
||||||
|
match event {
|
||||||
|
Event::ReceivedData { data, .. } => {
|
||||||
|
if let Err(e) = file.write_all(&data) {
|
||||||
|
let error = match e.kind() {
|
||||||
|
std::io::ErrorKind::StorageFull => {
|
||||||
|
Packet::error(ErrorCode::DiskFull, "Disk full")
|
||||||
|
}
|
||||||
|
_ => Packet::error(ErrorCode::NotDefined, e.to_string()),
|
||||||
|
};
|
||||||
|
let _ = socket.send(&error.serialize()?);
|
||||||
|
let _ = std::fs::remove_file(&path);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Event::Send(packet) => {
|
||||||
|
let bytes = packet.serialize()?;
|
||||||
|
socket.send(&bytes)?;
|
||||||
|
last_sent = Some(bytes);
|
||||||
|
}
|
||||||
|
Event::Complete => {
|
||||||
|
file.flush()?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
// NeedData not expected in WRQ handling
|
||||||
|
Event::NeedData { .. } => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.is_complete() {
|
||||||
|
file.flush()?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
retries_left = self.config.retries;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builder for creating a server with custom options.
|
||||||
|
pub struct ServerBuilder {
|
||||||
|
config: ServerConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerBuilder {
|
||||||
|
/// Create a new server builder with the specified root directory.
|
||||||
|
#[must_use]
|
||||||
|
pub fn new(root_dir: impl Into<PathBuf>) -> Self {
|
||||||
|
Self {
|
||||||
|
config: ServerConfig {
|
||||||
|
root_dir: root_dir.into(),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allow write operations.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn allow_write(mut self, allow: bool) -> Self {
|
||||||
|
self.config.allow_write = allow;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allow overwriting existing files.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn allow_overwrite(mut self, allow: bool) -> Self {
|
||||||
|
self.config.allow_overwrite = allow;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the timeout for operations.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn timeout(mut self, timeout: Duration) -> Self {
|
||||||
|
self.config.timeout = timeout;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the number of retries.
|
||||||
|
#[must_use]
|
||||||
|
pub const fn retries(mut self, retries: u32) -> Self {
|
||||||
|
self.config.retries = retries;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build the server bound to the specified address.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the socket cannot be bound.
|
||||||
|
pub fn bind<A: ToSocketAddrs>(self, addr: A) -> Result<Server> {
|
||||||
|
Server::with_config(addr, self.config)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build the server bound to the default TFTP port.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns an error if the socket cannot be bound.
|
||||||
|
pub fn bind_default(self) -> Result<Server> {
|
||||||
|
self.bind(("0.0.0.0", TFTP_PORT))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
fn main() {
|
|
||||||
println!("Hello, world!");
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user