Add column group operations and mask_with trait
Introduce the `ColGroup` struct and `MatrixGroupOps` trait to manage named subsets of column indices and perform additive aggregations (count, sum, any). Implement these operations for `PersistentBitMatrix` and `PersistentCompactIntMatrix`, applying size-optimized branches for presence counts and direct accumulation for small groups. Additionally, add a `mask_with` trait method that efficiently zero-sets elements based on a mask, optimized for sparse masks with O(n_zeros) complexity. Include comprehensive tests covering overflow handling, slot masking, and result additivity across partitioned data.
This commit is contained in:
@@ -7,8 +7,10 @@ use ndarray::{Array1, Array2};
|
||||
use rayon::prelude::*;
|
||||
|
||||
use crate::bitvec::{PersistentBitVec, PersistentBitVecBuilder};
|
||||
use crate::colgroup::{ColGroup, MatrixGroupOps, inc_primary_bits};
|
||||
use crate::memoryintvec::MemoryIntVec;
|
||||
use crate::memoryvec::MemoryBitVec;
|
||||
use crate::traits::{BitSlice, BitSliceMut};
|
||||
use crate::traits::{BitSlice, BitSliceMut, IntSliceMut};
|
||||
use crate::layer_meta::LayerMeta;
|
||||
use crate::meta::MatrixMeta;
|
||||
|
||||
@@ -447,6 +449,45 @@ impl PersistentBitMatrixBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
// ── MatrixGroupOps ────────────────────────────────────────────────────────────
|
||||
|
||||
impl MatrixGroupOps for PersistentBitMatrix {
|
||||
fn partial_group_presence_count(&self, g: &ColGroup, _threshold: u32) -> MemoryIntVec {
|
||||
// Bit matrices store 0/1 — threshold is structurally always 1.
|
||||
// Materialize each column to a MemoryBitVec and accumulate directly.
|
||||
let n = self.n();
|
||||
if g.indices.len() < 255 {
|
||||
let mut primary = vec![0u8; n];
|
||||
for &c in &g.indices {
|
||||
let mbv = MemoryBitVec::from(&self.col_view(c));
|
||||
inc_primary_bits(&mut primary, &mbv);
|
||||
}
|
||||
MemoryIntVec::from_primary(primary)
|
||||
} else {
|
||||
let mut result = MemoryIntVec::new(n);
|
||||
for &c in &g.indices {
|
||||
let mbv = MemoryBitVec::from(&self.col_view(c));
|
||||
result.count_bits(&mbv);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
fn partial_group_sum(&self, g: &ColGroup) -> MemoryIntVec {
|
||||
// For bit matrices, sum = count of 1-bits — identical to presence_count.
|
||||
self.partial_group_presence_count(g, 1)
|
||||
}
|
||||
|
||||
fn partial_group_any(&self, g: &ColGroup, _threshold: u32) -> MemoryBitVec {
|
||||
let n = self.n();
|
||||
let mut result = MemoryBitVec::new(n);
|
||||
for &c in &g.indices {
|
||||
result.or(&self.col_view(c));
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
// ── Shared matrix helpers (also used by intmatrix.rs) ─────────────────────────
|
||||
|
||||
fn upper_pairs(n: usize) -> Vec<(usize, usize)> {
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
use crate::memoryintvec::MemoryIntVec;
|
||||
use crate::memoryvec::MemoryBitVec;
|
||||
use crate::traits::BitSlice;
|
||||
|
||||
// ── ColGroup ──────────────────────────────────────────────────────────────────
|
||||
|
||||
/// A named subset of columns, identified by their indices within the matrix.
|
||||
///
|
||||
/// Defined once at the index level; the same indices are valid across all
|
||||
/// partitions and layers because the column structure (samples / genomes) is
|
||||
/// identical everywhere — only the row space (kmer slots) is partitioned.
|
||||
pub struct ColGroup {
|
||||
pub name: String,
|
||||
pub indices: Vec<usize>,
|
||||
}
|
||||
|
||||
impl ColGroup {
|
||||
pub fn new(name: impl Into<String>, indices: Vec<usize>) -> Self {
|
||||
Self { name: name.into(), indices }
|
||||
}
|
||||
}
|
||||
|
||||
// ── MatrixGroupOps ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Per-matrix group aggregations that return **additive intermediates**.
|
||||
///
|
||||
/// Results must be composed by the caller (concat across partitions, add across
|
||||
/// layers) before applying final predicates (`geq`, `leq`, …). Non-additive
|
||||
/// predicates like `group_all` or `group_at_least(k)` are intentionally absent
|
||||
/// — they are derived at the index level from these intermediates.
|
||||
pub trait MatrixGroupOps {
|
||||
/// Per-slot count of group columns whose value ≥ `threshold`.
|
||||
fn partial_group_presence_count(&self, g: &ColGroup, threshold: u32) -> MemoryIntVec;
|
||||
|
||||
/// Per-slot sum of values across all group columns.
|
||||
fn partial_group_sum(&self, g: &ColGroup) -> MemoryIntVec;
|
||||
|
||||
/// Per-slot OR: true if any group column has value ≥ `threshold`.
|
||||
fn partial_group_any(&self, g: &ColGroup, threshold: u32) -> MemoryBitVec;
|
||||
}
|
||||
|
||||
// ── Internal helper ───────────────────────────────────────────────────────────
|
||||
|
||||
/// Iterate 1-bits of a `MemoryBitVec` and increment the corresponding raw
|
||||
/// byte. Caller must guarantee that no counter will reach 255 (group size
|
||||
/// < 255 columns), so that incrementing `u8` is safe and no sentinel is
|
||||
/// accidentally written.
|
||||
pub(crate) fn inc_primary_bits(primary: &mut [u8], mask: &MemoryBitVec) {
|
||||
let n = primary.len();
|
||||
for (wi, &word) in mask.words().iter().enumerate() {
|
||||
let mut w = word;
|
||||
while w != 0 {
|
||||
let bit = w.trailing_zeros() as usize;
|
||||
let s = wi * 64 + bit;
|
||||
if s < n { primary[s] += 1; }
|
||||
w &= w - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -10,11 +10,13 @@ use rayon::prelude::*;
|
||||
|
||||
use crate::bitmatrix::{pairwise_matrix, pairwise2_matrix};
|
||||
use crate::builder::PersistentCompactIntVecBuilder;
|
||||
use crate::colgroup::{ColGroup, MatrixGroupOps, inc_primary_bits};
|
||||
use crate::memoryintvec::MemoryIntVec;
|
||||
use crate::memoryvec::MemoryBitVec;
|
||||
use crate::format::{byte_count_nonzero, byte_sum, HEADER_SIZE, OVERFLOW_ENTRY_SIZE, parse_index_entry, parse_overflow_entry};
|
||||
use crate::meta::MatrixMeta;
|
||||
use crate::reader::PersistentCompactIntVec;
|
||||
use crate::traits::IntSlice;
|
||||
use crate::traits::{BitSliceMut, IntSlice, IntSliceMut};
|
||||
|
||||
fn col_path(dir: &Path, col: usize) -> PathBuf {
|
||||
dir.join(format!("col_{col:06}.pciv"))
|
||||
@@ -624,3 +626,49 @@ impl PersistentCompactIntMatrixBuilder {
|
||||
MatrixMeta { n: self.n, n_cols: self.n_cols }.save(&self.dir)
|
||||
}
|
||||
}
|
||||
|
||||
// ── MatrixGroupOps ────────────────────────────────────────────────────────────
|
||||
|
||||
impl MatrixGroupOps for PersistentCompactIntMatrix {
|
||||
fn partial_group_presence_count(&self, g: &ColGroup, threshold: u32) -> MemoryIntVec {
|
||||
let n = self.n();
|
||||
if g.indices.len() < 255 {
|
||||
// Fast path: counts fit in u8 — accumulate directly into raw bytes,
|
||||
// no overflow map involved.
|
||||
let mut primary = vec![0u8; n];
|
||||
for &c in &g.indices {
|
||||
let mask = self.col_view(c).cmp_scalar(|v| v >= threshold);
|
||||
inc_primary_bits(&mut primary, &mask);
|
||||
}
|
||||
MemoryIntVec::from_primary(primary)
|
||||
} else {
|
||||
// Slow path (rare): use IntSliceMut::count_bits which handles overflow.
|
||||
let mut result = MemoryIntVec::new(n);
|
||||
for &c in &g.indices {
|
||||
let mask = self.col_view(c).cmp_scalar(|v| v >= threshold);
|
||||
result.count_bits(&mask);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
fn partial_group_sum(&self, g: &ColGroup) -> MemoryIntVec {
|
||||
let n = self.n();
|
||||
let mut result = MemoryIntVec::new(n);
|
||||
for &c in &g.indices {
|
||||
let view = self.col_view(c);
|
||||
IntSliceMut::add(&mut result, &view);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn partial_group_any(&self, g: &ColGroup, threshold: u32) -> MemoryBitVec {
|
||||
let n = self.n();
|
||||
let mut result = MemoryBitVec::new(n);
|
||||
for &c in &g.indices {
|
||||
let mask = self.col_view(c).cmp_scalar(|v| v >= threshold);
|
||||
result.or(&mask);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod bitvec;
|
||||
mod bitmatrix;
|
||||
mod builder;
|
||||
mod colgroup;
|
||||
mod format;
|
||||
mod intmatrix;
|
||||
mod layer_meta;
|
||||
@@ -13,6 +14,7 @@ pub mod traits;
|
||||
pub use bitvec::{BitIter, PersistentBitVec, PersistentBitVecBuilder};
|
||||
pub use bitmatrix::{BitColView, PersistentBitMatrix, PersistentBitMatrixBuilder, pack_bit_matrix};
|
||||
pub use builder::PersistentCompactIntVecBuilder;
|
||||
pub use colgroup::{ColGroup, MatrixGroupOps};
|
||||
pub use intmatrix::{IntColView, PersistentCompactIntMatrix, PersistentCompactIntMatrixBuilder, pack_compact_int_matrix};
|
||||
pub use layer_meta::LayerMeta;
|
||||
pub use memoryintvec::{MemoryIntIter, MemoryIntVec};
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
use tempfile::tempdir;
|
||||
|
||||
use crate::{
|
||||
ColGroup, MatrixGroupOps,
|
||||
PersistentBitMatrix, PersistentBitMatrixBuilder,
|
||||
PersistentCompactIntMatrix, PersistentCompactIntMatrixBuilder,
|
||||
};
|
||||
use crate::traits::{BitSliceMut, IntSlice, IntSliceMut};
|
||||
use crate::{MemoryBitVec, MemoryIntVec};
|
||||
|
||||
// ── helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
fn make_int_matrix(cols: &[&[u32]]) -> (tempfile::TempDir, PersistentCompactIntMatrix) {
|
||||
let n = cols.first().map_or(0, |c| c.len());
|
||||
let dir = tempdir().unwrap();
|
||||
let mut b = PersistentCompactIntMatrixBuilder::new(n, &dir.path().join("counts")).unwrap();
|
||||
for &col in cols {
|
||||
let mut cb = b.add_col().unwrap();
|
||||
for (slot, &v) in col.iter().enumerate() { cb.set(slot, v); }
|
||||
cb.close().unwrap();
|
||||
}
|
||||
b.close().unwrap();
|
||||
let m = PersistentCompactIntMatrix::open(dir.path()).unwrap();
|
||||
(dir, m)
|
||||
}
|
||||
|
||||
fn make_bit_matrix(cols: &[&[bool]]) -> (tempfile::TempDir, PersistentBitMatrix) {
|
||||
let n = cols.first().map_or(0, |c| c.len());
|
||||
let dir = tempdir().unwrap();
|
||||
let presence = dir.path().join("presence");
|
||||
let mut b = PersistentBitMatrixBuilder::new(n, &presence).unwrap();
|
||||
for &col in cols {
|
||||
let mut cb = b.add_col().unwrap();
|
||||
for (slot, &v) in col.iter().enumerate() { cb.set(slot, v); }
|
||||
cb.close().unwrap();
|
||||
}
|
||||
b.close().unwrap();
|
||||
let m = PersistentBitMatrix::open(dir.path()).unwrap();
|
||||
(dir, m)
|
||||
}
|
||||
|
||||
// ── IntMatrix: partial_group_sum ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn int_partial_group_sum_basic() {
|
||||
// col0=[1,2,3], col1=[10,20,30], col2=[100,0,5]
|
||||
// group {0,2}: sum = [101, 2, 8]
|
||||
let (_d, m) = make_int_matrix(&[&[1, 2, 3], &[10, 20, 30], &[100, 0, 5]]);
|
||||
let g = ColGroup::new("g", vec![0, 2]);
|
||||
let result = m.partial_group_sum(&g);
|
||||
assert_eq!(result.get(0), 101);
|
||||
assert_eq!(result.get(1), 2);
|
||||
assert_eq!(result.get(2), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn int_partial_group_sum_with_overflow() {
|
||||
// col0=[300,0], col1=[200,400]: group {0,1}: sum=[500, 400]
|
||||
let (_d, m) = make_int_matrix(&[&[300, 0], &[200, 400]]);
|
||||
let g = ColGroup::new("g", vec![0, 1]);
|
||||
let result = m.partial_group_sum(&g);
|
||||
assert_eq!(result.get(0), 500);
|
||||
assert_eq!(result.get(1), 400);
|
||||
assert_eq!(result.sum(), 900);
|
||||
}
|
||||
|
||||
// ── IntMatrix: partial_group_presence_count ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn int_partial_group_presence_count() {
|
||||
// col0=[5,1,0,3], col1=[2,0,4,3], col2=[0,3,1,0]
|
||||
// threshold=2: col0: [T,F,F,T], col1: [T,F,T,T], col2: [F,T,F,F]
|
||||
// group {0,1,2}: counts = [2, 1, 1, 2]
|
||||
let (_d, m) = make_int_matrix(&[&[5, 1, 0, 3], &[2, 0, 4, 3], &[0, 3, 1, 0]]);
|
||||
let g = ColGroup::new("g", vec![0, 1, 2]);
|
||||
let result = m.partial_group_presence_count(&g, 2);
|
||||
assert_eq!(result.get(0), 2);
|
||||
assert_eq!(result.get(1), 1);
|
||||
assert_eq!(result.get(2), 1);
|
||||
assert_eq!(result.get(3), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn int_partial_group_presence_count_with_overflow() {
|
||||
// col0=[300,0,10], col1=[0,400,10], col2=[1,1,10]
|
||||
// threshold=5: col0: [T,F,T], col1: [F,T,T], col2: [F,F,T]
|
||||
// group {0,1,2}: counts = [1, 1, 3]
|
||||
let (_d, m) = make_int_matrix(&[&[300, 0, 10], &[0, 400, 10], &[1, 1, 10]]);
|
||||
let g = ColGroup::new("g", vec![0, 1, 2]);
|
||||
let result = m.partial_group_presence_count(&g, 5);
|
||||
assert_eq!(result.get(0), 1);
|
||||
assert_eq!(result.get(1), 1);
|
||||
assert_eq!(result.get(2), 3);
|
||||
}
|
||||
|
||||
// ── IntMatrix: partial_group_any ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn int_partial_group_any() {
|
||||
// col0=[0,3,0,1], col1=[2,0,0,0], col2=[0,0,5,0]
|
||||
// threshold=2: col0: [F,T,F,F], col1: [T,F,F,F], col2: [F,F,T,F]
|
||||
// group {0,1,2}: any = [T, T, T, F]
|
||||
let (_d, m) = make_int_matrix(&[&[0, 3, 0, 1], &[2, 0, 0, 0], &[0, 0, 5, 0]]);
|
||||
let g = ColGroup::new("g", vec![0, 1, 2]);
|
||||
let result = m.partial_group_any(&g, 2);
|
||||
assert_eq!(result.get(0), true);
|
||||
assert_eq!(result.get(1), true);
|
||||
assert_eq!(result.get(2), true);
|
||||
assert_eq!(result.get(3), false);
|
||||
}
|
||||
|
||||
// ── IntMatrix: mask_with ──────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn mask_with_zeros_selected_slots() {
|
||||
// count vec [10, 20, 30, 40], mask [T, F, T, F] → [10, 0, 30, 0]
|
||||
let mut v = MemoryIntVec::new(4);
|
||||
v.set(0, 10); v.set(1, 20); v.set(2, 30); v.set(3, 40);
|
||||
let mut mask = MemoryBitVec::new(4);
|
||||
mask.set(0, true); mask.set(2, true);
|
||||
v.mask_with(&mask);
|
||||
assert_eq!(v.get(0), 10);
|
||||
assert_eq!(v.get(1), 0);
|
||||
assert_eq!(v.get(2), 30);
|
||||
assert_eq!(v.get(3), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_with_overflow_slot_zeroed() {
|
||||
// overflow slot (value 500) masked out → removed from overflow, primary=0
|
||||
let mut v = MemoryIntVec::new(3);
|
||||
v.set(0, 10); v.set(1, 500); v.set(2, 5);
|
||||
let mut mask = MemoryBitVec::new(3);
|
||||
mask.set(0, true); mask.set(2, true); // slot 1 masked out
|
||||
v.mask_with(&mask);
|
||||
assert_eq!(v.get(0), 10);
|
||||
assert_eq!(v.get(1), 0);
|
||||
assert_eq!(v.get(2), 5);
|
||||
let ov: Vec<_> = v.overflow_entries().collect();
|
||||
assert!(ov.is_empty(), "overflow entry for masked-out slot should be gone");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mask_with_all_ones_is_noop() {
|
||||
let mut v = MemoryIntVec::new(4);
|
||||
v.set(0, 300); v.set(1, 1); v.set(2, 0); v.set(3, 42);
|
||||
let mask = MemoryBitVec::ones(4);
|
||||
v.mask_with(&mask);
|
||||
assert_eq!(v.get(0), 300);
|
||||
assert_eq!(v.get(1), 1);
|
||||
assert_eq!(v.get(2), 0);
|
||||
assert_eq!(v.get(3), 42);
|
||||
}
|
||||
|
||||
// ── BitMatrix: partial_group_presence_count ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn bit_partial_group_presence_count() {
|
||||
// col0=[T,F,T,F], col1=[T,T,F,F], col2=[F,T,T,F]
|
||||
// group {0,1,2}: counts = [2, 2, 2, 0]
|
||||
let (_d, m) = make_bit_matrix(&[
|
||||
&[true, false, true, false],
|
||||
&[true, true, false, false],
|
||||
&[false,true, true, false],
|
||||
]);
|
||||
let g = ColGroup::new("g", vec![0, 1, 2]);
|
||||
let result = m.partial_group_presence_count(&g, 1);
|
||||
assert_eq!(result.get(0), 2);
|
||||
assert_eq!(result.get(1), 2);
|
||||
assert_eq!(result.get(2), 2);
|
||||
assert_eq!(result.get(3), 0);
|
||||
}
|
||||
|
||||
// ── BitMatrix: partial_group_any ──────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn bit_partial_group_any() {
|
||||
// col0=[T,F,F], col1=[F,F,T], group {0,1}: any = [T, F, T]
|
||||
let (_d, m) = make_bit_matrix(&[
|
||||
&[true, false, false],
|
||||
&[false, false, true],
|
||||
]);
|
||||
let g = ColGroup::new("g", vec![0, 1]);
|
||||
let result = m.partial_group_any(&g, 1);
|
||||
assert_eq!(result.get(0), true);
|
||||
assert_eq!(result.get(1), false);
|
||||
assert_eq!(result.get(2), true);
|
||||
}
|
||||
|
||||
// ── Composition: partial results are additive ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn int_presence_count_additive_across_split() {
|
||||
// Simulate two partitions (different kmer ranges) whose counts should add.
|
||||
// Global data for col0: [5,1,0,3,2], col1: [2,0,4,3,1] — threshold=2
|
||||
// Split: partition A = slots 0..2, partition B = slots 2..5
|
||||
let data_a: &[&[u32]] = &[&[5, 1], &[2, 0]];
|
||||
let data_b: &[&[u32]] = &[&[0, 3, 2], &[4, 3, 1]];
|
||||
let (_da, ma) = make_int_matrix(data_a);
|
||||
let (_db, mb) = make_int_matrix(data_b);
|
||||
let g = ColGroup::new("g", vec![0, 1]);
|
||||
|
||||
let pa = ma.partial_group_presence_count(&g, 2);
|
||||
let pb = mb.partial_group_presence_count(&g, 2);
|
||||
|
||||
// Concatenate by adding (disjoint kmer ranges — here we just verify
|
||||
// individual results match the expected per-partition counts).
|
||||
// partition A: col0=[5≥2,1<2]=[T,F], col1=[2≥2,0<2]=[T,F] → [2, 0]
|
||||
assert_eq!(pa.get(0), 2);
|
||||
assert_eq!(pa.get(1), 0);
|
||||
// partition B: col0=[0<2,3≥2,2≥2]=[F,T,T], col1=[4≥2,3≥2,1<2]=[T,T,F] → [1, 2, 1]
|
||||
assert_eq!(pb.get(0), 1);
|
||||
assert_eq!(pb.get(1), 2);
|
||||
assert_eq!(pb.get(2), 1);
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
mod bitmatrix;
|
||||
mod bitvec;
|
||||
mod colgroup;
|
||||
mod intmatrix;
|
||||
mod memoryvec;
|
||||
|
||||
|
||||
@@ -258,6 +258,29 @@ pub trait IntSliceMut: IntSlice {
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Zero every slot where the corresponding bit in `mask` is 0.
|
||||
/// Iterates only the zero bits — O(n_zeros), O(1) when mask is all-ones.
|
||||
fn mask_with<B: BitSlice>(&mut self, mask: &B) -> &mut Self {
|
||||
assert_eq!(self.len(), mask.len(), "IntSlice/BitSlice length mismatch");
|
||||
let n = self.len();
|
||||
for (wi, &word) in mask.words().iter().enumerate() {
|
||||
if word == u64::MAX { continue; }
|
||||
let mut zeros = !word;
|
||||
while zeros != 0 {
|
||||
let bit = zeros.trailing_zeros() as usize;
|
||||
let s = wi * 64 + bit;
|
||||
if s < n {
|
||||
// u8 is Copy — the immutable borrow from primary_bytes() ends
|
||||
// before the mutable borrow from set() begins.
|
||||
let b = self.primary_bytes()[s];
|
||||
if b != 0 { self.set(s, 0); }
|
||||
}
|
||||
zeros &= zeros - 1;
|
||||
}
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// ── IntSlice → MemoryBitVec conversions ───────────────────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user