#!/usr/bin/env python3 """Compare all reference distance matrices against obikmer distance outputs. Reads from: reference_dist/ — ground-truth matrices computed by build_reference_dist.py obikmer_dist/ — matrices produced by `obikmer distance` Handles label reordering: both matrices are sorted by genome label before element-wise comparison, so column/row order differences are irrelevant. Output: stats/dist_comparison/summary.csv comparison,max_abs,mean_abs,rmse,n_pairs,status """ import csv import sys from pathlib import Path import numpy as np # ── CSV loading ─────────────────────────────────────────────────────────────── def load_matrix(path: Path) -> tuple[list[str], np.ndarray]: """Load a distance-matrix CSV; return (sorted_labels, matrix_float64).""" with path.open() as fh: reader = csv.reader(fh) header = next(reader)[1:] # skip 'genome' column raw: dict[str, list[float]] = {} for row in reader: raw[row[0]] = [float(x) for x in row[1:]] label_to_col = {h: i for i, h in enumerate(header)} labels = sorted(raw.keys()) n = len(labels) mat = np.zeros((n, n), dtype=np.float64) for i, ri in enumerate(labels): for j, cj in enumerate(labels): mat[i, j] = raw[ri][label_to_col[cj]] return labels, mat # ── comparison ──────────────────────────────────────────────────────────────── def compare(label: str, ref_path: Path, obi_path: Path, tol: float = 1e-4) -> dict: if not ref_path.exists(): return {'comparison': label, 'status': 'REF_MISSING', 'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''} if not obi_path.exists(): return {'comparison': label, 'status': 'OBI_MISSING', 'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''} ref_labels, ref_mat = load_matrix(ref_path) obi_labels, obi_mat = load_matrix(obi_path) if ref_labels != obi_labels: only_ref = sorted(set(ref_labels) - set(obi_labels)) only_obi = sorted(set(obi_labels) - set(ref_labels)) print(f' [{label}] label mismatch — ' f'only_ref={only_ref} only_obi={only_obi}', file=sys.stderr) return {'comparison': label, 'status': 'LABEL_MISMATCH', 'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''} n = len(ref_labels) # Off-diagonal mask mask = ~np.eye(n, dtype=bool) diff = np.abs(ref_mat[mask] - obi_mat[mask]) n_pairs = diff.size max_abs = float(diff.max()) mean_abs = float(diff.mean()) rmse = float(np.sqrt((diff ** 2).mean())) status = 'PASS' if max_abs <= tol else 'FAIL' print(f' [{label}] n={n_pairs} ' f'max={max_abs:.3e} mean={mean_abs:.3e} rmse={rmse:.3e} {status}', file=sys.stderr) return { 'comparison': label, 'max_abs': f'{max_abs:.6e}', 'mean_abs': f'{mean_abs:.6e}', 'rmse': f'{rmse:.6e}', 'n_pairs': str(n_pairs), 'status': status, } # ── comparison table ────────────────────────────────────────────────────────── # (label, ref_csv, obikmer_csv) # The reference jaccard/shared is presence-based, which should match both # presence/jaccard and count/jaccard (threshold=1). COMPARISONS = [ # ── presence index ──────────────────────────────────────────────────────── ('presence/jaccard_dist', 'reference_dist/jaccard_dist.csv', 'obikmer_dist/presence/jaccard_dist.csv'), ('presence/jaccard_shared', 'reference_dist/shared_kmers.csv', 'obikmer_dist/presence/jaccard_shared.csv'), ('presence/hamming_dist', 'reference_dist/hamming_dist.csv', 'obikmer_dist/presence/hamming_dist.csv'), # ── count index (jaccard cross-check) ───────────────────────────────────── ('count/jaccard_dist', 'reference_dist/jaccard_dist.csv', 'obikmer_dist/count/jaccard_dist.csv'), ('count/jaccard_shared', 'reference_dist/shared_kmers.csv', 'obikmer_dist/count/jaccard_shared.csv'), # ── count index (count-based metrics) ──────────────────────────────────── ('count/bray_curtis_dist', 'reference_dist/bray_curtis_dist.csv', 'obikmer_dist/count/bray_curtis_dist.csv'), ('count/relfreq_bray_curtis_dist', 'reference_dist/relfreq_bray_curtis_dist.csv', 'obikmer_dist/count/relfreq_bray_curtis_dist.csv'), ('count/euclidean_dist', 'reference_dist/euclidean_dist.csv', 'obikmer_dist/count/euclidean_dist.csv'), ('count/relfreq_euclidean_dist', 'reference_dist/relfreq_euclidean_dist.csv', 'obikmer_dist/count/relfreq_euclidean_dist.csv'), ('count/hellinger_dist', 'reference_dist/hellinger_dist.csv', 'obikmer_dist/count/hellinger_dist.csv'), ('count/hellinger_euclidean_dist', 'reference_dist/hellinger_euclidean_dist.csv', 'obikmer_dist/count/hellinger_euclidean_dist.csv'), ] # ── main ───────────────────────────────────────────────────────────────────── def main() -> None: import argparse ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument('--tol', type=float, default=1e-4, help='Max abs diff threshold for PASS/FAIL (default 1e-4)') ap.add_argument('--out', default='stats/dist_comparison/summary.csv', help='Output summary CSV path') args = ap.parse_args() out_path = Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) print(f'Comparing {len(COMPARISONS)} matrix pairs…', file=sys.stderr) rows = [] for label, ref, obi in COMPARISONS: rows.append(compare(label, Path(ref), Path(obi), tol=args.tol)) fields = ['comparison', 'max_abs', 'mean_abs', 'rmse', 'n_pairs', 'status'] with out_path.open('w', newline='') as fh: w = csv.DictWriter(fh, fieldnames=fields) w.writeheader() w.writerows(rows) print(f'\n→ {out_path}', file=sys.stderr) n_fail = sum(1 for r in rows if r.get('status') == 'FAIL') n_pass = sum(1 for r in rows if r.get('status') == 'PASS') print(f'Summary: {n_pass} PASS {n_fail} FAIL ' f'{len(rows) - n_pass - n_fail} SKIP', file=sys.stderr) if n_fail: sys.exit(1) if __name__ == '__main__': main()