#![forbid(unsafe_code)] use std::{ fs::File, io::{Read, Write}, net::SocketAddr, path::Path, time::Duration, }; use pfs_tftp_proto::packet::{BLOCK_SIZE, ErrorCode, MAX_PACKET_SIZE, Packet, Request}; use crate::{ Mode, error::{Error, Result}, util, }; /// Configuration for a synchronous TFTP client. #[derive(Debug, Clone)] pub struct ClientConfig { pub timeout: Duration, pub retries: u32, pub dally_timeout: Duration, pub dally_retries: u32, } impl Default for ClientConfig { fn default() -> Self { let timeout = Duration::from_secs(5); Self { timeout, retries: 5, dally_timeout: timeout, // RFC 1350, Section 6: The final ACK sender *may* terminate immediately, but // "dallying is encouraged". For CLI UX, default to no dallying and allow // opt-in via flags. dally_retries: 0, } } } /// A synchronous TFTP client. #[derive(Debug)] pub struct Client { pub(crate) server: SocketAddr, pub(crate) config: ClientConfig, } impl Client { #[must_use] pub fn new(server: SocketAddr, config: ClientConfig) -> Self { Self { server, config } } #[must_use] pub fn server(&self) -> SocketAddr { self.server } #[must_use] pub fn config(&self) -> &ClientConfig { &self.config } /// Downloads `remote_filename` from the configured server into `output`. /// /// RFC reference: /// - Packet formats: RFC 1350, Section 5 (Figures 5-1..5-4) /// - Lock-step ACK/DATA: RFC 1350, Section 2 /// - Termination (final <512 DATA): RFC 1350, Section 6 /// - Transfer identifiers (TID): RFC 1350, Section 4 (page 4) /// /// # Errors /// Returns an error on I/O failures, protocol decode errors, timeouts after /// exhausting retries, or if the peer sends an ERROR packet. pub fn download_to_writer( &self, remote_filename: &str, mode: Mode, output: &mut impl Write, ) -> Result<()> { let socket = util::bind_ephemeral_for(self.server)?; socket.set_read_timeout(Some(self.config.timeout))?; Self::validate_mode(mode)?; let mut recv_buf = [0u8; MAX_PACKET_SIZE]; let (peer, first) = self.rrq_handshake(&socket, &mut recv_buf, remote_filename, mode)?; let Packet::Data { block, data } = first else { return Err(Error::Io(std::io::Error::new( std::io::ErrorKind::InvalidData, "unexpected packet type", ))); }; self.download_from_first_data(&socket, peer, block, &data, mode, output) } /// Downloads `remote_filename` from the configured server into `local_path`. /// /// # Errors /// Returns any error from [`Self::download_to_writer`] or filesystem I/O. pub fn get(&self, remote_filename: &str, local_path: &Path, mode: Mode) -> Result<()> { // Don't create the output file before we know the RRQ will succeed. Otherwise, a missing // remote file (RFC 1350 error code 1) leaves behind an empty local file. let socket = util::bind_ephemeral_for(self.server)?; socket.set_read_timeout(Some(self.config.timeout))?; Self::validate_mode(mode)?; let mut recv_buf = [0u8; MAX_PACKET_SIZE]; let (peer, first) = self.rrq_handshake(&socket, &mut recv_buf, remote_filename, mode)?; let Packet::Data { block, data } = first else { return Err(Error::Io(std::io::Error::new( std::io::ErrorKind::InvalidData, "unexpected packet type", ))); }; let mut file = File::create(local_path)?; self.download_from_first_data(&socket, peer, block, &data, mode, &mut file) } /// Uploads `input` to the configured server as `remote_filename`. /// /// RFC reference: /// - WRQ/ACK(0) handshake: RFC 1350, Section 4 (page 4) and Section 5 (ACK block=0) /// - Lock-step ACK/DATA: RFC 1350, Section 2 /// - Termination (final <512 DATA): RFC 1350, Section 6 /// /// # Errors /// Returns an error on I/O failures, protocol decode errors, timeouts after /// exhausting retries, or if the peer sends an ERROR packet. pub fn upload_from_reader( &self, remote_filename: &str, mode: Mode, input: &mut impl Read, ) -> Result<()> { let socket = util::bind_ephemeral_for(self.server)?; socket.set_read_timeout(Some(self.config.timeout))?; Self::validate_mode(mode)?; let mut recv_buf = [0u8; MAX_PACKET_SIZE]; let peer = self.wrq_handshake(&socket, &mut recv_buf, remote_filename, mode)?; let mut source = util::DataSource::new(input, mode); let mut next_block: u16 = 1; let mut attempts = 0u32; while let Some(data) = source.next_block()? { let data_bytes = Packet::Data { block: next_block, data, } .encode(); socket.send_to(&data_bytes, peer)?; let expected_ack = next_block; let last_sent = data_bytes; let last_sent_kind = "DATA"; loop { let n = self.recv_from_peer_with_retries( &socket, &mut recv_buf, &mut attempts, last_sent_kind, &last_sent, peer, )?; match Packet::decode(&recv_buf[..n])? { Packet::Ack { block } if block == expected_ack => break, Packet::Ack { block } if block == expected_ack.wrapping_sub(1) => { socket.send_to(&last_sent, peer)?; } Packet::Error { code, message } => { return Err(Error::RemoteError { code, message }); } _ => { util::send_error( &socket, peer, ErrorCode::IllegalOperation, "unexpected packet type", ); return Err(Error::Io(std::io::Error::new( std::io::ErrorKind::InvalidData, "unexpected packet type", ))); } } } attempts = 0; next_block = next_block.wrapping_add(1); } Ok(()) } /// Uploads `local_path` to the configured server as `remote_filename`. /// /// # Errors /// Returns any error from [`Self::upload_from_reader`] or filesystem I/O. pub fn put(&self, local_path: &Path, remote_filename: &str, mode: Mode) -> Result<()> { let mut file = File::open(local_path)?; self.upload_from_reader(remote_filename, mode, &mut file) } fn dally_final_ack( &self, socket: &std::net::UdpSocket, peer: SocketAddr, final_block: u16, ack_bytes: Option<&[u8]>, ) -> Result<()> { if self.config.dally_retries == 0 { return Ok(()); } let Some(ack_bytes) = ack_bytes else { return Ok(()); }; socket.set_read_timeout(Some(self.config.dally_timeout))?; let mut recv_buf = [0u8; MAX_PACKET_SIZE]; let mut timeouts = 0u32; // RFC 1350, Section 6: "dallying is encouraged" to retransmit the final ACK // if the last DATA is retransmitted. while timeouts < self.config.dally_retries { match socket.recv_from(&mut recv_buf) { Ok((n, from)) => { if from != peer { util::send_error( socket, from, ErrorCode::UnknownTransferId, "unknown transfer id", ); continue; } if let Ok(Packet::Data { block, .. }) = Packet::decode(&recv_buf[..n]) && block == final_block { let _ignored = socket.send_to(ack_bytes, peer); } } Err(e) if util::is_timeout(&e) => timeouts += 1, Err(e) => return Err(e.into()), } } Ok(()) } fn validate_mode(mode: Mode) -> Result<()> { if mode == Mode::Mail { return Err(Error::UnsupportedMode(mode)); } Ok(()) } fn rrq_handshake( &self, socket: &std::net::UdpSocket, recv_buf: &mut [u8], remote_filename: &str, mode: Mode, ) -> Result<(SocketAddr, Packet)> { let rrq_bytes = Packet::Rrq(Request { filename: remote_filename.to_string(), mode, }) .encode(); socket.send_to(&rrq_bytes, self.server)?; let mut attempts = 0u32; loop { let (n, from) = self.recv_from_with_retries( socket, recv_buf, &mut attempts, "RRQ", &rrq_bytes, self.server, )?; if from.ip() != self.server.ip() { continue; } let pkt = Packet::decode(&recv_buf[..n])?; match pkt { Packet::Data { .. } => return Ok((from, pkt)), Packet::Error { code, message } => { return Err(Error::RemoteError { code, message }); } _ => { util::send_error( socket, from, ErrorCode::IllegalOperation, "unexpected packet", ); return Err(Error::Io(std::io::Error::new( std::io::ErrorKind::InvalidData, "unexpected packet", ))); } } } } fn wrq_handshake( &self, socket: &std::net::UdpSocket, recv_buf: &mut [u8], remote_filename: &str, mode: Mode, ) -> Result { let wrq_bytes = Packet::Wrq(Request { filename: remote_filename.to_string(), mode, }) .encode(); socket.send_to(&wrq_bytes, self.server)?; let mut attempts = 0u32; loop { let (n, from) = self.recv_from_with_retries( socket, recv_buf, &mut attempts, "WRQ", &wrq_bytes, self.server, )?; if from.ip() != self.server.ip() { continue; } match Packet::decode(&recv_buf[..n])? { Packet::Ack { block: 0 } => return Ok(from), Packet::Error { code, message } => { return Err(Error::RemoteError { code, message }); } _ => { util::send_error( socket, from, ErrorCode::IllegalOperation, "unexpected packet", ); return Err(Error::Io(std::io::Error::new( std::io::ErrorKind::InvalidData, "unexpected packet", ))); } } } } fn recv_from_with_retries( &self, socket: &std::net::UdpSocket, recv_buf: &mut [u8], attempts: &mut u32, last_sent_kind: &'static str, last_sent: &[u8], resend_to: SocketAddr, ) -> Result<(usize, SocketAddr)> { loop { match socket.recv_from(recv_buf) { Ok(v) => { *attempts = 0; return Ok(v); } Err(e) if util::is_timeout(&e) => { *attempts += 1; if *attempts > self.config.retries { return Err(Error::Timeout { last_packet: last_sent_kind, attempts: *attempts, }); } socket.send_to(last_sent, resend_to)?; } Err(e) => return Err(e.into()), } } } fn recv_from_peer_with_retries( &self, socket: &std::net::UdpSocket, recv_buf: &mut [u8], attempts: &mut u32, last_sent_kind: &'static str, last_sent: &[u8], peer: SocketAddr, ) -> Result { loop { let (n, from) = self.recv_from_with_retries( socket, recv_buf, attempts, last_sent_kind, last_sent, peer, )?; if from == peer { return Ok(n); } // RFC 1350, Section 4 (page 4): incorrect source port => Unknown transfer ID. if from.ip() == peer.ip() { util::send_error( socket, from, ErrorCode::UnknownTransferId, "unknown transfer id", ); } } } fn download_from_first_data( &self, socket: &std::net::UdpSocket, peer: SocketAddr, first_block: u16, first_data: &[u8], mode: Mode, output: &mut impl Write, ) -> Result<()> { if first_block != 1 { util::send_error( socket, peer, ErrorCode::IllegalOperation, "unexpected block number", ); return Err(Error::Io(std::io::Error::new( std::io::ErrorKind::InvalidData, "unexpected block number", ))); } let mut sink = util::DataSink::new(output, mode); sink.write_data(first_data).map_err(map_sink_error)?; let ack_bytes = Packet::Ack { block: first_block }.encode(); socket.send_to(&ack_bytes, peer)?; let mut last_ack = Some(ack_bytes.clone()); let mut last_sent = ack_bytes; let mut expected_block = 2u16; if first_data.len() < BLOCK_SIZE { sink.finish().map_err(map_sink_error)?; self.dally_final_ack(socket, peer, first_block, last_ack.as_deref())?; return Ok(()); } let mut attempts = 0u32; let mut recv_buf = [0u8; MAX_PACKET_SIZE]; loop { let n = self.recv_from_peer_with_retries( socket, &mut recv_buf, &mut attempts, "ACK", &last_sent, peer, )?; let pkt = Packet::decode(&recv_buf[..n])?; match pkt { Packet::Error { code, message } => { return Err(Error::RemoteError { code, message }); } Packet::Data { block, data } if block == expected_block => { sink.write_data(&data).map_err(map_sink_error)?; let ack = Packet::Ack { block }.encode(); socket.send_to(&ack, peer)?; last_ack = Some(ack.clone()); last_sent = ack; if data.len() < BLOCK_SIZE { sink.finish().map_err(map_sink_error)?; self.dally_final_ack(socket, peer, block, last_ack.as_deref())?; return Ok(()); } expected_block = expected_block.wrapping_add(1); } Packet::Data { block, .. } if block == expected_block.wrapping_sub(1) => { if let Some(ack) = &last_ack { socket.send_to(ack, peer)?; } } _ => { util::send_error( socket, peer, ErrorCode::IllegalOperation, "unexpected packet type", ); return Err(Error::Io(std::io::Error::new( std::io::ErrorKind::InvalidData, "unexpected packet type", ))); } } } } } fn map_sink_error(err: util::DataSinkError) -> Error { match err { util::DataSinkError::Io(e) => Error::Io(e), util::DataSinkError::Netascii(e) => Error::Netascii(e), } }