From 9ebfdb8cb28432141d6ada2dcf8d93208c270075 Mon Sep 17 00:00:00 2001 From: ddidderr Date: Sun, 21 Dec 2025 11:36:07 +0100 Subject: [PATCH] feat: add sync TFTP client/server library --- crates/pfs-tftp-proto/src/packet.rs | 16 +- crates/pfs-tftp-sync/src/client.rs | 448 +++++++++++++++++++++++++- crates/pfs-tftp-sync/src/error.rs | 76 +++++ crates/pfs-tftp-sync/src/lib.rs | 5 +- crates/pfs-tftp-sync/src/server.rs | 471 +++++++++++++++++++++++++++- crates/pfs-tftp-sync/src/util.rs | 202 +++++++++++- 6 files changed, 1206 insertions(+), 12 deletions(-) create mode 100644 crates/pfs-tftp-sync/src/error.rs diff --git a/crates/pfs-tftp-proto/src/packet.rs b/crates/pfs-tftp-proto/src/packet.rs index 58f20fe..9615305 100644 --- a/crates/pfs-tftp-proto/src/packet.rs +++ b/crates/pfs-tftp-proto/src/packet.rs @@ -294,8 +294,15 @@ impl fmt::Display for DecodeError { Self::UnknownMode(mode) => write!(f, "unknown transfer mode {mode:?}"), Self::UnknownErrorCode(code) => write!(f, "unknown error code {code}"), Self::OversizeData(n) => write!(f, "DATA payload too large ({n} bytes)"), - Self::InvalidLength { kind, expected, got } => { - write!(f, "{kind} packet has invalid length (expected {expected}, got {got})") + Self::InvalidLength { + kind, + expected, + got, + } => { + write!( + f, + "{kind} packet has invalid length (expected {expected}, got {got})" + ) } Self::TrailingBytes => write!(f, "trailing bytes after packet"), } @@ -339,7 +346,10 @@ mod tests { Packet::Ack { block: 1 }.encode_into(&mut bytes); bytes.push(0); let err = Packet::decode(&bytes).unwrap_err(); - assert!(matches!(err, DecodeError::InvalidLength { kind: "ACK", .. })); + assert!(matches!( + err, + DecodeError::InvalidLength { kind: "ACK", .. } + )); } #[allow(clippy::unwrap_used)] diff --git a/crates/pfs-tftp-sync/src/client.rs b/crates/pfs-tftp-sync/src/client.rs index 0c84d59..d9df3f4 100644 --- a/crates/pfs-tftp-sync/src/client.rs +++ b/crates/pfs-tftp-sync/src/client.rs @@ -1,20 +1,38 @@ #![forbid(unsafe_code)] -use std::net::SocketAddr; -use std::time::Duration; +use std::{ + fs::File, + io::{Read, Write}, + net::SocketAddr, + path::Path, + time::Duration, +}; + +use pfs_tftp_proto::packet::{BLOCK_SIZE, ErrorCode, MAX_PACKET_SIZE, Packet, Request}; + +use crate::{ + Mode, + error::{Error, Result}, + util, +}; /// Configuration for a synchronous TFTP client. #[derive(Debug, Clone)] pub struct ClientConfig { pub timeout: Duration, pub retries: u32, + pub dally_timeout: Duration, + pub dally_retries: u32, } impl Default for ClientConfig { fn default() -> Self { + let timeout = Duration::from_secs(5); Self { - timeout: Duration::from_secs(5), + timeout, retries: 5, + dally_timeout: timeout, + dally_retries: 2, } } } @@ -41,4 +59,428 @@ impl Client { pub fn config(&self) -> &ClientConfig { &self.config } + + /// Downloads `remote_filename` from the configured server into `output`. + /// + /// RFC reference: + /// - Packet formats: RFC 1350, Section 5 (Figures 5-1..5-4) + /// - Lock-step ACK/DATA: RFC 1350, Section 2 + /// - Termination (final <512 DATA): RFC 1350, Section 6 + /// - Transfer identifiers (TID): RFC 1350, Section 4 (page 4) + /// + /// # Errors + /// Returns an error on I/O failures, protocol decode errors, timeouts after + /// exhausting retries, or if the peer sends an ERROR packet. + pub fn download_to_writer( + &self, + remote_filename: &str, + mode: Mode, + output: &mut impl Write, + ) -> Result<()> { + let socket = util::bind_ephemeral_for(self.server)?; + socket.set_read_timeout(Some(self.config.timeout))?; + Self::validate_mode(mode)?; + + let mut sink = util::DataSink::new(output, mode); + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + + let (peer, first) = self.rrq_handshake(&socket, &mut recv_buf, remote_filename, mode)?; + let Packet::Data { block, data } = first else { + return Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unexpected packet type", + ))); + }; + if block != 1 { + util::send_error( + &socket, + peer, + ErrorCode::IllegalOperation, + "unexpected block number", + ); + return Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unexpected block number", + ))); + } + + sink.write_data(&data).map_err(map_sink_error)?; + let ack_bytes = Packet::Ack { block }.encode(); + socket.send_to(&ack_bytes, peer)?; + let mut last_ack = Some(ack_bytes.clone()); + + let mut last_sent = ack_bytes; + let mut expected_block = 2u16; + + if data.len() < BLOCK_SIZE { + sink.finish().map_err(map_sink_error)?; + self.dally_final_ack(&socket, peer, block, last_ack.as_deref())?; + return Ok(()); + } + + let mut attempts = 0u32; + loop { + let n = self.recv_from_peer_with_retries( + &socket, + &mut recv_buf, + &mut attempts, + "ACK", + &last_sent, + peer, + )?; + let pkt = Packet::decode(&recv_buf[..n])?; + match pkt { + Packet::Error { code, message } => { + return Err(Error::RemoteError { code, message }); + } + Packet::Data { block, data } if block == expected_block => { + sink.write_data(&data).map_err(map_sink_error)?; + let ack = Packet::Ack { block }.encode(); + socket.send_to(&ack, peer)?; + last_ack = Some(ack.clone()); + last_sent = ack; + if data.len() < BLOCK_SIZE { + sink.finish().map_err(map_sink_error)?; + self.dally_final_ack(&socket, peer, block, last_ack.as_deref())?; + return Ok(()); + } + expected_block = expected_block.wrapping_add(1); + } + Packet::Data { block, .. } if block == expected_block.wrapping_sub(1) => { + if let Some(ack) = &last_ack { + socket.send_to(ack, peer)?; + } + } + _ => { + util::send_error( + &socket, + peer, + ErrorCode::IllegalOperation, + "unexpected packet type", + ); + return Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unexpected packet type", + ))); + } + } + } + } + + /// Downloads `remote_filename` from the configured server into `local_path`. + /// + /// # Errors + /// Returns any error from [`Self::download_to_writer`] or filesystem I/O. + pub fn get(&self, remote_filename: &str, local_path: &Path, mode: Mode) -> Result<()> { + let mut file = File::create(local_path)?; + self.download_to_writer(remote_filename, mode, &mut file) + } + + /// Uploads `input` to the configured server as `remote_filename`. + /// + /// RFC reference: + /// - WRQ/ACK(0) handshake: RFC 1350, Section 4 (page 4) and Section 5 (ACK block=0) + /// - Lock-step ACK/DATA: RFC 1350, Section 2 + /// - Termination (final <512 DATA): RFC 1350, Section 6 + /// + /// # Errors + /// Returns an error on I/O failures, protocol decode errors, timeouts after + /// exhausting retries, or if the peer sends an ERROR packet. + pub fn upload_from_reader( + &self, + remote_filename: &str, + mode: Mode, + input: &mut impl Read, + ) -> Result<()> { + let socket = util::bind_ephemeral_for(self.server)?; + socket.set_read_timeout(Some(self.config.timeout))?; + Self::validate_mode(mode)?; + + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + let peer = self.wrq_handshake(&socket, &mut recv_buf, remote_filename, mode)?; + + let mut source = util::DataSource::new(input, mode); + let mut next_block: u16 = 1; + let mut attempts = 0u32; + while let Some(data) = source.next_block()? { + let data_bytes = Packet::Data { + block: next_block, + data, + } + .encode(); + socket.send_to(&data_bytes, peer)?; + + let expected_ack = next_block; + let last_sent = data_bytes; + let last_sent_kind = "DATA"; + + loop { + let n = self.recv_from_peer_with_retries( + &socket, + &mut recv_buf, + &mut attempts, + last_sent_kind, + &last_sent, + peer, + )?; + match Packet::decode(&recv_buf[..n])? { + Packet::Ack { block } if block == expected_ack => break, + Packet::Ack { block } if block == expected_ack.wrapping_sub(1) => { + socket.send_to(&last_sent, peer)?; + } + Packet::Error { code, message } => { + return Err(Error::RemoteError { code, message }); + } + _ => { + util::send_error( + &socket, + peer, + ErrorCode::IllegalOperation, + "unexpected packet type", + ); + return Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unexpected packet type", + ))); + } + } + } + + attempts = 0; + next_block = next_block.wrapping_add(1); + } + + Ok(()) + } + + /// Uploads `local_path` to the configured server as `remote_filename`. + /// + /// # Errors + /// Returns any error from [`Self::upload_from_reader`] or filesystem I/O. + pub fn put(&self, local_path: &Path, remote_filename: &str, mode: Mode) -> Result<()> { + let mut file = File::open(local_path)?; + self.upload_from_reader(remote_filename, mode, &mut file) + } + + fn dally_final_ack( + &self, + socket: &std::net::UdpSocket, + peer: SocketAddr, + final_block: u16, + ack_bytes: Option<&[u8]>, + ) -> Result<()> { + if self.config.dally_retries == 0 { + return Ok(()); + } + let Some(ack_bytes) = ack_bytes else { + return Ok(()); + }; + + socket.set_read_timeout(Some(self.config.dally_timeout))?; + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + let mut timeouts = 0u32; + + // RFC 1350, Section 6: "dallying is encouraged" to retransmit the final ACK + // if the last DATA is retransmitted. + while timeouts < self.config.dally_retries { + match socket.recv_from(&mut recv_buf) { + Ok((n, from)) => { + if from != peer { + util::send_error( + socket, + from, + ErrorCode::UnknownTransferId, + "unknown transfer id", + ); + continue; + } + if let Ok(Packet::Data { block, .. }) = Packet::decode(&recv_buf[..n]) + && block == final_block + { + let _ignored = socket.send_to(ack_bytes, peer); + } + } + Err(e) if util::is_timeout(&e) => timeouts += 1, + Err(e) => return Err(e.into()), + } + } + + Ok(()) + } + + fn validate_mode(mode: Mode) -> Result<()> { + if mode == Mode::Mail { + return Err(Error::UnsupportedMode(mode)); + } + Ok(()) + } + + fn rrq_handshake( + &self, + socket: &std::net::UdpSocket, + recv_buf: &mut [u8], + remote_filename: &str, + mode: Mode, + ) -> Result<(SocketAddr, Packet)> { + let rrq_bytes = Packet::Rrq(Request { + filename: remote_filename.to_string(), + mode, + }) + .encode(); + socket.send_to(&rrq_bytes, self.server)?; + + let mut attempts = 0u32; + loop { + let (n, from) = self.recv_from_with_retries( + socket, + recv_buf, + &mut attempts, + "RRQ", + &rrq_bytes, + self.server, + )?; + if from.ip() != self.server.ip() { + continue; + } + + let pkt = Packet::decode(&recv_buf[..n])?; + match pkt { + Packet::Data { .. } => return Ok((from, pkt)), + Packet::Error { code, message } => { + return Err(Error::RemoteError { code, message }); + } + _ => { + util::send_error( + socket, + from, + ErrorCode::IllegalOperation, + "unexpected packet", + ); + return Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unexpected packet", + ))); + } + } + } + } + + fn wrq_handshake( + &self, + socket: &std::net::UdpSocket, + recv_buf: &mut [u8], + remote_filename: &str, + mode: Mode, + ) -> Result { + let wrq_bytes = Packet::Wrq(Request { + filename: remote_filename.to_string(), + mode, + }) + .encode(); + socket.send_to(&wrq_bytes, self.server)?; + + let mut attempts = 0u32; + loop { + let (n, from) = self.recv_from_with_retries( + socket, + recv_buf, + &mut attempts, + "WRQ", + &wrq_bytes, + self.server, + )?; + if from.ip() != self.server.ip() { + continue; + } + match Packet::decode(&recv_buf[..n])? { + Packet::Ack { block: 0 } => return Ok(from), + Packet::Error { code, message } => { + return Err(Error::RemoteError { code, message }); + } + _ => { + util::send_error( + socket, + from, + ErrorCode::IllegalOperation, + "unexpected packet", + ); + return Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unexpected packet", + ))); + } + } + } + } + + fn recv_from_with_retries( + &self, + socket: &std::net::UdpSocket, + recv_buf: &mut [u8], + attempts: &mut u32, + last_sent_kind: &'static str, + last_sent: &[u8], + resend_to: SocketAddr, + ) -> Result<(usize, SocketAddr)> { + loop { + match socket.recv_from(recv_buf) { + Ok(v) => { + *attempts = 0; + return Ok(v); + } + Err(e) if util::is_timeout(&e) => { + *attempts += 1; + if *attempts > self.config.retries { + return Err(Error::Timeout { + last_packet: last_sent_kind, + attempts: *attempts, + }); + } + socket.send_to(last_sent, resend_to)?; + } + Err(e) => return Err(e.into()), + } + } + } + + fn recv_from_peer_with_retries( + &self, + socket: &std::net::UdpSocket, + recv_buf: &mut [u8], + attempts: &mut u32, + last_sent_kind: &'static str, + last_sent: &[u8], + peer: SocketAddr, + ) -> Result { + loop { + let (n, from) = self.recv_from_with_retries( + socket, + recv_buf, + attempts, + last_sent_kind, + last_sent, + peer, + )?; + if from == peer { + return Ok(n); + } + + // RFC 1350, Section 4 (page 4): incorrect source port => Unknown transfer ID. + if from.ip() == peer.ip() { + util::send_error( + socket, + from, + ErrorCode::UnknownTransferId, + "unknown transfer id", + ); + } + } + } +} + +fn map_sink_error(err: util::DataSinkError) -> Error { + match err { + util::DataSinkError::Io(e) => Error::Io(e), + util::DataSinkError::Netascii(e) => Error::Netascii(e), + } } diff --git a/crates/pfs-tftp-sync/src/error.rs b/crates/pfs-tftp-sync/src/error.rs new file mode 100644 index 0000000..8cb4a2f --- /dev/null +++ b/crates/pfs-tftp-sync/src/error.rs @@ -0,0 +1,76 @@ +#![forbid(unsafe_code)] + +use core::fmt; + +use pfs_tftp_proto::{netascii, packet}; + +/// Errors returned by the synchronous client APIs. +#[derive(Debug)] +pub enum Error { + Io(std::io::Error), + Protocol(packet::DecodeError), + Netascii(netascii::DecodeError), + UnsupportedMode(packet::Mode), + RemoteError { + code: packet::ErrorCode, + message: String, + }, + Timeout { + last_packet: &'static str, + attempts: u32, + }, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Io(e) => write!(f, "{e}"), + Self::Protocol(e) => write!(f, "{e}"), + Self::Netascii(e) => write!(f, "{e}"), + Self::UnsupportedMode(mode) => { + write!(f, "unsupported transfer mode: {}", mode.as_str()) + } + Self::RemoteError { code, message } => { + write!(f, "remote error {code:?}: {message}") + } + Self::Timeout { + last_packet, + attempts, + } => write!( + f, + "timed out waiting for response after {attempts} attempts (last sent: {last_packet})" + ), + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Io(e) => Some(e), + Self::Protocol(e) => Some(e), + Self::Netascii(e) => Some(e), + Self::UnsupportedMode(_) | Self::RemoteError { .. } | Self::Timeout { .. } => None, + } + } +} + +impl From for Error { + fn from(value: std::io::Error) -> Self { + Self::Io(value) + } +} + +impl From for Error { + fn from(value: packet::DecodeError) -> Self { + Self::Protocol(value) + } +} + +impl From for Error { + fn from(value: netascii::DecodeError) -> Self { + Self::Netascii(value) + } +} + +pub type Result = core::result::Result; diff --git a/crates/pfs-tftp-sync/src/lib.rs b/crates/pfs-tftp-sync/src/lib.rs index 1fc881b..e8df65b 100644 --- a/crates/pfs-tftp-sync/src/lib.rs +++ b/crates/pfs-tftp-sync/src/lib.rs @@ -6,8 +6,11 @@ #![forbid(unsafe_code)] pub mod client; +pub mod error; pub mod server; -pub mod util; +mod util; pub use client::{Client, ClientConfig}; +pub use error::{Error, Result}; +pub use pfs_tftp_proto::packet::Mode; pub use server::{Server, ServerConfig}; diff --git a/crates/pfs-tftp-sync/src/server.rs b/crates/pfs-tftp-sync/src/server.rs index b2ab8e6..d94f8dc 100644 --- a/crates/pfs-tftp-sync/src/server.rs +++ b/crates/pfs-tftp-sync/src/server.rs @@ -1,8 +1,16 @@ #![forbid(unsafe_code)] -use std::net::SocketAddr; -use std::path::PathBuf; -use std::time::Duration; +use std::{ + fs::{File, OpenOptions}, + net::SocketAddr, + path::{Path, PathBuf}, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; + +use pfs_tftp_proto::packet::{BLOCK_SIZE, ErrorCode, MAX_PACKET_SIZE, Packet}; + +use crate::{Mode, util}; /// Configuration for a synchronous TFTP server. #[derive(Debug, Clone)] @@ -13,17 +21,24 @@ pub struct ServerConfig { pub overwrite: bool, pub timeout: Duration, pub retries: u32, + pub dally_timeout: Duration, + pub dally_retries: u32, + pub max_request_size: usize, } impl Default for ServerConfig { fn default() -> Self { + let timeout = Duration::from_secs(5); Self { bind: SocketAddr::from(([0, 0, 0, 0], 6969)), root: PathBuf::from("."), allow_write: false, overwrite: false, - timeout: Duration::from_secs(5), + timeout, retries: 5, + dally_timeout: timeout, + dally_retries: 2, + max_request_size: util::MAX_REQUEST_SIZE, } } } @@ -44,4 +59,452 @@ impl Server { pub fn config(&self) -> &ServerConfig { &self.config } + + /// Serves requests forever. + /// + /// # Errors + /// Returns an error if the server socket cannot be bound or configured. + pub fn serve(&self) -> std::io::Result<()> { + self.serve_inner(None) + } + + /// Serves requests until `shutdown` becomes `true`. + /// + /// This is mainly intended for tests. + /// + /// # Errors + /// Returns an error if the server socket cannot be bound or configured. + pub fn serve_until(&self, shutdown: &AtomicBool) -> std::io::Result<()> { + self.serve_inner(Some(shutdown)) + } + + fn serve_inner(&self, shutdown: Option<&AtomicBool>) -> std::io::Result<()> { + let socket = std::net::UdpSocket::bind(self.config.bind)?; + if shutdown.is_some() { + socket.set_read_timeout(Some(Duration::from_millis(200)))?; + } + + let mut buf = vec![0u8; self.config.max_request_size]; + + loop { + if let Some(flag) = shutdown + && flag.load(Ordering::Relaxed) + { + return Ok(()); + } + + let (n, peer) = match socket.recv_from(&mut buf) { + Ok(v) => v, + Err(e) if shutdown.is_some() && util::is_timeout(&e) => continue, + Err(e) => return Err(e), + }; + + let Ok(pkt) = Packet::decode(&buf[..n]) else { + util::send_error( + &socket, + peer, + ErrorCode::IllegalOperation, + "malformed packet", + ); + continue; + }; + + match pkt { + Packet::Rrq(req) => { + let cfg = self.config.clone(); + std::thread::spawn(move || { + let _ignored = handle_rrq(cfg, peer, req); + }); + } + Packet::Wrq(req) => { + let cfg = self.config.clone(); + std::thread::spawn(move || { + let _ignored = handle_wrq(cfg, peer, req); + }); + } + Packet::Data { .. } | Packet::Ack { .. } | Packet::Error { .. } => { + // RFC 1350, Section 4 (page 4): incorrect source port => Unknown transfer ID. + util::send_error( + &socket, + peer, + ErrorCode::UnknownTransferId, + "unknown transfer id", + ); + } + } + } + } +} + +fn handle_rrq( + cfg: ServerConfig, + client: SocketAddr, + req: pfs_tftp_proto::packet::Request, +) -> std::io::Result<()> { + let pfs_tftp_proto::packet::Request { filename, mode } = req; + let ServerConfig { + bind, + root, + timeout, + retries, + .. + } = cfg; + + if mode == Mode::Mail { + let socket = util::bind_ephemeral_on_ip(bind.ip())?; + util::send_error( + &socket, + client, + ErrorCode::NotDefined, + "mail mode is obsolete", + ); + return Ok(()); + } + + let socket = util::bind_ephemeral_on_ip(bind.ip())?; + socket.set_read_timeout(Some(timeout))?; + + let Some(path) = resolve_path(&root, &filename) else { + util::send_error(&socket, client, ErrorCode::AccessViolation, "invalid path"); + return Ok(()); + }; + + let mut file = match File::open(&path) { + Ok(f) => f, + Err(e) => { + util::send_error( + &socket, + client, + map_io_error_code(&e), + "failed to open file", + ); + return Ok(()); + } + }; + + let mut source = util::DataSource::new(&mut file, mode); + let mut block: u16 = 1; + + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + let mut attempts = 0u32; + + while let Some(data) = source.next_block()? { + let pkt = Packet::Data { block, data }; + let bytes = pkt.encode(); + socket.send_to(&bytes, client)?; + + let last_sent = bytes; + let expected_ack = block; + + loop { + let (n, from) = match socket.recv_from(&mut recv_buf) { + Ok(v) => v, + Err(e) if util::is_timeout(&e) => { + attempts += 1; + if attempts > retries { + return Ok(()); + } + socket.send_to(&last_sent, client)?; + continue; + } + Err(e) => return Err(e), + }; + if from != client { + util::send_error( + &socket, + from, + ErrorCode::UnknownTransferId, + "unknown transfer id", + ); + continue; + } + + attempts = 0; + match Packet::decode(&recv_buf[..n]) { + Ok(Packet::Ack { block }) if block == expected_ack => break, + Ok(Packet::Ack { block }) if block == expected_ack.wrapping_sub(1) => { + // Duplicate ACK, retransmit last DATA. + socket.send_to(&last_sent, client)?; + } + Ok(Packet::Error { .. }) => return Ok(()), + Ok(_) | Err(_) => { + util::send_error( + &socket, + client, + ErrorCode::IllegalOperation, + "unexpected packet", + ); + return Ok(()); + } + } + } + + // RFC 1350, Section 6: final DATA has <512 bytes. + if last_sent.len() < 4 + BLOCK_SIZE { + return Ok(()); + } + + block = block.wrapping_add(1); + } + + Ok(()) +} + +#[allow(clippy::too_many_lines)] +fn handle_wrq( + cfg: ServerConfig, + client: SocketAddr, + req: pfs_tftp_proto::packet::Request, +) -> std::io::Result<()> { + let pfs_tftp_proto::packet::Request { filename, mode } = req; + let ServerConfig { + bind, + root, + allow_write, + overwrite, + timeout, + retries, + dally_timeout, + dally_retries, + .. + } = cfg; + + let socket = util::bind_ephemeral_on_ip(bind.ip())?; + socket.set_read_timeout(Some(timeout))?; + + if mode == Mode::Mail { + util::send_error( + &socket, + client, + ErrorCode::NotDefined, + "mail mode is obsolete", + ); + return Ok(()); + } + + if !allow_write { + util::send_error( + &socket, + client, + ErrorCode::AccessViolation, + "writes disabled", + ); + return Ok(()); + } + + let Some(path) = resolve_path(&root, &filename) else { + util::send_error(&socket, client, ErrorCode::AccessViolation, "invalid path"); + return Ok(()); + }; + + let mut file = match open_for_write(&path, overwrite) { + Ok(f) => f, + Err(e) => { + util::send_error( + &socket, + client, + map_io_error_code(&e), + "failed to create file", + ); + return Ok(()); + } + }; + + // RFC 1350, Section 5: WRQ is acknowledged by ACK(block=0). + let ack0 = Packet::Ack { block: 0 }.encode(); + socket.send_to(&ack0, client)?; + + let mut last_ack = ack0; + let mut expected_block: u16 = 1; + let mut attempts = 0u32; + + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + let mut sink = util::DataSink::new(&mut file, mode); + + loop { + let (n, from) = match socket.recv_from(&mut recv_buf) { + Ok(v) => v, + Err(e) if util::is_timeout(&e) => { + attempts += 1; + if attempts > retries { + return Ok(()); + } + socket.send_to(&last_ack, client)?; + continue; + } + Err(e) => return Err(e), + }; + if from != client { + util::send_error( + &socket, + from, + ErrorCode::UnknownTransferId, + "unknown transfer id", + ); + continue; + } + + attempts = 0; + let Ok(pkt) = Packet::decode(&recv_buf[..n]) else { + util::send_error( + &socket, + client, + ErrorCode::IllegalOperation, + "malformed packet", + ); + return Ok(()); + }; + + match pkt { + Packet::Error { .. } => return Ok(()), + Packet::Data { block, data } => { + if block == expected_block { + if let Err(e) = sink.write_data(&data) { + util::send_error( + &socket, + client, + map_sink_error_code(&e), + "failed to write data", + ); + return Ok(()); + } + + let ack = Packet::Ack { block }.encode(); + socket.send_to(&ack, client)?; + last_ack = ack; + + if data.len() < BLOCK_SIZE { + if let Err(e) = sink.finish() { + util::send_error( + &socket, + client, + map_sink_error_code(&e), + "failed to finish transfer", + ); + return Ok(()); + } + dally_final_ack( + dally_timeout, + dally_retries, + &socket, + client, + block, + &last_ack, + )?; + return Ok(()); + } + + expected_block = expected_block.wrapping_add(1); + } else if block == expected_block.wrapping_sub(1) { + // Duplicate DATA (ACK lost); re-ACK. + socket.send_to(&last_ack, client)?; + } else { + util::send_error( + &socket, + client, + ErrorCode::IllegalOperation, + "unexpected block number", + ); + return Ok(()); + } + } + Packet::Ack { .. } | Packet::Rrq(_) | Packet::Wrq(_) => { + util::send_error( + &socket, + client, + ErrorCode::IllegalOperation, + "unexpected packet type", + ); + return Ok(()); + } + } + } +} + +fn resolve_path(root: &Path, filename: &str) -> Option { + let rel = Path::new(filename); + if rel.is_absolute() { + return None; + } + for c in rel.components() { + match c { + std::path::Component::Normal(_) => {} + _ => return None, + } + } + Some(root.join(rel)) +} + +#[must_use] +fn map_io_error_code(err: &std::io::Error) -> ErrorCode { + match err.kind() { + std::io::ErrorKind::NotFound => ErrorCode::FileNotFound, + std::io::ErrorKind::PermissionDenied => ErrorCode::AccessViolation, + std::io::ErrorKind::AlreadyExists => ErrorCode::FileAlreadyExists, + std::io::ErrorKind::StorageFull => ErrorCode::DiskFull, + _ => ErrorCode::NotDefined, + } +} + +fn open_for_write(path: &Path, overwrite: bool) -> std::io::Result { + let mut opts = OpenOptions::new(); + opts.write(true); + if overwrite { + opts.create(true).truncate(true); + } else { + opts.create_new(true); + } + opts.open(path) +} + +#[must_use] +fn map_sink_error_code(err: &util::DataSinkError) -> ErrorCode { + match err { + util::DataSinkError::Io(e) => map_io_error_code(e), + // Invalid netascii is a protocol violation. + util::DataSinkError::Netascii(_) => ErrorCode::IllegalOperation, + } +} + +fn dally_final_ack( + dally_timeout: Duration, + dally_retries: u32, + socket: &std::net::UdpSocket, + peer: SocketAddr, + final_block: u16, + ack_bytes: &[u8], +) -> std::io::Result<()> { + if dally_retries == 0 { + return Ok(()); + } + + socket.set_read_timeout(Some(dally_timeout))?; + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + let mut timeouts = 0u32; + + // RFC 1350, Section 6: dallying for final ACK retransmission. + while timeouts < dally_retries { + match socket.recv_from(&mut recv_buf) { + Ok((n, from)) => { + if from != peer { + util::send_error( + socket, + from, + ErrorCode::UnknownTransferId, + "unknown transfer id", + ); + continue; + } + if let Ok(Packet::Data { block, .. }) = Packet::decode(&recv_buf[..n]) + && block == final_block + { + let _ignored = socket.send_to(ack_bytes, peer); + } + } + Err(e) if util::is_timeout(&e) => timeouts += 1, + Err(e) => return Err(e), + } + } + + Ok(()) } diff --git a/crates/pfs-tftp-sync/src/util.rs b/crates/pfs-tftp-sync/src/util.rs index 13152f4..f2782c2 100644 --- a/crates/pfs-tftp-sync/src/util.rs +++ b/crates/pfs-tftp-sync/src/util.rs @@ -1,4 +1,204 @@ #![forbid(unsafe_code)] -// Misc sync helpers live here. +use std::{ + collections::VecDeque, + io::{Read, Write}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, +}; +use pfs_tftp_proto::{ + netascii::{NetasciiDecoder, NetasciiEncoder}, + packet::{BLOCK_SIZE, ErrorCode, Mode, Packet}, +}; + +pub(crate) const MAX_REQUEST_SIZE: usize = 2048; + +#[must_use] +pub(crate) fn wildcard_addr_for(peer: SocketAddr) -> SocketAddr { + match peer.ip() { + IpAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + IpAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), + } +} + +pub(crate) fn bind_ephemeral_for(peer: SocketAddr) -> std::io::Result { + UdpSocket::bind(wildcard_addr_for(peer)) +} + +pub(crate) fn bind_ephemeral_on_ip(ip: IpAddr) -> std::io::Result { + UdpSocket::bind(SocketAddr::new(ip, 0)) +} + +pub(crate) fn is_timeout(err: &std::io::Error) -> bool { + matches!( + err.kind(), + std::io::ErrorKind::TimedOut | std::io::ErrorKind::WouldBlock + ) +} + +pub(crate) fn send_error(socket: &UdpSocket, peer: SocketAddr, code: ErrorCode, message: &str) { + // RFC 1350, Section 7: ERROR packets are not acknowledged nor retransmitted. + // Best effort. + let pkt = Packet::Error { + code, + message: message.to_string(), + }; + let _ignored = socket.send_to(&pkt.encode(), peer); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SourcePhase { + Start, + AfterFullBlock, + Done, +} + +/// Produces 512-byte DATA payload blocks according to RFC 1350, Section 2 and Section 6. +/// +/// RFC 1350, Section 6: The end of a transfer is signaled by a DATA packet with +/// 0..511 bytes. If the payload stream ends exactly on a 512-byte boundary, an +/// additional zero-length DATA packet must be sent. +pub(crate) struct DataSource<'a, R: Read> { + reader: &'a mut R, + mode: Mode, + encoder: NetasciiEncoder, + pending: VecDeque, + source_eof: bool, + phase: SourcePhase, +} + +impl<'a, R: Read> DataSource<'a, R> { + #[must_use] + pub(crate) fn new(reader: &'a mut R, mode: Mode) -> Self { + Self { + reader, + mode, + encoder: NetasciiEncoder::new(), + pending: VecDeque::new(), + source_eof: false, + phase: SourcePhase::Start, + } + } + + pub(crate) fn next_block(&mut self) -> std::io::Result>> { + if self.phase == SourcePhase::Done { + return Ok(None); + } + + if self.mode == Mode::Octet { + return self.next_block_octet(); + } + + self.next_block_netascii() + } + + fn next_block_octet(&mut self) -> std::io::Result>> { + let mut buf = [0u8; BLOCK_SIZE]; + let n = self.reader.read(&mut buf)?; + if n == 0 { + match self.phase { + SourcePhase::Start | SourcePhase::AfterFullBlock => { + self.phase = SourcePhase::Done; + return Ok(Some(Vec::new())); + } + SourcePhase::Done => return Ok(None), + } + } + + if n < BLOCK_SIZE { + self.phase = SourcePhase::Done; + Ok(Some(buf[..n].to_vec())) + } else { + self.phase = SourcePhase::AfterFullBlock; + Ok(Some(buf.to_vec())) + } + } + + fn next_block_netascii(&mut self) -> std::io::Result>> { + while !self.source_eof && self.pending.len() < BLOCK_SIZE { + let mut in_buf = [0u8; 4096]; + let n = self.reader.read(&mut in_buf)?; + if n == 0 { + self.source_eof = true; + break; + } + let mut encoded = Vec::with_capacity(n * 2); + self.encoder.encode_chunk(&in_buf[..n], &mut encoded); + self.pending.extend(encoded); + } + + if self.pending.is_empty() { + match self.phase { + SourcePhase::Start | SourcePhase::AfterFullBlock => { + self.phase = SourcePhase::Done; + return Ok(Some(Vec::new())); + } + SourcePhase::Done => return Ok(None), + } + } + + if self.pending.len() < BLOCK_SIZE { + self.phase = SourcePhase::Done; + let block: Vec = self.pending.drain(..).collect(); + return Ok(Some(block)); + } + + self.phase = SourcePhase::AfterFullBlock; + let block: Vec = self.pending.drain(..BLOCK_SIZE).collect(); + Ok(Some(block)) + } +} + +/// Writes received DATA payload blocks to an output, applying netascii translation. +pub(crate) enum DataSink<'a, W: Write> { + Octet(&'a mut W), + Netascii { + decoder: NetasciiDecoder, + output: &'a mut W, + scratch: Vec, + }, +} + +impl<'a, W: Write> DataSink<'a, W> { + #[must_use] + pub(crate) fn new(output: &'a mut W, mode: Mode) -> Self { + match mode { + Mode::Octet => Self::Octet(output), + Mode::NetAscii | Mode::Mail => Self::Netascii { + decoder: NetasciiDecoder::new(), + output, + scratch: Vec::new(), + }, + } + } + + pub(crate) fn write_data(&mut self, data: &[u8]) -> Result<(), DataSinkError> { + match self { + Self::Octet(out) => out.write_all(data).map_err(DataSinkError::Io), + Self::Netascii { + decoder, + output, + scratch, + } => { + scratch.clear(); + decoder + .decode_chunk(data, scratch) + .map_err(DataSinkError::Netascii)?; + output.write_all(scratch).map_err(DataSinkError::Io) + } + } + } + + pub(crate) fn finish(self) -> Result<(), DataSinkError> { + match self { + Self::Octet(_) => Ok(()), + Self::Netascii { decoder, .. } => decoder.finish().map_err(DataSinkError::Netascii), + } + } +} + +#[derive(Debug)] +pub(crate) enum DataSinkError { + Io(std::io::Error), + Netascii(pfs_tftp_proto::netascii::DecodeError), +}