Files
obikmer/benchmark/compare_all_dist.py
T
Eric Coissac 469e53b6f5 Add genomic distance benchmarking suite and test data
Introduces scripts to compute and validate pairwise genomic distance matrices across multiple metrics. Updates the Makefile with build and comparison targets, adds .gitignore rules for generated outputs, and includes test CSV matrices and a Newick phylogenetic tree for validating the distance computation pipeline.
2026-06-22 18:24:30 +02:00

183 lines
7.1 KiB
Python
Executable File

#!/usr/bin/env python3
"""Compare all reference distance matrices against obikmer distance outputs.
Reads from:
reference_dist/ — ground-truth matrices computed by build_reference_dist.py
obikmer_dist/ — matrices produced by `obikmer distance`
Handles label reordering: both matrices are sorted by genome label before
element-wise comparison, so column/row order differences are irrelevant.
Output: stats/dist_comparison/summary.csv
comparison,max_abs,mean_abs,rmse,n_pairs,status
"""
import csv
import sys
from pathlib import Path
import numpy as np
# ── CSV loading ───────────────────────────────────────────────────────────────
def load_matrix(path: Path) -> tuple[list[str], np.ndarray]:
"""Load a distance-matrix CSV; return (sorted_labels, matrix_float64)."""
with path.open() as fh:
reader = csv.reader(fh)
header = next(reader)[1:] # skip 'genome' column
raw: dict[str, list[float]] = {}
for row in reader:
raw[row[0]] = [float(x) for x in row[1:]]
label_to_col = {h: i for i, h in enumerate(header)}
labels = sorted(raw.keys())
n = len(labels)
mat = np.zeros((n, n), dtype=np.float64)
for i, ri in enumerate(labels):
for j, cj in enumerate(labels):
mat[i, j] = raw[ri][label_to_col[cj]]
return labels, mat
# ── comparison ────────────────────────────────────────────────────────────────
def compare(label: str,
ref_path: Path,
obi_path: Path,
tol: float = 1e-4) -> dict:
if not ref_path.exists():
return {'comparison': label, 'status': 'REF_MISSING',
'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''}
if not obi_path.exists():
return {'comparison': label, 'status': 'OBI_MISSING',
'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''}
ref_labels, ref_mat = load_matrix(ref_path)
obi_labels, obi_mat = load_matrix(obi_path)
if ref_labels != obi_labels:
only_ref = sorted(set(ref_labels) - set(obi_labels))
only_obi = sorted(set(obi_labels) - set(ref_labels))
print(f' [{label}] label mismatch — '
f'only_ref={only_ref} only_obi={only_obi}', file=sys.stderr)
return {'comparison': label, 'status': 'LABEL_MISMATCH',
'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''}
n = len(ref_labels)
# Off-diagonal mask
mask = ~np.eye(n, dtype=bool)
diff = np.abs(ref_mat[mask] - obi_mat[mask])
n_pairs = diff.size
max_abs = float(diff.max())
mean_abs = float(diff.mean())
rmse = float(np.sqrt((diff ** 2).mean()))
status = 'PASS' if max_abs <= tol else 'FAIL'
print(f' [{label}] n={n_pairs} '
f'max={max_abs:.3e} mean={mean_abs:.3e} rmse={rmse:.3e} {status}',
file=sys.stderr)
return {
'comparison': label,
'max_abs': f'{max_abs:.6e}',
'mean_abs': f'{mean_abs:.6e}',
'rmse': f'{rmse:.6e}',
'n_pairs': str(n_pairs),
'status': status,
}
# ── comparison table ──────────────────────────────────────────────────────────
# (label, ref_csv, obikmer_csv)
# The reference jaccard/shared is presence-based, which should match both
# presence/jaccard and count/jaccard (threshold=1).
COMPARISONS = [
# ── presence index ────────────────────────────────────────────────────────
('presence/jaccard_dist',
'reference_dist/jaccard_dist.csv',
'obikmer_dist/presence/jaccard_dist.csv'),
('presence/jaccard_shared',
'reference_dist/shared_kmers.csv',
'obikmer_dist/presence/jaccard_shared.csv'),
('presence/hamming_dist',
'reference_dist/hamming_dist.csv',
'obikmer_dist/presence/hamming_dist.csv'),
# ── count index (jaccard cross-check) ─────────────────────────────────────
('count/jaccard_dist',
'reference_dist/jaccard_dist.csv',
'obikmer_dist/count/jaccard_dist.csv'),
('count/jaccard_shared',
'reference_dist/shared_kmers.csv',
'obikmer_dist/count/jaccard_shared.csv'),
# ── count index (count-based metrics) ────────────────────────────────────
('count/bray_curtis_dist',
'reference_dist/bray_curtis_dist.csv',
'obikmer_dist/count/bray_curtis_dist.csv'),
('count/relfreq_bray_curtis_dist',
'reference_dist/relfreq_bray_curtis_dist.csv',
'obikmer_dist/count/relfreq_bray_curtis_dist.csv'),
('count/euclidean_dist',
'reference_dist/euclidean_dist.csv',
'obikmer_dist/count/euclidean_dist.csv'),
('count/relfreq_euclidean_dist',
'reference_dist/relfreq_euclidean_dist.csv',
'obikmer_dist/count/relfreq_euclidean_dist.csv'),
('count/hellinger_dist',
'reference_dist/hellinger_dist.csv',
'obikmer_dist/count/hellinger_dist.csv'),
('count/hellinger_euclidean_dist',
'reference_dist/hellinger_euclidean_dist.csv',
'obikmer_dist/count/hellinger_euclidean_dist.csv'),
]
# ── main ─────────────────────────────────────────────────────────────────────
def main() -> None:
import argparse
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument('--tol', type=float, default=1e-4,
help='Max abs diff threshold for PASS/FAIL (default 1e-4)')
ap.add_argument('--out', default='stats/dist_comparison/summary.csv',
help='Output summary CSV path')
args = ap.parse_args()
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f'Comparing {len(COMPARISONS)} matrix pairs…', file=sys.stderr)
rows = []
for label, ref, obi in COMPARISONS:
rows.append(compare(label, Path(ref), Path(obi), tol=args.tol))
fields = ['comparison', 'max_abs', 'mean_abs', 'rmse', 'n_pairs', 'status']
with out_path.open('w', newline='') as fh:
w = csv.DictWriter(fh, fieldnames=fields)
w.writeheader()
w.writerows(rows)
print(f'\n{out_path}', file=sys.stderr)
n_fail = sum(1 for r in rows if r.get('status') == 'FAIL')
n_pass = sum(1 for r in rows if r.get('status') == 'PASS')
print(f'Summary: {n_pass} PASS {n_fail} FAIL '
f'{len(rows) - n_pass - n_fail} SKIP', file=sys.stderr)
if n_fail:
sys.exit(1)
if __name__ == '__main__':
main()