242 lines
8.3 KiB
Rust
242 lines
8.3 KiB
Rust
|
|
use std::fs::{self, File, OpenOptions};
|
||
|
|
use std::io::{self, Seek, SeekFrom, Write as _};
|
||
|
|
use std::path::{Path, PathBuf};
|
||
|
|
|
||
|
|
use memmap2::{Mmap, MmapMut};
|
||
|
|
|
||
|
|
use crate::reader::PersistentCompactIntVec;
|
||
|
|
|
||
|
|
const MAGIC: [u8; 4] = *b"PBIV";
|
||
|
|
|
||
|
|
// Header: magic(4) + _pad(4) + n(8) = 16 bytes.
|
||
|
|
// Data starts at offset 16, which is divisible by 8 → u64-aligned
|
||
|
|
// (mmap base is page-aligned, 16 % 8 == 0).
|
||
|
|
const HEADER_SIZE: usize = 16;
|
||
|
|
|
||
|
|
#[inline]
|
||
|
|
fn n_words(n: usize) -> usize { n.div_ceil(64) }
|
||
|
|
|
||
|
|
#[inline]
|
||
|
|
fn n_bytes_for_words(n: usize) -> usize { n_words(n) * 8 }
|
||
|
|
|
||
|
|
// ── Reader ────────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
pub struct PersistentBitVec {
|
||
|
|
mmap: Mmap,
|
||
|
|
n: usize,
|
||
|
|
path: PathBuf,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl PersistentBitVec {
|
||
|
|
pub fn open(path: &Path) -> io::Result<Self> {
|
||
|
|
let mmap = unsafe { Mmap::map(&File::open(path)?)? };
|
||
|
|
if mmap.len() < HEADER_SIZE {
|
||
|
|
return Err(io::Error::new(io::ErrorKind::InvalidData, "PBIV file too short"));
|
||
|
|
}
|
||
|
|
if &mmap[0..4] != &MAGIC {
|
||
|
|
return Err(io::Error::new(io::ErrorKind::InvalidData, "bad PBIV magic"));
|
||
|
|
}
|
||
|
|
let n = u64::from_le_bytes(mmap[8..16].try_into().unwrap()) as usize;
|
||
|
|
Ok(Self { mmap, n, path: path.to_path_buf() })
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn path(&self) -> &Path { &self.path }
|
||
|
|
pub fn len(&self) -> usize { self.n }
|
||
|
|
pub fn is_empty(&self) -> bool { self.n == 0 }
|
||
|
|
|
||
|
|
pub fn get(&self, slot: usize) -> bool {
|
||
|
|
(self.mmap[HEADER_SIZE + (slot >> 3)] >> (slot & 7)) & 1 != 0
|
||
|
|
}
|
||
|
|
|
||
|
|
// Used by iter() and get(): exact byte window, no padding.
|
||
|
|
fn data_bytes(&self) -> &[u8] {
|
||
|
|
&self.mmap[HEADER_SIZE..HEADER_SIZE + self.n.div_ceil(8)]
|
||
|
|
}
|
||
|
|
|
||
|
|
// Bulk word view. SAFETY: mmap is page-aligned, HEADER_SIZE=16 is divisible by 8,
|
||
|
|
// so &mmap[HEADER_SIZE] is u64-aligned. Slice length is n_words * 8 bytes.
|
||
|
|
fn data_words(&self) -> &[u64] {
|
||
|
|
let nw = n_words(self.n);
|
||
|
|
let ptr = self.mmap[HEADER_SIZE..].as_ptr() as *const u64;
|
||
|
|
unsafe { std::slice::from_raw_parts(ptr, nw) }
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn count_ones(&self) -> u64 {
|
||
|
|
// Padding bits in the last word are 0, so no masking needed.
|
||
|
|
self.data_words().iter().map(|w| w.count_ones() as u64).sum()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn count_zeros(&self) -> u64 {
|
||
|
|
self.n as u64 - self.count_ones()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn jaccard_dist(&self, other: &PersistentBitVec) -> f64 {
|
||
|
|
assert_eq!(self.n, other.n, "length mismatch");
|
||
|
|
let (inter, union) = self.data_words().iter().zip(other.data_words()).fold(
|
||
|
|
(0u64, 0u64),
|
||
|
|
|(i, u), (&a, &b)| (i + (a & b).count_ones() as u64, u + (a | b).count_ones() as u64),
|
||
|
|
);
|
||
|
|
if union == 0 { return 0.0; }
|
||
|
|
1.0 - inter as f64 / union as f64
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn hamming_dist(&self, other: &PersistentBitVec) -> u64 {
|
||
|
|
assert_eq!(self.n, other.n, "length mismatch");
|
||
|
|
self.data_words().iter().zip(other.data_words())
|
||
|
|
.map(|(&a, &b)| (a ^ b).count_ones() as u64)
|
||
|
|
.sum()
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn iter(&self) -> BitIter<'_> {
|
||
|
|
BitIter { bytes: self.data_bytes(), slot: 0, n: self.n }
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
impl<'a> IntoIterator for &'a PersistentBitVec {
|
||
|
|
type Item = bool;
|
||
|
|
type IntoIter = BitIter<'a>;
|
||
|
|
fn into_iter(self) -> BitIter<'a> { self.iter() }
|
||
|
|
}
|
||
|
|
|
||
|
|
pub struct BitIter<'a> {
|
||
|
|
bytes: &'a [u8],
|
||
|
|
slot: usize,
|
||
|
|
n: usize,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl ExactSizeIterator for BitIter<'_> {}
|
||
|
|
|
||
|
|
impl Iterator for BitIter<'_> {
|
||
|
|
type Item = bool;
|
||
|
|
|
||
|
|
fn next(&mut self) -> Option<bool> {
|
||
|
|
if self.slot >= self.n { return None; }
|
||
|
|
let v = (self.bytes[self.slot >> 3] >> (self.slot & 7)) & 1 != 0;
|
||
|
|
self.slot += 1;
|
||
|
|
Some(v)
|
||
|
|
}
|
||
|
|
|
||
|
|
fn size_hint(&self) -> (usize, Option<usize>) {
|
||
|
|
let rem = self.n - self.slot;
|
||
|
|
(rem, Some(rem))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// ── Builder ───────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
pub struct PersistentBitVecBuilder {
|
||
|
|
mmap: MmapMut,
|
||
|
|
n: usize,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl PersistentBitVecBuilder {
|
||
|
|
pub fn new(n: usize, path: &Path) -> io::Result<Self> {
|
||
|
|
let file_size = HEADER_SIZE + n_bytes_for_words(n);
|
||
|
|
let mut file = OpenOptions::new()
|
||
|
|
.read(true).write(true).create(true).truncate(true)
|
||
|
|
.open(path)?;
|
||
|
|
file.write_all(&MAGIC)?;
|
||
|
|
file.write_all(&[0u8; 4])?; // padding
|
||
|
|
file.write_all(&(n as u64).to_le_bytes())?;
|
||
|
|
file.seek(SeekFrom::Start(0))?;
|
||
|
|
file.set_len(file_size as u64)?;
|
||
|
|
let mmap = unsafe { MmapMut::map_mut(&file)? };
|
||
|
|
Ok(Self { mmap, n })
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn build_from(source: &PersistentBitVec, path: &Path) -> io::Result<Self> {
|
||
|
|
fs::copy(source.path(), path)?;
|
||
|
|
let file = OpenOptions::new().read(true).write(true).open(path)?;
|
||
|
|
let mmap = unsafe { MmapMut::map_mut(&file)? };
|
||
|
|
let n = source.len();
|
||
|
|
Ok(Self { mmap, n })
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn len(&self) -> usize { self.n }
|
||
|
|
pub fn is_empty(&self) -> bool { self.n == 0 }
|
||
|
|
|
||
|
|
pub fn get(&self, slot: usize) -> bool {
|
||
|
|
(self.mmap[HEADER_SIZE + (slot >> 3)] >> (slot & 7)) & 1 != 0
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn set(&mut self, slot: usize, value: bool) {
|
||
|
|
let byte = HEADER_SIZE + (slot >> 3);
|
||
|
|
let bit = 1u8 << (slot & 7);
|
||
|
|
if value { self.mmap[byte] |= bit; } else { self.mmap[byte] &= !bit; }
|
||
|
|
}
|
||
|
|
|
||
|
|
// SAFETY: same alignment argument as PersistentBitVec::data_words.
|
||
|
|
fn data_words_mut(&mut self) -> &mut [u64] {
|
||
|
|
let nw = n_words(self.n);
|
||
|
|
let ptr = self.mmap[HEADER_SIZE..].as_mut_ptr() as *mut u64;
|
||
|
|
unsafe { std::slice::from_raw_parts_mut(ptr, nw) }
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn and(&mut self, other: &PersistentBitVec) {
|
||
|
|
assert_eq!(self.n, other.n, "length mismatch");
|
||
|
|
for (sw, &ow) in self.data_words_mut().iter_mut().zip(other.data_words()) { *sw &= ow; }
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn or(&mut self, other: &PersistentBitVec) {
|
||
|
|
assert_eq!(self.n, other.n, "length mismatch");
|
||
|
|
for (sw, &ow) in self.data_words_mut().iter_mut().zip(other.data_words()) { *sw |= ow; }
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn xor(&mut self, other: &PersistentBitVec) {
|
||
|
|
assert_eq!(self.n, other.n, "length mismatch");
|
||
|
|
for (sw, &ow) in self.data_words_mut().iter_mut().zip(other.data_words()) { *sw ^= ow; }
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn not(&mut self) {
|
||
|
|
let rem = self.n % 64;
|
||
|
|
let words = self.data_words_mut();
|
||
|
|
for w in words.iter_mut() { *w ^= u64::MAX; }
|
||
|
|
// Zero padding bits in the last word so count_ones / jaccard remain correct.
|
||
|
|
if rem != 0 {
|
||
|
|
if let Some(last) = words.last_mut() { *last &= (1u64 << rem) - 1; }
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Convert a count vector to a bit vector: bit set iff count >= threshold.
|
||
|
|
/// Fills u64 words directly from the count iterator — O(n), no bit-level set() overhead.
|
||
|
|
pub fn build_from_counts(
|
||
|
|
source: &PersistentCompactIntVec,
|
||
|
|
threshold: u32,
|
||
|
|
path: &Path,
|
||
|
|
) -> io::Result<Self> {
|
||
|
|
let n = source.len();
|
||
|
|
let file_size = HEADER_SIZE + n_bytes_for_words(n);
|
||
|
|
let mut file = OpenOptions::new()
|
||
|
|
.read(true).write(true).create(true).truncate(true)
|
||
|
|
.open(path)?;
|
||
|
|
file.write_all(&MAGIC)?;
|
||
|
|
file.write_all(&[0u8; 4])?;
|
||
|
|
file.write_all(&(n as u64).to_le_bytes())?;
|
||
|
|
file.seek(SeekFrom::Start(0))?;
|
||
|
|
file.set_len(file_size as u64)?;
|
||
|
|
let mut mmap = unsafe { MmapMut::map_mut(&file)? };
|
||
|
|
|
||
|
|
{
|
||
|
|
let nw = n_words(n);
|
||
|
|
let ptr = mmap[HEADER_SIZE..].as_mut_ptr() as *mut u64;
|
||
|
|
let words = unsafe { std::slice::from_raw_parts_mut(ptr, nw) };
|
||
|
|
for (slot, count) in source.iter().enumerate() {
|
||
|
|
if count >= threshold {
|
||
|
|
words[slot >> 6] |= 1u64 << (slot & 63);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
Ok(Self { mmap, n })
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Convert a count vector to a presence/absence bit vector (threshold = 1).
|
||
|
|
pub fn build_from_presence(source: &PersistentCompactIntVec, path: &Path) -> io::Result<Self> {
|
||
|
|
Self::build_from_counts(source, 1, path)
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn close(self) -> io::Result<()> {
|
||
|
|
self.mmap.flush()
|
||
|
|
}
|
||
|
|
}
|