#![forbid(unsafe_code)] use std::{ collections::VecDeque, io::{Read, Write}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, }; use pfs_tftp_proto::{ netascii::{NetasciiDecoder, NetasciiEncoder}, packet::{BLOCK_SIZE, ErrorCode, Mode, Packet}, }; pub(crate) const MAX_REQUEST_SIZE: usize = 2048; #[must_use] pub(crate) fn wildcard_addr_for(peer: SocketAddr) -> SocketAddr { match peer.ip() { IpAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), IpAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0), } } pub(crate) fn bind_ephemeral_for(peer: SocketAddr) -> std::io::Result { UdpSocket::bind(wildcard_addr_for(peer)) } pub(crate) fn bind_ephemeral_on_ip(ip: IpAddr) -> std::io::Result { UdpSocket::bind(SocketAddr::new(ip, 0)) } pub(crate) fn is_timeout(err: &std::io::Error) -> bool { matches!( err.kind(), std::io::ErrorKind::TimedOut | std::io::ErrorKind::WouldBlock ) } pub(crate) fn send_error(socket: &UdpSocket, peer: SocketAddr, code: ErrorCode, message: &str) { // RFC 1350, Section 7: ERROR packets are not acknowledged nor retransmitted. // Best effort. let pkt = Packet::Error { code, message: message.to_string(), }; let _ignored = socket.send_to(&pkt.encode(), peer); } #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum SourcePhase { Start, AfterFullBlock, Done, } /// Produces 512-byte DATA payload blocks according to RFC 1350, Section 2 and Section 6. /// /// RFC 1350, Section 6: The end of a transfer is signaled by a DATA packet with /// 0..511 bytes. If the payload stream ends exactly on a 512-byte boundary, an /// additional zero-length DATA packet must be sent. pub(crate) struct DataSource<'a, R: Read> { reader: &'a mut R, mode: Mode, encoder: NetasciiEncoder, pending: VecDeque, source_eof: bool, phase: SourcePhase, } impl<'a, R: Read> DataSource<'a, R> { #[must_use] pub(crate) fn new(reader: &'a mut R, mode: Mode) -> Self { Self { reader, mode, encoder: NetasciiEncoder::new(), pending: VecDeque::new(), source_eof: false, phase: SourcePhase::Start, } } pub(crate) fn next_block(&mut self) -> std::io::Result>> { if self.phase == SourcePhase::Done { return Ok(None); } if self.mode == Mode::Octet { return self.next_block_octet(); } self.next_block_netascii() } fn next_block_octet(&mut self) -> std::io::Result>> { let mut buf = [0u8; BLOCK_SIZE]; let n = self.reader.read(&mut buf)?; if n == 0 { match self.phase { SourcePhase::Start | SourcePhase::AfterFullBlock => { self.phase = SourcePhase::Done; return Ok(Some(Vec::new())); } SourcePhase::Done => return Ok(None), } } if n < BLOCK_SIZE { self.phase = SourcePhase::Done; Ok(Some(buf[..n].to_vec())) } else { self.phase = SourcePhase::AfterFullBlock; Ok(Some(buf.to_vec())) } } fn next_block_netascii(&mut self) -> std::io::Result>> { while !self.source_eof && self.pending.len() < BLOCK_SIZE { let mut in_buf = [0u8; 4096]; let n = self.reader.read(&mut in_buf)?; if n == 0 { self.source_eof = true; break; } let mut encoded = Vec::with_capacity(n * 2); self.encoder.encode_chunk(&in_buf[..n], &mut encoded); self.pending.extend(encoded); } if self.pending.is_empty() { match self.phase { SourcePhase::Start | SourcePhase::AfterFullBlock => { self.phase = SourcePhase::Done; return Ok(Some(Vec::new())); } SourcePhase::Done => return Ok(None), } } if self.pending.len() < BLOCK_SIZE { self.phase = SourcePhase::Done; let block: Vec = self.pending.drain(..).collect(); return Ok(Some(block)); } self.phase = SourcePhase::AfterFullBlock; let block: Vec = self.pending.drain(..BLOCK_SIZE).collect(); Ok(Some(block)) } } /// Writes received DATA payload blocks to an output, applying netascii translation. pub(crate) enum DataSink<'a, W: Write> { Octet(&'a mut W), Netascii { decoder: NetasciiDecoder, output: &'a mut W, scratch: Vec, }, } impl<'a, W: Write> DataSink<'a, W> { #[must_use] pub(crate) fn new(output: &'a mut W, mode: Mode) -> Self { match mode { Mode::Octet => Self::Octet(output), Mode::NetAscii | Mode::Mail => Self::Netascii { decoder: NetasciiDecoder::new(), output, scratch: Vec::new(), }, } } pub(crate) fn write_data(&mut self, data: &[u8]) -> Result<(), DataSinkError> { match self { Self::Octet(out) => out.write_all(data).map_err(DataSinkError::Io), Self::Netascii { decoder, output, scratch, } => { scratch.clear(); decoder .decode_chunk(data, scratch) .map_err(DataSinkError::Netascii)?; output.write_all(scratch).map_err(DataSinkError::Io) } } } pub(crate) fn finish(self) -> Result<(), DataSinkError> { match self { Self::Octet(_) => Ok(()), Self::Netascii { decoder, .. } => decoder.finish().map_err(DataSinkError::Netascii), } } } #[derive(Debug)] pub(crate) enum DataSinkError { Io(std::io::Error), Netascii(pfs_tftp_proto::netascii::DecodeError), }