diff --git a/src/pipeline.rs b/src/pipeline.rs index e452606..c63f675 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -11,21 +11,33 @@ //! //! The reader is sequential (one input handle, lookahead detects last chunk), //! workers parallelize the AEAD step (independent per chunk), and the writer -//! reorders results by counter before writing them to `OutSink`. The job -//! channel is bounded to give backpressure; the writer's reorder buffer is -//! also bounded so a slow worker can stall the pipeline rather than leaking -//! memory. +//! reorders results by counter before writing them to `OutSink`. //! -//! Peak memory ≈ chunk_size × (jobs_capacity + workers + reorder_capacity + 2) -//! — for 1 MiB chunks and 8 cores that's ~32 MiB, which we accept. +//! Bounded memory: a permit channel caps the total number of in-flight chunks +//! (queued jobs + in-progress at workers + pending in the writer's reorder +//! buffer). The reader acquires a permit before sending each job; the writer +//! releases a permit after flushing the chunk in order. A slow or stuck worker +//! therefore stalls the reader rather than letting the writer's reorder buffer +//! grow without bound. +//! +//! Fail-fast: a shared `cancel` flag lets workers signal an authentication or +//! AEAD error to the reader. The reader checks it each iteration and exits +//! early, so a tampered chunk doesn't waste full-file I/O on top of the +//! detection. +//! +//! Peak memory ≈ chunk_size × (in_flight_cap + 2). For 1 MiB chunks and 8 +//! cores (cap = 32) that's ~34 MiB. Adjust `in_flight_capacity` if you need +//! a different memory/throughput tradeoff. use std::collections::BTreeMap; use std::io::Write; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::thread::{self, JoinHandle}; +use std::time::Duration; use chacha20poly1305::{XChaCha20Poly1305, aead::AeadInPlace}; -use crossbeam_channel::{Receiver, Sender, bounded}; +use crossbeam_channel::{Receiver, RecvTimeoutError, Sender, bounded}; use crate::crypto::{bump_counter, make_nonce}; use crate::error::FcryError; @@ -44,21 +56,22 @@ struct Done { buf: Vec, } -/// Channel sizing: small multiples of worker count, enough to keep workers -/// fed without unbounded memory. +/// Job-channel capacity: small multiples of worker count, enough to keep +/// workers fed without unbounded memory. fn channel_capacity(threads: usize) -> usize { (threads * 2).max(2) } -/// Reorder buffer cap: drop into back-pressure once we've held this many -/// out-of-order chunks. With uniform AEAD work this rarely fills. -fn reorder_capacity(threads: usize) -> usize { - (threads * 2).max(2) +/// Total in-flight chunk cap (jobs queued + at workers + in writer's reorder +/// buffer). Permit count; bounded above the job-channel capacity to absorb +/// reordering without blocking workers unnecessarily. +fn in_flight_capacity(threads: usize) -> usize { + (threads * 4).max(4) } #[allow(clippy::too_many_arguments)] pub(crate) fn encrypt_parallel( - mut input: AheadReader, + input: AheadReader, output: OutSink, aead: Arc, aad: Arc>, @@ -67,137 +80,32 @@ pub(crate) fn encrypt_parallel( threads: usize, expected_length: Option, ) -> Result<(), FcryError> { - let cap = channel_capacity(threads); - let (jobs_tx, jobs_rx) = bounded::(cap); - let (done_tx, done_rx) = bounded::(cap); + let (sink, bytes_seen) = run_pipeline( + input, + output, + aead, + aad, + nonce_prefix, + chunk_sz, + threads, + true, + )?; - // Reader thread: drives AheadReader, dispatches jobs in counter order. - let reader_handle: JoinHandle> = thread::spawn(move || { - let mut counter: u32 = 0; - let mut bytes_seen: u64 = 0; - loop { - let mut buf = vec![0u8; chunk_sz]; - match input.read_ahead(&mut buf)? { - ReadInfoChunk::Normal(_) => { - if jobs_tx - .send(Job { - counter, - last: false, - buf, - }) - .is_err() - { - return Ok(bytes_seen); - } - bytes_seen = bytes_seen.saturating_add(chunk_sz as u64); - counter = bump_counter(counter)?; - } - ReadInfoChunk::Last(n) => { - buf.truncate(n); - let _ = jobs_tx.send(Job { - counter, - last: true, - buf, - }); - bytes_seen = bytes_seen.saturating_add(n as u64); - return Ok(bytes_seen); - } - ReadInfoChunk::Empty => { - buf.clear(); - let _ = jobs_tx.send(Job { - counter, - last: true, - buf, - }); - return Ok(bytes_seen); - } - } - } - }); - - // Worker threads: AEAD encrypt in place, ship to writer. - let mut worker_handles: Vec>> = Vec::with_capacity(threads); - for _ in 0..threads { - let jobs_rx = jobs_rx.clone(); - let done_tx = done_tx.clone(); - let aead = aead.clone(); - let aad = aad.clone(); - worker_handles.push(thread::spawn(move || { - for mut job in jobs_rx.iter() { - let nonce = make_nonce(&nonce_prefix, job.counter, job.last); - aead.encrypt_in_place(&nonce, aad.as_slice(), &mut job.buf)?; - if done_tx - .send(Done { - counter: job.counter, - buf: job.buf, - }) - .is_err() - { - break; - } - } - Ok(()) - })); - } - drop(jobs_rx); - drop(done_tx); - - // Writer thread: ordered writeback. Commit is deferred to the main - // thread so a failure anywhere in the pipeline leaves the temp file to - // be unlinked by `OutSink::drop` instead of being renamed into place. - let writer_handle: JoinHandle> = { - let cap = reorder_capacity(threads); - thread::spawn(move || ordered_writer(done_rx, output, cap)) - }; - - // Join everything; surface the first error. - let reader_res = reader_handle.join().expect("reader thread panicked"); - let mut first_err: Option = None; - let bytes_seen = match reader_res { - Ok(n) => Some(n), - Err(e) => { - first_err.get_or_insert(e); - None - } - }; - for h in worker_handles { - if let Err(e) = h.join().expect("worker thread panicked") - && first_err.is_none() - { - first_err = Some(e); - } - } - let writer_res = writer_handle.join().expect("writer thread panicked"); - let sink = match writer_res { - Ok(s) => Some(s), - Err(e) => { - if first_err.is_none() { - first_err = Some(e); - } - None - } - }; - - if let Some(e) = first_err { - // Drop `sink` here without committing so the temp file is unlinked. - return Err(e); - } - - if let (Some(committed), Some(seen)) = (expected_length, bytes_seen) - && committed != seen + if let Some(committed) = expected_length + && committed != bytes_seen { return Err(FcryError::Format(format!( - "input length changed during encryption: committed {committed}, read {seen}" + "input length changed during encryption: committed {committed}, read {bytes_seen}" ))); } - sink.expect("no error but no sink").commit()?; + sink.commit()?; Ok(()) } #[allow(clippy::too_many_arguments)] pub(crate) fn decrypt_parallel( - mut input: AheadReader, + input: AheadReader, output: OutSink, aead: Arc, aad: Arc>, @@ -206,56 +114,161 @@ pub(crate) fn decrypt_parallel( threads: usize, expected_length: Option, ) -> Result<(), FcryError> { + let (sink, written) = run_pipeline( + input, + output, + aead, + aad, + nonce_prefix, + cipher_chunk, + threads, + false, + )?; + + if let Some(committed) = expected_length + && committed != written + { + return Err(FcryError::Format(format!( + "decrypted length {written} disagrees with committed {committed}" + ))); + } + + sink.commit()?; + Ok(()) +} + +/// Drives the reader/worker/writer pipeline. `is_encrypt = true` performs +/// `encrypt_in_place` and tracks bytes-read; `false` performs +/// `decrypt_in_place` and tracks bytes-written. The single shared topology +/// keeps backpressure, reorder, and fail-fast logic in one place. +#[allow(clippy::too_many_arguments)] +fn run_pipeline( + mut input: AheadReader, + output: OutSink, + aead: Arc, + aad: Arc>, + nonce_prefix: [u8; NONCE_PREFIX_LEN], + chunk_sz: usize, + threads: usize, + is_encrypt: bool, +) -> Result<(OutSink, u64), FcryError> { let cap = channel_capacity(threads); + let in_flight = in_flight_capacity(threads); let (jobs_tx, jobs_rx) = bounded::(cap); let (done_tx, done_rx) = bounded::(cap); - let reader_handle: JoinHandle> = thread::spawn(move || { - let mut counter: u32 = 0; - loop { - let mut buf = vec![0u8; cipher_chunk]; - match input.read_ahead(&mut buf)? { - ReadInfoChunk::Normal(_) => { - if jobs_tx - .send(Job { - counter, - last: false, - buf, - }) - .is_err() - { - return Ok(()); + // Pre-fill the permit channel. Each permit represents one in-flight chunk + // slot. The reader consumes a permit before sending a job; the writer + // returns a permit after flushing in order. + let (permit_tx, permit_rx) = bounded::<()>(in_flight); + for _ in 0..in_flight { + permit_tx + .send(()) + .expect("pre-fill of permit channel cannot fail"); + } + + let cancel = Arc::new(AtomicBool::new(false)); + + // Reader thread: dispatches jobs in counter order and tracks bytes read + // (used for the encrypt-side length cross-check). On decrypt the count is + // ignored — the writer's count is authoritative there. + let reader_handle: JoinHandle> = { + let cancel = cancel.clone(); + thread::spawn(move || { + let mut counter: u32 = 0; + let mut bytes_seen: u64 = 0; + loop { + // Acquire an in-flight slot. We recv with a short timeout so + // a worker error (which sets `cancel`) is observed even if + // the rest of the pipeline has quiesced and is no longer + // releasing permits — this avoids a 3-way deadlock between + // reader, idle workers, and a stalled writer. + loop { + if cancel.load(Ordering::Acquire) { + return Ok(bytes_seen); + } + match permit_rx.recv_timeout(Duration::from_millis(50)) { + Ok(()) => break, + Err(RecvTimeoutError::Timeout) => continue, + Err(RecvTimeoutError::Disconnected) => return Ok(bytes_seen), } - counter = bump_counter(counter)?; } - ReadInfoChunk::Last(n) => { - buf.truncate(n); - let _ = jobs_tx.send(Job { - counter, - last: true, - buf, - }); - return Ok(()); - } - ReadInfoChunk::Empty => { - return Err(FcryError::Format( - "truncated ciphertext: missing final chunk".into(), - )); + let mut buf = vec![0u8; chunk_sz]; + match input.read_ahead(&mut buf)? { + ReadInfoChunk::Normal(_) => { + if jobs_tx + .send(Job { + counter, + last: false, + buf, + }) + .is_err() + { + return Ok(bytes_seen); + } + bytes_seen = bytes_seen.saturating_add(chunk_sz as u64); + counter = bump_counter(counter)?; + } + ReadInfoChunk::Last(n) => { + buf.truncate(n); + let _ = jobs_tx.send(Job { + counter, + last: true, + buf, + }); + bytes_seen = bytes_seen.saturating_add(n as u64); + return Ok(bytes_seen); + } + ReadInfoChunk::Empty => { + if is_encrypt { + buf.clear(); + let _ = jobs_tx.send(Job { + counter, + last: true, + buf, + }); + return Ok(bytes_seen); + } + // On decrypt an unexpected EOF means the ciphertext is + // truncated. Surface it as an error so the writer + // doesn't commit a partial output. + return Err(FcryError::Format( + "truncated ciphertext: missing final chunk".into(), + )); + } } } - } - }); + }) + }; + // Worker threads: AEAD encrypt/decrypt in place, ship to writer. On error + // we set the cancel flag so the reader exits early, and drop the senders + // so the writer drains and exits. let mut worker_handles: Vec>> = Vec::with_capacity(threads); for _ in 0..threads { let jobs_rx = jobs_rx.clone(); let done_tx = done_tx.clone(); let aead = aead.clone(); let aad = aad.clone(); + let cancel = cancel.clone(); worker_handles.push(thread::spawn(move || { for mut job in jobs_rx.iter() { + if cancel.load(Ordering::Acquire) { + // Drain remaining queued jobs without doing AEAD work. + // Returning Ok here keeps the previously-set error from + // being clobbered by a fresh "ok" status. + break; + } let nonce = make_nonce(&nonce_prefix, job.counter, job.last); - aead.decrypt_in_place(&nonce, aad.as_slice(), &mut job.buf)?; + let res = if is_encrypt { + aead.encrypt_in_place(&nonce, aad.as_slice(), &mut job.buf) + } else { + aead.decrypt_in_place(&nonce, aad.as_slice(), &mut job.buf) + }; + if let Err(e) = res { + cancel.store(true, Ordering::Release); + return Err(e.into()); + } if done_tx .send(Done { counter: job.counter, @@ -272,16 +285,24 @@ pub(crate) fn decrypt_parallel( drop(jobs_rx); drop(done_tx); - let writer_handle: JoinHandle> = { - let cap = reorder_capacity(threads); - thread::spawn(move || ordered_writer_counted(done_rx, output, cap)) - }; + // Writer thread: ordered writeback. Returns the `OutSink` ownership back + // without committing; the caller commits only after every other thread + // has joined cleanly so a failure anywhere drops the sink and unlinks the + // temp file. Releases one permit per chunk flushed so the reader can make + // forward progress in lockstep with the actual disk write. + let writer_handle: JoinHandle> = + thread::spawn(move || ordered_writer(done_rx, output, permit_tx)); let reader_res = reader_handle.join().expect("reader thread panicked"); let mut first_err: Option = None; - if let Err(e) = reader_res { - first_err = Some(e); - } + let bytes_seen = match reader_res { + Ok(n) => Some(n), + Err(e) => { + cancel.store(true, Ordering::Release); + first_err.get_or_insert(e); + None + } + }; for h in worker_handles { if let Err(e) = h.join().expect("worker thread panicked") && first_err.is_none() @@ -305,51 +326,21 @@ pub(crate) fn decrypt_parallel( } let (sink, n) = written.expect("no error but no sink"); - - if let Some(committed) = expected_length - && committed != n - { - return Err(FcryError::Format(format!( - "decrypted length {n} disagrees with committed {committed}" - ))); - } - - sink.commit()?; - Ok(()) + let count = if is_encrypt { + bytes_seen.expect("no error but no reader count") + } else { + n + }; + Ok((sink, count)) } -/// Drain `done_rx` in counter order and stream into `output`. Returns the -/// `OutSink` ownership back to the caller without committing — the caller -/// commits only after every other thread has joined cleanly, so a failure -/// anywhere in the pipeline drops the sink and unlinks the temp file. +/// Drain `done_rx` in counter order, writing each chunk to `output` and +/// returning a permit to `permit_tx` after every flush so the reader is held +/// in lockstep with disk writes (bounded reorder buffer). fn ordered_writer( done_rx: Receiver, mut output: OutSink, - _cap: usize, -) -> Result { - let mut next: u32 = 0; - let mut pending: BTreeMap> = BTreeMap::new(); - for done in done_rx.iter() { - pending.insert(done.counter, done.buf); - while let Some(buf) = pending.remove(&next) { - output.write_all(&buf)?; - next = next.wrapping_add(1); - } - } - if !pending.is_empty() { - return Err(FcryError::Format( - "internal: ordered writer left chunks unflushed".into(), - )); - } - Ok(output) -} - -/// Same as `ordered_writer` but also returns total bytes written for the -/// length-committed cross-check on decrypt. -fn ordered_writer_counted( - done_rx: Receiver, - mut output: OutSink, - _cap: usize, + permit_tx: Sender<()>, ) -> Result<(OutSink, u64), FcryError> { let mut next: u32 = 0; let mut pending: BTreeMap> = BTreeMap::new(); @@ -359,7 +350,12 @@ fn ordered_writer_counted( while let Some(buf) = pending.remove(&next) { output.write_all(&buf)?; total = total.saturating_add(buf.len() as u64); - next = next.wrapping_add(1); + // `bump_counter` rejects overflow upstream; a wrap here would be + // a real bug, so use plain addition and let it panic in debug. + next += 1; + // Release one in-flight slot. If the reader is gone the channel + // is closed; we don't care about the send result. + let _ = permit_tx.send(()); } } if !pending.is_empty() { @@ -370,8 +366,8 @@ fn ordered_writer_counted( Ok((output, total)) } -// Suppress unused warnings on Sender clones inside the worker loop; the -// channel is closed when its last sender is dropped which is what we want. +// Compile-time check that the job type is Send+Sync (channel sends across +// threads). Kept as a footgun for future struct edits. #[allow(dead_code)] fn _assert_send_sync() {} const _: fn() = || _assert_send_sync::>();