refactor(debruijn): unify graph traversal with WalkState iterator

Replaces deeply nested branching with early returns and `then_some`. Introduces a cycle-detecting `find_chain_start` method and updates `UnitigNucIter` to use step-based iteration with atomic node claiming. This eliminates nested iterators and redundant state management, improving code readability and maintainability.
This commit is contained in:
Eric Coissac
2026-06-06 06:35:20 +02:00
parent 95b3461405
commit b39eee688a
2 changed files with 82 additions and 282 deletions
+78 -265
View File
@@ -1,11 +1,9 @@
//use ahash::RandomState;
use hashbrown::HashMap;
use obikseq::k;
use obikseq::{CanonicalKmer, Kmer, Sequence};
use obikseq::{CanonicalKmer, Sequence};
use rayon::prelude::*;
use std::fmt;
use std::mem::needs_drop;
use std::os::unix::raw::gid_t;
use std::sync::atomic::{AtomicU8, Ordering};
use xxhash_rust::xxh3::Xxh3Builder;
@@ -197,86 +195,36 @@ impl WalkState {
.leavable(graph)
}
pub fn walk(&self, graph: &GraphDeBruijn) -> Option<WalkState> {
pub fn walk(&self, graph: &GraphDeBruijn) -> Option<(WalkState, u8)> {
if self.direct {
if self.node.can_extend_right() {
let next = self.kmer.into_kmer().push_right(self.node.right_nuc());
let cnext = next.canonical();
let dnext = next.raw() == cnext.raw();
let next_node = Node(
graph
.nodes
.get(&cnext)
.unwrap()
.load(std::sync::atomic::Ordering::Relaxed),
);
if next_node.is_visited() {
None
} else {
if dnext {
if next_node.can_extend_left() {
Some(WalkState {
kmer: cnext,
node: next_node,
direct: dnext,
})
} else {
None
}
} else {
if next_node.can_extend_right() {
Some(WalkState {
kmer: cnext,
node: next_node,
direct: dnext,
})
} else {
None
}
}
} else {
None
if !self.node.can_extend_right() {
return None;
}
}
let nuc = self.node.right_nuc();
let next = self.kmer.into_kmer().push_right(nuc);
let cnext = next.canonical();
let dnext = next.raw() == cnext.raw();
let next_node = Node(graph.nodes.get(&cnext).unwrap().load(Ordering::Relaxed));
if next_node.is_visited() {
return None;
}
let reachable = if dnext { next_node.can_extend_left() } else { next_node.can_extend_right() };
reachable.then_some((WalkState { kmer: cnext, node: next_node, direct: dnext }, nuc))
} else {
if self.node.can_extend_left() {
let next = self.kmer.into_kmer().push_left(self.node.left_nuc());
let cnext = next.canonical();
let dnext = next.raw() == cnext.raw();
let next_node = Node(
graph
.nodes
.get(&cnext)
.unwrap()
.load(std::sync::atomic::Ordering::Relaxed),
);
if next_node.is_visited() {
None
} else {
if dnext {
if next_node.can_extend_right() {
Some(WalkState {
kmer: cnext,
node: next_node,
direct: dnext,
})
} else {
None
}
} else {
if next_node.can_extend_left() {
Some(WalkState {
kmer: cnext,
node: next_node,
direct: dnext,
})
} else {
None
}
}
} else {
None
if !self.node.can_extend_left() {
return None;
}
let nuc = self.node.left_nuc();
let next = self.kmer.into_kmer().push_left(nuc);
let cnext = next.canonical();
let dnext = next.raw() != cnext.raw();
let next_node = Node(graph.nodes.get(&cnext).unwrap().load(Ordering::Relaxed));
if next_node.is_visited() {
return None;
}
let reachable = if dnext { next_node.can_extend_right() } else { next_node.can_extend_left() };
reachable.then_some((WalkState { kmer: cnext, node: next_node, direct: dnext }, 3 - nuc))
}
}
}
@@ -347,7 +295,6 @@ impl GraphDeBruijn {
}
if self.is_start(*kmer, node) {
node.set_start();
node.set_visited();
atomic.store(node.0, Ordering::Relaxed);
}
});
@@ -387,98 +334,15 @@ impl GraphDeBruijn {
Some(WalkState::new(kmer, node, true))
}
pub fn walk(&self, step: WalkState) -> Option<WalkState> {
if !step.leavable(self) {
return None;
}
let node = step.node;
let kmer = step.kmer.into_kmer();
let n_kmer = if step.direct {
kmer.push_right(node.right_nuc())
} else {
kmer.push_left(node.left_nuc())
};
let n_ckmer = n_kmer.canonical();
let n_direct = n_ckmer.raw() == n_kmer.raw();
let n_node = if let Some(node_val) = self.nodes.get(&n_ckmer) {
Node(node_val.load(Ordering::Relaxed))
} else {
unreachable!()
};
if n_node.is_visited() {
return None;
}
Some(WalkState::new(n_ckmer, n_node, n_direct))
}
fn next_unitig_kmer(&self, kmer: Kmer) -> Option<Kmer> {
let canonical = kmer.canonical();
let node = Node(self.nodes.get(&canonical)?.load(Ordering::Relaxed));
let direct = kmer.raw() == canonical.raw();
if (direct && !node.can_extend_right()) || (!direct && !node.can_extend_left()) {
return None;
}
let next_c: CanonicalKmer = if direct {
canonical
.into_kmer()
.push_right(node.right_nuc())
.canonical()
} else {
canonical.into_kmer().push_left(node.left_nuc()).canonical()
};
let atomic = self.nodes.get(&next_c)?;
let next_node = Node(atomic.load(Ordering::Relaxed));
if next_node.is_visited() {
return None;
}
let oriented = oriented_next(kmer, next_c);
let ndirect = oriented.raw() == next_c.raw();
if (ndirect && next_node.n_right_neighbours() > 1)
|| (!ndirect && next_node.n_left_neighbours() > 1)
{
return None;
}
if !try_claim(atomic) {
return None;
}
Some(oriented)
}
fn iter_unitig_kmers(&self, start: Kmer) -> UnitigIter<'_> {
UnitigIter {
graph: self,
current: Some(start),
}
}
fn unitig_nucleotides(&self, start: CanonicalKmer, node: Node, k: usize) -> UnitigNucIter<'_> {
let chain = if node.can_extend_right() {
let next_c = start.into_kmer().push_right(node.right_nuc()).canonical();
self.nodes.get(&next_c).and_then(|next_a| {
let old = next_a.fetch_or(IS_VISITED_MASK, Ordering::AcqRel);
if old & IS_VISITED_MASK == 0 {
let oriented = oriented_next(start.into_kmer(), next_c);
Some(self.iter_unitig_kmers(oriented))
} else {
None
}
})
} else {
None
};
UnitigNucIter {
start,
pos: 0,
k,
chain,
}
fn unitig_nucleotides(&self, kmer: CanonicalKmer, k: usize) -> Option<UnitigNucIter<'_>> {
let old = self.nodes.get(&kmer)?.fetch_or(IS_VISITED_MASK, Ordering::AcqRel);
if old & IS_VISITED_MASK != 0 { return None; }
let start = WalkState::new(kmer, Node(old), true);
let next_step = start.walk(self).and_then(|(next_state, nuc)| {
let ext_old = self.nodes.get(&next_state.kmer)?.fetch_or(IS_VISITED_MASK, Ordering::AcqRel);
(ext_old & IS_VISITED_MASK == 0).then_some((next_state, nuc))
});
Some(UnitigNucIter { graph: self, start: kmer, pos: 0, k, next_step })
}
pub fn for_each_unitig(&self, f: impl Fn(UnitigNucIter<'_>) + Sync) {
@@ -495,20 +359,22 @@ impl GraphDeBruijn {
self.nodes
.par_iter()
.filter_map(|(&kmer, atomic)| {
let node = Node(atomic.load(Ordering::Acquire));
node.is_start().then_some((kmer, node))
Node(atomic.load(Ordering::Acquire)).is_start().then_some(kmer)
})
.for_each(|(start, node)| {
n_new.fetch_add(1, Ordering::Relaxed);
f(self.unitig_nucleotides(start, node, k));
.for_each(|kmer| {
if let Some(iter) = self.unitig_nucleotides(kmer, k) {
n_new.fetch_add(1, Ordering::Relaxed);
f(iter);
}
});
#[cfg(any(test, feature = "test-utils"))]
self.nodes.iter().for_each(|(&kmer, atomic)| {
let node = Node(atomic.load(Ordering::Acquire));
if node.is_start() {
n_new.fetch_add(1, Ordering::Relaxed);
f(self.unitig_nucleotides(kmer, node, k));
if Node(atomic.load(Ordering::Acquire)).is_start() {
if let Some(iter) = self.unitig_nucleotides(kmer, k) {
n_new.fetch_add(1, Ordering::Relaxed);
f(iter);
}
}
});
@@ -528,25 +394,18 @@ impl GraphDeBruijn {
if node.is_visited() {
continue;
}
let start = if !node.can_extend_right() && node.can_extend_left() {
self.find_left_chain_start(kmer)
} else {
kmer
};
let start_atomic = &self.nodes[&start];
let old = start_atomic.fetch_or(IS_VISITED_MASK, Ordering::AcqRel);
if old & IS_VISITED_MASK == 0 {
let chain_start = self.find_chain_start(kmer);
if let Some(iter) = self.unitig_nucleotides(chain_start, k) {
n2.fetch_add(1, Ordering::Relaxed);
f(self.unitig_nucleotides(start, Node(old), k));
f(iter);
}
// Fallback: if kmer was not reached by start's chain, claim it directly.
// Safe because unitig_nucleotides(start, ...) may have visited kmer in the
// meantime — in that case fetch_or returns IS_VISITED_MASK set and we skip.
if start != kmer {
let kmer_old = atomic.fetch_or(IS_VISITED_MASK, Ordering::AcqRel);
if kmer_old & IS_VISITED_MASK == 0 {
// Safe because unitig_nucleotides may have visited kmer in the
// meantime — in that case it returns None and we skip.
if chain_start != kmer {
if let Some(iter) = self.unitig_nucleotides(kmer, k) {
n2.fetch_add(1, Ordering::Relaxed);
f(self.unitig_nucleotides(kmer, Node(kmer_old), k));
f(iter);
}
}
}
@@ -570,37 +429,21 @@ impl GraphDeBruijn {
self.nodes.extend(other.nodes);
}
/// Returns `true` if `query` is a unitig start node:
/// - no unique left predecessor (`!can_extend_left`), or
/// - unique left predecessor exists but cannot extend right
/// (i.e., no chain traversal from the left can reach `start`).
fn find_left_chain_start(&self, kmer: CanonicalKmer) -> CanonicalKmer {
let mut current = kmer;
fn find_chain_start(&self, kmer: CanonicalKmer) -> CanonicalKmer {
let node = Node(self.nodes[&kmer].load(Ordering::Acquire));
let mut state = WalkState::new(kmer, node, false);
let mut seen = std::collections::HashSet::new();
seen.insert((state.kmer.raw(), state.direct));
loop {
let node = Node(self.nodes[&current].load(Ordering::Acquire));
if !node.can_extend_left() {
return current;
match state.walk(self) {
None => return state.kmer,
Some((next, _)) => {
if !seen.insert((next.kmer.raw(), next.direct)) {
return kmer;
}
state = next;
}
}
let pred = current.into_kmer().push_left(node.left_nuc()).canonical();
let Some(pred_a) = self.nodes.get(&pred) else {
return current;
};
let pred_node = Node(pred_a.load(Ordering::Acquire));
if pred_node.is_visited() {
return current;
}
if !pred_node.can_extend_right() {
return current;
}
// Stop if asymmetry: pred's right canonical neighbor is not current
let pred_right = pred
.into_kmer()
.push_right(pred_node.right_nuc())
.canonical();
if pred_right != current {
return current;
}
current = pred;
}
}
@@ -640,30 +483,14 @@ impl GraphDeBruijn {
}
}
// ── UnitigIter ────────────────────────────────────────────────────────────────
struct UnitigIter<'a> {
graph: &'a GraphDeBruijn,
current: Option<Kmer>,
}
impl Iterator for UnitigIter<'_> {
type Item = Kmer;
fn next(&mut self) -> Option<Kmer> {
let current = self.current?;
self.current = self.graph.next_unitig_kmer(current);
Some(current)
}
}
// ── UnitigNucIter ─────────────────────────────────────────────────────────────
pub struct UnitigNucIter<'a> {
graph: &'a GraphDeBruijn,
start: CanonicalKmer,
pos: usize,
k: usize,
chain: Option<UnitigIter<'a>>,
next_step: Option<(WalkState, u8)>,
}
impl Iterator for UnitigNucIter<'_> {
@@ -674,33 +501,19 @@ impl Iterator for UnitigNucIter<'_> {
let nuc = self.start.nucleotide(self.pos);
self.pos += 1;
Some(nuc)
} else if let Some((state, nuc)) = self.next_step.take() {
self.next_step = state.walk(self.graph).and_then(|(next_state, next_nuc)| {
let old = self.graph.nodes.get(&next_state.kmer)?.fetch_or(IS_VISITED_MASK, Ordering::AcqRel);
(old & IS_VISITED_MASK == 0).then_some((next_state, next_nuc))
});
Some(nuc)
} else {
self.chain
.as_mut()?
.next()
.map(|kmer| kmer.nucleotide(self.k - 1))
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.k - self.pos, None)
}
}
// ── helpers ───────────────────────────────────────────────────────────────────
/// Atomically set the visited bit. Returns `true` iff this call claimed the node.
#[inline]
fn try_claim(atomic: &AtomicU8) -> bool {
atomic.fetch_or(IS_VISITED_MASK, Ordering::AcqRel) & IS_VISITED_MASK == 0
}
fn oriented_next(from: Kmer, to: CanonicalKmer) -> Kmer {
let direct = to.into_kmer();
if from.is_overlapping(direct) {
direct
} else {
to.revcomp()
(self.k - self.pos.min(self.k), None)
}
}
+4 -17
View File
@@ -1,5 +1,5 @@
use super::*;
use obikseq::{k, set_k, unitig::Unitig};
use obikseq::{k, set_k, unitig::Unitig, Kmer};
// Build a graph from an ASCII sequence, inserting all canonical k-mers.
fn graph_from_ascii(seq: &[u8]) -> GraphDeBruijn {
@@ -116,27 +116,14 @@ fn kmers_from_unitigs(unitigs: &[Unitig]) -> Vec<CanonicalKmer> {
#[test]
fn unitig_roundtrip_linear() {
// Non-repetitive sequence: no k-mer appears twice, no homopolymer run of length k.
// ACGTGGCTA with k=5 → 5 distinct k-mers forming a clean linear chain.
// AAAAAGGGC with k=5 → 5 distinct k-mers, all in direct canonical form,
// forming a clean linear chain with no orientation flips.
let k = 5;
set_k(k);
let seq = b"ACCTGGCTA";
let seq = b"AAAAAGGGC";
let g = graph_from_ascii(seq);
g.compute_degrees_and_mark_starts();
println!("Les kmers:");
for (kmer, v) in g.nodes.iter() {
println!(
"{}: {}",
String::from_utf8_lossy(&kmer.to_ascii()),
v.load(std::sync::atomic::Ordering::Relaxed)
);
}
println!("Les unitig:");
let unitigs: Vec<Unitig> = collect_unitigs(&g);
for unitig in &unitigs {
println!("{}", String::from_utf8_lossy(&unitig.to_ascii()));
}
assert_eq!(
unitigs.len(),
1,