feat: add sync TFTP client/server library
This commit is contained in:
@@ -1,20 +1,38 @@
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
use std::{
|
||||
fs::File,
|
||||
io::{Read, Write},
|
||||
net::SocketAddr,
|
||||
path::Path,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use pfs_tftp_proto::packet::{BLOCK_SIZE, ErrorCode, MAX_PACKET_SIZE, Packet, Request};
|
||||
|
||||
use crate::{
|
||||
Mode,
|
||||
error::{Error, Result},
|
||||
util,
|
||||
};
|
||||
|
||||
/// Configuration for a synchronous TFTP client.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClientConfig {
|
||||
pub timeout: Duration,
|
||||
pub retries: u32,
|
||||
pub dally_timeout: Duration,
|
||||
pub dally_retries: u32,
|
||||
}
|
||||
|
||||
impl Default for ClientConfig {
|
||||
fn default() -> Self {
|
||||
let timeout = Duration::from_secs(5);
|
||||
Self {
|
||||
timeout: Duration::from_secs(5),
|
||||
timeout,
|
||||
retries: 5,
|
||||
dally_timeout: timeout,
|
||||
dally_retries: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -41,4 +59,428 @@ impl Client {
|
||||
pub fn config(&self) -> &ClientConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Downloads `remote_filename` from the configured server into `output`.
|
||||
///
|
||||
/// RFC reference:
|
||||
/// - Packet formats: RFC 1350, Section 5 (Figures 5-1..5-4)
|
||||
/// - Lock-step ACK/DATA: RFC 1350, Section 2
|
||||
/// - Termination (final <512 DATA): RFC 1350, Section 6
|
||||
/// - Transfer identifiers (TID): RFC 1350, Section 4 (page 4)
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error on I/O failures, protocol decode errors, timeouts after
|
||||
/// exhausting retries, or if the peer sends an ERROR packet.
|
||||
pub fn download_to_writer(
|
||||
&self,
|
||||
remote_filename: &str,
|
||||
mode: Mode,
|
||||
output: &mut impl Write,
|
||||
) -> Result<()> {
|
||||
let socket = util::bind_ephemeral_for(self.server)?;
|
||||
socket.set_read_timeout(Some(self.config.timeout))?;
|
||||
Self::validate_mode(mode)?;
|
||||
|
||||
let mut sink = util::DataSink::new(output, mode);
|
||||
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
|
||||
|
||||
let (peer, first) = self.rrq_handshake(&socket, &mut recv_buf, remote_filename, mode)?;
|
||||
let Packet::Data { block, data } = first else {
|
||||
return Err(Error::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"unexpected packet type",
|
||||
)));
|
||||
};
|
||||
if block != 1 {
|
||||
util::send_error(
|
||||
&socket,
|
||||
peer,
|
||||
ErrorCode::IllegalOperation,
|
||||
"unexpected block number",
|
||||
);
|
||||
return Err(Error::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"unexpected block number",
|
||||
)));
|
||||
}
|
||||
|
||||
sink.write_data(&data).map_err(map_sink_error)?;
|
||||
let ack_bytes = Packet::Ack { block }.encode();
|
||||
socket.send_to(&ack_bytes, peer)?;
|
||||
let mut last_ack = Some(ack_bytes.clone());
|
||||
|
||||
let mut last_sent = ack_bytes;
|
||||
let mut expected_block = 2u16;
|
||||
|
||||
if data.len() < BLOCK_SIZE {
|
||||
sink.finish().map_err(map_sink_error)?;
|
||||
self.dally_final_ack(&socket, peer, block, last_ack.as_deref())?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut attempts = 0u32;
|
||||
loop {
|
||||
let n = self.recv_from_peer_with_retries(
|
||||
&socket,
|
||||
&mut recv_buf,
|
||||
&mut attempts,
|
||||
"ACK",
|
||||
&last_sent,
|
||||
peer,
|
||||
)?;
|
||||
let pkt = Packet::decode(&recv_buf[..n])?;
|
||||
match pkt {
|
||||
Packet::Error { code, message } => {
|
||||
return Err(Error::RemoteError { code, message });
|
||||
}
|
||||
Packet::Data { block, data } if block == expected_block => {
|
||||
sink.write_data(&data).map_err(map_sink_error)?;
|
||||
let ack = Packet::Ack { block }.encode();
|
||||
socket.send_to(&ack, peer)?;
|
||||
last_ack = Some(ack.clone());
|
||||
last_sent = ack;
|
||||
if data.len() < BLOCK_SIZE {
|
||||
sink.finish().map_err(map_sink_error)?;
|
||||
self.dally_final_ack(&socket, peer, block, last_ack.as_deref())?;
|
||||
return Ok(());
|
||||
}
|
||||
expected_block = expected_block.wrapping_add(1);
|
||||
}
|
||||
Packet::Data { block, .. } if block == expected_block.wrapping_sub(1) => {
|
||||
if let Some(ack) = &last_ack {
|
||||
socket.send_to(ack, peer)?;
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
util::send_error(
|
||||
&socket,
|
||||
peer,
|
||||
ErrorCode::IllegalOperation,
|
||||
"unexpected packet type",
|
||||
);
|
||||
return Err(Error::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"unexpected packet type",
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Downloads `remote_filename` from the configured server into `local_path`.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns any error from [`Self::download_to_writer`] or filesystem I/O.
|
||||
pub fn get(&self, remote_filename: &str, local_path: &Path, mode: Mode) -> Result<()> {
|
||||
let mut file = File::create(local_path)?;
|
||||
self.download_to_writer(remote_filename, mode, &mut file)
|
||||
}
|
||||
|
||||
/// Uploads `input` to the configured server as `remote_filename`.
|
||||
///
|
||||
/// RFC reference:
|
||||
/// - WRQ/ACK(0) handshake: RFC 1350, Section 4 (page 4) and Section 5 (ACK block=0)
|
||||
/// - Lock-step ACK/DATA: RFC 1350, Section 2
|
||||
/// - Termination (final <512 DATA): RFC 1350, Section 6
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error on I/O failures, protocol decode errors, timeouts after
|
||||
/// exhausting retries, or if the peer sends an ERROR packet.
|
||||
pub fn upload_from_reader(
|
||||
&self,
|
||||
remote_filename: &str,
|
||||
mode: Mode,
|
||||
input: &mut impl Read,
|
||||
) -> Result<()> {
|
||||
let socket = util::bind_ephemeral_for(self.server)?;
|
||||
socket.set_read_timeout(Some(self.config.timeout))?;
|
||||
Self::validate_mode(mode)?;
|
||||
|
||||
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
|
||||
let peer = self.wrq_handshake(&socket, &mut recv_buf, remote_filename, mode)?;
|
||||
|
||||
let mut source = util::DataSource::new(input, mode);
|
||||
let mut next_block: u16 = 1;
|
||||
let mut attempts = 0u32;
|
||||
while let Some(data) = source.next_block()? {
|
||||
let data_bytes = Packet::Data {
|
||||
block: next_block,
|
||||
data,
|
||||
}
|
||||
.encode();
|
||||
socket.send_to(&data_bytes, peer)?;
|
||||
|
||||
let expected_ack = next_block;
|
||||
let last_sent = data_bytes;
|
||||
let last_sent_kind = "DATA";
|
||||
|
||||
loop {
|
||||
let n = self.recv_from_peer_with_retries(
|
||||
&socket,
|
||||
&mut recv_buf,
|
||||
&mut attempts,
|
||||
last_sent_kind,
|
||||
&last_sent,
|
||||
peer,
|
||||
)?;
|
||||
match Packet::decode(&recv_buf[..n])? {
|
||||
Packet::Ack { block } if block == expected_ack => break,
|
||||
Packet::Ack { block } if block == expected_ack.wrapping_sub(1) => {
|
||||
socket.send_to(&last_sent, peer)?;
|
||||
}
|
||||
Packet::Error { code, message } => {
|
||||
return Err(Error::RemoteError { code, message });
|
||||
}
|
||||
_ => {
|
||||
util::send_error(
|
||||
&socket,
|
||||
peer,
|
||||
ErrorCode::IllegalOperation,
|
||||
"unexpected packet type",
|
||||
);
|
||||
return Err(Error::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"unexpected packet type",
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
attempts = 0;
|
||||
next_block = next_block.wrapping_add(1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Uploads `local_path` to the configured server as `remote_filename`.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns any error from [`Self::upload_from_reader`] or filesystem I/O.
|
||||
pub fn put(&self, local_path: &Path, remote_filename: &str, mode: Mode) -> Result<()> {
|
||||
let mut file = File::open(local_path)?;
|
||||
self.upload_from_reader(remote_filename, mode, &mut file)
|
||||
}
|
||||
|
||||
fn dally_final_ack(
|
||||
&self,
|
||||
socket: &std::net::UdpSocket,
|
||||
peer: SocketAddr,
|
||||
final_block: u16,
|
||||
ack_bytes: Option<&[u8]>,
|
||||
) -> Result<()> {
|
||||
if self.config.dally_retries == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let Some(ack_bytes) = ack_bytes else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
socket.set_read_timeout(Some(self.config.dally_timeout))?;
|
||||
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
|
||||
let mut timeouts = 0u32;
|
||||
|
||||
// RFC 1350, Section 6: "dallying is encouraged" to retransmit the final ACK
|
||||
// if the last DATA is retransmitted.
|
||||
while timeouts < self.config.dally_retries {
|
||||
match socket.recv_from(&mut recv_buf) {
|
||||
Ok((n, from)) => {
|
||||
if from != peer {
|
||||
util::send_error(
|
||||
socket,
|
||||
from,
|
||||
ErrorCode::UnknownTransferId,
|
||||
"unknown transfer id",
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if let Ok(Packet::Data { block, .. }) = Packet::decode(&recv_buf[..n])
|
||||
&& block == final_block
|
||||
{
|
||||
let _ignored = socket.send_to(ack_bytes, peer);
|
||||
}
|
||||
}
|
||||
Err(e) if util::is_timeout(&e) => timeouts += 1,
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_mode(mode: Mode) -> Result<()> {
|
||||
if mode == Mode::Mail {
|
||||
return Err(Error::UnsupportedMode(mode));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rrq_handshake(
|
||||
&self,
|
||||
socket: &std::net::UdpSocket,
|
||||
recv_buf: &mut [u8],
|
||||
remote_filename: &str,
|
||||
mode: Mode,
|
||||
) -> Result<(SocketAddr, Packet)> {
|
||||
let rrq_bytes = Packet::Rrq(Request {
|
||||
filename: remote_filename.to_string(),
|
||||
mode,
|
||||
})
|
||||
.encode();
|
||||
socket.send_to(&rrq_bytes, self.server)?;
|
||||
|
||||
let mut attempts = 0u32;
|
||||
loop {
|
||||
let (n, from) = self.recv_from_with_retries(
|
||||
socket,
|
||||
recv_buf,
|
||||
&mut attempts,
|
||||
"RRQ",
|
||||
&rrq_bytes,
|
||||
self.server,
|
||||
)?;
|
||||
if from.ip() != self.server.ip() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pkt = Packet::decode(&recv_buf[..n])?;
|
||||
match pkt {
|
||||
Packet::Data { .. } => return Ok((from, pkt)),
|
||||
Packet::Error { code, message } => {
|
||||
return Err(Error::RemoteError { code, message });
|
||||
}
|
||||
_ => {
|
||||
util::send_error(
|
||||
socket,
|
||||
from,
|
||||
ErrorCode::IllegalOperation,
|
||||
"unexpected packet",
|
||||
);
|
||||
return Err(Error::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"unexpected packet",
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn wrq_handshake(
|
||||
&self,
|
||||
socket: &std::net::UdpSocket,
|
||||
recv_buf: &mut [u8],
|
||||
remote_filename: &str,
|
||||
mode: Mode,
|
||||
) -> Result<SocketAddr> {
|
||||
let wrq_bytes = Packet::Wrq(Request {
|
||||
filename: remote_filename.to_string(),
|
||||
mode,
|
||||
})
|
||||
.encode();
|
||||
socket.send_to(&wrq_bytes, self.server)?;
|
||||
|
||||
let mut attempts = 0u32;
|
||||
loop {
|
||||
let (n, from) = self.recv_from_with_retries(
|
||||
socket,
|
||||
recv_buf,
|
||||
&mut attempts,
|
||||
"WRQ",
|
||||
&wrq_bytes,
|
||||
self.server,
|
||||
)?;
|
||||
if from.ip() != self.server.ip() {
|
||||
continue;
|
||||
}
|
||||
match Packet::decode(&recv_buf[..n])? {
|
||||
Packet::Ack { block: 0 } => return Ok(from),
|
||||
Packet::Error { code, message } => {
|
||||
return Err(Error::RemoteError { code, message });
|
||||
}
|
||||
_ => {
|
||||
util::send_error(
|
||||
socket,
|
||||
from,
|
||||
ErrorCode::IllegalOperation,
|
||||
"unexpected packet",
|
||||
);
|
||||
return Err(Error::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"unexpected packet",
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn recv_from_with_retries(
|
||||
&self,
|
||||
socket: &std::net::UdpSocket,
|
||||
recv_buf: &mut [u8],
|
||||
attempts: &mut u32,
|
||||
last_sent_kind: &'static str,
|
||||
last_sent: &[u8],
|
||||
resend_to: SocketAddr,
|
||||
) -> Result<(usize, SocketAddr)> {
|
||||
loop {
|
||||
match socket.recv_from(recv_buf) {
|
||||
Ok(v) => {
|
||||
*attempts = 0;
|
||||
return Ok(v);
|
||||
}
|
||||
Err(e) if util::is_timeout(&e) => {
|
||||
*attempts += 1;
|
||||
if *attempts > self.config.retries {
|
||||
return Err(Error::Timeout {
|
||||
last_packet: last_sent_kind,
|
||||
attempts: *attempts,
|
||||
});
|
||||
}
|
||||
socket.send_to(last_sent, resend_to)?;
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn recv_from_peer_with_retries(
|
||||
&self,
|
||||
socket: &std::net::UdpSocket,
|
||||
recv_buf: &mut [u8],
|
||||
attempts: &mut u32,
|
||||
last_sent_kind: &'static str,
|
||||
last_sent: &[u8],
|
||||
peer: SocketAddr,
|
||||
) -> Result<usize> {
|
||||
loop {
|
||||
let (n, from) = self.recv_from_with_retries(
|
||||
socket,
|
||||
recv_buf,
|
||||
attempts,
|
||||
last_sent_kind,
|
||||
last_sent,
|
||||
peer,
|
||||
)?;
|
||||
if from == peer {
|
||||
return Ok(n);
|
||||
}
|
||||
|
||||
// RFC 1350, Section 4 (page 4): incorrect source port => Unknown transfer ID.
|
||||
if from.ip() == peer.ip() {
|
||||
util::send_error(
|
||||
socket,
|
||||
from,
|
||||
ErrorCode::UnknownTransferId,
|
||||
"unknown transfer id",
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn map_sink_error(err: util::DataSinkError) -> Error {
|
||||
match err {
|
||||
util::DataSinkError::Io(e) => Error::Io(e),
|
||||
util::DataSinkError::Netascii(e) => Error::Netascii(e),
|
||||
}
|
||||
}
|
||||
|
||||
76
crates/pfs-tftp-sync/src/error.rs
Normal file
76
crates/pfs-tftp-sync/src/error.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
use core::fmt;
|
||||
|
||||
use pfs_tftp_proto::{netascii, packet};
|
||||
|
||||
/// Errors returned by the synchronous client APIs.
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
Io(std::io::Error),
|
||||
Protocol(packet::DecodeError),
|
||||
Netascii(netascii::DecodeError),
|
||||
UnsupportedMode(packet::Mode),
|
||||
RemoteError {
|
||||
code: packet::ErrorCode,
|
||||
message: String,
|
||||
},
|
||||
Timeout {
|
||||
last_packet: &'static str,
|
||||
attempts: u32,
|
||||
},
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Io(e) => write!(f, "{e}"),
|
||||
Self::Protocol(e) => write!(f, "{e}"),
|
||||
Self::Netascii(e) => write!(f, "{e}"),
|
||||
Self::UnsupportedMode(mode) => {
|
||||
write!(f, "unsupported transfer mode: {}", mode.as_str())
|
||||
}
|
||||
Self::RemoteError { code, message } => {
|
||||
write!(f, "remote error {code:?}: {message}")
|
||||
}
|
||||
Self::Timeout {
|
||||
last_packet,
|
||||
attempts,
|
||||
} => write!(
|
||||
f,
|
||||
"timed out waiting for response after {attempts} attempts (last sent: {last_packet})"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Io(e) => Some(e),
|
||||
Self::Protocol(e) => Some(e),
|
||||
Self::Netascii(e) => Some(e),
|
||||
Self::UnsupportedMode(_) | Self::RemoteError { .. } | Self::Timeout { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for Error {
|
||||
fn from(value: std::io::Error) -> Self {
|
||||
Self::Io(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<packet::DecodeError> for Error {
|
||||
fn from(value: packet::DecodeError) -> Self {
|
||||
Self::Protocol(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<netascii::DecodeError> for Error {
|
||||
fn from(value: netascii::DecodeError) -> Self {
|
||||
Self::Netascii(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub type Result<T> = core::result::Result<T, Error>;
|
||||
@@ -6,8 +6,11 @@
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
pub mod client;
|
||||
pub mod error;
|
||||
pub mod server;
|
||||
pub mod util;
|
||||
mod util;
|
||||
|
||||
pub use client::{Client, ClientConfig};
|
||||
pub use error::{Error, Result};
|
||||
pub use pfs_tftp_proto::packet::Mode;
|
||||
pub use server::{Server, ServerConfig};
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
use std::{
|
||||
fs::{File, OpenOptions},
|
||||
net::SocketAddr,
|
||||
path::{Path, PathBuf},
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use pfs_tftp_proto::packet::{BLOCK_SIZE, ErrorCode, MAX_PACKET_SIZE, Packet};
|
||||
|
||||
use crate::{Mode, util};
|
||||
|
||||
/// Configuration for a synchronous TFTP server.
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -13,17 +21,24 @@ pub struct ServerConfig {
|
||||
pub overwrite: bool,
|
||||
pub timeout: Duration,
|
||||
pub retries: u32,
|
||||
pub dally_timeout: Duration,
|
||||
pub dally_retries: u32,
|
||||
pub max_request_size: usize,
|
||||
}
|
||||
|
||||
impl Default for ServerConfig {
|
||||
fn default() -> Self {
|
||||
let timeout = Duration::from_secs(5);
|
||||
Self {
|
||||
bind: SocketAddr::from(([0, 0, 0, 0], 6969)),
|
||||
root: PathBuf::from("."),
|
||||
allow_write: false,
|
||||
overwrite: false,
|
||||
timeout: Duration::from_secs(5),
|
||||
timeout,
|
||||
retries: 5,
|
||||
dally_timeout: timeout,
|
||||
dally_retries: 2,
|
||||
max_request_size: util::MAX_REQUEST_SIZE,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -44,4 +59,452 @@ impl Server {
|
||||
pub fn config(&self) -> &ServerConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Serves requests forever.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the server socket cannot be bound or configured.
|
||||
pub fn serve(&self) -> std::io::Result<()> {
|
||||
self.serve_inner(None)
|
||||
}
|
||||
|
||||
/// Serves requests until `shutdown` becomes `true`.
|
||||
///
|
||||
/// This is mainly intended for tests.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the server socket cannot be bound or configured.
|
||||
pub fn serve_until(&self, shutdown: &AtomicBool) -> std::io::Result<()> {
|
||||
self.serve_inner(Some(shutdown))
|
||||
}
|
||||
|
||||
fn serve_inner(&self, shutdown: Option<&AtomicBool>) -> std::io::Result<()> {
|
||||
let socket = std::net::UdpSocket::bind(self.config.bind)?;
|
||||
if shutdown.is_some() {
|
||||
socket.set_read_timeout(Some(Duration::from_millis(200)))?;
|
||||
}
|
||||
|
||||
let mut buf = vec![0u8; self.config.max_request_size];
|
||||
|
||||
loop {
|
||||
if let Some(flag) = shutdown
|
||||
&& flag.load(Ordering::Relaxed)
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (n, peer) = match socket.recv_from(&mut buf) {
|
||||
Ok(v) => v,
|
||||
Err(e) if shutdown.is_some() && util::is_timeout(&e) => continue,
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
let Ok(pkt) = Packet::decode(&buf[..n]) else {
|
||||
util::send_error(
|
||||
&socket,
|
||||
peer,
|
||||
ErrorCode::IllegalOperation,
|
||||
"malformed packet",
|
||||
);
|
||||
continue;
|
||||
};
|
||||
|
||||
match pkt {
|
||||
Packet::Rrq(req) => {
|
||||
let cfg = self.config.clone();
|
||||
std::thread::spawn(move || {
|
||||
let _ignored = handle_rrq(cfg, peer, req);
|
||||
});
|
||||
}
|
||||
Packet::Wrq(req) => {
|
||||
let cfg = self.config.clone();
|
||||
std::thread::spawn(move || {
|
||||
let _ignored = handle_wrq(cfg, peer, req);
|
||||
});
|
||||
}
|
||||
Packet::Data { .. } | Packet::Ack { .. } | Packet::Error { .. } => {
|
||||
// RFC 1350, Section 4 (page 4): incorrect source port => Unknown transfer ID.
|
||||
util::send_error(
|
||||
&socket,
|
||||
peer,
|
||||
ErrorCode::UnknownTransferId,
|
||||
"unknown transfer id",
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_rrq(
|
||||
cfg: ServerConfig,
|
||||
client: SocketAddr,
|
||||
req: pfs_tftp_proto::packet::Request,
|
||||
) -> std::io::Result<()> {
|
||||
let pfs_tftp_proto::packet::Request { filename, mode } = req;
|
||||
let ServerConfig {
|
||||
bind,
|
||||
root,
|
||||
timeout,
|
||||
retries,
|
||||
..
|
||||
} = cfg;
|
||||
|
||||
if mode == Mode::Mail {
|
||||
let socket = util::bind_ephemeral_on_ip(bind.ip())?;
|
||||
util::send_error(
|
||||
&socket,
|
||||
client,
|
||||
ErrorCode::NotDefined,
|
||||
"mail mode is obsolete",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let socket = util::bind_ephemeral_on_ip(bind.ip())?;
|
||||
socket.set_read_timeout(Some(timeout))?;
|
||||
|
||||
let Some(path) = resolve_path(&root, &filename) else {
|
||||
util::send_error(&socket, client, ErrorCode::AccessViolation, "invalid path");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let mut file = match File::open(&path) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
util::send_error(
|
||||
&socket,
|
||||
client,
|
||||
map_io_error_code(&e),
|
||||
"failed to open file",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let mut source = util::DataSource::new(&mut file, mode);
|
||||
let mut block: u16 = 1;
|
||||
|
||||
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
|
||||
let mut attempts = 0u32;
|
||||
|
||||
while let Some(data) = source.next_block()? {
|
||||
let pkt = Packet::Data { block, data };
|
||||
let bytes = pkt.encode();
|
||||
socket.send_to(&bytes, client)?;
|
||||
|
||||
let last_sent = bytes;
|
||||
let expected_ack = block;
|
||||
|
||||
loop {
|
||||
let (n, from) = match socket.recv_from(&mut recv_buf) {
|
||||
Ok(v) => v,
|
||||
Err(e) if util::is_timeout(&e) => {
|
||||
attempts += 1;
|
||||
if attempts > retries {
|
||||
return Ok(());
|
||||
}
|
||||
socket.send_to(&last_sent, client)?;
|
||||
continue;
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
if from != client {
|
||||
util::send_error(
|
||||
&socket,
|
||||
from,
|
||||
ErrorCode::UnknownTransferId,
|
||||
"unknown transfer id",
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
attempts = 0;
|
||||
match Packet::decode(&recv_buf[..n]) {
|
||||
Ok(Packet::Ack { block }) if block == expected_ack => break,
|
||||
Ok(Packet::Ack { block }) if block == expected_ack.wrapping_sub(1) => {
|
||||
// Duplicate ACK, retransmit last DATA.
|
||||
socket.send_to(&last_sent, client)?;
|
||||
}
|
||||
Ok(Packet::Error { .. }) => return Ok(()),
|
||||
Ok(_) | Err(_) => {
|
||||
util::send_error(
|
||||
&socket,
|
||||
client,
|
||||
ErrorCode::IllegalOperation,
|
||||
"unexpected packet",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RFC 1350, Section 6: final DATA has <512 bytes.
|
||||
if last_sent.len() < 4 + BLOCK_SIZE {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
block = block.wrapping_add(1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn handle_wrq(
|
||||
cfg: ServerConfig,
|
||||
client: SocketAddr,
|
||||
req: pfs_tftp_proto::packet::Request,
|
||||
) -> std::io::Result<()> {
|
||||
let pfs_tftp_proto::packet::Request { filename, mode } = req;
|
||||
let ServerConfig {
|
||||
bind,
|
||||
root,
|
||||
allow_write,
|
||||
overwrite,
|
||||
timeout,
|
||||
retries,
|
||||
dally_timeout,
|
||||
dally_retries,
|
||||
..
|
||||
} = cfg;
|
||||
|
||||
let socket = util::bind_ephemeral_on_ip(bind.ip())?;
|
||||
socket.set_read_timeout(Some(timeout))?;
|
||||
|
||||
if mode == Mode::Mail {
|
||||
util::send_error(
|
||||
&socket,
|
||||
client,
|
||||
ErrorCode::NotDefined,
|
||||
"mail mode is obsolete",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !allow_write {
|
||||
util::send_error(
|
||||
&socket,
|
||||
client,
|
||||
ErrorCode::AccessViolation,
|
||||
"writes disabled",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let Some(path) = resolve_path(&root, &filename) else {
|
||||
util::send_error(&socket, client, ErrorCode::AccessViolation, "invalid path");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let mut file = match open_for_write(&path, overwrite) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
util::send_error(
|
||||
&socket,
|
||||
client,
|
||||
map_io_error_code(&e),
|
||||
"failed to create file",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// RFC 1350, Section 5: WRQ is acknowledged by ACK(block=0).
|
||||
let ack0 = Packet::Ack { block: 0 }.encode();
|
||||
socket.send_to(&ack0, client)?;
|
||||
|
||||
let mut last_ack = ack0;
|
||||
let mut expected_block: u16 = 1;
|
||||
let mut attempts = 0u32;
|
||||
|
||||
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
|
||||
let mut sink = util::DataSink::new(&mut file, mode);
|
||||
|
||||
loop {
|
||||
let (n, from) = match socket.recv_from(&mut recv_buf) {
|
||||
Ok(v) => v,
|
||||
Err(e) if util::is_timeout(&e) => {
|
||||
attempts += 1;
|
||||
if attempts > retries {
|
||||
return Ok(());
|
||||
}
|
||||
socket.send_to(&last_ack, client)?;
|
||||
continue;
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
if from != client {
|
||||
util::send_error(
|
||||
&socket,
|
||||
from,
|
||||
ErrorCode::UnknownTransferId,
|
||||
"unknown transfer id",
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
attempts = 0;
|
||||
let Ok(pkt) = Packet::decode(&recv_buf[..n]) else {
|
||||
util::send_error(
|
||||
&socket,
|
||||
client,
|
||||
ErrorCode::IllegalOperation,
|
||||
"malformed packet",
|
||||
);
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
match pkt {
|
||||
Packet::Error { .. } => return Ok(()),
|
||||
Packet::Data { block, data } => {
|
||||
if block == expected_block {
|
||||
if let Err(e) = sink.write_data(&data) {
|
||||
util::send_error(
|
||||
&socket,
|
||||
client,
|
||||
map_sink_error_code(&e),
|
||||
"failed to write data",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let ack = Packet::Ack { block }.encode();
|
||||
socket.send_to(&ack, client)?;
|
||||
last_ack = ack;
|
||||
|
||||
if data.len() < BLOCK_SIZE {
|
||||
if let Err(e) = sink.finish() {
|
||||
util::send_error(
|
||||
&socket,
|
||||
client,
|
||||
map_sink_error_code(&e),
|
||||
"failed to finish transfer",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
dally_final_ack(
|
||||
dally_timeout,
|
||||
dally_retries,
|
||||
&socket,
|
||||
client,
|
||||
block,
|
||||
&last_ack,
|
||||
)?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
expected_block = expected_block.wrapping_add(1);
|
||||
} else if block == expected_block.wrapping_sub(1) {
|
||||
// Duplicate DATA (ACK lost); re-ACK.
|
||||
socket.send_to(&last_ack, client)?;
|
||||
} else {
|
||||
util::send_error(
|
||||
&socket,
|
||||
client,
|
||||
ErrorCode::IllegalOperation,
|
||||
"unexpected block number",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Packet::Ack { .. } | Packet::Rrq(_) | Packet::Wrq(_) => {
|
||||
util::send_error(
|
||||
&socket,
|
||||
client,
|
||||
ErrorCode::IllegalOperation,
|
||||
"unexpected packet type",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_path(root: &Path, filename: &str) -> Option<PathBuf> {
|
||||
let rel = Path::new(filename);
|
||||
if rel.is_absolute() {
|
||||
return None;
|
||||
}
|
||||
for c in rel.components() {
|
||||
match c {
|
||||
std::path::Component::Normal(_) => {}
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
Some(root.join(rel))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn map_io_error_code(err: &std::io::Error) -> ErrorCode {
|
||||
match err.kind() {
|
||||
std::io::ErrorKind::NotFound => ErrorCode::FileNotFound,
|
||||
std::io::ErrorKind::PermissionDenied => ErrorCode::AccessViolation,
|
||||
std::io::ErrorKind::AlreadyExists => ErrorCode::FileAlreadyExists,
|
||||
std::io::ErrorKind::StorageFull => ErrorCode::DiskFull,
|
||||
_ => ErrorCode::NotDefined,
|
||||
}
|
||||
}
|
||||
|
||||
fn open_for_write(path: &Path, overwrite: bool) -> std::io::Result<File> {
|
||||
let mut opts = OpenOptions::new();
|
||||
opts.write(true);
|
||||
if overwrite {
|
||||
opts.create(true).truncate(true);
|
||||
} else {
|
||||
opts.create_new(true);
|
||||
}
|
||||
opts.open(path)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn map_sink_error_code(err: &util::DataSinkError) -> ErrorCode {
|
||||
match err {
|
||||
util::DataSinkError::Io(e) => map_io_error_code(e),
|
||||
// Invalid netascii is a protocol violation.
|
||||
util::DataSinkError::Netascii(_) => ErrorCode::IllegalOperation,
|
||||
}
|
||||
}
|
||||
|
||||
fn dally_final_ack(
|
||||
dally_timeout: Duration,
|
||||
dally_retries: u32,
|
||||
socket: &std::net::UdpSocket,
|
||||
peer: SocketAddr,
|
||||
final_block: u16,
|
||||
ack_bytes: &[u8],
|
||||
) -> std::io::Result<()> {
|
||||
if dally_retries == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
socket.set_read_timeout(Some(dally_timeout))?;
|
||||
let mut recv_buf = [0u8; MAX_PACKET_SIZE];
|
||||
let mut timeouts = 0u32;
|
||||
|
||||
// RFC 1350, Section 6: dallying for final ACK retransmission.
|
||||
while timeouts < dally_retries {
|
||||
match socket.recv_from(&mut recv_buf) {
|
||||
Ok((n, from)) => {
|
||||
if from != peer {
|
||||
util::send_error(
|
||||
socket,
|
||||
from,
|
||||
ErrorCode::UnknownTransferId,
|
||||
"unknown transfer id",
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if let Ok(Packet::Data { block, .. }) = Packet::decode(&recv_buf[..n])
|
||||
&& block == final_block
|
||||
{
|
||||
let _ignored = socket.send_to(ack_bytes, peer);
|
||||
}
|
||||
}
|
||||
Err(e) if util::is_timeout(&e) => timeouts += 1,
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,4 +1,204 @@
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
// Misc sync helpers live here.
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
io::{Read, Write},
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
|
||||
};
|
||||
|
||||
use pfs_tftp_proto::{
|
||||
netascii::{NetasciiDecoder, NetasciiEncoder},
|
||||
packet::{BLOCK_SIZE, ErrorCode, Mode, Packet},
|
||||
};
|
||||
|
||||
pub(crate) const MAX_REQUEST_SIZE: usize = 2048;
|
||||
|
||||
#[must_use]
|
||||
pub(crate) fn wildcard_addr_for(peer: SocketAddr) -> SocketAddr {
|
||||
match peer.ip() {
|
||||
IpAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
|
||||
IpAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn bind_ephemeral_for(peer: SocketAddr) -> std::io::Result<UdpSocket> {
|
||||
UdpSocket::bind(wildcard_addr_for(peer))
|
||||
}
|
||||
|
||||
pub(crate) fn bind_ephemeral_on_ip(ip: IpAddr) -> std::io::Result<UdpSocket> {
|
||||
UdpSocket::bind(SocketAddr::new(ip, 0))
|
||||
}
|
||||
|
||||
pub(crate) fn is_timeout(err: &std::io::Error) -> bool {
|
||||
matches!(
|
||||
err.kind(),
|
||||
std::io::ErrorKind::TimedOut | std::io::ErrorKind::WouldBlock
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn send_error(socket: &UdpSocket, peer: SocketAddr, code: ErrorCode, message: &str) {
|
||||
// RFC 1350, Section 7: ERROR packets are not acknowledged nor retransmitted.
|
||||
// Best effort.
|
||||
let pkt = Packet::Error {
|
||||
code,
|
||||
message: message.to_string(),
|
||||
};
|
||||
let _ignored = socket.send_to(&pkt.encode(), peer);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum SourcePhase {
|
||||
Start,
|
||||
AfterFullBlock,
|
||||
Done,
|
||||
}
|
||||
|
||||
/// Produces 512-byte DATA payload blocks according to RFC 1350, Section 2 and Section 6.
|
||||
///
|
||||
/// RFC 1350, Section 6: The end of a transfer is signaled by a DATA packet with
|
||||
/// 0..511 bytes. If the payload stream ends exactly on a 512-byte boundary, an
|
||||
/// additional zero-length DATA packet must be sent.
|
||||
pub(crate) struct DataSource<'a, R: Read> {
|
||||
reader: &'a mut R,
|
||||
mode: Mode,
|
||||
encoder: NetasciiEncoder,
|
||||
pending: VecDeque<u8>,
|
||||
source_eof: bool,
|
||||
phase: SourcePhase,
|
||||
}
|
||||
|
||||
impl<'a, R: Read> DataSource<'a, R> {
|
||||
#[must_use]
|
||||
pub(crate) fn new(reader: &'a mut R, mode: Mode) -> Self {
|
||||
Self {
|
||||
reader,
|
||||
mode,
|
||||
encoder: NetasciiEncoder::new(),
|
||||
pending: VecDeque::new(),
|
||||
source_eof: false,
|
||||
phase: SourcePhase::Start,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn next_block(&mut self) -> std::io::Result<Option<Vec<u8>>> {
|
||||
if self.phase == SourcePhase::Done {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if self.mode == Mode::Octet {
|
||||
return self.next_block_octet();
|
||||
}
|
||||
|
||||
self.next_block_netascii()
|
||||
}
|
||||
|
||||
fn next_block_octet(&mut self) -> std::io::Result<Option<Vec<u8>>> {
|
||||
let mut buf = [0u8; BLOCK_SIZE];
|
||||
let n = self.reader.read(&mut buf)?;
|
||||
if n == 0 {
|
||||
match self.phase {
|
||||
SourcePhase::Start | SourcePhase::AfterFullBlock => {
|
||||
self.phase = SourcePhase::Done;
|
||||
return Ok(Some(Vec::new()));
|
||||
}
|
||||
SourcePhase::Done => return Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
if n < BLOCK_SIZE {
|
||||
self.phase = SourcePhase::Done;
|
||||
Ok(Some(buf[..n].to_vec()))
|
||||
} else {
|
||||
self.phase = SourcePhase::AfterFullBlock;
|
||||
Ok(Some(buf.to_vec()))
|
||||
}
|
||||
}
|
||||
|
||||
fn next_block_netascii(&mut self) -> std::io::Result<Option<Vec<u8>>> {
|
||||
while !self.source_eof && self.pending.len() < BLOCK_SIZE {
|
||||
let mut in_buf = [0u8; 4096];
|
||||
let n = self.reader.read(&mut in_buf)?;
|
||||
if n == 0 {
|
||||
self.source_eof = true;
|
||||
break;
|
||||
}
|
||||
let mut encoded = Vec::with_capacity(n * 2);
|
||||
self.encoder.encode_chunk(&in_buf[..n], &mut encoded);
|
||||
self.pending.extend(encoded);
|
||||
}
|
||||
|
||||
if self.pending.is_empty() {
|
||||
match self.phase {
|
||||
SourcePhase::Start | SourcePhase::AfterFullBlock => {
|
||||
self.phase = SourcePhase::Done;
|
||||
return Ok(Some(Vec::new()));
|
||||
}
|
||||
SourcePhase::Done => return Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
if self.pending.len() < BLOCK_SIZE {
|
||||
self.phase = SourcePhase::Done;
|
||||
let block: Vec<u8> = self.pending.drain(..).collect();
|
||||
return Ok(Some(block));
|
||||
}
|
||||
|
||||
self.phase = SourcePhase::AfterFullBlock;
|
||||
let block: Vec<u8> = self.pending.drain(..BLOCK_SIZE).collect();
|
||||
Ok(Some(block))
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes received DATA payload blocks to an output, applying netascii translation.
|
||||
pub(crate) enum DataSink<'a, W: Write> {
|
||||
Octet(&'a mut W),
|
||||
Netascii {
|
||||
decoder: NetasciiDecoder,
|
||||
output: &'a mut W,
|
||||
scratch: Vec<u8>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<'a, W: Write> DataSink<'a, W> {
|
||||
#[must_use]
|
||||
pub(crate) fn new(output: &'a mut W, mode: Mode) -> Self {
|
||||
match mode {
|
||||
Mode::Octet => Self::Octet(output),
|
||||
Mode::NetAscii | Mode::Mail => Self::Netascii {
|
||||
decoder: NetasciiDecoder::new(),
|
||||
output,
|
||||
scratch: Vec::new(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn write_data(&mut self, data: &[u8]) -> Result<(), DataSinkError> {
|
||||
match self {
|
||||
Self::Octet(out) => out.write_all(data).map_err(DataSinkError::Io),
|
||||
Self::Netascii {
|
||||
decoder,
|
||||
output,
|
||||
scratch,
|
||||
} => {
|
||||
scratch.clear();
|
||||
decoder
|
||||
.decode_chunk(data, scratch)
|
||||
.map_err(DataSinkError::Netascii)?;
|
||||
output.write_all(scratch).map_err(DataSinkError::Io)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn finish(self) -> Result<(), DataSinkError> {
|
||||
match self {
|
||||
Self::Octet(_) => Ok(()),
|
||||
Self::Netascii { decoder, .. } => decoder.finish().map_err(DataSinkError::Netascii),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum DataSinkError {
|
||||
Io(std::io::Error),
|
||||
Netascii(pfs_tftp_proto::netascii::DecodeError),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user