refactor: replace manual bit ops with BitSlice traits

Refactors bit manipulation and distance calculations to leverage standardized `BitSlice` traits, replacing manual byte/word logic with safer, reusable methods. Extends `IntSlice` and `IntSliceMut` traits to expose direct memory-mapped access and overflow management, enabling efficient bulk data extraction and serialization. Replaces manual bit-shifting loops with optimized table-based unpacking and adds population count and distance metric methods for improved performance. Updates `PersistentBitVecBuilder` with file tracking and safe flushing, and aligns test imports with new trait bounds.
This commit is contained in:
Eric Coissac
2026-06-17 09:19:30 +02:00
parent df7b400fda
commit 5ff5b04d2d
9 changed files with 128 additions and 133 deletions
+35 -65
View File
@@ -8,7 +8,7 @@ use rayon::prelude::*;
use crate::bitvec::{PersistentBitVec, PersistentBitVecBuilder}; use crate::bitvec::{PersistentBitVec, PersistentBitVecBuilder};
use crate::memoryvec::MemoryBitVec; use crate::memoryvec::MemoryBitVec;
use crate::traits::BitSliceMut; use crate::traits::{BitSlice, BitSliceMut};
use crate::layer_meta::LayerMeta; use crate::layer_meta::LayerMeta;
use crate::meta::MatrixMeta; use crate::meta::MatrixMeta;
@@ -126,11 +126,22 @@ impl PackedBitMatrix {
}).collect() }).collect()
} }
#[inline]
fn col_bytes(&self, c: usize) -> &[u8] { fn col_bytes(&self, c: usize) -> &[u8] {
let start = self.data_offsets[c]; let start = self.data_offsets[c];
let len = (self.n_rows + 7) / 8; &self.mmap[start..start + self.n_rows.div_ceil(8)]
&self.mmap[start..start + len] }
fn col_words(&self, c: usize) -> &[u64] {
let nw = self.n_rows.div_ceil(64);
// SAFETY: data_offsets[c] is always 8-byte aligned.
// PBMX header = 24 + n_cols×8 (multiple of 8); each PBIV blob =
// 16 + nwords×8 (multiple of 8); mmap base is page-aligned.
let ptr = self.mmap[self.data_offsets[c]..].as_ptr() as *const u64;
unsafe { std::slice::from_raw_parts(ptr, nw) }
}
pub(crate) fn col_slice(&self, c: usize) -> PackedCol<'_> {
PackedCol { words: self.col_words(c), n: self.n_rows }
} }
pub(crate) fn col_persist(&self, c: usize, path: &Path) -> io::Result<PersistentBitVecBuilder> { pub(crate) fn col_persist(&self, c: usize, path: &Path) -> io::Result<PersistentBitVecBuilder> {
@@ -138,81 +149,40 @@ impl PackedBitMatrix {
} }
pub(crate) fn col_as_memory(&self, c: usize) -> MemoryBitVec { pub(crate) fn col_as_memory(&self, c: usize) -> MemoryBitVec {
let bytes = self.col_bytes(c); MemoryBitVec::from(&self.col_slice(c))
let n = self.n_rows;
let n_words = n.div_ceil(64);
let mut words = vec![0u64; n_words];
let full = bytes.len() / 8;
for (i, chunk) in bytes[..full * 8].chunks_exact(8).enumerate() {
words[i] = u64::from_le_bytes(chunk.try_into().unwrap());
}
let rem = bytes.len() % 8;
if rem > 0 {
let mut last = [0u8; 8];
last[..rem].copy_from_slice(&bytes[full * 8..]);
words[full] = u64::from_le_bytes(last);
}
MemoryBitVec::from_words(words, n)
}
fn count_ones_col(&self, c: usize) -> u64 {
let bytes = self.col_bytes(c);
let full = self.n_rows / 8;
let rem = self.n_rows % 8;
let mut n: u64 = bytes[..full].iter().map(|b| b.count_ones() as u64).sum();
if rem > 0 { n += (bytes[full] & ((1u8 << rem) - 1)).count_ones() as u64; }
n
}
fn pair_op(&self, i: usize, j: usize, and_or: bool) -> u64 {
let ai = self.col_bytes(i);
let aj = self.col_bytes(j);
let full = self.n_rows / 8;
let rem = self.n_rows % 8;
let mut n: u64 = ai[..full].iter().zip(aj[..full].iter())
.map(|(a, b)| if and_or { a & b } else { a ^ b }.count_ones() as u64)
.sum();
if rem > 0 {
let mask = (1u8 << rem) - 1;
let last = if and_or { ai[full] & aj[full] } else { ai[full] ^ aj[full] };
n += (last & mask).count_ones() as u64;
}
n
}
fn partial_jaccard_col(&self, i: usize, j: usize) -> (u64, u64) {
let ai = self.col_bytes(i);
let aj = self.col_bytes(j);
let full = self.n_rows / 8;
let rem = self.n_rows % 8;
let (mut inter, mut union) = ai[..full].iter().zip(aj[..full].iter())
.fold((0u64, 0u64), |(inter, union), (a, b)| {
(inter + (a & b).count_ones() as u64,
union + (a | b).count_ones() as u64)
});
if rem > 0 {
let mask = (1u8 << rem) - 1;
inter += ((ai[full] & aj[full]) & mask).count_ones() as u64;
union += ((ai[full] | aj[full]) & mask).count_ones() as u64;
}
(inter, union)
} }
pub(crate) fn count_ones(&self) -> Array1<u64> { pub(crate) fn count_ones(&self) -> Array1<u64> {
Array1::from_vec( Array1::from_vec(
(0..self.n_cols).into_par_iter().map(|c| self.count_ones_col(c)).collect() (0..self.n_cols).into_par_iter()
.map(|c| self.col_slice(c).count_ones())
.collect()
) )
} }
pub(crate) fn partial_jaccard_dist_matrix(&self) -> (Array2<u64>, Array2<u64>) { pub(crate) fn partial_jaccard_dist_matrix(&self) -> (Array2<u64>, Array2<u64>) {
pairwise2_matrix(self.n_cols, |i, j| self.partial_jaccard_col(i, j)) pairwise2_matrix(self.n_cols, |i, j| {
self.col_slice(i).partial_jaccard_dist(&self.col_slice(j))
})
} }
pub(crate) fn partial_hamming_dist_matrix(&self) -> Array2<u64> { pub(crate) fn partial_hamming_dist_matrix(&self) -> Array2<u64> {
pairwise_matrix(self.n_cols, |i, j| self.pair_op(i, j, false)) pairwise_matrix(self.n_cols, |i, j| {
self.col_slice(i).hamming_dist(&self.col_slice(j))
})
} }
} }
pub(crate) struct PackedCol<'a> {
words: &'a [u64],
n: usize,
}
impl BitSlice for PackedCol<'_> {
fn len(&self) -> usize { self.n }
fn words(&self) -> &[u64] { self.words }
}
/// Build `presence/matrix.pbmx` from existing `col_*.pbiv` files. /// Build `presence/matrix.pbmx` from existing `col_*.pbiv` files.
pub fn pack_bit_matrix(dir: &Path) -> io::Result<()> { pub fn pack_bit_matrix(dir: &Path) -> io::Result<()> {
let packed_path = dir.join("matrix.pbmx"); let packed_path = dir.join("matrix.pbmx");
+12 -46
View File
@@ -78,48 +78,6 @@ impl PersistentBitVec {
unsafe { std::slice::from_raw_parts(ptr, nw) } unsafe { std::slice::from_raw_parts(ptr, nw) }
} }
pub fn count_ones(&self) -> u64 {
// Padding bits in the last word are 0, so no masking needed.
self.data_words()
.iter()
.map(|w| w.count_ones() as u64)
.sum()
}
pub fn count_zeros(&self) -> u64 {
self.n as u64 - self.count_ones()
}
pub fn jaccard_dist(&self, other: &PersistentBitVec) -> f64 {
let (inter, union) = self.partial_jaccard_dist(other);
if union == 0 {
return 0.0;
}
1.0 - inter as f64 / union as f64
}
pub fn partial_jaccard_dist(&self, other: &PersistentBitVec) -> (u64, u64) {
assert_eq!(self.n, other.n, "length mismatch");
self.data_words()
.iter()
.zip(other.data_words())
.fold((0u64, 0u64), |(i, u), (&a, &b)| {
(
i + (a & b).count_ones() as u64,
u + (a | b).count_ones() as u64,
)
})
}
pub fn hamming_dist(&self, other: &PersistentBitVec) -> u64 {
assert_eq!(self.n, other.n, "length mismatch");
self.data_words()
.iter()
.zip(other.data_words())
.map(|(&a, &b)| (a ^ b).count_ones() as u64)
.sum()
}
pub fn iter(&self) -> BitIter<'_> { pub fn iter(&self) -> BitIter<'_> {
BitIter { BitIter {
bytes: self.data_bytes(), bytes: self.data_bytes(),
@@ -168,6 +126,7 @@ impl Iterator for BitIter<'_> {
pub struct PersistentBitVecBuilder { pub struct PersistentBitVecBuilder {
mmap: MmapMut, mmap: MmapMut,
n: usize, n: usize,
path: PathBuf,
} }
impl PersistentBitVecBuilder { impl PersistentBitVecBuilder {
@@ -185,7 +144,7 @@ impl PersistentBitVecBuilder {
file.seek(SeekFrom::Start(0))?; file.seek(SeekFrom::Start(0))?;
file.set_len(file_size as u64)?; file.set_len(file_size as u64)?;
let mmap = unsafe { MmapMut::map_mut(&file)? }; let mmap = unsafe { MmapMut::map_mut(&file)? };
Ok(Self { mmap, n }) Ok(Self { mmap, n, path: path.to_path_buf() })
} }
/// Create a PBIV file from raw packed bit-bytes, zero-padding to the next word boundary. /// Create a PBIV file from raw packed bit-bytes, zero-padding to the next word boundary.
@@ -200,7 +159,7 @@ impl PersistentBitVecBuilder {
mmap[0..4].copy_from_slice(&MAGIC); mmap[0..4].copy_from_slice(&MAGIC);
mmap[8..16].copy_from_slice(&(n as u64).to_le_bytes()); mmap[8..16].copy_from_slice(&(n as u64).to_le_bytes());
mmap[HEADER_SIZE..HEADER_SIZE + bytes.len()].copy_from_slice(bytes); mmap[HEADER_SIZE..HEADER_SIZE + bytes.len()].copy_from_slice(bytes);
Ok(Self { mmap, n }) Ok(Self { mmap, n, path: path.to_path_buf() })
} }
pub fn build_from(source: &PersistentBitVec, path: &Path) -> io::Result<Self> { pub fn build_from(source: &PersistentBitVec, path: &Path) -> io::Result<Self> {
@@ -208,7 +167,7 @@ impl PersistentBitVecBuilder {
let file = OpenOptions::new().read(true).write(true).open(path)?; let file = OpenOptions::new().read(true).write(true).open(path)?;
let mmap = unsafe { MmapMut::map_mut(&file)? }; let mmap = unsafe { MmapMut::map_mut(&file)? };
let n = source.len(); let n = source.len();
Ok(Self { mmap, n }) Ok(Self { mmap, n, path: path.to_path_buf() })
} }
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
@@ -268,7 +227,7 @@ impl PersistentBitVecBuilder {
} }
} }
Ok(Self { mmap, n }) Ok(Self { mmap, n, path: path.to_path_buf() })
} }
/// Convert a count vector to a presence/absence bit vector (threshold = 1). /// Convert a count vector to a presence/absence bit vector (threshold = 1).
@@ -279,6 +238,13 @@ impl PersistentBitVecBuilder {
pub fn close(self) -> io::Result<()> { pub fn close(self) -> io::Result<()> {
self.mmap.flush() self.mmap.flush()
} }
/// Flush, close, and reopen as a read-only `PersistentBitVec`.
pub fn finish(self) -> io::Result<PersistentBitVec> {
let path = self.path.clone();
self.close()?;
PersistentBitVec::open(&path)
}
} }
// ── BitSlice / BitSliceMut impls ────────────────────────────────────────────── // ── BitSlice / BitSliceMut impls ──────────────────────────────────────────────
+6
View File
@@ -144,8 +144,14 @@ use crate::traits::{IntSlice, IntSliceMut};
impl IntSlice for PersistentCompactIntVecBuilder { impl IntSlice for PersistentCompactIntVecBuilder {
fn len(&self) -> usize { self.n } fn len(&self) -> usize { self.n }
fn get(&self, slot: usize) -> u32 { self.get(slot) } fn get(&self, slot: usize) -> u32 { self.get(slot) }
fn primary_bytes(&self) -> &[u8] { &self.mmap[HEADER_SIZE..HEADER_SIZE + self.n] }
fn overflow_entries(&self) -> impl Iterator<Item = (usize, u32)> + '_ {
self.overflow.iter().map(|(&k, &v)| (k, v))
}
} }
impl IntSliceMut for PersistentCompactIntVecBuilder { impl IntSliceMut for PersistentCompactIntVecBuilder {
fn set(&mut self, slot: usize, value: u32) { self.set(slot, value); } fn set(&mut self, slot: usize, value: u32) { self.set(slot, value); }
fn primary_bytes_mut(&mut self) -> &mut [u8] { &mut self.mmap[HEADER_SIZE..HEADER_SIZE + self.n] }
fn clear_overflow(&mut self) { self.overflow.clear(); }
} }
+20 -6
View File
@@ -76,6 +76,10 @@ impl MemoryIntVec {
impl IntSlice for MemoryIntVec { impl IntSlice for MemoryIntVec {
fn len(&self) -> usize { self.n } fn len(&self) -> usize { self.n }
fn get(&self, slot: usize) -> u32 { self.get(slot) } fn get(&self, slot: usize) -> u32 { self.get(slot) }
fn primary_bytes(&self) -> &[u8] { &self.primary }
fn overflow_entries(&self) -> impl Iterator<Item = (usize, u32)> + '_ {
self.overflow.iter().map(|(&k, &v)| (k, v))
}
fn sum(&self) -> u64 { self.sum() } fn sum(&self) -> u64 { self.sum() }
fn count_nonzero(&self) -> u64 { self.count_nonzero() } fn count_nonzero(&self) -> u64 { self.count_nonzero() }
} }
@@ -90,18 +94,28 @@ impl IntSliceMut for MemoryIntVec {
self.overflow.insert(slot, value); self.overflow.insert(slot, value);
} }
} }
fn primary_bytes_mut(&mut self) -> &mut [u8] { &mut self.primary }
fn clear_overflow(&mut self) { self.overflow.clear(); }
} }
// ── From conversions ────────────────────────────────────────────────────────── // ── From conversions ──────────────────────────────────────────────────────────
impl MemoryIntVec {
/// Bulk copy from another `MemoryIntVec`: memcpy for the primary bytes,
/// clone for the overflow map.
pub fn copy_from_memory(&mut self, src: &MemoryIntVec) {
assert_eq!(self.n, src.n, "MemoryIntVec length mismatch");
self.primary.copy_from_slice(&src.primary);
self.overflow = src.overflow.clone();
}
}
impl<S: IntSlice> From<&S> for MemoryIntVec { impl<S: IntSlice> From<&S> for MemoryIntVec {
fn from(src: &S) -> Self { fn from(src: &S) -> Self {
let mut v = Self::new(src.len()); Self::from_primary_and_overflow(
for slot in 0..src.len() { src.primary_bytes().to_vec(),
let val = src.get(slot); src.overflow_entries().collect(),
if val != 0 { v.set(slot, val); } )
}
v
} }
} }
-6
View File
@@ -38,12 +38,6 @@ impl MemoryBitVec {
(self.words[slot >> 6] >> (slot & 63)) & 1 != 0 (self.words[slot >> 6] >> (slot & 63)) & 1 != 0
} }
pub fn count_ones(&self) -> u64 {
self.words.iter().map(|w| w.count_ones() as u64).sum()
}
pub fn count_zeros(&self) -> u64 { self.n as u64 - self.count_ones() }
/// Write to disk and return a writable builder positioned at the same path. /// Write to disk and return a writable builder positioned at the same path.
pub fn persist(&self, path: &Path) -> io::Result<PersistentBitVecBuilder> { pub fn persist(&self, path: &Path) -> io::Result<PersistentBitVecBuilder> {
let mut b = PersistentBitVecBuilder::new(self.n, path)?; let mut b = PersistentBitVecBuilder::new(self.n, path)?;
+6
View File
@@ -357,6 +357,12 @@ use crate::traits::IntSlice;
impl IntSlice for PersistentCompactIntVec { impl IntSlice for PersistentCompactIntVec {
fn len(&self) -> usize { self.n } fn len(&self) -> usize { self.n }
fn get(&self, slot: usize) -> u32 { self.get(slot) } fn get(&self, slot: usize) -> u32 { self.get(slot) }
fn primary_bytes(&self) -> &[u8] {
&self.mmap[self.primary_offset..self.primary_offset + self.n]
}
fn overflow_entries(&self) -> impl Iterator<Item = (usize, u32)> + '_ {
(0..self.n_overflow).map(|i| (self.data_slot(i), self.data_value(i)))
}
} }
impl<'a> IntoIterator for &'a PersistentCompactIntVec { impl<'a> IntoIterator for &'a PersistentCompactIntVec {
+1 -1
View File
@@ -1,7 +1,7 @@
use tempfile::tempdir; use tempfile::tempdir;
use crate::{PersistentBitMatrix, PersistentBitMatrixBuilder}; use crate::{PersistentBitMatrix, PersistentBitMatrixBuilder};
use crate::traits::{BitPartials, BitSliceMut}; use crate::traits::{BitPartials, BitSlice, BitSliceMut};
fn make_matrix(cols: &[&[bool]]) -> (tempfile::TempDir, PersistentBitMatrix) { fn make_matrix(cols: &[&[bool]]) -> (tempfile::TempDir, PersistentBitMatrix) {
let n = cols.first().map_or(0, |c| c.len()); let n = cols.first().map_or(0, |c| c.len());
+1 -1
View File
@@ -1,6 +1,6 @@
use tempfile::tempdir; use tempfile::tempdir;
use crate::traits::BitSliceMut; use crate::traits::{BitSlice, BitSliceMut};
use crate::{PersistentBitVec, PersistentBitVecBuilder, PersistentCompactIntVec, PersistentCompactIntVecBuilder}; use crate::{PersistentBitVec, PersistentBitVecBuilder, PersistentCompactIntVec, PersistentCompactIntVecBuilder};
fn make_bv(bits: &[bool]) -> (tempfile::TempDir, PersistentBitVec) { fn make_bv(bits: &[bool]) -> (tempfile::TempDir, PersistentBitVec) {
+47 -8
View File
@@ -13,6 +13,27 @@ pub trait BitSlice {
fn get(&self, slot: usize) -> bool { fn get(&self, slot: usize) -> bool {
(self.words()[slot >> 6] >> (slot & 63)) & 1 != 0 (self.words()[slot >> 6] >> (slot & 63)) & 1 != 0
} }
fn count_ones(&self) -> u64 {
self.words().iter().map(|w| w.count_ones() as u64).sum()
}
fn count_zeros(&self) -> u64 { self.len() as u64 - self.count_ones() }
fn partial_jaccard_dist<S: BitSlice>(&self, other: &S) -> (u64, u64) {
assert_eq!(self.len(), other.len(), "length mismatch");
self.words().iter().zip(other.words())
.fold((0u64, 0u64), |(i, u), (&a, &b)| {
(i + (a & b).count_ones() as u64, u + (a | b).count_ones() as u64)
})
}
fn jaccard_dist<S: BitSlice>(&self, other: &S) -> f64 {
let (inter, union) = self.partial_jaccard_dist(other);
if union == 0 { 0.0 } else { 1.0 - inter as f64 / union as f64 }
}
fn hamming_dist<S: BitSlice>(&self, other: &S) -> u64 {
assert_eq!(self.len(), other.len(), "length mismatch");
self.words().iter().zip(other.words())
.map(|(&a, &b)| (a ^ b).count_ones() as u64)
.sum()
}
} }
/// Mutable view over a bit-vector word array; default methods maintain the /// Mutable view over a bit-vector word array; default methods maintain the
@@ -66,6 +87,10 @@ pub trait BitSliceMut: BitSlice {
pub trait IntSlice { pub trait IntSlice {
fn len(&self) -> usize; fn len(&self) -> usize;
fn get(&self, slot: usize) -> u32; fn get(&self, slot: usize) -> u32;
/// Raw primary byte slice (sentinel 255 marks overflow slots).
fn primary_bytes(&self) -> &[u8];
/// Iterator over `(slot, true_value)` pairs for all overflow entries (value >= 255).
fn overflow_entries(&self) -> impl Iterator<Item = (usize, u32)> + '_;
fn is_empty(&self) -> bool { self.len() == 0 } fn is_empty(&self) -> bool { self.len() == 0 }
fn iter(&self) -> impl Iterator<Item = u32> + '_ { (0..self.len()).map(|i| self.get(i)) } fn iter(&self) -> impl Iterator<Item = u32> + '_ { (0..self.len()).map(|i| self.get(i)) }
fn sum(&self) -> u64 { self.iter().map(|v| v as u64).sum() } fn sum(&self) -> u64 { self.iter().map(|v| v as u64).sum() }
@@ -90,6 +115,8 @@ pub trait IntSlice {
/// compact encoding invariants on the implementor's side. /// compact encoding invariants on the implementor's side.
pub trait IntSliceMut: IntSlice { pub trait IntSliceMut: IntSlice {
fn set(&mut self, slot: usize, value: u32); fn set(&mut self, slot: usize, value: u32);
fn primary_bytes_mut(&mut self) -> &mut [u8];
fn clear_overflow(&mut self);
fn inc(&mut self, slot: usize) -> &mut Self { fn inc(&mut self, slot: usize) -> &mut Self {
let v = self.get(slot); let v = self.get(slot);
@@ -111,7 +138,9 @@ pub trait IntSliceMut: IntSlice {
fn copy_from<S: IntSlice>(&mut self, src: &S) -> &mut Self { fn copy_from<S: IntSlice>(&mut self, src: &S) -> &mut Self {
assert_eq!(self.len(), src.len(), "IntSlice length mismatch"); assert_eq!(self.len(), src.len(), "IntSlice length mismatch");
for s in 0..src.len() { self.set(s, src.get(s)); } self.primary_bytes_mut().copy_from_slice(src.primary_bytes());
self.clear_overflow();
for (slot, val) in src.overflow_entries() { self.set(slot, val); }
self self
} }
@@ -176,13 +205,26 @@ impl<T: IntSlice> IntToBit for T {}
use crate::memoryintvec::MemoryIntVec; use crate::memoryintvec::MemoryIntVec;
// Maps each byte value to its 8 constituent bits as individual u8 (0 or 1).
static EXPAND_BYTE: [[u8; 8]; 256] = {
let mut table = [[0u8; 8]; 256];
let mut b = 0usize;
while b < 256 {
let mut bit = 0usize;
while bit < 8 {
table[b][bit] = ((b >> bit) & 1) as u8;
bit += 1;
}
b += 1;
}
table
};
pub trait BitToInt: BitSlice { pub trait BitToInt: BitSlice {
fn to_intvec(&self) -> MemoryIntVec { fn to_intvec(&self) -> MemoryIntVec {
let n = self.len(); let n = self.len();
let mut primary = vec![0u8; n]; let mut primary = vec![0u8; n];
// Unpack u64 words: each byte within a word yields 8 output bytes.
// Values are always 0 or 1 → no overflow entries needed.
let words = self.words(); let words = self.words();
let full_words = n / 64; let full_words = n / 64;
@@ -190,14 +232,11 @@ pub trait BitToInt: BitSlice {
let base = w_idx * 64; let base = w_idx * 64;
for byte_off in 0..8usize { for byte_off in 0..8usize {
let byte = (word >> (byte_off * 8)) as u8; let byte = (word >> (byte_off * 8)) as u8;
let out = &mut primary[base + byte_off * 8..base + byte_off * 8 + 8]; primary[base + byte_off * 8..base + byte_off * 8 + 8]
for bit in 0..8usize { .copy_from_slice(&EXPAND_BYTE[byte as usize]);
out[bit] = (byte >> bit) & 1;
}
} }
} }
// Remaining bits in the last partial word
let rem = n % 64; let rem = n % 64;
if rem > 0 { if rem > 0 {
let word = words[full_words]; let word = words[full_words];