diff --git a/src/Cargo.lock b/src/Cargo.lock index 3108163..39706b2 100644 --- a/src/Cargo.lock +++ b/src/Cargo.lock @@ -1486,6 +1486,7 @@ name = "obidebruinj" version = "0.1.0" dependencies = [ "ahash", + "crossbeam-channel", "hashbrown 0.14.5", "obifastwrite", "obikseq", diff --git a/src/obidebruinj/Cargo.toml b/src/obidebruinj/Cargo.toml index acbc742..feb5848 100644 --- a/src/obidebruinj/Cargo.toml +++ b/src/obidebruinj/Cargo.toml @@ -8,7 +8,8 @@ obikseq = { path = "../obikseq" } obifastwrite = { path = "../obifastwrite" } ahash = "0.8" hashbrown = { version = "0.14", features = ["rayon"] } -rayon = "1" +rayon = "1" +crossbeam-channel = "0.5" xxhash-rust = { version = "0.8.15", features = ["xxh3", "const_xxh3"] } tracing = "0.1" diff --git a/src/obidebruinj/src/debruijn.rs b/src/obidebruinj/src/debruijn.rs index 865567e..ed52c4d 100644 --- a/src/obidebruinj/src/debruijn.rs +++ b/src/obidebruinj/src/debruijn.rs @@ -5,6 +5,7 @@ use obikseq::{CanonicalKmer, Sequence}; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use std::cell::RefCell; use std::fmt; +use crossbeam_channel; use std::sync::atomic::{AtomicU8, Ordering}; use xxhash_rust::xxh3::Xxh3Builder; use std::time::Instant; @@ -454,24 +455,30 @@ impl GraphDeBruijn { F: FnMut(&[u8]) -> Result<(), E> + Send, { thread_local! { - static BUF: std::cell::RefCell> = RefCell::new(Vec::with_capacity(4096)); + static BUF: RefCell> = RefCell::new(Vec::with_capacity(4096)); } - let error = std::sync::Mutex::new(None::); - let f = std::sync::Mutex::new(f); - self.for_each_unitig(|iter| { - if error.lock().unwrap().is_some() { - return; - } - BUF.with(|buf| { - let mut buf = buf.borrow_mut(); - buf.clear(); - buf.extend(iter); - if let Err(e) = f.lock().unwrap()(&buf) { - *error.lock().unwrap() = Some(e); + let (tx, rx) = crossbeam_channel::bounded::>(rayon::current_num_threads() * 256); + std::thread::scope(|s| { + let writer = s.spawn(move || -> Result<(), E> { + let mut f = f; + for nucs in rx { + f(&nucs)?; } + Ok(()) }); - }); - error.into_inner().unwrap().map_or(Ok(()), Err) + self.for_each_unitig(|iter| { + BUF.with(|buf| { + let mut buf = buf.borrow_mut(); + buf.clear(); + buf.extend(iter); + let to_send = buf.clone(); + buf.clear(); + tx.send(to_send).ok(); + }); + }); + drop(tx); + writer.join().expect("writer thread panicked") + }) } pub fn len(&self) -> usize {