feat: add vectorized column filters and optimize partitioner iteration

Adds `FilterMask` and conditional bitwise methods (`*_where`) to `obicompactvec` for composable column-based slot filtering. Extends `obikpartitionner` with a `MatrixGroupOps` trait and `column_mask_expr` method to express aggregate constraints as vectorized masks. Refactors matrix builder management into a unified `Builders` enum and introduces `try_compute_combined_mask`, enabling O(1) slot checks and skipping unnecessary row reads during partitioning and rebuilding passes.
This commit is contained in:
Eric Coissac
2026-06-19 09:12:07 +02:00
parent 4c4524766c
commit 7c1efa9cbb
7 changed files with 462 additions and 65 deletions
+67 -1
View File
@@ -5,7 +5,7 @@ use std::path::{Path, PathBuf};
use memmap2::{Mmap, MmapMut};
use crate::reader::PersistentCompactIntVec;
use crate::views::{BitSliceView, BitSliceIter};
use crate::views::{BitSliceIter, BitSliceView, IntSliceView};
const MAGIC: [u8; 4] = *b"PBIV";
@@ -241,6 +241,72 @@ impl PersistentBitVecBuilder {
}
}
/// OR in bits at slots where `pred(col[slot])` is true.
pub fn or_where(&mut self, col: IntSliceView<'_>, pred: impl Fn(u32) -> bool) {
assert_eq!(self.n, col.len(), "IntSliceView length mismatch");
let n = self.n;
let primary = col.primary_bytes();
let words = self.data_words_mut();
let nw = n_words(n);
for wi in 0..nw {
let base = wi * 64;
let limit = (base + 64).min(n);
let mut mask = 0u64;
for bit in 0..(limit - base) {
let b = primary[base + bit];
if b < 255 && pred(b as u32) { mask |= 1u64 << bit; }
}
words[wi] |= mask;
}
for (slot, val) in col.overflow_entries() {
if pred(val) { words[slot >> 6] |= 1u64 << (slot & 63); }
}
}
/// Clear bits at slots where `pred(col[slot])` is false.
pub fn and_where(&mut self, col: IntSliceView<'_>, pred: impl Fn(u32) -> bool) {
assert_eq!(self.n, col.len(), "IntSliceView length mismatch");
let n = self.n;
let primary = col.primary_bytes();
let words = self.data_words_mut();
let nw = n_words(n);
for wi in 0..nw {
let base = wi * 64;
let limit = (base + 64).min(n);
let mut mask = 0u64;
for bit in 0..(limit - base) {
let b = primary[base + bit];
if b < 255 && !pred(b as u32) { mask |= 1u64 << bit; }
}
words[wi] &= !mask;
}
for (slot, val) in col.overflow_entries() {
if !pred(val) { words[slot >> 6] &= !(1u64 << (slot & 63)); }
}
}
/// Toggle bits at slots where `pred(col[slot])` is true.
pub fn xor_where(&mut self, col: IntSliceView<'_>, pred: impl Fn(u32) -> bool) {
assert_eq!(self.n, col.len(), "IntSliceView length mismatch");
let n = self.n;
let primary = col.primary_bytes();
let words = self.data_words_mut();
let nw = n_words(n);
for wi in 0..nw {
let base = wi * 64;
let limit = (base + 64).min(n);
let mut mask = 0u64;
for bit in 0..(limit - base) {
let b = primary[base + bit];
if b < 255 && pred(b as u32) { mask |= 1u64 << bit; }
}
words[wi] ^= mask;
}
for (slot, val) in col.overflow_entries() {
if pred(val) { words[slot >> 6] ^= 1u64 << (slot & 63); }
}
}
pub fn iter(&self) -> BitSliceIter<'_> {
self.view().iter()
}
+70
View File
@@ -70,3 +70,73 @@ pub trait MatrixGroupOps {
b.freeze()
}
}
// ── FilterMask — expression tree for column-based slot filters ────────────────
/// A composable filter expression that can be evaluated against a matrix
/// using only column operations (no MPHF lookup per kmer).
///
/// `threshold` semantics follow [`MatrixGroupOps::partial_group_presence_count`]:
/// a slot contributes to the count when its value is **≥ threshold**.
/// To match the row-level filter (`value > t`), callers should pass `t + 1`.
#[derive(Debug, Clone)]
pub enum FilterMask {
/// Slot passes if count of columns in `indices` with value ≥ `threshold` is ≥ `min_count`.
PresenceGeq { indices: Vec<usize>, threshold: u32, min_count: usize },
/// Slot passes if count of columns in `indices` with value ≥ `threshold` is ≤ `max_count`.
PresenceLeq { indices: Vec<usize>, threshold: u32, max_count: usize },
/// Slot passes if sum of values across `indices` columns is ≥ `min_sum`.
SumGeq { indices: Vec<usize>, min_sum: u32 },
/// Slot passes if sum of values across `indices` columns is ≤ `max_sum`.
SumLeq { indices: Vec<usize>, max_sum: u32 },
/// Slot passes if it passes all sub-expressions. Empty `And` is always true.
And(Vec<FilterMask>),
}
/// Evaluate a [`FilterMask`] against `mat`, returning a per-slot `TempBitVec`
/// where bit=1 means the slot passes the filter.
pub fn eval_filter_mask(expr: &FilterMask, mat: &dyn MatrixGroupOps, n: usize) -> io::Result<TempBitVec> {
match expr {
FilterMask::PresenceGeq { indices, threshold, min_count } => {
let g = ColGroup::new("", indices.clone());
let counts = mat.partial_group_presence_count(&g, *threshold)?;
let mut b = TempBitVecBuilder::new(n)?;
let mc = *min_count as u32;
b.or_where(counts.view(), |v| v >= mc);
b.freeze()
}
FilterMask::PresenceLeq { indices, threshold, max_count } => {
let g = ColGroup::new("", indices.clone());
let counts = mat.partial_group_presence_count(&g, *threshold)?;
let mut b = TempBitVecBuilder::new(n)?;
let mc = *max_count as u32;
b.or_where(counts.view(), |v| v <= mc);
b.freeze()
}
FilterMask::SumGeq { indices, min_sum } => {
let g = ColGroup::new("", indices.clone());
let sums = mat.partial_group_sum(&g)?;
let mut b = TempBitVecBuilder::new(n)?;
let ms = *min_sum;
b.or_where(sums.view(), |v| v >= ms);
b.freeze()
}
FilterMask::SumLeq { indices, max_sum } => {
let g = ColGroup::new("", indices.clone());
let sums = mat.partial_group_sum(&g)?;
let mut b = TempBitVecBuilder::new(n)?;
let ms = *max_sum;
b.or_where(sums.view(), |v| v <= ms);
b.freeze()
}
FilterMask::And(parts) => {
let mut b = TempBitVecBuilder::new(n)?;
b.not(); // initialise à tout-1 (tout passe)
for part in parts {
let m = eval_filter_mask(part, mat, n)?;
b.and(m.view());
}
b.freeze()
}
}
}
+1 -1
View File
@@ -15,7 +15,7 @@ pub mod traits;
pub use bitvec::{BitIter, PersistentBitVec, PersistentBitVecBuilder};
pub use bitmatrix::{PersistentBitMatrix, PersistentBitMatrixBuilder, pack_bit_matrix};
pub use builder::PersistentCompactIntVecBuilder;
pub use colgroup::{ColGroup, MatrixGroupOps};
pub use colgroup::{ColGroup, FilterMask, MatrixGroupOps, eval_filter_mask};
pub use intmatrix::{PersistentCompactIntMatrix, PersistentCompactIntMatrixBuilder, pack_compact_int_matrix};
pub use layer_meta::LayerMeta;
pub use reader::PersistentCompactIntVec;
+25 -12
View File
@@ -73,18 +73,31 @@ impl TempBitVecBuilder {
self.builder.or(other);
}
/// Set self[slot] where pred(col[slot]) is true. Two-pass: primary then overflow.
pub(crate) fn and(&mut self, other: BitSliceView<'_>) {
self.builder.and(other);
}
pub(crate) fn xor(&mut self, other: BitSliceView<'_>) {
self.builder.xor(other);
}
pub(crate) fn not(&mut self) {
self.builder.not();
}
pub(crate) fn copy_from(&mut self, src: BitSliceView<'_>) {
self.builder.copy_from(src);
}
pub fn or_where(&mut self, col: IntSliceView<'_>, pred: impl Fn(u32) -> bool) {
for slot in 0..col.len() {
let b = col.primary_bytes()[slot];
if b < 255 && pred(b as u32) {
self.builder.set(slot, true);
}
}
for (slot, val) in col.overflow_entries() {
if pred(val) {
self.builder.set(slot, true);
}
}
self.builder.or_where(col, pred);
}
pub(crate) fn and_where(&mut self, col: IntSliceView<'_>, pred: impl Fn(u32) -> bool) {
self.builder.and_where(col, pred);
}
pub(crate) fn xor_where(&mut self, col: IntSliceView<'_>, pred: impl Fn(u32) -> bool) {
self.builder.xor_where(col, pred);
}
}