diff --git a/crates/pfs-tftp/src/bin/tftp.rs b/crates/pfs-tftp/src/bin/tftp.rs index 1f0f082..757a88e 100644 --- a/crates/pfs-tftp/src/bin/tftp.rs +++ b/crates/pfs-tftp/src/bin/tftp.rs @@ -34,10 +34,61 @@ //! tftp -m netascii put 192.168.1.1 readme.txt //! ``` -use std::{env, fs::File, path::Path, process::ExitCode}; +use std::{ + env, + fs::File, + io::{self, Write}, + path::{Path, PathBuf}, + process::ExitCode, +}; use pfs_tftp::{Client, Mode}; +/// A writer that defers file creation until the first write. +/// +/// This ensures the local file is only created after we've confirmed the remote +/// file exists (i.e., after receiving the first DATA packet, not an ERROR). +struct DeferredFileWriter { + path: PathBuf, + file: Option, + bytes_written: u64, +} + +impl DeferredFileWriter { + fn new(path: PathBuf) -> Self { + Self { + path, + file: None, + bytes_written: 0, + } + } + + fn bytes_written(&self) -> u64 { + self.bytes_written + } +} + +impl Write for DeferredFileWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + if self.file.is_none() { + self.file = Some(File::create(&self.path)?); + } + // SAFETY: We just ensured self.file is Some above + let file = self.file.as_mut().expect("file is Some"); + let n = file.write(buf)?; + self.bytes_written += n as u64; + Ok(n) + } + + fn flush(&mut self) -> io::Result<()> { + if let Some(ref mut file) = self.file { + file.flush() + } else { + Ok(()) + } + } +} + /// Print usage information. fn print_usage(program: &str) { eprintln!("TFTP Client (RFC 1350)"); @@ -248,28 +299,19 @@ fn main() -> ExitCode { } }; - // Download to memory first - only create local file on success - let data = match client.get(&remote_file, args.mode) { - Ok(data) => data, - Err(e) => { - eprintln!("Error: {e}"); - return ExitCode::FAILURE; - } - }; + // Use deferred writer - file is only created after first DATA packet + let mut writer = DeferredFileWriter::new(PathBuf::from(&local_file)); - // Write to local file - if let Err(e) = std::fs::write(&local_file, &data) { - eprintln!("Error writing file '{local_file}': {e}"); + if let Err(e) = client.get_to_writer(&remote_file, args.mode, &mut writer) { + eprintln!("Error: {e}"); return ExitCode::FAILURE; } + let bytes = writer.bytes_written(); if args.verbose { - eprintln!("Received {} bytes", data.len()); + eprintln!("Received {bytes} bytes"); } - println!( - "Downloaded '{remote_file}' -> '{local_file}' ({} bytes)", - data.len() - ); + println!("Downloaded '{remote_file}' -> '{local_file}' ({bytes} bytes)"); } Command::Put {