perf: optimize vec arithmetic and add overflow tests
Refactor `cmp_scalar`, `min`, `max`, `add`, and `diff` to operate directly on the primary byte array, deferring overflow slot resolution to a secondary pass. This eliminates HashMap lookups in the hot path and enables SIMD vectorization. Add six unit tests to validate correct promotion and demotion between storage slots when values cross the 255 threshold.
This commit is contained in:
@@ -305,6 +305,110 @@ fn count_bits_skips_zero_words() {
|
|||||||
assert_eq!(count.get(127), 1);
|
assert_eq!(count.get(127), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── min / max / add / diff — overflow edge cases ──────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn miv_min_overflow_edges() {
|
||||||
|
// [300, 50, 400, 300] min [50, 300, 500, 200]
|
||||||
|
// slot 0: self=overflow(300), other=primary(50) → 50 (overflow removed)
|
||||||
|
// slot 1: self=primary(50), other=overflow(300) → 50 (no overflow created)
|
||||||
|
// slot 2: self=overflow(400), other=overflow(500) → 400 (overflow updated)
|
||||||
|
// slot 3: self=overflow(300), other=primary(200) → 200 (overflow removed, 200 < 255)
|
||||||
|
let mut a = MemoryIntVec::new(4);
|
||||||
|
a.set(0, 300); a.set(1, 50); a.set(2, 400); a.set(3, 300);
|
||||||
|
let mut b = MemoryIntVec::new(4);
|
||||||
|
b.set(0, 50); b.set(1, 300); b.set(2, 500); b.set(3, 200);
|
||||||
|
IntSliceMut::min(&mut a, &b);
|
||||||
|
assert_eq!(a.get(0), 50);
|
||||||
|
assert_eq!(a.get(1), 50);
|
||||||
|
assert_eq!(a.get(2), 400);
|
||||||
|
assert_eq!(a.get(3), 200);
|
||||||
|
// Only slot 2 should still have an overflow entry.
|
||||||
|
let ov: std::collections::HashMap<usize, u32> = a.overflow_entries().collect();
|
||||||
|
assert_eq!(ov.len(), 1);
|
||||||
|
assert_eq!(ov[&2], 400);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn miv_max_overflow_edges() {
|
||||||
|
// [50, 300, 100, 400] max [300, 50, 500, 200]
|
||||||
|
// slot 0: self=primary(50), other=overflow(300) → 300 (overflow created)
|
||||||
|
// slot 1: self=overflow(300), other=primary(50) → 300 (overflow unchanged)
|
||||||
|
// slot 2: self=primary(100), other=overflow(500) → 500 (overflow created)
|
||||||
|
// slot 3: self=overflow(400), other=overflow(200) → 400 (overflow unchanged, 200 < 255 wait...)
|
||||||
|
// Wait — 200 < 255 so other slot 3 is NOT overflow. Correct: max(400, 200) = 400.
|
||||||
|
let mut a = MemoryIntVec::new(4);
|
||||||
|
a.set(0, 50); a.set(1, 300); a.set(2, 100); a.set(3, 400);
|
||||||
|
let mut b = MemoryIntVec::new(4);
|
||||||
|
b.set(0, 300); b.set(1, 50); b.set(2, 500); b.set(3, 200);
|
||||||
|
IntSliceMut::max(&mut a, &b);
|
||||||
|
assert_eq!(a.get(0), 300);
|
||||||
|
assert_eq!(a.get(1), 300);
|
||||||
|
assert_eq!(a.get(2), 500);
|
||||||
|
assert_eq!(a.get(3), 400);
|
||||||
|
let ov: std::collections::HashMap<usize, u32> = a.overflow_entries().collect();
|
||||||
|
assert_eq!(ov.len(), 4); // all four results >= 255
|
||||||
|
assert_eq!(ov[&0], 300);
|
||||||
|
assert_eq!(ov[&1], 300);
|
||||||
|
assert_eq!(ov[&2], 500);
|
||||||
|
assert_eq!(ov[&3], 400);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn miv_add_overflow_edges() {
|
||||||
|
// [300, 50, 400, 200] + [50, 300, 200, 200]
|
||||||
|
// slot 0: self=overflow(300), other=primary(50) → 350 (overflow updated)
|
||||||
|
// slot 1: self=primary(50), other=overflow(300) → 350 (overflow created from primary)
|
||||||
|
// slot 2: self=overflow(400), other=overflow(200... wait 200 < 255)
|
||||||
|
// other slot 2 is primary(200); 400+200=600 (overflow updated)
|
||||||
|
// slot 3: self=primary(200), other=primary(200) → 400 (overflow created, 400 >= 255)
|
||||||
|
let mut a = MemoryIntVec::new(4);
|
||||||
|
a.set(0, 300); a.set(1, 50); a.set(2, 400); a.set(3, 200);
|
||||||
|
let mut b = MemoryIntVec::new(4);
|
||||||
|
b.set(0, 50); b.set(1, 300); b.set(2, 200); b.set(3, 200);
|
||||||
|
a.add(&b);
|
||||||
|
assert_eq!(a.get(0), 350);
|
||||||
|
assert_eq!(a.get(1), 350);
|
||||||
|
assert_eq!(a.get(2), 600);
|
||||||
|
assert_eq!(a.get(3), 400);
|
||||||
|
let ov: std::collections::HashMap<usize, u32> = a.overflow_entries().collect();
|
||||||
|
assert_eq!(ov.len(), 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn miv_add_both_overflow() {
|
||||||
|
// [300] + [400] = [700]
|
||||||
|
let mut a = MemoryIntVec::new(1);
|
||||||
|
a.set(0, 300);
|
||||||
|
let mut b = MemoryIntVec::new(1);
|
||||||
|
b.set(0, 400);
|
||||||
|
a.add(&b);
|
||||||
|
assert_eq!(a.get(0), 700);
|
||||||
|
let ov: std::collections::HashMap<usize, u32> = a.overflow_entries().collect();
|
||||||
|
assert_eq!(ov[&0], 700);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn miv_diff_overflow_edges() {
|
||||||
|
// [300, 400, 400, 50] - [100, 50, 350, 300]
|
||||||
|
// slot 0: self=overflow(300), other=primary(100) → 200 (overflow removed, 200 < 255)
|
||||||
|
// slot 1: self=overflow(400), other=primary(50) → 350 (overflow updated, 350 >= 255)
|
||||||
|
// slot 2: self=overflow(400), other=overflow(350) → 50 (overflow removed, 50 < 255)
|
||||||
|
// slot 3: self=primary(50), other=overflow(300) → 0 (saturating, stays primary)
|
||||||
|
let mut a = MemoryIntVec::new(4);
|
||||||
|
a.set(0, 300); a.set(1, 400); a.set(2, 400); a.set(3, 50);
|
||||||
|
let mut b = MemoryIntVec::new(4);
|
||||||
|
b.set(0, 100); b.set(1, 50); b.set(2, 350); b.set(3, 300);
|
||||||
|
a.diff(&b);
|
||||||
|
assert_eq!(a.get(0), 200);
|
||||||
|
assert_eq!(a.get(1), 350);
|
||||||
|
assert_eq!(a.get(2), 50);
|
||||||
|
assert_eq!(a.get(3), 0);
|
||||||
|
let ov: std::collections::HashMap<usize, u32> = a.overflow_entries().collect();
|
||||||
|
assert_eq!(ov.len(), 1); // only slot 1 remains overflow
|
||||||
|
assert_eq!(ov[&1], 350);
|
||||||
|
}
|
||||||
|
|
||||||
// ── Comparison operators ──────────────────────────────────────────────────────
|
// ── Comparison operators ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -340,6 +444,27 @@ fn cmp_leq() {
|
|||||||
assert!(bv.get(0)); assert!(bv.get(1)); assert!(!bv.get(2)); assert!(bv.get(3));
|
assert!(bv.get(0)); assert!(bv.get(1)); assert!(!bv.get(2)); assert!(bv.get(3));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cmp_scalar_with_overflow() {
|
||||||
|
// Slots: [10, 1000, 50, 500, 0]
|
||||||
|
// geq(100): slots 1 (1000) and 3 (500) → both overflow, must qualify
|
||||||
|
// lt(500): slots 0 (10), 2 (50), 4 (0) → primary; slot 1 (1000) → no; slot 3 (500) → no
|
||||||
|
// geq(2000): only slot 1 (1000) fails, no slot qualifies
|
||||||
|
let mut v = MemoryIntVec::new(5);
|
||||||
|
v.set(0, 10); v.set(1, 1000); v.set(2, 50); v.set(3, 500); v.set(4, 0);
|
||||||
|
|
||||||
|
let bv = v.geq(100);
|
||||||
|
assert!(!bv.get(0)); assert!(bv.get(1)); assert!(!bv.get(2));
|
||||||
|
assert!(bv.get(3)); assert!(!bv.get(4));
|
||||||
|
|
||||||
|
let bv = v.lt(500);
|
||||||
|
assert!(bv.get(0)); assert!(!bv.get(1)); assert!(bv.get(2));
|
||||||
|
assert!(!bv.get(3)); assert!(bv.get(4));
|
||||||
|
|
||||||
|
let bv = v.geq(2000);
|
||||||
|
assert!(!(0..5).any(|s| bv.get(s)));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn filter_pattern() {
|
fn filter_pattern() {
|
||||||
// Typical filter: ingroup >= min_count AND outgroup <= max_outgroup
|
// Typical filter: ingroup >= min_count AND outgroup <= max_outgroup
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use ndarray::{Array1, Array2};
|
use ndarray::{Array1, Array2};
|
||||||
|
|
||||||
// ── BitSlice / BitSliceMut ────────────────────────────────────────────────────
|
// ── BitSlice / BitSliceMut ────────────────────────────────────────────────────
|
||||||
@@ -104,8 +106,18 @@ pub trait IntSlice {
|
|||||||
fn cmp_scalar(&self, pred: impl Fn(u32) -> bool) -> MemoryBitVec {
|
fn cmp_scalar(&self, pred: impl Fn(u32) -> bool) -> MemoryBitVec {
|
||||||
let n = self.len();
|
let n = self.len();
|
||||||
let mut words = vec![0u64; n.div_ceil(64)];
|
let mut words = vec![0u64; n.div_ceil(64)];
|
||||||
|
let primary = self.primary_bytes();
|
||||||
|
// Pass 1: byte scan — no HashMap access, vectorisable for simple predicates.
|
||||||
|
// Overflow slots (b == 255) are left as 0 and fixed in pass 2.
|
||||||
for s in 0..n {
|
for s in 0..n {
|
||||||
if pred(self.get(s)) { words[s >> 6] |= 1u64 << (s & 63); }
|
let b = primary[s];
|
||||||
|
if b < 255 && pred(b as u32) {
|
||||||
|
words[s >> 6] |= 1u64 << (s & 63);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Pass 2: fix up overflow slots — O(k), negligible.
|
||||||
|
for (s, val) in self.overflow_entries() {
|
||||||
|
if pred(val) { words[s >> 6] |= 1u64 << (s & 63); }
|
||||||
}
|
}
|
||||||
MemoryBitVec::from_words(words, n)
|
MemoryBitVec::from_words(words, n)
|
||||||
}
|
}
|
||||||
@@ -146,25 +158,86 @@ pub trait IntSliceMut: IntSlice {
|
|||||||
|
|
||||||
fn min<S: IntSlice>(&mut self, other: &S) -> &mut Self {
|
fn min<S: IntSlice>(&mut self, other: &S) -> &mut Self {
|
||||||
assert_eq!(self.len(), other.len(), "IntSlice length mismatch");
|
assert_eq!(self.len(), other.len(), "IntSlice length mismatch");
|
||||||
for s in 0..other.len() { self.set(s, self.get(s).min(other.get(s))); }
|
// Snapshot both overflow sets (O(k), tiny) before mutating self.
|
||||||
|
// 255 = +∞ on u8, so byte-level min is correct in all cases except
|
||||||
|
// both-overflow: only those slots need a fixup pass.
|
||||||
|
let self_ov: Vec<(usize, u32)> = self.overflow_entries().collect();
|
||||||
|
let other_ov: HashMap<usize, u32> = other.overflow_entries().collect();
|
||||||
|
self.clear_overflow();
|
||||||
|
// Pass 1 — SIMD-vectorizable byte min over the full primary array.
|
||||||
|
for (a, &b) in self.primary_bytes_mut().iter_mut().zip(other.primary_bytes()) {
|
||||||
|
if b < *a { *a = b; }
|
||||||
|
}
|
||||||
|
// Pass 2 — fixup slots where BOTH sides were overflow (primary = 255 after pass 1,
|
||||||
|
// but the overflow value may have changed). Slots where only self was overflow are
|
||||||
|
// already correct: pass 1 wrote other.primary[slot] < 255 and clear_overflow removed
|
||||||
|
// the stale entry.
|
||||||
|
for (slot, self_val) in self_ov {
|
||||||
|
if let Some(&other_val) = other_ov.get(&slot) {
|
||||||
|
self.set(slot, self_val.min(other_val));
|
||||||
|
}
|
||||||
|
}
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max<S: IntSlice>(&mut self, other: &S) -> &mut Self {
|
fn max<S: IntSlice>(&mut self, other: &S) -> &mut Self {
|
||||||
assert_eq!(self.len(), other.len(), "IntSlice length mismatch");
|
assert_eq!(self.len(), other.len(), "IntSlice length mismatch");
|
||||||
for s in 0..other.len() { self.set(s, self.get(s).max(other.get(s))); }
|
// Pre-pass — process other's overflow entries BEFORE the byte pass.
|
||||||
|
// After the byte pass, self.primary[slot] = 255 for all slots in other_ov,
|
||||||
|
// making it impossible to recover the original self value; we need it now.
|
||||||
|
for (slot, other_val) in other.overflow_entries() {
|
||||||
|
let self_val = self.get(slot);
|
||||||
|
self.set(slot, self_val.max(other_val));
|
||||||
|
}
|
||||||
|
// Pass 1 — SIMD-vectorizable byte max over the full primary array.
|
||||||
|
// 255 = +∞ on u8 → max(a, 255) = 255 is the correct sentinel for all
|
||||||
|
// overflow slots, whether handled by the pre-pass or already in self.
|
||||||
|
for (a, &b) in self.primary_bytes_mut().iter_mut().zip(other.primary_bytes()) {
|
||||||
|
if b > *a { *a = b; }
|
||||||
|
}
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add<S: IntSlice>(&mut self, other: &S) -> &mut Self {
|
fn add<S: IntSlice>(&mut self, other: &S) -> &mut Self {
|
||||||
assert_eq!(self.len(), other.len(), "IntSlice length mismatch");
|
assert_eq!(self.len(), other.len(), "IntSlice length mismatch");
|
||||||
for s in 0..other.len() { self.set(s, self.get(s).saturating_add(other.get(s))); }
|
let n = self.len();
|
||||||
|
for s in 0..n {
|
||||||
|
// Read both primary bytes first — u8 is Copy, borrows released immediately.
|
||||||
|
let sb = self.primary_bytes()[s];
|
||||||
|
let ob = other.primary_bytes()[s];
|
||||||
|
if sb < 255 && ob < 255 {
|
||||||
|
// Hot path: no overflow lookup, no HashMap write in the common case.
|
||||||
|
let sum = sb as u32 + ob as u32;
|
||||||
|
if sum < 255 { self.primary_bytes_mut()[s] = sum as u8; }
|
||||||
|
else { self.set(s, sum); }
|
||||||
|
} else {
|
||||||
|
// At least one side is in overflow — get() is unavoidable.
|
||||||
|
let self_val = self.get(s);
|
||||||
|
let other_val = other.get(s);
|
||||||
|
self.set(s, self_val + other_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
fn diff<S: IntSlice>(&mut self, other: &S) -> &mut Self {
|
fn diff<S: IntSlice>(&mut self, other: &S) -> &mut Self {
|
||||||
assert_eq!(self.len(), other.len(), "IntSlice length mismatch");
|
assert_eq!(self.len(), other.len(), "IntSlice length mismatch");
|
||||||
for s in 0..other.len() { self.set(s, self.get(s).saturating_sub(other.get(s))); }
|
let n = self.len();
|
||||||
|
for s in 0..n {
|
||||||
|
let sb = self.primary_bytes()[s];
|
||||||
|
let ob = other.primary_bytes()[s];
|
||||||
|
if sb < 255 {
|
||||||
|
// Result is always < 255 — no overflow created or consulted.
|
||||||
|
// ob == 255 means b ≥ 255 > a, so saturating result = 0.
|
||||||
|
self.primary_bytes_mut()[s] = if ob < 255 { sb.saturating_sub(ob) } else { 0 };
|
||||||
|
} else {
|
||||||
|
// sb == 255: self has overflow — get() unavoidable.
|
||||||
|
// other.get() only needed when ob == 255 too (both-overflow case).
|
||||||
|
let self_val = self.get(s);
|
||||||
|
let other_val = if ob < 255 { ob as u32 } else { other.get(s) };
|
||||||
|
self.set(s, self_val.saturating_sub(other_val));
|
||||||
|
}
|
||||||
|
}
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user