diff --git a/src/obidebruinj/src/debruijn.rs b/src/obidebruinj/src/debruijn.rs index 1bef858..f96958c 100644 --- a/src/obidebruinj/src/debruijn.rs +++ b/src/obidebruinj/src/debruijn.rs @@ -222,9 +222,7 @@ impl GraphDeBruijn { pub fn set_visited(&self, kmer: CanonicalKmer) { if let Some(a) = self.nodes.get(&kmer) { - let mut node = Node(a.load(Ordering::Relaxed)); - node.set_visited(); - a.store(node.0, Ordering::Relaxed); + a.fetch_or(0b0000_0100, Ordering::AcqRel); } } @@ -296,9 +294,9 @@ impl GraphDeBruijn { return None; } - let mut updated = next_node; - updated.set_visited(); - atomic.store(updated.0, Ordering::Relaxed); + if !try_claim(atomic) { + return None; + } Some(oriented) } @@ -333,6 +331,113 @@ impl GraphDeBruijn { self.nodes.extend(other.nodes); } + /// Build one unitig starting at `start`, whose node flags are `node` (pre-claim value). + /// The caller has already atomically claimed `start`; this method claims subsequent nodes. + /// + /// The first right extension is always included (mirroring the original `start_iter` + /// behaviour). Branching checks for subsequent extensions are handled inside + /// `iter_unitig_kmers` → `next_unitig_kmer`. + fn collect_from_start(&self, start: CanonicalKmer, node: Node, k: usize) -> Unitig { + let mut nucs: Vec = (0..k).map(|i| start.nucleotide(i)).collect(); + if node.can_extend_right() { + let next_c = start.into_kmer().push_right(node.right_nuc()).canonical(); + if let Some(next_a) = self.nodes.get(&next_c) { + let old = next_a.fetch_or(0b0000_0100, Ordering::AcqRel); + if old & 0b0000_0100 == 0 { + let oriented = oriented_next(start.into_kmer(), next_c); + nucs.push(oriented.nucleotide(k - 1)); + for kmer in self.iter_unitig_kmers(oriented) { + nucs.push(kmer.nucleotide(k - 1)); + } + } + } + } + Unitig::from_nucleotides(&nucs) + } + + /// Returns `true` if `start` is a unitig start node: + /// - no unique left predecessor (`!can_extend_left`), or + /// - unique left predecessor exists but cannot extend right + /// (i.e., no chain traversal from the left can reach `start`). + fn is_start(&self, start: CanonicalKmer, node: Node) -> bool { + if !node.can_extend_left() { + return true; + } + let pred = start.into_kmer().push_left(node.left_nuc()).canonical(); + self.nodes.get(&pred) + .map(|a| !Node(a.load(Ordering::Acquire)).can_extend_right()) + .unwrap_or(false) + } + + /// Call `f` once per unitig. + /// + /// Uses the extended start definition: a node is a start if it has no unique + /// left predecessor, or if its left predecessor cannot extend right (so no chain + /// traversal from the left could ever claim it). + /// + /// Two-step execution to avoid races between start claiming and chain extension: + /// 1. Claim all starts atomically (parallel). + /// 2. Build and emit chains from claimed starts (parallel). + /// + /// Parallel in production builds; sequential in test builds. + /// Must be called before [`for_each_remaining_unitig`]. + pub fn par_for_each_chain_unitig(&self, f: impl Fn(Unitig) + Sync) { + let k = k(); + + #[cfg(not(any(test, feature = "test-utils")))] + { + // Step 1 — claim all starts in parallel. + let starts: Vec = self.nodes + .par_iter() + .filter_map(|(&start, atomic)| { + let node = Node(atomic.load(Ordering::Acquire)); + if node.is_visited() || !self.is_start(start, node) { + return None; + } + let old = atomic.fetch_or(0b0000_0100, Ordering::AcqRel); + (old & 0b0000_0100 == 0).then_some(start) + }) + .collect(); + + // Step 2 — build chains in parallel. + starts.into_par_iter().for_each(|start| { + let node = Node(self.nodes[&start].load(Ordering::Acquire)); + f(self.collect_from_start(start, node, k)); + }); + } + + #[cfg(any(test, feature = "test-utils"))] + { + let starts: Vec = self.nodes + .iter() + .filter_map(|(&start, atomic)| { + let node = Node(atomic.load(Ordering::Acquire)); + if node.is_visited() || !self.is_start(start, node) { + return None; + } + let old = atomic.fetch_or(0b0000_0100, Ordering::AcqRel); + (old & 0b0000_0100 == 0).then_some(start) + }) + .collect(); + + for start in starts { + let node = Node(self.nodes[&start].load(Ordering::Acquire)); + f(self.collect_from_start(start, node, k)); + } + } + } + + /// Call `f` for each node still unvisited after [`par_for_each_chain_unitig`]. + /// Handles true cycles and rare deeply-nested junctions. Always sequential. + pub fn for_each_remaining_unitig(&self, f: impl Fn(Unitig)) { + let k = k(); + for (&kmer, atomic) in &self.nodes { + let old = atomic.fetch_or(0b0000_0100, Ordering::AcqRel); + if old & 0b0000_0100 != 0 { continue; } + f(self.collect_from_start(kmer, Node(old), k)); + } + } + pub fn len(&self) -> usize { self.nodes.len() } @@ -448,6 +553,12 @@ impl Iterator for UnitigIter<'_> { // ── helpers ─────────────────────────────────────────────────────────────────── +/// Atomically set the visited bit. Returns `true` iff this call claimed the node. +#[inline] +fn try_claim(atomic: &AtomicU8) -> bool { + atomic.fetch_or(0b0000_0100, Ordering::AcqRel) & 0b0000_0100 == 0 +} + fn oriented_next(from: Kmer, to: CanonicalKmer) -> Kmer { if from.is_overlapping(to.into_kmer()) { to.into_kmer() diff --git a/src/obikmer/src/cmd/unitig.rs b/src/obikmer/src/cmd/unitig.rs index 4a81703..79f6ed2 100644 --- a/src/obikmer/src/cmd/unitig.rs +++ b/src/obikmer/src/cmd/unitig.rs @@ -1,5 +1,7 @@ use std::io::{self, BufWriter, Write}; use std::path::PathBuf; +use std::sync::Mutex; +use std::sync::atomic::{AtomicUsize, Ordering}; use clap::Args; use obidebruinj::GraphDeBruijn; @@ -68,24 +70,32 @@ pub fn run(args: UnitigArgs) { g.compute_degrees(); rep.push(stage.stop()); - // ── Phase 3 : enumerate unitigs and write as FASTA ──────────────────────── - let pb = spinner("unitig"); - let stdout = io::stdout(); - let mut out = BufWriter::new(stdout.lock()); + // ── Phase 3 : enumerate unitigs and write as FASTA ─────────────────────── + let pb = spinner("unitig"); + let out = Mutex::new(BufWriter::new(io::stdout())); + let j = AtomicUsize::new(0); - let stage = Stage::start("enumerate unitigs"); - for (j, unitig) in g.iter_unitig().enumerate() { - write_unitig(&unitig, k, 0, j, &mut out).unwrap_or_else(|e| { + let write_unitig_fn = |unitig: obikseq::unitig::Unitig| { + let idx = j.fetch_add(1, Ordering::Relaxed); + let mut w = out.lock().unwrap(); + write_unitig(&unitig, k, 0, idx, &mut *w).unwrap_or_else(|e| { eprintln!("write error: {e}"); std::process::exit(1); }); - if j % 10_000 == 0 { - pb.set_message(format!("{j} unitigs written")); + if idx % 10_000 == 0 { + pb.set_message(format!("{idx} unitigs written")); } - } + }; + + let stage = Stage::start("chain unitigs"); + g.par_for_each_chain_unitig(&write_unitig_fn); + rep.push(stage.stop()); + + let stage = Stage::start("remaining unitigs"); + g.for_each_remaining_unitig(&write_unitig_fn); pb.finish_and_clear(); rep.push(stage.stop()); - out.flush().expect("flush error"); + out.into_inner().unwrap().flush().expect("flush error"); rep.print(); }