feat: add vectorized column filters and optimize partitioner iteration
Adds `FilterMask` and conditional bitwise methods (`*_where`) to `obicompactvec` for composable column-based slot filtering. Extends `obikpartitionner` with a `MatrixGroupOps` trait and `column_mask_expr` method to express aggregate constraints as vectorized masks. Refactors matrix builder management into a unified `Builders` enum and introduces `try_compute_combined_mask`, enabling O(1) slot checks and skipping unnecessary row reads during partitioning and rebuilding passes.
This commit is contained in:
@@ -5,7 +5,7 @@ use std::path::{Path, PathBuf};
|
||||
use memmap2::{Mmap, MmapMut};
|
||||
|
||||
use crate::reader::PersistentCompactIntVec;
|
||||
use crate::views::{BitSliceView, BitSliceIter};
|
||||
use crate::views::{BitSliceIter, BitSliceView, IntSliceView};
|
||||
|
||||
const MAGIC: [u8; 4] = *b"PBIV";
|
||||
|
||||
@@ -241,6 +241,72 @@ impl PersistentBitVecBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/// OR in bits at slots where `pred(col[slot])` is true.
|
||||
pub fn or_where(&mut self, col: IntSliceView<'_>, pred: impl Fn(u32) -> bool) {
|
||||
assert_eq!(self.n, col.len(), "IntSliceView length mismatch");
|
||||
let n = self.n;
|
||||
let primary = col.primary_bytes();
|
||||
let words = self.data_words_mut();
|
||||
let nw = n_words(n);
|
||||
for wi in 0..nw {
|
||||
let base = wi * 64;
|
||||
let limit = (base + 64).min(n);
|
||||
let mut mask = 0u64;
|
||||
for bit in 0..(limit - base) {
|
||||
let b = primary[base + bit];
|
||||
if b < 255 && pred(b as u32) { mask |= 1u64 << bit; }
|
||||
}
|
||||
words[wi] |= mask;
|
||||
}
|
||||
for (slot, val) in col.overflow_entries() {
|
||||
if pred(val) { words[slot >> 6] |= 1u64 << (slot & 63); }
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear bits at slots where `pred(col[slot])` is false.
|
||||
pub fn and_where(&mut self, col: IntSliceView<'_>, pred: impl Fn(u32) -> bool) {
|
||||
assert_eq!(self.n, col.len(), "IntSliceView length mismatch");
|
||||
let n = self.n;
|
||||
let primary = col.primary_bytes();
|
||||
let words = self.data_words_mut();
|
||||
let nw = n_words(n);
|
||||
for wi in 0..nw {
|
||||
let base = wi * 64;
|
||||
let limit = (base + 64).min(n);
|
||||
let mut mask = 0u64;
|
||||
for bit in 0..(limit - base) {
|
||||
let b = primary[base + bit];
|
||||
if b < 255 && !pred(b as u32) { mask |= 1u64 << bit; }
|
||||
}
|
||||
words[wi] &= !mask;
|
||||
}
|
||||
for (slot, val) in col.overflow_entries() {
|
||||
if !pred(val) { words[slot >> 6] &= !(1u64 << (slot & 63)); }
|
||||
}
|
||||
}
|
||||
|
||||
/// Toggle bits at slots where `pred(col[slot])` is true.
|
||||
pub fn xor_where(&mut self, col: IntSliceView<'_>, pred: impl Fn(u32) -> bool) {
|
||||
assert_eq!(self.n, col.len(), "IntSliceView length mismatch");
|
||||
let n = self.n;
|
||||
let primary = col.primary_bytes();
|
||||
let words = self.data_words_mut();
|
||||
let nw = n_words(n);
|
||||
for wi in 0..nw {
|
||||
let base = wi * 64;
|
||||
let limit = (base + 64).min(n);
|
||||
let mut mask = 0u64;
|
||||
for bit in 0..(limit - base) {
|
||||
let b = primary[base + bit];
|
||||
if b < 255 && pred(b as u32) { mask |= 1u64 << bit; }
|
||||
}
|
||||
words[wi] ^= mask;
|
||||
}
|
||||
for (slot, val) in col.overflow_entries() {
|
||||
if pred(val) { words[slot >> 6] ^= 1u64 << (slot & 63); }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> BitSliceIter<'_> {
|
||||
self.view().iter()
|
||||
}
|
||||
|
||||
@@ -70,3 +70,73 @@ pub trait MatrixGroupOps {
|
||||
b.freeze()
|
||||
}
|
||||
}
|
||||
|
||||
// ── FilterMask — expression tree for column-based slot filters ────────────────
|
||||
|
||||
/// A composable filter expression that can be evaluated against a matrix
|
||||
/// using only column operations (no MPHF lookup per kmer).
|
||||
///
|
||||
/// `threshold` semantics follow [`MatrixGroupOps::partial_group_presence_count`]:
|
||||
/// a slot contributes to the count when its value is **≥ threshold**.
|
||||
/// To match the row-level filter (`value > t`), callers should pass `t + 1`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum FilterMask {
|
||||
/// Slot passes if count of columns in `indices` with value ≥ `threshold` is ≥ `min_count`.
|
||||
PresenceGeq { indices: Vec<usize>, threshold: u32, min_count: usize },
|
||||
/// Slot passes if count of columns in `indices` with value ≥ `threshold` is ≤ `max_count`.
|
||||
PresenceLeq { indices: Vec<usize>, threshold: u32, max_count: usize },
|
||||
/// Slot passes if sum of values across `indices` columns is ≥ `min_sum`.
|
||||
SumGeq { indices: Vec<usize>, min_sum: u32 },
|
||||
/// Slot passes if sum of values across `indices` columns is ≤ `max_sum`.
|
||||
SumLeq { indices: Vec<usize>, max_sum: u32 },
|
||||
/// Slot passes if it passes all sub-expressions. Empty `And` is always true.
|
||||
And(Vec<FilterMask>),
|
||||
}
|
||||
|
||||
/// Evaluate a [`FilterMask`] against `mat`, returning a per-slot `TempBitVec`
|
||||
/// where bit=1 means the slot passes the filter.
|
||||
pub fn eval_filter_mask(expr: &FilterMask, mat: &dyn MatrixGroupOps, n: usize) -> io::Result<TempBitVec> {
|
||||
match expr {
|
||||
FilterMask::PresenceGeq { indices, threshold, min_count } => {
|
||||
let g = ColGroup::new("", indices.clone());
|
||||
let counts = mat.partial_group_presence_count(&g, *threshold)?;
|
||||
let mut b = TempBitVecBuilder::new(n)?;
|
||||
let mc = *min_count as u32;
|
||||
b.or_where(counts.view(), |v| v >= mc);
|
||||
b.freeze()
|
||||
}
|
||||
FilterMask::PresenceLeq { indices, threshold, max_count } => {
|
||||
let g = ColGroup::new("", indices.clone());
|
||||
let counts = mat.partial_group_presence_count(&g, *threshold)?;
|
||||
let mut b = TempBitVecBuilder::new(n)?;
|
||||
let mc = *max_count as u32;
|
||||
b.or_where(counts.view(), |v| v <= mc);
|
||||
b.freeze()
|
||||
}
|
||||
FilterMask::SumGeq { indices, min_sum } => {
|
||||
let g = ColGroup::new("", indices.clone());
|
||||
let sums = mat.partial_group_sum(&g)?;
|
||||
let mut b = TempBitVecBuilder::new(n)?;
|
||||
let ms = *min_sum;
|
||||
b.or_where(sums.view(), |v| v >= ms);
|
||||
b.freeze()
|
||||
}
|
||||
FilterMask::SumLeq { indices, max_sum } => {
|
||||
let g = ColGroup::new("", indices.clone());
|
||||
let sums = mat.partial_group_sum(&g)?;
|
||||
let mut b = TempBitVecBuilder::new(n)?;
|
||||
let ms = *max_sum;
|
||||
b.or_where(sums.view(), |v| v <= ms);
|
||||
b.freeze()
|
||||
}
|
||||
FilterMask::And(parts) => {
|
||||
let mut b = TempBitVecBuilder::new(n)?;
|
||||
b.not(); // initialise à tout-1 (tout passe)
|
||||
for part in parts {
|
||||
let m = eval_filter_mask(part, mat, n)?;
|
||||
b.and(m.view());
|
||||
}
|
||||
b.freeze()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ pub mod traits;
|
||||
pub use bitvec::{BitIter, PersistentBitVec, PersistentBitVecBuilder};
|
||||
pub use bitmatrix::{PersistentBitMatrix, PersistentBitMatrixBuilder, pack_bit_matrix};
|
||||
pub use builder::PersistentCompactIntVecBuilder;
|
||||
pub use colgroup::{ColGroup, MatrixGroupOps};
|
||||
pub use colgroup::{ColGroup, FilterMask, MatrixGroupOps, eval_filter_mask};
|
||||
pub use intmatrix::{PersistentCompactIntMatrix, PersistentCompactIntMatrixBuilder, pack_compact_int_matrix};
|
||||
pub use layer_meta::LayerMeta;
|
||||
pub use reader::PersistentCompactIntVec;
|
||||
|
||||
@@ -73,18 +73,31 @@ impl TempBitVecBuilder {
|
||||
self.builder.or(other);
|
||||
}
|
||||
|
||||
/// Set self[slot] where pred(col[slot]) is true. Two-pass: primary then overflow.
|
||||
pub(crate) fn and(&mut self, other: BitSliceView<'_>) {
|
||||
self.builder.and(other);
|
||||
}
|
||||
|
||||
pub(crate) fn xor(&mut self, other: BitSliceView<'_>) {
|
||||
self.builder.xor(other);
|
||||
}
|
||||
|
||||
pub(crate) fn not(&mut self) {
|
||||
self.builder.not();
|
||||
}
|
||||
|
||||
pub(crate) fn copy_from(&mut self, src: BitSliceView<'_>) {
|
||||
self.builder.copy_from(src);
|
||||
}
|
||||
|
||||
pub fn or_where(&mut self, col: IntSliceView<'_>, pred: impl Fn(u32) -> bool) {
|
||||
for slot in 0..col.len() {
|
||||
let b = col.primary_bytes()[slot];
|
||||
if b < 255 && pred(b as u32) {
|
||||
self.builder.set(slot, true);
|
||||
}
|
||||
}
|
||||
for (slot, val) in col.overflow_entries() {
|
||||
if pred(val) {
|
||||
self.builder.set(slot, true);
|
||||
}
|
||||
}
|
||||
self.builder.or_where(col, pred);
|
||||
}
|
||||
|
||||
pub(crate) fn and_where(&mut self, col: IntSliceView<'_>, pred: impl Fn(u32) -> bool) {
|
||||
self.builder.and_where(col, pred);
|
||||
}
|
||||
|
||||
pub(crate) fn xor_where(&mut self, col: IntSliceView<'_>, pred: impl Fn(u32) -> bool) {
|
||||
self.builder.xor_where(col, pred);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user