feat: parallelize unitig extraction and FASTA output

Replace the non-atomic `set_visited` with atomic `fetch_or` bitmask operations to enable thread-safe node claiming. Introduce a two-phase extraction pipeline where `par_for_each_chain_unitig` builds chains in parallel and `for_each_remaining_unitig` sequentially handles residual cycles and junctions. Add `is_start` and `collect_from_start` to explicitly define unitig boundaries. Wrap `BufWriter` in a `Mutex` and use an `AtomicUsize` counter to ensure thread-safe concurrent FASTA output, refactoring the write logic into a shared closure for safe multi-threaded execution.
This commit is contained in:
Eric Coissac
2026-06-05 10:27:19 +02:00
parent 249998beed
commit d202ead385
2 changed files with 138 additions and 17 deletions
+117 -6
View File
@@ -222,9 +222,7 @@ impl GraphDeBruijn {
pub fn set_visited(&self, kmer: CanonicalKmer) {
if let Some(a) = self.nodes.get(&kmer) {
let mut node = Node(a.load(Ordering::Relaxed));
node.set_visited();
a.store(node.0, Ordering::Relaxed);
a.fetch_or(0b0000_0100, Ordering::AcqRel);
}
}
@@ -296,9 +294,9 @@ impl GraphDeBruijn {
return None;
}
let mut updated = next_node;
updated.set_visited();
atomic.store(updated.0, Ordering::Relaxed);
if !try_claim(atomic) {
return None;
}
Some(oriented)
}
@@ -333,6 +331,113 @@ impl GraphDeBruijn {
self.nodes.extend(other.nodes);
}
/// Build one unitig starting at `start`, whose node flags are `node` (pre-claim value).
/// The caller has already atomically claimed `start`; this method claims subsequent nodes.
///
/// The first right extension is always included (mirroring the original `start_iter`
/// behaviour). Branching checks for subsequent extensions are handled inside
/// `iter_unitig_kmers` → `next_unitig_kmer`.
fn collect_from_start(&self, start: CanonicalKmer, node: Node, k: usize) -> Unitig {
let mut nucs: Vec<u8> = (0..k).map(|i| start.nucleotide(i)).collect();
if node.can_extend_right() {
let next_c = start.into_kmer().push_right(node.right_nuc()).canonical();
if let Some(next_a) = self.nodes.get(&next_c) {
let old = next_a.fetch_or(0b0000_0100, Ordering::AcqRel);
if old & 0b0000_0100 == 0 {
let oriented = oriented_next(start.into_kmer(), next_c);
nucs.push(oriented.nucleotide(k - 1));
for kmer in self.iter_unitig_kmers(oriented) {
nucs.push(kmer.nucleotide(k - 1));
}
}
}
}
Unitig::from_nucleotides(&nucs)
}
/// Returns `true` if `start` is a unitig start node:
/// - no unique left predecessor (`!can_extend_left`), or
/// - unique left predecessor exists but cannot extend right
/// (i.e., no chain traversal from the left can reach `start`).
fn is_start(&self, start: CanonicalKmer, node: Node) -> bool {
if !node.can_extend_left() {
return true;
}
let pred = start.into_kmer().push_left(node.left_nuc()).canonical();
self.nodes.get(&pred)
.map(|a| !Node(a.load(Ordering::Acquire)).can_extend_right())
.unwrap_or(false)
}
/// Call `f` once per unitig.
///
/// Uses the extended start definition: a node is a start if it has no unique
/// left predecessor, or if its left predecessor cannot extend right (so no chain
/// traversal from the left could ever claim it).
///
/// Two-step execution to avoid races between start claiming and chain extension:
/// 1. Claim all starts atomically (parallel).
/// 2. Build and emit chains from claimed starts (parallel).
///
/// Parallel in production builds; sequential in test builds.
/// Must be called before [`for_each_remaining_unitig`].
pub fn par_for_each_chain_unitig(&self, f: impl Fn(Unitig) + Sync) {
let k = k();
#[cfg(not(any(test, feature = "test-utils")))]
{
// Step 1 — claim all starts in parallel.
let starts: Vec<CanonicalKmer> = self.nodes
.par_iter()
.filter_map(|(&start, atomic)| {
let node = Node(atomic.load(Ordering::Acquire));
if node.is_visited() || !self.is_start(start, node) {
return None;
}
let old = atomic.fetch_or(0b0000_0100, Ordering::AcqRel);
(old & 0b0000_0100 == 0).then_some(start)
})
.collect();
// Step 2 — build chains in parallel.
starts.into_par_iter().for_each(|start| {
let node = Node(self.nodes[&start].load(Ordering::Acquire));
f(self.collect_from_start(start, node, k));
});
}
#[cfg(any(test, feature = "test-utils"))]
{
let starts: Vec<CanonicalKmer> = self.nodes
.iter()
.filter_map(|(&start, atomic)| {
let node = Node(atomic.load(Ordering::Acquire));
if node.is_visited() || !self.is_start(start, node) {
return None;
}
let old = atomic.fetch_or(0b0000_0100, Ordering::AcqRel);
(old & 0b0000_0100 == 0).then_some(start)
})
.collect();
for start in starts {
let node = Node(self.nodes[&start].load(Ordering::Acquire));
f(self.collect_from_start(start, node, k));
}
}
}
/// Call `f` for each node still unvisited after [`par_for_each_chain_unitig`].
/// Handles true cycles and rare deeply-nested junctions. Always sequential.
pub fn for_each_remaining_unitig(&self, f: impl Fn(Unitig)) {
let k = k();
for (&kmer, atomic) in &self.nodes {
let old = atomic.fetch_or(0b0000_0100, Ordering::AcqRel);
if old & 0b0000_0100 != 0 { continue; }
f(self.collect_from_start(kmer, Node(old), k));
}
}
pub fn len(&self) -> usize {
self.nodes.len()
}
@@ -448,6 +553,12 @@ impl Iterator for UnitigIter<'_> {
// ── helpers ───────────────────────────────────────────────────────────────────
/// Atomically set the visited bit. Returns `true` iff this call claimed the node.
#[inline]
fn try_claim(atomic: &AtomicU8) -> bool {
atomic.fetch_or(0b0000_0100, Ordering::AcqRel) & 0b0000_0100 == 0
}
fn oriented_next(from: Kmer, to: CanonicalKmer) -> Kmer {
if from.is_overlapping(to.into_kmer()) {
to.into_kmer()
+21 -11
View File
@@ -1,5 +1,7 @@
use std::io::{self, BufWriter, Write};
use std::path::PathBuf;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use clap::Args;
use obidebruinj::GraphDeBruijn;
@@ -68,24 +70,32 @@ pub fn run(args: UnitigArgs) {
g.compute_degrees();
rep.push(stage.stop());
// ── Phase 3 : enumerate unitigs and write as FASTA ───────────────────────
let pb = spinner("unitig");
let stdout = io::stdout();
let mut out = BufWriter::new(stdout.lock());
// ── Phase 3 : enumerate unitigs and write as FASTA ───────────────────────
let pb = spinner("unitig");
let out = Mutex::new(BufWriter::new(io::stdout()));
let j = AtomicUsize::new(0);
let stage = Stage::start("enumerate unitigs");
for (j, unitig) in g.iter_unitig().enumerate() {
write_unitig(&unitig, k, 0, j, &mut out).unwrap_or_else(|e| {
let write_unitig_fn = |unitig: obikseq::unitig::Unitig| {
let idx = j.fetch_add(1, Ordering::Relaxed);
let mut w = out.lock().unwrap();
write_unitig(&unitig, k, 0, idx, &mut *w).unwrap_or_else(|e| {
eprintln!("write error: {e}");
std::process::exit(1);
});
if j % 10_000 == 0 {
pb.set_message(format!("{j} unitigs written"));
if idx % 10_000 == 0 {
pb.set_message(format!("{idx} unitigs written"));
}
}
};
let stage = Stage::start("chain unitigs");
g.par_for_each_chain_unitig(&write_unitig_fn);
rep.push(stage.stop());
let stage = Stage::start("remaining unitigs");
g.for_each_remaining_unitig(&write_unitig_fn);
pb.finish_and_clear();
rep.push(stage.stop());
out.flush().expect("flush error");
out.into_inner().unwrap().flush().expect("flush error");
rep.print();
}