Files
obikmer/benchmark/build_reference_dist.py
T
Eric Coissac 469e53b6f5 Add genomic distance benchmarking suite and test data
Introduces scripts to compute and validate pairwise genomic distance matrices across multiple metrics. Updates the Makefile with build and comparison targets, adds .gitignore rules for generated outputs, and includes test CSV matrices and a Newick phylogenetic tree for validating the distance computation pipeline.
2026-06-22 18:24:30 +02:00

227 lines
9.7 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Compute reference pairwise distance matrices from per-specimen .npz kmer indexes.
Reads all .npz files in reference_index/ (each containing sorted uint64 `kmers`
and uint32 `counts`), computes all distance metrics supported by `obikmer distance`,
and writes one CSV per metric to reference_dist/.
Output CSV format matches `obikmer distance --output`:
- first row: "genome", then specimen names
- subsequent rows: specimen name, then float or int values
Metrics written
jaccard_dist.csv Jaccard distance (presence/absence)
shared_kmers.csv Shared-kmer count matrix (intersection size)
bray_curtis_dist.csv Bray-Curtis dissimilarity (raw counts)
relfreq_bray_curtis_dist.csv Bray-Curtis on relative frequencies
euclidean_dist.csv Euclidean distance (raw counts)
relfreq_euclidean_dist.csv Euclidean distance on relative frequencies
hellinger_dist.csv Hellinger distance
hellinger_euclidean_dist.csv Euclidean distance in Hellinger space
"""
import argparse
import sys
from pathlib import Path
import numpy as np
# ── pairwise helpers ──────────────────────────────────────────────────────────
def shared_indices(a_kmers: np.ndarray, b_kmers: np.ndarray):
"""Return index arrays (idx_a, idx_b) for kmers present in both sets.
Both arrays must be sorted uint64. Uses searchsorted: O(|B| log |A|).
"""
pos = np.searchsorted(a_kmers, b_kmers)
pos = np.clip(pos, 0, len(a_kmers) - 1)
mask = a_kmers[pos] == b_kmers
idx_b = np.where(mask)[0]
idx_a = pos[idx_b]
return idx_a, idx_b
def pairwise_stats(specimens: list[dict]) -> dict[str, np.ndarray]:
"""Compute all pairwise distance matrices at once.
Returns a dict metric_name → ndarray (n×n float64 or int64).
Each specimen dict has keys: name, kmers, counts.
"""
n = len(specimens)
# Pre-compute per-specimen scalars
kmer_counts = np.array([len(s['kmers']) for s in specimens], dtype=np.uint64)
count_sums = np.array([s['counts'].sum() for s in specimens], dtype=np.uint64)
# Per-specimen sum-of-squares (for Euclidean decomposition)
sq_sums = np.array([(s['counts'].astype(np.float64) ** 2).sum() for s in specimens])
# Allocate output matrices
shared_mat = np.zeros((n, n), dtype=np.uint64)
hamming_mat = np.zeros((n, n), dtype=np.float64)
jaccard_mat = np.zeros((n, n), dtype=np.float64)
bray_mat = np.zeros((n, n), dtype=np.float64)
relfreq_bray = np.zeros((n, n), dtype=np.float64)
euclidean_mat = np.zeros((n, n), dtype=np.float64)
relfreq_eucl = np.zeros((n, n), dtype=np.float64)
hellinger_mat = np.zeros((n, n), dtype=np.float64)
hell_eucl_mat = np.zeros((n, n), dtype=np.float64)
for i in range(n):
a_km = specimens[i]['kmers']
a_ct = specimens[i]['counts'].astype(np.float64)
sa = float(count_sums[i])
na = int(kmer_counts[i])
for j in range(i + 1, n):
b_km = specimens[j]['kmers']
b_ct = specimens[j]['counts'].astype(np.float64)
sb = float(count_sums[j])
nb = int(kmer_counts[j])
idx_a, idx_b = shared_indices(a_km, b_km)
inter = len(idx_a)
ca_sh = a_ct[idx_a]
cb_sh = b_ct[idx_b]
# ── Presence metrics ──────────────────────────────────────────────
union = na + nb - inter
jac = (1.0 - inter / union) if union else 0.0
hamming = float(na + nb - 2 * inter) # |A Δ B|
# ── Count metrics ─────────────────────────────────────────────────
# Bray-Curtis: 1 - 2*Σmin(a,b) / (Σa + Σb)
sum_min = np.minimum(ca_sh, cb_sh).sum()
denom_bc = sa + sb
bc = (1.0 - 2.0 * sum_min / denom_bc) if denom_bc else 0.0
# RelfreqBray: 1 - Σmin(a/sa, b/sb) [only shared contribute]
if sa and sb:
rfb = 1.0 - np.minimum(ca_sh / sa, cb_sh / sb).sum()
else:
rfb = 0.0
# Euclidean: √(Σa² + Σb² - 2·Σ(a·b)_shared)
cross = (ca_sh * cb_sh).sum()
eucl_partial = sq_sums[i] + sq_sums[j] - 2.0 * cross
eucl = np.sqrt(max(eucl_partial, 0.0))
# RelfreqEuclidean: √(Σ(a/sa - b/sb)²)
# = √(Σa²/sa² + Σb²/sb² - 2·Σ(a·b)_shared/(sa·sb))
if sa and sb:
rf_cross = (ca_sh / sa * (cb_sh / sb)).sum()
rfe_partial = (sq_sums[i] / sa**2
+ sq_sums[j] / sb**2
- 2.0 * rf_cross)
rfe = np.sqrt(max(rfe_partial, 0.0))
else:
rfe = 0.0
# Hellinger partial: Σ(√(a/sa) - √(b/sb))² over global universe
# = 2 - 2·Σ√(a·b)_shared / √(sa·sb)
if sa and sb:
bc_coeff = np.sqrt(ca_sh * cb_sh).sum() / np.sqrt(sa * sb)
hell_partial = max(2.0 - 2.0 * bc_coeff, 0.0)
else:
hell_partial = 0.0
sq2 = np.sqrt(2.0)
hell = np.sqrt(hell_partial) / sq2
hell_euc = np.sqrt(hell_partial)
# ── Fill symmetric matrices ───────────────────────────────────────
for mat, val in [
(shared_mat, inter),
(hamming_mat, hamming),
(jaccard_mat, jac),
(bray_mat, bc),
(relfreq_bray, rfb),
(euclidean_mat, eucl),
(relfreq_eucl, rfe),
(hellinger_mat, hell),
(hell_eucl_mat, hell_euc),
]:
mat[i, j] = val
mat[j, i] = val
return {
'shared_kmers': shared_mat,
'hamming_dist': hamming_mat,
'jaccard_dist': jaccard_mat,
'bray_curtis_dist': bray_mat,
'relfreq_bray_curtis_dist': relfreq_bray,
'euclidean_dist': euclidean_mat,
'relfreq_euclidean_dist': relfreq_eucl,
'hellinger_dist': hellinger_mat,
'hellinger_euclidean_dist': hell_eucl_mat,
}
# ── I/O ───────────────────────────────────────────────────────────────────────
def write_csv(path: Path, labels: list[str], mat: np.ndarray, fmt: str) -> None:
with path.open('w') as fh:
fh.write('genome,' + ','.join(labels) + '\n')
for i, label in enumerate(labels):
row = ','.join(format(mat[i, j], fmt) for j in range(len(labels)))
fh.write(f'{label},{row}\n')
print(f' → {path}', file=sys.stderr)
# ── main ─────────────────────────────────────────────────────────────────────
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument('--ref-dir', default='reference_index',
help='Directory with per-specimen .npz files (default: reference_index)')
ap.add_argument('--out-dir', default='reference_dist',
help='Output directory for CSV files (default: reference_dist)')
args = ap.parse_args()
ref_dir = Path(args.ref_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(exist_ok=True)
npz_files = sorted(ref_dir.glob('*.npz'))
if not npz_files:
print(f'ERROR: no .npz files found in {ref_dir}', file=sys.stderr)
sys.exit(1)
print(f'Loading {len(npz_files)} specimen(s) from {ref_dir}/', file=sys.stderr)
specimens = []
for f in npz_files:
data = np.load(f)
specimens.append({
'name': f.stem,
'kmers': data['kmers'],
'counts': data['counts'],
})
print(f' {f.stem}: {len(data["kmers"]):,} kmers', file=sys.stderr)
labels = [s['name'] for s in specimens]
n = len(labels)
print(f'\nComputing pairwise distances for {n} specimens…', file=sys.stderr)
matrices = pairwise_stats(specimens)
print(f'\nWriting CSVs to {out_dir}/', file=sys.stderr)
write_csv(out_dir / 'shared_kmers.csv', labels, matrices['shared_kmers'], 'd')
write_csv(out_dir / 'hamming_dist.csv', labels, matrices['hamming_dist'], '.6f')
write_csv(out_dir / 'jaccard_dist.csv', labels, matrices['jaccard_dist'], '.6f')
write_csv(out_dir / 'bray_curtis_dist.csv', labels, matrices['bray_curtis_dist'], '.6f')
write_csv(out_dir / 'relfreq_bray_curtis_dist.csv', labels, matrices['relfreq_bray_curtis_dist'], '.6f')
write_csv(out_dir / 'euclidean_dist.csv', labels, matrices['euclidean_dist'], '.6f')
write_csv(out_dir / 'relfreq_euclidean_dist.csv', labels, matrices['relfreq_euclidean_dist'], '.6f')
write_csv(out_dir / 'hellinger_dist.csv', labels, matrices['hellinger_dist'], '.6f')
write_csv(out_dir / 'hellinger_euclidean_dist.csv', labels, matrices['hellinger_euclidean_dist'], '.6f')
print('\nDone.', file=sys.stderr)
if __name__ == '__main__':
main()