From 6d1fbef9577b43db365a56a312cdb8e448f04dc8 Mon Sep 17 00:00:00 2001 From: ddidderr Date: Sun, 21 Dec 2025 13:26:28 +0100 Subject: [PATCH] fix: avoid creating output file on missing RRQ --- crates/pfs-tftp-sync/src/client.rs | 181 +++++++++++++---------- crates/pfs-tftp-sync/tests/end_to_end.rs | 19 +++ 2 files changed, 124 insertions(+), 76 deletions(-) diff --git a/crates/pfs-tftp-sync/src/client.rs b/crates/pfs-tftp-sync/src/client.rs index cdd0baa..d7a809d 100644 --- a/crates/pfs-tftp-sync/src/client.rs +++ b/crates/pfs-tftp-sync/src/client.rs @@ -84,7 +84,6 @@ impl Client { socket.set_read_timeout(Some(self.config.timeout))?; Self::validate_mode(mode)?; - let mut sink = util::DataSink::new(output, mode); let mut recv_buf = [0u8; MAX_PACKET_SIZE]; let (peer, first) = self.rrq_handshake(&socket, &mut recv_buf, remote_filename, mode)?; @@ -94,80 +93,7 @@ impl Client { "unexpected packet type", ))); }; - if block != 1 { - util::send_error( - &socket, - peer, - ErrorCode::IllegalOperation, - "unexpected block number", - ); - return Err(Error::Io(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "unexpected block number", - ))); - } - - sink.write_data(&data).map_err(map_sink_error)?; - let ack_bytes = Packet::Ack { block }.encode(); - socket.send_to(&ack_bytes, peer)?; - let mut last_ack = Some(ack_bytes.clone()); - - let mut last_sent = ack_bytes; - let mut expected_block = 2u16; - - if data.len() < BLOCK_SIZE { - sink.finish().map_err(map_sink_error)?; - self.dally_final_ack(&socket, peer, block, last_ack.as_deref())?; - return Ok(()); - } - - let mut attempts = 0u32; - loop { - let n = self.recv_from_peer_with_retries( - &socket, - &mut recv_buf, - &mut attempts, - "ACK", - &last_sent, - peer, - )?; - let pkt = Packet::decode(&recv_buf[..n])?; - match pkt { - Packet::Error { code, message } => { - return Err(Error::RemoteError { code, message }); - } - Packet::Data { block, data } if block == expected_block => { - sink.write_data(&data).map_err(map_sink_error)?; - let ack = Packet::Ack { block }.encode(); - socket.send_to(&ack, peer)?; - last_ack = Some(ack.clone()); - last_sent = ack; - if data.len() < BLOCK_SIZE { - sink.finish().map_err(map_sink_error)?; - self.dally_final_ack(&socket, peer, block, last_ack.as_deref())?; - return Ok(()); - } - expected_block = expected_block.wrapping_add(1); - } - Packet::Data { block, .. } if block == expected_block.wrapping_sub(1) => { - if let Some(ack) = &last_ack { - socket.send_to(ack, peer)?; - } - } - _ => { - util::send_error( - &socket, - peer, - ErrorCode::IllegalOperation, - "unexpected packet type", - ); - return Err(Error::Io(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "unexpected packet type", - ))); - } - } - } + self.download_from_first_data(&socket, peer, block, &data, mode, output) } /// Downloads `remote_filename` from the configured server into `local_path`. @@ -175,8 +101,23 @@ impl Client { /// # Errors /// Returns any error from [`Self::download_to_writer`] or filesystem I/O. pub fn get(&self, remote_filename: &str, local_path: &Path, mode: Mode) -> Result<()> { + // Don't create the output file before we know the RRQ will succeed. Otherwise, a missing + // remote file (RFC 1350 error code 1) leaves behind an empty local file. + let socket = util::bind_ephemeral_for(self.server)?; + socket.set_read_timeout(Some(self.config.timeout))?; + Self::validate_mode(mode)?; + + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + let (peer, first) = self.rrq_handshake(&socket, &mut recv_buf, remote_filename, mode)?; + let Packet::Data { block, data } = first else { + return Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unexpected packet type", + ))); + }; + let mut file = File::create(local_path)?; - self.download_to_writer(remote_filename, mode, &mut file) + self.download_from_first_data(&socket, peer, block, &data, mode, &mut file) } /// Uploads `input` to the configured server as `remote_filename`. @@ -479,6 +420,94 @@ impl Client { } } } + + fn download_from_first_data( + &self, + socket: &std::net::UdpSocket, + peer: SocketAddr, + first_block: u16, + first_data: &[u8], + mode: Mode, + output: &mut impl Write, + ) -> Result<()> { + if first_block != 1 { + util::send_error( + socket, + peer, + ErrorCode::IllegalOperation, + "unexpected block number", + ); + return Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unexpected block number", + ))); + } + + let mut sink = util::DataSink::new(output, mode); + sink.write_data(first_data).map_err(map_sink_error)?; + + let ack_bytes = Packet::Ack { block: first_block }.encode(); + socket.send_to(&ack_bytes, peer)?; + let mut last_ack = Some(ack_bytes.clone()); + + let mut last_sent = ack_bytes; + let mut expected_block = 2u16; + + if first_data.len() < BLOCK_SIZE { + sink.finish().map_err(map_sink_error)?; + self.dally_final_ack(socket, peer, first_block, last_ack.as_deref())?; + return Ok(()); + } + + let mut attempts = 0u32; + let mut recv_buf = [0u8; MAX_PACKET_SIZE]; + loop { + let n = self.recv_from_peer_with_retries( + socket, + &mut recv_buf, + &mut attempts, + "ACK", + &last_sent, + peer, + )?; + let pkt = Packet::decode(&recv_buf[..n])?; + match pkt { + Packet::Error { code, message } => { + return Err(Error::RemoteError { code, message }); + } + Packet::Data { block, data } if block == expected_block => { + sink.write_data(&data).map_err(map_sink_error)?; + let ack = Packet::Ack { block }.encode(); + socket.send_to(&ack, peer)?; + last_ack = Some(ack.clone()); + last_sent = ack; + if data.len() < BLOCK_SIZE { + sink.finish().map_err(map_sink_error)?; + self.dally_final_ack(socket, peer, block, last_ack.as_deref())?; + return Ok(()); + } + expected_block = expected_block.wrapping_add(1); + } + Packet::Data { block, .. } if block == expected_block.wrapping_sub(1) => { + if let Some(ack) = &last_ack { + socket.send_to(ack, peer)?; + } + } + _ => { + util::send_error( + socket, + peer, + ErrorCode::IllegalOperation, + "unexpected packet type", + ); + return Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "unexpected packet type", + ))); + } + } + } + } } fn map_sink_error(err: util::DataSinkError) -> Error { diff --git a/crates/pfs-tftp-sync/tests/end_to_end.rs b/crates/pfs-tftp-sync/tests/end_to_end.rs index b4a50a9..66e509a 100644 --- a/crates/pfs-tftp-sync/tests/end_to_end.rs +++ b/crates/pfs-tftp-sync/tests/end_to_end.rs @@ -98,6 +98,25 @@ fn netascii_roundtrip_preserves_newlines_and_cr() { stop_server(&shutdown, handle); } +#[test] +fn missing_remote_file_does_not_create_local_file() { + let server_root = TempDir::new("pfs_tftp_server_root_missing"); + let local_root = TempDir::new("pfs_tftp_local_root_missing"); + + let (addr, shutdown, handle) = start_server(server_root.path(), true); + let client = Client::new(addr, test_client_config()); + + let out = local_root.path().join("should_not_exist.bin"); + assert!(!out.exists()); + + let _err = client + .get("does_not_exist.bin", &out, Mode::Octet) + .unwrap_err(); + assert!(!out.exists()); + + stop_server(&shutdown, handle); +} + fn test_client_config() -> ClientConfig { ClientConfig { timeout: Duration::from_millis(200),