feat: add sync TFTP client/server library

This commit is contained in:
2025-12-21 11:36:07 +01:00
parent 890443fdc2
commit 9ebfdb8cb2
6 changed files with 1206 additions and 12 deletions

View File

@@ -294,8 +294,15 @@ impl fmt::Display for DecodeError {
Self::UnknownMode(mode) => write!(f, "unknown transfer mode {mode:?}"), Self::UnknownMode(mode) => write!(f, "unknown transfer mode {mode:?}"),
Self::UnknownErrorCode(code) => write!(f, "unknown error code {code}"), Self::UnknownErrorCode(code) => write!(f, "unknown error code {code}"),
Self::OversizeData(n) => write!(f, "DATA payload too large ({n} bytes)"), Self::OversizeData(n) => write!(f, "DATA payload too large ({n} bytes)"),
Self::InvalidLength { kind, expected, got } => { Self::InvalidLength {
write!(f, "{kind} packet has invalid length (expected {expected}, got {got})") kind,
expected,
got,
} => {
write!(
f,
"{kind} packet has invalid length (expected {expected}, got {got})"
)
} }
Self::TrailingBytes => write!(f, "trailing bytes after packet"), Self::TrailingBytes => write!(f, "trailing bytes after packet"),
} }
@@ -339,7 +346,10 @@ mod tests {
Packet::Ack { block: 1 }.encode_into(&mut bytes); Packet::Ack { block: 1 }.encode_into(&mut bytes);
bytes.push(0); bytes.push(0);
let err = Packet::decode(&bytes).unwrap_err(); let err = Packet::decode(&bytes).unwrap_err();
assert!(matches!(err, DecodeError::InvalidLength { kind: "ACK", .. })); assert!(matches!(
err,
DecodeError::InvalidLength { kind: "ACK", .. }
));
} }
#[allow(clippy::unwrap_used)] #[allow(clippy::unwrap_used)]

View File

@@ -1,20 +1,38 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
use std::net::SocketAddr; use std::{
use std::time::Duration; fs::File,
io::{Read, Write},
net::SocketAddr,
path::Path,
time::Duration,
};
use pfs_tftp_proto::packet::{BLOCK_SIZE, ErrorCode, MAX_PACKET_SIZE, Packet, Request};
use crate::{
Mode,
error::{Error, Result},
util,
};
/// Configuration for a synchronous TFTP client. /// Configuration for a synchronous TFTP client.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ClientConfig { pub struct ClientConfig {
pub timeout: Duration, pub timeout: Duration,
pub retries: u32, pub retries: u32,
pub dally_timeout: Duration,
pub dally_retries: u32,
} }
impl Default for ClientConfig { impl Default for ClientConfig {
fn default() -> Self { fn default() -> Self {
let timeout = Duration::from_secs(5);
Self { Self {
timeout: Duration::from_secs(5), timeout,
retries: 5, retries: 5,
dally_timeout: timeout,
dally_retries: 2,
} }
} }
} }
@@ -41,4 +59,428 @@ impl Client {
pub fn config(&self) -> &ClientConfig { pub fn config(&self) -> &ClientConfig {
&self.config &self.config
} }
/// Downloads `remote_filename` from the configured server into `output`.
///
/// RFC reference:
/// - Packet formats: RFC 1350, Section 5 (Figures 5-1..5-4)
/// - Lock-step ACK/DATA: RFC 1350, Section 2
/// - Termination (final <512 DATA): RFC 1350, Section 6
/// - Transfer identifiers (TID): RFC 1350, Section 4 (page 4)
///
/// # Errors
/// Returns an error on I/O failures, protocol decode errors, timeouts after
/// exhausting retries, or if the peer sends an ERROR packet.
pub fn download_to_writer(
&self,
remote_filename: &str,
mode: Mode,
output: &mut impl Write,
) -> Result<()> {
let socket = util::bind_ephemeral_for(self.server)?;
socket.set_read_timeout(Some(self.config.timeout))?;
Self::validate_mode(mode)?;
let mut sink = util::DataSink::new(output, mode);
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
let (peer, first) = self.rrq_handshake(&socket, &mut recv_buf, remote_filename, mode)?;
let Packet::Data { block, data } = first else {
return Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"unexpected packet type",
)));
};
if block != 1 {
util::send_error(
&socket,
peer,
ErrorCode::IllegalOperation,
"unexpected block number",
);
return Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"unexpected block number",
)));
}
sink.write_data(&data).map_err(map_sink_error)?;
let ack_bytes = Packet::Ack { block }.encode();
socket.send_to(&ack_bytes, peer)?;
let mut last_ack = Some(ack_bytes.clone());
let mut last_sent = ack_bytes;
let mut expected_block = 2u16;
if data.len() < BLOCK_SIZE {
sink.finish().map_err(map_sink_error)?;
self.dally_final_ack(&socket, peer, block, last_ack.as_deref())?;
return Ok(());
}
let mut attempts = 0u32;
loop {
let n = self.recv_from_peer_with_retries(
&socket,
&mut recv_buf,
&mut attempts,
"ACK",
&last_sent,
peer,
)?;
let pkt = Packet::decode(&recv_buf[..n])?;
match pkt {
Packet::Error { code, message } => {
return Err(Error::RemoteError { code, message });
}
Packet::Data { block, data } if block == expected_block => {
sink.write_data(&data).map_err(map_sink_error)?;
let ack = Packet::Ack { block }.encode();
socket.send_to(&ack, peer)?;
last_ack = Some(ack.clone());
last_sent = ack;
if data.len() < BLOCK_SIZE {
sink.finish().map_err(map_sink_error)?;
self.dally_final_ack(&socket, peer, block, last_ack.as_deref())?;
return Ok(());
}
expected_block = expected_block.wrapping_add(1);
}
Packet::Data { block, .. } if block == expected_block.wrapping_sub(1) => {
if let Some(ack) = &last_ack {
socket.send_to(ack, peer)?;
}
}
_ => {
util::send_error(
&socket,
peer,
ErrorCode::IllegalOperation,
"unexpected packet type",
);
return Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"unexpected packet type",
)));
}
}
}
}
/// Downloads `remote_filename` from the configured server into `local_path`.
///
/// # Errors
/// Returns any error from [`Self::download_to_writer`] or filesystem I/O.
pub fn get(&self, remote_filename: &str, local_path: &Path, mode: Mode) -> Result<()> {
let mut file = File::create(local_path)?;
self.download_to_writer(remote_filename, mode, &mut file)
}
/// Uploads `input` to the configured server as `remote_filename`.
///
/// RFC reference:
/// - WRQ/ACK(0) handshake: RFC 1350, Section 4 (page 4) and Section 5 (ACK block=0)
/// - Lock-step ACK/DATA: RFC 1350, Section 2
/// - Termination (final <512 DATA): RFC 1350, Section 6
///
/// # Errors
/// Returns an error on I/O failures, protocol decode errors, timeouts after
/// exhausting retries, or if the peer sends an ERROR packet.
pub fn upload_from_reader(
&self,
remote_filename: &str,
mode: Mode,
input: &mut impl Read,
) -> Result<()> {
let socket = util::bind_ephemeral_for(self.server)?;
socket.set_read_timeout(Some(self.config.timeout))?;
Self::validate_mode(mode)?;
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
let peer = self.wrq_handshake(&socket, &mut recv_buf, remote_filename, mode)?;
let mut source = util::DataSource::new(input, mode);
let mut next_block: u16 = 1;
let mut attempts = 0u32;
while let Some(data) = source.next_block()? {
let data_bytes = Packet::Data {
block: next_block,
data,
}
.encode();
socket.send_to(&data_bytes, peer)?;
let expected_ack = next_block;
let last_sent = data_bytes;
let last_sent_kind = "DATA";
loop {
let n = self.recv_from_peer_with_retries(
&socket,
&mut recv_buf,
&mut attempts,
last_sent_kind,
&last_sent,
peer,
)?;
match Packet::decode(&recv_buf[..n])? {
Packet::Ack { block } if block == expected_ack => break,
Packet::Ack { block } if block == expected_ack.wrapping_sub(1) => {
socket.send_to(&last_sent, peer)?;
}
Packet::Error { code, message } => {
return Err(Error::RemoteError { code, message });
}
_ => {
util::send_error(
&socket,
peer,
ErrorCode::IllegalOperation,
"unexpected packet type",
);
return Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"unexpected packet type",
)));
}
}
}
attempts = 0;
next_block = next_block.wrapping_add(1);
}
Ok(())
}
/// Uploads `local_path` to the configured server as `remote_filename`.
///
/// # Errors
/// Returns any error from [`Self::upload_from_reader`] or filesystem I/O.
pub fn put(&self, local_path: &Path, remote_filename: &str, mode: Mode) -> Result<()> {
let mut file = File::open(local_path)?;
self.upload_from_reader(remote_filename, mode, &mut file)
}
fn dally_final_ack(
&self,
socket: &std::net::UdpSocket,
peer: SocketAddr,
final_block: u16,
ack_bytes: Option<&[u8]>,
) -> Result<()> {
if self.config.dally_retries == 0 {
return Ok(());
}
let Some(ack_bytes) = ack_bytes else {
return Ok(());
};
socket.set_read_timeout(Some(self.config.dally_timeout))?;
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
let mut timeouts = 0u32;
// RFC 1350, Section 6: "dallying is encouraged" to retransmit the final ACK
// if the last DATA is retransmitted.
while timeouts < self.config.dally_retries {
match socket.recv_from(&mut recv_buf) {
Ok((n, from)) => {
if from != peer {
util::send_error(
socket,
from,
ErrorCode::UnknownTransferId,
"unknown transfer id",
);
continue;
}
if let Ok(Packet::Data { block, .. }) = Packet::decode(&recv_buf[..n])
&& block == final_block
{
let _ignored = socket.send_to(ack_bytes, peer);
}
}
Err(e) if util::is_timeout(&e) => timeouts += 1,
Err(e) => return Err(e.into()),
}
}
Ok(())
}
fn validate_mode(mode: Mode) -> Result<()> {
if mode == Mode::Mail {
return Err(Error::UnsupportedMode(mode));
}
Ok(())
}
fn rrq_handshake(
&self,
socket: &std::net::UdpSocket,
recv_buf: &mut [u8],
remote_filename: &str,
mode: Mode,
) -> Result<(SocketAddr, Packet)> {
let rrq_bytes = Packet::Rrq(Request {
filename: remote_filename.to_string(),
mode,
})
.encode();
socket.send_to(&rrq_bytes, self.server)?;
let mut attempts = 0u32;
loop {
let (n, from) = self.recv_from_with_retries(
socket,
recv_buf,
&mut attempts,
"RRQ",
&rrq_bytes,
self.server,
)?;
if from.ip() != self.server.ip() {
continue;
}
let pkt = Packet::decode(&recv_buf[..n])?;
match pkt {
Packet::Data { .. } => return Ok((from, pkt)),
Packet::Error { code, message } => {
return Err(Error::RemoteError { code, message });
}
_ => {
util::send_error(
socket,
from,
ErrorCode::IllegalOperation,
"unexpected packet",
);
return Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"unexpected packet",
)));
}
}
}
}
fn wrq_handshake(
&self,
socket: &std::net::UdpSocket,
recv_buf: &mut [u8],
remote_filename: &str,
mode: Mode,
) -> Result<SocketAddr> {
let wrq_bytes = Packet::Wrq(Request {
filename: remote_filename.to_string(),
mode,
})
.encode();
socket.send_to(&wrq_bytes, self.server)?;
let mut attempts = 0u32;
loop {
let (n, from) = self.recv_from_with_retries(
socket,
recv_buf,
&mut attempts,
"WRQ",
&wrq_bytes,
self.server,
)?;
if from.ip() != self.server.ip() {
continue;
}
match Packet::decode(&recv_buf[..n])? {
Packet::Ack { block: 0 } => return Ok(from),
Packet::Error { code, message } => {
return Err(Error::RemoteError { code, message });
}
_ => {
util::send_error(
socket,
from,
ErrorCode::IllegalOperation,
"unexpected packet",
);
return Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"unexpected packet",
)));
}
}
}
}
fn recv_from_with_retries(
&self,
socket: &std::net::UdpSocket,
recv_buf: &mut [u8],
attempts: &mut u32,
last_sent_kind: &'static str,
last_sent: &[u8],
resend_to: SocketAddr,
) -> Result<(usize, SocketAddr)> {
loop {
match socket.recv_from(recv_buf) {
Ok(v) => {
*attempts = 0;
return Ok(v);
}
Err(e) if util::is_timeout(&e) => {
*attempts += 1;
if *attempts > self.config.retries {
return Err(Error::Timeout {
last_packet: last_sent_kind,
attempts: *attempts,
});
}
socket.send_to(last_sent, resend_to)?;
}
Err(e) => return Err(e.into()),
}
}
}
fn recv_from_peer_with_retries(
&self,
socket: &std::net::UdpSocket,
recv_buf: &mut [u8],
attempts: &mut u32,
last_sent_kind: &'static str,
last_sent: &[u8],
peer: SocketAddr,
) -> Result<usize> {
loop {
let (n, from) = self.recv_from_with_retries(
socket,
recv_buf,
attempts,
last_sent_kind,
last_sent,
peer,
)?;
if from == peer {
return Ok(n);
}
// RFC 1350, Section 4 (page 4): incorrect source port => Unknown transfer ID.
if from.ip() == peer.ip() {
util::send_error(
socket,
from,
ErrorCode::UnknownTransferId,
"unknown transfer id",
);
}
}
}
}
fn map_sink_error(err: util::DataSinkError) -> Error {
match err {
util::DataSinkError::Io(e) => Error::Io(e),
util::DataSinkError::Netascii(e) => Error::Netascii(e),
}
} }

View File

@@ -0,0 +1,76 @@
#![forbid(unsafe_code)]
use core::fmt;
use pfs_tftp_proto::{netascii, packet};
/// Errors returned by the synchronous client APIs.
#[derive(Debug)]
pub enum Error {
Io(std::io::Error),
Protocol(packet::DecodeError),
Netascii(netascii::DecodeError),
UnsupportedMode(packet::Mode),
RemoteError {
code: packet::ErrorCode,
message: String,
},
Timeout {
last_packet: &'static str,
attempts: u32,
},
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Io(e) => write!(f, "{e}"),
Self::Protocol(e) => write!(f, "{e}"),
Self::Netascii(e) => write!(f, "{e}"),
Self::UnsupportedMode(mode) => {
write!(f, "unsupported transfer mode: {}", mode.as_str())
}
Self::RemoteError { code, message } => {
write!(f, "remote error {code:?}: {message}")
}
Self::Timeout {
last_packet,
attempts,
} => write!(
f,
"timed out waiting for response after {attempts} attempts (last sent: {last_packet})"
),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
Self::Protocol(e) => Some(e),
Self::Netascii(e) => Some(e),
Self::UnsupportedMode(_) | Self::RemoteError { .. } | Self::Timeout { .. } => None,
}
}
}
impl From<std::io::Error> for Error {
fn from(value: std::io::Error) -> Self {
Self::Io(value)
}
}
impl From<packet::DecodeError> for Error {
fn from(value: packet::DecodeError) -> Self {
Self::Protocol(value)
}
}
impl From<netascii::DecodeError> for Error {
fn from(value: netascii::DecodeError) -> Self {
Self::Netascii(value)
}
}
pub type Result<T> = core::result::Result<T, Error>;

View File

@@ -6,8 +6,11 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
pub mod client; pub mod client;
pub mod error;
pub mod server; pub mod server;
pub mod util; mod util;
pub use client::{Client, ClientConfig}; pub use client::{Client, ClientConfig};
pub use error::{Error, Result};
pub use pfs_tftp_proto::packet::Mode;
pub use server::{Server, ServerConfig}; pub use server::{Server, ServerConfig};

View File

@@ -1,8 +1,16 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
use std::net::SocketAddr; use std::{
use std::path::PathBuf; fs::{File, OpenOptions},
use std::time::Duration; net::SocketAddr,
path::{Path, PathBuf},
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use pfs_tftp_proto::packet::{BLOCK_SIZE, ErrorCode, MAX_PACKET_SIZE, Packet};
use crate::{Mode, util};
/// Configuration for a synchronous TFTP server. /// Configuration for a synchronous TFTP server.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -13,17 +21,24 @@ pub struct ServerConfig {
pub overwrite: bool, pub overwrite: bool,
pub timeout: Duration, pub timeout: Duration,
pub retries: u32, pub retries: u32,
pub dally_timeout: Duration,
pub dally_retries: u32,
pub max_request_size: usize,
} }
impl Default for ServerConfig { impl Default for ServerConfig {
fn default() -> Self { fn default() -> Self {
let timeout = Duration::from_secs(5);
Self { Self {
bind: SocketAddr::from(([0, 0, 0, 0], 6969)), bind: SocketAddr::from(([0, 0, 0, 0], 6969)),
root: PathBuf::from("."), root: PathBuf::from("."),
allow_write: false, allow_write: false,
overwrite: false, overwrite: false,
timeout: Duration::from_secs(5), timeout,
retries: 5, retries: 5,
dally_timeout: timeout,
dally_retries: 2,
max_request_size: util::MAX_REQUEST_SIZE,
} }
} }
} }
@@ -44,4 +59,452 @@ impl Server {
pub fn config(&self) -> &ServerConfig { pub fn config(&self) -> &ServerConfig {
&self.config &self.config
} }
/// Serves requests forever.
///
/// # Errors
/// Returns an error if the server socket cannot be bound or configured.
pub fn serve(&self) -> std::io::Result<()> {
self.serve_inner(None)
}
/// Serves requests until `shutdown` becomes `true`.
///
/// This is mainly intended for tests.
///
/// # Errors
/// Returns an error if the server socket cannot be bound or configured.
pub fn serve_until(&self, shutdown: &AtomicBool) -> std::io::Result<()> {
self.serve_inner(Some(shutdown))
}
fn serve_inner(&self, shutdown: Option<&AtomicBool>) -> std::io::Result<()> {
let socket = std::net::UdpSocket::bind(self.config.bind)?;
if shutdown.is_some() {
socket.set_read_timeout(Some(Duration::from_millis(200)))?;
}
let mut buf = vec![0u8; self.config.max_request_size];
loop {
if let Some(flag) = shutdown
&& flag.load(Ordering::Relaxed)
{
return Ok(());
}
let (n, peer) = match socket.recv_from(&mut buf) {
Ok(v) => v,
Err(e) if shutdown.is_some() && util::is_timeout(&e) => continue,
Err(e) => return Err(e),
};
let Ok(pkt) = Packet::decode(&buf[..n]) else {
util::send_error(
&socket,
peer,
ErrorCode::IllegalOperation,
"malformed packet",
);
continue;
};
match pkt {
Packet::Rrq(req) => {
let cfg = self.config.clone();
std::thread::spawn(move || {
let _ignored = handle_rrq(cfg, peer, req);
});
}
Packet::Wrq(req) => {
let cfg = self.config.clone();
std::thread::spawn(move || {
let _ignored = handle_wrq(cfg, peer, req);
});
}
Packet::Data { .. } | Packet::Ack { .. } | Packet::Error { .. } => {
// RFC 1350, Section 4 (page 4): incorrect source port => Unknown transfer ID.
util::send_error(
&socket,
peer,
ErrorCode::UnknownTransferId,
"unknown transfer id",
);
}
}
}
}
}
fn handle_rrq(
cfg: ServerConfig,
client: SocketAddr,
req: pfs_tftp_proto::packet::Request,
) -> std::io::Result<()> {
let pfs_tftp_proto::packet::Request { filename, mode } = req;
let ServerConfig {
bind,
root,
timeout,
retries,
..
} = cfg;
if mode == Mode::Mail {
let socket = util::bind_ephemeral_on_ip(bind.ip())?;
util::send_error(
&socket,
client,
ErrorCode::NotDefined,
"mail mode is obsolete",
);
return Ok(());
}
let socket = util::bind_ephemeral_on_ip(bind.ip())?;
socket.set_read_timeout(Some(timeout))?;
let Some(path) = resolve_path(&root, &filename) else {
util::send_error(&socket, client, ErrorCode::AccessViolation, "invalid path");
return Ok(());
};
let mut file = match File::open(&path) {
Ok(f) => f,
Err(e) => {
util::send_error(
&socket,
client,
map_io_error_code(&e),
"failed to open file",
);
return Ok(());
}
};
let mut source = util::DataSource::new(&mut file, mode);
let mut block: u16 = 1;
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
let mut attempts = 0u32;
while let Some(data) = source.next_block()? {
let pkt = Packet::Data { block, data };
let bytes = pkt.encode();
socket.send_to(&bytes, client)?;
let last_sent = bytes;
let expected_ack = block;
loop {
let (n, from) = match socket.recv_from(&mut recv_buf) {
Ok(v) => v,
Err(e) if util::is_timeout(&e) => {
attempts += 1;
if attempts > retries {
return Ok(());
}
socket.send_to(&last_sent, client)?;
continue;
}
Err(e) => return Err(e),
};
if from != client {
util::send_error(
&socket,
from,
ErrorCode::UnknownTransferId,
"unknown transfer id",
);
continue;
}
attempts = 0;
match Packet::decode(&recv_buf[..n]) {
Ok(Packet::Ack { block }) if block == expected_ack => break,
Ok(Packet::Ack { block }) if block == expected_ack.wrapping_sub(1) => {
// Duplicate ACK, retransmit last DATA.
socket.send_to(&last_sent, client)?;
}
Ok(Packet::Error { .. }) => return Ok(()),
Ok(_) | Err(_) => {
util::send_error(
&socket,
client,
ErrorCode::IllegalOperation,
"unexpected packet",
);
return Ok(());
}
}
}
// RFC 1350, Section 6: final DATA has <512 bytes.
if last_sent.len() < 4 + BLOCK_SIZE {
return Ok(());
}
block = block.wrapping_add(1);
}
Ok(())
}
#[allow(clippy::too_many_lines)]
fn handle_wrq(
cfg: ServerConfig,
client: SocketAddr,
req: pfs_tftp_proto::packet::Request,
) -> std::io::Result<()> {
let pfs_tftp_proto::packet::Request { filename, mode } = req;
let ServerConfig {
bind,
root,
allow_write,
overwrite,
timeout,
retries,
dally_timeout,
dally_retries,
..
} = cfg;
let socket = util::bind_ephemeral_on_ip(bind.ip())?;
socket.set_read_timeout(Some(timeout))?;
if mode == Mode::Mail {
util::send_error(
&socket,
client,
ErrorCode::NotDefined,
"mail mode is obsolete",
);
return Ok(());
}
if !allow_write {
util::send_error(
&socket,
client,
ErrorCode::AccessViolation,
"writes disabled",
);
return Ok(());
}
let Some(path) = resolve_path(&root, &filename) else {
util::send_error(&socket, client, ErrorCode::AccessViolation, "invalid path");
return Ok(());
};
let mut file = match open_for_write(&path, overwrite) {
Ok(f) => f,
Err(e) => {
util::send_error(
&socket,
client,
map_io_error_code(&e),
"failed to create file",
);
return Ok(());
}
};
// RFC 1350, Section 5: WRQ is acknowledged by ACK(block=0).
let ack0 = Packet::Ack { block: 0 }.encode();
socket.send_to(&ack0, client)?;
let mut last_ack = ack0;
let mut expected_block: u16 = 1;
let mut attempts = 0u32;
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
let mut sink = util::DataSink::new(&mut file, mode);
loop {
let (n, from) = match socket.recv_from(&mut recv_buf) {
Ok(v) => v,
Err(e) if util::is_timeout(&e) => {
attempts += 1;
if attempts > retries {
return Ok(());
}
socket.send_to(&last_ack, client)?;
continue;
}
Err(e) => return Err(e),
};
if from != client {
util::send_error(
&socket,
from,
ErrorCode::UnknownTransferId,
"unknown transfer id",
);
continue;
}
attempts = 0;
let Ok(pkt) = Packet::decode(&recv_buf[..n]) else {
util::send_error(
&socket,
client,
ErrorCode::IllegalOperation,
"malformed packet",
);
return Ok(());
};
match pkt {
Packet::Error { .. } => return Ok(()),
Packet::Data { block, data } => {
if block == expected_block {
if let Err(e) = sink.write_data(&data) {
util::send_error(
&socket,
client,
map_sink_error_code(&e),
"failed to write data",
);
return Ok(());
}
let ack = Packet::Ack { block }.encode();
socket.send_to(&ack, client)?;
last_ack = ack;
if data.len() < BLOCK_SIZE {
if let Err(e) = sink.finish() {
util::send_error(
&socket,
client,
map_sink_error_code(&e),
"failed to finish transfer",
);
return Ok(());
}
dally_final_ack(
dally_timeout,
dally_retries,
&socket,
client,
block,
&last_ack,
)?;
return Ok(());
}
expected_block = expected_block.wrapping_add(1);
} else if block == expected_block.wrapping_sub(1) {
// Duplicate DATA (ACK lost); re-ACK.
socket.send_to(&last_ack, client)?;
} else {
util::send_error(
&socket,
client,
ErrorCode::IllegalOperation,
"unexpected block number",
);
return Ok(());
}
}
Packet::Ack { .. } | Packet::Rrq(_) | Packet::Wrq(_) => {
util::send_error(
&socket,
client,
ErrorCode::IllegalOperation,
"unexpected packet type",
);
return Ok(());
}
}
}
}
fn resolve_path(root: &Path, filename: &str) -> Option<PathBuf> {
let rel = Path::new(filename);
if rel.is_absolute() {
return None;
}
for c in rel.components() {
match c {
std::path::Component::Normal(_) => {}
_ => return None,
}
}
Some(root.join(rel))
}
#[must_use]
fn map_io_error_code(err: &std::io::Error) -> ErrorCode {
match err.kind() {
std::io::ErrorKind::NotFound => ErrorCode::FileNotFound,
std::io::ErrorKind::PermissionDenied => ErrorCode::AccessViolation,
std::io::ErrorKind::AlreadyExists => ErrorCode::FileAlreadyExists,
std::io::ErrorKind::StorageFull => ErrorCode::DiskFull,
_ => ErrorCode::NotDefined,
}
}
fn open_for_write(path: &Path, overwrite: bool) -> std::io::Result<File> {
let mut opts = OpenOptions::new();
opts.write(true);
if overwrite {
opts.create(true).truncate(true);
} else {
opts.create_new(true);
}
opts.open(path)
}
#[must_use]
fn map_sink_error_code(err: &util::DataSinkError) -> ErrorCode {
match err {
util::DataSinkError::Io(e) => map_io_error_code(e),
// Invalid netascii is a protocol violation.
util::DataSinkError::Netascii(_) => ErrorCode::IllegalOperation,
}
}
fn dally_final_ack(
dally_timeout: Duration,
dally_retries: u32,
socket: &std::net::UdpSocket,
peer: SocketAddr,
final_block: u16,
ack_bytes: &[u8],
) -> std::io::Result<()> {
if dally_retries == 0 {
return Ok(());
}
socket.set_read_timeout(Some(dally_timeout))?;
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
let mut timeouts = 0u32;
// RFC 1350, Section 6: dallying for final ACK retransmission.
while timeouts < dally_retries {
match socket.recv_from(&mut recv_buf) {
Ok((n, from)) => {
if from != peer {
util::send_error(
socket,
from,
ErrorCode::UnknownTransferId,
"unknown transfer id",
);
continue;
}
if let Ok(Packet::Data { block, .. }) = Packet::decode(&recv_buf[..n])
&& block == final_block
{
let _ignored = socket.send_to(ack_bytes, peer);
}
}
Err(e) if util::is_timeout(&e) => timeouts += 1,
Err(e) => return Err(e),
}
}
Ok(())
} }

View File

@@ -1,4 +1,204 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
// Misc sync helpers live here. use std::{
collections::VecDeque,
io::{Read, Write},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
};
use pfs_tftp_proto::{
netascii::{NetasciiDecoder, NetasciiEncoder},
packet::{BLOCK_SIZE, ErrorCode, Mode, Packet},
};
pub(crate) const MAX_REQUEST_SIZE: usize = 2048;
#[must_use]
pub(crate) fn wildcard_addr_for(peer: SocketAddr) -> SocketAddr {
match peer.ip() {
IpAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
IpAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
}
}
pub(crate) fn bind_ephemeral_for(peer: SocketAddr) -> std::io::Result<UdpSocket> {
UdpSocket::bind(wildcard_addr_for(peer))
}
pub(crate) fn bind_ephemeral_on_ip(ip: IpAddr) -> std::io::Result<UdpSocket> {
UdpSocket::bind(SocketAddr::new(ip, 0))
}
pub(crate) fn is_timeout(err: &std::io::Error) -> bool {
matches!(
err.kind(),
std::io::ErrorKind::TimedOut | std::io::ErrorKind::WouldBlock
)
}
pub(crate) fn send_error(socket: &UdpSocket, peer: SocketAddr, code: ErrorCode, message: &str) {
// RFC 1350, Section 7: ERROR packets are not acknowledged nor retransmitted.
// Best effort.
let pkt = Packet::Error {
code,
message: message.to_string(),
};
let _ignored = socket.send_to(&pkt.encode(), peer);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SourcePhase {
Start,
AfterFullBlock,
Done,
}
/// Produces 512-byte DATA payload blocks according to RFC 1350, Section 2 and Section 6.
///
/// RFC 1350, Section 6: The end of a transfer is signaled by a DATA packet with
/// 0..511 bytes. If the payload stream ends exactly on a 512-byte boundary, an
/// additional zero-length DATA packet must be sent.
pub(crate) struct DataSource<'a, R: Read> {
reader: &'a mut R,
mode: Mode,
encoder: NetasciiEncoder,
pending: VecDeque<u8>,
source_eof: bool,
phase: SourcePhase,
}
impl<'a, R: Read> DataSource<'a, R> {
#[must_use]
pub(crate) fn new(reader: &'a mut R, mode: Mode) -> Self {
Self {
reader,
mode,
encoder: NetasciiEncoder::new(),
pending: VecDeque::new(),
source_eof: false,
phase: SourcePhase::Start,
}
}
pub(crate) fn next_block(&mut self) -> std::io::Result<Option<Vec<u8>>> {
if self.phase == SourcePhase::Done {
return Ok(None);
}
if self.mode == Mode::Octet {
return self.next_block_octet();
}
self.next_block_netascii()
}
fn next_block_octet(&mut self) -> std::io::Result<Option<Vec<u8>>> {
let mut buf = [0u8; BLOCK_SIZE];
let n = self.reader.read(&mut buf)?;
if n == 0 {
match self.phase {
SourcePhase::Start | SourcePhase::AfterFullBlock => {
self.phase = SourcePhase::Done;
return Ok(Some(Vec::new()));
}
SourcePhase::Done => return Ok(None),
}
}
if n < BLOCK_SIZE {
self.phase = SourcePhase::Done;
Ok(Some(buf[..n].to_vec()))
} else {
self.phase = SourcePhase::AfterFullBlock;
Ok(Some(buf.to_vec()))
}
}
fn next_block_netascii(&mut self) -> std::io::Result<Option<Vec<u8>>> {
while !self.source_eof && self.pending.len() < BLOCK_SIZE {
let mut in_buf = [0u8; 4096];
let n = self.reader.read(&mut in_buf)?;
if n == 0 {
self.source_eof = true;
break;
}
let mut encoded = Vec::with_capacity(n * 2);
self.encoder.encode_chunk(&in_buf[..n], &mut encoded);
self.pending.extend(encoded);
}
if self.pending.is_empty() {
match self.phase {
SourcePhase::Start | SourcePhase::AfterFullBlock => {
self.phase = SourcePhase::Done;
return Ok(Some(Vec::new()));
}
SourcePhase::Done => return Ok(None),
}
}
if self.pending.len() < BLOCK_SIZE {
self.phase = SourcePhase::Done;
let block: Vec<u8> = self.pending.drain(..).collect();
return Ok(Some(block));
}
self.phase = SourcePhase::AfterFullBlock;
let block: Vec<u8> = self.pending.drain(..BLOCK_SIZE).collect();
Ok(Some(block))
}
}
/// Writes received DATA payload blocks to an output, applying netascii translation.
pub(crate) enum DataSink<'a, W: Write> {
Octet(&'a mut W),
Netascii {
decoder: NetasciiDecoder,
output: &'a mut W,
scratch: Vec<u8>,
},
}
impl<'a, W: Write> DataSink<'a, W> {
#[must_use]
pub(crate) fn new(output: &'a mut W, mode: Mode) -> Self {
match mode {
Mode::Octet => Self::Octet(output),
Mode::NetAscii | Mode::Mail => Self::Netascii {
decoder: NetasciiDecoder::new(),
output,
scratch: Vec::new(),
},
}
}
pub(crate) fn write_data(&mut self, data: &[u8]) -> Result<(), DataSinkError> {
match self {
Self::Octet(out) => out.write_all(data).map_err(DataSinkError::Io),
Self::Netascii {
decoder,
output,
scratch,
} => {
scratch.clear();
decoder
.decode_chunk(data, scratch)
.map_err(DataSinkError::Netascii)?;
output.write_all(scratch).map_err(DataSinkError::Io)
}
}
}
pub(crate) fn finish(self) -> Result<(), DataSinkError> {
match self {
Self::Octet(_) => Ok(()),
Self::Netascii { decoder, .. } => decoder.finish().map_err(DataSinkError::Netascii),
}
}
}
#[derive(Debug)]
pub(crate) enum DataSinkError {
Io(std::io::Error),
Netascii(pfs_tftp_proto::netascii::DecodeError),
}