feat(bitvec): add partial Jaccard, fix padding, optimize constructor
Introduces `partial_jaccard_dist` to return raw intersection and union counts, improving Jaccard distance flexibility. Corrects `not()` to explicitly zero padding bits in the final word, ensuring accurate bit-counting for partially-filled words. Adds an optimized `build_from_counts` constructor.
This commit is contained in:
+104
-39
@@ -14,16 +14,20 @@ const MAGIC: [u8; 4] = *b"PBIV";
|
||||
const HEADER_SIZE: usize = 16;
|
||||
|
||||
#[inline]
|
||||
fn n_words(n: usize) -> usize { n.div_ceil(64) }
|
||||
fn n_words(n: usize) -> usize {
|
||||
n.div_ceil(64)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn n_bytes_for_words(n: usize) -> usize { n_words(n) * 8 }
|
||||
fn n_bytes_for_words(n: usize) -> usize {
|
||||
n_words(n) * 8
|
||||
}
|
||||
|
||||
// ── Reader ────────────────────────────────────────────────────────────────────
|
||||
|
||||
pub struct PersistentBitVec {
|
||||
mmap: Mmap,
|
||||
n: usize,
|
||||
n: usize,
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
@@ -31,18 +35,31 @@ impl PersistentBitVec {
|
||||
pub fn open(path: &Path) -> io::Result<Self> {
|
||||
let mmap = unsafe { Mmap::map(&File::open(path)?)? };
|
||||
if mmap.len() < HEADER_SIZE {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "PBIV file too short"));
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"PBIV file too short",
|
||||
));
|
||||
}
|
||||
if &mmap[0..4] != &MAGIC {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "bad PBIV magic"));
|
||||
}
|
||||
let n = u64::from_le_bytes(mmap[8..16].try_into().unwrap()) as usize;
|
||||
Ok(Self { mmap, n, path: path.to_path_buf() })
|
||||
Ok(Self {
|
||||
mmap,
|
||||
n,
|
||||
path: path.to_path_buf(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn path(&self) -> &Path { &self.path }
|
||||
pub fn len(&self) -> usize { self.n }
|
||||
pub fn is_empty(&self) -> bool { self.n == 0 }
|
||||
pub fn path(&self) -> &Path {
|
||||
&self.path
|
||||
}
|
||||
pub fn len(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.n == 0
|
||||
}
|
||||
|
||||
pub fn get(&self, slot: usize) -> bool {
|
||||
(self.mmap[HEADER_SIZE + (slot >> 3)] >> (slot & 7)) & 1 != 0
|
||||
@@ -56,14 +73,17 @@ impl PersistentBitVec {
|
||||
// Bulk word view. SAFETY: mmap is page-aligned, HEADER_SIZE=16 is divisible by 8,
|
||||
// so &mmap[HEADER_SIZE] is u64-aligned. Slice length is n_words * 8 bytes.
|
||||
fn data_words(&self) -> &[u64] {
|
||||
let nw = n_words(self.n);
|
||||
let nw = n_words(self.n);
|
||||
let ptr = self.mmap[HEADER_SIZE..].as_ptr() as *const u64;
|
||||
unsafe { std::slice::from_raw_parts(ptr, nw) }
|
||||
}
|
||||
|
||||
pub fn count_ones(&self) -> u64 {
|
||||
// Padding bits in the last word are 0, so no masking needed.
|
||||
self.data_words().iter().map(|w| w.count_ones() as u64).sum()
|
||||
self.data_words()
|
||||
.iter()
|
||||
.map(|w| w.count_ones() as u64)
|
||||
.sum()
|
||||
}
|
||||
|
||||
pub fn count_zeros(&self) -> u64 {
|
||||
@@ -71,37 +91,56 @@ impl PersistentBitVec {
|
||||
}
|
||||
|
||||
pub fn jaccard_dist(&self, other: &PersistentBitVec) -> f64 {
|
||||
assert_eq!(self.n, other.n, "length mismatch");
|
||||
let (inter, union) = self.data_words().iter().zip(other.data_words()).fold(
|
||||
(0u64, 0u64),
|
||||
|(i, u), (&a, &b)| (i + (a & b).count_ones() as u64, u + (a | b).count_ones() as u64),
|
||||
);
|
||||
if union == 0 { return 0.0; }
|
||||
let (inter, union) = self.partial_jaccard_dist(other);
|
||||
if union == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
1.0 - inter as f64 / union as f64
|
||||
}
|
||||
|
||||
pub fn partial_jaccard_dist(&self, other: &PersistentBitVec) -> (u64, u64) {
|
||||
assert_eq!(self.n, other.n, "length mismatch");
|
||||
self.data_words()
|
||||
.iter()
|
||||
.zip(other.data_words())
|
||||
.fold((0u64, 0u64), |(i, u), (&a, &b)| {
|
||||
(
|
||||
i + (a & b).count_ones() as u64,
|
||||
u + (a | b).count_ones() as u64,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn hamming_dist(&self, other: &PersistentBitVec) -> u64 {
|
||||
assert_eq!(self.n, other.n, "length mismatch");
|
||||
self.data_words().iter().zip(other.data_words())
|
||||
self.data_words()
|
||||
.iter()
|
||||
.zip(other.data_words())
|
||||
.map(|(&a, &b)| (a ^ b).count_ones() as u64)
|
||||
.sum()
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> BitIter<'_> {
|
||||
BitIter { bytes: self.data_bytes(), slot: 0, n: self.n }
|
||||
BitIter {
|
||||
bytes: self.data_bytes(),
|
||||
slot: 0,
|
||||
n: self.n,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> IntoIterator for &'a PersistentBitVec {
|
||||
type Item = bool;
|
||||
type IntoIter = BitIter<'a>;
|
||||
fn into_iter(self) -> BitIter<'a> { self.iter() }
|
||||
fn into_iter(self) -> BitIter<'a> {
|
||||
self.iter()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BitIter<'a> {
|
||||
bytes: &'a [u8],
|
||||
slot: usize,
|
||||
n: usize,
|
||||
slot: usize,
|
||||
n: usize,
|
||||
}
|
||||
|
||||
impl ExactSizeIterator for BitIter<'_> {}
|
||||
@@ -110,7 +149,9 @@ impl Iterator for BitIter<'_> {
|
||||
type Item = bool;
|
||||
|
||||
fn next(&mut self) -> Option<bool> {
|
||||
if self.slot >= self.n { return None; }
|
||||
if self.slot >= self.n {
|
||||
return None;
|
||||
}
|
||||
let v = (self.bytes[self.slot >> 3] >> (self.slot & 7)) & 1 != 0;
|
||||
self.slot += 1;
|
||||
Some(v)
|
||||
@@ -126,17 +167,20 @@ impl Iterator for BitIter<'_> {
|
||||
|
||||
pub struct PersistentBitVecBuilder {
|
||||
mmap: MmapMut,
|
||||
n: usize,
|
||||
n: usize,
|
||||
}
|
||||
|
||||
impl PersistentBitVecBuilder {
|
||||
pub fn new(n: usize, path: &Path) -> io::Result<Self> {
|
||||
let file_size = HEADER_SIZE + n_bytes_for_words(n);
|
||||
let mut file = OpenOptions::new()
|
||||
.read(true).write(true).create(true).truncate(true)
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.open(path)?;
|
||||
file.write_all(&MAGIC)?;
|
||||
file.write_all(&[0u8; 4])?; // padding
|
||||
file.write_all(&[0u8; 4])?; // padding
|
||||
file.write_all(&(n as u64).to_le_bytes())?;
|
||||
file.seek(SeekFrom::Start(0))?;
|
||||
file.set_len(file_size as u64)?;
|
||||
@@ -152,8 +196,12 @@ impl PersistentBitVecBuilder {
|
||||
Ok(Self { mmap, n })
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize { self.n }
|
||||
pub fn is_empty(&self) -> bool { self.n == 0 }
|
||||
pub fn len(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.n == 0
|
||||
}
|
||||
|
||||
pub fn get(&self, slot: usize) -> bool {
|
||||
(self.mmap[HEADER_SIZE + (slot >> 3)] >> (slot & 7)) & 1 != 0
|
||||
@@ -161,53 +209,70 @@ impl PersistentBitVecBuilder {
|
||||
|
||||
pub fn set(&mut self, slot: usize, value: bool) {
|
||||
let byte = HEADER_SIZE + (slot >> 3);
|
||||
let bit = 1u8 << (slot & 7);
|
||||
if value { self.mmap[byte] |= bit; } else { self.mmap[byte] &= !bit; }
|
||||
let bit = 1u8 << (slot & 7);
|
||||
if value {
|
||||
self.mmap[byte] |= bit;
|
||||
} else {
|
||||
self.mmap[byte] &= !bit;
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: same alignment argument as PersistentBitVec::data_words.
|
||||
fn data_words_mut(&mut self) -> &mut [u64] {
|
||||
let nw = n_words(self.n);
|
||||
let nw = n_words(self.n);
|
||||
let ptr = self.mmap[HEADER_SIZE..].as_mut_ptr() as *mut u64;
|
||||
unsafe { std::slice::from_raw_parts_mut(ptr, nw) }
|
||||
}
|
||||
|
||||
pub fn and(&mut self, other: &PersistentBitVec) {
|
||||
assert_eq!(self.n, other.n, "length mismatch");
|
||||
for (sw, &ow) in self.data_words_mut().iter_mut().zip(other.data_words()) { *sw &= ow; }
|
||||
for (sw, &ow) in self.data_words_mut().iter_mut().zip(other.data_words()) {
|
||||
*sw &= ow;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn or(&mut self, other: &PersistentBitVec) {
|
||||
assert_eq!(self.n, other.n, "length mismatch");
|
||||
for (sw, &ow) in self.data_words_mut().iter_mut().zip(other.data_words()) { *sw |= ow; }
|
||||
for (sw, &ow) in self.data_words_mut().iter_mut().zip(other.data_words()) {
|
||||
*sw |= ow;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn xor(&mut self, other: &PersistentBitVec) {
|
||||
assert_eq!(self.n, other.n, "length mismatch");
|
||||
for (sw, &ow) in self.data_words_mut().iter_mut().zip(other.data_words()) { *sw ^= ow; }
|
||||
for (sw, &ow) in self.data_words_mut().iter_mut().zip(other.data_words()) {
|
||||
*sw ^= ow;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn not(&mut self) {
|
||||
let rem = self.n % 64;
|
||||
let words = self.data_words_mut();
|
||||
for w in words.iter_mut() { *w ^= u64::MAX; }
|
||||
for w in words.iter_mut() {
|
||||
*w ^= u64::MAX;
|
||||
}
|
||||
// Zero padding bits in the last word so count_ones / jaccard remain correct.
|
||||
if rem != 0 {
|
||||
if let Some(last) = words.last_mut() { *last &= (1u64 << rem) - 1; }
|
||||
if let Some(last) = words.last_mut() {
|
||||
*last &= (1u64 << rem) - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a count vector to a bit vector: bit set iff count >= threshold.
|
||||
/// Fills u64 words directly from the count iterator — O(n), no bit-level set() overhead.
|
||||
pub fn build_from_counts(
|
||||
source: &PersistentCompactIntVec,
|
||||
source: &PersistentCompactIntVec,
|
||||
threshold: u32,
|
||||
path: &Path,
|
||||
path: &Path,
|
||||
) -> io::Result<Self> {
|
||||
let n = source.len();
|
||||
let file_size = HEADER_SIZE + n_bytes_for_words(n);
|
||||
let mut file = OpenOptions::new()
|
||||
.read(true).write(true).create(true).truncate(true)
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.open(path)?;
|
||||
file.write_all(&MAGIC)?;
|
||||
file.write_all(&[0u8; 4])?;
|
||||
@@ -217,7 +282,7 @@ impl PersistentBitVecBuilder {
|
||||
let mut mmap = unsafe { MmapMut::map_mut(&file)? };
|
||||
|
||||
{
|
||||
let nw = n_words(n);
|
||||
let nw = n_words(n);
|
||||
let ptr = mmap[HEADER_SIZE..].as_mut_ptr() as *mut u64;
|
||||
let words = unsafe { std::slice::from_raw_parts_mut(ptr, nw) };
|
||||
for (slot, count) in source.iter().enumerate() {
|
||||
|
||||
Reference in New Issue
Block a user