From 8c17bf958b643b87fc4677c634382961f4157d85 Mon Sep 17 00:00:00 2001 From: Eric Coissac Date: Tue, 5 May 2026 18:08:19 +0200 Subject: [PATCH] refactor: centralize k-mer config and introduce packed sequences Centralize k-mer and minimizer configuration using a thread-safe global module, and replace manual bit-packing with a memory-efficient `PackedSeq` type. Refactor core sequence and k-mer types to use compile-time length enforcement and centralized hashing. Introduce a new De Bruijn graph implementation with compact node encoding and traversal iterators. Update I/O, partitioning, and builder modules to align with the new architecture, and add the `xxhash-rust` dependency. --- CLAUDE.md | 3 + src/Cargo.lock | 1 + src/obidebruinj/Cargo.toml | 4 + src/obidebruinj/src/debruijn.rs | 573 +++++++++++++++++ src/obidebruinj/src/lib.rs | 891 +------------------------- src/obidebruinj/src/tests/debruijn.rs | 301 +++++++++ src/obifastwrite/Cargo.toml | 3 + src/obifastwrite/src/lib.rs | 31 +- src/obikmer/src/cmd/fasta.rs | 2 +- src/obikmer/src/cmd/longtig.rs | 8 +- src/obikmer/src/cmd/partition.rs | 6 +- src/obikmer/src/cmd/superkmer.rs | 4 +- src/obikmer/src/cmd/unitig.rs | 8 +- src/obikpartitionner/src/partition.rs | 39 +- src/obikseq/Cargo.toml | 5 + src/obikseq/src/annotations.rs | 8 + src/obikseq/src/kmer.rs | 494 +++++++------- src/obikseq/src/lib.rs | 6 +- src/obikseq/src/packed_seq.rs | 361 +++++++++++ src/obikseq/src/params.rs | 102 +++ src/obikseq/src/routable.rs | 82 ++- src/obikseq/src/sequence.rs | 71 +- src/obikseq/src/superkmer.rs | 445 ++++--------- src/obikseq/src/tests/kmer.rs | 213 ++++++ src/obikseq/src/tests/superkmer.rs | 410 ++++-------- src/obikseq/src/tests/unitig.rs | 171 +++++ src/obikseq/src/unitig.rs | 426 +----------- src/obiskbuilder/Cargo.toml | 3 + src/obiskbuilder/src/entropy_table.rs | 9 +- src/obiskbuilder/src/iter.rs | 45 +- src/obiskbuilder/src/lib.rs | 9 +- src/obiskbuilder/src/rolling_stat.rs | 172 ++--- src/obiskbuilder/src/scratch.rs | 4 +- src/obiskio/Cargo.toml | 1 + src/obiskio/src/codec.rs | 73 +-- src/obiskio/src/pool.rs | 79 +-- src/obiskio/src/reader.rs | 34 +- 37 files changed, 2641 insertions(+), 2456 deletions(-) create mode 100644 src/obidebruinj/src/debruijn.rs create mode 100644 src/obidebruinj/src/tests/debruijn.rs create mode 100644 src/obikseq/src/packed_seq.rs create mode 100644 src/obikseq/src/params.rs create mode 100644 src/obikseq/src/tests/kmer.rs create mode 100644 src/obikseq/src/tests/unitig.rs diff --git a/CLAUDE.md b/CLAUDE.md index 9c869aa..de49776 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -2,6 +2,9 @@ Tu es ma base de connaissance et mon bloc-notes intelligent sur le projet **obikmer**. Tu ne proposes pas, tu ne codes pas spontanément — tu réponds à mes questions et tu structures mes idées au fur et à mesure que je les exprime. +**Règle absolue : une question appelle une réponse, pas une action.** +Ne modifier aucun fichier à moins d'une demande explicite de modification. En particulier : observer un bug ou une incohérence dans le code montré ne constitue pas un mandat pour le corriger. Le code montré peut refléter une intention en cours — modifier sans mandat risque d'introduire un vrai bug là où tu croyais corriger. + Tu maintiens en **anglais**, dense et sans remplissage, les documents suivants : - `docmd/index.md` — document de discussion de base, enrichi progressivement au fil de nos échanges ; il reflète l'état courant de la réflexion sur le projet - les autres fichiers Markdown dans `docmd/` selon leur thème respectif diff --git a/src/Cargo.lock b/src/Cargo.lock index 565f931..19eb24f 100644 --- a/src/Cargo.lock +++ b/src/Cargo.lock @@ -1575,6 +1575,7 @@ dependencies = [ "hashbrown 0.14.5", "obifastwrite", "obikseq", + "xxhash-rust", ] [[package]] diff --git a/src/obidebruinj/Cargo.toml b/src/obidebruinj/Cargo.toml index 40c7da6..5fb8273 100644 --- a/src/obidebruinj/Cargo.toml +++ b/src/obidebruinj/Cargo.toml @@ -8,3 +8,7 @@ obikseq = { path = "../obikseq" } obifastwrite = { path = "../obifastwrite" } ahash = "0.8" hashbrown = "0.14" +xxhash-rust = { version = "0.8.15", features = ["xxh3", "const_xxh3"] } + +[dev-dependencies] +obikseq = { path = "../obikseq", features = ["test-utils"] } diff --git a/src/obidebruinj/src/debruijn.rs b/src/obidebruinj/src/debruijn.rs new file mode 100644 index 0000000..c76af41 --- /dev/null +++ b/src/obidebruinj/src/debruijn.rs @@ -0,0 +1,573 @@ +//use ahash::RandomState; +use hashbrown::HashMap; +use obifastwrite::write_unitig; +use obikseq::k; +use obikseq::unitig::Unitig; +use obikseq::{CanonicalKmer, Kmer, Sequence}; +use std::cell::Cell; +use std::fmt; +use std::io; +use xxhash_rust::xxh3::Xxh3Builder; + +// ── Types ───────────────────────────────────────────────────────────────────── + +type FastHashMap = HashMap; + +// ── Node ────────────────────────────────────────────────────────────────────── +// +// bit layout (LSB first): +// bit 0 : can_extend_right — exactly one right canonical neighbour exists +// bit 1 : can_extend_left — exactly one left canonical neighbour exists +// bit 2 : visited +// bits 3–4 : right_nuc — index 0–3 (A/C/G/T) of that neighbour; valid iff bit 0 = 1 +// bits 5–6 : left_nuc — index 0–3 (A/C/G/T) of that neighbour; valid iff bit 1 = 1 +// bit 7 : reserved (0) +// +// "can_extend" = false covers both 0 neighbours and ≥2 neighbours; the only +// information needed for traversal is "exactly one". + +#[repr(transparent)] +#[derive(Debug, Clone, Copy, Default)] +pub struct Node(u8); + +impl Node { + /// Returns `true` if the node can be extended to the right. + /// + /// A single right neighbour exists. + pub fn can_extend_right(self) -> bool { + self.0 & 0b0000_0001 != 0 + } + + /// Returns `true` if the node can be extended to the left. + /// + /// A single left neighbour exists. + pub fn can_extend_left(self) -> bool { + self.0 & 0b0000_0010 != 0 + } + + /// Returns `true` if the node has been visited. + pub fn is_visited(self) -> bool { + self.0 & 0b0000_0100 != 0 + } + + /// Index of the unique right neighbour (0=A, 1=C, 2=G, 3=T). + /// Only meaningful when `can_extend_right()` is true. + pub fn right_nuc(self) -> u8 { + (self.0 >> 3) & 0b11 + } + + /// Index of the unique left neighbour (0=A, 1=C, 2=G, 3=T). + /// Only meaningful when `can_extend_left()` is true. + pub fn left_nuc(self) -> u8 { + (self.0 >> 5) & 0b11 + } + + /// Marks the node as visited. + pub fn set_visited(&mut self) { + if self.is_visited() { + unreachable!("from: is_visited -> The node has already been visited") + } + self.0 |= 0b0000_0100; + } + + /// Number of left neighbours. + pub fn n_left_neighbours(self) -> u8 { + if self.can_extend_left() { + 1 + } else { + let v = (self.0 >> 5) & 0b11; + v + (v != 0) as u8 + } + } + + /// Number of right neighbours. + pub fn n_right_neighbours(self) -> u8 { + if self.can_extend_right() { + 1 + } else { + let v = (self.0 >> 3) & 0b11; + v + (v != 0) as u8 + } + } + + /// `nuc` = Some(i) → exactly one neighbour (bit 0 set, bits 3–4 = nucleotide index). + /// `nuc` = None → 0 or ≥2 neighbours; `count` encoded in bits 3–4 as count.sat_sub(1). + pub fn set_right(&mut self, count: u8, nuc: Option) { + self.0 &= !(0b0000_0001 | 0b001_1000); + if count == 1 { + self.0 |= 0b0000_0001; + if let Some(n) = nuc { + self.0 |= (n & 0b11) << 3; + return; + } + unreachable!("nuc must be Some when count is 1"); + } + self.0 |= (count.saturating_sub(1).min(3)) << 3; + } + + /// `nuc` = Some(i) → exactly one neighbour (bit 0 set, bits 3–4 = nucleotide index). + /// `nuc` = None → 0 or ≥2 neighbours; `count` encoded in bits 3–4 as count.sat_sub(1). + pub fn set_left(&mut self, count: u8, nuc: Option) { + self.0 &= !(0b0000_0010 | 0b0110_0000); + if count == 1 { + self.0 |= 0b0000_0010; + if let Some(n) = nuc { + self.0 |= (n & 0b11) << 5; + return; + } + unreachable!("nuc must be Some when count is 1"); + } + self.0 |= (count.saturating_sub(1).min(3)) << 5; + } +} + +impl fmt::Display for Node { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + const NUC: [char; 4] = ['A', 'C', 'G', 'T']; + let r = if self.can_extend_right() { + format!("→{}", NUC[self.right_nuc() as usize]) + } else { + format!("→{}", self.n_right_neighbours()) + }; + let l = if self.can_extend_left() { + format!("←{}", NUC[self.left_nuc() as usize]) + } else { + format!("←{}", self.n_left_neighbours()) + }; + let v = if self.is_visited() { "V" } else { "." }; + write!(f, "Node({r} {l} {v})") + } +} + +// ── GraphDeBruijn ───────────────────────────────────────────────────────────── + +pub struct GraphDeBruijn { + nodes: FastHashMap>, +} + +impl GraphDeBruijn { + pub fn new() -> Self { + Self { + nodes: FastHashMap::with_hasher(Xxh3Builder::new()), + } + } + + pub fn with_capacity(capacity: usize) -> Self { + Self { + nodes: FastHashMap::with_capacity_and_hasher(capacity, Xxh3Builder::new()), + } + } + + /// Insert a canonical kmer into the graph. No-op if already present. + pub fn push(&mut self, kmer: CanonicalKmer) { + self.nodes + .entry(kmer) + .or_insert_with(|| Cell::new(Node::default())); + } + + /// For every node, find its unique right/left canonical neighbour (if any) + /// and store the nucleotide index in the Node flags. + /// + /// Single pass thanks to Cell interior mutability. + pub fn compute_degrees(&self) { + for (&kmer, cell) in &self.nodes { + let (rc, rn) = count_neighbors(kmer.right_canonical_neighbors(), &self.nodes); + let (lc, ln) = count_neighbors(kmer.left_canonical_neighbors(), &self.nodes); + + let mut node = cell.get(); + node.set_right(rc, rn); + node.set_left(lc, ln); + cell.set(node); + } + } + + /// Iterates over the right neighbors of `kmer`. + pub fn iter_right_neighbors( + &self, + kmer: CanonicalKmer, + ) -> impl Iterator + '_ { + kmer.right_canonical_neighbors() + .into_iter() + .filter_map(|kmer| { + self.nodes.get(&kmer)?; + Some(kmer) + }) + } + + /// Iterates over the left neighbors of `kmer`. + pub fn iter_left_neighbors( + &self, + kmer: CanonicalKmer, + ) -> impl Iterator + '_ { + kmer.left_canonical_neighbors() + .into_iter() + .filter_map(|kmer| { + self.nodes.get(&kmer)?; + Some(kmer) + }) + } + + pub fn is_visited(&self, kmer: &CanonicalKmer) -> Option { + self.nodes.get(kmer).map(|cell| cell.get().is_visited()) + } + + pub fn set_visited(&self, kmer: CanonicalKmer) { + if let Some(cell) = self.nodes.get(&kmer) { + let mut node = cell.get(); + node.set_visited(); + cell.set(node); + } + } + + /// Returns the single right neighbor of `kmer`, if it exists. + pub fn the_single_right_neighbor(&self, kmer: CanonicalKmer) -> Option { + let node = self.nodes.get(&kmer)?.get(); + if !node.can_extend_right() { + return None; + } + let next = kmer.into_kmer().push_right(node.right_nuc()).canonical(); + self.nodes.contains_key(&next).then_some(next) + } + + /// Returns the single left neighbor of `kmer`, if it exists. + pub fn the_single_left_neighbor(&self, kmer: CanonicalKmer) -> Option { + let node = self.nodes.get(&kmer)?.get(); + if !node.can_extend_left() { + return None; + } + let next = kmer.into_kmer().push_left(node.left_nuc()).canonical(); + self.nodes.contains_key(&next).then_some(next) + } + + /// Internal iterator over unitig-start nodes; drives `iter_unitig`. + /// + /// MUST NOT be consumed standalone: the second pass finds cycle nodes only + /// because `iter_unitig` lazily interleaves chain traversal between the two passes. + /// + /// Two passes: + /// 1. Chain ends / isolated nodes (at most one extension missing): + /// - `!can_extend_left` → yield canonical form + /// - `!can_extend_right` → yield reverse complement + /// 2. Nodes still unvisited → part of a cycle; yield canonical form. + fn start_iter(&self) -> impl Iterator)> + '_ { + StartIter::new(self) + } + + fn next_unitig_kmer(&self, kmer: Kmer) -> Option { + let canonical = kmer.canonical(); + let node = self.nodes.get(&canonical)?.get(); + + let direct = kmer.raw() == canonical.raw(); + + if (direct && !node.can_extend_right()) || (!direct && !node.can_extend_left()) { + return None; + } + + let next_c: CanonicalKmer = if direct { + canonical + .into_kmer() + .push_right(node.right_nuc()) + .canonical() + } else { + canonical.into_kmer().push_left(node.left_nuc()).canonical() + }; + + let cell = self.nodes.get(&next_c)?; + let next_node = cell.get(); + if next_node.is_visited() { + return None; + } + + let oriented = oriented_next(kmer, next_c); + let ndirect = oriented.raw() == next_c.raw(); + + if (ndirect && next_node.n_right_neighbours() > 1) + || (!ndirect && next_node.n_left_neighbours() > 1) + { + return None; + } + + let mut updated = next_node; + updated.set_visited(); + cell.set(updated); + Some(oriented) + } + + fn next_longtig_kmer(&self, kmer: Kmer) -> Option { + let canonical = kmer.canonical(); + let node = self.nodes.get(&canonical)?.get(); + + let direct = kmer.raw() == canonical.raw(); + + if (direct && node.n_right_neighbours() == 0) || (!direct && node.n_left_neighbours() == 0) + { + return None; + } + + let next_c: CanonicalKmer = if direct { + if node.can_extend_right() { + canonical + .into_kmer() + .push_right(node.right_nuc()) + .canonical() + } else { + self.iter_right_neighbors(canonical) + .filter(|n| !self.is_visited(n).unwrap_or(true)) + .next()? + } + } else { + if node.can_extend_left() { + canonical.into_kmer().push_left(node.left_nuc()).canonical() + } else { + self.iter_left_neighbors(canonical) + .filter(|n| !self.is_visited(n).unwrap_or(true)) + .next()? + } + }; + + let cell = self.nodes.get(&next_c)?; + let next_node = cell.get(); + if next_node.is_visited() { + return None; + } + + let oriented = oriented_next(kmer, next_c); + let ndirect = oriented.raw() == next_c.raw(); + + if (ndirect && next_node.n_right_neighbours() > 1) + || (!ndirect && next_node.n_left_neighbours() > 1) + { + return None; + } + + let mut updated = next_node; + updated.set_visited(); + cell.set(updated); + Some(oriented) + } + + fn iter_unitig_kmers(&self, start: Kmer) -> UnitigIter<'_> { + UnitigIter { + graph: self, + current: Some(start), + } + } + + fn iter_longtig_kmers(&self, start: Kmer) -> LongtigIter<'_> { + LongtigIter { + graph: self, + current: Some(start), + } + } + + pub fn iter_unitig(&self) -> impl Iterator + '_ { + let k = k(); + self.start_iter().map(move |(start, first_next)| { + let mut nucs: Vec = (0..k).map(|i| start.nucleotide(i)).collect(); + if let Some(next_c) = first_next { + for kmer in self.iter_unitig_kmers(next_c) { + nucs.push(kmer.nucleotide(k - 1)); + } + } + Unitig::from_nucleotides(&nucs) + }) + } + + pub fn iter_longtig(&self) -> impl Iterator + '_ { + let k = k(); + self.start_iter().map(move |(start, first_next)| { + let mut nucs: Vec = (0..k).map(|i| start.nucleotide(i)).collect(); + if let Some(next_c) = first_next { + for kmer in self.iter_longtig_kmers(next_c) { + nucs.push(kmer.nucleotide(k - 1)); + } + } + Unitig::from_nucleotides(&nucs) + }) + } + + /// Write all unitigs to `out` in FASTA format. + /// + /// Calls [`obifastwrite::write_unitig`] for each unitig produced by + /// [`iter_unitig`]. Stops and returns the first I/O error encountered. + pub fn write_fasta(&self, out: &mut W, unitig: bool) -> io::Result<()> { + if unitig { + for unitig in self.iter_unitig() { + write_unitig(&unitig, k(), out)?; + } + } else { + for unitig in self.iter_longtig() { + write_unitig(&unitig, k(), out)?; + } + } + Ok(()) + } + + pub fn len(&self) -> usize { + self.nodes.len() + } + + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } +} + +// --- StartIter ----------------------------------------------------------------- +struct StartIter<'a> { + graph: &'a GraphDeBruijn, + nodes: hashbrown::hash_map::Iter<'a, CanonicalKmer, Cell>, + suspended: Vec, + in_cycle_pass: bool, +} + +impl<'a> StartIter<'a> { + fn new(graph: &'a GraphDeBruijn) -> Self { + Self { + graph, + nodes: graph.nodes.iter(), + suspended: Vec::new(), + in_cycle_pass: false, + } + } +} + +impl<'a> Iterator for StartIter<'a> { + type Item = (CanonicalKmer, Option); + + fn next(&mut self) -> Option<(CanonicalKmer, Option)> { + loop { + let current = if let Some(k) = self.suspended.pop() { + k + } else { + match self.nodes.next() { + Some((&k, _)) => k, + None => { + if self.in_cycle_pass { + return None; + } + self.in_cycle_pass = true; + self.nodes = self.graph.nodes.iter(); + match self.nodes.next() { + Some((&k, _)) => k, + None => return None, + } + } + } + }; + + let node = match self.graph.nodes.get(¤t) { + Some(c) => c.get(), + None => continue, + }; + if node.is_visited() { + continue; + } + if !self.in_cycle_pass && node.can_extend_left() { + continue; + } + + self.graph.set_visited(current); + + if let Some(next) = self.graph.the_single_right_neighbor(current) { + if self.graph.is_visited(&next).unwrap_or(true) { + return Some((current, None)); + } + self.graph.set_visited(next); + let oriented = oriented_next(current.into_kmer(), next); + return Some((current, Some(oriented))); + } + + let mut first_neighbor: Option = None; + for neighbor in self.graph.iter_right_neighbors(current) { + if self.graph.is_visited(&neighbor).unwrap_or(true) { + continue; + } + if first_neighbor.is_none() { + self.graph.set_visited(neighbor); + first_neighbor = Some(neighbor); + } else { + self.suspended.push(neighbor); + } + } + + let oriented = match first_neighbor { + Some(neighbor) => Some(oriented_next(current.into_kmer(), neighbor)), + None => None, + }; + return Some((current, oriented)); + } + } +} + +// ── UnitigIter ──────────────────────────────────────────────────────────────── + +struct UnitigIter<'a> { + graph: &'a GraphDeBruijn, + current: Option, +} + +impl Iterator for UnitigIter<'_> { + type Item = Kmer; + + fn next(&mut self) -> Option { + let current = self.current?; + self.current = self.graph.next_unitig_kmer(current); + Some(current) + } +} + +// ── UnitigIter ──────────────────────────────────────────────────────────────── + +struct LongtigIter<'a> { + graph: &'a GraphDeBruijn, + current: Option, +} + +impl Iterator for LongtigIter<'_> { + type Item = Kmer; + + fn next(&mut self) -> Option { + let current = self.current?; + self.current = self.graph.next_longtig_kmer(current); + Some(current) + } +} + +// ── helpers ─────────────────────────────────────────────────────────────────── + +fn oriented_next(from: Kmer, to: CanonicalKmer) -> Kmer { + if from.is_overlapping(to.into_kmer()) { + to.into_kmer() + } else { + to.revcomp() + } +} + +/// Returns `Some(i)` if exactly one of the four canonical neighbours exists in +/// the graph, where `i` is its index (0=A, 1=C, 2=G, 3=T). Returns `None` for +/// zero or ≥2 existing neighbours. +fn count_neighbors( + neighbors: [CanonicalKmer; 4], + nodes: &FastHashMap>, +) -> (u8, Option) { + let mut count = 0u8; + let mut first = None; + for (i, neighbour) in neighbors.iter().enumerate() { + if nodes.contains_key(neighbour) { + count += 1; + if first.is_none() { + first = Some(i as u8); + } + } + } + if count == 1 { + (1, first) + } else { + (count, None) + } +} + +// ── tests ───────────────────────────────────────────────────────────────────── +#[cfg(test)] +#[path = "tests/debruijn.rs"] +mod tests; diff --git a/src/obidebruinj/src/lib.rs b/src/obidebruinj/src/lib.rs index 58eed58..7652e5e 100644 --- a/src/obidebruinj/src/lib.rs +++ b/src/obidebruinj/src/lib.rs @@ -1,890 +1,3 @@ -use ahash::RandomState; -use hashbrown::HashMap; -use obifastwrite::write_unitig; -use obikseq::kmer::{self, CanonicalKmer, Kmer}; -use obikseq::unitig::Unitig; -use std::cell::Cell; -use std::fmt; -use std::io; +mod debruijn; -// ── Types ───────────────────────────────────────────────────────────────────── - -type FastHashMap = HashMap; - -// ── Node ────────────────────────────────────────────────────────────────────── -// -// bit layout (LSB first): -// bit 0 : can_extend_right — exactly one right canonical neighbour exists -// bit 1 : can_extend_left — exactly one left canonical neighbour exists -// bit 2 : visited -// bits 3–4 : right_nuc — index 0–3 (A/C/G/T) of that neighbour; valid iff bit 0 = 1 -// bits 5–6 : left_nuc — index 0–3 (A/C/G/T) of that neighbour; valid iff bit 1 = 1 -// bit 7 : reserved (0) -// -// "can_extend" = false covers both 0 neighbours and ≥2 neighbours; the only -// information needed for traversal is "exactly one". - -#[repr(transparent)] -#[derive(Debug, Clone, Copy, Default)] -pub struct Node(u8); - -impl Node { - /// Returns `true` if the node can be extended to the right. - /// - /// A single right neighbour exists. - pub fn can_extend_right(self) -> bool { - self.0 & 0b0000_0001 != 0 - } - - /// Returns `true` if the node can be extended to the left. - /// - /// A single left neighbour exists. - pub fn can_extend_left(self) -> bool { - self.0 & 0b0000_0010 != 0 - } - - /// Returns `true` if the node has been visited. - pub fn is_visited(self) -> bool { - self.0 & 0b0000_0100 != 0 - } - - /// Index of the unique right neighbour (0=A, 1=C, 2=G, 3=T). - /// Only meaningful when `can_extend_right()` is true. - pub fn right_nuc(self) -> u8 { - (self.0 >> 3) & 0b11 - } - - /// Index of the unique left neighbour (0=A, 1=C, 2=G, 3=T). - /// Only meaningful when `can_extend_left()` is true. - pub fn left_nuc(self) -> u8 { - (self.0 >> 5) & 0b11 - } - - /// Marks the node as visited. - pub fn set_visited(&mut self) { - if self.is_visited() { - unreachable!("from: is_visited -> The node has already been visited") - } - self.0 |= 0b0000_0100; - } - - /// Number of left neighbours. - pub fn n_left_neighbours(self) -> u8 { - if self.can_extend_left() { - 1 - } else { - let v = (self.0 >> 5) & 0b11; - v + (v != 0) as u8 - } - } - - /// Number of right neighbours. - pub fn n_right_neighbours(self) -> u8 { - if self.can_extend_right() { - 1 - } else { - let v = (self.0 >> 3) & 0b11; - v + (v != 0) as u8 - } - } - - /// `nuc` = Some(i) → exactly one neighbour (bit 0 set, bits 3–4 = nucleotide index). - /// `nuc` = None → 0 or ≥2 neighbours; `count` encoded in bits 3–4 as count.sat_sub(1). - pub fn set_right(&mut self, count: u8, nuc: Option) { - self.0 &= !(0b0000_0001 | 0b001_1000); - if count == 1 { - self.0 |= 0b0000_0001; - if let Some(n) = nuc { - self.0 |= (n & 0b11) << 3; - return; - } - unreachable!("nuc must be Some when count is 1"); - } - self.0 |= (count.saturating_sub(1).min(3)) << 3; - } - - /// `nuc` = Some(i) → exactly one neighbour (bit 0 set, bits 3–4 = nucleotide index). - /// `nuc` = None → 0 or ≥2 neighbours; `count` encoded in bits 3–4 as count.sat_sub(1). - pub fn set_left(&mut self, count: u8, nuc: Option) { - self.0 &= !(0b0000_0010 | 0b0110_0000); - if count == 1 { - self.0 |= 0b0000_0010; - if let Some(n) = nuc { - self.0 |= (n & 0b11) << 5; - return; - } - unreachable!("nuc must be Some when count is 1"); - } - self.0 |= (count.saturating_sub(1).min(3)) << 5; - } -} - -impl fmt::Display for Node { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - const NUC: [char; 4] = ['A', 'C', 'G', 'T']; - let r = if self.can_extend_right() { - format!("→{}", NUC[self.right_nuc() as usize]) - } else { - format!("→{}", self.n_right_neighbours()) - }; - let l = if self.can_extend_left() { - format!("←{}", NUC[self.left_nuc() as usize]) - } else { - format!("←{}", self.n_left_neighbours()) - }; - let v = if self.is_visited() { "V" } else { "." }; - write!(f, "Node({r} {l} {v})") - } -} - -// ── GraphDeBruijn ───────────────────────────────────────────────────────────── - -pub struct GraphDeBruijn { - nodes: FastHashMap>, - k: usize, -} - -impl GraphDeBruijn { - pub fn new(k: usize) -> Self { - Self { - nodes: FastHashMap::with_hasher(RandomState::new()), - k, - } - } - - pub fn with_capacity(k: usize, capacity: usize) -> Self { - Self { - nodes: FastHashMap::with_capacity_and_hasher(capacity, RandomState::new()), - k, - } - } - - /// Insert a canonical kmer into the graph. No-op if already present. - pub fn push(&mut self, kmer: CanonicalKmer) { - self.nodes - .entry(kmer) - .or_insert_with(|| Cell::new(Node::default())); - } - - /// For every node, find its unique right/left canonical neighbour (if any) - /// and store the nucleotide index in the Node flags. - /// - /// Single pass thanks to Cell interior mutability. - pub fn compute_degrees(&self) { - for (&kmer, cell) in &self.nodes { - let (rc, rn) = count_neighbors(kmer.right_canonical_neighbors(self.k), &self.nodes); - let (lc, ln) = count_neighbors(kmer.left_canonical_neighbors(self.k), &self.nodes); - - let mut node = cell.get(); - node.set_right(rc, rn); - node.set_left(lc, ln); - cell.set(node); - } - } - - /// Iterates over the right neighbors of `kmer`. - pub fn iter_right_neighbors( - &self, - kmer: CanonicalKmer, - ) -> impl Iterator + '_ { - kmer.right_canonical_neighbors(self.k) - .into_iter() - .filter_map(|kmer| { - self.nodes.get(&kmer)?; - Some(kmer) - }) - } - - /// Iterates over the left neighbors of `kmer`. - pub fn iter_left_neighbors( - &self, - kmer: CanonicalKmer, - ) -> impl Iterator + '_ { - kmer.left_canonical_neighbors(self.k) - .into_iter() - .filter_map(|kmer| { - self.nodes.get(&kmer)?; - Some(kmer) - }) - } - - pub fn is_visited(&self, kmer: &CanonicalKmer) -> Option { - self.nodes.get(kmer).map(|cell| cell.get().is_visited()) - } - - pub fn set_visited(&self, kmer: CanonicalKmer) { - if let Some(cell) = self.nodes.get(&kmer) { - let mut node = cell.get(); - node.set_visited(); - cell.set(node); - } - } - - /// Returns the single right neighbor of `kmer`, if it exists. - pub fn the_single_right_neighbor(&self, kmer: CanonicalKmer) -> Option { - let node = self.nodes.get(&kmer)?.get(); - if !node.can_extend_right() { - return None; - } - let next = kmer - .into_kmer() - .push_right(node.right_nuc(), self.k) - .canonical(self.k); - self.nodes.contains_key(&next).then_some(next) - } - - /// Returns the single left neighbor of `kmer`, if it exists. - pub fn the_single_left_neighbor(&self, kmer: CanonicalKmer) -> Option { - let node = self.nodes.get(&kmer)?.get(); - if !node.can_extend_left() { - return None; - } - let next = kmer - .into_kmer() - .push_left(node.left_nuc(), self.k) - .canonical(self.k); - self.nodes.contains_key(&next).then_some(next) - } - - /// Internal iterator over unitig-start nodes; drives `iter_unitig`. - /// - /// MUST NOT be consumed standalone: the second pass finds cycle nodes only - /// because `iter_unitig` lazily interleaves chain traversal between the two passes. - /// - /// Two passes: - /// 1. Chain ends / isolated nodes (at most one extension missing): - /// - `!can_extend_left` → yield canonical form - /// - `!can_extend_right` → yield reverse complement - /// 2. Nodes still unvisited → part of a cycle; yield canonical form. - fn start_iter(&self) -> impl Iterator)> + '_ { - StartIter::new(self) - } - - fn next_unitig_kmer(&self, kmer: Kmer) -> Option { - let canonical = kmer.canonical(self.k); - let node = self.nodes.get(&canonical)?.get(); - - let direct = kmer.raw() == canonical.raw(); - - if (direct && !node.can_extend_right()) || (!direct && !node.can_extend_left()) { - return None; - } - - let next_c: CanonicalKmer = if direct { - canonical - .into_kmer() - .push_right(node.right_nuc(), self.k) - .canonical(self.k) - } else { - canonical - .into_kmer() - .push_left(node.left_nuc(), self.k) - .canonical(self.k) - }; - - let cell = self.nodes.get(&next_c)?; - let next_node = cell.get(); - if next_node.is_visited() { - return None; - } - - let oriented = oriented_next(kmer, next_c, self.k); - let ndirect = oriented.raw() == next_c.raw(); - - if (ndirect && next_node.n_right_neighbours() > 1) - || (!ndirect && next_node.n_left_neighbours() > 1) - { - return None; - } - - let mut updated = next_node; - updated.set_visited(); - cell.set(updated); - Some(oriented) - } - - fn next_longtig_kmer(&self, kmer: Kmer) -> Option { - let k = self.k; - let canonical = kmer.canonical(k); - let node = self.nodes.get(&canonical)?.get(); - - let direct = kmer.raw() == canonical.raw(); - - if (direct && node.n_right_neighbours() == 0) || (!direct && node.n_left_neighbours() == 0) - { - return None; - } - - let next_c: CanonicalKmer = if direct { - if node.can_extend_right() { - canonical - .into_kmer() - .push_right(node.right_nuc(), k) - .canonical(k) - } else { - self.iter_right_neighbors(canonical) - .filter(|n| !self.is_visited(n).unwrap_or(true)) - .next()? - } - } else { - if node.can_extend_left() { - canonical - .into_kmer() - .push_left(node.left_nuc(), k) - .canonical(k) - } else { - self.iter_left_neighbors(canonical) - .filter(|n| !self.is_visited(n).unwrap_or(true)) - .next()? - } - }; - - let cell = self.nodes.get(&next_c)?; - let next_node = cell.get(); - if next_node.is_visited() { - return None; - } - - let oriented = oriented_next(kmer, next_c, self.k); - let ndirect = oriented.raw() == next_c.raw(); - - if (ndirect && next_node.n_right_neighbours() > 1) - || (!ndirect && next_node.n_left_neighbours() > 1) - { - return None; - } - - let mut updated = next_node; - updated.set_visited(); - cell.set(updated); - Some(oriented) - } - - fn iter_unitig_kmers(&self, start: Kmer) -> UnitigIter<'_> { - UnitigIter { - graph: self, - current: Some(start), - } - } - - fn iter_longtig_kmers(&self, start: Kmer) -> LongtigIter<'_> { - LongtigIter { - graph: self, - current: Some(start), - } - } - - pub fn iter_unitig(&self) -> impl Iterator + '_ { - let k = self.k; - self.start_iter().map(move |(start, first_next)| { - let mut nucs: Vec = (0..k).map(|i| start.nucleotide(i)).collect(); - if let Some(next_c) = first_next { - for kmer in self.iter_unitig_kmers(next_c) { - nucs.push(kmer.nucleotide(k - 1)); - } - } - Unitig::from_nucleotides(&nucs) - }) - } - - pub fn iter_longtig(&self) -> impl Iterator + '_ { - let k = self.k; - self.start_iter().map(move |(start, first_next)| { - let mut nucs: Vec = (0..k).map(|i| start.nucleotide(i)).collect(); - if let Some(next_c) = first_next { - for kmer in self.iter_longtig_kmers(next_c) { - nucs.push(kmer.nucleotide(k - 1)); - } - } - Unitig::from_nucleotides(&nucs) - }) - } - - /// Write all unitigs to `out` in FASTA format. - /// - /// Calls [`obifastwrite::write_unitig`] for each unitig produced by - /// [`iter_unitig`]. Stops and returns the first I/O error encountered. - pub fn write_fasta(&self, out: &mut W, unitig: bool) -> io::Result<()> { - if unitig { - for unitig in self.iter_unitig() { - write_unitig(&unitig, self.k, out)?; - } - } else { - for unitig in self.iter_longtig() { - write_unitig(&unitig, self.k, out)?; - } - } - Ok(()) - } - - pub fn len(&self) -> usize { - self.nodes.len() - } - - pub fn is_empty(&self) -> bool { - self.nodes.is_empty() - } -} - -// --- StartIter ----------------------------------------------------------------- -struct StartIter<'a> { - graph: &'a GraphDeBruijn, - nodes: hashbrown::hash_map::Iter<'a, CanonicalKmer, Cell>, - suspended: Vec, - in_cycle_pass: bool, -} - -impl<'a> StartIter<'a> { - fn new(graph: &'a GraphDeBruijn) -> Self { - Self { - graph, - nodes: graph.nodes.iter(), - suspended: Vec::new(), - in_cycle_pass: false, - } - } -} - -impl<'a> Iterator for StartIter<'a> { - type Item = (CanonicalKmer, Option); - - fn next(&mut self) -> Option<(CanonicalKmer, Option)> { - loop { - let current = if let Some(k) = self.suspended.pop() { - k - } else { - match self.nodes.next() { - Some((&k, _)) => k, - None => { - if self.in_cycle_pass { - return None; - } - self.in_cycle_pass = true; - self.nodes = self.graph.nodes.iter(); - match self.nodes.next() { - Some((&k, _)) => k, - None => return None, - } - } - } - }; - - let node = match self.graph.nodes.get(¤t) { - Some(c) => c.get(), - None => continue, - }; - if node.is_visited() { - continue; - } - if !self.in_cycle_pass && node.can_extend_left() { - continue; - } - - self.graph.set_visited(current); - - if let Some(next) = self.graph.the_single_right_neighbor(current) { - if self.graph.is_visited(&next).unwrap_or(true) { - return Some((current, None)); - } - self.graph.set_visited(next); - let oriented = oriented_next(current.into_kmer(), next, self.graph.k); - return Some((current, Some(oriented))); - } - - let mut first_neighbor: Option = None; - for neighbor in self.graph.iter_right_neighbors(current) { - if self.graph.is_visited(&neighbor).unwrap_or(true) { - continue; - } - if first_neighbor.is_none() { - self.graph.set_visited(neighbor); - first_neighbor = Some(neighbor); - } else { - self.suspended.push(neighbor); - } - } - - let oriented = match first_neighbor { - Some(neighbor) => Some(oriented_next(current.into_kmer(), neighbor, self.graph.k)), - None => None, - }; - return Some((current, oriented)); - } - } -} - -// ── UnitigIter ──────────────────────────────────────────────────────────────── - -struct UnitigIter<'a> { - graph: &'a GraphDeBruijn, - current: Option, -} - -impl Iterator for UnitigIter<'_> { - type Item = Kmer; - - fn next(&mut self) -> Option { - let current = self.current?; - self.current = self.graph.next_unitig_kmer(current); - Some(current) - } -} - -// ── UnitigIter ──────────────────────────────────────────────────────────────── - -struct LongtigIter<'a> { - graph: &'a GraphDeBruijn, - current: Option, -} - -impl Iterator for LongtigIter<'_> { - type Item = Kmer; - - fn next(&mut self) -> Option { - let current = self.current?; - self.current = self.graph.next_longtig_kmer(current); - Some(current) - } -} - -// ── helpers ─────────────────────────────────────────────────────────────────── - -fn oriented_next(from: Kmer, to: CanonicalKmer, k: usize) -> Kmer { - if from.is_overlapping(to.into_kmer(), k) { - to.into_kmer() - } else { - to.revcomp(k) - } -} - -/// Returns `Some(i)` if exactly one of the four canonical neighbours exists in -/// the graph, where `i` is its index (0=A, 1=C, 2=G, 3=T). Returns `None` for -/// zero or ≥2 existing neighbours. -fn count_neighbors( - neighbors: [CanonicalKmer; 4], - nodes: &FastHashMap>, -) -> (u8, Option) { - let mut count = 0u8; - let mut first = None; - for (i, neighbour) in neighbors.iter().enumerate() { - if nodes.contains_key(neighbour) { - count += 1; - if first.is_none() { - first = Some(i as u8); - } - } - } - if count == 1 { - (1, first) - } else { - (count, None) - } -} - -// ── tests ───────────────────────────────────────────────────────────────────── - -#[cfg(test)] -mod tests { - use super::*; - - // Build a graph from an ASCII sequence, inserting all canonical k-mers. - fn graph_from_ascii(seq: &[u8], k: usize) -> GraphDeBruijn { - let mut g = GraphDeBruijn::new(k); - for i in 0..=seq.len().saturating_sub(k) { - g.push(Kmer::from_ascii(&seq[i..i + k], k).unwrap().canonical(k)); - } - g - } - - // Collect all canonical k-mers from an ASCII sequence into a sorted vec. - fn canonical_kmers(seq: &[u8], k: usize) -> Vec { - let mut v: Vec = (0..=seq.len().saturating_sub(k)) - .map(|i| Kmer::from_ascii(&seq[i..i + k], k).unwrap().canonical(k)) - .collect(); - v.sort_unstable(); - v.dedup(); - v - } - - // ── push / canonicalisation ─────────────────────────────────────────────── - - #[test] - fn push_deduplicates_revcomp() { - let k = 5; - let kmer = Kmer::from_ascii(b"ACGTA", k).unwrap(); - let mut g = GraphDeBruijn::new(k); - g.push(kmer.canonical(k)); - g.push(kmer.revcomp(k).canonical(k)); - assert_eq!(g.len(), 1, "kmer and its revcomp must map to the same node"); - } - - #[test] - fn push_palindrome_single_node() { - // ACGT is its own revcomp - let k = 4; - let kmer = Kmer::from_ascii(b"ACGT", k).unwrap(); - assert_eq!(kmer, kmer.revcomp(k), "test requires a palindrome"); - let mut g = GraphDeBruijn::new(k); - g.push(kmer.canonical(k)); - assert_eq!(g.len(), 1); - } - - // ── compute_degrees on a linear chain ──────────────────────────────────── - - // AAAAGGGG with k=5 → 4 distinct k-mers (AAAAG, AAAGG, AAGGG, AGGGG), - // clean linear chain, no Watson-Crick palindrome in first k-1 bases. - fn linear_chain_graph(k: usize) -> (GraphDeBruijn, Vec) { - let seq = b"AAAAGGGG"; - let g = graph_from_ascii(seq, k); - let kmers = canonical_kmers(seq, k); - (g, kmers) - } - - #[test] - fn degrees_linear_chain_node_count() { - let k = 5; - let (g, kmers) = linear_chain_graph(k); - assert_eq!(g.len(), kmers.len()); - } - - #[test] - fn degrees_linear_chain_extensions() { - // A linear chain yields exactly 1 unitig covering all k-mers. - // Note: start_iter must not be consumed standalone — its second pass only - // finds true cycle nodes when interleaved with chain traversal (iter_unitig). - let k = 5; - let seq = b"AAAAGGGG"; - let g = graph_from_ascii(seq, k); - g.compute_degrees(); - let unitigs: Vec = g.iter_unitig().collect(); - assert_eq!(unitigs.len(), 1, "linear chain → exactly one unitig"); - // seql = k + (n_kmers - 1) = 5 + 3 = 8 = seq.len() - assert_eq!( - unitigs[0].seql(), - seq.len(), - "unitig spans the full sequence" - ); - assert_eq!( - kmers_from_unitigs(&unitigs, k), - canonical_kmers(seq, k), - "unitig k-mers must equal inserted k-mers" - ); - } - - // ── unitig reconstruction ───────────────────────────────────────────────── - - // Round-trip: all canonical k-mers in the unitigs == all canonical k-mers inserted. - fn kmers_from_unitigs(unitigs: &[Unitig], k: usize) -> Vec { - let mut v: Vec = unitigs - .iter() - .flat_map(|u| u.iter_canonical_kmers(k)) - .collect(); - v.sort_unstable(); - v.dedup(); - v - } - - #[test] - fn unitig_roundtrip_linear() { - // Non-repetitive sequence: no k-mer appears twice, no homopolymer run of length k. - // ACGTGGCTA with k=5 → 5 distinct k-mers forming a clean linear chain. - let k = 5; - let seq = b"ACCTGGCTA"; - let g = graph_from_ascii(seq, k); - g.compute_degrees(); - println!("Les kmers:"); - for (kmer, v) in g.nodes.iter() { - println!( - "{}: {}", - String::from_utf8_lossy(&kmer.to_ascii(k)), - v.get() - ); - } - // println!("Les starts:"); - // for (start, first_next) in g.start_iter() { - // if let Some(next) = first_next { - // println!( - // "{}->{}", - // String::from_utf8_lossy(&start.to_ascii(k)), - // String::from_utf8_lossy(&next.to_ascii(k)) - // ) - // } else { - // println!("{}->None", String::from_utf8_lossy(&start.to_ascii(k))) - // } - // } - - println!("Les unitig:"); - let unitigs: Vec = g.iter_unitig().collect(); - for unitig in &unitigs { - println!("{}", String::from_utf8_lossy(&unitig.to_ascii())); - } - assert_eq!( - unitigs.len(), - 1, - "linear chain → exactly one unitig {:?}", - unitigs - ); - assert_eq!( - kmers_from_unitigs(&unitigs, k), - canonical_kmers(seq, k), - "unitig must contain exactly the inserted k-mers" - ); - } - - #[test] - fn unitig_roundtrip_longer_sequence() { - // Longer non-repetitive sequence with no repeated k-mer of length k. - // ACGTGGCTATCGAC with k=5 → 10 distinct k-mers, one linear chain. - let k = 5; - let seq = b"ACGTGGCTATCGAC"; - let g = graph_from_ascii(seq, k); - g.compute_degrees(); - let unitigs: Vec = g.iter_unitig().collect(); - let mut got = kmers_from_unitigs(&unitigs, k); - let mut expected = canonical_kmers(seq, k); - got.sort_unstable(); - expected.sort_unstable(); - assert_eq!(got, expected); - } - - #[test] - fn unitig_isolated_node() { - // Single k-mer with no neighbours - let k = 5; - let kmer = Kmer::from_ascii(b"ACGTA", k).unwrap(); - let mut g = GraphDeBruijn::new(k); - g.push(kmer.canonical(k)); - g.compute_degrees(); - let unitigs: Vec = g.iter_unitig().collect(); - assert_eq!(unitigs.len(), 1); - assert_eq!(unitigs[0].seql(), k); - } - - #[test] - fn unitig_two_isolated_nodes() { - let k = 5; - let mut g = GraphDeBruijn::new(k); - // Two k-mers that share no (k-1)-overlap - g.push(Kmer::from_ascii(b"AAAAA", k).unwrap().canonical(k)); - g.push(Kmer::from_ascii(b"TTTTT", k).unwrap().canonical(k)); // same canonical as AAAAA — dedup - // They collapse to one canonical node - assert_eq!(g.len(), 1); - } - - #[test] - fn unitig_two_truly_distinct_isolated_nodes() { - let k = 5; - let mut g = GraphDeBruijn::new(k); - g.push(Kmer::from_ascii(b"AAAAC", k).unwrap().canonical(k)); - g.push(Kmer::from_ascii(b"GGGGT", k).unwrap().canonical(k)); - g.compute_degrees(); - let unitigs: Vec = g.iter_unitig().collect(); - // Each isolated node → one unitig of length k - assert_eq!(unitigs.len(), 2); - assert!(unitigs.iter().all(|u| u.seql() == k)); - } - - // ── all k-mers covered, none duplicated ─────────────────────────────────── - - #[test] - fn no_kmer_lost_or_duplicated() { - let k = 7; - let seq = b"ACGTACGTACGTTTTTACGTACGT"; - let g = graph_from_ascii(seq, k); - g.compute_degrees(); - let unitigs: Vec = g.iter_unitig().collect(); - let got = kmers_from_unitigs(&unitigs, k); - let expected = canonical_kmers(seq, k); - assert_eq!( - got.len(), - expected.len(), - "kmer count mismatch: got {}, expected {}", - got.len(), - expected.len() - ); - assert_eq!(got, expected, "kmer sets differ"); - } - - // ── cycle coverage ──────────────────────────────────────────────────────── - - #[test] - fn cycle_kmers_not_lost() { - // ACGTACGT with k=5 forms a pure cycle: ACGTA→CGTAC→GTACG→TACGT→ACGTA. - // start_iter first pass yields nothing (all nodes internal); second pass - // picks up cycle entries. All 4 k-mers must appear in the unitigs. - let k = 5; - let seq = b"ACGTACGT"; - let g = graph_from_ascii(seq, k); - g.compute_degrees(); - let unitigs: Vec = g.iter_unitig().collect(); - let got = kmers_from_unitigs(&unitigs, k); - let expected = canonical_kmers(seq, k); - assert_eq!(got.len(), expected.len(), "cycle k-mers lost"); - assert_eq!(got, expected); - } - - // ── branching graph ─────────────────────────────────────────────────────── - // - // Topology (k=5): two sources A,B converge at C; chain C-D-E-F; - // F branches to G and H; H continues H-M-N; second source J feeds I-F. - // Every k-mer must appear in exactly one unitig (no duplication, no loss). - #[test] - fn branching_graph_no_kmer_lost_or_duplicated() { - // Build sequences that realise the topology without accidental overlaps. - // Each "node" is a distinct 5-mer; edges share a 4-mer suffix/prefix. - // We use long non-repetitive sequences and extract only the required kmers. - let k: usize = 5; - let mut g = GraphDeBruijn::new(k); - - // Helper: insert all k-mers of a sequence. - let mut insert = |seq: &[u8]| { - for i in 0..=seq.len().saturating_sub(k) { - g.push(Kmer::from_ascii(&seq[i..i + k], k).unwrap().canonical(k)); - } - }; - - // Chains that realise the topology: - // A-C (A→C share 4-mer overlap) - // B-C (B→C share 4-mer overlap, different prefix) - // C-D-E-F - // F-G (F→G) - // F-H-M-N (F→H→M→N) - // J-I-F (J→I→F) - insert(b"AACGTGGCTA"); // A-C-D … part of the right branch - insert(b"TACGTGGCTA"); // B-C-D … merges at C (same C-suffix) - insert(b"CGTGGCTACG"); // continues D-E-F-G - insert(b"CGTGGCTACC"); // F-H branch (different last base) - insert(b"GTGGCTACCGT"); // H-M-N continuation - insert(b"TTCGTGGCTA"); // J-I-F (different J prefix) - - g.compute_degrees(); - let unitigs: Vec = g.iter_unitig().collect(); - - // Collect all k-mers from unitigs. - let got = kmers_from_unitigs(&unitigs, k); - - // Collect all distinct canonical k-mers inserted. - let mut expected: Vec = Vec::new(); - for seq in &[ - b"AACGTGGCTA".as_slice(), - b"TACGTGGCTA", - b"CGTGGCTACG", - b"CGTGGCTACC", - b"GTGGCTACCGT", - b"TTCGTGGCTA", - ] { - expected.extend(canonical_kmers(seq, k)); - } - expected.sort_unstable(); - expected.dedup(); - - assert_eq!( - got.len(), - expected.len(), - "k-mer count mismatch: got {}, expected {}", - got.len(), - expected.len() - ); - assert_eq!(got, expected, "k-mer sets differ"); - } -} +pub use debruijn::GraphDeBruijn; diff --git a/src/obidebruinj/src/tests/debruijn.rs b/src/obidebruinj/src/tests/debruijn.rs new file mode 100644 index 0000000..1ba2459 --- /dev/null +++ b/src/obidebruinj/src/tests/debruijn.rs @@ -0,0 +1,301 @@ +use super::*; +use obikseq::{k, set_k}; + +// Build a graph from an ASCII sequence, inserting all canonical k-mers. +fn graph_from_ascii(seq: &[u8]) -> GraphDeBruijn { + let mut g = GraphDeBruijn::new(); + let k = k(); + for i in 0..=seq.len().saturating_sub(k) { + g.push(Kmer::from_ascii(&seq[i..i + k]).unwrap().canonical()); + } + g +} + +// Collect all canonical k-mers from an ASCII sequence into a sorted vec. +fn canonical_kmers(seq: &[u8]) -> Vec { + let k = k(); + let mut v: Vec = (0..=seq.len().saturating_sub(k)) + .map(|i| Kmer::from_ascii(&seq[i..i + k]).unwrap().canonical()) + .collect(); + v.sort_unstable(); + v.dedup(); + v +} + +// ── push / canonicalisation ─────────────────────────────────────────────── + +#[test] +fn push_deduplicates_revcomp() { + let k = 5; + set_k(k); + let kmer = Kmer::from_ascii(b"ACGTA").unwrap(); + let mut g = GraphDeBruijn::new(); + g.push(kmer.canonical()); + g.push(kmer.revcomp().canonical()); + assert_eq!(g.len(), 1, "kmer and its revcomp must map to the same node"); +} + +#[test] +fn push_palindrome_single_node() { + // ACGT is its own revcomp + let k = 4; + set_k(k); + let kmer = Kmer::from_ascii(b"ACGT").unwrap(); + assert_eq!(kmer, kmer.revcomp(), "test requires a palindrome"); + let mut g = GraphDeBruijn::new(); + g.push(kmer.canonical()); + assert_eq!(g.len(), 1); +} + +// ── compute_degrees on a linear chain ──────────────────────────────────── + +// AAAAGGGG with k=5 → 4 distinct k-mers (AAAAG, AAAGG, AAGGG, AGGGG), +// clean linear chain, no Watson-Crick palindrome in first k-1 bases. +fn linear_chain_graph() -> (GraphDeBruijn, Vec) { + let seq = b"AAAAGGGG"; + let g = graph_from_ascii(seq); + let kmers = canonical_kmers(seq); + (g, kmers) +} + +#[test] +fn degrees_linear_chain_node_count() { + let k = 5; + set_k(k); + let (g, kmers) = linear_chain_graph(); + assert_eq!(g.len(), kmers.len()); +} + +#[test] +fn degrees_linear_chain_extensions() { + // A linear chain yields exactly 1 unitig covering all k-mers. + // Note: start_iter must not be consumed standalone — its second pass only + // finds true cycle nodes when interleaved with chain traversal (iter_unitig). + let k = 5; + set_k(k); + let seq = b"AAAAGGGG"; + let g = graph_from_ascii(seq); + g.compute_degrees(); + let unitigs: Vec = g.iter_unitig().collect(); + assert_eq!(unitigs.len(), 1, "linear chain → exactly one unitig"); + // seql = k + (n_kmers - 1) = 5 + 3 = 8 = seq.len() + assert_eq!( + unitigs[0].seql(), + seq.len(), + "unitig spans the full sequence" + ); + assert_eq!( + kmers_from_unitigs(&unitigs), + canonical_kmers(seq), + "unitig k-mers must equal inserted k-mers" + ); +} + +// ── unitig reconstruction ───────────────────────────────────────────────── + +// Round-trip: all canonical k-mers in the unitigs == all canonical k-mers inserted. +fn kmers_from_unitigs(unitigs: &[Unitig]) -> Vec { + let mut v: Vec = unitigs + .iter() + .flat_map(|u| u.iter_canonical_kmers()) + .collect(); + v.sort_unstable(); + v.dedup(); + v +} + +#[test] +fn unitig_roundtrip_linear() { + // Non-repetitive sequence: no k-mer appears twice, no homopolymer run of length k. + // ACGTGGCTA with k=5 → 5 distinct k-mers forming a clean linear chain. + let k = 5; + set_k(k); + let seq = b"ACCTGGCTA"; + let g = graph_from_ascii(seq); + g.compute_degrees(); + println!("Les kmers:"); + for (kmer, v) in g.nodes.iter() { + println!("{}: {}", String::from_utf8_lossy(&kmer.to_ascii()), v.get()); + } + + println!("Les unitig:"); + let unitigs: Vec = g.iter_unitig().collect(); + for unitig in &unitigs { + println!("{}", String::from_utf8_lossy(&unitig.to_ascii())); + } + assert_eq!( + unitigs.len(), + 1, + "linear chain → exactly one unitig {:?}", + unitigs + ); + assert_eq!( + kmers_from_unitigs(&unitigs), + canonical_kmers(seq), + "unitig must contain exactly the inserted k-mers" + ); +} + +#[test] +fn unitig_roundtrip_longer_sequence() { + // Longer non-repetitive sequence with no repeated k-mer of length k. + // ACGTGGCTATCGAC with k=5 → 10 distinct k-mers, one linear chain. + let k = 5; + set_k(k); + let seq = b"ACGTGGCTATCGAC"; + let g = graph_from_ascii(seq); + g.compute_degrees(); + let unitigs: Vec = g.iter_unitig().collect(); + let mut got = kmers_from_unitigs(&unitigs); + let mut expected = canonical_kmers(seq); + got.sort_unstable(); + expected.sort_unstable(); + assert_eq!(got, expected); +} + +#[test] +fn unitig_isolated_node() { + // Single k-mer with no neighbours + let k = 5; + set_k(k); + let kmer = Kmer::from_ascii(b"ACGTA").unwrap(); + let mut g = GraphDeBruijn::new(); + g.push(kmer.canonical()); + g.compute_degrees(); + let unitigs: Vec = g.iter_unitig().collect(); + assert_eq!(unitigs.len(), 1); + assert_eq!(unitigs[0].seql(), k); +} + +#[test] +fn unitig_two_isolated_nodes() { + let k = 5; + set_k(k); + let mut g = GraphDeBruijn::new(); + // Two k-mers that share no (k-1)-overlap + g.push(Kmer::from_ascii(b"AAAAA").unwrap().canonical()); + g.push(Kmer::from_ascii(b"TTTTT").unwrap().canonical()); // same canonical as AAAAA — dedup + // They collapse to one canonical node + assert_eq!(g.len(), 1); +} + +#[test] +fn unitig_two_truly_distinct_isolated_nodes() { + let k = 5; + set_k(k); + let mut g = GraphDeBruijn::new(); + g.push(Kmer::from_ascii(b"AAAAC").unwrap().canonical()); + g.push(Kmer::from_ascii(b"GGGGT").unwrap().canonical()); + g.compute_degrees(); + let unitigs: Vec = g.iter_unitig().collect(); + // Each isolated node → one unitig of length k + assert_eq!(unitigs.len(), 2); + assert!(unitigs.iter().all(|u| u.seql() == k)); +} + +// ── all k-mers covered, none duplicated ─────────────────────────────────── + +#[test] +fn no_kmer_lost_or_duplicated() { + let k = 7; + set_k(k); + let seq = b"ACGTACGTACGTTTTTACGTACGT"; + let g = graph_from_ascii(seq); + g.compute_degrees(); + let unitigs: Vec = g.iter_unitig().collect(); + let got = kmers_from_unitigs(&unitigs); + let expected = canonical_kmers(seq); + assert_eq!( + got.len(), + expected.len(), + "kmer count mismatch: got {}, expected {}", + got.len(), + expected.len() + ); + assert_eq!(got, expected, "kmer sets differ"); +} + +// ── cycle coverage ──────────────────────────────────────────────────────── + +#[test] +fn cycle_kmers_not_lost() { + // ACGTACGT with k=5 forms a pure cycle: ACGTA→CGTAC→GTACG→TACGT→ACGTA. + // start_iter first pass yields nothing (all nodes internal); second pass + // picks up cycle entries. All 4 k-mers must appear in the unitigs. + let k = 5; + set_k(k); + let seq = b"ACGTACGT"; + let g = graph_from_ascii(seq); + g.compute_degrees(); + let unitigs: Vec = g.iter_unitig().collect(); + let got = kmers_from_unitigs(&unitigs); + let expected = canonical_kmers(seq); + assert_eq!(got.len(), expected.len(), "cycle k-mers lost"); + assert_eq!(got, expected); +} + +// ── branching graph ─────────────────────────────────────────────────────── +// +// Topology (k=5): two sources A,B converge at C; chain C-D-E-F; +// F branches to G and H; H continues H-M-N; second source J feeds I-F. +// Every k-mer must appear in exactly one unitig (no duplication, no loss). +#[test] +fn branching_graph_no_kmer_lost_or_duplicated() { + // Build sequences that realise the topology without accidental overlaps. + // Each "node" is a distinct 5-mer; edges share a 4-mer suffix/prefix. + // We use long non-repetitive sequences and extract only the required kmers. + let k: usize = 5; + set_k(k); + let mut g = GraphDeBruijn::new(); + + // Helper: insert all k-mers of a sequence. + let mut insert = |seq: &[u8]| { + for i in 0..=seq.len().saturating_sub(k) { + g.push(Kmer::from_ascii(&seq[i..i + k]).unwrap().canonical()); + } + }; + + // Chains that realise the topology: + // A-C (A→C share 4-mer overlap) + // B-C (B→C share 4-mer overlap, different prefix) + // C-D-E-F + // F-G (F→G) + // F-H-M-N (F→H→M→N) + // J-I-F (J→I→F) + insert(b"AACGTGGCTA"); // A-C-D … part of the right branch + insert(b"TACGTGGCTA"); // B-C-D … merges at C (same C-suffix) + insert(b"CGTGGCTACG"); // continues D-E-F-G + insert(b"CGTGGCTACC"); // F-H branch (different last base) + insert(b"GTGGCTACCGT"); // H-M-N continuation + insert(b"TTCGTGGCTA"); // J-I-F (different J prefix) + + g.compute_degrees(); + let unitigs: Vec = g.iter_unitig().collect(); + + // Collect all k-mers from unitigs. + let got = kmers_from_unitigs(&unitigs); + + // Collect all distinct canonical k-mers inserted. + let mut expected: Vec = Vec::new(); + for seq in &[ + b"AACGTGGCTA".as_slice(), + b"TACGTGGCTA", + b"CGTGGCTACG", + b"CGTGGCTACC", + b"GTGGCTACCGT", + b"TTCGTGGCTA", + ] { + expected.extend(canonical_kmers(seq)); + } + expected.sort_unstable(); + expected.dedup(); + + assert_eq!( + got.len(), + expected.len(), + "k-mer count mismatch: got {}, expected {}", + got.len(), + expected.len() + ); + assert_eq!(got, expected, "k-mer sets differ"); +} diff --git a/src/obifastwrite/Cargo.toml b/src/obifastwrite/Cargo.toml index 1c49380..2615305 100644 --- a/src/obifastwrite/Cargo.toml +++ b/src/obifastwrite/Cargo.toml @@ -6,3 +6,6 @@ edition = "2024" [dependencies] obikseq = { path = "../obikseq" } xxhash-rust = { version = "0.8", features = ["xxh64"] } + +[dev-dependencies] +obikseq = { path = "../obikseq", features = ["test-utils"] } diff --git a/src/obifastwrite/src/lib.rs b/src/obifastwrite/src/lib.rs index 6855a2c..0b1f3d9 100644 --- a/src/obifastwrite/src/lib.rs +++ b/src/obifastwrite/src/lib.rs @@ -34,7 +34,7 @@ mod fasta; use std::io::{self, Write}; -use obikseq::{kmer::CanonicalKmer, superkmer::SuperKmer, unitig::Unitig}; +use obikseq::{Minimizer, SuperKmer, Unitig}; use xxhash_rust::xxh64::xxh64; // ── public API ──────────────────────────────────────────────────────────────── @@ -57,12 +57,12 @@ pub fn write_scatter( k: usize, m: usize, partition: usize, - minimizer: CanonicalKmer, + minimizer: Minimizer, ) -> io::Result<()> { let ascii = sk.to_ascii(); let id = seq_id(&ascii); let seq_len = ascii.len(); - let min_seq = minimizer.to_ascii(m); + let min_seq = minimizer.to_ascii(); writeln!( out, @@ -154,7 +154,6 @@ fn seq_id(ascii: &[u8]) -> String { #[cfg(test)] mod tests { use super::*; - use obikseq::kmer::Kmer; use obikseq::superkmer::SuperKmer; fn make(seq: &[u8]) -> SuperKmer { @@ -172,23 +171,27 @@ mod tests { #[test] fn scatter_header_contains_minimizer_field() { let sk = make(b"ACGTACGTACGT"); - let out = capture(|w| write_scatter(&sk, w, 4, 3, 7, CanonicalKmer::from_raw_unchecked(0))); + let out = capture(|w| write_scatter(&sk, w, 4, 3, 7, Minimizer::from_raw_unchecked(0))); assert!(out.contains("\"minimizer\":\"")); assert!(!out.contains("\"count\":")); } #[test] fn scatter_minimizer_decoded_from_hash() { - // min_hash for "ACG" (A=0,C=1,G=2, m=3): 0*16 + 1*4 + 2 = 6 + // "ACG" right-aligned: A=00, C=01, G=10 → 0b000110 = 6 + // Left-aligned for m=3: shift by 64 − 2·3 = 58. + // set_m(3) so that Minimizer::to_ascii() decodes exactly 3 bases. + obikseq::params::set_m(3); let sk = make(b"ACGTACGTACGT"); - let out = capture(|w| write_scatter(&sk, w, 4, 3, 0, CanonicalKmer::from_raw_unchecked(Kmer::from_raw_right(6, 3).raw()))); + let minimizer = Minimizer::from_raw_unchecked(6u64 << (64 - 2 * 3)); + let out = capture(|w| write_scatter(&sk, w, 4, 3, 0, minimizer)); assert!(out.contains("\"minimizer\":\"ACG\""), "got: {out}"); } #[test] fn scatter_fields_present() { let sk = make(b"ACGTACGTACGT"); - let out = capture(|w| write_scatter(&sk, w, 4, 3, 5, CanonicalKmer::from_raw_unchecked(0))); + let out = capture(|w| write_scatter(&sk, w, 4, 3, 5, Minimizer::from_raw_unchecked(0))); assert!(out.contains("\"seq_length\":12")); assert!(out.contains("\"kmer_size\":4")); assert!(out.contains("\"minimizer_size\":3")); @@ -198,7 +201,7 @@ mod tests { #[test] fn scatter_sequence_line_correct() { let sk = make(b"ACGTACGT"); - let out = capture(|w| write_scatter(&sk, w, 4, 2, 0, CanonicalKmer::from_raw_unchecked(0))); + let out = capture(|w| write_scatter(&sk, w, 4, 2, 0, Minimizer::from_raw_unchecked(0))); let lines: Vec<&str> = out.lines().collect(); assert_eq!(lines[1], "ACGTACGT"); } @@ -241,7 +244,7 @@ mod tests { let sk1 = make(b"ACGTACGT"); let sk2 = make(b"ACGTACGT"); - let id1 = capture(|w| write_scatter(&sk1, w, 4, 2, 0, CanonicalKmer::from_raw_unchecked(0))) + let id1 = capture(|w| write_scatter(&sk1, w, 4, 2, 0, Minimizer::from_raw_unchecked(0))) .lines() .next() .unwrap() @@ -249,7 +252,7 @@ mod tests { .next() .unwrap()[1..] .to_string(); - let id2 = capture(|w| write_scatter(&sk2, w, 4, 2, 0, CanonicalKmer::from_raw_unchecked(0))) + let id2 = capture(|w| write_scatter(&sk2, w, 4, 2, 0, Minimizer::from_raw_unchecked(0))) .lines() .next() .unwrap() @@ -265,7 +268,7 @@ mod tests { let sk1 = make(b"ACGTACGT"); let sk2 = make(b"TTTTTTTT"); - let id1 = capture(|w| write_scatter(&sk1, w, 4, 2, 0, CanonicalKmer::from_raw_unchecked(0))) + let id1 = capture(|w| write_scatter(&sk1, w, 4, 2, 0, Minimizer::from_raw_unchecked(0))) .lines() .next() .unwrap() @@ -273,7 +276,7 @@ mod tests { .next() .unwrap()[1..] .to_string(); - let id2 = capture(|w| write_scatter(&sk2, w, 4, 2, 0, CanonicalKmer::from_raw_unchecked(0))) + let id2 = capture(|w| write_scatter(&sk2, w, 4, 2, 0, Minimizer::from_raw_unchecked(0))) .lines() .next() .unwrap() @@ -287,7 +290,7 @@ mod tests { #[test] fn id_is_16_hex_digits() { let sk = make(b"ACGTACGT"); - let out = capture(|w| write_scatter(&sk, w, 4, 2, 0, CanonicalKmer::from_raw_unchecked(0))); + let out = capture(|w| write_scatter(&sk, w, 4, 2, 0, Minimizer::from_raw_unchecked(0))); let id = &out.lines().next().unwrap()[1..17]; // skip '>' assert_eq!(id.len(), 16); assert!(id.chars().all(|c| c.is_ascii_hexdigit())); diff --git a/src/obikmer/src/cmd/fasta.rs b/src/obikmer/src/cmd/fasta.rs index bd6bf59..4fd1e43 100644 --- a/src/obikmer/src/cmd/fasta.rs +++ b/src/obikmer/src/cmd/fasta.rs @@ -64,7 +64,7 @@ fn dump_super_kmers(kp: &KmerPartition, partition_dir: &PathBuf) { std::process::exit(1) }); - let mut reader = SKFileReader::open(&in_path, k).unwrap_or_else(|e| { + let mut reader = SKFileReader::open(&in_path).unwrap_or_else(|e| { eprintln!("error opening {}: {e}", in_path.display()); std::process::exit(1) }); diff --git a/src/obikmer/src/cmd/longtig.rs b/src/obikmer/src/cmd/longtig.rs index 5a54dc8..fa8d1bf 100644 --- a/src/obikmer/src/cmd/longtig.rs +++ b/src/obikmer/src/cmd/longtig.rs @@ -7,6 +7,7 @@ use niffler::Level; use niffler::send::compression::Format; use obidebruinj::GraphDeBruijn; use obikpartitionner::KmerPartition; +use obikseq::set_k; use obiskio::SKFileReader; use ph::fmph::GOFunction; use rayon::prelude::*; @@ -33,6 +34,7 @@ pub fn run(args: LongtigArgs) { }); let k = kp.kmer_size(); + set_k(k); let n = kp.n_partitions(); info!("building longtigs from {n} partitions (k={k}, parallel)"); @@ -46,7 +48,7 @@ pub fn run(args: LongtigArgs) { } let out_path = part_dir.join("longtig.fasta.gz"); - let mut g = GraphDeBruijn::new(k); + let mut g = GraphDeBruijn::new(); let mphf_path = part_dir.join("mphf1.bin"); let counts_path = part_dir.join("counts1.bin"); @@ -86,12 +88,12 @@ pub fn run(args: LongtigArgs) { .as_ref() .map(|m| unsafe { std::slice::from_raw_parts(m.as_ptr() as *const u32, m.len() / 4) }); - let mut reader = SKFileReader::open(&in_path, k).unwrap_or_else(|e| { + let mut reader = SKFileReader::open(&in_path).unwrap_or_else(|e| { eprintln!("error opening {}: {e}", in_path.display()); std::process::exit(1) }); for sk in reader.iter() { - for kmer in sk.iter_canonical_kmers(k) { + for kmer in sk.iter_canonical_kmers() { let accept = match (&mphf_opt, counts_slice) { (Some(mphf), Some(counts)) => { if let Some(slot) = mphf.get(&kmer) { diff --git a/src/obikmer/src/cmd/partition.rs b/src/obikmer/src/cmd/partition.rs index 138bc43..03cf3cd 100644 --- a/src/obikmer/src/cmd/partition.rs +++ b/src/obikmer/src/cmd/partition.rs @@ -2,7 +2,7 @@ use std::path::PathBuf; use clap::Args; use obikpartitionner::KmerPartition; -use obikseq::RoutableSuperKmer; +use obikseq::{RoutableSuperKmer, set_k, set_m}; use tracing::info; use crate::cli::{CommonArgs, PipelineData, open_chunks}; @@ -25,7 +25,9 @@ pub struct PartitionArgs { pub fn run(args: PartitionArgs) { let k = args.common.kmer_size; + set_k(k); let m = args.common.minimizer_size; + set_m(m); let theta = args.common.theta; let level_max = args.common.level_max; let n_workers = args.common.threads.max(1); @@ -42,7 +44,7 @@ pub fn run(args: PartitionArgs) { PipelineData : PathBuf => Vec, ||? { |path| open_chunks(path) } : Path => RawChunk, |? { move |rope| obiread::normalize_sequence_chunk(rope, k) } : RawChunk => NormChunk, - | { move |rope| obiskbuilder::build_superkmers(rope, k, m, level_max, theta) }: NormChunk => Batch, + | { move |rope| obiskbuilder::build_superkmers(rope, k, level_max, theta) }: NormChunk => Batch, }; for batch in pipe.apply(path_source, n_workers, 1) { diff --git a/src/obikmer/src/cmd/superkmer.rs b/src/obikmer/src/cmd/superkmer.rs index 473c67a..8dc938a 100644 --- a/src/obikmer/src/cmd/superkmer.rs +++ b/src/obikmer/src/cmd/superkmer.rs @@ -25,7 +25,7 @@ fn write_batch( let partition_mask = (1u64 << partition_bits) - 1; for rsk in batch { let minimizer = *rsk.minimizer(); - let partition = (minimizer.seq_hash(m) & partition_mask) as usize; + let partition = (minimizer.seq_hash() & partition_mask) as usize; write_scatter(rsk.superkmer(), out, k, m, partition, minimizer)?; } Ok(()) @@ -47,7 +47,7 @@ pub fn run(args: SuperkmerArgs) { PipelineData : PathBuf => Vec, ||? { |path| open_chunks(path) } : Path => RawChunk, |? { move |rope| obiread::normalize_sequence_chunk(rope, k) } : RawChunk => NormChunk, - | { move |rope| obiskbuilder::build_superkmers(rope, k, m, level_max, theta) }: NormChunk => Batch, + | { move |rope| obiskbuilder::build_superkmers(rope, k, level_max, theta) }: NormChunk => Batch, }; let mut out = BufWriter::new(io::stdout()); diff --git a/src/obikmer/src/cmd/unitig.rs b/src/obikmer/src/cmd/unitig.rs index 0fa3abd..66fd583 100644 --- a/src/obikmer/src/cmd/unitig.rs +++ b/src/obikmer/src/cmd/unitig.rs @@ -7,6 +7,7 @@ use niffler::Level; use niffler::send::compression::Format; use obidebruinj::GraphDeBruijn; use obikpartitionner::KmerPartition; +use obikseq::set_k; use obiskio::SKFileReader; use ph::fmph::GOFunction; use rayon::prelude::*; @@ -33,6 +34,7 @@ pub fn run(args: UnitigArgs) { }); let k = kp.kmer_size(); + set_k(k); let n = kp.n_partitions(); info!("building unitigs from {n} partitions (k={k}, parallel)"); @@ -46,7 +48,7 @@ pub fn run(args: UnitigArgs) { } let out_path = part_dir.join("unitig.fasta.gz"); - let mut g = GraphDeBruijn::new(k); + let mut g = GraphDeBruijn::new(); let mphf_path = part_dir.join("mphf1.bin"); let counts_path = part_dir.join("counts1.bin"); @@ -86,12 +88,12 @@ pub fn run(args: UnitigArgs) { .as_ref() .map(|m| unsafe { std::slice::from_raw_parts(m.as_ptr() as *const u32, m.len() / 4) }); - let mut reader = SKFileReader::open(&in_path, k).unwrap_or_else(|e| { + let mut reader = SKFileReader::open(&in_path).unwrap_or_else(|e| { eprintln!("error opening {}: {e}", in_path.display()); std::process::exit(1) }); for sk in reader.iter() { - for kmer in sk.iter_canonical_kmers(k) { + for kmer in sk.iter_canonical_kmers() { let accept = match (&mphf_opt, counts_slice) { (Some(mphf), Some(counts)) => { if let Some(slot) = mphf.get(&kmer) { diff --git a/src/obikpartitionner/src/partition.rs b/src/obikpartitionner/src/partition.rs index 7b93530..d1bf7ea 100644 --- a/src/obikpartitionner/src/partition.rs +++ b/src/obikpartitionner/src/partition.rs @@ -15,6 +15,7 @@ use remove_dir_all::remove_dir_all; use niffler::Level; use niffler::send::compression::Format; use obikseq::RoutableSuperKmer; +use obikseq::Sequence; use obikseq::superkmer::SuperKmer; use obiskio::{SKFileMeta, SKFileReader, SKFileWriter, SKResult}; use rayon::prelude::*; @@ -124,8 +125,7 @@ impl KmerPartition { /// Route and write one super-kmer to its partition file. pub fn write(&mut self, rsk: RoutableSuperKmer) -> SKResult<()> { self.check_not_closed()?; - let partition = - (rsk.minimizer().seq_hash(self.minimizer_size) & self.partitions_mask) as usize; + let partition = (rsk.minimizer().seq_hash() & self.partitions_mask) as usize; let sk = rsk.into_superkmer(); self.ensure_writer(partition)?.write(&sk) } @@ -134,8 +134,7 @@ impl KmerPartition { pub fn write_batch(&mut self, rsks: Vec) -> SKResult<()> { self.check_not_closed()?; for rsk in rsks { - let partition = - (rsk.minimizer().seq_hash(self.minimizer_size) & self.partitions_mask) as usize; + let partition = (rsk.minimizer().seq_hash() & self.partitions_mask) as usize; let sk = rsk.into_superkmer(); self.ensure_writer(partition)?.write(&sk)?; } @@ -202,7 +201,6 @@ impl KmerPartition { /// more temporary file descriptors — all managed by the global fd pool. pub fn dereplicate(&self) -> SKResult<()> { let level = self.level; - let k = self.kmer_size; let root = &self.root_path; let sys = System::new_all(); // available_memory() can return 0 on macOS when the compressor page count exceeds @@ -223,7 +221,7 @@ impl KmerPartition { } let raw_path = dir.join(format!("raw.{SK_EXT}")); let n_buckets = optimal_buckets(&raw_path, available_per_thread); - dereplicate_partition(&dir, level, n_buckets, k) + dereplicate_partition(&dir, level, n_buckets) }) .collect(); @@ -328,8 +326,7 @@ impl KmerPartition { let dir = self.root_path.join(format!("part_{:05}", partition)); fs::create_dir_all(&dir)?; let file_path = dir.join(format!("raw.{SK_EXT}")); - let writer = - SKFileWriter::create_with(file_path, self.kmer_size, Format::Zstd, self.level)?; + let writer = SKFileWriter::create_with(file_path, Format::Zstd, self.level)?; self.writers[partition] = Some(writer); } Ok(self.writers[partition].as_mut().unwrap()) @@ -415,18 +412,18 @@ fn level_from_u32(n: u32) -> Level { const MAX_SK_COUNT: u64 = (1 << 24) - 1; /// Deduplicate one partition directory in place (two-phase split + merge). -fn dereplicate_partition(dir: &Path, level: Level, n_temp: usize, k: usize) -> SKResult<()> { +fn dereplicate_partition(dir: &Path, level: Level, n_temp: usize) -> SKResult<()> { let raw_path = dir.join(format!("raw.{SK_EXT}")); if !raw_path.exists() { return Ok(()); } let out_path = dir.join(format!("dereplicated.{SK_EXT}")); - let mut writer = SKFileWriter::create_with(&out_path, k, Format::Zstd, level)?; + let mut writer = SKFileWriter::create_with(&out_path, Format::Zstd, level)?; if n_temp == 1 { // ── Direct path: partition fits in memory, no split needed ──────────── - let map = load_bucket(&raw_path, k)?; + let map = load_bucket(&raw_path)?; remove_skmer_file(&raw_path)?; flush_map(map, &mut writer)?; } else { @@ -439,10 +436,10 @@ fn dereplicate_partition(dir: &Path, level: Level, n_temp: usize, k: usize) -> S { let mut writers: Vec = temp_paths .iter() - .map(|p| SKFileWriter::create_with(p, k, Format::Zstd, level)) + .map(|p| SKFileWriter::create_with(p, Format::Zstd, level)) .collect::>()?; - let mut reader = SKFileReader::open(&raw_path, k)?; + let mut reader = SKFileReader::open(&raw_path)?; while let Some(sk) = reader.read()? { let bucket = (sk.seq_hash() & temp_mask) as usize; writers[bucket].write(&sk)?; @@ -455,7 +452,7 @@ fn dereplicate_partition(dir: &Path, level: Level, n_temp: usize, k: usize) -> S // ── Phase 2: merge each temp bucket into the output ─────────────────── for temp_path in &temp_paths { - let map = load_bucket(temp_path, k)?; + let map = load_bucket(temp_path)?; remove_skmer_file(temp_path)?; flush_map(map, &mut writer)?; } @@ -466,14 +463,14 @@ fn dereplicate_partition(dir: &Path, level: Level, n_temp: usize, k: usize) -> S } /// Read a SuperKmer file into a deduplication map (already canonical). -fn load_bucket(path: &Path, k: usize) -> SKResult> { +fn load_bucket(path: &Path) -> SKResult> { let capacity = SKFileMeta::read(path) .ok() .flatten() .map(|m| m.instances as usize) .unwrap_or(0); let mut map: HashMap = HashMap::with_capacity(capacity); - let mut reader = SKFileReader::open(path, k)?; + let mut reader = SKFileReader::open(path)?; while let Some(sk) = reader.read()? { let count = sk.count() as u64; *map.entry(sk).or_insert(0) += count; @@ -512,10 +509,10 @@ fn count_partition(dir: &Path, dedup_path: &Path, k: usize) -> SKResult<()> { let mut seen: HashSet = HashSet::with_capacity(capacity); let mut pass1_superkmers: u64 = 0; { - let mut reader = SKFileReader::open(dedup_path, k)?; + let mut reader = SKFileReader::open(dedup_path)?; while let Some(sk) = reader.read()? { pass1_superkmers += 1; - for kmer in sk.iter_canonical_kmers(k) { + for kmer in sk.iter_canonical_kmers() { seen.insert(kmer); } } @@ -555,10 +552,10 @@ fn count_partition(dir: &Path, dedup_path: &Path, k: usize) -> SKResult<()> { { let counts = unsafe { std::slice::from_raw_parts_mut(mmap.as_mut_ptr() as *mut u32, n_kmers) }; - let mut reader = SKFileReader::open(dedup_path, k)?; + let mut reader = SKFileReader::open(dedup_path)?; while let Some(sk) = reader.read()? { pass2_superkmers += 1; - let seql = sk.len(); + let seql = sk.seql(); let sk_count = sk.count(); if pass2_superkmers <= 3 { debug!( @@ -570,7 +567,7 @@ fn count_partition(dir: &Path, dedup_path: &Path, k: usize) -> SKResult<()> { continue; } pass2_count_sum += sk_count as u64; - for kmer in sk.iter_canonical_kmers(k) { + for kmer in sk.iter_canonical_kmers() { if let Some(idx) = mphf.get(&kmer) { counts[idx as usize] = counts[idx as usize].saturating_add(sk_count); pass2_kmer_hits += 1; diff --git a/src/obikseq/Cargo.toml b/src/obikseq/Cargo.toml index 426b741..97be974 100644 --- a/src/obikseq/Cargo.toml +++ b/src/obikseq/Cargo.toml @@ -3,6 +3,11 @@ name = "obikseq" version = "0.1.0" edition = "2024" +[features] +# Replaces the OnceLock-based params with thread-local storage so that +# tests in dependent crates can call set_k / set_m freely without conflicts. +test-utils = [] + [dependencies] bitvec = "1" serde = { version = "1.0", features = ["derive"] } diff --git a/src/obikseq/src/annotations.rs b/src/obikseq/src/annotations.rs index cf359f7..ba8a7bd 100644 --- a/src/obikseq/src/annotations.rs +++ b/src/obikseq/src/annotations.rs @@ -2,6 +2,14 @@ use serde::Serialize; use serde_json; use std::io::{self, Write}; +/// Minimal annotation carrying only the sequence length. +#[derive(Serialize)] +pub struct BasicAnnotation { + pub seq_length: usize, +} + +impl Annotation for BasicAnnotation {} + /// Serialize `self` as a single-line JSON object into a writer. pub trait Annotation: Serialize { /// Write the annotation as compact JSON into `writer`. diff --git a/src/obikseq/src/kmer.rs b/src/obikseq/src/kmer.rs index f83eb4f..70e49bb 100644 --- a/src/obikseq/src/kmer.rs +++ b/src/obikseq/src/kmer.rs @@ -1,12 +1,70 @@ //! Compact 2-bit kmer stored as a left-aligned u64. //! //! Nucleotide 0 occupies bits 63–62, nucleotide i occupies bits 63−2i and 62−2i. -//! The low 64−2k bits are always zero. k is not stored — it is a parameter of -//! every operation that needs it, and will be owned by the collection-level indexer. +//! The low 64−2·len bits are always zero. +//! +//! The length is not stored in the struct — it is supplied by the [`KmerLength`] +//! type parameter. Two public marker types cover the common cases: +//! +//! | Alias | Marker | Length source | +//! |-----------------|----------|----------------| +//! | [`Kmer`] | [`KLen`] | `params::k()` | +//! | [`CanonicalKmer`]| [`KLen`]| `params::k()` | +//! | [`Minimizer`] | [`MLen`] | `params::m()` | +//! +//! Tests that need a fixed length independent of the global singletons can use +//! [`ConstLen`]. +use serde::Serialize; use std::io::{self, Write}; +use std::marker::PhantomData; +use crate::Annotation; +use crate::Sequence; use crate::encoding::{DEC4, encode_base}; +use crate::params::{k, m}; +use crate::sequence::mix64; + +// ── KmerLength ──────────────────────────────────────────────────────────────── + +/// Marker trait that supplies a kmer length at runtime. +pub trait KmerLength: Copy + std::fmt::Debug + 'static { + /// Returns the length this marker represents. + fn len() -> usize; +} + +/// Marker for the k-mer length (`params::k()`). +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct KLen; + +/// Marker for the minimizer length (`params::m()`). +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct MLen; + +/// Marker for a compile-time-constant length — useful for tests. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ConstLen; + +impl KmerLength for KLen { + #[inline] + fn len() -> usize { + k() + } +} + +impl KmerLength for MLen { + #[inline] + fn len() -> usize { + m() + } +} + +impl KmerLength for ConstLen { + #[inline] + fn len() -> usize { + N + } +} // ── KmerError ───────────────────────────────────────────────────────────────── @@ -43,35 +101,31 @@ impl std::fmt::Display for KmerError { impl std::error::Error for KmerError {} -// ── Kmer ────────────────────────────────────────────────────────────────────── +#[derive(Serialize)] +struct KmerAnnotation { + seq_length: usize, +} +impl Annotation for KmerAnnotation {} -/// A DNA kmer of length k encoded as a left-aligned u64 (2 bits/nucleotide, MSB-first). -/// k is not stored in the struct — it must be supplied by the caller. +// ── KmerOf ──────────────────────────────────────────────────────────────────── + +/// A DNA kmer of length `L::len()` encoded as a left-aligned u64 (2 bits/nucleotide, MSB-first). +/// The low `64 − 2·L::len()` bits are always zero. #[repr(transparent)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Kmer(u64); +pub struct KmerOf(u64, PhantomData); -#[inline] -fn mix64(x: u64) -> u64 { - let x = x ^ (x >> 30); - let x = x.wrapping_mul(0xbf58476d1ce4e5b9); - let x = x ^ (x >> 27); - let x = x.wrapping_mul(0x94d049bb133111eb); - x ^ (x >> 31) -} - -impl Kmer { - /// Wrap a raw left-aligned u64 value as a Kmer. +impl KmerOf { + /// Wrap a raw left-aligned u64 value as a kmer. #[inline] pub fn from_raw(raw: u64) -> Self { - Kmer(raw) + KmerOf(raw, PhantomData) } - /// Wrap a raw right-aligned u64 value as a Kmer. - /// The raw value is shifted left by `2 * k` bits to align it with the leftmost position. + /// Wrap a raw right-aligned u64 value, shifting it into left-aligned position. #[inline] - pub fn from_raw_right(raw: u64, k: usize) -> Self { - Kmer(raw << (64 - 2 * k)) + pub fn from_raw_right(raw: u64) -> Self { + KmerOf(raw << (64 - 2 * L::len()), PhantomData) } /// Return the raw left-aligned u64 value. @@ -82,14 +136,13 @@ impl Kmer { /// Return the raw right-aligned u64 value. #[inline] - pub fn raw_right(&self, k: usize) -> u64 { - self.0 >> (64 - 2 * k) + pub fn raw_right(&self) -> u64 { + self.0 >> (64 - 2 * L::len()) } - /// Encode the first k nucleotides of an ASCII slice into a Kmer. - /// Zero allocation — result lives on the stack. - #[inline] - pub fn from_ascii(ascii: &[u8], k: usize) -> Result { + /// Encode the first `L::len()` nucleotides of an ASCII slice into a kmer. + pub fn from_ascii(ascii: &[u8]) -> Result { + let k = L::len(); if k == 0 || k > 32 { return Err(KmerError::InvalidK { k }); } @@ -104,26 +157,21 @@ impl Kmer { for i in 0..k { val = (val << 2) | encode_base(ascii[i]) as u64; } - Ok(Kmer(val << (64 - 2 * k))) + Ok(KmerOf(val << (64 - 2 * k), PhantomData)) } - /// Extract nucleotide i (0-based from 5′ end) as a 2-bit value. + /// Decode into a freshly allocated ASCII `Vec`. #[inline] - pub fn nucleotide(&self, i: usize) -> u8 { - ((self.0 >> (62 - 2 * i)) & 0b11) as u8 - } - - /// Decode this kmer into a freshly allocated ASCII `Vec`. - #[inline] - pub fn to_ascii(&self, k: usize) -> Vec { - let mut buf = Vec::with_capacity(k); - self.write_ascii(k, &mut buf).unwrap(); + pub fn to_ascii(&self) -> Vec { + let mut buf = Vec::with_capacity(L::len()); + self.write_ascii(&mut buf).unwrap(); buf } - /// Decode this kmer into ASCII nucleotides, writing into `writer`. + /// Decode into ASCII nucleotides, writing into `writer`. #[inline] - pub fn write_ascii(&self, k: usize, writer: &mut W) -> io::Result<()> { + pub fn write_ascii(&self, writer: &mut W) -> io::Result<()> { + let k = L::len(); let bytes = self.0.to_be_bytes(); let full = k / 4; let rem = k % 4; @@ -137,296 +185,212 @@ impl Kmer { Ok(()) } - /// Compute the reverse complement of this kmer. - /// Zero allocation — result lives on the stack. + /// Compute the reverse complement. #[inline] - pub fn revcomp(&self, k: usize) -> Self { - let x = !self.0; // complement - let x = x.swap_bytes(); // reverse bytes - let x = ((x >> 4) & 0x0F0F0F0F0F0F0F0F) | ((x & 0x0F0F0F0F0F0F0F0F) << 4); // swap nibbles - let x = ((x >> 2) & 0x3333333333333333) | ((x & 0x3333333333333333) << 2); // swap 2-bit groups - Kmer(x << (64 - 2 * k)) + pub fn revcomp(&self) -> Self { + let k = L::len(); + let x = !self.0; + let x = x.swap_bytes(); + let x = ((x >> 4) & 0x0F0F0F0F0F0F0F0F) | ((x & 0x0F0F0F0F0F0F0F0F) << 4); + let x = ((x >> 2) & 0x3333333333333333) | ((x & 0x3333333333333333) << 2); + KmerOf(x << (64 - 2 * k), PhantomData) } - /// Return the canonical form: lexicographic minimum of forward and reverse complement. - /// Zero allocation — result lives on the stack. - #[inline] - pub fn canonical(&self, k: usize) -> CanonicalKmer { - let rc = self.revcomp(k); - CanonicalKmer(if self.0 <= rc.0 { *self } else { rc }) - } - - /// Slide the window one base to the right: drop the first nucleotide, append `nuc` at position k-1. - pub fn push_right(self, nuc: u8, k: usize) -> Self { + /// Slide the window one base to the right: drop nucleotide 0, append `nuc` at position `L::len()-1`. + pub fn push_right(self, nuc: u8) -> Self { + let k = L::len(); let shifted = self.0 << 2 & (!0u64 << (64 - 2 * (k - 1))); - let shift = 64 - 2 * k; - Kmer(shifted | ((nuc as u64 & 3) << shift)) + KmerOf(shifted | ((nuc as u64 & 3) << (64 - 2 * k)), PhantomData) } - /// Slide the window one base to the left: drop the last nucleotide, prepend `nuc` at position 0. - pub fn push_left(self, nuc: u8, k: usize) -> Self { + /// Slide the window one base to the left: drop nucleotide `L::len()-1`, prepend `nuc` at position 0. + pub fn push_left(self, nuc: u8) -> Self { + let k = L::len(); let shifted = (self.0 >> 2) & (!0u64 << (64 - 2 * k)); - Kmer(shifted | ((nuc as u64 & 3) << 62)) + KmerOf(shifted | ((nuc as u64 & 3) << 62), PhantomData) } - /// Returns `true` if `self` and `other` overlap by `k` - 1 bases. - /// - /// The last K-1 nucleotides of `self` and the first K-1 nucleotides - /// of `other` must be equal. - pub fn is_overlapping(self, other: Self, k: usize) -> bool { - let left = self.0 << 2 & (!0u64 << (64 - 2 * (k - 1))); - let right = other.0 & (!0u64 << (64 - 2 * (k - 1))); - left == right + /// Returns `true` if the last `L::len()-1` nucleotides of `self` equal the first `L::len()-1` of `other`. + pub fn is_overlapping(self, other: Self) -> bool { + let k = L::len(); + let mask = !0u64 << (64 - 2 * (k - 1)); + (self.0 << 2 & mask) == (other.0 & mask) } } -// ── CanonicalKmer ───────────────────────────────────────────────────────────── +impl Sequence for KmerOf { + type Canonical = CanonicalKmerOf; -/// A [`Kmer`] guaranteed to be in canonical form (lexicographic minimum of + fn seql(&self) -> usize { + L::len() + } + + fn seq_hash(&self) -> u64 { + self.canonical().seq_hash() + } + + #[inline] + fn nucleotide(&self, i: usize) -> u8 { + ((self.0 >> (62 - 2 * i)) & 0b11) as u8 + } + + #[inline] + fn canonical(&self) -> Self::Canonical { + let rc = self.revcomp(); + CanonicalKmerOf(if self.0 <= rc.0 { self.0 } else { rc.0 }, PhantomData) + } + + fn annotation(&self) -> impl Annotation { + KmerAnnotation { + seq_length: L::len(), + } + } +} +// ── CanonicalKmerOf ─────────────────────────────────────────────────────────── + +/// A [`KmerOf`] guaranteed to be in canonical form (lexicographic minimum of /// forward and reverse complement). /// -/// The only public constructors are [`Kmer::canonical`] (checked) and -/// [`CanonicalKmer::from_raw_unchecked`] (for trusted paths such as -/// deserialisation or rolling-window minimizer extraction). +/// The only public constructors are [`KmerOf::canonical`] (verified) and +/// [`CanonicalKmerOf::from_raw_unchecked`] (trusted paths such as deserialisation). #[repr(transparent)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct CanonicalKmer(Kmer); +pub struct CanonicalKmerOf(u64, PhantomData); -impl CanonicalKmer { +impl CanonicalKmerOf { /// Wrap a raw left-aligned u64 without verifying the canonical invariant. /// /// # Safety (logical) - /// The caller must guarantee that `raw == min(raw, revcomp(raw, k))`. - /// Violations cause silently wrong results in MPHF lookup and graph traversal. + /// The caller must guarantee `raw == min(raw, revcomp(raw))`. #[inline] pub fn from_raw_unchecked(raw: u64) -> Self { - CanonicalKmer(Kmer(raw)) + CanonicalKmerOf(raw, PhantomData) } /// Return the raw left-aligned u64 value. #[inline] pub fn raw(&self) -> u64 { - self.0.0 + self.0 } /// Decode into a freshly allocated ASCII `Vec`. #[inline] - pub fn to_ascii(&self, k: usize) -> Vec { - self.0.to_ascii(k) + pub fn to_ascii(&self) -> Vec { + self.into_kmer().to_ascii() } /// Decode into ASCII nucleotides, writing into `writer`. #[inline] - pub fn write_ascii(&self, k: usize, writer: &mut W) -> io::Result<()> { - self.0.write_ascii(k, writer) + pub fn write_ascii(&self, writer: &mut W) -> io::Result<()> { + self.into_kmer().write_ascii(writer) } - /// Compute the reverse complement. The result is a raw [`Kmer`] — the - /// revcomp of a canonical kmer is not necessarily canonical itself. + /// Compute the reverse complement. The result is a raw [`KmerOf`] — not + /// necessarily canonical itself. #[inline] - pub fn revcomp(&self, k: usize) -> Kmer { - self.0.revcomp(k) + pub fn revcomp(&self) -> KmerOf { + self.into_kmer().revcomp() } - /// Hash via `mix64`. No re-canonicalisation needed. + /// Hash via `mix64`. #[inline] - pub fn seq_hash(&self, _k: usize) -> u64 { - mix64(self.0.0) + pub fn seq_hash(&self) -> u64 { + hash_kmer(self.0) } /// Extract nucleotide i (0-based from 5′ end) as a 2-bit value. #[inline] pub fn nucleotide(&self, i: usize) -> u8 { - self.0.nucleotide(i) + self.into_kmer().nucleotide(i) } /// Return the four left canonical neighbours (each already canonical). - /// Zero allocation — result lives on the stack. - pub fn left_canonical_neighbors(&self, k: usize) -> [CanonicalKmer; 4] { - let shifted = (self.raw() >> 2) & (!0u64 << (64 - 2 * k)); + pub fn left_canonical_neighbors(&self) -> [CanonicalKmerOf; 4] { + let k = L::len(); + let shifted = (self.0 >> 2) & (!0u64 << (64 - 2 * k)); [ - Kmer(shifted).canonical(k), - Kmer(shifted | (1u64 << 62)).canonical(k), - Kmer(shifted | (2u64 << 62)).canonical(k), - Kmer(shifted | (3u64 << 62)).canonical(k), + KmerOf::(shifted, PhantomData).canonical(), + KmerOf::(shifted | (1u64 << 62), PhantomData).canonical(), + KmerOf::(shifted | (2u64 << 62), PhantomData).canonical(), + KmerOf::(shifted | (3u64 << 62), PhantomData).canonical(), ] } /// Return the four right canonical neighbours (each already canonical). - /// Zero allocation — result lives on the stack. - pub fn right_canonical_neighbors(&self, k: usize) -> [CanonicalKmer; 4] { - let shifted = self.raw() << 2 & (!0u64 << (64 - 2 * (k - 1))); + pub fn right_canonical_neighbors(&self) -> [CanonicalKmerOf; 4] { + let k = L::len(); + let shifted = self.0 << 2 & (!0u64 << (64 - 2 * (k - 1))); let shift = 64 - 2 * k; [ - Kmer(shifted).canonical(k), - Kmer(shifted | (1u64 << shift)).canonical(k), - Kmer(shifted | (2u64 << shift)).canonical(k), - Kmer(shifted | (3u64 << shift)).canonical(k), + KmerOf::(shifted, PhantomData).canonical(), + KmerOf::(shifted | (1u64 << shift), PhantomData).canonical(), + KmerOf::(shifted | (2u64 << shift), PhantomData).canonical(), + KmerOf::(shifted | (3u64 << shift), PhantomData).canonical(), ] } - /// Consume this wrapper and return the inner raw [`Kmer`]. + /// Return the inner value as a raw [`KmerOf`]. #[inline] - pub fn into_kmer(self) -> Kmer { - self.0 + pub fn into_kmer(self) -> KmerOf { + KmerOf(self.0, PhantomData) } } -impl From for Kmer { - #[inline] - fn from(ck: CanonicalKmer) -> Self { - ck.0 +impl Sequence for CanonicalKmerOf { + type Canonical = CanonicalKmerOf; + + fn seql(&self) -> usize { + L::len() } + + fn seq_hash(&self) -> u64 { + hash_kmer(self.0) + } + + #[inline] + fn nucleotide(&self, i: usize) -> u8 { + ((self.0 >> (62 - 2 * i)) & 0b11) as u8 + } + + fn canonical(&self) -> Self::Canonical { + *self + } + + fn annotation(&self) -> impl Annotation { + KmerAnnotation { + seq_length: L::len(), + } + } +} +impl From> for KmerOf { + #[inline] + fn from(ck: CanonicalKmerOf) -> Self { + ck.into_kmer() + } +} + +// ── Public type aliases ─────────────────────────────────────────────────────── + +/// A DNA k-mer using the global `params::k()` length. +pub type Kmer = KmerOf; + +/// A canonical k-mer using the global `params::k()` length. +pub type CanonicalKmer = CanonicalKmerOf; + +/// A minimizer: a canonical k-mer using the global `params::m()` length. +pub type Minimizer = CanonicalKmerOf; + +/// Compute a hash for a raw (left-aligned) kmer value. +/// +/// This is a convenience wrapper around [`mix64`] that accepts the raw +/// 64-bit representation directly, which is useful when the canonical +/// invariant is not required or has already been handled. +#[inline] +pub fn hash_kmer(raw: u64) -> u64 { + mix64(raw ^ 0x9e3779b97f4a7c15) } // ── tests ───────────────────────────────────────────────────────────────────── - #[cfg(test)] -mod tests { - use super::*; - - fn ascii_revcomp(seq: &[u8]) -> Vec { - seq.iter() - .rev() - .map(|&b| match b { - b'A' => b'T', - b'T' => b'A', - b'C' => b'G', - b'G' => b'C', - _ => b'A', - }) - .collect() - } - - const K_VALUES: &[usize] = &[1, 2, 3, 4, 8, 11, 16, 31, 32]; - - fn make_seq(k: usize) -> Vec { - (0..k).map(|i| b"ACGT"[i % 4]).collect() - } - - // ── from_ascii / to_ascii ───────────────────────────────────────────────── - - #[test] - fn ascii_roundtrip() { - for &k in K_VALUES { - let ascii = make_seq(k); - let kmer = Kmer::from_ascii(&ascii, k).unwrap(); - assert_eq!(kmer.to_ascii(k), ascii, "roundtrip failed for k={k}"); - } - } - - #[test] - fn from_ascii_all_bases() { - for (base, expected) in [(b'A', b'A'), (b'C', b'C'), (b'G', b'G'), (b'T', b'T')] { - let kmer = Kmer::from_ascii(&[base], 1).unwrap(); - assert_eq!(kmer.to_ascii(1), vec![expected]); - } - } - - #[test] - fn from_ascii_invalid_k() { - assert!(Kmer::from_ascii(b"A", 0).is_err()); - assert!(Kmer::from_ascii(b"ACGT", 33).is_err()); - } - - #[test] - fn from_ascii_too_short() { - assert!(Kmer::from_ascii(b"ACG", 4).is_err()); - } - - // ── nucleotide ──────────────────────────────────────────────────────────── - - #[test] - fn nucleotide_extraction() { - let kmer = Kmer::from_ascii(b"ACGT", 4).unwrap(); - assert_eq!(kmer.nucleotide(0), 0b00); // A - assert_eq!(kmer.nucleotide(1), 0b01); // C - assert_eq!(kmer.nucleotide(2), 0b10); // G - assert_eq!(kmer.nucleotide(3), 0b11); // T - } - - // ── revcomp ─────────────────────────────────────────────────────────────── - - #[test] - fn revcomp_known_values() { - let cases: &[(&[u8], &[u8])] = &[ - (b"A", b"T"), - (b"AC", b"GT"), - (b"ACG", b"CGT"), - (b"ACGT", b"ACGT"), // palindrome - (b"AAAA", b"TTTT"), - (b"TTTT", b"AAAA"), - ]; - for (seq, expected) in cases { - let k = seq.len(); - let kmer = Kmer::from_ascii(seq, k).unwrap(); - let rc = kmer.revcomp(k); - assert_eq!( - rc.to_ascii(k), - *expected, - "revcomp wrong for \"{}\"", - std::str::from_utf8(seq).unwrap() - ); - } - } - - #[test] - fn revcomp_vs_reference() { - for &k in K_VALUES { - let ascii = make_seq(k); - let expected = ascii_revcomp(&ascii); - let rc = Kmer::from_ascii(&ascii, k).unwrap().revcomp(k); - assert_eq!(rc.to_ascii(k), expected, "revcomp wrong for k={k}"); - } - } - - #[test] - fn revcomp_involution() { - for &k in K_VALUES { - let ascii = make_seq(k); - let kmer = Kmer::from_ascii(&ascii, k).unwrap(); - assert_eq!( - kmer.revcomp(k).revcomp(k), - kmer, - "revcomp∘revcomp≠id for k={k}" - ); - } - } - - // ── canonical ───────────────────────────────────────────────────────────── - - #[test] - fn canonical_palindrome() { - let kmer = Kmer::from_ascii(b"ACGT", 4).unwrap(); - assert_eq!(kmer.canonical(4).into_kmer(), kmer); - } - - #[test] - fn canonical_chooses_lesser() { - let kmer = Kmer::from_ascii(b"TTTT", 4).unwrap(); - let expected = Kmer::from_ascii(b"AAAA", 4).unwrap(); - assert_eq!(kmer.canonical(4).into_kmer(), expected); - } - - #[test] - fn canonical_is_minimal() { - for &k in K_VALUES { - let ascii = make_seq(k); - let ck = Kmer::from_ascii(&ascii, k).unwrap().canonical(k); - let rc = ck.revcomp(k); - assert!(ck.raw() <= rc.raw(), "canonical not minimal for k={k}"); - } - } - - #[test] - fn canonical_idempotent() { - for &k in K_VALUES { - let ck = Kmer::from_ascii(&make_seq(k), k).unwrap().canonical(k); - assert_eq!( - ck.into_kmer().canonical(k), - ck, - "canonical not idempotent for k={k}" - ); - } - } -} +#[path = "tests/kmer.rs"] +mod tests; diff --git a/src/obikseq/src/lib.rs b/src/obikseq/src/lib.rs index 2219b71..aaa7a7b 100644 --- a/src/obikseq/src/lib.rs +++ b/src/obikseq/src/lib.rs @@ -9,6 +9,8 @@ mod annotations; mod encoding; pub mod kmer; +pub mod packed_seq; +pub mod params; mod revcomp_lookup; /// Routable super-kmer: canonical sequence paired with its minimizer for scatter routing. pub mod routable; @@ -18,7 +20,9 @@ pub mod superkmer; pub mod unitig; pub use annotations::Annotation; -pub use kmer::CanonicalKmer; +pub use kmer::{CanonicalKmer, Kmer, Minimizer, hash_kmer}; +pub use params::{k, m, set_k, set_m}; pub use routable::RoutableSuperKmer; pub use sequence::Sequence; pub use superkmer::SuperKmer; +pub use unitig::Unitig; diff --git a/src/obikseq/src/packed_seq.rs b/src/obikseq/src/packed_seq.rs new file mode 100644 index 0000000..73d0ebc --- /dev/null +++ b/src/obikseq/src/packed_seq.rs @@ -0,0 +1,361 @@ +//! Compact 2-bit DNA sequence — shared substrate for [`SuperKmer`] and [`Unitig`]. +//! +//! Encoding: A=00, C=01, G=10, T=11. Nucleotide 0 occupies bits 7–6 of `seq[0]`, +//! nucleotide i occupies bits `7−2*(i%4)` and `6−2*(i%4)` of `seq[i/4]`. +//! Padding bits in the last byte are always 0. +//! +//! The exact nucleotide count is recovered without storing it explicitly: +//! +//! ```text +//! seql = (seq.len() - 1) * 4 + tail_count(tail) +//! ``` +//! +//! where `tail` encodes the number of valid nucleotides in the last byte (0 → 4). + +use std::io::{self, Read, Write}; + +use bitvec::prelude::*; + +use crate::Sequence; +use crate::encoding::{DEC4, encode_base}; +use crate::kmer::{CanonicalKmer, Kmer, KmerError, KLen, KmerLength, KmerOf, MLen, Minimizer}; +use crate::params::k; +use crate::revcomp_lookup::REVCOMP4; + +// ── PackedSeq ───────────────────────────────────────────────────────────────── + +/// 2-bit packed DNA sequence of arbitrary length ≥ 1. +/// +/// `tail` encodes the number of valid nucleotides in the last byte: 0 stands for 4, +/// so the range 0–3 covers all four cases. Padding bits are always 0. +#[derive(Debug, Clone)] +pub struct PackedSeq { + pub(crate) tail: u8, + pub(crate) seq: Box<[u8]>, +} + +impl PartialEq for PackedSeq { + fn eq(&self, other: &Self) -> bool { + self.tail == other.tail && self.seq == other.seq + } +} + +impl Eq for PackedSeq {} + +impl std::hash::Hash for PackedSeq { + fn hash(&self, state: &mut H) { + self.tail.hash(state); + self.seq.hash(state); + } +} + +impl PackedSeq { + /// Construct from pre-packed bytes and a `tail` value (0–3, where 0 means 4). + /// Caller must guarantee padding bits in the last byte are zeroed. + #[inline] + pub fn new(tail: u8, seq: Box<[u8]>) -> Self { + debug_assert!(tail <= 3, "tail must be 0–3"); + debug_assert!(!seq.is_empty(), "seq must be non-empty"); + Self { tail, seq } + } + + /// Sequence length in nucleotides. + #[inline] + pub fn seql(&self) -> usize { + (self.seq.len() - 1) * 4 + tail_count(self.tail) + } + + /// Read-only view of the packed 2-bit bytes. + #[inline] + pub fn seq_bytes(&self) -> &[u8] { + &self.seq + } + + /// Extract nucleotide i (0-based from 5′ end) as a 2-bit value. Zero copy. + #[inline] + pub fn nucleotide(&self, i: usize) -> u8 { + (self.seq[i / 4] >> (6 - 2 * (i % 4))) & 0b11 + } + + /// Encode an ASCII nucleotide slice (ACGT, length ≥ 1). Allocates once. + pub fn from_ascii(ascii: &[u8]) -> Self { + let seql = ascii.len(); + debug_assert!(seql >= 1); + let n = byte_len(seql); + let mut seq = vec![0u8; n]; + let full = seql / 4; + for i in 0..full { + seq[i] = encode_base(ascii[i * 4]) << 6 + | encode_base(ascii[i * 4 + 1]) << 4 + | encode_base(ascii[i * 4 + 2]) << 2 + | encode_base(ascii[i * 4 + 3]); + } + let rem = seql % 4; + if rem > 0 { + let mut last = 0u8; + for j in 0..rem { + last |= encode_base(ascii[full * 4 + j]) << (6 - 2 * j); + } + seq[full] = last; + } + Self::new(count_to_tail(seql), seq.into_boxed_slice()) + } + + /// Encode a slice of 2-bit nucleotide values (0=A…3=T, length ≥ 1). Allocates once. + pub fn from_nucleotides(nucs: &[u8]) -> Self { + let seql = nucs.len(); + debug_assert!(seql >= 1); + let n = byte_len(seql); + let mut seq = vec![0u8; n]; + for (i, &nuc) in nucs.iter().enumerate() { + seq[i / 4] |= (nuc & 0b11) << (6 - 2 * (i % 4)); + } + Self::new(count_to_tail(seql), seq.into_boxed_slice()) + } + + /// Write ASCII nucleotides into `writer`. Zero allocation. + pub fn write_ascii(&self, writer: &mut W) -> io::Result<()> { + let seql = self.seql(); + let full = seql / 4; + for i in 0..full { + writer.write_all(&DEC4[self.seq[i] as usize].to_be_bytes())?; + } + let rem = seql % 4; + if rem > 0 { + writer.write_all(&DEC4[self.seq[full] as usize].to_be_bytes()[..rem])?; + } + Ok(()) + } + + /// Decode into a fresh ASCII `Vec`. Allocates. + #[inline] + pub fn to_ascii(&self) -> Vec { + let mut buf = Vec::with_capacity(self.seql()); + self.write_ascii(&mut buf).unwrap(); + buf + } + + /// Reverse-complement in place. Zero allocation. + pub fn revcomp_inplace(&mut self) { + let seql = self.seql(); + let n = self.seq.len(); + { + let bytes = &mut self.seq[..n]; + let (mut lo, mut hi) = (0, n - 1); + while lo < hi { + (bytes[lo], bytes[hi]) = + (REVCOMP4[bytes[hi] as usize], REVCOMP4[bytes[lo] as usize]); + lo += 1; + hi -= 1; + } + if lo == hi { + bytes[lo] = REVCOMP4[bytes[lo] as usize]; + } + } + let shift = n * 8 - seql * 2; + if shift > 0 { + let bits = self.seq[..n].view_bits_mut::(); + bits.rotate_left(shift); + let len = bits.len(); + bits[len - shift..].fill(false); + } + // tail is invariant: seql is unchanged by revcomp + } + + /// Returns `true` if in canonical form (lexicographic minimum of forward and revcomp). + pub fn is_canonical(&self) -> bool { + let seql = self.seql(); + for i in 0..seql { + let fwd = self.nucleotide(i); + let rev = complement(self.nucleotide(seql - 1 - i)); + match fwd.cmp(&rev) { + std::cmp::Ordering::Less => return true, + std::cmp::Ordering::Greater => return false, + std::cmp::Ordering::Equal => {} + } + } + true + } + + /// Put in canonical form in place. Returns `true` if already canonical. Zero allocation. + #[inline] + pub fn canonicalize(&mut self) -> bool { + if self.is_canonical() { + return true; + } + self.revcomp_inplace(); + false + } + + /// Extract a kmer of length `L::len()` at nucleotide position `i`. Zero allocation. + fn extract(&self, i: usize) -> Result, KmerError> { + let len = L::len(); + let seql = self.seql(); + if i + len > seql { + return Err(KmerError::OutOfBounds { position: i, k: len, seql }); + } + let bits = self.seq.view_bits::(); + let raw: u64 = bits[i * 2..(i + len) * 2].load_be(); + Ok(KmerOf::from_raw(raw << (64 - 2 * len))) + } + + /// Extract the kmer of length `params::k()` at nucleotide position `i`. Zero allocation. + #[inline] + pub fn kmer(&self, i: usize) -> Result { + self.extract::(i) + } + + /// Extract the canonical m-mer (minimizer) of length `params::m()` at position `i`. Zero allocation. + #[inline] + pub fn mmer(&self, i: usize) -> Result { + Ok(self.extract::(i)?.canonical()) + } + + /// Extract the canonical kmer of length `params::k()` at position `i`. Zero allocation. + #[inline] + pub fn canonical_kmer(&self, i: usize) -> Result { + Ok(self.kmer(i)?.canonical()) + } + + /// Iterate over all kmers of length `params::k()` in order. Zero allocation. + #[inline] + pub fn iter_kmers(&self) -> PackedSeqKmerIter<'_> { + PackedSeqKmerIter::new(self) + } + + /// Iterate over all canonical kmers of length `params::k()` in order. Zero allocation. + #[inline] + pub fn iter_canonical_kmers(&self) -> impl Iterator + '_ { + self.iter_kmers().map(|km| km.canonical()) + } + + /// Serialise to a compact binary representation. + /// + /// Format: varint(seql) followed by raw packed bytes. + /// `tail` and `byte_len` are both derivable from `seql` and need not be stored. + pub fn write_to_binary(&self, w: &mut W) -> io::Result<()> { + write_varint(w, self.seql() as u64)?; + w.write_all(&self.seq) + } + + /// Deserialise from the compact binary format produced by [`write_to_binary`]. + /// Allocates exactly one `Box<[u8]>` for the packed bytes. + pub fn read_from_binary(r: &mut R) -> io::Result { + let seql = read_varint(r)? as usize; + if seql == 0 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "empty sequence")); + } + let byte_len = (seql + 3) / 4; + let tail = (seql % 4) as u8; + let mut seq = vec![0u8; byte_len]; + r.read_exact(&mut seq)?; + Ok(Self::new(tail, seq.into_boxed_slice())) + } +} + +// ── PackedSeqKmerIter ───────────────────────────────────────────────────────── + +/// Sliding-window kmer iterator over a [`PackedSeq`]. Zero allocation. +pub struct PackedSeqKmerIter<'a> { + seq: &'a PackedSeq, + mask: u64, + lshift: usize, + current: u64, + pos: usize, + max_pos: usize, +} + +impl<'a> PackedSeqKmerIter<'a> { + fn new(seq: &'a PackedSeq) -> Self { + let seql = seq.seql(); + let klen = k(); + let lshift = 64 - klen * 2; + let mask = ((!0u128) << (lshift + 2)) as u64; + Self { + seq, + mask, + lshift, + current: if seql >= klen { seq.extract::(0).map(|km| km.raw()).unwrap_or(0) } else { 0 }, + pos: klen, + max_pos: seql, + } + } +} + +impl Iterator for PackedSeqKmerIter<'_> { + type Item = Kmer; + + fn next(&mut self) -> Option { + if self.pos > self.max_pos { + return None; + } + let result = Kmer::from_raw(self.current); + if self.pos < self.max_pos { + let inner_shift = 6 - 2 * (self.pos & 3); + let nuc = ((self.seq.seq[self.pos / 4] >> inner_shift) & 3) as u64; + self.current = ((self.current << 2) & self.mask) | (nuc << self.lshift); + } + self.pos += 1; + Some(result) + } +} + +// ── varint (LEB128) ─────────────────────────────────────────────────────────── + +pub(crate) fn write_varint(w: &mut W, mut val: u64) -> io::Result<()> { + loop { + let mut byte = (val & 0x7F) as u8; + val >>= 7; + if val != 0 { + byte |= 0x80; + } + w.write_all(&[byte])?; + if val == 0 { + break; + } + } + Ok(()) +} + +pub(crate) fn read_varint(r: &mut R) -> io::Result { + let mut val = 0u64; + let mut shift = 0u32; + let mut buf = [0u8; 1]; + loop { + r.read_exact(&mut buf)?; + let byte = buf[0]; + val |= ((byte & 0x7F) as u64) << shift; + if byte & 0x80 == 0 { + break; + } + shift += 7; + if shift >= 64 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "varint overflow")); + } + } + Ok(val) +} + +// ── helpers ─────────────────────────────────────────────────────────────────── + +#[inline] +fn complement(base: u8) -> u8 { + !base & 0b11 +} + +#[inline] +fn byte_len(seql: usize) -> usize { + (seql + 3) / 4 +} + +/// Nucleotide count → `tail` value: 0 encodes 4, 1–3 are identity. +#[inline] +pub(crate) fn count_to_tail(seql: usize) -> u8 { + (seql % 4) as u8 +} + +/// `tail` value → nucleotide count in last byte: 0 means 4. +#[inline] +pub(crate) fn tail_count(tail: u8) -> usize { + if tail == 0 { 4 } else { tail as usize } +} diff --git a/src/obikseq/src/params.rs b/src/obikseq/src/params.rs new file mode 100644 index 0000000..6899114 --- /dev/null +++ b/src/obikseq/src/params.rs @@ -0,0 +1,102 @@ +//! Global k-mer and minimizer length parameters, set once at program startup. +//! +//! # Production vs. test behaviour +//! +//! In production (`#[cfg(not(test))]`) both `K` and `M` are stored in a +//! [`OnceLock`]: they can be initialised exactly once; any attempt to set a +//! different value panics. This prevents silent divergence between the global +//! parameter and the values used to build data structures. +//! +//! In test builds (`#[cfg(test)]`) the same public API is backed by +//! `thread_local!` [`Cell`]s instead. Each test thread gets its own +//! independent copies of `K` and `M`, so tests can use arbitrary values +//! without coordinating with one another and without any reset mechanism. +//! The `OnceLock` constraint is deliberately absent: test isolation is +//! provided by thread locality, not by write-once semantics. + +// ── Production implementation ───────────────────────────────────────────────── + +#[cfg(not(any(test, feature = "test-utils")))] +mod state { + use std::sync::OnceLock; + + static K: OnceLock = OnceLock::new(); + static M: OnceLock = OnceLock::new(); + + pub fn set_k(k: usize) { + K.get_or_init(|| k); + assert_eq!(*K.get().unwrap(), k, "K already initialized to a different value"); + } + + pub fn k() -> usize { + *K.get().expect("K not initialized — call params::set_k or params::init first") + } + + pub fn set_m(m: usize) { + M.get_or_init(|| m); + assert_eq!(*M.get().unwrap(), m, "M already initialized to a different value"); + } + + pub fn m() -> usize { + *M.get().expect("M not initialized — call params::set_m or params::init first") + } +} + +// ── Test implementation ─────────────────────────────────────────────────────── +// +// Each test thread owns its private K and M via thread_local!, so tests may +// call set_k / set_m with any value without affecting other tests. + +#[cfg(any(test, feature = "test-utils"))] +mod state { + use std::cell::Cell; + + thread_local! { + static K: Cell = Cell::new(0); + static M: Cell = Cell::new(0); + } + + pub fn set_k(k: usize) { K.with(|c| c.set(k)); } + pub fn k() -> usize { K.with(|c| c.get()) } + pub fn set_m(m: usize) { M.with(|c| c.set(m)); } + pub fn m() -> usize { M.with(|c| c.get()) } +} + +// ── Public API (identical signature in both configurations) ─────────────────── + +/// Initialise both K and M in one call. +/// +/// In production, panics if either value has already been set to a different +/// value. In tests, simply overwrites the thread-local. +pub fn init(k: usize, m: usize) { + state::set_k(k); + state::set_m(m); +} + +/// Set the k-mer length. +/// +/// In production: idempotent for the same value, panics on conflict. +/// In tests: unconditionally updates the calling thread's value. +pub fn set_k(k: usize) { + state::set_k(k); +} + +/// Returns the k-mer length. Panics if not yet initialized. +#[inline] +pub fn k() -> usize { + state::k() +} + +/// Set the minimizer length. +/// +/// In production: idempotent for the same value, panics on conflict. +/// In tests: unconditionally updates the calling thread's value. +pub fn set_m(m: usize) { + state::set_m(m); +} + +/// Returns the minimizer length. Panics if not yet initialized. +#[inline] +pub fn m() -> usize { + state::m() +} diff --git a/src/obikseq/src/routable.rs b/src/obikseq/src/routable.rs index b0dc65d..a4b2338 100644 --- a/src/obikseq/src/routable.rs +++ b/src/obikseq/src/routable.rs @@ -1,40 +1,53 @@ //! Super-kmer with routing metadata: canonical sequence + pre-computed minimizer. -use super::kmer::CanonicalKmer; -use super::SuperKmer; +use serde::Serialize; -/// Owned wrapper that pairs a canonical [`SuperKmer`] with its minimizer [`Kmer`]. +use crate::Annotation; +use crate::Sequence; +use crate::SuperKmer; +use crate::kmer::Minimizer; +use crate::packed_seq::{PackedSeq, count_to_tail}; +use crate::params::m; + +/// Owned wrapper that pairs a canonical [`SuperKmer`] with its pre-computed minimizer. /// /// Created at the single point where raw sequence bytes are emitted from the -/// scratch buffer. The minimizer position (given in original orientation) is -/// adjusted for any flip applied during canonicalisation. After routing, call +/// scratch buffer. The minimizer position (given in original orientation) is +/// adjusted for any flip applied during canonicalisation. After routing, call /// [`into_superkmer`] to discard the metadata and continue with the bare sequence. /// /// [`into_superkmer`]: RoutableSuperKmer::into_superkmer +#[derive(Clone)] pub struct RoutableSuperKmer { superkmer: SuperKmer, - minimizer: CanonicalKmer, + minimizer: Minimizer, } +#[derive(Serialize)] +struct SKRAnnotation { + seq_length: usize, + count: u32, +} + +impl Annotation for SKRAnnotation {} + impl RoutableSuperKmer { /// Construct from raw packed bytes. /// /// `min_pos` is the 0-based minimizer position in the **original** (pre-flip) - /// orientation. `m` is the minimizer length. `seql` and `seq` are the - /// raw length byte and 2-bit-packed nucleotides as produced by the scratch - /// buffer. - pub fn build(min_pos: usize, m: usize, seql: u8, seq: Box<[u8]>) -> Self { - let (sk, already_canonical) = SuperKmer::build(seql, seq); + /// orientation. `seql` is the nucleotide count. The sequence is canonicalised + /// in place; `min_pos` is adjusted accordingly. + pub fn build(min_pos: usize, seql: usize, seq: Box<[u8]>) -> Self { + let mut inner = PackedSeq::new(count_to_tail(seql), seq); + let already_canonical = inner.canonicalize(); let adjusted_pos = if already_canonical { min_pos } else { - sk.len() - m - min_pos + seql - m() - min_pos }; - let minimizer = sk.kmer(adjusted_pos, m).unwrap().canonical(m); - Self { - superkmer: sk, - minimizer, - } + let minimizer = inner.mmer(adjusted_pos).unwrap(); + let superkmer = SuperKmer { count: 1, inner }; + Self { superkmer, minimizer } } /// Borrow the canonical super-kmer sequence. @@ -42,8 +55,8 @@ impl RoutableSuperKmer { &self.superkmer } - /// Borrow the canonical minimizer kmer. - pub fn minimizer(&self) -> &CanonicalKmer { + /// Borrow the canonical minimizer. + pub fn minimizer(&self) -> &Minimizer { &self.minimizer } @@ -53,7 +66,34 @@ impl RoutableSuperKmer { } /// Sequence length in nucleotides. - pub fn len(&self) -> usize { - self.superkmer.len() + pub fn seql(&self) -> usize { + self.superkmer.seql() + } +} + +impl Sequence for RoutableSuperKmer { + type Canonical = RoutableSuperKmer; + + fn seql(&self) -> usize { + self.superkmer.seql() + } + + fn nucleotide(&self, i: usize) -> u8 { + self.superkmer.nucleotide(i) + } + + fn seq_hash(&self) -> u64 { + self.minimizer.seq_hash() + } + + fn canonical(&self) -> Self::Canonical { + self.clone() + } + + fn annotation(&self) -> impl Annotation { + SKRAnnotation { + seq_length: self.superkmer.seql(), + count: self.superkmer.count(), + } } } diff --git a/src/obikseq/src/sequence.rs b/src/obikseq/src/sequence.rs index b23711e..4f8220d 100644 --- a/src/obikseq/src/sequence.rs +++ b/src/obikseq/src/sequence.rs @@ -1,8 +1,71 @@ -use crate::Annotation; +use std::io::{self, Write}; +use crate::Annotation; +use crate::annotations::BasicAnnotation; + +/// Common interface for all 2-bit packed DNA sequences in the pipeline. +/// +/// Required methods: `Canonical`, `seql`, `nucleotide`, `seq_hash`, `canonical`. +/// All other methods have default implementations derived from those five. pub trait Sequence { - fn sequence(&self) -> Box<[u8]>; - fn canonical(&self) -> &Self; + /// The canonical form of this sequence type. + /// + /// For types always stored canonical (`SuperKmer`, `CanonicalKmerOf`), set `Canonical = Self`. + /// For `KmerOf`, set `Canonical = CanonicalKmerOf`. + type Canonical: Sequence; + + /// Sequence length in nucleotides. + fn seql(&self) -> usize; + + /// Extract nucleotide `i` (0-based from 5′ end) as a 2-bit value (A=0, C=1, G=2, T=3). + fn nucleotide(&self, i: usize) -> u8; + + /// Hash of the sequence, used for partitioning and routing. fn seq_hash(&self) -> u64; - fn annotation(&self) -> Annotation; + + /// Return the canonical form. + /// + /// For `Copy` types this is free; for heap-backed types it clones (output/debug paths only). + fn canonical(&self) -> Self::Canonical; + + /// Return an annotation describing this sequence's metadata. + /// + /// Default: `BasicAnnotation { seq_length }`. Override for richer metadata. + fn annotation(&self) -> impl Annotation { + BasicAnnotation { seq_length: self.seql() } + } + + /// Decode into ASCII nucleotides, writing into `writer`. + /// + /// Default: one byte per nucleotide via `nucleotide()`. + /// Types with packed byte access should override with the faster DEC4 path. + fn write_ascii(&self, w: &mut W) -> io::Result<()> { + for i in 0..self.seql() { + w.write_all(&[b"ACGT"[self.nucleotide(i) as usize]])?; + } + Ok(()) + } + + /// Decode into a fresh ASCII `Vec`. + fn to_ascii(&self) -> Vec { + let mut buf = Vec::with_capacity(self.seql()); + self.write_ascii(&mut buf).unwrap(); + buf + } + + /// Partition index derived from `seq_hash`. + /// + /// * `part_bits` — number of low bits to use (partition count = `1 << part_bits`). + fn partition(&self, part_bits: usize) -> usize { + (mix64(self.seq_hash()) & ((1 << part_bits) - 1)) as usize + } +} + +#[inline] +pub(crate) fn mix64(x: u64) -> u64 { + let x = x ^ (x >> 30); + let x = x.wrapping_mul(0xbf58476d1ce4e5b9); + let x = x ^ (x >> 27); + let x = x.wrapping_mul(0x94d049bb133111eb); + x ^ (x >> 31) } diff --git a/src/obikseq/src/superkmer.rs b/src/obikseq/src/superkmer.rs index 7f14303..944fbef 100644 --- a/src/obikseq/src/superkmer.rs +++ b/src/obikseq/src/superkmer.rs @@ -1,84 +1,44 @@ -//! Compact 2-bit DNA super-kmer with in-place reverse complement and canonical form. -use std::io::{self, Write}; +//! Canonical 2-bit DNA super-kmer with occurrence count. +//! +//! Delegates all sequence operations to [`PackedSeq`]. +//! +//! On-disk header word (32 bits): `(count << 2) | tail` — 30-bit count, 2-bit tail. + +use std::io::{self, Read, Write}; -use bitvec::prelude::*; use serde::Serialize; use xxhash_rust::xxh3::xxh3_64; +use crate::Annotation; use crate::Sequence; -use crate::encoding::{DEC4, encode_base}; use crate::kmer::{CanonicalKmer, Kmer, KmerError}; -use crate::revcomp_lookup::REVCOMP4; +use crate::packed_seq::{PackedSeq, read_varint, write_varint}; -// ── SuperKmerHeader ─────────────────────────────────────────────────────────── - -/// 32-bit super-kmer header. -/// -/// Bit layout (MSB → LSB): -/// -/// ```text -/// [31 .......... 8] [7 ...... 0] -/// count (24 b) SEQL (8 b) -/// ``` -/// -/// SEQL encodes the sequence length: 1–255 map directly; 0 encodes 256. -/// The count field starts at 1 and accumulates occurrence counts during -/// deduplication. -#[derive(Debug, Clone, Copy)] -pub(crate) struct SuperKmerHeader(u32); - -impl SuperKmerHeader { - pub(crate) fn new(seql: u8) -> Self { - Self((1 << 8) | seql as u32) - } - - fn seql(&self) -> u8 { - self.0 as u8 - } - - fn count(&self) -> u32 { - self.0 >> 8 - } - - fn increment(&mut self) { - self.0 += 1 << 8; - } - - fn add(&mut self, n: u32) { - self.0 += n << 8; - } - - fn set_count(&mut self, n: u32) { - self.0 = (self.0 & 0xFF) | (n << 8); - } -} +// ── SKAnnotation ────────────────────────────────────────────────────────────── #[derive(Serialize)] struct SKAnnotation { seq_length: usize, - kmer_size: usize, - minimizer_size: usize, - partition: u32, count: u32, } +impl Annotation for SKAnnotation {} + // ── SuperKmer ───────────────────────────────────────────────────────────────── -/// Canonical super-kmer: 32-bit header followed by a byte-aligned 2-bit nucleotide sequence. -/// Nucleotide 0 is at the MSB of `seq[0]`. Always stored in canonical form. +/// Canonical super-kmer: occurrence count + 2-bit packed DNA sequence. /// -/// `PartialEq`, `Eq`, and `Hash` compare only sequence content (seql + seq bytes), -/// ignoring the count / minimizer-pos payload — two records with identical sequences -/// but different counts are considered equal. +/// Always stored in canonical form (lex min of forward and revcomp). +/// `PartialEq`/`Hash` compare only sequence content, ignoring count. #[derive(Debug, Clone)] pub struct SuperKmer { - header: SuperKmerHeader, - seq: Box<[u8]>, + pub(crate) count: u32, + pub(crate) inner: PackedSeq, } impl PartialEq for SuperKmer { fn eq(&self, other: &Self) -> bool { - self.header.seql() == other.header.seql() && self.seq == other.seq + self.inner == other.inner } } @@ -86,320 +46,145 @@ impl Eq for SuperKmer {} impl std::hash::Hash for SuperKmer { fn hash(&self, state: &mut H) { - self.header.seql().hash(state); - self.seq.hash(state); + self.inner.hash(state); } } impl Sequence for SuperKmer { - fn sequence(&self) -> Box<[u8]> { - self.seq.clone() + type Canonical = SuperKmer; + + fn seql(&self) -> usize { + self.inner.seql() } - fn canonical(&self) -> &Self { - &self + fn nucleotide(&self, i: usize) -> u8 { + self.inner.nucleotide(i) } - /// Returns the XXH3-64 hash of the packed sequence bytes. fn seq_hash(&self) -> u64 { - xxh3_64(&self.seq) + xxh3_64(self.inner.seq_bytes()) } - fn annotation(&self) -> Annotation {} + fn canonical(&self) -> Self::Canonical { + self.clone() + } + + fn annotation(&self) -> impl Annotation { + SKAnnotation { + seq_length: self.inner.seql(), + count: self.count, + } + } } + impl SuperKmer { - /// `seql` is the raw stored byte: 1–255 for lengths 1–255, 0 for length 256. - pub fn new(seql: u8, seq: Box<[u8]>) -> Self { - Self::build(seql, seq).0 - } - - /// Construct and canonicalise in place, returning `(sk, already_canonical)`. - /// `already_canonical` is `true` when the sequence was not flipped. - pub fn build(seql: u8, seq: Box<[u8]>) -> (Self, bool) { - let mut sk = Self { - header: SuperKmerHeader::new(seql), - seq, - }; - let already_canonical = sk.canonical(); // true = pas retourné - (sk, already_canonical) - } - - /// Deserialise from a raw 32-bit header word and packed sequence bytes. - /// Preserves the full header payload (count or minimizer_pos in bits [31:8]). - pub fn from_header_bits(bits: u32, seq: Box<[u8]>) -> Self { - let seql = (bits & 0xFF) as u8; - let len = stored_to_len(seql); - debug_assert_eq!(seq.len(), byte_len(len)); - let sk = Self { - header: SuperKmerHeader(bits), - seq, - }; - debug_assert!( - sk.is_canonical(), - "SuperKmer deserialised from disk is not canonical" - ); - sk - } - - /// Returns the sequence length in nucleotides (1–256). - pub fn len(&self) -> usize { - stored_to_len(self.header.seql()) - } - - /// Returns the occurrence count of this super-kmer. - pub fn count(&self) -> u32 { - self.header.count() - } - - /// Increments the occurrence count by 1. - pub fn increment(&mut self) { - self.header.increment(); - } - - /// Adds `n` to the occurrence count. - pub fn add(&mut self, n: u32) { - self.header.add(n); - } - - /// Sets the occurrence count to an absolute value. - pub fn set_count(&mut self, n: u32) { - self.header.set_count(n); - } - - /// Extract nucleotide i (0-based from 5' end) as a 2-bit value. - pub fn nucleotide(&self, i: usize) -> u8 { - (self.seq[i / 4] >> (6 - 2 * (i % 4))) & 0b11 - } - - /// Reverse-complement this super-kmer in place. - /// - /// This method is only used internally by the build method. - fn revcomp(&mut self) { - let seql = self.len(); - let n = byte_len(seql); - - // Step 1: swap bytes outside-in, applying revcomp4 to each. - { - let bytes = &mut self.seq[..n]; - let (mut lo, mut hi) = (0, n - 1); - while lo < hi { - (bytes[lo], bytes[hi]) = - (REVCOMP4[bytes[hi] as usize], REVCOMP4[bytes[lo] as usize]); - lo += 1; - hi -= 1; - } - if lo == hi { - bytes[lo] = REVCOMP4[bytes[lo] as usize]; - } - } - - // Step 2: left-shift to flush padding T's introduced by complementing padding A's. - let shift = n * 8 - seql * 2; - if shift > 0 { - let bits = self.seq[..n].view_bits_mut::(); - bits.rotate_left(shift); - let len = bits.len(); - bits[len - shift..].fill(false); - } - } - - /// Encode an ASCII nucleotide sequence (ACGT, length 1–256) into a canonical SuperKmer. + /// Encode ASCII nucleotides (length ≥ 1) into a canonical SuperKmer. pub fn from_ascii(ascii: &[u8]) -> Self { - let seql = ascii.len(); - debug_assert!( - seql >= 1 && seql <= 256, - "super-kmer length must be 1..=256" - ); - let n = byte_len(seql); - let mut seq = vec![0u8; n]; - - let full = seql / 4; - for i in 0..full { - seq[i] = encode_base(ascii[i * 4]) << 6 - | encode_base(ascii[i * 4 + 1]) << 4 - | encode_base(ascii[i * 4 + 2]) << 2 - | encode_base(ascii[i * 4 + 3]); - } - let rem = seql % 4; - if rem > 0 { - let mut last = 0u8; - for j in 0..rem { - last |= encode_base(ascii[full * 4 + j]) << (6 - 2 * j); - } - seq[full] = last; - } - - Self::new(seql as u8, seq.into_boxed_slice()) // 256usize as u8 == 0, intentional + let mut inner = PackedSeq::from_ascii(ascii); + inner.canonicalize(); + Self { count: 1, inner } } - /// Decode this super-kmer sequence into ASCII nucleotides, writing into `writer`. - pub fn write_ascii(&self, writer: &mut W) -> io::Result<()> { - let seql = self.len(); - let full = seql / 4; - - for i in 0..full { - writer.write_all(&DEC4[self.seq[i] as usize].to_be_bytes())?; - } - let rem = seql % 4; - if rem > 0 { - let bytes = DEC4[self.seq[full] as usize].to_be_bytes(); - writer.write_all(&bytes[..rem])?; - } - Ok(()) + /// Wrap a pre-built [`PackedSeq`], canonicalising in place. + pub fn build(mut inner: PackedSeq) -> Self { + inner.canonicalize(); + Self { count: 1, inner } } - /// Decode this super-kmer sequence into a fresh ASCII `Vec`. - pub fn to_ascii(&self) -> Vec { - let mut buf = Vec::with_capacity(self.len()); - self.write_ascii(&mut buf).unwrap(); - buf + /// Serialise to compact binary. Format: varint(count) + varint((byte_len << 2) | tail) + bytes. + pub fn write_to_binary(&self, w: &mut W) -> io::Result<()> { + write_varint(w, self.count as u64)?; + self.inner.write_to_binary(w) } - /// Returns the raw 32-bit header word for binary serialisation. - /// Bits [7:0] = seql encoding (0→256, 1-255 direct). Bits [31:8] = payload. + /// Deserialise from the binary format produced by [`write_to_binary`]. + /// Allocates exactly one `Box<[u8]>` for the packed bytes. + pub fn read_from_binary(r: &mut R) -> io::Result { + let count = read_varint(r)? as u32; + let inner = PackedSeq::read_from_binary(r)?; + debug_assert!(inner.is_canonical(), "SuperKmer from disk is not canonical"); + Ok(Self { count, inner }) + } + + /// Sequence length in nucleotides. #[inline] - pub fn header_bits(&self) -> u32 { - self.header.0 + pub fn seql(&self) -> usize { + self.inner.seql() } - /// Returns a read-only view of the packed 2-bit sequence bytes. - /// Length is always `(seql() + 3) / 4` bytes. + /// Occurrence count. + #[inline] + pub fn count(&self) -> u32 { + self.count + } + + /// Increment occurrence count by 1. + #[inline] + pub fn increment(&mut self) { + self.count += 1; + } + + /// Add `n` to the occurrence count. + #[inline] + pub fn add(&mut self, n: u32) { + self.count += n; + } + + /// Set the occurrence count to `n`. + #[inline] + pub fn set_count(&mut self, n: u32) { + self.count = n; + } + + /// Read-only view of packed 2-bit bytes. #[inline] pub fn seq_bytes(&self) -> &[u8] { - &self.seq + self.inner.seq_bytes() } - /// Extract the kmer of length k starting at nucleotide position i (0-based). - /// - /// Returns an error if k is invalid (0 or > 32) or if position i + k exceeds the sequence length. - pub fn kmer(&self, i: usize, k: usize) -> Result { - if k == 0 || k > 32 { - return Err(KmerError::InvalidK { k }); - } - let seql = self.len(); - if i + k > seql { - return Err(KmerError::OutOfBounds { - position: i, - k, - seql, - }); - } - let bits = self.seq.view_bits::(); - let raw: u64 = bits[i * 2..(i + k) * 2].load_be(); - Ok(Kmer::from_raw(raw << (64 - 2 * k))) + /// Extract nucleotide i (0-based from 5′ end) as a 2-bit value. + #[inline] + pub fn nucleotide(&self, i: usize) -> u8 { + self.inner.nucleotide(i) } - /// Extract the canonical kmer of length k starting at nucleotide position i (0-based). - /// - /// Returns an error if k is invalid (0 or > 32) or if position i + k exceeds the sequence length. - pub fn canonical_kmer(&self, i: usize, k: usize) -> Result { - Ok(self.kmer(i, k)?.canonical(k)) + /// Extract the k-mer at position `i` using `params::k()`. + #[inline] + pub fn kmer(&self, i: usize) -> Result { + self.inner.kmer(i) } - /// Put this super-kmer in canonical form (lexicographic minimum of forward and revcomp). - /// - /// Returns `true` if already canonical (no change), `false` if revcomp was applied. - fn canonical(&mut self) -> bool { - if self.is_canonical() { - return true; - } - self.revcomp(); - false + /// Extract the canonical k-mer at position `i`. + #[inline] + pub fn canonical_kmer(&self, i: usize) -> Result { + self.inner.canonical_kmer(i) } - /// Returns `true` if this super-kmer is in canonical form (lexicographic minimum of forward and revcomp). - fn is_canonical(&self) -> bool { - let seql = self.len(); - for i in 0..seql { - let fwd = self.nucleotide(i); - let rev = complement(self.nucleotide(seql - 1 - i)); - if fwd < rev { - return true; - } - if fwd > rev { - return false; - } - } - true + /// Decode into ASCII, writing into `writer`. + #[inline] + pub fn write_ascii(&self, writer: &mut W) -> std::io::Result<()> { + self.inner.write_ascii(writer) } - /// Iterate over all kmers of length `k` in order, yielding each as a left-aligned [`Kmer`]. - pub fn iter_kmers(&self, k: usize) -> impl Iterator + '_ { - SKKmerIter::new(self, k) + /// Decode into a fresh ASCII `Vec`. + #[inline] + pub fn to_ascii(&self) -> Vec { + self.inner.to_ascii() } - /// Iterate over all canonical kmers of length `k` in order. - pub fn iter_canonical_kmers(&self, k: usize) -> impl Iterator + '_ { - self.iter_kmers(k).map(move |km| km.canonical(k)) + /// Iterate over all k-mers of length `params::k()` in order. + #[inline] + pub fn iter_kmers(&self) -> impl Iterator + '_ { + self.inner.iter_kmers() } -} -struct SKKmerIter<'a> { - skmer: &'a SuperKmer, - mask: u64, - lshift: usize, - current: u64, - pos: usize, - max_pos: usize, -} - -impl<'a> SKKmerIter<'a> { - fn new(skmer: &'a SuperKmer, k: usize) -> Self { - let seql = skmer.len(); - let lshift = 64 - k * 2; - let mask = ((!0u128) << (lshift + 2)) as u64; - Self { - skmer, - mask, - lshift, - current: if seql >= k { - skmer.kmer(0, k).unwrap().raw() - } else { - 0 - }, - pos: k, - max_pos: seql, - } + /// Iterate over all canonical k-mers in order. + #[inline] + pub fn iter_canonical_kmers(&self) -> impl Iterator + '_ { + self.inner.iter_canonical_kmers() } } -impl<'a> Iterator for SKKmerIter<'a> { - type Item = Kmer; - - fn next(&mut self) -> Option { - if self.pos > self.max_pos { - return None; - } - // Emit current kmer first, then slide the window forward. - let result = Kmer::from_raw(self.current); - if self.pos < self.max_pos { - let byte_pos = self.pos / 4; - // Nucleotide at position r within its byte occupies bits 7-2r (MSB) and 6-2r (LSB). - // Extract right-aligned, then place at lshift. - let inner_shift = 6 - 2 * (self.pos & 3); - let nuc = (((self.skmer.seq[byte_pos] >> inner_shift) & 3) as u64) << self.lshift; - self.current = ((self.current << 2) & self.mask) | nuc; - } - self.pos += 1; - Some(result) - } -} - -// ── helpers ─────────────────────────────────────────────────────────────────── - -fn complement(base: u8) -> u8 { - !base & 0b11 -} - -fn byte_len(seql: usize) -> usize { - (seql + 3) / 4 -} - -/// Stored u8 → actual length: 0 encodes 256, 1–255 are identity. -fn stored_to_len(s: u8) -> usize { - if s == 0 { 256 } else { s as usize } -} - #[cfg(test)] #[path = "tests/superkmer.rs"] mod tests; diff --git a/src/obikseq/src/tests/kmer.rs b/src/obikseq/src/tests/kmer.rs new file mode 100644 index 0000000..2aafdd6 --- /dev/null +++ b/src/obikseq/src/tests/kmer.rs @@ -0,0 +1,213 @@ +use super::*; + +#[cfg(test)] +mod tests { + use super::*; + + // Tests use ConstLen — no dependency on global params singletons. + type K1 = KmerOf>; + type K4 = KmerOf>; + + fn ascii_revcomp(seq: &[u8]) -> Vec { + seq.iter() + .rev() + .map(|&b| match b { + b'A' => b'T', + b'T' => b'A', + b'C' => b'G', + b'G' => b'C', + _ => b'A', + }) + .collect() + } + + fn make_seq() -> Vec { + (0..N).map(|i| b"ACGT"[i % 4]).collect() + } + + // ── from_ascii / to_ascii ───────────────────────────────────────────────── + + #[test] + fn ascii_roundtrip() { + macro_rules! check { + ($n:expr) => {{ + let ascii = make_seq::<$n>(); + let kmer = KmerOf::>::from_ascii(&ascii).unwrap(); + assert_eq!(kmer.to_ascii(), ascii, "roundtrip failed for k={}", $n); + }}; + } + check!(1); + check!(2); + check!(3); + check!(4); + check!(8); + check!(11); + check!(16); + check!(31); + check!(32); + } + + #[test] + fn from_ascii_all_bases() { + for (base, expected) in [(b'A', b'A'), (b'C', b'C'), (b'G', b'G'), (b'T', b'T')] { + let kmer = K1::from_ascii(&[base]).unwrap(); + assert_eq!(kmer.to_ascii(), vec![expected]); + } + } + + #[test] + fn from_ascii_invalid_k() { + assert!(KmerOf::>::from_ascii(b"A").is_err()); + assert!(KmerOf::>::from_ascii(b"ACGT").is_err()); + } + + #[test] + fn from_ascii_too_short() { + assert!(KmerOf::>::from_ascii(b"ACG").is_err()); + } + + // ── nucleotide ──────────────────────────────────────────────────────────── + + #[test] + fn nucleotide_extraction() { + let kmer = K4::from_ascii(b"ACGT").unwrap(); + assert_eq!(kmer.nucleotide(0), 0b00); // A + assert_eq!(kmer.nucleotide(1), 0b01); // C + assert_eq!(kmer.nucleotide(2), 0b10); // G + assert_eq!(kmer.nucleotide(3), 0b11); // T + } + + // ── revcomp ─────────────────────────────────────────────────────────────── + + #[test] + fn revcomp_known_values() { + let cases: &[(&[u8], &[u8])] = &[ + (b"A", b"T"), + (b"AC", b"GT"), + (b"ACG", b"CGT"), + (b"ACGT", b"ACGT"), + (b"AAAA", b"TTTT"), + (b"TTTT", b"AAAA"), + ]; + for (seq, expected) in cases { + macro_rules! check_len { + ($n:expr) => { + if seq.len() == $n { + let kmer = KmerOf::>::from_ascii(seq).unwrap(); + assert_eq!( + kmer.revcomp().to_ascii(), + *expected, + "revcomp wrong for \"{}\"", + std::str::from_utf8(seq).unwrap() + ); + } + }; + } + check_len!(1); + check_len!(2); + check_len!(3); + check_len!(4); + } + } + + #[test] + fn revcomp_vs_reference() { + macro_rules! check { + ($n:expr) => {{ + let ascii = make_seq::<$n>(); + let expected = ascii_revcomp(&ascii); + let rc = KmerOf::>::from_ascii(&ascii) + .unwrap() + .revcomp(); + assert_eq!(rc.to_ascii(), expected, "revcomp wrong for k={}", $n); + }}; + } + check!(1); + check!(4); + check!(8); + check!(11); + check!(16); + check!(31); + check!(32); + } + + #[test] + fn revcomp_involution() { + macro_rules! check { + ($n:expr) => {{ + let ascii = make_seq::<$n>(); + let kmer = KmerOf::>::from_ascii(&ascii).unwrap(); + assert_eq!( + kmer.revcomp().revcomp(), + kmer, + "revcomp∘revcomp≠id for k={}", + $n + ); + }}; + } + check!(1); + check!(4); + check!(8); + check!(16); + check!(31); + check!(32); + } + + // ── canonical ───────────────────────────────────────────────────────────── + + #[test] + fn canonical_palindrome() { + let kmer = K4::from_ascii(b"ACGT").unwrap(); + assert_eq!(kmer.canonical().into_kmer(), kmer); + } + + #[test] + fn canonical_chooses_lesser() { + let kmer = K4::from_ascii(b"TTTT").unwrap(); + let expected = K4::from_ascii(b"AAAA").unwrap(); + assert_eq!(kmer.canonical().into_kmer(), expected); + } + + #[test] + fn canonical_is_minimal() { + macro_rules! check { + ($n:expr) => {{ + let ascii = make_seq::<$n>(); + let ck = KmerOf::>::from_ascii(&ascii) + .unwrap() + .canonical(); + let rc = ck.revcomp(); + assert!(ck.raw() <= rc.raw(), "canonical not minimal for k={}", $n); + }}; + } + check!(1); + check!(4); + check!(8); + check!(16); + check!(31); + check!(32); + } + + #[test] + fn canonical_idempotent() { + macro_rules! check { + ($n:expr) => {{ + let ck = KmerOf::>::from_ascii(&make_seq::<$n>()) + .unwrap() + .canonical(); + assert_eq!( + ck.into_kmer().canonical(), + ck, + "canonical not idempotent for k={}", + $n + ); + }}; + } + check!(1); + check!(4); + check!(8); + check!(16); + check!(31); + check!(32); + } +} diff --git a/src/obikseq/src/tests/superkmer.rs b/src/obikseq/src/tests/superkmer.rs index d5d161f..8d8562c 100644 --- a/src/obikseq/src/tests/superkmer.rs +++ b/src/obikseq/src/tests/superkmer.rs @@ -1,11 +1,10 @@ use super::*; +use crate::set_k; -/// Repeating ACGT pattern of the given length. fn make_seq(len: usize) -> Vec { (0..len).map(|i| b"ACGT"[i % 4]).collect() } -/// Reference revcomp on ASCII bytes. fn ascii_revcomp(seq: &[u8]) -> Vec { seq.iter() .rev() @@ -20,96 +19,93 @@ fn ascii_revcomp(seq: &[u8]) -> Vec { } fn all_lengths() -> impl Iterator { - (1..=9).chain([255, 256]) + (1..=9).chain([255, 256, 257, 1000]) } -// ── kmer extraction ─────────────────────────────────────────────────────── +// ── from_ascii / canonical form ─────────────────────────────────────────────── #[test] -fn kmer_first_matches_from_ascii() { - let ascii = b"ACGTACGT"; - let sk = SuperKmer::from_ascii(ascii); - let k = 4; - let kmer = sk.kmer(0, k).unwrap(); - let expected = crate::kmer::Kmer::from_ascii(&ascii[..k], k).unwrap(); - assert_eq!(kmer, expected); -} - -#[test] -fn kmer_last_position() { - let ascii = b"ACGTACGT"; - let seql = ascii.len(); - let k = 4; - let sk = SuperKmer::from_ascii(ascii); - let kmer = sk.kmer(seql - k, k).unwrap(); - let expected = crate::kmer::Kmer::from_ascii(&ascii[seql - k..], k).unwrap(); - assert_eq!(kmer, expected); -} - -#[test] -fn kmer_all_positions() { - let ascii = b"ACGTACGTACGT"; - let k = 4; - let sk = SuperKmer::from_ascii(ascii); - for i in 0..=ascii.len() - k { - let kmer = sk.kmer(i, k).unwrap(); - let expected = crate::kmer::Kmer::from_ascii(&ascii[i..i + k], k).unwrap(); - assert_eq!(kmer, expected, "mismatch at position {i}"); +fn ascii_roundtrip_all_lengths() { + for len in all_lengths() { + let ascii = make_seq(len); + let sk = SuperKmer::from_ascii(&ascii); + // SuperKmer stores in canonical form; ACGT pattern is already canonical. + assert_eq!(sk.to_ascii(), ascii, "roundtrip failed for len={len}"); } } #[test] -fn kmer_out_of_bounds() { - let sk = SuperKmer::from_ascii(b"ACGT"); - assert!(sk.kmer(2, 4).is_err()); // 2 + 4 > 4 - assert!(sk.kmer(4, 1).is_err()); // 4 + 1 > 4 -} - -#[test] -fn kmer_invalid_k() { - let sk = SuperKmer::from_ascii(b"ACGT"); - assert!(sk.kmer(0, 0).is_err()); - assert!(sk.kmer(0, 33).is_err()); -} - -// ── canonical_kmer ──────────────────────────────────────────────────────── - -#[test] -fn canonical_kmer_is_min_of_kmer_and_revcomp() { - let sk = SuperKmer::from_ascii(b"ACGTACGT"); - let k = 4; - for i in 0..=(sk.len() - k) { - let ck = sk.canonical_kmer(i, k).unwrap(); - let fwd = sk.kmer(i, k).unwrap(); - assert_eq!(ck, fwd.canonical(k)); +fn from_ascii_canonical_all_bases() { + // G×4 revcomp is C×4; T×4 revcomp is A×4. + for (base, expected) in [(b'A', b'A'), (b'C', b'C'), (b'G', b'C'), (b'T', b'A')] { + let ascii = vec![base; 4]; + let sk = SuperKmer::from_ascii(&ascii); + assert_eq!(sk.to_ascii(), vec![expected; 4]); } } #[test] -fn canonical_kmer_palindrome_unchanged() { - // ACGT is its own reverse complement - let sk = SuperKmer::from_ascii(b"ACGT"); - let ck = sk.canonical_kmer(0, 4).unwrap(); - let fwd = sk.kmer(0, 4).unwrap(); - assert_eq!(ck.into_kmer(), fwd); +fn from_ascii_is_canonical_all_lengths() { + for len in all_lengths() { + let ascii = make_seq(len); + let sk = SuperKmer::from_ascii(&ascii); + let fwd = sk.to_ascii(); + let rev = ascii_revcomp(&fwd); + assert!(fwd <= rev, "not canonical for len={len}"); + } +} + +// ── seql ────────────────────────────────────────────────────────────────────── + +#[test] +fn seql_roundtrip() { + for len in all_lengths() { + let sk = SuperKmer::from_ascii(&make_seq(len)); + assert_eq!(sk.seql(), len, "seql() wrong for len={len}"); + } +} + +// ── binary serialisation ────────────────────────────────────────────────────── + +#[test] +fn binary_roundtrip() { + for len in all_lengths() { + let mut sk = SuperKmer::from_ascii(&make_seq(len)); + sk.set_count(42); + let mut buf = Vec::new(); + sk.write_to_binary(&mut buf).unwrap(); + let sk2 = SuperKmer::read_from_binary(&mut buf.as_slice()).unwrap(); + assert_eq!( + sk.to_ascii(), + sk2.to_ascii(), + "sequence mismatch for len={len}" + ); + assert_eq!(sk2.count(), 42, "count mismatch for len={len}"); + } } #[test] -fn canonical_kmer_tttt_becomes_aaaa() { - let sk = SuperKmer::from_ascii(b"TTTT"); - let ck = sk.canonical_kmer(0, 4).unwrap(); - let expected = Kmer::from_ascii(b"AAAA", 4).unwrap(); - assert_eq!(ck.into_kmer(), expected); +fn binary_packed_seq_roundtrip() { + use crate::packed_seq::PackedSeq; + for len in all_lengths() { + let ps = PackedSeq::from_ascii(&make_seq(len)); + let mut buf = Vec::new(); + ps.write_to_binary(&mut buf).unwrap(); + let ps2 = PackedSeq::read_from_binary(&mut buf.as_slice()).unwrap(); + assert_eq!(ps, ps2, "PackedSeq mismatch for len={len}"); + } } #[test] -fn canonical_kmer_errors_propagate() { +fn binary_size_is_compact() { + // seql=4 (1 byte packed): varint(count=1, 1 byte) + varint((1<<2)|0=4, 1 byte) + 1 byte = 3 bytes let sk = SuperKmer::from_ascii(b"ACGT"); - assert!(sk.canonical_kmer(2, 4).is_err()); // out of bounds - assert!(sk.canonical_kmer(0, 0).is_err()); // invalid k + let mut buf = Vec::new(); + sk.write_to_binary(&mut buf).unwrap(); + assert_eq!(buf.len(), 3, "expected 3 bytes for 4-nt superkmer"); } -// ── count ───────────────────────────────────────────────────────────────── +// ── count ───────────────────────────────────────────────────────────────────── #[test] fn count_starts_at_one() { @@ -144,30 +140,15 @@ fn set_count_overwrites() { } #[test] -fn increment_preserves_seql() { +fn count_operations_preserve_seql() { for len in all_lengths() { let mut sk = SuperKmer::from_ascii(&make_seq(len)); sk.increment(); - assert_eq!(sk.len(), len, "increment altered seql for len={len}"); - } -} - -#[test] -fn add_preserves_seql() { - for len in all_lengths() { - let mut sk = SuperKmer::from_ascii(&make_seq(len)); + assert_eq!(sk.seql(), len, "increment altered seql for len={len}"); sk.add(1000); - assert_eq!(sk.len(), len, "add altered seql for len={len}"); - } -} - -#[test] -fn set_count_preserves_seql() { - for len in all_lengths() { - let mut sk = SuperKmer::from_ascii(&make_seq(len)); + assert_eq!(sk.seql(), len, "add altered seql for len={len}"); sk.set_count(999); - assert_eq!(sk.len(), len, "set_count altered seql for len={len}"); - assert_eq!(sk.count(), 999); + assert_eq!(sk.seql(), len, "set_count altered seql for len={len}"); } } @@ -179,247 +160,136 @@ fn count_does_not_affect_sequence() { assert_eq!(sk.to_ascii(), ascii); } -// ── seql encoding ───────────────────────────────────────────────────────── +// ── kmer extraction ─────────────────────────────────────────────────────────── #[test] -fn seql_roundtrip() { - for len in all_lengths() { - let sk = SuperKmer::from_ascii(&make_seq(len)); - assert_eq!(sk.len(), len, "seql() wrong for len={len}"); +fn kmer_first_matches_from_ascii() { + set_k(4); + let k = crate::params::k(); + let ascii = b"ACGTACGT"; + let sk = SuperKmer::from_ascii(ascii); + let kmer = sk.kmer(0).unwrap(); + let expected = crate::kmer::Kmer::from_ascii(&ascii[..k]).unwrap(); + assert_eq!(kmer, expected); +} + +#[test] +fn kmer_all_positions() { + set_k(4); + let k = crate::params::k(); + let ascii = b"ACGTACGTACGT"; + let sk = SuperKmer::from_ascii(ascii); + for i in 0..=ascii.len() - k { + let kmer = sk.kmer(i).unwrap(); + let expected = crate::kmer::Kmer::from_ascii(&ascii[i..i + k]).unwrap(); + assert_eq!(kmer, expected, "mismatch at position {i}"); } } #[test] -fn seql_256_stored_as_zero() { - let sk = SuperKmer::from_ascii(&make_seq(256)); - assert_eq!(sk.header.seql(), 0u8); - assert_eq!(sk.len(), 256); +fn kmer_out_of_bounds() { + set_k(4); + let sk = SuperKmer::from_ascii(b"ACGT"); // seql=4, k=4 + assert!(sk.kmer(1).is_err()); // 1 + 4 > 4 } -// ── from_ascii / to_ascii roundtrip ─────────────────────────────────────── +// ── canonical_kmer ──────────────────────────────────────────────────────────── #[test] -fn ascii_roundtrip_all_lengths() { - for len in all_lengths() { - let ascii = make_seq(len); - let sk = SuperKmer::from_ascii(&ascii); - assert_eq!(sk.to_ascii(), ascii, "roundtrip failed for len={len}"); +fn canonical_kmer_is_min_of_kmer_and_revcomp() { + set_k(4); + let k = crate::params::k(); + let sk = SuperKmer::from_ascii(b"ACGTACGT"); + for i in 0..=(sk.seql() - k) { + let ck = sk.canonical_kmer(i).unwrap(); + let fwd = sk.kmer(i).unwrap(); + assert_eq!(ck, fwd.canonical()); } } #[test] -fn ascii_roundtrip_all_bases() { - // Canonical form: min(seq, revcomp). G×4 flips to C×4, T×4 flips to A×4. - for (base, expected) in [(b'A', b'A'), (b'C', b'C'), (b'G', b'C'), (b'T', b'A')] { - let ascii = vec![base; 4]; - let sk = SuperKmer::from_ascii(&ascii); - assert_eq!(sk.to_ascii(), vec![expected; 4]); - } -} - -// ── revcomp correctness ─────────────────────────────────────────────────── - -/// Known (seq, expected_revcomp) pairs — one per shift value × two byte counts. -#[test] -fn revcomp_known_values() { - let cases = [ - // shift=6 - ("A", "T"), - ("ACGTA", "TACGT"), - // shift=4 - ("AC", "GT"), - ("ACGTAC", "GTACGT"), - // shift=2 - ("ACG", "CGT"), - ("ACGTACG", "CGTACGT"), - // shift=0 - ("ACGT", "ACGT"), - ("ACGTACGT", "ACGTACGT"), - ]; - for (seq, expected) in cases { - let mut sk = SuperKmer::from_ascii(seq.as_bytes()); - sk.revcomp(); - assert_eq!( - sk.to_ascii(), - expected.as_bytes(), - "revcomp wrong for \"{seq}\"" - ); - } +fn canonical_kmer_palindrome_unchanged() { + set_k(4); + let sk = SuperKmer::from_ascii(b"ACGT"); // ACGT is its own revcomp + let ck = sk.canonical_kmer(0).unwrap(); + let fwd = sk.kmer(0).unwrap(); + assert_eq!(ck.into_kmer(), fwd); } #[test] -fn revcomp_vs_reference_all_lengths() { - for len in all_lengths() { - let ascii = make_seq(len); - let expected = ascii_revcomp(&ascii); - let mut sk = SuperKmer::from_ascii(&ascii); - sk.revcomp(); - assert_eq!(sk.to_ascii(), expected, "revcomp wrong for len={len}"); - } +fn canonical_kmer_errors_propagate() { + set_k(4); + let sk = SuperKmer::from_ascii(b"ACGT"); + assert!(sk.canonical_kmer(1).is_err()); // out of bounds: 1 + 4 > 4 } -#[test] -fn revcomp_involution_all_lengths() { - for len in all_lengths() { - let ascii = make_seq(len); - let mut sk = SuperKmer::from_ascii(&ascii); - sk.revcomp(); - sk.revcomp(); - assert_eq!(sk.to_ascii(), ascii, "revcomp∘revcomp≠id for len={len}"); - } -} - -// ── canonical ───────────────────────────────────────────────────────────── - -#[test] -fn canonical_palindrome_unchanged() { - // ACGT is its own revcomp - let mut sk = SuperKmer::from_ascii(b"ACGT"); - sk.canonical(); - assert_eq!(sk.to_ascii(), b"ACGT"); -} - -#[test] -fn canonical_chooses_forward() { - // "AAAA" < "TTTT" → stays as-is - let mut sk = SuperKmer::from_ascii(b"AAAA"); - sk.canonical(); - assert_eq!(sk.to_ascii(), b"AAAA"); -} - -#[test] -fn canonical_chooses_revcomp() { - // "TTTT" > "AAAA" → flipped - let mut sk = SuperKmer::from_ascii(b"TTTT"); - sk.canonical(); - assert_eq!(sk.to_ascii(), b"AAAA"); -} - -#[test] -fn canonical_is_minimal_all_lengths() { - for len in all_lengths() { - let ascii = make_seq(len); - let mut sk = SuperKmer::from_ascii(&ascii); - sk.canonical(); - let fwd = sk.to_ascii(); - let rev = ascii_revcomp(&fwd); - assert!(fwd <= rev, "canonical not minimal for len={len}"); - } -} - -// ── iter_kmers ──────────────────────────────────────────────────────────── +// ── iter_kmers ──────────────────────────────────────────────────────────────── #[test] fn iter_kmers_count() { + set_k(4); + let k = crate::params::k(); let ascii = b"ACGTACGTACGT"; let sk = SuperKmer::from_ascii(ascii); - for k in [1usize, 3, 4, 5, 8, 12] { - let n = sk.iter_kmers(k).count(); - assert_eq!(n, ascii.len() - k + 1, "count mismatch for k={k}"); - } + assert_eq!(sk.iter_kmers().count(), ascii.len() - k + 1); } #[test] fn iter_kmers_first_is_kmer_0() { + set_k(4); let ascii = b"ACGTACGT"; let sk = SuperKmer::from_ascii(ascii); - for k in 1..=ascii.len() { - let first = sk.iter_kmers(k).next().unwrap(); - assert_eq!(first, sk.kmer(0, k).unwrap(), "k={k}"); - } + let first = sk.iter_kmers().next().unwrap(); + assert_eq!(first, sk.kmer(0).unwrap()); } #[test] fn iter_kmers_matches_kmer_at_each_position() { + set_k(4); let ascii = b"ACGTACGTACGT"; let sk = SuperKmer::from_ascii(ascii); - let k = 4; - let kmers: Vec = sk.iter_kmers(k).collect(); - assert_eq!(kmers.len(), ascii.len() - k + 1); + let kmers: Vec = sk.iter_kmers().collect(); for (i, &km) in kmers.iter().enumerate() { - assert_eq!(km, sk.kmer(i, k).unwrap(), "mismatch at pos {i}"); + assert_eq!(km, sk.kmer(i).unwrap(), "mismatch at pos {i}"); } } #[test] fn iter_kmers_single_when_seql_eq_k() { - let ascii = b"ACGTACGT"; - let sk = SuperKmer::from_ascii(ascii); - let k = ascii.len(); - let kmers: Vec = sk.iter_kmers(k).collect(); - assert_eq!(kmers.len(), 1); - assert_eq!(kmers[0], sk.kmer(0, k).unwrap()); -} - -#[test] -fn iter_kmers_two_when_seql_eq_k_plus_one() { - let ascii = b"ACGTACGT"; - let sk = SuperKmer::from_ascii(ascii); - let k = ascii.len() - 1; - let kmers: Vec = sk.iter_kmers(k).collect(); - assert_eq!(kmers.len(), 2); - assert_eq!(kmers[0], sk.kmer(0, k).unwrap()); - assert_eq!(kmers[1], sk.kmer(1, k).unwrap()); -} - -#[test] -fn iter_kmers_all_k_values() { - // For every valid k, each yielded kmer must match kmer(i, k). - let ascii = b"ACGTACGTACGT"; - let sk = SuperKmer::from_ascii(ascii); - let seql = ascii.len(); - for k in 1..=seql { - let kmers: Vec = sk.iter_kmers(k).collect(); - assert_eq!(kmers.len(), seql - k + 1, "k={k}"); - for (i, &km) in kmers.iter().enumerate() { - assert_eq!(km, sk.kmer(i, k).unwrap(), "k={k}, pos={i}"); - } - } + set_k(4); + let k = crate::params::k(); + let ascii = make_seq(k); + let sk = SuperKmer::from_ascii(&ascii); + assert_eq!(sk.iter_kmers().count(), 1); + assert_eq!(sk.iter_kmers().next().unwrap(), sk.kmer(0).unwrap()); } #[test] fn iter_kmers_crosses_byte_boundary() { - // Positions 3→4 and 7→8 cross a 4-nucleotide byte boundary. + set_k(4); let ascii = b"ACGTACGTACGT"; let sk = SuperKmer::from_ascii(ascii); - let k = 3; - let kmers: Vec = sk.iter_kmers(k).collect(); + let kmers: Vec = sk.iter_kmers().collect(); for boundary in [3usize, 4, 7, 8] { if boundary + 1 < kmers.len() { assert_eq!( kmers[boundary], - sk.kmer(boundary, k).unwrap(), + sk.kmer(boundary).unwrap(), "pos={boundary}" ); - assert_eq!( - kmers[boundary + 1], - sk.kmer(boundary + 1, k).unwrap(), - "pos={}", - boundary + 1 - ); } } } -#[test] -fn iter_kmers_k1_yields_all_nucleotides() { - let ascii = b"ACGT"; - let sk = SuperKmer::from_ascii(ascii); - let kmers: Vec = sk.iter_kmers(1).collect(); - assert_eq!(kmers.len(), 4); - for (i, &km) in kmers.iter().enumerate() { - assert_eq!(km, sk.kmer(i, 1).unwrap(), "pos={i}"); - } -} - #[test] fn iter_kmers_long_sequence() { - let ascii = make_seq(20); + set_k(4); + let k = crate::params::k(); + let ascii = make_seq(200); let sk = SuperKmer::from_ascii(&ascii); - let k = 7; - let kmers: Vec = sk.iter_kmers(k).collect(); - assert_eq!(kmers.len(), ascii.len() - k + 1); + let kmers: Vec = sk.iter_kmers().collect(); + assert_eq!(kmers.len(), 200 - k + 1); for (i, &km) in kmers.iter().enumerate() { - assert_eq!(km, sk.kmer(i, k).unwrap(), "pos={i}"); + assert_eq!(km, sk.kmer(i).unwrap(), "pos={i}"); } } diff --git a/src/obikseq/src/tests/unitig.rs b/src/obikseq/src/tests/unitig.rs new file mode 100644 index 0000000..7e246c9 --- /dev/null +++ b/src/obikseq/src/tests/unitig.rs @@ -0,0 +1,171 @@ +// ── tests ───────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use crate::packed_seq::PackedSeq as Unitig; + use crate::set_k; + + fn make_seq(len: usize) -> Vec { + (0..len).map(|i| b"ACGT"[i % 4]).collect() + } + + fn ascii_revcomp(seq: &[u8]) -> Vec { + seq.iter() + .rev() + .map(|&b| match b { + b'A' => b'T', + b'T' => b'A', + b'C' => b'G', + b'G' => b'C', + _ => b'A', + }) + .collect() + } + + fn test_lengths() -> impl Iterator { + (1..=9).chain([255, 256, 257, 1000, 10_000]) + } + + // ── from_ascii / to_ascii ───────────────────────────────────────────────── + + #[test] + fn ascii_roundtrip_all_lengths() { + for len in test_lengths() { + let ascii = make_seq(len); + let u = Unitig::from_ascii(&ascii); + assert_eq!(u.to_ascii(), ascii, "roundtrip failed for len={len}"); + } + } + + // ── seql ────────────────────────────────────────────────────────────────── + + #[test] + fn seql_roundtrip() { + for len in test_lengths() { + let u = Unitig::from_ascii(&make_seq(len)); + assert_eq!(u.seql(), len); + } + } + + // ── revcomp ─────────────────────────────────────────────────────────────── + + #[test] + fn revcomp_known_values() { + let cases = [ + ("A", "T"), + ("AC", "GT"), + ("ACG", "CGT"), + ("ACGT", "ACGT"), + ("ACGTA", "TACGT"), + ]; + for (seq, expected) in cases { + let mut u = Unitig::from_ascii(seq.as_bytes()); + u.revcomp_inplace(); + assert_eq!( + u.to_ascii(), + expected.as_bytes(), + "revcomp wrong for \"{seq}\"" + ); + } + } + + #[test] + fn revcomp_vs_reference_all_lengths() { + for len in test_lengths() { + let ascii = make_seq(len); + let expected = ascii_revcomp(&ascii); + let mut u = Unitig::from_ascii(&ascii); + u.revcomp_inplace(); + assert_eq!(u.to_ascii(), expected, "revcomp wrong for len={len}"); + } + } + + #[test] + fn revcomp_involution_all_lengths() { + for len in test_lengths() { + let ascii = make_seq(len); + let mut u = Unitig::from_ascii(&ascii); + u.revcomp_inplace(); + u.revcomp_inplace(); + assert_eq!(u.to_ascii(), ascii, "revcomp∘revcomp≠id for len={len}"); + } + } + + // ── canonicalize ────────────────────────────────────────────────────────── + + #[test] + fn canonical_palindrome_unchanged() { + let mut u = Unitig::from_ascii(b"ACGT"); + u.canonicalize(); + assert_eq!(u.to_ascii(), b"ACGT"); + } + + #[test] + fn canonical_chooses_revcomp() { + let mut u = Unitig::from_ascii(b"TTTT"); + u.canonicalize(); + assert_eq!(u.to_ascii(), b"AAAA"); + } + + #[test] + fn canonical_is_minimal_all_lengths() { + for len in test_lengths() { + let ascii = make_seq(len); + let mut u = Unitig::from_ascii(&ascii); + u.canonicalize(); + let fwd = u.to_ascii(); + let rev = ascii_revcomp(&fwd); + assert!(fwd <= rev, "canonical not minimal for len={len}"); + } + } + + // ── kmer extraction ─────────────────────────────────────────────────────── + + #[test] + fn kmer_all_positions() { + set_k(4); + let k = crate::params::k(); + let ascii = b"ACGTACGTACGT"; + let u = Unitig::from_ascii(ascii); + for i in 0..=ascii.len() - k { + let kmer = u.kmer(i).unwrap(); + let expected = crate::kmer::Kmer::from_ascii(&ascii[i..i + k]).unwrap(); + assert_eq!(kmer, expected, "mismatch at position {i}"); + } + } + + // ── iter_kmers ──────────────────────────────────────────────────────────── + + #[test] + fn iter_kmers_matches_kmer_at_each_position() { + set_k(4); + let ascii = make_seq(20); + let u = Unitig::from_ascii(&ascii); + let kmers: Vec = u.iter_kmers().collect(); + for (i, &km) in kmers.iter().enumerate() { + assert_eq!(km, u.kmer(i).unwrap(), "pos={i}"); + } + } + + #[test] + fn iter_kmers_long_unitig() { + set_k(4); + let k = crate::params::k(); + let ascii = make_seq(10_000); + let u = Unitig::from_ascii(&ascii); + assert_eq!(u.iter_kmers().count(), 10_000 - k + 1); + } + + // ── binary serialisation ────────────────────────────────────────────────── + + #[test] + fn binary_roundtrip_all_lengths() { + for len in test_lengths() { + let u = Unitig::from_ascii(&make_seq(len)); + let mut buf = Vec::new(); + u.write_to_binary(&mut buf).unwrap(); + let u2 = Unitig::read_from_binary(&mut buf.as_slice()).unwrap(); + assert_eq!(u, u2, "binary roundtrip failed for len={len}"); + } + } +} diff --git a/src/obikseq/src/unitig.rs b/src/obikseq/src/unitig.rs index d543df2..4a6a2c0 100644 --- a/src/obikseq/src/unitig.rs +++ b/src/obikseq/src/unitig.rs @@ -1,424 +1,10 @@ -//! Compact 2-bit DNA unitig with in-place reverse complement and canonical form. +//! Unitig: a 2-bit packed DNA sequence without metadata. //! -//! Same encoding as [`SuperKmer`](crate::superkmer::SuperKmer) — nucleotide 0 -//! at the MSB of `seq[0]`, 4 bases per byte — but without the 256-nucleotide -//! length cap and without the scatter/count header payload. +//! [`Unitig`] is a type alias for [`PackedSeq`] — all sequence operations, +//! binary serialisation, and k-mer iteration are available directly. -use std::io::{self, Write}; - -use crate::encoding::{DEC4, encode_base}; -use crate::kmer::{CanonicalKmer, Kmer, KmerError}; -use crate::revcomp_lookup::REVCOMP4; -use bitvec::prelude::*; - -// ── Unitig ──────────────────────────────────────────────────────────────────── - -/// Compact unitig: sequence length (usize) + byte-aligned 2-bit nucleotide sequence. -/// -/// Encoding: A=00, C=01, G=10, T=11. Nucleotide 0 occupies bits 7–6 of `seq[0]`, -/// nucleotide i occupies bits `7 − 2*(i%4)` and `6 − 2*(i%4)` of `seq[i/4]`. -/// Padding bits in the last byte are always 0. -#[derive(Debug, Clone)] -pub struct Unitig { - seql: usize, - seq: Box<[u8]>, -} - -impl PartialEq for Unitig { - fn eq(&self, other: &Self) -> bool { - self.seql == other.seql && self.seq == other.seq - } -} - -impl Eq for Unitig {} - -impl std::hash::Hash for Unitig { - fn hash(&self, state: &mut H) { - self.seql.hash(state); - self.seq.hash(state); - } -} - -impl Unitig { - /// Create from a pre-packed 2-bit byte slice and explicit length. - /// `seq.len()` must equal `(seql + 3) / 4`. - pub fn new(seql: usize, seq: Box<[u8]>) -> Self { - debug_assert_eq!(seq.len(), byte_len(seql)); - Self { seql, seq } - } - - /// Encode a slice of 2-bit nucleotide values (0=A, 1=C, 2=G, 3=T, any length ≥ 1). - /// More efficient than `from_ascii` when nucleotides are already 2-bit encoded. - pub fn from_nucleotides(nucs: &[u8]) -> Self { - let seql = nucs.len(); - debug_assert!(seql >= 1, "unitig length must be ≥ 1"); - let n = byte_len(seql); - let mut seq = vec![0u8; n]; - for (i, &nuc) in nucs.iter().enumerate() { - seq[i / 4] |= (nuc & 0b11) << (6 - 2 * (i % 4)); - } - Self::new(seql, seq.into_boxed_slice()) - } - - /// Encode an ASCII nucleotide slice (ACGT, any length ≥ 1) into a new Unitig. - /// The result is not yet in canonical form; call `.canonical()` if needed. - pub fn from_ascii(ascii: &[u8]) -> Self { - let seql = ascii.len(); - debug_assert!(seql >= 1, "unitig length must be ≥ 1"); - let n = byte_len(seql); - let mut seq = vec![0u8; n]; - - let full = seql / 4; - for i in 0..full { - seq[i] = encode_base(ascii[i * 4]) << 6 - | encode_base(ascii[i * 4 + 1]) << 4 - | encode_base(ascii[i * 4 + 2]) << 2 - | encode_base(ascii[i * 4 + 3]); - } - let rem = seql % 4; - if rem > 0 { - let mut last = 0u8; - for j in 0..rem { - last |= encode_base(ascii[full * 4 + j]) << (6 - 2 * j); - } - seq[full] = last; - } - - Self::new(seql, seq.into_boxed_slice()) - } - - /// Returns the sequence length in nucleotides. - pub fn seql(&self) -> usize { - self.seql - } - - /// Returns a read-only view of the packed 2-bit sequence bytes. - /// Length is always `(seql() + 3) / 4`. - pub fn seq_bytes(&self) -> &[u8] { - &self.seq - } - - /// Extract nucleotide i (0-based from 5′ end) as a 2-bit value. - pub fn nucleotide(&self, i: usize) -> u8 { - (self.seq[i / 4] >> (6 - 2 * (i % 4))) & 0b11 - } - - /// Decode into ASCII nucleotides, writing into `writer`. - pub fn write_ascii(&self, writer: &mut W) -> io::Result<()> { - let full = self.seql / 4; - for i in 0..full { - writer.write_all(&DEC4[self.seq[i] as usize].to_be_bytes())?; - } - let rem = self.seql % 4; - if rem > 0 { - let bytes = DEC4[self.seq[full] as usize].to_be_bytes(); - writer.write_all(&bytes[..rem])?; - } - Ok(()) - } - - /// Decode into a fresh ASCII `Vec`. - pub fn to_ascii(&self) -> Vec { - let mut buf = Vec::with_capacity(self.seql); - self.write_ascii(&mut buf).unwrap(); - buf - } - - /// Reverse-complement this unitig in place. - pub fn revcomp(&mut self) { - let n = byte_len(self.seql); - - // Step 1: swap bytes outside-in, complementing each 4-base chunk via lookup. - { - let bytes = &mut self.seq[..n]; - let (mut lo, mut hi) = (0, n - 1); - while lo < hi { - (bytes[lo], bytes[hi]) = - (REVCOMP4[bytes[hi] as usize], REVCOMP4[bytes[lo] as usize]); - lo += 1; - hi -= 1; - } - if lo == hi { - bytes[lo] = REVCOMP4[bytes[lo] as usize]; - } - } - - // Step 2: left-shift to flush the padding T's produced by complementing padding A's. - let shift = n * 8 - self.seql * 2; - if shift > 0 { - let bits = self.seq[..n].view_bits_mut::(); - bits.rotate_left(shift); - let len = bits.len(); - bits[len - shift..].fill(false); - } - } - - /// Returns `true` if this unitig is in canonical form (lexicographic minimum - /// of forward and reverse complement). - pub fn is_canonical(&self) -> bool { - for i in 0..self.seql { - let fwd = self.nucleotide(i); - let rev = complement(self.nucleotide(self.seql - 1 - i)); - if fwd < rev { - return true; - } - if fwd > rev { - return false; - } - } - true - } - - /// Put this unitig in canonical form in place. - /// - /// Returns `true` if already canonical (no change), `false` if revcomp was applied. - pub fn canonical(&mut self) -> bool { - if self.is_canonical() { - return true; - } - self.revcomp(); - false - } - - /// Extract the kmer of length `k` starting at nucleotide position `i` (0-based). - pub fn kmer(&self, i: usize, k: usize) -> Result { - if k == 0 || k > 32 { - return Err(KmerError::InvalidK { k }); - } - if i + k > self.seql { - return Err(KmerError::OutOfBounds { - position: i, - k, - seql: self.seql, - }); - } - let bits = self.seq.view_bits::(); - let raw: u64 = bits[i * 2..(i + k) * 2].load_be(); - Ok(Kmer::from_raw(raw << (64 - 2 * k))) - } - - /// Extract the canonical kmer of length `k` starting at position `i`. - pub fn canonical_kmer(&self, i: usize, k: usize) -> Result { - Ok(self.kmer(i, k)?.canonical(k)) - } - - /// Iterate over all kmers of length `k` in order, yielding each as a [`Kmer`]. - pub fn iter_kmers(&self, k: usize) -> impl Iterator + '_ { - UnitigKmerIter::new(self, k) - } - - /// Iterate over all canonical kmers of length `k` in order. - pub fn iter_canonical_kmers(&self, k: usize) -> impl Iterator + '_ { - self.iter_kmers(k).map(move |km| km.canonical(k)) - } -} - -// ── UnitigKmerIter ──────────────────────────────────────────────────────────── - -struct UnitigKmerIter<'a> { - unitig: &'a Unitig, - mask: u64, - lshift: usize, - current: u64, - pos: usize, - max_pos: usize, -} - -impl<'a> UnitigKmerIter<'a> { - fn new(unitig: &'a Unitig, k: usize) -> Self { - let seql = unitig.seql(); - let lshift = 64 - k * 2; - let mask = ((!0u128) << (lshift + 2)) as u64; - Self { - unitig, - mask, - lshift, - current: if seql >= k { unitig.kmer(0, k).unwrap().raw() } else { 0 }, - pos: k, - max_pos: seql, - } - } -} - -impl<'a> Iterator for UnitigKmerIter<'a> { - type Item = Kmer; - - fn next(&mut self) -> Option { - if self.pos > self.max_pos { - return None; - } - let result = Kmer::from_raw(self.current); - if self.pos < self.max_pos { - let byte_pos = self.pos / 4; - // nucleotide at position p within its byte occupies bits 7−2*(p%4) and 6−2*(p%4) - let inner_shift = 6 - 2 * (self.pos & 3); - let nuc = (((self.unitig.seq[byte_pos] >> inner_shift) & 3) as u64) << self.lshift; - self.current = ((self.current << 2) & self.mask) | nuc; - } - self.pos += 1; - Some(result) - } -} - -// ── helpers ─────────────────────────────────────────────────────────────────── - -fn complement(base: u8) -> u8 { - !base & 0b11 -} - -fn byte_len(seql: usize) -> usize { - (seql + 3) / 4 -} - -// ── tests ───────────────────────────────────────────────────────────────────── +pub use crate::packed_seq::PackedSeq as Unitig; #[cfg(test)] -mod tests { - use super::*; - - fn make_seq(len: usize) -> Vec { - (0..len).map(|i| b"ACGT"[i % 4]).collect() - } - - fn ascii_revcomp(seq: &[u8]) -> Vec { - seq.iter() - .rev() - .map(|&b| match b { - b'A' => b'T', - b'T' => b'A', - b'C' => b'G', - b'G' => b'C', - _ => b'A', - }) - .collect() - } - - fn test_lengths() -> impl Iterator { - (1..=9).chain([255, 256, 257, 1000, 10_000]) - } - - // ── from_ascii / to_ascii ───────────────────────────────────────────────── - - #[test] - fn ascii_roundtrip_all_lengths() { - for len in test_lengths() { - let ascii = make_seq(len); - let u = Unitig::from_ascii(&ascii); - assert_eq!(u.to_ascii(), ascii, "roundtrip failed for len={len}"); - } - } - - // ── seql ────────────────────────────────────────────────────────────────── - - #[test] - fn seql_roundtrip() { - for len in test_lengths() { - let u = Unitig::from_ascii(&make_seq(len)); - assert_eq!(u.seql(), len); - } - } - - // ── revcomp ─────────────────────────────────────────────────────────────── - - #[test] - fn revcomp_known_values() { - let cases = [ - ("A", "T"), - ("AC", "GT"), - ("ACG", "CGT"), - ("ACGT", "ACGT"), - ("ACGTA", "TACGT"), - ]; - for (seq, expected) in cases { - let mut u = Unitig::from_ascii(seq.as_bytes()); - u.revcomp(); - assert_eq!(u.to_ascii(), expected.as_bytes(), "revcomp wrong for \"{seq}\""); - } - } - - #[test] - fn revcomp_vs_reference_all_lengths() { - for len in test_lengths() { - let ascii = make_seq(len); - let expected = ascii_revcomp(&ascii); - let mut u = Unitig::from_ascii(&ascii); - u.revcomp(); - assert_eq!(u.to_ascii(), expected, "revcomp wrong for len={len}"); - } - } - - #[test] - fn revcomp_involution_all_lengths() { - for len in test_lengths() { - let ascii = make_seq(len); - let mut u = Unitig::from_ascii(&ascii); - u.revcomp(); - u.revcomp(); - assert_eq!(u.to_ascii(), ascii, "revcomp∘revcomp≠id for len={len}"); - } - } - - // ── canonical ───────────────────────────────────────────────────────────── - - #[test] - fn canonical_palindrome_unchanged() { - let mut u = Unitig::from_ascii(b"ACGT"); - u.canonical(); - assert_eq!(u.to_ascii(), b"ACGT"); - } - - #[test] - fn canonical_chooses_revcomp() { - let mut u = Unitig::from_ascii(b"TTTT"); - u.canonical(); - assert_eq!(u.to_ascii(), b"AAAA"); - } - - #[test] - fn canonical_is_minimal_all_lengths() { - for len in test_lengths() { - let ascii = make_seq(len); - let mut u = Unitig::from_ascii(&ascii); - u.canonical(); - let fwd = u.to_ascii(); - let rev = ascii_revcomp(&fwd); - assert!(fwd <= rev, "canonical not minimal for len={len}"); - } - } - - // ── kmer extraction ─────────────────────────────────────────────────────── - - #[test] - fn kmer_all_positions() { - let ascii = b"ACGTACGTACGT"; - let k = 4; - let u = Unitig::from_ascii(ascii); - for i in 0..=ascii.len() - k { - let kmer = u.kmer(i, k).unwrap(); - let expected = Kmer::from_ascii(&ascii[i..i + k], k).unwrap(); - assert_eq!(kmer, expected, "mismatch at position {i}"); - } - } - - // ── iter_kmers ──────────────────────────────────────────────────────────── - - #[test] - fn iter_kmers_matches_kmer_at_each_position() { - let ascii = make_seq(20); - let k = 7; - let u = Unitig::from_ascii(&ascii); - let kmers: Vec = u.iter_kmers(k).collect(); - assert_eq!(kmers.len(), ascii.len() - k + 1); - for (i, &km) in kmers.iter().enumerate() { - assert_eq!(km, u.kmer(i, k).unwrap(), "pos={i}"); - } - } - - #[test] - fn iter_kmers_long_unitig() { - let ascii = make_seq(10_000); - let k = 11; - let u = Unitig::from_ascii(&ascii); - assert_eq!(u.iter_kmers(k).count(), 10_000 - k + 1); - } -} +#[path = "tests/unitig.rs"] +mod tests; diff --git a/src/obiskbuilder/Cargo.toml b/src/obiskbuilder/Cargo.toml index 3b2cb51..0b68fb8 100644 --- a/src/obiskbuilder/Cargo.toml +++ b/src/obiskbuilder/Cargo.toml @@ -7,3 +7,6 @@ edition = "2024" obikseq = { path = "../obikseq" } obikrope = { path = "../obikrope" } lazy_static = "1.5.0" + +[dev-dependencies] +obikseq = { path = "../obikseq", features = ["test-utils"] } diff --git a/src/obiskbuilder/src/entropy_table.rs b/src/obiskbuilder/src/entropy_table.rs index 6aeb213..4132ef7 100644 --- a/src/obiskbuilder/src/entropy_table.rs +++ b/src/obiskbuilder/src/entropy_table.rs @@ -21,7 +21,6 @@ pub(crate) static LN_CARD_ROT5: LazyLock<[f64; 1024]> = pub(crate) static LN_CARD_ROT6: LazyLock<[f64; 4096]> = LazyLock::new(|| build_log_class_size::<4096>(&NORMK6)); - fn ln0(x: f64) -> f64 { if x == 0.0 { 0.0 } else { x.ln() } } @@ -47,7 +46,7 @@ fn build_normalized_kmer() -> [u64; N] { for i in 0..N { let la = (i as u64) << shift; let ra = i as u64; - let rc_ra = Kmer::from_raw(la).revcomp(k).raw() >> shift; + let rc_ra = Kmer::from_raw(la).revcomp().raw() >> shift; let circ = normalize_circular(ra, k); let circ_rc = normalize_circular(rc_ra, k); result[i] = circ.min(circ_rc); @@ -107,12 +106,10 @@ pub(crate) const K_MAX: usize = 32; pub(crate) const WS_MAX: usize = 6; /// n·ln(n), with n_log_n[0] = 0. Indexed by n = 0..=K_MAX. -pub(crate) static N_LOG_N: LazyLock<[f64; K_MAX + 1]> = - LazyLock::new(|| build_n_log_n()); +pub(crate) static N_LOG_N: LazyLock<[f64; K_MAX + 1]> = LazyLock::new(|| build_n_log_n()); /// H_max[k][ws]: maximum entropy for kmer length k and word size ws. -pub(crate) static EMAX: LazyLock<[[f64; WS_MAX + 1]; K_MAX + 1]> = - LazyLock::new(|| build_emax()); +pub(crate) static EMAX: LazyLock<[[f64; WS_MAX + 1]; K_MAX + 1]> = LazyLock::new(|| build_emax()); /// ln(k − ws + 1): log of the number of ws-words in a kmer of length k. pub(crate) static LOG_NWORDS: LazyLock<[[f64; WS_MAX + 1]; K_MAX + 1]> = diff --git a/src/obiskbuilder/src/iter.rs b/src/obiskbuilder/src/iter.rs index 7c09dfa..1505203 100644 --- a/src/obiskbuilder/src/iter.rs +++ b/src/obiskbuilder/src/iter.rs @@ -16,8 +16,8 @@ //! | super-kmer length = 256| k | use obikrope::{ForwardCursor, Rope, RopeCursor}; -use obikseq::kmer::CanonicalKmer; use obikseq::RoutableSuperKmer; +use obikseq::kmer::Minimizer; use crate::rolling_stat::RollingStat; use crate::scratch::SuperKmerScratch; @@ -26,11 +26,10 @@ use crate::scratch::SuperKmerScratch; pub struct SuperKmerIter<'a> { cursor: ForwardCursor<'a>, k: usize, - m: usize, theta: f64, scratch: SuperKmerScratch, stat: RollingStat, - prev_min: Option, + prev_min: Option, prev_min_pos: usize, } @@ -41,14 +40,13 @@ impl<'a> SuperKmerIter<'a> { /// - `m`: minimizer size (1 < m < k) /// - `level_max`: maximum sub-word size for entropy (1–6) /// - `theta`: entropy threshold; k-mers with score ≤ theta are rejected - pub fn new(rope: &'a Rope, k: usize, m: usize, level_max: usize, theta: f64) -> Self { + pub fn new(rope: &'a Rope, k: usize, level_max: usize, theta: f64) -> Self { Self { cursor: rope.fw_cursor(), k, - m, theta, scratch: SuperKmerScratch::new(), - stat: RollingStat::new(k, m, level_max), + stat: RollingStat::new(level_max), prev_min: None, prev_min_pos: 0, } @@ -66,7 +64,7 @@ impl<'a> SuperKmerIter<'a> { return None; } self.prev_min?; - Some(self.scratch.emit(self.prev_min_pos, self.m)) + Some(self.scratch.emit(self.prev_min_pos)) } } @@ -149,26 +147,31 @@ mod tests { use super::*; use obikrope::Rope; + fn setup() { + obikseq::params::set_k(K); + obikseq::params::set_m(5); + } + fn make_rope(data: &[u8]) -> Rope { let mut r = Rope::new(None); r.push(data.to_vec()); r } - fn run_nofilter(data: &[u8], k: usize, m: usize) -> Vec> { + fn run_nofilter(data: &[u8], k: usize) -> Vec> { let rope = make_rope(data); - SuperKmerIter::new(&rope, k, m, 1, 0.0) + SuperKmerIter::new(&rope, k, 1, 0.0) .map(|rsk| rsk.superkmer().to_ascii()) .collect() } // k=11, m=5 — valeurs minimales du projet (k ∈ [11,31]) const K: usize = 11; - const M: usize = 5; #[test] fn single_segment_one_superkmer() { - let out = run_nofilter(b"ACGTACGTACGTACGTACGT\x00", K, M); + setup(); + let out = run_nofilter(b"ACGTACGTACGTACGTACGT\x00", K); assert!(!out.is_empty()); let total: Vec = out.into_iter().flatten().collect(); assert!(total.len() >= K); @@ -176,29 +179,33 @@ mod tests { #[test] fn segment_shorter_than_k_emits_nothing() { - let out = run_nofilter(b"ACGTACGT\x00", K, M); + setup(); + let out = run_nofilter(b"ACGTACGT\x00", K); assert_eq!(out, Vec::>::new()); } #[test] fn empty_input_emits_nothing() { - let out = run_nofilter(b"", K, M); + setup(); + let out = run_nofilter(b"", K); assert_eq!(out, Vec::>::new()); } #[test] fn two_segments_both_emitted() { - let out = run_nofilter(b"ACGTACGTACGTACGT\x00TGCATGCATGCATGCA\x00", K, M); + setup(); + let out = run_nofilter(b"ACGTACGTACGTACGT\x00TGCATGCATGCATGCA\x00", K); assert!(!out.is_empty()); } #[test] fn low_complexity_kmer_is_rejected() { - let out_pass = run_nofilter(b"AAAAAAAAAAAACGTACGTACGT\x00", K, M); + setup(); + let out_pass = run_nofilter(b"AAAAAAAAAAAACGTACGTACGT\x00", K); assert!(!out_pass.is_empty()); let rope = make_rope(b"AAAAAAAAAAAAAAAAAAAA\x00"); - let out_reject: Vec> = SuperKmerIter::new(&rope, K, M, 6, 0.9) + let out_reject: Vec> = SuperKmerIter::new(&rope, K, 6, 0.9) .map(|rsk| rsk.superkmer().to_ascii()) .collect(); assert!(out_reject.is_empty()); @@ -206,12 +213,13 @@ mod tests { #[test] fn multi_slice_rope() { + setup(); let data = b"ACGTACGTACGTACGTACGT\x00"; let mid = data.len() / 2; let mut rope = Rope::new(None); rope.push(data[..mid].to_vec()); rope.push(data[mid..].to_vec()); - let out: Vec> = SuperKmerIter::new(&rope, K, M, 1, 0.0) + let out: Vec> = SuperKmerIter::new(&rope, K, 1, 0.0) .map(|rsk| rsk.superkmer().to_ascii()) .collect(); assert!(!out.is_empty()); @@ -219,8 +227,9 @@ mod tests { #[test] fn yields_minimizer_value() { + setup(); let rope = make_rope(b"ACGTACGTACGTACGTACGT\x00"); - let results: Vec = SuperKmerIter::new(&rope, K, M, 1, 0.0).collect(); + let results: Vec = SuperKmerIter::new(&rope, K, 1, 0.0).collect(); assert!(!results.is_empty()); } } diff --git a/src/obiskbuilder/src/lib.rs b/src/obiskbuilder/src/lib.rs index 9a75f4b..89ea349 100644 --- a/src/obiskbuilder/src/lib.rs +++ b/src/obiskbuilder/src/lib.rs @@ -19,6 +19,11 @@ use obikrope::Rope; use obikseq::RoutableSuperKmer; /// Collect all super-kmers from a normalised rope chunk. -pub fn build_superkmers(rope: Rope, k: usize, m: usize, level_max: usize, theta: f64) -> Vec { - SuperKmerIter::new(&rope, k, m, level_max, theta).collect() +pub fn build_superkmers( + rope: Rope, + k: usize, + level_max: usize, + theta: f64, +) -> Vec { + SuperKmerIter::new(&rope, k, level_max, theta).collect() } diff --git a/src/obiskbuilder/src/rolling_stat.rs b/src/obiskbuilder/src/rolling_stat.rs index a4e4b05..fb61d20 100644 --- a/src/obiskbuilder/src/rolling_stat.rs +++ b/src/obiskbuilder/src/rolling_stat.rs @@ -1,4 +1,5 @@ -use obikseq::kmer::{CanonicalKmer, Kmer}; +use obikseq::kmer::{Minimizer, hash_kmer}; +use obikseq::params; use crate::encoding::encode_nuc; use crate::entropy_table::{WS_MAX, emax, entropy_norm_kmer, ln_class_size, log_nwords, n_log_n}; @@ -13,22 +14,7 @@ struct MmerItem { hash: u64, } -/// Bijective hash used to randomise the minimizer ordering. -/// The XOR seed (2^64/φ) breaks the mix64 fixed point at 0, -/// preventing poly-A/T kmers (canonical = 0) from always winning. -#[inline(always)] -fn hash_mmer(canonical: u64) -> u64 { - let x = canonical ^ 0x9e3779b97f4a7c15; - let x = x ^ (x >> 30); - let x = x.wrapping_mul(0xbf58476d1ce4e5b9); - let x = x ^ (x >> 27); - let x = x.wrapping_mul(0x94d049bb133111eb); - x ^ (x >> 31) -} - pub struct RollingStat { - k: usize, - m: usize, entropy_max_k: usize, rolling_k: u64, rolling_rck: u64, @@ -53,15 +39,15 @@ pub struct RollingStat { } impl RollingStat { - pub fn new(k: usize, m: usize, entropy_max_k: usize) -> Self { + pub fn new(entropy_max_k: usize) -> Self { + let k = params::k(); + let m = params::m(); Self { - k, - m, entropy_max_k, rolling_k: 0, rolling_rck: 0, - k_mask: (!0) >> (64 - k * 2), - m_mask: (!0) >> (64 - m * 2), + k_mask: (!0u64) >> (64 - k * 2), + m_mask: (!0u64) >> (64 - m * 2), received: 0, k1q: std::collections::VecDeque::with_capacity(k), k2q: std::collections::VecDeque::with_capacity(k - 1), @@ -85,12 +71,24 @@ impl RollingStat { self.rolling_k = 0; self.rolling_rck = 0; self.received = 0; - for &i in &self.k1q { self.k1c[i as usize] = 0; } - for &i in &self.k2q { self.k2c[i as usize] = 0; } - for &i in &self.k3q { self.k3c[i as usize] = 0; } - for &i in &self.k4q { self.k4c[i as usize] = 0; } - for &i in &self.k5q { self.k5c[i as usize] = 0; } - for &i in &self.k6q { self.k6c[i as usize] = 0; } + for &i in &self.k1q { + self.k1c[i as usize] = 0; + } + for &i in &self.k2q { + self.k2c[i as usize] = 0; + } + for &i in &self.k3q { + self.k3c[i as usize] = 0; + } + for &i in &self.k4q { + self.k4c[i as usize] = 0; + } + for &i in &self.k5q { + self.k5c[i as usize] = 0; + } + for &i in &self.k6q { + self.k6c[i as usize] = 0; + } self.k1q.clear(); self.k2q.clear(); self.k3q.clear(); @@ -127,12 +125,15 @@ impl RollingStat { } pub fn push(&mut self, nuc: u8) { + let k = params::k(); + let m = params::m(); + let bnuc = encode_nuc(nuc); let cnuc = bnuc ^ 3; self.rolling_k = ((self.rolling_k << 2) | (bnuc as u64)) & self.k_mask; self.rolling_rck = - ((self.rolling_rck >> 2) | ((cnuc as u64) << ((self.k - 1) * 2))) & self.k_mask; + ((self.rolling_rck >> 2) | ((cnuc as u64) << ((k - 1) * 2))) & self.k_mask; let canonical_k1 = entropy_norm_kmer(self.rolling_k & 3, 1, false); let canonical_k2 = entropy_norm_kmer(self.rolling_k & 15, 2, false); @@ -143,30 +144,37 @@ impl RollingStat { self.received += 1; - if self.received >= self.m { + if self.received >= m { let possible_canonical_m = - (self.rolling_k & self.m_mask).min(self.rolling_rck >> ((self.k - self.m) * 2)); - let possible_hash_m = hash_mmer(possible_canonical_m); - let possible_pos_m = self.received - self.m; + (self.rolling_k & self.m_mask).min(self.rolling_rck >> ((k - m) * 2)); + let possible_hash_m = hash_kmer(possible_canonical_m << 64 - m * 2); + let possible_pos_m = self.received - m; - while self.minimier.back().map_or(false, |it| it.hash >= possible_hash_m) { + while self + .minimier + .back() + .map_or(false, |it| it.hash >= possible_hash_m) + { self.minimier.pop_back(); } - self.minimier - .push_back(MmerItem { position: possible_pos_m, canonical: possible_canonical_m, hash: possible_hash_m }); + self.minimier.push_back(MmerItem { + position: possible_pos_m, + canonical: possible_canonical_m, + hash: possible_hash_m, + }); - if self.received > self.k { + if self.received > k { while self .minimier .front() - .map_or(false, |it| it.position + self.k < self.received) + .map_or(false, |it| it.position + k < self.received) { self.minimier.pop_front(); } } } - if self.received > self.k { + if self.received > k { let old1 = self.k1q.pop_front().unwrap(); let f1 = self.k1c[old1 as usize]; Self::update_sums_decrement(&mut self.sum_f_log_f, &mut self.sum_f_log_s, 1, old1, f1); @@ -199,37 +207,73 @@ impl RollingStat { } let g1 = self.k1c[canonical_k1 as usize]; - Self::update_sums_increment(&mut self.sum_f_log_f, &mut self.sum_f_log_s, 1, canonical_k1, g1); + Self::update_sums_increment( + &mut self.sum_f_log_f, + &mut self.sum_f_log_s, + 1, + canonical_k1, + g1, + ); self.k1c[canonical_k1 as usize] += 1; self.k1q.push_back(canonical_k1); if self.received >= 2 { let g2 = self.k2c[canonical_k2 as usize]; - Self::update_sums_increment(&mut self.sum_f_log_f, &mut self.sum_f_log_s, 2, canonical_k2, g2); + Self::update_sums_increment( + &mut self.sum_f_log_f, + &mut self.sum_f_log_s, + 2, + canonical_k2, + g2, + ); self.k2c[canonical_k2 as usize] += 1; self.k2q.push_back(canonical_k2); if self.received >= 3 { let g3 = self.k3c[canonical_k3 as usize]; - Self::update_sums_increment(&mut self.sum_f_log_f, &mut self.sum_f_log_s, 3, canonical_k3, g3); + Self::update_sums_increment( + &mut self.sum_f_log_f, + &mut self.sum_f_log_s, + 3, + canonical_k3, + g3, + ); self.k3c[canonical_k3 as usize] += 1; self.k3q.push_back(canonical_k3); if self.received >= 4 { let g4 = self.k4c[canonical_k4 as usize]; - Self::update_sums_increment(&mut self.sum_f_log_f, &mut self.sum_f_log_s, 4, canonical_k4, g4); + Self::update_sums_increment( + &mut self.sum_f_log_f, + &mut self.sum_f_log_s, + 4, + canonical_k4, + g4, + ); self.k4c[canonical_k4 as usize] += 1; self.k4q.push_back(canonical_k4); if self.received >= 5 { let g5 = self.k5c[canonical_k5 as usize]; - Self::update_sums_increment(&mut self.sum_f_log_f, &mut self.sum_f_log_s, 5, canonical_k5, g5); + Self::update_sums_increment( + &mut self.sum_f_log_f, + &mut self.sum_f_log_s, + 5, + canonical_k5, + g5, + ); self.k5c[canonical_k5 as usize] += 1; self.k5q.push_back(canonical_k5); if self.received >= 6 { let g6 = self.k6c[canonical_k6 as usize]; - Self::update_sums_increment(&mut self.sum_f_log_f, &mut self.sum_f_log_s, 6, canonical_k6, g6); + Self::update_sums_increment( + &mut self.sum_f_log_f, + &mut self.sum_f_log_s, + 6, + canonical_k6, + g6, + ); self.k6c[canonical_k6 as usize] += 1; self.k6q.push_back(canonical_k6); } @@ -240,31 +284,7 @@ impl RollingStat { } pub fn ready(&self) -> bool { - self.received >= self.k - } - - pub fn kmer(&self) -> Option { - if self.ready() { - Some(Kmer::from_raw_right(self.rolling_k, self.k)) - } else { - None - } - } - - pub fn revcomp_kmer(&self) -> Option { - if self.ready() { - Some(Kmer::from_raw_right(self.rolling_rck, self.k)) - } else { - None - } - } - - pub fn canonical_kmer(&self) -> Option { - if self.ready() { - Some(Kmer::from_raw_right(self.rolling_k.min(self.rolling_rck), self.k)) - } else { - None - } + self.received >= params::k() } pub fn minimizer_position(&self) -> Option { @@ -283,22 +303,22 @@ impl RollingStat { } } - pub fn canonical_minimizer(&self) -> Option { - self.canonical_minimizer_raw().map(|raw| { - CanonicalKmer::from_raw_unchecked(Kmer::from_raw_right(raw, self.m).raw()) - }) + pub fn canonical_minimizer(&self) -> Option { + self.canonical_minimizer_raw() + .map(|raw| Minimizer::from_raw_unchecked(raw << (64 - params::m() * 2))) } pub fn entropy(&self, order: usize) -> Option { if !self.ready() { return None; } - let em = emax(self.k, order); + let k = params::k(); + let em = emax(k, order); if em <= 0.0 { return Some(1.0); } - let nwords = self.k - order + 1; - let log_nw = log_nwords(self.k, order); + let nwords = k - order + 1; + let log_nw = log_nwords(k, order); let nw_f = nwords as f64; let h_corr = log_nw + (self.sum_f_log_s[order] - self.sum_f_log_f[order]) / nw_f; Some((h_corr / em).max(0.0)) diff --git a/src/obiskbuilder/src/scratch.rs b/src/obiskbuilder/src/scratch.rs index ec8a6e4..1f7b7bf 100644 --- a/src/obiskbuilder/src/scratch.rs +++ b/src/obiskbuilder/src/scratch.rs @@ -56,14 +56,14 @@ impl SuperKmerScratch { /// /// The heap allocation (`Box<[u8]>`) is exactly sized to the sequence. /// Resets the buffer to empty afterward. - pub fn emit(&mut self, min_pos: usize, m: usize) -> RoutableSuperKmer { + pub fn emit(&mut self, min_pos: usize) -> RoutableSuperKmer { let seql = self.len; debug_assert!(seql >= 1 && seql <= MAX_SUPERKMER_LEN); let n = (seql + 3) / 4; let seq: Box<[u8]> = self.buf[..n].into(); self.buf[..n].fill(0); self.len = 0; - RoutableSuperKmer::build(min_pos, m, seql as u8, seq) + RoutableSuperKmer::build(min_pos, seql, seq) } /// Discard all accumulated nucleotides without producing a [`SuperKmer`]. pub fn reset(&mut self) { diff --git a/src/obiskio/Cargo.toml b/src/obiskio/Cargo.toml index 75d142c..4d60b3f 100644 --- a/src/obiskio/Cargo.toml +++ b/src/obiskio/Cargo.toml @@ -14,3 +14,4 @@ obikseq = { path = "../obikseq" } [dev-dependencies] tempfile = "3" +obikseq = { path = "../obikseq", features = ["test-utils"] } diff --git a/src/obiskio/src/codec.rs b/src/obiskio/src/codec.rs index 9274e65..b2b5b44 100644 --- a/src/obiskio/src/codec.rs +++ b/src/obiskio/src/codec.rs @@ -1,46 +1,19 @@ -use obikseq::superkmer::SuperKmer; +use obikseq::SuperKmer; use std::io::{self, Read, Write}; /// Serialise one SuperKmer into `w` (uncompressed; caller must wrap with a compressor). -/// -/// Bits [7:0] of the header store `n_kmers = seql - k + 1` (kmer units, 1–255), -/// not the raw nucleotide length. This removes the 0=256 wrapping convention. #[inline] -pub(crate) fn write_superkmer(w: &mut W, sk: &SuperKmer, k: usize) -> io::Result<()> { - let n_kmers = sk.len() - k + 1; - let new_bits = (sk.header_bits() & !0xFF) | (n_kmers as u32); - w.write_all(&new_bits.to_le_bytes())?; - w.write_all(sk.seq_bytes()) +pub(crate) fn write_superkmer(w: &mut W, sk: &SuperKmer) -> io::Result<()> { + sk.write_to_binary(w) } /// Deserialise one SuperKmer from `r`. Returns `None` on clean EOF. -/// `seq_buf` is a reusable scratch buffer to avoid per-record allocation. -/// Bits [7:0] of the on-disk header contain `n_kmers`; nucleotide length is -/// reconstructed as `n_kmers + k - 1`. -pub(crate) fn read_superkmer( - r: &mut R, - seq_buf: &mut Vec, - k: usize, -) -> io::Result> { - let mut hdr = [0u8; 4]; - match r.read_exact(&mut hdr) { - Ok(()) => {} - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(e), +pub(crate) fn read_superkmer(r: &mut R) -> io::Result> { + match SuperKmer::read_from_binary(r) { + Ok(sk) => Ok(Some(sk)), + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(None), + Err(e) => Err(e), } - let bits = u32::from_le_bytes(hdr); - let n_kmers = (bits & 0xFF) as usize; - let nt_len = n_kmers + k - 1; - let byte_len = (nt_len + 3) / 4; - seq_buf.resize(byte_len, 0); - r.read_exact(seq_buf)?; - // Reconstruct the in-memory seql byte (0 encodes 256, 1-255 direct). - let seql_byte = if nt_len == 256 { 0u8 } else { nt_len as u8 }; - let mem_bits = (bits & !0xFF) | (seql_byte as u32); - Ok(Some(SuperKmer::from_header_bits( - mem_bits, - seq_buf.as_slice().into(), - ))) } #[cfg(test)] @@ -54,32 +27,29 @@ mod tests { #[test] fn roundtrip_single() { - let k = 4; let sk = make_sk(b"ACGTACGT"); let mut buf = Vec::new(); - write_superkmer(&mut buf, &sk, k).unwrap(); + write_superkmer(&mut buf, &sk).unwrap(); let mut cur = Cursor::new(&buf); - let mut seq_buf = Vec::new(); - let got = read_superkmer(&mut cur, &mut seq_buf, k).unwrap().unwrap(); + let got = read_superkmer(&mut cur).unwrap().unwrap(); assert_eq!(sk.to_ascii(), got.to_ascii()); - assert_eq!(sk.len(), got.len()); + assert_eq!(sk.seql(), got.seql()); } #[test] fn roundtrip_all_lengths() { - let bases: Vec = (0..256).map(|i| b"ACGT"[i % 4]).collect(); - // k=11 is the project minimum; test seql from k to 256. + let bases: Vec = (0..300).map(|i| b"ACGT"[i % 4]).collect(); let k = 11; - for len in (k..=k + 8).chain([255, 256]) { + for len in (k..=k + 8).chain([255, 256, 257]) { let sk = make_sk(&bases[..len]); let mut buf = Vec::new(); - write_superkmer(&mut buf, &sk, k).unwrap(); + write_superkmer(&mut buf, &sk).unwrap(); let mut cur = Cursor::new(&buf); - let mut seq_buf = Vec::new(); - let got = read_superkmer(&mut cur, &mut seq_buf, k).unwrap().unwrap(); + let got = read_superkmer(&mut cur).unwrap().unwrap(); assert_eq!(sk.to_ascii(), got.to_ascii(), "len={len}"); + assert_eq!(sk.seql(), got.seql(), "len={len}"); } } @@ -87,26 +57,23 @@ mod tests { fn eof_returns_none() { let buf: Vec = vec![]; let mut cur = Cursor::new(&buf); - let mut seq_buf = Vec::new(); - assert!(read_superkmer(&mut cur, &mut seq_buf, 4).unwrap().is_none()); + assert!(read_superkmer(&mut cur).unwrap().is_none()); } #[test] fn multiple_records() { - let k = 4; let seqs: &[&[u8]] = &[b"AAAA", b"CCCC", b"GGGG", b"TTTT"]; let mut buf = Vec::new(); for s in seqs { - write_superkmer(&mut buf, &make_sk(s), k).unwrap(); + write_superkmer(&mut buf, &make_sk(s)).unwrap(); } let mut cur = Cursor::new(&buf); - let mut seq_buf = Vec::new(); for s in seqs { - let got = read_superkmer(&mut cur, &mut seq_buf, k).unwrap().unwrap(); + let got = read_superkmer(&mut cur).unwrap().unwrap(); let expected = make_sk(s); assert_eq!(expected.to_ascii(), got.to_ascii()); } - assert!(read_superkmer(&mut cur, &mut seq_buf, k).unwrap().is_none()); + assert!(read_superkmer(&mut cur).unwrap().is_none()); } } diff --git a/src/obiskio/src/pool.rs b/src/obiskio/src/pool.rs index 19c3fa9..54c3d9c 100644 --- a/src/obiskio/src/pool.rs +++ b/src/obiskio/src/pool.rs @@ -5,7 +5,7 @@ use crate::meta::SKFileMeta; use lru::LruCache; use niffler::Level; use niffler::send::compression::Format; -use obikseq::superkmer::SuperKmer; +use obikseq::SuperKmer; use std::fs::{File, OpenOptions}; use std::io::{BufWriter, Write}; use std::num::NonZeroUsize; @@ -222,7 +222,6 @@ pub struct SKFileWriter { id: usize, pool: Arc>, path: PathBuf, - k: usize, pending: Vec, flush_threshold: usize, logically_closed: bool, @@ -230,15 +229,14 @@ pub struct SKFileWriter { } /// Create a `SKFileWriter` for a new file (Zstd, level 3). -pub fn create_token(pool: &SharedPool, path: PathBuf, k: usize) -> SKResult { - create_token_with(pool, path, k, Format::Zstd, Level::Three) +pub fn create_token(pool: &SharedPool, path: PathBuf) -> SKResult { + create_token_with(pool, path, Format::Zstd, Level::Three) } /// Create a `SKFileWriter` for a new file with explicit format and level. pub fn create_token_with( pool: &SharedPool, path: PathBuf, - k: usize, format: Format, level: Level, ) -> SKResult { @@ -247,7 +245,6 @@ pub fn create_token_with( id, pool: Arc::clone(pool), path, - k, pending: Vec::with_capacity(DEFAULT_FLUSH_THRESHOLD + 128), flush_threshold: DEFAULT_FLUSH_THRESHOLD, logically_closed: false, @@ -258,18 +255,13 @@ pub fn create_token_with( impl SKFileWriter { /// Create a standalone file writer (Zstd, level 3). /// The pool is created internally and is not accessible to the caller. - pub fn create>(path: P, k: usize) -> SKResult { - Self::create_with(path, k, Format::Zstd, Level::Three) + pub fn create>(path: P) -> SKResult { + Self::create_with(path, Format::Zstd, Level::Three) } /// Create a standalone file writer with explicit format and level. - pub fn create_with>( - path: P, - k: usize, - format: Format, - level: Level, - ) -> SKResult { - create_token_with(global_pool(), path.as_ref().to_owned(), k, format, level) + pub fn create_with>(path: P, format: Format, level: Level) -> SKResult { + create_token_with(global_pool(), path.as_ref().to_owned(), format, level) } /// `true` if the underlying fd is currently open in the pool. @@ -280,10 +272,10 @@ impl SKFileWriter { /// Accumulate one SuperKmer. Drains to fd when `pending ≥ flush_threshold`. pub fn write(&mut self, sk: &SuperKmer) -> SKResult<()> { self.check_not_closed()?; - write_superkmer(&mut self.pending, sk, self.k)?; + write_superkmer(&mut self.pending, sk)?; self.meta.instances += 1; self.meta.count_sum += sk.count() as u64; - self.meta.length_sum += sk.len() as u64; + self.meta.length_sum += sk.seql() as u64; if self.pending.len() >= self.flush_threshold { self.drain()?; } @@ -294,10 +286,10 @@ impl SKFileWriter { pub fn write_batch(&mut self, sks: &[SuperKmer]) -> SKResult<()> { self.check_not_closed()?; for sk in sks { - write_superkmer(&mut self.pending, sk, self.k)?; + write_superkmer(&mut self.pending, sk)?; self.meta.instances += 1; self.meta.count_sum += sk.count() as u64; - self.meta.length_sum += sk.len() as u64; + self.meta.length_sum += sk.seql() as u64; if self.pending.len() >= self.flush_threshold { self.drain()?; } @@ -439,7 +431,7 @@ impl Drop for SKFileWriter { mod tests { use super::*; use crate::reader::SKFileReader; - use obikseq::superkmer::SuperKmer; + use obikseq::{SuperKmer, set_k}; use tempfile::{NamedTempFile, TempDir}; const TEST_K: usize = 4; @@ -460,22 +452,24 @@ mod tests { #[test] fn creation_holds_no_fd() { + set_k(TEST_K); let dir = TempDir::new().unwrap(); let p = pool(3); for i in 0..10 { - create_token(&p, dir.path().join(format!("p{i}.zst")), TEST_K).unwrap(); + create_token(&p, dir.path().join(format!("p{i}.zst"))).unwrap(); } assert_eq!(p.lock().unwrap().open_count(), 0); } #[test] fn pool_limits_open_fds() { + set_k(TEST_K); let dir = TempDir::new().unwrap(); let p = pool(3); let sk = make_sk(0); let mut tokens: Vec = (0..6) - .map(|i| create_token(&p, dir.path().join(format!("p{i}.zst")), TEST_K).unwrap()) + .map(|i| create_token(&p, dir.path().join(format!("p{i}.zst"))).unwrap()) .collect(); for t in tokens.iter_mut() { @@ -491,12 +485,13 @@ mod tests { #[test] fn evicted_token_stays_logically_open() { + set_k(TEST_K); let dir = TempDir::new().unwrap(); let p = pool(1); let sk = make_sk(0); - let mut t0 = create_token(&p, dir.path().join("a.zst"), TEST_K).unwrap(); - let mut t1 = create_token(&p, dir.path().join("b.zst"), TEST_K).unwrap(); + let mut t0 = create_token(&p, dir.path().join("a.zst")).unwrap(); + let mut t1 = create_token(&p, dir.path().join("b.zst")).unwrap(); open_token(&mut t0, &sk); // t0 fd open, pool full open_token(&mut t1, &sk); // evicts t0, t1 fd open @@ -507,12 +502,13 @@ mod tests { #[test] fn evicted_data_readable_after_close_all() { + set_k(TEST_K); let dir = TempDir::new().unwrap(); let p = pool(1); let sk = make_sk(0); - let mut t0 = create_token(&p, dir.path().join("a.zst"), TEST_K).unwrap(); - let mut t1 = create_token(&p, dir.path().join("b.zst"), TEST_K).unwrap(); + let mut t0 = create_token(&p, dir.path().join("a.zst")).unwrap(); + let mut t1 = create_token(&p, dir.path().join("b.zst")).unwrap(); t0.set_flush_threshold(1); t0.write(&sk).unwrap(); // t0 fd open, pool full @@ -528,7 +524,7 @@ mod tests { p.lock().unwrap().close_all().unwrap(); for name in &["a.zst", "b.zst"] { - let mut r = SKFileReader::open(dir.path().join(name), TEST_K).unwrap(); + let mut r = SKFileReader::open(dir.path().join(name)).unwrap(); let got = r.read_batch(10).unwrap(); assert_eq!(got.len(), 1, "{name}: expected 1 record"); } @@ -536,13 +532,14 @@ mod tests { #[test] fn touch_moves_to_mru_so_lru_is_evicted() { + set_k(TEST_K); let dir = TempDir::new().unwrap(); let p = pool(2); let sk = make_sk(0); - let mut t0 = create_token(&p, dir.path().join("a.zst"), TEST_K).unwrap(); - let mut t1 = create_token(&p, dir.path().join("b.zst"), TEST_K).unwrap(); - let mut t2 = create_token(&p, dir.path().join("c.zst"), TEST_K).unwrap(); + let mut t0 = create_token(&p, dir.path().join("a.zst")).unwrap(); + let mut t1 = create_token(&p, dir.path().join("b.zst")).unwrap(); + let mut t2 = create_token(&p, dir.path().join("c.zst")).unwrap(); open_token(&mut t0, &sk); // t0 open open_token(&mut t1, &sk); // t1 open, t0 LRU @@ -560,6 +557,7 @@ mod tests { #[test] fn close_all_produces_readable_files() { + set_k(TEST_K); let dir = TempDir::new().unwrap(); let p = pool(8); let paths: Vec<_> = (0..4) @@ -568,7 +566,7 @@ mod tests { let mut tokens: Vec = paths .iter() - .map(|path| create_token(&p, path.clone(), TEST_K).unwrap()) + .map(|path| create_token(&p, path.clone()).unwrap()) .collect(); for (i, t) in tokens.iter_mut().enumerate() { @@ -581,7 +579,7 @@ mod tests { p.lock().unwrap().close_all().unwrap(); for path in &paths { - let mut r = SKFileReader::open(path, TEST_K).unwrap(); + let mut r = SKFileReader::open(path).unwrap(); let got = r.read_batch(10).unwrap(); assert_eq!(got.len(), 1); } @@ -589,16 +587,17 @@ mod tests { #[test] fn write_batch_roundtrip() { + set_k(TEST_K); let dir = TempDir::new().unwrap(); let p = pool(4); let sks: Vec<_> = (0..50).map(make_sk).collect(); let path = dir.path().join("batch.zst"); - let mut t = create_token(&p, path.clone(), TEST_K).unwrap(); + let mut t = create_token(&p, path.clone()).unwrap(); t.write_batch(&sks).unwrap(); t.close().unwrap(); - let mut r = SKFileReader::open(&path, TEST_K).unwrap(); + let mut r = SKFileReader::open(&path).unwrap(); let got = r.read_batch(100).unwrap(); assert_eq!(got.len(), 50); for (a, b) in sks.iter().zip(got.iter()) { @@ -608,6 +607,7 @@ mod tests { #[test] fn from_system_limits_bounded() { + set_k(TEST_K); let pool = SKFilePool::from_system_limits(); assert!(pool.max_open() >= 16); assert!(pool.max_open() <= MAX_POOL_SIZE); @@ -615,14 +615,15 @@ mod tests { #[test] fn standalone_roundtrip_zstd() { + set_k(TEST_K); let tmp = NamedTempFile::new().unwrap(); let sks: Vec<_> = (0..100).map(make_sk).collect(); { - let mut w = SKFileWriter::create(tmp.path(), TEST_K).unwrap(); + let mut w = SKFileWriter::create(tmp.path()).unwrap(); w.write_batch(&sks).unwrap(); w.close().unwrap(); } - let mut r = SKFileReader::open(tmp.path(), TEST_K).unwrap(); + let mut r = SKFileReader::open(tmp.path()).unwrap(); let got = r.read_batch(200).unwrap(); assert_eq!(got.len(), 100); for (a, b) in sks.iter().zip(got.iter()) { @@ -632,8 +633,9 @@ mod tests { #[test] fn standalone_close_prevents_write() { + set_k(TEST_K); let tmp = NamedTempFile::new().unwrap(); - let mut w = SKFileWriter::create(tmp.path(), TEST_K).unwrap(); + let mut w = SKFileWriter::create(tmp.path()).unwrap(); w.close().unwrap(); assert!(!w.is_open()); assert!(w.write(&make_sk(0)).is_err()); @@ -641,8 +643,9 @@ mod tests { #[test] fn standalone_is_physically_open() { + set_k(TEST_K); let tmp = NamedTempFile::new().unwrap(); - let mut w = SKFileWriter::create(tmp.path(), TEST_K).unwrap(); + let mut w = SKFileWriter::create(tmp.path()).unwrap(); assert!(!w.is_physically_open()); // fd deferred until first drain w.set_flush_threshold(1); w.write(&make_sk(0)).unwrap(); // triggers drain → fd opened diff --git a/src/obiskio/src/reader.rs b/src/obiskio/src/reader.rs index 9f62ec9..d523ab8 100644 --- a/src/obiskio/src/reader.rs +++ b/src/obiskio/src/reader.rs @@ -15,25 +15,20 @@ use std::path::{Path, PathBuf}; /// that it can fast-forward on next open. pub struct SKFileReader { path: PathBuf, - k: usize, reader: Option>, - /// Reusable scratch buffer for the `seq` bytes of each record. - seq_buf: Vec, /// Number of SuperKmers successfully read so far (for eviction recovery). consumed: u64, } impl SKFileReader { /// Open a file for reading. Format is auto-detected from magic bytes. - /// `k` is the kmer size of the partition; required to decode the on-disk n_kmers field. - pub fn open>(path: P, k: usize) -> SKResult { + pub fn open>(path: P) -> SKResult { let path = path.as_ref().to_owned(); - let (reader, _fmt) = niffler::send::get_reader(Box::new(BufReader::new(File::open(&path)?)))?; + let (reader, _fmt) = + niffler::send::get_reader(Box::new(BufReader::new(File::open(&path)?)))?; Ok(Self { path, - k, reader: Some(reader), - seq_buf: Vec::with_capacity(64), consumed: 0, }) } @@ -46,7 +41,7 @@ impl SKFileReader { "read from physically closed SKFileReader", ) })?; - let result = read_superkmer(r, &mut self.seq_buf, self.k)?; + let result = read_superkmer(r)?; if result.is_some() { self.consumed += 1; } @@ -87,7 +82,10 @@ impl SKFileReader { /// Return an iterator over this reader. pub fn iter(&mut self) -> SKFileIter<'_> { - SKFileIter { reader: self, error: None } + SKFileIter { + reader: self, + error: None, + } } // ── pool-internal helpers ───────────────────────────────────────────────── @@ -103,7 +101,7 @@ impl SKFileReader { let target = self.consumed; self.consumed = 0; for _ in 0..target { - match read_superkmer(self.reader.as_mut().unwrap(), &mut self.seq_buf, self.k)? { + match read_superkmer(self.reader.as_mut().unwrap())? { Some(_) => self.consumed += 1, None => break, } @@ -152,6 +150,10 @@ mod tests { const TEST_K: usize = 4; // test sequences are 8 bases; k=4 gives n_kmers=5 + fn setup() { + obikseq::params::set_k(TEST_K); + } + fn make_sks(n: usize) -> Vec { (0..n) .map(|i| { @@ -163,15 +165,16 @@ mod tests { #[test] fn iter_all() { + setup(); let tmp = NamedTempFile::new().unwrap(); let sks = make_sks(50); { - let mut w = SKFileWriter::create(tmp.path(), TEST_K).unwrap(); + let mut w = SKFileWriter::create(tmp.path()).unwrap(); w.write_batch(&sks).unwrap(); } - let mut r = SKFileReader::open(tmp.path(), TEST_K).unwrap(); + let mut r = SKFileReader::open(tmp.path()).unwrap(); let got: Vec<_> = r.iter().collect(); assert_eq!(got.len(), 50); for (a, b) in sks.iter().zip(got.iter()) { @@ -181,15 +184,16 @@ mod tests { #[test] fn reopen_and_seek() { + setup(); let tmp = NamedTempFile::new().unwrap(); let sks = make_sks(20); { - let mut w = SKFileWriter::create(tmp.path(), TEST_K).unwrap(); + let mut w = SKFileWriter::create(tmp.path()).unwrap(); w.write_batch(&sks).unwrap(); } - let mut r = SKFileReader::open(tmp.path(), TEST_K).unwrap(); + let mut r = SKFileReader::open(tmp.path()).unwrap(); // Read 10, then simulate pool eviction + re-access let first = r.read_batch(10).unwrap(); r.close();