diff --git a/Cargo.lock b/Cargo.lock index 35de35d..46a30e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -232,6 +232,21 @@ dependencies = [ "libc", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crypto-common" version = "0.1.7" @@ -290,6 +305,7 @@ dependencies = [ "assert_cmd", "chacha20poly1305", "clap", + "crossbeam-channel", "getrandom 0.4.2", "libc", "rlimit", diff --git a/Cargo.toml b/Cargo.toml index b53d10b..ac830c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ version = "0.10.0" argon2 = "0.5" chacha20poly1305 = "0.10" clap = {version = "4", features = ["derive"]} +crossbeam-channel = "0.5" getrandom = {version = "0.4"} protected-secrets = {package = "secrets", version = "1.3"} zeroize = {version = "1", features = ["derive"]} diff --git a/TODO.md b/TODO.md deleted file mode 100644 index f9540eb..0000000 --- a/TODO.md +++ /dev/null @@ -1,3 +0,0 @@ -**Deferred to follow-up commits** (in order): -1. Multi-threaded pipeline (worker pool + ordered writer) -2. Length-committed mode + random-access decrypt fast path for files diff --git a/src/crypto.rs b/src/crypto.rs index 130d452..00a9515 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -1,21 +1,24 @@ // SPDX-License-Identifier: GPL-3.0-only use chacha20poly1305::{KeyInit, XChaCha20Poly1305, XNonce, aead::AeadInPlace}; -use std::io::Write; +use std::fs::File; +use std::io::{BufReader, Read, Seek, SeekFrom, Write}; +use std::sync::Arc; use crate::error::*; -use crate::header::{AlgId, Header, KdfParams, NONCE_PREFIX_LEN, TAG_LEN}; +use crate::header::{AlgId, FLAG_LENGTH_COMMITTED, Header, KdfParams, NONCE_PREFIX_LEN, TAG_LEN}; +use crate::pipeline; use crate::reader::{AheadReader, ReadInfoChunk}; use crate::secrets::{SecretBytes32, SecretVec}; use crate::utils::*; /// XChaCha20Poly1305 nonce: 24 bytes total. STREAM splits the trailing 5 bytes /// into a 4-byte big-endian counter and a 1-byte "last block" flag. -const NONCE_LEN: usize = 24; -const COUNTER_LEN: usize = 4; +pub(crate) const NONCE_LEN: usize = 24; +pub(crate) const COUNTER_LEN: usize = 4; const _: () = assert!(NONCE_PREFIX_LEN + COUNTER_LEN + 1 == NONCE_LEN); -fn make_nonce(prefix: &[u8; NONCE_PREFIX_LEN], counter: u32, last: bool) -> XNonce { +pub(crate) fn make_nonce(prefix: &[u8; NONCE_PREFIX_LEN], counter: u32, last: bool) -> XNonce { let mut n = [0u8; NONCE_LEN]; n[..NONCE_PREFIX_LEN].copy_from_slice(prefix); n[NONCE_PREFIX_LEN..NONCE_PREFIX_LEN + COUNTER_LEN].copy_from_slice(&counter.to_be_bytes()); @@ -55,36 +58,72 @@ pub fn derive_key( Ok(out) } +/// Build the AEAD cipher from the protected key. The cipher holds an +/// unprotected copy of the key while alive; `chacha20poly1305` zeroizes that +/// copy on drop. Wrapping in `Arc` lets us share it across worker threads. +fn build_aead(key: &SecretBytes32) -> Arc { + Arc::new(key.with_array(|key| XChaCha20Poly1305::new(key.into()))) +} + +/// Bump the per-chunk counter; surface a domain error on overflow rather than +/// panicking on debug or wrapping in release. +pub(crate) fn bump_counter(counter: u32) -> Result { + counter + .checked_add(1) + .ok_or_else(|| FcryError::Format("STREAM counter overflow (input too large)".into())) +} + pub fn encrypt>( input_file: Option, output_file: Option, key: &SecretBytes32, chunk_size: u32, kdf: KdfParams, + threads: usize, ) -> Result<(), FcryError> { let chunk_sz = chunk_size as usize; - let mut f_plain = AheadReader::from(open_input(input_file)?, chunk_sz); + let input = open_input(input_file)?; + let plaintext_length = input.length; + let mut f_plain = AheadReader::from(input.reader, chunk_sz); let mut f_encrypted = OutSink::open(output_file)?; let mut nonce_prefix = [0u8; NONCE_PREFIX_LEN]; getrandom::fill(&mut nonce_prefix)?; + let flags = if plaintext_length.is_some() { + FLAG_LENGTH_COMMITTED + } else { + 0 + }; let header = Header { alg: AlgId::XChaCha20Poly1305, - flags: 0, + flags, chunk_size, kdf, nonce_prefix, + plaintext_length, }; - let aad = header.encode(); + let aad = Arc::new(header.encode()); f_encrypted.write_all(&aad)?; - // The AEAD keeps its own unprotected key copy while the loop runs. - // chacha20poly1305 zeroizes that copy on drop. - let aead = key.with_array(|key| XChaCha20Poly1305::new(key.into())); + let aead = build_aead(key); + + if threads > 1 { + return pipeline::encrypt_parallel( + f_plain, + f_encrypted, + aead, + aad, + nonce_prefix, + chunk_sz, + threads, + plaintext_length, + ); + } let mut buf = vec![0u8; chunk_sz]; let mut counter: u32 = 0; + let mut bytes_seen: u64 = 0; loop { match f_plain.read_ahead(&mut buf)? { @@ -93,15 +132,15 @@ pub fn encrypt>( aead.encrypt_in_place(&nonce, &aad, &mut buf)?; f_encrypted.write_all(&buf)?; buf.truncate(chunk_sz); - counter = counter.checked_add(1).ok_or_else(|| { - FcryError::Format("STREAM counter overflow (input too large)".into()) - })?; + bytes_seen = bytes_seen.saturating_add(chunk_sz as u64); + counter = bump_counter(counter)?; } ReadInfoChunk::Last(n) => { buf.truncate(n); let nonce = make_nonce(&nonce_prefix, counter, true); aead.encrypt_in_place(&nonce, &aad, &mut buf)?; f_encrypted.write_all(&buf)?; + bytes_seen = bytes_seen.saturating_add(n as u64); break; } ReadInfoChunk::Empty => { @@ -116,6 +155,17 @@ pub fn encrypt>( } } + if let Some(committed) = plaintext_length + && committed != bytes_seen + { + // Defense in depth: the input changed between stat and EOF. The + // committed length is part of the AEAD AAD, so any decrypter would + // also surface this, but we prefer to fail before publishing the file. + return Err(FcryError::Format(format!( + "input length changed during encryption: committed {committed}, read {bytes_seen}" + ))); + } + f_encrypted.commit()?; Ok(()) } @@ -125,10 +175,11 @@ pub fn decrypt>( output_file: Option, raw_key: Option<&SecretBytes32>, passphrase: Option<&SecretVec>, + threads: usize, ) -> Result<(), FcryError> { - let mut reader = open_input(input_file)?; + let mut reader = open_input(input_file)?.reader; let header = Header::read(&mut reader)?; - let aad = header.encode(); + let aad = Arc::new(header.encode()); let key = derive_key(&header.kdf, raw_key, passphrase)?; @@ -138,12 +189,24 @@ pub fn decrypt>( let mut f_encrypted = AheadReader::from(reader, cipher_chunk); let mut f_plain = OutSink::open(output_file)?; - // The AEAD keeps its own unprotected key copy while the loop runs. - // chacha20poly1305 zeroizes that copy on drop. - let aead = key.with_array(|key| XChaCha20Poly1305::new(key.into())); + let aead = build_aead(&key); + + if threads > 1 { + return pipeline::decrypt_parallel( + f_encrypted, + f_plain, + aead, + aad, + header.nonce_prefix, + cipher_chunk, + threads, + header.plaintext_length, + ); + } let mut buf = vec![0u8; cipher_chunk]; let mut counter: u32 = 0; + let mut bytes_written: u64 = 0; loop { match f_encrypted.read_ahead(&mut buf)? { @@ -151,16 +214,16 @@ pub fn decrypt>( let nonce = make_nonce(&header.nonce_prefix, counter, false); aead.decrypt_in_place(&nonce, &aad, &mut buf)?; f_plain.write_all(&buf)?; + bytes_written = bytes_written.saturating_add(buf.len() as u64); buf.resize(cipher_chunk, 0); - counter = counter - .checked_add(1) - .ok_or_else(|| FcryError::Format("STREAM counter overflow".into()))?; + counter = bump_counter(counter)?; } ReadInfoChunk::Last(n) => { buf.truncate(n); let nonce = make_nonce(&header.nonce_prefix, counter, true); aead.decrypt_in_place(&nonce, &aad, &mut buf)?; f_plain.write_all(&buf)?; + bytes_written = bytes_written.saturating_add(buf.len() as u64); break; } ReadInfoChunk::Empty => { @@ -171,6 +234,116 @@ pub fn decrypt>( } } + if let Some(committed) = header.plaintext_length + && committed != bytes_written + { + return Err(FcryError::Format(format!( + "decrypted length {bytes_written} disagrees with committed {committed}" + ))); + } + f_plain.commit()?; Ok(()) } + +/// Random-access decrypt of a byte range. Requires a seekable input file +/// whose header has `FLAG_LENGTH_COMMITTED` set, so we know exactly where +/// each ciphertext chunk lives and which chunk is the last (its nonce uses +/// the STREAM last-block flag). +pub fn decrypt_range>( + input_file: &str, + output_file: Option, + raw_key: Option<&SecretBytes32>, + passphrase: Option<&SecretVec>, + offset: u64, + length: u64, +) -> Result<(), FcryError> { + let file = File::open(input_file)?; + let mut reader = BufReader::new(file); + let header = Header::read(&mut reader)?; + let aad = header.encode(); + let header_len = aad.len() as u64; + + let total = header.plaintext_length.ok_or_else(|| { + FcryError::Format( + "random-access decrypt requires a length-committed header (encrypt from a regular file)".into(), + ) + })?; + + let end = offset + .checked_add(length) + .ok_or_else(|| FcryError::Format("offset + length overflows u64".into()))?; + if end > total { + return Err(FcryError::Format(format!( + "range [{offset}, {end}) exceeds plaintext length {total}" + ))); + } + + let key = derive_key(&header.kdf, raw_key, passphrase)?; + let aead = build_aead(&key); + + let chunk_sz = header.chunk_size as u64; + let cipher_chunk = chunk_sz + TAG_LEN as u64; + + // Layout invariants: + // n_chunks = ceil(total / chunk_sz), but always ≥ 1 (the empty file + // still authenticates a single empty "last" chunk). + // last_idx = n_chunks - 1 + // last_pt = total - last_idx * chunk_sz (in [0, chunk_sz]) + let (n_chunks, last_pt) = if total == 0 { + (1u64, 0u64) + } else { + let n = total.div_ceil(chunk_sz); + let last = total - (n - 1) * chunk_sz; + (n, last) + }; + let last_idx = n_chunks - 1; + + let mut out = OutSink::open(output_file)?; + + if length == 0 { + out.commit()?; + return Ok(()); + } + + let start_chunk = offset / chunk_sz; + let end_chunk = (end - 1) / chunk_sz; + + // Reusable buffer sized to a full chunk + tag. + let mut buf = Vec::with_capacity(cipher_chunk as usize); + + let mut file = reader.into_inner(); + + for i in start_chunk..=end_chunk { + let i_u32 = + u32::try_from(i).map_err(|_| FcryError::Format("chunk index exceeds u32".into()))?; + let is_last = i == last_idx; + let cipher_len = if is_last { + last_pt + TAG_LEN as u64 + } else { + cipher_chunk + }; + let cipher_len_usz = + usize::try_from(cipher_len).map_err(|_| FcryError::Format("chunk too big".into()))?; + + let chunk_offset = header_len + i * cipher_chunk; + file.seek(SeekFrom::Start(chunk_offset))?; + buf.clear(); + buf.resize(cipher_len_usz, 0); + file.read_exact(&mut buf)?; + + let nonce = make_nonce(&header.nonce_prefix, i_u32, is_last); + aead.decrypt_in_place(&nonce, &aad, &mut buf)?; + + // `buf` is now plaintext for this chunk. Compute the chunk's plaintext + // window in absolute bytes and intersect with the requested range. + let chunk_start = i * chunk_sz; + let chunk_end = chunk_start + buf.len() as u64; + let lo = offset.max(chunk_start) - chunk_start; + let hi = end.min(chunk_end) - chunk_start; + out.write_all(&buf[lo as usize..hi as usize])?; + } + + out.commit()?; + Ok(()) +} diff --git a/src/header.rs b/src/header.rs index 0009478..aef90d2 100644 --- a/src/header.rs +++ b/src/header.rs @@ -4,34 +4,49 @@ //! //! Layout: //! ```text -//! magic "fcry" 4 bytes -//! version u8 1 -//! alg_id u8 1 -//! flags u8 1 -//! reserved u8 1 (must be 0) -//! chunk_size u32 LE 4 (plaintext bytes per chunk) -//! kdf_id u8 1 -//! kdf_params variable (depends on kdf_id) -//! nonce_prefix [u8; 19] 19 (STREAM nonce prefix) +//! magic "fcry" 4 bytes +//! version u8 1 +//! alg_id u8 1 +//! flags u8 1 +//! reserved u8 1 (must be 0) +//! chunk_size u32 LE 4 (plaintext bytes per chunk) +//! kdf_id u8 1 +//! kdf_params variable (depends on kdf_id) +//! nonce_prefix [u8; 19] 19 (STREAM nonce prefix) +//! plaintext_length u64 LE 8 (only if version >= 2 and flags & 0x01) //! --- end of header --- -//! chunk[0..N] each chunk_size + 16 bytes, -//! last may be shorter +//! chunk[0..N] each chunk_size + 16 bytes, +//! last may be shorter //! ``` //! //! The full encoded header is fed as AAD to every chunk, so any tampering -//! with chunk_size, nonce_prefix, kdf params, etc. causes authentication -//! failure on every chunk. +//! with chunk_size, nonce_prefix, kdf params, plaintext_length, etc. causes +//! authentication failure on every chunk. +//! +//! Versions: +//! * v1 — no length committed, no flag bits used. +//! * v2 — adds `FLAG_LENGTH_COMMITTED` (bit 0); when set, the total plaintext +//! length is appended after `nonce_prefix`. This enables random-access +//! decryption without scanning predecessors. use std::io::Read; use crate::error::FcryError; const MAGIC: [u8; 4] = *b"fcry"; -const VERSION: u8 = 1; +const VERSION_CURRENT: u8 = 2; +const VERSION_MIN: u8 = 1; pub const NONCE_PREFIX_LEN: usize = 19; pub const TAG_LEN: usize = 16; +/// Set in `flags` when the header carries an authenticated `plaintext_length` +/// field. Required for random-access decryption. +pub const FLAG_LENGTH_COMMITTED: u8 = 0x01; + +/// Mask of all flag bits this build understands. Unknown bits → reject. +const FLAG_KNOWN_MASK: u8 = FLAG_LENGTH_COMMITTED; + #[repr(u8)] #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum AlgId { @@ -121,13 +136,15 @@ pub struct Header { pub chunk_size: u32, pub kdf: KdfParams, pub nonce_prefix: [u8; NONCE_PREFIX_LEN], + /// Total plaintext byte count. `Some` iff `flags & FLAG_LENGTH_COMMITTED`. + pub plaintext_length: Option, } impl Header { pub fn encode(&self) -> Vec { - let mut out = Vec::with_capacity(64); + let mut out = Vec::with_capacity(72); out.extend_from_slice(&MAGIC); - out.push(VERSION); + out.push(VERSION_CURRENT); out.push(self.alg as u8); out.push(self.flags); out.push(0); // reserved @@ -135,6 +152,12 @@ impl Header { out.push(self.kdf.id()); self.kdf.write_into(&mut out); out.extend_from_slice(&self.nonce_prefix); + if (self.flags & FLAG_LENGTH_COMMITTED) != 0 { + let len = self + .plaintext_length + .expect("FLAG_LENGTH_COMMITTED set but plaintext_length is None"); + out.extend_from_slice(&len.to_le_bytes()); + } out } @@ -148,12 +171,20 @@ impl Header { let mut fixed = [0u8; 4]; r.read_exact(&mut fixed)?; let [version, alg_id, flags, reserved] = fixed; - if version != VERSION { + if !(VERSION_MIN..=VERSION_CURRENT).contains(&version) { return Err(FcryError::Format(format!("unsupported version: {version}"))); } if reserved != 0 { return Err(FcryError::Format("reserved byte must be zero".into())); } + if (flags & !FLAG_KNOWN_MASK) != 0 { + return Err(FcryError::Format(format!( + "unknown flag bits: 0x{flags:02x}" + ))); + } + if version < 2 && flags != 0 { + return Err(FcryError::Format("v1 header must have flags == 0".into())); + } let alg = AlgId::from_u8(alg_id)?; let mut chunk_size_bytes = [0u8; 4]; @@ -170,12 +201,21 @@ impl Header { let mut nonce_prefix = [0u8; NONCE_PREFIX_LEN]; r.read_exact(&mut nonce_prefix)?; + let plaintext_length = if (flags & FLAG_LENGTH_COMMITTED) != 0 { + let mut b = [0u8; 8]; + r.read_exact(&mut b)?; + Some(u64::from_le_bytes(b)) + } else { + None + }; + Ok(Self { alg, flags, chunk_size, kdf, nonce_prefix, + plaintext_length, }) } } @@ -193,6 +233,7 @@ mod tests { chunk_size: 1024 * 1024, kdf: KdfParams::Raw, nonce_prefix: [7u8; NONCE_PREFIX_LEN], + plaintext_length: None, }; let bytes = h.encode(); let mut cur = Cursor::new(&bytes); @@ -201,6 +242,25 @@ mod tests { assert_eq!(parsed.flags, h.flags); assert_eq!(parsed.chunk_size, h.chunk_size); assert_eq!(parsed.nonce_prefix, h.nonce_prefix); + assert_eq!(parsed.plaintext_length, None); + assert_eq!(cur.position() as usize, bytes.len()); + } + + #[test] + fn roundtrip_length_committed() { + let h = Header { + alg: AlgId::XChaCha20Poly1305, + flags: FLAG_LENGTH_COMMITTED, + chunk_size: 65536, + kdf: KdfParams::Raw, + nonce_prefix: [9u8; NONCE_PREFIX_LEN], + plaintext_length: Some(123_456_789), + }; + let bytes = h.encode(); + let mut cur = Cursor::new(&bytes); + let parsed = Header::read(&mut cur).unwrap(); + assert_eq!(parsed.flags, FLAG_LENGTH_COMMITTED); + assert_eq!(parsed.plaintext_length, Some(123_456_789)); assert_eq!(cur.position() as usize, bytes.len()); } @@ -212,6 +272,7 @@ mod tests { chunk_size: 4096, kdf: KdfParams::Raw, nonce_prefix: [0u8; NONCE_PREFIX_LEN], + plaintext_length: None, } .encode(); bytes[0] ^= 1; @@ -220,4 +281,41 @@ mod tests { Err(FcryError::Format(_)) )); } + + #[test] + fn rejects_unknown_flag_bits() { + let mut bytes = Header { + alg: AlgId::XChaCha20Poly1305, + flags: 0, + chunk_size: 4096, + kdf: KdfParams::Raw, + nonce_prefix: [0u8; NONCE_PREFIX_LEN], + plaintext_length: None, + } + .encode(); + // flags byte is at offset 6 (4 magic + version + alg) + bytes[6] = 0x80; + assert!(matches!( + Header::read(&mut Cursor::new(&bytes)), + Err(FcryError::Format(_)) + )); + } + + #[test] + fn reads_v1_header() { + // hand-crafted v1 header (raw kdf, no length field) + let mut bytes = Vec::new(); + bytes.extend_from_slice(b"fcry"); + bytes.push(1); // version + bytes.push(1); // alg + bytes.push(0); // flags + bytes.push(0); // reserved + bytes.extend_from_slice(&1024u32.to_le_bytes()); + bytes.push(0); // kdf id raw + bytes.extend_from_slice(&[3u8; NONCE_PREFIX_LEN]); + let parsed = Header::read(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(parsed.flags, 0); + assert_eq!(parsed.chunk_size, 1024); + assert_eq!(parsed.plaintext_length, None); + } } diff --git a/src/main.rs b/src/main.rs index 03e28d5..a06b81e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ mod crypto; mod error; mod header; +mod pipeline; mod reader; mod secrets; mod utils; @@ -63,6 +64,32 @@ struct Cli { /// Argon2id parallelism / lanes (encryption only). #[clap(long, default_value_t = 4)] argon_parallelism: u32, + + /// Number of worker threads for AEAD work. Defaults to the number of + /// available CPUs. Set to 1 for fully serial encrypt/decrypt. + #[clap(short = 'j', long)] + threads: Option, + + /// Random-access decrypt: byte offset of the slice to read. + /// Requires `--decrypt`, an `--input-file` whose header has the + /// length-committed flag set, and `--length`. + #[clap( + long, + requires = "length", + requires = "decrypt", + requires = "input_file" + )] + offset: Option, + + /// Random-access decrypt: byte length of the slice to read. + /// Requires `--decrypt`, `--input-file`, and `--offset`. + #[clap( + long, + requires = "offset", + requires = "decrypt", + requires = "input_file" + )] + length: Option, } fn parse_raw_key(s: &str) -> Result { @@ -148,6 +175,13 @@ fn run(mut cli: Cli) -> Result<(), FcryError> { let argon_memory = cli.argon_memory; let argon_passes = cli.argon_passes; let argon_parallelism = cli.argon_parallelism; + let threads = cli.threads.unwrap_or_else(|| { + std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(1) + }); + let offset = cli.offset; + let length = cli.length; drop(cli); if decrypt_mode { @@ -159,7 +193,26 @@ fn run(mut cli: Cli) -> Result<(), FcryError> { Some(src) => Some(read_passphrase(src, false)?), None => None, }; - decrypt(input, output, raw_key.as_ref(), pw.as_ref())?; + match (offset, length) { + (Some(o), Some(l)) => { + // clap's `requires` makes this unreachable, but keep the + // dynamic check so the failure mode is a clean error. + let path = input.as_deref().ok_or_else(|| { + FcryError::Format( + "--offset/--length require --input-file (random-access needs a seekable file)".into(), + ) + })?; + decrypt_range(path, output, raw_key.as_ref(), pw.as_ref(), o, l)?; + } + (None, None) => { + decrypt(input, output, raw_key.as_ref(), pw.as_ref(), threads)?; + } + _ => { + return Err(FcryError::Format( + "--offset and --length must be supplied together".into(), + )); + } + } } else { let (key, kdf) = if let Some(src) = &pw_src { let mut salt = [0u8; ARGON2_SALT_LEN]; @@ -180,7 +233,7 @@ fn run(mut cli: Cli) -> Result<(), FcryError> { let key = parse_raw_key(raw_key_str.as_deref().unwrap())?; (key, KdfParams::Raw) }; - encrypt(input, output, &key, chunk_size, kdf)?; + encrypt(input, output, &key, chunk_size, kdf, threads)?; } Ok(()) diff --git a/src/pipeline.rs b/src/pipeline.rs new file mode 100644 index 0000000..e452606 --- /dev/null +++ b/src/pipeline.rs @@ -0,0 +1,377 @@ +// SPDX-License-Identifier: GPL-3.0-only + +//! Multi-threaded encrypt/decrypt pipeline. +//! +//! Topology: +//! +//! ```text +//! reader thread → jobs (bounded MPMC) → N AEAD workers → +//! → results (bounded MPMC) → writer thread +//! ``` +//! +//! The reader is sequential (one input handle, lookahead detects last chunk), +//! workers parallelize the AEAD step (independent per chunk), and the writer +//! reorders results by counter before writing them to `OutSink`. The job +//! channel is bounded to give backpressure; the writer's reorder buffer is +//! also bounded so a slow worker can stall the pipeline rather than leaking +//! memory. +//! +//! Peak memory ≈ chunk_size × (jobs_capacity + workers + reorder_capacity + 2) +//! — for 1 MiB chunks and 8 cores that's ~32 MiB, which we accept. + +use std::collections::BTreeMap; +use std::io::Write; +use std::sync::Arc; +use std::thread::{self, JoinHandle}; + +use chacha20poly1305::{XChaCha20Poly1305, aead::AeadInPlace}; +use crossbeam_channel::{Receiver, Sender, bounded}; + +use crate::crypto::{bump_counter, make_nonce}; +use crate::error::FcryError; +use crate::header::NONCE_PREFIX_LEN; +use crate::reader::{AheadReader, ReadInfoChunk}; +use crate::utils::OutSink; + +struct Job { + counter: u32, + last: bool, + buf: Vec, +} + +struct Done { + counter: u32, + buf: Vec, +} + +/// Channel sizing: small multiples of worker count, enough to keep workers +/// fed without unbounded memory. +fn channel_capacity(threads: usize) -> usize { + (threads * 2).max(2) +} + +/// Reorder buffer cap: drop into back-pressure once we've held this many +/// out-of-order chunks. With uniform AEAD work this rarely fills. +fn reorder_capacity(threads: usize) -> usize { + (threads * 2).max(2) +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn encrypt_parallel( + mut input: AheadReader, + output: OutSink, + aead: Arc, + aad: Arc>, + nonce_prefix: [u8; NONCE_PREFIX_LEN], + chunk_sz: usize, + threads: usize, + expected_length: Option, +) -> Result<(), FcryError> { + let cap = channel_capacity(threads); + let (jobs_tx, jobs_rx) = bounded::(cap); + let (done_tx, done_rx) = bounded::(cap); + + // Reader thread: drives AheadReader, dispatches jobs in counter order. + let reader_handle: JoinHandle> = thread::spawn(move || { + let mut counter: u32 = 0; + let mut bytes_seen: u64 = 0; + loop { + let mut buf = vec![0u8; chunk_sz]; + match input.read_ahead(&mut buf)? { + ReadInfoChunk::Normal(_) => { + if jobs_tx + .send(Job { + counter, + last: false, + buf, + }) + .is_err() + { + return Ok(bytes_seen); + } + bytes_seen = bytes_seen.saturating_add(chunk_sz as u64); + counter = bump_counter(counter)?; + } + ReadInfoChunk::Last(n) => { + buf.truncate(n); + let _ = jobs_tx.send(Job { + counter, + last: true, + buf, + }); + bytes_seen = bytes_seen.saturating_add(n as u64); + return Ok(bytes_seen); + } + ReadInfoChunk::Empty => { + buf.clear(); + let _ = jobs_tx.send(Job { + counter, + last: true, + buf, + }); + return Ok(bytes_seen); + } + } + } + }); + + // Worker threads: AEAD encrypt in place, ship to writer. + let mut worker_handles: Vec>> = Vec::with_capacity(threads); + for _ in 0..threads { + let jobs_rx = jobs_rx.clone(); + let done_tx = done_tx.clone(); + let aead = aead.clone(); + let aad = aad.clone(); + worker_handles.push(thread::spawn(move || { + for mut job in jobs_rx.iter() { + let nonce = make_nonce(&nonce_prefix, job.counter, job.last); + aead.encrypt_in_place(&nonce, aad.as_slice(), &mut job.buf)?; + if done_tx + .send(Done { + counter: job.counter, + buf: job.buf, + }) + .is_err() + { + break; + } + } + Ok(()) + })); + } + drop(jobs_rx); + drop(done_tx); + + // Writer thread: ordered writeback. Commit is deferred to the main + // thread so a failure anywhere in the pipeline leaves the temp file to + // be unlinked by `OutSink::drop` instead of being renamed into place. + let writer_handle: JoinHandle> = { + let cap = reorder_capacity(threads); + thread::spawn(move || ordered_writer(done_rx, output, cap)) + }; + + // Join everything; surface the first error. + let reader_res = reader_handle.join().expect("reader thread panicked"); + let mut first_err: Option = None; + let bytes_seen = match reader_res { + Ok(n) => Some(n), + Err(e) => { + first_err.get_or_insert(e); + None + } + }; + for h in worker_handles { + if let Err(e) = h.join().expect("worker thread panicked") + && first_err.is_none() + { + first_err = Some(e); + } + } + let writer_res = writer_handle.join().expect("writer thread panicked"); + let sink = match writer_res { + Ok(s) => Some(s), + Err(e) => { + if first_err.is_none() { + first_err = Some(e); + } + None + } + }; + + if let Some(e) = first_err { + // Drop `sink` here without committing so the temp file is unlinked. + return Err(e); + } + + if let (Some(committed), Some(seen)) = (expected_length, bytes_seen) + && committed != seen + { + return Err(FcryError::Format(format!( + "input length changed during encryption: committed {committed}, read {seen}" + ))); + } + + sink.expect("no error but no sink").commit()?; + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn decrypt_parallel( + mut input: AheadReader, + output: OutSink, + aead: Arc, + aad: Arc>, + nonce_prefix: [u8; NONCE_PREFIX_LEN], + cipher_chunk: usize, + threads: usize, + expected_length: Option, +) -> Result<(), FcryError> { + let cap = channel_capacity(threads); + let (jobs_tx, jobs_rx) = bounded::(cap); + let (done_tx, done_rx) = bounded::(cap); + + let reader_handle: JoinHandle> = thread::spawn(move || { + let mut counter: u32 = 0; + loop { + let mut buf = vec![0u8; cipher_chunk]; + match input.read_ahead(&mut buf)? { + ReadInfoChunk::Normal(_) => { + if jobs_tx + .send(Job { + counter, + last: false, + buf, + }) + .is_err() + { + return Ok(()); + } + counter = bump_counter(counter)?; + } + ReadInfoChunk::Last(n) => { + buf.truncate(n); + let _ = jobs_tx.send(Job { + counter, + last: true, + buf, + }); + return Ok(()); + } + ReadInfoChunk::Empty => { + return Err(FcryError::Format( + "truncated ciphertext: missing final chunk".into(), + )); + } + } + } + }); + + let mut worker_handles: Vec>> = Vec::with_capacity(threads); + for _ in 0..threads { + let jobs_rx = jobs_rx.clone(); + let done_tx = done_tx.clone(); + let aead = aead.clone(); + let aad = aad.clone(); + worker_handles.push(thread::spawn(move || { + for mut job in jobs_rx.iter() { + let nonce = make_nonce(&nonce_prefix, job.counter, job.last); + aead.decrypt_in_place(&nonce, aad.as_slice(), &mut job.buf)?; + if done_tx + .send(Done { + counter: job.counter, + buf: job.buf, + }) + .is_err() + { + break; + } + } + Ok(()) + })); + } + drop(jobs_rx); + drop(done_tx); + + let writer_handle: JoinHandle> = { + let cap = reorder_capacity(threads); + thread::spawn(move || ordered_writer_counted(done_rx, output, cap)) + }; + + let reader_res = reader_handle.join().expect("reader thread panicked"); + let mut first_err: Option = None; + if let Err(e) = reader_res { + first_err = Some(e); + } + for h in worker_handles { + if let Err(e) = h.join().expect("worker thread panicked") + && first_err.is_none() + { + first_err = Some(e); + } + } + let writer_res = writer_handle.join().expect("writer thread panicked"); + let written = match writer_res { + Ok((sink, n)) => Some((sink, n)), + Err(e) => { + if first_err.is_none() { + first_err = Some(e); + } + None + } + }; + + if let Some(e) = first_err { + return Err(e); + } + + let (sink, n) = written.expect("no error but no sink"); + + if let Some(committed) = expected_length + && committed != n + { + return Err(FcryError::Format(format!( + "decrypted length {n} disagrees with committed {committed}" + ))); + } + + sink.commit()?; + Ok(()) +} + +/// Drain `done_rx` in counter order and stream into `output`. Returns the +/// `OutSink` ownership back to the caller without committing — the caller +/// commits only after every other thread has joined cleanly, so a failure +/// anywhere in the pipeline drops the sink and unlinks the temp file. +fn ordered_writer( + done_rx: Receiver, + mut output: OutSink, + _cap: usize, +) -> Result { + let mut next: u32 = 0; + let mut pending: BTreeMap> = BTreeMap::new(); + for done in done_rx.iter() { + pending.insert(done.counter, done.buf); + while let Some(buf) = pending.remove(&next) { + output.write_all(&buf)?; + next = next.wrapping_add(1); + } + } + if !pending.is_empty() { + return Err(FcryError::Format( + "internal: ordered writer left chunks unflushed".into(), + )); + } + Ok(output) +} + +/// Same as `ordered_writer` but also returns total bytes written for the +/// length-committed cross-check on decrypt. +fn ordered_writer_counted( + done_rx: Receiver, + mut output: OutSink, + _cap: usize, +) -> Result<(OutSink, u64), FcryError> { + let mut next: u32 = 0; + let mut pending: BTreeMap> = BTreeMap::new(); + let mut total: u64 = 0; + for done in done_rx.iter() { + pending.insert(done.counter, done.buf); + while let Some(buf) = pending.remove(&next) { + output.write_all(&buf)?; + total = total.saturating_add(buf.len() as u64); + next = next.wrapping_add(1); + } + } + if !pending.is_empty() { + return Err(FcryError::Format( + "internal: ordered writer left chunks unflushed".into(), + )); + } + Ok((output, total)) +} + +// Suppress unused warnings on Sender clones inside the worker loop; the +// channel is closed when its last sender is dropped which is what we want. +#[allow(dead_code)] +fn _assert_send_sync() {} +const _: fn() = || _assert_send_sync::>(); diff --git a/src/reader.rs b/src/reader.rs index 0531ada..86dad5e 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -10,14 +10,14 @@ pub enum ReadInfoChunk { } pub struct AheadReader { - inner: Box, + inner: Box, buf: Vec, bufsz: usize, capacity: usize, } impl AheadReader { - pub fn from(reader: Box, capacity: usize) -> Self { + pub fn from(reader: Box, capacity: usize) -> Self { Self { inner: reader, buf: vec![0; capacity], diff --git a/src/utils.rs b/src/utils.rs index 060ad39..b5fffc1 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -10,11 +10,39 @@ use std::path::PathBuf; /// breaking older files (the decryptor reads the size from the header). pub const DEFAULT_CHUNK_SIZE: u32 = 1024 * 1024; -pub(crate) fn open_input>(input_file: Option) -> io::Result> { - Ok(match input_file { - Some(f) => Box::new(BufReader::new(File::open(f.as_ref())?)), - None => Box::new(io::stdin().lock()), - }) +/// Opened input. +/// +/// `length` is `Some(n)` only when the source is a regular file (we stat the +/// open FD to avoid TOCTOU). For stdin, FIFOs, sockets, char devices, etc. +/// it is `None` — those paths cannot commit a length in the header. +pub(crate) struct Input { + pub reader: Box, + pub length: Option, +} + +pub(crate) fn open_input>(input_file: Option) -> io::Result { + match input_file { + Some(f) => { + let file = File::open(f.as_ref())?; + // Stat the open FD (not the path) so we can't be raced between + // stat and open. + let length = file + .metadata() + .ok() + .filter(|m| m.is_file()) + .map(|m| m.len()); + Ok(Input { + reader: Box::new(BufReader::new(file)), + length, + }) + } + None => Ok(Input { + // `Stdin` is `Send` (unlike `StdinLock`), so wrap it in a + // `BufReader` and box for cross-thread use in the parallel pipeline. + reader: Box::new(BufReader::new(io::stdin())), + length: None, + }), + } } /// Output sink that supports atomic file replacement. diff --git a/tests/roundtrip.rs b/tests/roundtrip.rs index 0b36ab0..db87514 100644 --- a/tests/roundtrip.rs +++ b/tests/roundtrip.rs @@ -422,6 +422,307 @@ fn atomic_output_no_stale_tmp_on_failure() { assert!(!tmp.exists(), "temp file must be cleaned up"); } +// --------------------------------------------------------------------------- +// Multi-threaded pipeline + length-committed + random-access tests +// --------------------------------------------------------------------------- + +fn encrypt_file_threads( + plain: &std::path::Path, + ct: &std::path::Path, + chunk_size: Option, + threads: usize, +) { + let mut cmd = fcry(); + cmd.arg("-i") + .arg(plain) + .arg("-o") + .arg(ct) + .arg("--raw-key") + .arg(KEY_STR) + .arg("-j") + .arg(threads.to_string()); + if let Some(cs) = chunk_size { + cmd.arg("--chunk-size").arg(cs.to_string()); + } + let out = cmd.output().unwrap(); + assert!( + out.status.success(), + "encrypt -j{threads} failed: {}", + String::from_utf8_lossy(&out.stderr) + ); +} + +fn decrypt_file_threads(ct: &std::path::Path, rt: &std::path::Path, threads: usize) { + let out = fcry() + .arg("-d") + .arg("-i") + .arg(ct) + .arg("-o") + .arg(rt) + .arg("--raw-key") + .arg(KEY_STR) + .arg("-j") + .arg(threads.to_string()) + .output() + .unwrap(); + assert!( + out.status.success(), + "decrypt -j{threads} failed: {}", + String::from_utf8_lossy(&out.stderr) + ); +} + +#[test] +fn roundtrip_multi_threaded() { + // Multi-chunk input. Encrypt+decrypt with -j 4 must round-trip. + let dir = TempDir::new().unwrap(); + let plain = dir.path().join("p.bin"); + let ct = dir.path().join("c.bin"); + let rt = dir.path().join("r.bin"); + let data = pseudo_random(11, 5 * 1024 * 1024 + 12345); + fs::write(&plain, &data).unwrap(); + + encrypt_file_threads(&plain, &ct, Some(64 * 1024), 4); + decrypt_file_threads(&ct, &rt, 4); + assert_eq!(fs::read(&rt).unwrap(), data); +} + +#[test] +fn parallel_and_serial_outputs_round_trip() { + // Encrypt with -j 4 and decrypt serially (and vice-versa); both directions + // must yield the original plaintext. + let dir = TempDir::new().unwrap(); + let plain = dir.path().join("p.bin"); + let data = pseudo_random(13, 256 * 1024 + 17); + fs::write(&plain, &data).unwrap(); + + let ct_par = dir.path().join("c_par.bin"); + let ct_ser = dir.path().join("c_ser.bin"); + encrypt_file_threads(&plain, &ct_par, Some(8192), 4); + encrypt_file_threads(&plain, &ct_ser, Some(8192), 1); + + let rt1 = dir.path().join("r1.bin"); + let rt2 = dir.path().join("r2.bin"); + // par-encrypted, serial-decrypted + decrypt_file_threads(&ct_par, &rt1, 1); + // serial-encrypted, par-decrypted + decrypt_file_threads(&ct_ser, &rt2, 4); + assert_eq!(fs::read(&rt1).unwrap(), data); + assert_eq!(fs::read(&rt2).unwrap(), data); +} + +#[test] +fn roundtrip_pipe_multi_threaded() { + // stdin/stdout mode with -j 4: length flag must NOT be set (no committed + // length when we don't know the input size), but encrypt/decrypt must still + // round-trip cleanly across the pipeline. + let data = pseudo_random(14, 200_000); + + let mut enc = fcry() + .arg("--raw-key") + .arg(KEY_STR) + .arg("-j") + .arg("4") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .unwrap(); + enc.stdin.as_mut().unwrap().write_all(&data).unwrap(); + let enc_out = enc.wait_with_output().unwrap(); + assert!( + enc_out.status.success(), + "pipe encrypt -j4 failed: {}", + String::from_utf8_lossy(&enc_out.stderr) + ); + + // flags byte at offset 6 must be 0 (no length committed for stdin input). + assert_eq!( + enc_out.stdout[6], 0, + "stdin-encrypted file unexpectedly committed length" + ); + + let mut dec = fcry() + .arg("-d") + .arg("--raw-key") + .arg(KEY_STR) + .arg("-j") + .arg("4") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .unwrap(); + dec.stdin + .as_mut() + .unwrap() + .write_all(&enc_out.stdout) + .unwrap(); + let dec_out = dec.wait_with_output().unwrap(); + assert!( + dec_out.status.success(), + "pipe decrypt -j4 failed: {}", + String::from_utf8_lossy(&dec_out.stderr) + ); + assert_eq!(dec_out.stdout, data); +} + +#[test] +fn file_input_commits_length() { + // Encrypting from a regular file must auto-set FLAG_LENGTH_COMMITTED (bit 0 + // of the flags byte at offset 6) and embed the length. + let dir = TempDir::new().unwrap(); + let plain = dir.path().join("p.bin"); + let ct = dir.path().join("c.bin"); + let data = pseudo_random(15, 50_000); + fs::write(&plain, &data).unwrap(); + encrypt_file(&plain, &ct, Some(4096)); + + let bytes = fs::read(&ct).unwrap(); + // Magic(4) + version(1) + alg(1) + flags(1) = byte 6 + assert_eq!(bytes[4], 2, "version should be 2"); + assert_eq!(bytes[6] & 0x01, 0x01, "length-committed flag should be set"); +} + +fn encrypt_random_access_fixture( + dir: &std::path::Path, + data: &[u8], + chunk_size: u32, +) -> std::path::PathBuf { + let plain = dir.join("p.bin"); + let ct = dir.join("c.bin"); + fs::write(&plain, data).unwrap(); + encrypt_file(&plain, &ct, Some(chunk_size)); + ct +} + +fn random_access_decrypt( + ct: &std::path::Path, + out: &std::path::Path, + offset: u64, + length: u64, +) -> std::process::Output { + fcry() + .arg("-d") + .arg("-i") + .arg(ct) + .arg("-o") + .arg(out) + .arg("--raw-key") + .arg(KEY_STR) + .arg("--offset") + .arg(offset.to_string()) + .arg("--length") + .arg(length.to_string()) + .output() + .unwrap() +} + +#[test] +fn random_access_decrypt_slices() { + let dir = TempDir::new().unwrap(); + let chunk = 4096u32; + let total = 5 * 1024 * 1024 + 12345; + let data = pseudo_random(16, total); + let ct = encrypt_random_access_fixture(dir.path(), &data, chunk); + + // (offset, length) cases: + // - chunk-aligned start, mid-chunk end + // - mid-chunk start crossing several chunks + // - last partial chunk + // - last byte + // - entire file + let cases: &[(u64, u64)] = &[ + (0, 1), + (chunk as u64, 7), + (chunk as u64 - 5, 100), + (10, chunk as u64 * 3 + 17), + (total as u64 - 1, 1), + (total as u64 - 100, 100), + (0, total as u64), + ]; + for (i, (offset, length)) in cases.iter().copied().enumerate() { + let out = dir.path().join(format!("slice_{i}.bin")); + let r = random_access_decrypt(&ct, &out, offset, length); + assert!( + r.status.success(), + "slice {i} ({offset}, {length}) failed: {}", + String::from_utf8_lossy(&r.stderr) + ); + let got = fs::read(&out).unwrap(); + let expected = &data[offset as usize..(offset + length) as usize]; + assert_eq!(got, expected, "slice {i} mismatch"); + } +} + +#[test] +fn random_access_rejects_out_of_range() { + let dir = TempDir::new().unwrap(); + let data = pseudo_random(17, 1000); + let ct = encrypt_random_access_fixture(dir.path(), &data, 256); + let out = dir.path().join("oob.bin"); + let r = random_access_decrypt(&ct, &out, 900, 1000); // 900+1000 > 1000 + assert!(!r.status.success(), "out-of-range slice should fail"); +} + +#[test] +fn random_access_rejects_stdin_encrypted() { + // Encrypt via stdin → no length committed → random access must refuse. + let data = pseudo_random(18, 2000); + let dir = TempDir::new().unwrap(); + let ct = dir.path().join("c.bin"); + + let mut enc = fcry() + .arg("--raw-key") + .arg(KEY_STR) + .arg("-o") + .arg(&ct) + .stdin(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .unwrap(); + enc.stdin.as_mut().unwrap().write_all(&data).unwrap(); + assert!(enc.wait().unwrap().success()); + + let out = dir.path().join("slice.bin"); + let r = random_access_decrypt(&ct, &out, 0, 100); + assert!( + !r.status.success(), + "random access on stdin-encrypted file should fail" + ); +} + +#[test] +fn random_access_zero_length() { + let dir = TempDir::new().unwrap(); + let data = pseudo_random(19, 1000); + let ct = encrypt_random_access_fixture(dir.path(), &data, 256); + let out = dir.path().join("empty.bin"); + let r = random_access_decrypt(&ct, &out, 500, 0); + assert!(r.status.success(), "zero-length slice should succeed"); + assert_eq!(fs::read(&out).unwrap(), Vec::::new()); +} + +#[test] +fn random_access_tampered_length_fails() { + // Flip a byte inside the committed plaintext_length field. The header is + // AAD for every chunk, so the AEAD must reject decryption. + let dir = TempDir::new().unwrap(); + let data = pseudo_random(20, 4000); + let ct = encrypt_random_access_fixture(dir.path(), &data, 1024); + let mut bytes = fs::read(&ct).unwrap(); + // For raw-kdf header: magic(4)+ver(1)+alg(1)+flags(1)+rsv(1)+chunksize(4)+kdf_id(1)+nonce_prefix(19) = 32 + // plaintext_length is at offset 32..40. + bytes[34] ^= 0xff; + fs::write(&ct, &bytes).unwrap(); + let out = dir.path().join("bad.bin"); + let r = random_access_decrypt(&ct, &out, 0, 100); + assert!( + !r.status.success(), + "tampered plaintext_length must fail authentication" + ); +} + #[test] fn header_chunk_size_is_authoritative_on_decrypt() { // Encrypt with a non-default chunk size; decrypt without specifying one.