Files
obikmer/src/obicompactvec/src/bitvec.rs
T

242 lines
8.3 KiB
Rust
Raw Normal View History

use std::fs::{self, File, OpenOptions};
use std::io::{self, Seek, SeekFrom, Write as _};
use std::path::{Path, PathBuf};
use memmap2::{Mmap, MmapMut};
use crate::reader::PersistentCompactIntVec;
const MAGIC: [u8; 4] = *b"PBIV";
// Header: magic(4) + _pad(4) + n(8) = 16 bytes.
// Data starts at offset 16, which is divisible by 8 → u64-aligned
// (mmap base is page-aligned, 16 % 8 == 0).
const HEADER_SIZE: usize = 16;
#[inline]
fn n_words(n: usize) -> usize { n.div_ceil(64) }
#[inline]
fn n_bytes_for_words(n: usize) -> usize { n_words(n) * 8 }
// ── Reader ────────────────────────────────────────────────────────────────────
pub struct PersistentBitVec {
mmap: Mmap,
n: usize,
path: PathBuf,
}
impl PersistentBitVec {
pub fn open(path: &Path) -> io::Result<Self> {
let mmap = unsafe { Mmap::map(&File::open(path)?)? };
if mmap.len() < HEADER_SIZE {
return Err(io::Error::new(io::ErrorKind::InvalidData, "PBIV file too short"));
}
if &mmap[0..4] != &MAGIC {
return Err(io::Error::new(io::ErrorKind::InvalidData, "bad PBIV magic"));
}
let n = u64::from_le_bytes(mmap[8..16].try_into().unwrap()) as usize;
Ok(Self { mmap, n, path: path.to_path_buf() })
}
pub fn path(&self) -> &Path { &self.path }
pub fn len(&self) -> usize { self.n }
pub fn is_empty(&self) -> bool { self.n == 0 }
pub fn get(&self, slot: usize) -> bool {
(self.mmap[HEADER_SIZE + (slot >> 3)] >> (slot & 7)) & 1 != 0
}
// Used by iter() and get(): exact byte window, no padding.
fn data_bytes(&self) -> &[u8] {
&self.mmap[HEADER_SIZE..HEADER_SIZE + self.n.div_ceil(8)]
}
// Bulk word view. SAFETY: mmap is page-aligned, HEADER_SIZE=16 is divisible by 8,
// so &mmap[HEADER_SIZE] is u64-aligned. Slice length is n_words * 8 bytes.
fn data_words(&self) -> &[u64] {
let nw = n_words(self.n);
let ptr = self.mmap[HEADER_SIZE..].as_ptr() as *const u64;
unsafe { std::slice::from_raw_parts(ptr, nw) }
}
pub fn count_ones(&self) -> u64 {
// Padding bits in the last word are 0, so no masking needed.
self.data_words().iter().map(|w| w.count_ones() as u64).sum()
}
pub fn count_zeros(&self) -> u64 {
self.n as u64 - self.count_ones()
}
pub fn jaccard_dist(&self, other: &PersistentBitVec) -> f64 {
assert_eq!(self.n, other.n, "length mismatch");
let (inter, union) = self.data_words().iter().zip(other.data_words()).fold(
(0u64, 0u64),
|(i, u), (&a, &b)| (i + (a & b).count_ones() as u64, u + (a | b).count_ones() as u64),
);
if union == 0 { return 0.0; }
1.0 - inter as f64 / union as f64
}
pub fn hamming_dist(&self, other: &PersistentBitVec) -> u64 {
assert_eq!(self.n, other.n, "length mismatch");
self.data_words().iter().zip(other.data_words())
.map(|(&a, &b)| (a ^ b).count_ones() as u64)
.sum()
}
pub fn iter(&self) -> BitIter<'_> {
BitIter { bytes: self.data_bytes(), slot: 0, n: self.n }
}
}
impl<'a> IntoIterator for &'a PersistentBitVec {
type Item = bool;
type IntoIter = BitIter<'a>;
fn into_iter(self) -> BitIter<'a> { self.iter() }
}
pub struct BitIter<'a> {
bytes: &'a [u8],
slot: usize,
n: usize,
}
impl ExactSizeIterator for BitIter<'_> {}
impl Iterator for BitIter<'_> {
type Item = bool;
fn next(&mut self) -> Option<bool> {
if self.slot >= self.n { return None; }
let v = (self.bytes[self.slot >> 3] >> (self.slot & 7)) & 1 != 0;
self.slot += 1;
Some(v)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let rem = self.n - self.slot;
(rem, Some(rem))
}
}
// ── Builder ───────────────────────────────────────────────────────────────────
pub struct PersistentBitVecBuilder {
mmap: MmapMut,
n: usize,
}
impl PersistentBitVecBuilder {
pub fn new(n: usize, path: &Path) -> io::Result<Self> {
let file_size = HEADER_SIZE + n_bytes_for_words(n);
let mut file = OpenOptions::new()
.read(true).write(true).create(true).truncate(true)
.open(path)?;
file.write_all(&MAGIC)?;
file.write_all(&[0u8; 4])?; // padding
file.write_all(&(n as u64).to_le_bytes())?;
file.seek(SeekFrom::Start(0))?;
file.set_len(file_size as u64)?;
let mmap = unsafe { MmapMut::map_mut(&file)? };
Ok(Self { mmap, n })
}
pub fn build_from(source: &PersistentBitVec, path: &Path) -> io::Result<Self> {
fs::copy(source.path(), path)?;
let file = OpenOptions::new().read(true).write(true).open(path)?;
let mmap = unsafe { MmapMut::map_mut(&file)? };
let n = source.len();
Ok(Self { mmap, n })
}
pub fn len(&self) -> usize { self.n }
pub fn is_empty(&self) -> bool { self.n == 0 }
pub fn get(&self, slot: usize) -> bool {
(self.mmap[HEADER_SIZE + (slot >> 3)] >> (slot & 7)) & 1 != 0
}
pub fn set(&mut self, slot: usize, value: bool) {
let byte = HEADER_SIZE + (slot >> 3);
let bit = 1u8 << (slot & 7);
if value { self.mmap[byte] |= bit; } else { self.mmap[byte] &= !bit; }
}
// SAFETY: same alignment argument as PersistentBitVec::data_words.
fn data_words_mut(&mut self) -> &mut [u64] {
let nw = n_words(self.n);
let ptr = self.mmap[HEADER_SIZE..].as_mut_ptr() as *mut u64;
unsafe { std::slice::from_raw_parts_mut(ptr, nw) }
}
pub fn and(&mut self, other: &PersistentBitVec) {
assert_eq!(self.n, other.n, "length mismatch");
for (sw, &ow) in self.data_words_mut().iter_mut().zip(other.data_words()) { *sw &= ow; }
}
pub fn or(&mut self, other: &PersistentBitVec) {
assert_eq!(self.n, other.n, "length mismatch");
for (sw, &ow) in self.data_words_mut().iter_mut().zip(other.data_words()) { *sw |= ow; }
}
pub fn xor(&mut self, other: &PersistentBitVec) {
assert_eq!(self.n, other.n, "length mismatch");
for (sw, &ow) in self.data_words_mut().iter_mut().zip(other.data_words()) { *sw ^= ow; }
}
pub fn not(&mut self) {
let rem = self.n % 64;
let words = self.data_words_mut();
for w in words.iter_mut() { *w ^= u64::MAX; }
// Zero padding bits in the last word so count_ones / jaccard remain correct.
if rem != 0 {
if let Some(last) = words.last_mut() { *last &= (1u64 << rem) - 1; }
}
}
/// Convert a count vector to a bit vector: bit set iff count >= threshold.
/// Fills u64 words directly from the count iterator — O(n), no bit-level set() overhead.
pub fn build_from_counts(
source: &PersistentCompactIntVec,
threshold: u32,
path: &Path,
) -> io::Result<Self> {
let n = source.len();
let file_size = HEADER_SIZE + n_bytes_for_words(n);
let mut file = OpenOptions::new()
.read(true).write(true).create(true).truncate(true)
.open(path)?;
file.write_all(&MAGIC)?;
file.write_all(&[0u8; 4])?;
file.write_all(&(n as u64).to_le_bytes())?;
file.seek(SeekFrom::Start(0))?;
file.set_len(file_size as u64)?;
let mut mmap = unsafe { MmapMut::map_mut(&file)? };
{
let nw = n_words(n);
let ptr = mmap[HEADER_SIZE..].as_mut_ptr() as *mut u64;
let words = unsafe { std::slice::from_raw_parts_mut(ptr, nw) };
for (slot, count) in source.iter().enumerate() {
if count >= threshold {
words[slot >> 6] |= 1u64 << (slot & 63);
}
}
}
Ok(Self { mmap, n })
}
/// Convert a count vector to a presence/absence bit vector (threshold = 1).
pub fn build_from_presence(source: &PersistentCompactIntVec, path: &Path) -> io::Result<Self> {
Self::build_from_counts(source, 1, path)
}
pub fn close(self) -> io::Result<()> {
self.mmap.flush()
}
}