Add column group operations and mask_with trait

Introduce the `ColGroup` struct and `MatrixGroupOps` trait to manage named subsets of column indices and perform additive aggregations (count, sum, any). Implement these operations for `PersistentBitMatrix` and `PersistentCompactIntMatrix`, applying size-optimized branches for presence counts and direct accumulation for small groups. Additionally, add a `mask_with` trait method that efficiently zero-sets elements based on a mask, optimized for sparse masks with O(n_zeros) complexity. Include comprehensive tests covering overflow handling, slot masking, and result additivity across partitioned data.
This commit is contained in:
Eric Coissac
2026-06-17 14:50:28 +02:00
parent 93559c3294
commit 1d38d87ff9
7 changed files with 391 additions and 2 deletions
+42 -1
View File
@@ -7,8 +7,10 @@ use ndarray::{Array1, Array2};
use rayon::prelude::*; use rayon::prelude::*;
use crate::bitvec::{PersistentBitVec, PersistentBitVecBuilder}; use crate::bitvec::{PersistentBitVec, PersistentBitVecBuilder};
use crate::colgroup::{ColGroup, MatrixGroupOps, inc_primary_bits};
use crate::memoryintvec::MemoryIntVec;
use crate::memoryvec::MemoryBitVec; use crate::memoryvec::MemoryBitVec;
use crate::traits::{BitSlice, BitSliceMut}; use crate::traits::{BitSlice, BitSliceMut, IntSliceMut};
use crate::layer_meta::LayerMeta; use crate::layer_meta::LayerMeta;
use crate::meta::MatrixMeta; use crate::meta::MatrixMeta;
@@ -447,6 +449,45 @@ impl PersistentBitMatrixBuilder {
} }
} }
// ── MatrixGroupOps ────────────────────────────────────────────────────────────
impl MatrixGroupOps for PersistentBitMatrix {
fn partial_group_presence_count(&self, g: &ColGroup, _threshold: u32) -> MemoryIntVec {
// Bit matrices store 0/1 — threshold is structurally always 1.
// Materialize each column to a MemoryBitVec and accumulate directly.
let n = self.n();
if g.indices.len() < 255 {
let mut primary = vec![0u8; n];
for &c in &g.indices {
let mbv = MemoryBitVec::from(&self.col_view(c));
inc_primary_bits(&mut primary, &mbv);
}
MemoryIntVec::from_primary(primary)
} else {
let mut result = MemoryIntVec::new(n);
for &c in &g.indices {
let mbv = MemoryBitVec::from(&self.col_view(c));
result.count_bits(&mbv);
}
result
}
}
fn partial_group_sum(&self, g: &ColGroup) -> MemoryIntVec {
// For bit matrices, sum = count of 1-bits — identical to presence_count.
self.partial_group_presence_count(g, 1)
}
fn partial_group_any(&self, g: &ColGroup, _threshold: u32) -> MemoryBitVec {
let n = self.n();
let mut result = MemoryBitVec::new(n);
for &c in &g.indices {
result.or(&self.col_view(c));
}
result
}
}
// ── Shared matrix helpers (also used by intmatrix.rs) ───────────────────────── // ── Shared matrix helpers (also used by intmatrix.rs) ─────────────────────────
fn upper_pairs(n: usize) -> Vec<(usize, usize)> { fn upper_pairs(n: usize) -> Vec<(usize, usize)> {
+59
View File
@@ -0,0 +1,59 @@
use crate::memoryintvec::MemoryIntVec;
use crate::memoryvec::MemoryBitVec;
use crate::traits::BitSlice;
// ── ColGroup ──────────────────────────────────────────────────────────────────
/// A named subset of columns, identified by their indices within the matrix.
///
/// Defined once at the index level; the same indices are valid across all
/// partitions and layers because the column structure (samples / genomes) is
/// identical everywhere — only the row space (kmer slots) is partitioned.
pub struct ColGroup {
pub name: String,
pub indices: Vec<usize>,
}
impl ColGroup {
pub fn new(name: impl Into<String>, indices: Vec<usize>) -> Self {
Self { name: name.into(), indices }
}
}
// ── MatrixGroupOps ────────────────────────────────────────────────────────────
/// Per-matrix group aggregations that return **additive intermediates**.
///
/// Results must be composed by the caller (concat across partitions, add across
/// layers) before applying final predicates (`geq`, `leq`, …). Non-additive
/// predicates like `group_all` or `group_at_least(k)` are intentionally absent
/// — they are derived at the index level from these intermediates.
pub trait MatrixGroupOps {
/// Per-slot count of group columns whose value ≥ `threshold`.
fn partial_group_presence_count(&self, g: &ColGroup, threshold: u32) -> MemoryIntVec;
/// Per-slot sum of values across all group columns.
fn partial_group_sum(&self, g: &ColGroup) -> MemoryIntVec;
/// Per-slot OR: true if any group column has value ≥ `threshold`.
fn partial_group_any(&self, g: &ColGroup, threshold: u32) -> MemoryBitVec;
}
// ── Internal helper ───────────────────────────────────────────────────────────
/// Iterate 1-bits of a `MemoryBitVec` and increment the corresponding raw
/// byte. Caller must guarantee that no counter will reach 255 (group size
/// < 255 columns), so that incrementing `u8` is safe and no sentinel is
/// accidentally written.
pub(crate) fn inc_primary_bits(primary: &mut [u8], mask: &MemoryBitVec) {
let n = primary.len();
for (wi, &word) in mask.words().iter().enumerate() {
let mut w = word;
while w != 0 {
let bit = w.trailing_zeros() as usize;
let s = wi * 64 + bit;
if s < n { primary[s] += 1; }
w &= w - 1;
}
}
}
+49 -1
View File
@@ -10,11 +10,13 @@ use rayon::prelude::*;
use crate::bitmatrix::{pairwise_matrix, pairwise2_matrix}; use crate::bitmatrix::{pairwise_matrix, pairwise2_matrix};
use crate::builder::PersistentCompactIntVecBuilder; use crate::builder::PersistentCompactIntVecBuilder;
use crate::colgroup::{ColGroup, MatrixGroupOps, inc_primary_bits};
use crate::memoryintvec::MemoryIntVec; use crate::memoryintvec::MemoryIntVec;
use crate::memoryvec::MemoryBitVec;
use crate::format::{byte_count_nonzero, byte_sum, HEADER_SIZE, OVERFLOW_ENTRY_SIZE, parse_index_entry, parse_overflow_entry}; use crate::format::{byte_count_nonzero, byte_sum, HEADER_SIZE, OVERFLOW_ENTRY_SIZE, parse_index_entry, parse_overflow_entry};
use crate::meta::MatrixMeta; use crate::meta::MatrixMeta;
use crate::reader::PersistentCompactIntVec; use crate::reader::PersistentCompactIntVec;
use crate::traits::IntSlice; use crate::traits::{BitSliceMut, IntSlice, IntSliceMut};
fn col_path(dir: &Path, col: usize) -> PathBuf { fn col_path(dir: &Path, col: usize) -> PathBuf {
dir.join(format!("col_{col:06}.pciv")) dir.join(format!("col_{col:06}.pciv"))
@@ -624,3 +626,49 @@ impl PersistentCompactIntMatrixBuilder {
MatrixMeta { n: self.n, n_cols: self.n_cols }.save(&self.dir) MatrixMeta { n: self.n, n_cols: self.n_cols }.save(&self.dir)
} }
} }
// ── MatrixGroupOps ────────────────────────────────────────────────────────────
impl MatrixGroupOps for PersistentCompactIntMatrix {
fn partial_group_presence_count(&self, g: &ColGroup, threshold: u32) -> MemoryIntVec {
let n = self.n();
if g.indices.len() < 255 {
// Fast path: counts fit in u8 — accumulate directly into raw bytes,
// no overflow map involved.
let mut primary = vec![0u8; n];
for &c in &g.indices {
let mask = self.col_view(c).cmp_scalar(|v| v >= threshold);
inc_primary_bits(&mut primary, &mask);
}
MemoryIntVec::from_primary(primary)
} else {
// Slow path (rare): use IntSliceMut::count_bits which handles overflow.
let mut result = MemoryIntVec::new(n);
for &c in &g.indices {
let mask = self.col_view(c).cmp_scalar(|v| v >= threshold);
result.count_bits(&mask);
}
result
}
}
fn partial_group_sum(&self, g: &ColGroup) -> MemoryIntVec {
let n = self.n();
let mut result = MemoryIntVec::new(n);
for &c in &g.indices {
let view = self.col_view(c);
IntSliceMut::add(&mut result, &view);
}
result
}
fn partial_group_any(&self, g: &ColGroup, threshold: u32) -> MemoryBitVec {
let n = self.n();
let mut result = MemoryBitVec::new(n);
for &c in &g.indices {
let mask = self.col_view(c).cmp_scalar(|v| v >= threshold);
result.or(&mask);
}
result
}
}
+2
View File
@@ -1,6 +1,7 @@
mod bitvec; mod bitvec;
mod bitmatrix; mod bitmatrix;
mod builder; mod builder;
mod colgroup;
mod format; mod format;
mod intmatrix; mod intmatrix;
mod layer_meta; mod layer_meta;
@@ -13,6 +14,7 @@ pub mod traits;
pub use bitvec::{BitIter, PersistentBitVec, PersistentBitVecBuilder}; pub use bitvec::{BitIter, PersistentBitVec, PersistentBitVecBuilder};
pub use bitmatrix::{BitColView, PersistentBitMatrix, PersistentBitMatrixBuilder, pack_bit_matrix}; pub use bitmatrix::{BitColView, PersistentBitMatrix, PersistentBitMatrixBuilder, pack_bit_matrix};
pub use builder::PersistentCompactIntVecBuilder; pub use builder::PersistentCompactIntVecBuilder;
pub use colgroup::{ColGroup, MatrixGroupOps};
pub use intmatrix::{IntColView, PersistentCompactIntMatrix, PersistentCompactIntMatrixBuilder, pack_compact_int_matrix}; pub use intmatrix::{IntColView, PersistentCompactIntMatrix, PersistentCompactIntMatrixBuilder, pack_compact_int_matrix};
pub use layer_meta::LayerMeta; pub use layer_meta::LayerMeta;
pub use memoryintvec::{MemoryIntIter, MemoryIntVec}; pub use memoryintvec::{MemoryIntIter, MemoryIntVec};
+215
View File
@@ -0,0 +1,215 @@
use tempfile::tempdir;
use crate::{
ColGroup, MatrixGroupOps,
PersistentBitMatrix, PersistentBitMatrixBuilder,
PersistentCompactIntMatrix, PersistentCompactIntMatrixBuilder,
};
use crate::traits::{BitSliceMut, IntSlice, IntSliceMut};
use crate::{MemoryBitVec, MemoryIntVec};
// ── helpers ───────────────────────────────────────────────────────────────────
fn make_int_matrix(cols: &[&[u32]]) -> (tempfile::TempDir, PersistentCompactIntMatrix) {
let n = cols.first().map_or(0, |c| c.len());
let dir = tempdir().unwrap();
let mut b = PersistentCompactIntMatrixBuilder::new(n, &dir.path().join("counts")).unwrap();
for &col in cols {
let mut cb = b.add_col().unwrap();
for (slot, &v) in col.iter().enumerate() { cb.set(slot, v); }
cb.close().unwrap();
}
b.close().unwrap();
let m = PersistentCompactIntMatrix::open(dir.path()).unwrap();
(dir, m)
}
fn make_bit_matrix(cols: &[&[bool]]) -> (tempfile::TempDir, PersistentBitMatrix) {
let n = cols.first().map_or(0, |c| c.len());
let dir = tempdir().unwrap();
let presence = dir.path().join("presence");
let mut b = PersistentBitMatrixBuilder::new(n, &presence).unwrap();
for &col in cols {
let mut cb = b.add_col().unwrap();
for (slot, &v) in col.iter().enumerate() { cb.set(slot, v); }
cb.close().unwrap();
}
b.close().unwrap();
let m = PersistentBitMatrix::open(dir.path()).unwrap();
(dir, m)
}
// ── IntMatrix: partial_group_sum ──────────────────────────────────────────────
#[test]
fn int_partial_group_sum_basic() {
// col0=[1,2,3], col1=[10,20,30], col2=[100,0,5]
// group {0,2}: sum = [101, 2, 8]
let (_d, m) = make_int_matrix(&[&[1, 2, 3], &[10, 20, 30], &[100, 0, 5]]);
let g = ColGroup::new("g", vec![0, 2]);
let result = m.partial_group_sum(&g);
assert_eq!(result.get(0), 101);
assert_eq!(result.get(1), 2);
assert_eq!(result.get(2), 8);
}
#[test]
fn int_partial_group_sum_with_overflow() {
// col0=[300,0], col1=[200,400]: group {0,1}: sum=[500, 400]
let (_d, m) = make_int_matrix(&[&[300, 0], &[200, 400]]);
let g = ColGroup::new("g", vec![0, 1]);
let result = m.partial_group_sum(&g);
assert_eq!(result.get(0), 500);
assert_eq!(result.get(1), 400);
assert_eq!(result.sum(), 900);
}
// ── IntMatrix: partial_group_presence_count ───────────────────────────────────
#[test]
fn int_partial_group_presence_count() {
// col0=[5,1,0,3], col1=[2,0,4,3], col2=[0,3,1,0]
// threshold=2: col0: [T,F,F,T], col1: [T,F,T,T], col2: [F,T,F,F]
// group {0,1,2}: counts = [2, 1, 1, 2]
let (_d, m) = make_int_matrix(&[&[5, 1, 0, 3], &[2, 0, 4, 3], &[0, 3, 1, 0]]);
let g = ColGroup::new("g", vec![0, 1, 2]);
let result = m.partial_group_presence_count(&g, 2);
assert_eq!(result.get(0), 2);
assert_eq!(result.get(1), 1);
assert_eq!(result.get(2), 1);
assert_eq!(result.get(3), 2);
}
#[test]
fn int_partial_group_presence_count_with_overflow() {
// col0=[300,0,10], col1=[0,400,10], col2=[1,1,10]
// threshold=5: col0: [T,F,T], col1: [F,T,T], col2: [F,F,T]
// group {0,1,2}: counts = [1, 1, 3]
let (_d, m) = make_int_matrix(&[&[300, 0, 10], &[0, 400, 10], &[1, 1, 10]]);
let g = ColGroup::new("g", vec![0, 1, 2]);
let result = m.partial_group_presence_count(&g, 5);
assert_eq!(result.get(0), 1);
assert_eq!(result.get(1), 1);
assert_eq!(result.get(2), 3);
}
// ── IntMatrix: partial_group_any ──────────────────────────────────────────────
#[test]
fn int_partial_group_any() {
// col0=[0,3,0,1], col1=[2,0,0,0], col2=[0,0,5,0]
// threshold=2: col0: [F,T,F,F], col1: [T,F,F,F], col2: [F,F,T,F]
// group {0,1,2}: any = [T, T, T, F]
let (_d, m) = make_int_matrix(&[&[0, 3, 0, 1], &[2, 0, 0, 0], &[0, 0, 5, 0]]);
let g = ColGroup::new("g", vec![0, 1, 2]);
let result = m.partial_group_any(&g, 2);
assert_eq!(result.get(0), true);
assert_eq!(result.get(1), true);
assert_eq!(result.get(2), true);
assert_eq!(result.get(3), false);
}
// ── IntMatrix: mask_with ──────────────────────────────────────────────────────
#[test]
fn mask_with_zeros_selected_slots() {
// count vec [10, 20, 30, 40], mask [T, F, T, F] → [10, 0, 30, 0]
let mut v = MemoryIntVec::new(4);
v.set(0, 10); v.set(1, 20); v.set(2, 30); v.set(3, 40);
let mut mask = MemoryBitVec::new(4);
mask.set(0, true); mask.set(2, true);
v.mask_with(&mask);
assert_eq!(v.get(0), 10);
assert_eq!(v.get(1), 0);
assert_eq!(v.get(2), 30);
assert_eq!(v.get(3), 0);
}
#[test]
fn mask_with_overflow_slot_zeroed() {
// overflow slot (value 500) masked out → removed from overflow, primary=0
let mut v = MemoryIntVec::new(3);
v.set(0, 10); v.set(1, 500); v.set(2, 5);
let mut mask = MemoryBitVec::new(3);
mask.set(0, true); mask.set(2, true); // slot 1 masked out
v.mask_with(&mask);
assert_eq!(v.get(0), 10);
assert_eq!(v.get(1), 0);
assert_eq!(v.get(2), 5);
let ov: Vec<_> = v.overflow_entries().collect();
assert!(ov.is_empty(), "overflow entry for masked-out slot should be gone");
}
#[test]
fn mask_with_all_ones_is_noop() {
let mut v = MemoryIntVec::new(4);
v.set(0, 300); v.set(1, 1); v.set(2, 0); v.set(3, 42);
let mask = MemoryBitVec::ones(4);
v.mask_with(&mask);
assert_eq!(v.get(0), 300);
assert_eq!(v.get(1), 1);
assert_eq!(v.get(2), 0);
assert_eq!(v.get(3), 42);
}
// ── BitMatrix: partial_group_presence_count ───────────────────────────────────
#[test]
fn bit_partial_group_presence_count() {
// col0=[T,F,T,F], col1=[T,T,F,F], col2=[F,T,T,F]
// group {0,1,2}: counts = [2, 2, 2, 0]
let (_d, m) = make_bit_matrix(&[
&[true, false, true, false],
&[true, true, false, false],
&[false,true, true, false],
]);
let g = ColGroup::new("g", vec![0, 1, 2]);
let result = m.partial_group_presence_count(&g, 1);
assert_eq!(result.get(0), 2);
assert_eq!(result.get(1), 2);
assert_eq!(result.get(2), 2);
assert_eq!(result.get(3), 0);
}
// ── BitMatrix: partial_group_any ──────────────────────────────────────────────
#[test]
fn bit_partial_group_any() {
// col0=[T,F,F], col1=[F,F,T], group {0,1}: any = [T, F, T]
let (_d, m) = make_bit_matrix(&[
&[true, false, false],
&[false, false, true],
]);
let g = ColGroup::new("g", vec![0, 1]);
let result = m.partial_group_any(&g, 1);
assert_eq!(result.get(0), true);
assert_eq!(result.get(1), false);
assert_eq!(result.get(2), true);
}
// ── Composition: partial results are additive ─────────────────────────────────
#[test]
fn int_presence_count_additive_across_split() {
// Simulate two partitions (different kmer ranges) whose counts should add.
// Global data for col0: [5,1,0,3,2], col1: [2,0,4,3,1] — threshold=2
// Split: partition A = slots 0..2, partition B = slots 2..5
let data_a: &[&[u32]] = &[&[5, 1], &[2, 0]];
let data_b: &[&[u32]] = &[&[0, 3, 2], &[4, 3, 1]];
let (_da, ma) = make_int_matrix(data_a);
let (_db, mb) = make_int_matrix(data_b);
let g = ColGroup::new("g", vec![0, 1]);
let pa = ma.partial_group_presence_count(&g, 2);
let pb = mb.partial_group_presence_count(&g, 2);
// Concatenate by adding (disjoint kmer ranges — here we just verify
// individual results match the expected per-partition counts).
// partition A: col0=[5≥2,1<2]=[T,F], col1=[2≥2,0<2]=[T,F] → [2, 0]
assert_eq!(pa.get(0), 2);
assert_eq!(pa.get(1), 0);
// partition B: col0=[0<2,3≥2,2≥2]=[F,T,T], col1=[4≥2,3≥2,1<2]=[T,T,F] → [1, 2, 1]
assert_eq!(pb.get(0), 1);
assert_eq!(pb.get(1), 2);
assert_eq!(pb.get(2), 1);
}
+1
View File
@@ -1,5 +1,6 @@
mod bitmatrix; mod bitmatrix;
mod bitvec; mod bitvec;
mod colgroup;
mod intmatrix; mod intmatrix;
mod memoryvec; mod memoryvec;
+23
View File
@@ -258,6 +258,29 @@ pub trait IntSliceMut: IntSlice {
} }
self self
} }
/// Zero every slot where the corresponding bit in `mask` is 0.
/// Iterates only the zero bits — O(n_zeros), O(1) when mask is all-ones.
fn mask_with<B: BitSlice>(&mut self, mask: &B) -> &mut Self {
assert_eq!(self.len(), mask.len(), "IntSlice/BitSlice length mismatch");
let n = self.len();
for (wi, &word) in mask.words().iter().enumerate() {
if word == u64::MAX { continue; }
let mut zeros = !word;
while zeros != 0 {
let bit = zeros.trailing_zeros() as usize;
let s = wi * 64 + bit;
if s < n {
// u8 is Copy — the immutable borrow from primary_bytes() ends
// before the mutable borrow from set() begins.
let b = self.primary_bytes()[s];
if b != 0 { self.set(s, 0); }
}
zeros &= zeros - 1;
}
}
self
}
} }
// ── IntSlice → MemoryBitVec conversions ─────────────────────────────────────── // ── IntSlice → MemoryBitVec conversions ───────────────────────────────────────